45 lines
1.9 KiB
Python
45 lines
1.9 KiB
Python
import torch
|
|
import PytorchBoot.namespace as namespace
|
|
|
|
class TensorboardWriter:
|
|
@staticmethod
|
|
def write_tensorboard(writer, panel, data_dict, step, simple_scalar = False):
|
|
|
|
if simple_scalar:
|
|
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
|
|
|
|
if namespace.TensorBoard.SCALAR in data_dict:
|
|
scalar_data_dict = data_dict[namespace.TensorBoard.SCALAR]
|
|
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
|
|
if namespace.TensorBoard.IMAGE in data_dict:
|
|
image_data_dict = data_dict[namespace.TensorBoard.IMAGE]
|
|
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
|
|
if namespace.TensorBoard.POINT in data_dict:
|
|
point_data_dict = data_dict[namespace.TensorBoard.POINT]
|
|
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
|
|
|
|
@staticmethod
|
|
def write_scalar_tensorboard(writer, panel, data_dict, step):
|
|
for key, value in data_dict.items():
|
|
if isinstance(value, dict):
|
|
writer.add_scalars(f'{panel}/{key}', value, step)
|
|
else:
|
|
writer.add_scalar(f'{panel}/{key}', value, step)
|
|
|
|
@staticmethod
|
|
def write_image_tensorboard(writer, panel, data_dict, step):
|
|
pass
|
|
|
|
@staticmethod
|
|
def write_points_tensorboard(writer, panel, data_dict, step):
|
|
for key, value in data_dict.items():
|
|
if value.shape[-1] == 3:
|
|
colors = torch.zeros_like(value)
|
|
vertices = torch.cat([value, colors], dim=-1)
|
|
elif value.shape[-1] == 6:
|
|
vertices = value
|
|
else:
|
|
raise ValueError(f'Unexpected value shape: {value.shape}')
|
|
faces = None
|
|
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)
|