import torch import PytorchBoot.stereotype as stereotype @stereotype.loss_function("gf_loss") class GFLoss: def __init__(self, _): pass def compute(self, output, _): estimated_score = output['estimated_score'] target_score = output['target_score'] std = output['std'] bs = estimated_score.shape[0] loss_weighting = std ** 2 loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1)) return loss