81 lines
3.3 KiB
Python
Executable File
81 lines
3.3 KiB
Python
Executable File
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
def get_loss(end_points):
|
|
objectness_loss, end_points = compute_objectness_loss(end_points)
|
|
graspness_loss, end_points = compute_graspness_loss(end_points)
|
|
view_loss, end_points = compute_view_graspness_loss(end_points)
|
|
score_loss, end_points = compute_score_loss(end_points)
|
|
width_loss, end_points = compute_width_loss(end_points)
|
|
loss = objectness_loss + 10 * graspness_loss + 100 * view_loss + 15 * score_loss + 10 * width_loss
|
|
end_points['loss/overall_loss'] = loss
|
|
return loss, end_points
|
|
|
|
|
|
def compute_objectness_loss(end_points):
|
|
criterion = nn.CrossEntropyLoss(reduction='mean')
|
|
objectness_score = end_points['objectness_score']
|
|
objectness_label = end_points['objectness_label']
|
|
loss = criterion(objectness_score, objectness_label)
|
|
end_points['loss/stage1_objectness_loss'] = loss
|
|
|
|
objectness_pred = torch.argmax(objectness_score, 1)
|
|
end_points['stage1_objectness_acc'] = (objectness_pred == objectness_label.long()).float().mean()
|
|
end_points['stage1_objectness_prec'] = (objectness_pred == objectness_label.long())[
|
|
objectness_pred == 1].float().mean()
|
|
end_points['stage1_objectness_recall'] = (objectness_pred == objectness_label.long())[
|
|
objectness_label == 1].float().mean()
|
|
return loss, end_points
|
|
|
|
|
|
def compute_graspness_loss(end_points):
|
|
criterion = nn.SmoothL1Loss(reduction='none')
|
|
graspness_score = end_points['graspness_score'].squeeze(1)
|
|
graspness_label = end_points['graspness_label'].squeeze(-1)
|
|
loss_mask = end_points['objectness_label'].bool()
|
|
loss = criterion(graspness_score, graspness_label)
|
|
loss = loss[loss_mask]
|
|
loss = loss.mean()
|
|
|
|
graspness_score_c = graspness_score.detach().clone()[loss_mask]
|
|
graspness_label_c = graspness_label.detach().clone()[loss_mask]
|
|
graspness_score_c = torch.clamp(graspness_score_c, 0., 0.99)
|
|
graspness_label_c = torch.clamp(graspness_label_c, 0., 0.99)
|
|
rank_error = (torch.abs(torch.trunc(graspness_score_c * 20) - torch.trunc(graspness_label_c * 20)) / 20.).mean()
|
|
end_points['stage1_graspness_acc_rank_error'] = rank_error
|
|
|
|
end_points['loss/stage1_graspness_loss'] = loss
|
|
return loss, end_points
|
|
|
|
|
|
def compute_view_graspness_loss(end_points):
|
|
criterion = nn.SmoothL1Loss(reduction='mean')
|
|
view_score = end_points['view_score']
|
|
view_label = end_points['batch_grasp_view_graspness']
|
|
loss = criterion(view_score, view_label)
|
|
end_points['loss/stage2_view_loss'] = loss
|
|
return loss, end_points
|
|
|
|
|
|
def compute_score_loss(end_points):
|
|
criterion = nn.SmoothL1Loss(reduction='mean')
|
|
grasp_score_pred = end_points['grasp_score_pred']
|
|
grasp_score_label = end_points['batch_grasp_score']
|
|
loss = criterion(grasp_score_pred, grasp_score_label)
|
|
|
|
end_points['loss/stage3_score_loss'] = loss
|
|
return loss, end_points
|
|
|
|
|
|
def compute_width_loss(end_points):
|
|
criterion = nn.SmoothL1Loss(reduction='none')
|
|
grasp_width_pred = end_points['grasp_width_pred']
|
|
grasp_width_label = end_points['batch_grasp_width'] * 10
|
|
loss = criterion(grasp_width_pred, grasp_width_label)
|
|
grasp_score_label = end_points['batch_grasp_score']
|
|
loss_mask = grasp_score_label > 0
|
|
loss = loss[loss_mask].mean()
|
|
end_points['loss/stage3_width_loss'] = loss
|
|
return loss, end_points
|