from torch import nn import PytorchBoot.stereotype as stereotype @stereotype.module("pts_num_encoder") class PointsNumEncoder(nn.Module): def __init__(self, config): super(PointsNumEncoder, self).__init__() self.config = config out_dim = config["out_dim"] self.act = nn.ReLU(True) self.pts_num_encoder = nn.Sequential( nn.Linear(1, out_dim), self.act, nn.Linear(out_dim, out_dim), self.act, ) def encode_pts_num(self, num_seq): return self.pts_num_encoder(num_seq)