22 lines
597 B
Python
22 lines
597 B
Python
from torch import nn
|
|
import PytorchBoot.stereotype as stereotype
|
|
|
|
@stereotype.module("pose_encoder")
|
|
class PoseEncoder(nn.Module):
|
|
def __init__(self, config):
|
|
super(PoseEncoder, self).__init__()
|
|
self.config = config
|
|
pose_dim = config["pose_dim"]
|
|
out_dim = config["out_dim"]
|
|
self.act = nn.ReLU(True)
|
|
|
|
self.pose_encoder = nn.Sequential(
|
|
nn.Linear(pose_dim, out_dim),
|
|
self.act,
|
|
nn.Linear(out_dim, out_dim),
|
|
self.act,
|
|
)
|
|
|
|
def encode_pose(self, pose):
|
|
return self.pose_encoder(pose)
|