2024-10-09 16:13:22 +00:00

149 lines
6.2 KiB
Python
Executable File

import os
import sys
import numpy as np
from datetime import datetime
import argparse
import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(ROOT_DIR, 'pointnet2'))
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
sys.path.append(os.path.join(ROOT_DIR, 'models'))
sys.path.append(os.path.join(ROOT_DIR, 'dataset'))
from models.graspnet import GraspNet
from models.loss import get_loss
from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn, load_grasp_labels
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_root', default=None, required=True)
parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]')
parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None)
parser.add_argument('--model_name', type=str, default=None)
parser.add_argument('--log_dir', default='logs/log')
parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 20000]')
parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim')
parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size to process point clouds ')
parser.add_argument('--max_epoch', type=int, default=10, help='Epoch to run [default: 18]')
parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 2]')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
parser.add_argument('--resume', action='store_true', default=False, help='Whether to resume from checkpoint')
cfgs = parser.parse_args()
# ------------------------------------------------------------------------- GLOBAL CONFIG BEG
EPOCH_CNT = 0
CHECKPOINT_PATH = cfgs.checkpoint_path if cfgs.checkpoint_path is not None and cfgs.resume else None
if not os.path.exists(cfgs.log_dir):
os.makedirs(cfgs.log_dir)
LOG_FOUT = open(os.path.join(cfgs.log_dir, 'log_train.txt'), 'a')
LOG_FOUT.write(str(cfgs) + '\n')
def log_string(out_str):
LOG_FOUT.write(out_str + '\n')
LOG_FOUT.flush()
print(out_str)
# Init datasets and dataloaders
def my_worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
pass
grasp_labels = load_grasp_labels(cfgs.dataset_root)
TRAIN_DATASET = GraspNetDataset(cfgs.dataset_root, grasp_labels=grasp_labels, camera=cfgs.camera, split='train',
num_points=cfgs.num_point, voxel_size=cfgs.voxel_size,
remove_outlier=True, augment=True, load_label=True)
print('train dataset length: ', len(TRAIN_DATASET))
TRAIN_DATALOADER = DataLoader(TRAIN_DATASET, batch_size=cfgs.batch_size, shuffle=True,
num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn)
print('train dataloader length: ', len(TRAIN_DATALOADER))
net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)
# Load the Adam optimizer
optimizer = optim.Adam(net.parameters(), lr=cfgs.learning_rate)
start_epoch = 0
if CHECKPOINT_PATH is not None and os.path.isfile(CHECKPOINT_PATH):
checkpoint = torch.load(CHECKPOINT_PATH)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
log_string("-> loaded checkpoint %s (epoch: %d)" % (CHECKPOINT_PATH, start_epoch))
# TensorBoard Visualizers
TRAIN_WRITER = SummaryWriter(os.path.join(cfgs.log_dir, 'train'))
def get_current_lr(epoch):
lr = cfgs.learning_rate
lr = lr * (0.95 ** epoch)
return lr
def adjust_learning_rate(optimizer, epoch):
lr = get_current_lr(epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def train_one_epoch():
stat_dict = {} # collect statistics
adjust_learning_rate(optimizer, EPOCH_CNT)
net.train()
batch_interval = 50
for batch_idx, batch_data_label in enumerate(tqdm(TRAIN_DATALOADER)):
for key in batch_data_label:
if 'list' in key:
for i in range(len(batch_data_label[key])):
for j in range(len(batch_data_label[key][i])):
batch_data_label[key][i][j] = batch_data_label[key][i][j].to(device)
else:
batch_data_label[key] = batch_data_label[key].to(device)
end_points = net(batch_data_label)
loss, end_points = get_loss(end_points)
loss.backward()
optimizer.step()
optimizer.zero_grad()
for key in end_points:
if 'loss' in key or 'acc' in key or 'prec' in key or 'recall' in key or 'count' in key:
if key not in stat_dict:
stat_dict[key] = 0
stat_dict[key] += end_points[key].item()
if (batch_idx + 1) % batch_interval == 0:
log_string(' ----epoch: %03d ---- batch: %03d ----' % (EPOCH_CNT, batch_idx + 1))
for key in sorted(stat_dict.keys()):
TRAIN_WRITER.add_scalar(key, stat_dict[key] / batch_interval,
(EPOCH_CNT * len(TRAIN_DATALOADER) + batch_idx) * cfgs.batch_size)
log_string('mean %s: %f' % (key, stat_dict[key] / batch_interval))
stat_dict[key] = 0
def train(start_epoch):
global EPOCH_CNT
for epoch in range(start_epoch, cfgs.max_epoch):
EPOCH_CNT = epoch
log_string('**** EPOCH %03d ****' % epoch)
log_string('Current learning rate: %f' % (get_current_lr(epoch)))
log_string(str(datetime.now()))
# Reset numpy seed.
# REF: https://github.com/pytorch/pytorch/issues/5059
np.random.seed()
train_one_epoch()
save_dict = {'epoch': epoch + 1, 'optimizer_state_dict': optimizer.state_dict(),
'model_state_dict': net.state_dict()}
torch.save(save_dict, os.path.join(cfgs.log_dir, cfgs.model_name + '_epoch' + str(epoch + 1).zfill(2) + '.tar'))
if __name__ == '__main__':
train(start_epoch)