27 lines
838 B
Python
27 lines
838 B
Python
import torch
|
|
import PytorchBoot.stereotype as stereotype
|
|
|
|
@stereotype.loss_function("gf_loss")
|
|
class GFLoss:
|
|
def __init__(self, _):
|
|
pass
|
|
|
|
def compute(self, output, _):
|
|
estimated_score = output['estimated_score']
|
|
target_score = output['target_score']
|
|
std = output['std']
|
|
bs = estimated_score.shape[0]
|
|
loss_weighting = std ** 2
|
|
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
|
return loss
|
|
|
|
@stereotype.loss_function("mse_loss")
|
|
class MSELoss:
|
|
def __init__(self,_):
|
|
pass
|
|
|
|
def compute(self, output, _):
|
|
pred_pose = output["pred"]
|
|
gt_pose = output["gt"]
|
|
loss = torch.mean(torch.sum((pred_pose - gt_pose) ** 2, dim=-1))
|
|
return loss |