diff --git a/core/loss.py b/core/loss.py new file mode 100644 index 0000000..1de8b8d --- /dev/null +++ b/core/loss.py @@ -0,0 +1,16 @@ +import torch +import PytorchBoot.stereotype as stereotype + +@stereotype.loss_function("gf_loss") +class GFLoss: + def __init__(self, config): + pass + + def compute(self, output, _): + estimated_score = output['estimated_score'] + target_score = output['target_score'] + std = output['std'] + bs = estimated_score.shape[0] + loss_weighting = std ** 2 + loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1)) + return loss \ No newline at end of file diff --git a/losses/gf_loss.py b/losses/gf_loss.py deleted file mode 100644 index a4320a2..0000000 --- a/losses/gf_loss.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -import PytorchBoot.stereotype as stereotype - -@stereotype.loss_function("gf_loss") -def compute_loss(output, data): - estimated_score = output['estimated_score'] - target_score = output['target_score'] - std = output['std'] - bs = estimated_score.shape[0] - loss_weighting = std ** 2 - loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1)) - return loss