update loss
This commit is contained in:
parent
f977fd4b8e
commit
913d4e521d
16
core/loss.py
Normal file
16
core/loss.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import torch
|
||||||
|
import PytorchBoot.stereotype as stereotype
|
||||||
|
|
||||||
|
@stereotype.loss_function("gf_loss")
|
||||||
|
class GFLoss:
|
||||||
|
def __init__(self, config):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def compute(self, output, _):
|
||||||
|
estimated_score = output['estimated_score']
|
||||||
|
target_score = output['target_score']
|
||||||
|
std = output['std']
|
||||||
|
bs = estimated_score.shape[0]
|
||||||
|
loss_weighting = std ** 2
|
||||||
|
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
||||||
|
return loss
|
@ -1,12 +0,0 @@
|
|||||||
import torch
|
|
||||||
import PytorchBoot.stereotype as stereotype
|
|
||||||
|
|
||||||
@stereotype.loss_function("gf_loss")
|
|
||||||
def compute_loss(output, data):
|
|
||||||
estimated_score = output['estimated_score']
|
|
||||||
target_score = output['target_score']
|
|
||||||
std = output['std']
|
|
||||||
bs = estimated_score.shape[0]
|
|
||||||
loss_weighting = std ** 2
|
|
||||||
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
|
|
||||||
return loss
|
|
Loading…
x
Reference in New Issue
Block a user