149 lines
6.2 KiB
Python
Executable File
149 lines
6.2 KiB
Python
Executable File
import os
|
|
import sys
|
|
import numpy as np
|
|
from datetime import datetime
|
|
import argparse
|
|
|
|
import torch
|
|
import torch.optim as optim
|
|
from tqdm import tqdm
|
|
from torch.utils.data import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'pointnet2'))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'utils'))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'models'))
|
|
sys.path.append(os.path.join(ROOT_DIR, 'dataset'))
|
|
|
|
from models.graspnet import GraspNet
|
|
from models.loss import get_loss
|
|
from dataset.graspnet_dataset import GraspNetDataset, minkowski_collate_fn, load_grasp_labels
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset_root', default=None, required=True)
|
|
parser.add_argument('--camera', default='kinect', help='Camera split [realsense/kinect]')
|
|
parser.add_argument('--checkpoint_path', help='Model checkpoint path', default=None)
|
|
parser.add_argument('--model_name', type=str, default=None)
|
|
parser.add_argument('--log_dir', default='logs/log')
|
|
parser.add_argument('--num_point', type=int, default=15000, help='Point Number [default: 20000]')
|
|
parser.add_argument('--seed_feat_dim', default=512, type=int, help='Point wise feature dim')
|
|
parser.add_argument('--voxel_size', type=float, default=0.005, help='Voxel Size to process point clouds ')
|
|
parser.add_argument('--max_epoch', type=int, default=10, help='Epoch to run [default: 18]')
|
|
parser.add_argument('--batch_size', type=int, default=4, help='Batch Size during training [default: 2]')
|
|
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
|
|
parser.add_argument('--resume', action='store_true', default=False, help='Whether to resume from checkpoint')
|
|
cfgs = parser.parse_args()
|
|
# ------------------------------------------------------------------------- GLOBAL CONFIG BEG
|
|
EPOCH_CNT = 0
|
|
CHECKPOINT_PATH = cfgs.checkpoint_path if cfgs.checkpoint_path is not None and cfgs.resume else None
|
|
if not os.path.exists(cfgs.log_dir):
|
|
os.makedirs(cfgs.log_dir)
|
|
|
|
LOG_FOUT = open(os.path.join(cfgs.log_dir, 'log_train.txt'), 'a')
|
|
LOG_FOUT.write(str(cfgs) + '\n')
|
|
|
|
|
|
def log_string(out_str):
|
|
LOG_FOUT.write(out_str + '\n')
|
|
LOG_FOUT.flush()
|
|
print(out_str)
|
|
|
|
|
|
# Init datasets and dataloaders
|
|
def my_worker_init_fn(worker_id):
|
|
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
|
pass
|
|
|
|
|
|
grasp_labels = load_grasp_labels(cfgs.dataset_root)
|
|
TRAIN_DATASET = GraspNetDataset(cfgs.dataset_root, grasp_labels=grasp_labels, camera=cfgs.camera, split='train',
|
|
num_points=cfgs.num_point, voxel_size=cfgs.voxel_size,
|
|
remove_outlier=True, augment=True, load_label=True)
|
|
print('train dataset length: ', len(TRAIN_DATASET))
|
|
TRAIN_DATALOADER = DataLoader(TRAIN_DATASET, batch_size=cfgs.batch_size, shuffle=True,
|
|
num_workers=0, worker_init_fn=my_worker_init_fn, collate_fn=minkowski_collate_fn)
|
|
print('train dataloader length: ', len(TRAIN_DATALOADER))
|
|
|
|
net = GraspNet(seed_feat_dim=cfgs.seed_feat_dim, is_training=True)
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
net.to(device)
|
|
# Load the Adam optimizer
|
|
optimizer = optim.Adam(net.parameters(), lr=cfgs.learning_rate)
|
|
start_epoch = 0
|
|
if CHECKPOINT_PATH is not None and os.path.isfile(CHECKPOINT_PATH):
|
|
checkpoint = torch.load(CHECKPOINT_PATH)
|
|
net.load_state_dict(checkpoint['model_state_dict'])
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
start_epoch = checkpoint['epoch']
|
|
log_string("-> loaded checkpoint %s (epoch: %d)" % (CHECKPOINT_PATH, start_epoch))
|
|
# TensorBoard Visualizers
|
|
TRAIN_WRITER = SummaryWriter(os.path.join(cfgs.log_dir, 'train'))
|
|
|
|
|
|
def get_current_lr(epoch):
|
|
lr = cfgs.learning_rate
|
|
lr = lr * (0.95 ** epoch)
|
|
return lr
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch):
|
|
lr = get_current_lr(epoch)
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
|
|
|
|
def train_one_epoch():
|
|
stat_dict = {} # collect statistics
|
|
adjust_learning_rate(optimizer, EPOCH_CNT)
|
|
net.train()
|
|
batch_interval = 50
|
|
for batch_idx, batch_data_label in enumerate(tqdm(TRAIN_DATALOADER)):
|
|
for key in batch_data_label:
|
|
if 'list' in key:
|
|
for i in range(len(batch_data_label[key])):
|
|
for j in range(len(batch_data_label[key][i])):
|
|
batch_data_label[key][i][j] = batch_data_label[key][i][j].to(device)
|
|
else:
|
|
batch_data_label[key] = batch_data_label[key].to(device)
|
|
end_points = net(batch_data_label)
|
|
loss, end_points = get_loss(end_points)
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
for key in end_points:
|
|
if 'loss' in key or 'acc' in key or 'prec' in key or 'recall' in key or 'count' in key:
|
|
if key not in stat_dict:
|
|
stat_dict[key] = 0
|
|
stat_dict[key] += end_points[key].item()
|
|
|
|
if (batch_idx + 1) % batch_interval == 0:
|
|
log_string(' ----epoch: %03d ---- batch: %03d ----' % (EPOCH_CNT, batch_idx + 1))
|
|
for key in sorted(stat_dict.keys()):
|
|
TRAIN_WRITER.add_scalar(key, stat_dict[key] / batch_interval,
|
|
(EPOCH_CNT * len(TRAIN_DATALOADER) + batch_idx) * cfgs.batch_size)
|
|
log_string('mean %s: %f' % (key, stat_dict[key] / batch_interval))
|
|
stat_dict[key] = 0
|
|
|
|
|
|
def train(start_epoch):
|
|
global EPOCH_CNT
|
|
for epoch in range(start_epoch, cfgs.max_epoch):
|
|
EPOCH_CNT = epoch
|
|
log_string('**** EPOCH %03d ****' % epoch)
|
|
log_string('Current learning rate: %f' % (get_current_lr(epoch)))
|
|
log_string(str(datetime.now()))
|
|
# Reset numpy seed.
|
|
# REF: https://github.com/pytorch/pytorch/issues/5059
|
|
np.random.seed()
|
|
train_one_epoch()
|
|
|
|
save_dict = {'epoch': epoch + 1, 'optimizer_state_dict': optimizer.state_dict(),
|
|
'model_state_dict': net.state_dict()}
|
|
torch.save(save_dict, os.path.join(cfgs.log_dir, cfgs.model_name + '_epoch' + str(epoch + 1).zfill(2) + '.tar'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
train(start_epoch)
|