nbv_grasping/losses/gf_loss.py
2024-10-09 16:13:22 +00:00

12 lines
364 B
Python
Executable File

import torch
def compute_loss(output, data):
estimated_score = output['estimated_score']
target_score = output['target_score']
std = output['std']
bs = estimated_score.shape[0]
loss_weighting = std ** 2
loss = torch.mean(torch.sum((loss_weighting * (estimated_score - target_score) ** 2).view(bs, -1), dim=-1))
return loss