from torch import nn import PytorchBoot.stereotype as stereotype @stereotype.module("pose_encoder") class PoseEncoder(nn.Module): def __init__(self, config): super(PoseEncoder, self).__init__() self.config = config pose_dim = config["pose_dim"] out_dim = config["out_dim"] self.act = nn.ReLU(True) self.pose_encoder = nn.Sequential( nn.Linear(pose_dim, out_dim), self.act, nn.Linear(out_dim, out_dim), self.act, ) def encode_pose(self, pose): return self.pose_encoder(pose)