PyTorchBoot/build/lib/PytorchBoot/utils/tensorboard_util.py
2024-09-13 16:58:34 +08:00

45 lines
1.9 KiB
Python

import torch
import PytorchBoot.namespace as namespace
class TensorboardWriter:
@staticmethod
def write_tensorboard(writer, panel, data_dict, step, simple_scalar = False):
if simple_scalar:
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
if namespace.TensorBoard.SCALAR in data_dict:
scalar_data_dict = data_dict[namespace.TensorBoard.SCALAR]
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
if namespace.TensorBoard.IMAGE in data_dict:
image_data_dict = data_dict[namespace.TensorBoard.IMAGE]
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
if namespace.TensorBoard.POINT in data_dict:
point_data_dict = data_dict[namespace.TensorBoard.POINT]
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
@staticmethod
def write_scalar_tensorboard(writer, panel, data_dict, step):
for key, value in data_dict.items():
if isinstance(value, dict):
writer.add_scalars(f'{panel}/{key}', value, step)
else:
writer.add_scalar(f'{panel}/{key}', value, step)
@staticmethod
def write_image_tensorboard(writer, panel, data_dict, step):
pass
@staticmethod
def write_points_tensorboard(writer, panel, data_dict, step):
for key, value in data_dict.items():
if value.shape[-1] == 3:
colors = torch.zeros_like(value)
vertices = torch.cat([value, colors], dim=-1)
elif value.shape[-1] == 6:
vertices = value
else:
raise ValueError(f'Unexpected value shape: {value.shape}')
faces = None
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)