remove build
This commit is contained in:
parent
b06a9ecee0
commit
7546b34515
162
.gitignore
vendored
Normal file
162
.gitignore
vendored
Normal file
@ -0,0 +1,162 @@
|
||||
# ---> Python
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
9
LICENSE
Normal file
9
LICENSE
Normal file
@ -0,0 +1,9 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 hofee
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
@ -1,21 +0,0 @@
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
application_class = {}
|
||||
def PytorchBootApplication(arg=None):
|
||||
if callable(arg):
|
||||
cls = arg
|
||||
if "default" in application_class:
|
||||
Log.error("Multiple classes annotated with default @PytorchBootApplication, require a 'name' parameter.", True)
|
||||
application_class["default"] = cls
|
||||
return cls
|
||||
|
||||
else:
|
||||
name = arg
|
||||
def decorator(cls):
|
||||
if name is None:
|
||||
raise Log.error("The 'name' parameter is required when using @PytorchBootApplication with arguments.", True)
|
||||
if name in application_class:
|
||||
raise Log.error(f"Multiple classes annotated with @PytorchBootApplication with the same name '{name}' found.", True)
|
||||
application_class[name] = cls
|
||||
return cls
|
||||
return decorator
|
@ -1,108 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from PytorchBoot.application import application_class
|
||||
from PytorchBoot.stereotype import get_all_component_classes, get_all_component_comments
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.utils.timer_util import Timer
|
||||
from PytorchBoot.utils.project_util import ProjectUtil
|
||||
from PytorchBoot.templates.application import template as app_template
|
||||
from PytorchBoot.templates.config import template as config_template
|
||||
from PytorchBoot.ui.server.app import app
|
||||
|
||||
|
||||
def run():
|
||||
root_path = os.getcwd()
|
||||
ProjectUtil.scan_project(root_path)
|
||||
|
||||
app_name = "default"
|
||||
if len(application_class) == 0:
|
||||
Log.error("No class annotated with @PytorchBootApplication found.", True)
|
||||
if len(sys.argv) < 3 and "default" not in application_class:
|
||||
Log.error("No default @PytorchBootApplication found. Please specify the 'name' parameter.", True)
|
||||
if len(sys.argv) == 3:
|
||||
app_name = sys.argv[2]
|
||||
|
||||
app_cls = application_class.get(app_name)
|
||||
|
||||
if app_cls is None:
|
||||
Log.error(f"No class annotated with @PytorchBootApplication found with the name '{app_name}'.", True)
|
||||
|
||||
if not hasattr(app_cls, "start"):
|
||||
Log.error("The class annotated with @PytorchBootApplication should have a 'start' method.", True)
|
||||
|
||||
Log.info(f"Application '{app_cls.__name__}' started.")
|
||||
timer = Timer("Application")
|
||||
|
||||
timer.start()
|
||||
app_cls.start()
|
||||
timer.stop()
|
||||
Log.info(timer.get_elasped_time_str(Timer.HOURS))
|
||||
Log.success("Application finished.")
|
||||
|
||||
|
||||
def init():
|
||||
Log.info("Initializing PytorchBoot project.")
|
||||
root_path = os.getcwd()
|
||||
if len(os.listdir(root_path)) > 0:
|
||||
Log.error("Current directory is not empty. Please provide an empty directory.")
|
||||
else:
|
||||
with open(os.path.join(root_path, "application.py"), "w") as file:
|
||||
file.write(app_template)
|
||||
with open(os.path.join(root_path, "config.yaml"), "w") as file:
|
||||
file.write(config_template)
|
||||
|
||||
Log.success("PytorchBoot project initialized.")
|
||||
Log.info("Now you can create your components and run the application.")
|
||||
|
||||
def scan():
|
||||
root_path = os.getcwd()
|
||||
ProjectUtil.scan_project(root_path)
|
||||
comments = get_all_component_comments()
|
||||
Log.info("Components detected in the project:")
|
||||
for stereotype, classes in get_all_component_classes().items():
|
||||
Log.info(f" {stereotype}:")
|
||||
for name, cls in classes.items():
|
||||
comment = comments[stereotype].get(name)
|
||||
if comment is not None:
|
||||
Log.warning(f" - {name}: {cls.__module__}.{cls.__name__} ({comment})")
|
||||
else:
|
||||
Log.success(f" - {name}: {cls.__module__}.{cls.__name__}")
|
||||
|
||||
Log.info("Applications detected in the project:")
|
||||
for app_name, app_cls in application_class.items():
|
||||
Log.success(f" - {app_name}: {app_cls.__module__}.{app_cls.__name__}")
|
||||
Log.success("Scan completed.")
|
||||
|
||||
def ui():
|
||||
port = 5000
|
||||
if len(sys.argv) == 3:
|
||||
port = int(sys.argv[2])
|
||||
Log.success(f"PytorchBoot UI server started at http://localhost:{port}")
|
||||
app.run(port=port, host="0.0.0.0")
|
||||
|
||||
|
||||
def help():
|
||||
Log.info("PytorchBoot commands:")
|
||||
Log.info(" init: Initialize a new PytorchBoot project in the current directory.")
|
||||
Log.info(" run [name]: Run the PytorchBoot application with the specified name. If no name is provided, the default application will be run.")
|
||||
Log.info(" scan: Scan the project for PytorchBoot components.")
|
||||
Log.info(" ui [port]: Start the PytorchBoot UI server. If no port is provided, the default port 5000 will be used.")
|
||||
Log.info(" help: Display this help message.")
|
||||
|
||||
def main():
|
||||
if len(sys.argv) > 1:
|
||||
if sys.argv[1] == "init":
|
||||
init()
|
||||
elif sys.argv[1] == "run":
|
||||
run()
|
||||
elif sys.argv[1] == "scan":
|
||||
scan()
|
||||
elif sys.argv[1] == "ui":
|
||||
ui()
|
||||
elif sys.argv[1] == "help":
|
||||
help()
|
||||
else:
|
||||
Log.error("Invalid command: " + sys.argv[1] + ". Use 'pytorch-boot help' for help.")
|
||||
else:
|
||||
Log.error("Please provide a command to run the application.")
|
@ -1,21 +0,0 @@
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class Component:
|
||||
TYPE: str
|
||||
NAME: str
|
||||
|
||||
def get_name(self):
|
||||
return self.NAME
|
||||
|
||||
def get_type(self):
|
||||
return self.TYPE
|
||||
|
||||
def get_config(self):
|
||||
return self.config
|
||||
|
||||
def print(self):
|
||||
Log.blue("Component Information")
|
||||
Log.blue(f"- Type: {self.TYPE}")
|
||||
Log.blue(f"- Name: {self.NAME}")
|
||||
Log.blue(f"- Config: \n\t{self.config}")
|
||||
|
@ -1,59 +0,0 @@
|
||||
import argparse
|
||||
import os.path
|
||||
import shutil
|
||||
import yaml
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class ConfigManager:
|
||||
config = None
|
||||
config_path = None
|
||||
|
||||
@staticmethod
|
||||
def get(*args):
|
||||
result = ConfigManager.config
|
||||
for arg in args:
|
||||
result = result[arg]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_config_with(config_file_path):
|
||||
ConfigManager.config_path = config_file_path
|
||||
if not os.path.exists(ConfigManager.config_path):
|
||||
raise ValueError(f"Config file <{config_file_path}> does not exist")
|
||||
with open(config_file_path, 'r') as file:
|
||||
ConfigManager.config = yaml.safe_load(file)
|
||||
|
||||
@staticmethod
|
||||
def backup_config_to(target_config_dir, file_name, prefix="config"):
|
||||
file_name = f"__{prefix}_{file_name}.yaml"
|
||||
target_config_file_path = str(os.path.join(target_config_dir, file_name))
|
||||
shutil.copy(ConfigManager.config_path, target_config_file_path)
|
||||
|
||||
@staticmethod
|
||||
def load_config():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--config', type=str, default='', help='config file path')
|
||||
args = parser.parse_args()
|
||||
if args.config:
|
||||
ConfigManager.load_config_with(args.config)
|
||||
|
||||
@staticmethod
|
||||
def print_config(key: str = None, group: dict = None, level=0):
|
||||
table_size = 80
|
||||
if key and group:
|
||||
value = group[key]
|
||||
if type(value) is dict:
|
||||
Log.blue("\t" * level + f"+-{key}:")
|
||||
for k in value:
|
||||
ConfigManager.print_config(k, value, level=level + 1)
|
||||
else:
|
||||
Log.blue("\t" * level + f"| {key}: {value}")
|
||||
elif key:
|
||||
ConfigManager.print_config(key, ConfigManager.config, level=level)
|
||||
else:
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
Log.blue(f"| Configurations in <{ConfigManager.config_path}>:")
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
for key in ConfigManager.config:
|
||||
ConfigManager.print_config(key, level=level + 1)
|
||||
Log.blue("+" + "-" * table_size + "+")
|
@ -1 +0,0 @@
|
||||
from PytorchBoot.dataset.base_dataset import BaseDataset
|
@ -1,44 +0,0 @@
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader, Subset
|
||||
|
||||
from PytorchBoot.component import Component
|
||||
|
||||
class BaseDataset(ABC, Dataset, Component):
|
||||
def __init__(self, config):
|
||||
super(BaseDataset, self).__init__()
|
||||
self.config = config
|
||||
|
||||
@staticmethod
|
||||
def process_batch(batch, device):
|
||||
for key in batch.keys():
|
||||
if isinstance(batch[key], list):
|
||||
continue
|
||||
batch[key] = batch[key].to(device)
|
||||
return batch
|
||||
|
||||
def get_collate_fn(self):
|
||||
return None
|
||||
|
||||
def get_loader(self, shuffle=False):
|
||||
ratio = self.config["ratio"]
|
||||
if ratio > 1 or ratio <= 0:
|
||||
raise ValueError(
|
||||
f"dataset ratio should be between (0,1], found {ratio} in {self.config['name']}"
|
||||
)
|
||||
subset_size = max(1,int(len(self) * ratio))
|
||||
indices = np.random.permutation(len(self))[:subset_size]
|
||||
subset = Subset(self, indices)
|
||||
return DataLoader(
|
||||
subset,
|
||||
batch_size=self.config["batch_size"],
|
||||
num_workers=self.config["num_workers"],
|
||||
shuffle=shuffle,
|
||||
collate_fn=self.get_collate_fn(),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,2 +0,0 @@
|
||||
from PytorchBoot.factory.component_factory import ComponentFactory
|
||||
from PytorchBoot.factory.optimizer_factory import OptimizerFactory
|
@ -1,27 +0,0 @@
|
||||
from PytorchBoot.component import Component
|
||||
from PytorchBoot.stereotype import *
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.config import ConfigManager
|
||||
|
||||
class ComponentFactory:
|
||||
@staticmethod
|
||||
def create(component_type: str, name: str) -> Component:
|
||||
component_classes = get_component_classes(component_type=component_type)
|
||||
if component_classes is None:
|
||||
Log.error(f"Unsupported component type: {component_type}", True)
|
||||
|
||||
if component_type == namespace.Stereotype.DATASET:
|
||||
config = ConfigManager.get(component_type, name)
|
||||
cls = dataset_classes[config["source"]]
|
||||
dataset_obj = cls(config)
|
||||
dataset_obj.NAME = name
|
||||
dataset_obj.TYPE = component_type
|
||||
return dataset_obj
|
||||
|
||||
if name not in component_classes:
|
||||
Log.error(f"Unsupported component name: {name}", True)
|
||||
|
||||
cls = component_classes[name]
|
||||
config = ConfigManager.get(component_type, name)
|
||||
return cls(config)
|
||||
|
@ -1,67 +0,0 @@
|
||||
import torch.optim as optim
|
||||
|
||||
class OptimizerFactory:
|
||||
@staticmethod
|
||||
def create(config: dict, params) -> optim.Optimizer:
|
||||
optim_type = config["type"]
|
||||
lr = config.get("lr", 1e-3)
|
||||
|
||||
if optim_type == "SGD":
|
||||
return optim.SGD(
|
||||
params,
|
||||
lr=lr,
|
||||
momentum=config.get("momentum", 0.9),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
)
|
||||
elif optim_type == "Adam":
|
||||
return optim.Adam(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
)
|
||||
elif optim_type == "AdamW":
|
||||
return optim.AdamW(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 1e-2),
|
||||
)
|
||||
elif optim_type == "RMSprop":
|
||||
return optim.RMSprop(
|
||||
params,
|
||||
lr=lr,
|
||||
alpha=config.get("alpha", 0.99),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 1e-4),
|
||||
momentum=config.get("momentum", 0.9),
|
||||
)
|
||||
elif optim_type == "Adagrad":
|
||||
return optim.Adagrad(
|
||||
params,
|
||||
lr=lr,
|
||||
lr_decay=config.get("lr_decay", 0),
|
||||
weight_decay=config.get("weight_decay", 0),
|
||||
)
|
||||
elif optim_type == "Adamax":
|
||||
return optim.Adamax(
|
||||
params,
|
||||
lr=lr,
|
||||
betas=config.get("betas", (0.9, 0.999)),
|
||||
eps=config.get("eps", 1e-8),
|
||||
weight_decay=config.get("weight_decay", 0),
|
||||
)
|
||||
elif optim_type == "LBFGS":
|
||||
return optim.LBFGS(
|
||||
params,
|
||||
lr=lr,
|
||||
max_iter=config.get("max_iter", 20),
|
||||
max_eval=config.get("max_eval", None),
|
||||
tolerance_grad=config.get("tolerance_grad", 1e-7),
|
||||
tolerance_change=config.get("tolerance_change", 1e-9),
|
||||
history_size=config.get("history_size", 100),
|
||||
line_search_fn=config.get("line_search_fn", None),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Unknown optimizer: {}".format(optim_type))
|
@ -1,33 +0,0 @@
|
||||
|
||||
class Stereotype:
|
||||
DATASET:str = "dataset"
|
||||
MODULE:str = "module"
|
||||
PIPELINE:str = "pipeline"
|
||||
RUNNER:str = "runner"
|
||||
FACTORY:str = "factory"
|
||||
EVALUATION_METHOD:str = "evaluation_method"
|
||||
LOSS_FUNCTION:str = "loss_function"
|
||||
|
||||
class Mode:
|
||||
TRAIN:str = "train"
|
||||
TEST:str = "test"
|
||||
EVALUATION:str = "evaluation"
|
||||
|
||||
class Direcotry:
|
||||
CHECKPOINT_DIR_NAME: str = 'checkpoints'
|
||||
TENSORBOARD_DIR_NAME: str = 'tensorboard'
|
||||
LOG_DIR_NAME: str = 'log'
|
||||
RESULT_DIR_NAME: str = 'results'
|
||||
|
||||
class TensorBoard:
|
||||
SCALAR: str = "scalar"
|
||||
IMAGE: str = "image"
|
||||
POINT: str = "point"
|
||||
|
||||
class LogType:
|
||||
INFO:str = "info"
|
||||
ERROR:str = "error"
|
||||
WARNING:str = "warning"
|
||||
SUCCESS:str = "success"
|
||||
DEBUG:str = "debug"
|
||||
TERMINATE:str = "terminate"
|
@ -1,4 +0,0 @@
|
||||
from PytorchBoot.runners.trainer import DefaultTrainer
|
||||
from PytorchBoot.runners.evaluator import DefaultEvaluator
|
||||
from PytorchBoot.runners.predictor import DefaultPredictor
|
||||
from PytorchBoot.runners.runner import Runner
|
@ -1,132 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
@stereotype.runner("default_evaluator")
|
||||
class DefaultEvaluator(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.pipeline = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
self.pipeline:torch.nn.Module = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.model_path = self.config["experiment"]["model_path"]
|
||||
self.load_experiment("default_evaluator")
|
||||
|
||||
|
||||
''' Test '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
self.test_dataset_name_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_name in self.test_dataset_name_list:
|
||||
if test_dataset_name not in seen_name:
|
||||
seen_name.add(test_dataset_name)
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
|
||||
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
|
||||
self.test_set_list.append(test_set)
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
eval_result = self.test()
|
||||
self.save_eval_result(eval_result)
|
||||
|
||||
def test(self):
|
||||
self.pipeline.eval()
|
||||
eval_result = {}
|
||||
with torch.no_grad():
|
||||
test_set: BaseDataset
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
test_set_config = test_set.get_config()
|
||||
eval_list = test_set_config["eval_list"]
|
||||
ratio = test_set_config["ratio"]
|
||||
test_set_name = test_set.get_name()
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader()
|
||||
loop = tqdm(enumerate(test_loader), total=int(len(test_loader)))
|
||||
for _, data in loop:
|
||||
test_set.process_batch(data, self.device)
|
||||
data["mode"] = namespace.Mode.TEST
|
||||
output = self.pipeline(data)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Evaluating [{dataset_idx+1}/{len(self.test_set_list)}] (Test: {test_set_name}, ratio={ratio})')
|
||||
result_dict = self.eval_fn(output_list, data_list, eval_list)
|
||||
eval_result[test_set_name] = result_dict
|
||||
return eval_result
|
||||
|
||||
def save_eval_result(self, eval_result):
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
|
||||
eval_result_path = os.path.join(result_dir, self.file_name + "_eval_result.json")
|
||||
with open(eval_result_path, "w") as f:
|
||||
json.dump(eval_result, f, indent=4)
|
||||
Log.success(f"Saved evaluation result to {eval_result_path}")
|
||||
|
||||
@staticmethod
|
||||
def eval_fn(output_list, data_list, eval_list):
|
||||
collected_result = {}
|
||||
for eval_method_name in eval_list:
|
||||
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name)
|
||||
eval_results:dict = eval_method.evaluate(output_list, data_list)
|
||||
for data_type, eval_result in eval_results.items():
|
||||
if data_type not in collected_result:
|
||||
collected_result[data_type] = {}
|
||||
for name, value in eval_result.items():
|
||||
collected_result[data_type][name] = value
|
||||
|
||||
return collected_result
|
||||
|
||||
def load_checkpoint(self):
|
||||
self.load(self.model_path)
|
||||
Log.success(f"Loaded checkpoint from {self.model_path}")
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
self.load_checkpoint()
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
|
||||
os.makedirs(result_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(dataset: BaseDataset):
|
||||
config = dataset.get_config()
|
||||
name = dataset.get_name()
|
||||
Log.blue(f"Dataset: {name}")
|
||||
for k,v in config.items():
|
||||
Log.blue(f"\t{k}: {v}")
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue(self.pipeline)
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
for i, test_set in enumerate(self.test_set_list):
|
||||
Log.blue(f"test dataset {i}: ")
|
||||
print_dataset(test_set)
|
||||
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
@ -1,128 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
@stereotype.runner("default_predictor")
|
||||
class DefaultPredictor(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.pipeline = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
self.pipeline:torch.nn.Module = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.model_path = self.config["experiment"]["model_path"]
|
||||
self.load_experiment("default_predictor")
|
||||
self.save_original_data = self.config["experiment"]["save_original_data"]
|
||||
|
||||
''' Testset '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
self.test_dataset_name_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_name in self.test_dataset_name_list:
|
||||
if test_dataset_name not in seen_name:
|
||||
seen_name.add(test_dataset_name)
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
|
||||
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
|
||||
self.test_set_list.append(test_set)
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
predict_result = self.predict()
|
||||
self.save_predict_result(predict_result)
|
||||
|
||||
def predict(self):
|
||||
self.pipeline.eval()
|
||||
predict_result = {}
|
||||
with torch.no_grad():
|
||||
test_set: BaseDataset
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
test_set_config = test_set.get_config()
|
||||
ratio = test_set_config["ratio"]
|
||||
test_set_name = test_set.get_name()
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader()
|
||||
loop = tqdm(enumerate(test_loader), total=int(len(test_loader)))
|
||||
for _, data in loop:
|
||||
test_set.process_batch(data, self.device)
|
||||
data["mode"] = namespace.Mode.TEST
|
||||
output = self.pipeline(data)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Predicting [{dataset_idx+1}/{len(self.test_set_list)}] (Test: {test_set_name}, ratio={ratio})')
|
||||
predict_result[test_set_name] = {
|
||||
"output": output_list,
|
||||
"data": data_list
|
||||
}
|
||||
return predict_result
|
||||
|
||||
def save_predict_result(self, predict_result):
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME, self.file_name+"_predict_result")
|
||||
os.makedirs(result_dir)
|
||||
for test_set_name in predict_result.keys():
|
||||
os.mkdir(os.path.join(result_dir, test_set_name))
|
||||
idx = 0
|
||||
for output, data in zip(predict_result[test_set_name]["output"], predict_result[test_set_name]["data"]):
|
||||
output_path = os.path.join(result_dir, test_set_name, f"output_{idx}.pth")
|
||||
torch.save(output, output_path)
|
||||
if self.save_original_data:
|
||||
data_path = os.path.join(result_dir, test_set_name, f"data_{idx}.pth")
|
||||
torch.save(data, data_path)
|
||||
idx += 1
|
||||
Log.success(f"Saved predict result of {test_set_name} to {result_dir}")
|
||||
Log.success(f"Saved all predict result to {result_dir}")
|
||||
|
||||
def load_checkpoint(self):
|
||||
self.load(self.model_path)
|
||||
Log.success(f"Loaded checkpoint from {self.model_path}")
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
self.load_checkpoint()
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
result_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.RESULT_DIR_NAME)
|
||||
os.makedirs(result_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(dataset: BaseDataset):
|
||||
config = dataset.get_config()
|
||||
name = dataset.get_name()
|
||||
Log.blue(f"Dataset: {name}")
|
||||
for k,v in config.items():
|
||||
Log.blue(f"\t{k}: {v}")
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue(self.pipeline)
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
for i, test_set in enumerate(self.test_set_list):
|
||||
Log.blue(f"test dataset {i}: ")
|
||||
print_dataset(test_set)
|
||||
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
@ -1,61 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
|
||||
class Runner(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
ConfigManager.print_config()
|
||||
self.config = ConfigManager.get("runner")
|
||||
self.seed = self.config["general"]["seed"]
|
||||
self.device = self.config["general"]["device"]
|
||||
self.cuda_visible_devices = self.config["general"]["cuda_visible_devices"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = self.cuda_visible_devices
|
||||
self.experiments_config = self.config["experiment"]
|
||||
self.experiment_path = os.path.join(self.experiments_config["root_dir"], self.experiments_config["name"])
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
lt = time.localtime()
|
||||
self.file_name = f"{lt.tm_year}_{lt.tm_mon}_{lt.tm_mday}_{lt.tm_hour}h{lt.tm_min}m{lt.tm_sec}s"
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_experiment(self, backup_name=None):
|
||||
if not os.path.exists(self.experiment_path):
|
||||
Log.info(f"experiments environment {self.experiments_config['name']} does not exists.")
|
||||
self.create_experiment(backup_name)
|
||||
else:
|
||||
Log.info(f"experiments environment {self.experiments_config['name']}")
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
if not os.path.exists(backup_config_dir):
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
|
||||
@abstractmethod
|
||||
def create_experiment(self, backup_name=None):
|
||||
Log.info("creating experiment: " + self.experiments_config["name"])
|
||||
os.makedirs(self.experiment_path)
|
||||
backup_config_dir = os.path.join(str(self.experiment_path), "configs")
|
||||
os.makedirs(backup_config_dir)
|
||||
ConfigManager.backup_config_to(backup_config_dir, self.file_name, backup_name)
|
||||
log_dir = os.path.join(str(self.experiment_path), "log")
|
||||
os.makedirs(log_dir)
|
||||
cache_dir = os.path.join(str(self.experiment_path), "cache")
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
def print_info(self):
|
||||
table_size = 80
|
||||
Log.blue("+" + "-" * table_size + "+")
|
||||
Log.blue(f"| Experiment <{self.experiments_config['name']}>")
|
||||
Log.blue("+" + "-" * table_size + "+")
|
@ -1,266 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
import PytorchBoot.namespace as namespace
|
||||
import PytorchBoot.stereotype as stereotype
|
||||
from PytorchBoot.factory import ComponentFactory
|
||||
from PytorchBoot.factory import OptimizerFactory
|
||||
|
||||
from PytorchBoot.dataset import BaseDataset
|
||||
from PytorchBoot.runners.runner import Runner
|
||||
from PytorchBoot.utils.tensorboard_util import TensorboardWriter
|
||||
from PytorchBoot.stereotype import EXTERNAL_FRONZEN_MODULES
|
||||
from PytorchBoot.utils import Log
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
@stereotype.runner("default_trainer")
|
||||
class DefaultTrainer(Runner):
|
||||
def __init__(self, config_path):
|
||||
super().__init__(config_path)
|
||||
tensorboard_path = os.path.join(self.experiment_path, namespace.Direcotry.TENSORBOARD_DIR_NAME)
|
||||
|
||||
''' Pipeline '''
|
||||
self.pipeline_name = self.config[namespace.Stereotype.PIPELINE]
|
||||
self.parallel = self.config["general"]["parallel"]
|
||||
self.pipeline:torch.nn.Module = ComponentFactory.create(namespace.Stereotype.PIPELINE, self.pipeline_name)
|
||||
if self.parallel and self.device == "cuda":
|
||||
self.pipeline = torch.nn.DataParallel(self.pipeline)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
''' Experiment '''
|
||||
self.current_epoch = 0
|
||||
self.current_iter = 0
|
||||
self.max_epochs = self.experiments_config["max_epochs"]
|
||||
self.test_first = self.experiments_config["test_first"]
|
||||
self.load_experiment("default_trainer")
|
||||
|
||||
''' Train '''
|
||||
self.train_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TRAIN)
|
||||
self.train_dataset_name= self.train_config["dataset"]
|
||||
self.train_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, self.train_dataset_name)
|
||||
self.optimizer = OptimizerFactory.create(self.train_config["optimizer"], self.pipeline.parameters())
|
||||
self.train_writer = SummaryWriter(
|
||||
log_dir=os.path.join(tensorboard_path, f"[{namespace.Mode.TRAIN}]{self.train_dataset_name}"))
|
||||
|
||||
''' Test '''
|
||||
self.test_config = ConfigManager.get(namespace.Stereotype.RUNNER, namespace.Mode.TEST)
|
||||
self.test_dataset_name_list = self.test_config["dataset_list"]
|
||||
self.test_set_list = []
|
||||
self.test_writer_list = []
|
||||
seen_name = set()
|
||||
for test_dataset_name in self.test_dataset_name_list:
|
||||
if test_dataset_name not in seen_name:
|
||||
seen_name.add(test_dataset_name)
|
||||
else:
|
||||
raise ValueError("Duplicate test dataset name: {}".format(test_dataset_name))
|
||||
test_set: BaseDataset = ComponentFactory.create(namespace.Stereotype.DATASET, test_dataset_name)
|
||||
test_writer = SummaryWriter(
|
||||
log_dir=os.path.join(tensorboard_path, f"[test]{test_dataset_name}"))
|
||||
self.test_set_list.append(test_set)
|
||||
self.test_writer_list.append(test_writer)
|
||||
|
||||
self.print_info()
|
||||
|
||||
def run(self):
|
||||
save_interval = self.experiments_config["save_checkpoint_interval"]
|
||||
if self.current_epoch != 0:
|
||||
Log.info("Continue training from epoch {}.".format(self.current_epoch))
|
||||
else:
|
||||
Log.info("Start training from initial model.")
|
||||
if self.test_first:
|
||||
Log.info("Do test first.")
|
||||
self.test()
|
||||
while self.current_epoch < self.max_epochs:
|
||||
self.current_epoch += 1
|
||||
status_manager.set_progress("train", "default_trainer", "Epoch", self.current_epoch, self.max_epochs)
|
||||
self.train()
|
||||
self.test()
|
||||
if self.current_epoch % save_interval == 0:
|
||||
self.save_checkpoint()
|
||||
self.save_checkpoint(is_last=True)
|
||||
|
||||
def train(self):
|
||||
self.pipeline.train()
|
||||
train_set_name = self.train_dataset_name
|
||||
config = self.train_set.get_config()
|
||||
train_loader = self.train_set.get_loader(shuffle=True)
|
||||
|
||||
total=len(train_loader)
|
||||
loop = tqdm(enumerate(train_loader), total=total)
|
||||
|
||||
for i, data in loop:
|
||||
status_manager.set_progress("train", "default_trainer", f"(train) Batch[{train_set_name}]", i+1, total)
|
||||
self.train_set.process_batch(data, self.device)
|
||||
loss_dict = self.train_step(data)
|
||||
loop.set_description(
|
||||
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Train: {train_set_name}, ratio={config["ratio"]})')
|
||||
loop.set_postfix(loss=loss_dict)
|
||||
for loss_name, loss in loss_dict.items():
|
||||
status_manager.set_status("train", "default_trainer", f"[loss]{loss_name}", loss)
|
||||
TensorboardWriter.write_tensorboard(self.train_writer, "iter", loss_dict, self.current_iter, simple_scalar=True)
|
||||
self.current_iter += 1
|
||||
|
||||
|
||||
|
||||
def train_step(self, data):
|
||||
self.optimizer.zero_grad()
|
||||
data["mode"] = namespace.Mode.TRAIN
|
||||
output = self.pipeline(data)
|
||||
total_loss, loss_dict = self.loss_fn(output, data)
|
||||
total_loss.backward()
|
||||
self.optimizer.step()
|
||||
for k, v in loss_dict.items():
|
||||
loss_dict[k] = round(v, 5)
|
||||
return loss_dict
|
||||
|
||||
def loss_fn(self, output, data):
|
||||
loss_name_list = self.train_config["losses"]
|
||||
loss_dict = {}
|
||||
total_loss = torch.tensor(0.0, dtype=torch.float32, device=self.device)
|
||||
for loss_name in loss_name_list:
|
||||
target_loss_fn = ComponentFactory.create(namespace.Stereotype.LOSS_FUNCTION, loss_name)
|
||||
loss = target_loss_fn.compute(output, data)
|
||||
loss_dict[loss_name] = loss.item()
|
||||
total_loss += loss
|
||||
|
||||
loss_dict['total_loss'] = total_loss.item()
|
||||
return total_loss, loss_dict
|
||||
|
||||
def test(self):
|
||||
self.pipeline.eval()
|
||||
with torch.no_grad():
|
||||
test_set: BaseDataset
|
||||
for dataset_idx, test_set in enumerate(self.test_set_list):
|
||||
test_set_config = test_set.get_config()
|
||||
eval_list = test_set_config["eval_list"]
|
||||
ratio = test_set_config["ratio"]
|
||||
test_set_name = test_set.get_name()
|
||||
writer = self.test_writer_list[dataset_idx]
|
||||
output_list = []
|
||||
data_list = []
|
||||
test_loader = test_set.get_loader()
|
||||
total=int(len(test_loader))
|
||||
loop = tqdm(enumerate(test_loader), total=total)
|
||||
for i, data in loop:
|
||||
status_manager.set_progress("train", "default_trainer", f"(test) Batch[{test_set_name}]", i+1, total)
|
||||
test_set.process_batch(data, self.device)
|
||||
data["mode"] = namespace.Mode.TEST
|
||||
output = self.pipeline(data)
|
||||
output_list.append(output)
|
||||
data_list.append(data)
|
||||
loop.set_description(
|
||||
f'Epoch [{self.current_epoch}/{self.max_epochs}] (Test: {test_set_name}, ratio={ratio})')
|
||||
result_dict = self.eval_fn(output_list, data_list, eval_list)
|
||||
TensorboardWriter.write_tensorboard(writer, "epoch", result_dict, self.current_epoch - 1)
|
||||
|
||||
@staticmethod
|
||||
def eval_fn(output_list, data_list, eval_list):
|
||||
collected_result = {}
|
||||
for eval_method_name in eval_list:
|
||||
eval_method = ComponentFactory.create(namespace.Stereotype.EVALUATION_METHOD, eval_method_name)
|
||||
eval_results:dict = eval_method.evaluate(output_list, data_list)
|
||||
for data_type, eval_result in eval_results.items():
|
||||
if data_type not in collected_result:
|
||||
collected_result[data_type] = {}
|
||||
for name, value in eval_result.items():
|
||||
collected_result[data_type][name] = value
|
||||
status_manager.set_status("train", "default_trainer", f"[eval]{name}", value)
|
||||
|
||||
return collected_result
|
||||
|
||||
def get_checkpoint_path(self, is_last=False):
|
||||
return os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME,
|
||||
"Epoch_{}.pth".format(
|
||||
self.current_epoch if self.current_epoch != -1 and not is_last else "last"))
|
||||
|
||||
def load_checkpoint(self, is_last=False):
|
||||
self.load(self.get_checkpoint_path(is_last))
|
||||
Log.success(f"Loaded checkpoint from {self.get_checkpoint_path(is_last)}")
|
||||
if is_last:
|
||||
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
meta_path = os.path.join(checkpoint_root, "meta.json")
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError(
|
||||
"No checkpoint meta.json file in the experiment {}".format(self.experiments_config["name"]))
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "r") as f:
|
||||
meta = json.load(f)
|
||||
self.current_epoch = meta["last_epoch"]
|
||||
self.current_iter = meta["last_iter"]
|
||||
|
||||
def save_checkpoint(self, is_last=False):
|
||||
self.save(self.get_checkpoint_path(is_last))
|
||||
if not is_last:
|
||||
Log.success(f"Checkpoint at epoch {self.current_epoch} saved to {self.get_checkpoint_path(is_last)}")
|
||||
else:
|
||||
meta = {
|
||||
"last_epoch": self.current_epoch,
|
||||
"last_iter": self.current_iter,
|
||||
"time": str(datetime.now())
|
||||
}
|
||||
checkpoint_root = os.path.join(self.experiment_path, namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
file_path = os.path.join(checkpoint_root, "meta.json")
|
||||
with open(file_path, "w") as f:
|
||||
json.dump(meta, f)
|
||||
|
||||
def load_experiment(self, backup_name=None):
|
||||
super().load_experiment(backup_name)
|
||||
if self.experiments_config["use_checkpoint"]:
|
||||
self.current_epoch = self.experiments_config["epoch"]
|
||||
self.load_checkpoint(is_last=(self.current_epoch == -1))
|
||||
|
||||
def create_experiment(self, backup_name=None):
|
||||
super().create_experiment(backup_name)
|
||||
ckpt_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.CHECKPOINT_DIR_NAME)
|
||||
os.makedirs(ckpt_dir)
|
||||
tensorboard_dir = os.path.join(str(self.experiment_path), namespace.Direcotry.TENSORBOARD_DIR_NAME)
|
||||
os.makedirs(tensorboard_dir)
|
||||
|
||||
def load(self, path):
|
||||
state_dict = torch.load(path)
|
||||
if self.parallel:
|
||||
self.pipeline.module.load_state_dict(state_dict)
|
||||
else:
|
||||
self.pipeline.load_state_dict(state_dict)
|
||||
|
||||
def save(self, path):
|
||||
if self.parallel:
|
||||
state_dict = self.pipeline.module.state_dict()
|
||||
else:
|
||||
state_dict = self.pipeline.state_dict()
|
||||
|
||||
for name, module in self.pipeline.named_modules():
|
||||
if module.__class__ in EXTERNAL_FRONZEN_MODULES:
|
||||
if name in state_dict:
|
||||
del state_dict[name]
|
||||
|
||||
torch.save(state_dict, path)
|
||||
|
||||
|
||||
def print_info(self):
|
||||
def print_dataset(dataset: BaseDataset):
|
||||
config = dataset.get_config()
|
||||
name = dataset.get_name()
|
||||
Log.blue(f"Dataset: {name}")
|
||||
for k,v in config.items():
|
||||
Log.blue(f"\t{k}: {v}")
|
||||
|
||||
super().print_info()
|
||||
table_size = 70
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Pipeline {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue(self.pipeline)
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)} Datasets {'-' * (table_size // 2)}" + '+')
|
||||
Log.blue("train dataset: ")
|
||||
print_dataset(self.train_set)
|
||||
for i, test_set in enumerate(self.test_set_list):
|
||||
Log.blue(f"test dataset {i}: ")
|
||||
print_dataset(test_set)
|
||||
|
||||
Log.blue(f"{'+' + '-' * (table_size // 2)}----------{'-' * (table_size // 2)}" + '+')
|
||||
|
@ -1,17 +0,0 @@
|
||||
|
||||
from abc import abstractmethod, ABC
|
||||
|
||||
from PytorchBoot.config import ConfigManager
|
||||
from PytorchBoot.runners import Runner
|
||||
from PytorchBoot.utils import Log
|
||||
|
||||
class WebRunner(ABC, Runner):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config_path):
|
||||
ConfigManager.load_config_with(config_path)
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
pass
|
||||
|
@ -1,56 +0,0 @@
|
||||
|
||||
class StatusManager:
|
||||
def __init__(self):
|
||||
self.running_app = {}
|
||||
self.last_status = {}
|
||||
self.curr_status = {}
|
||||
self.progress = {}
|
||||
self.log = []
|
||||
|
||||
def is_running(self):
|
||||
return len(self.running_app) > 0
|
||||
|
||||
def run_app(self, app_name, app):
|
||||
self.running_app[app_name] = app
|
||||
|
||||
def end_app(self, app_name):
|
||||
self.running_app.pop(app_name)
|
||||
|
||||
def set_status(self, app_name, runner_name, key, value):
|
||||
self.last_status = self.curr_status
|
||||
if app_name not in self.curr_status:
|
||||
self.curr_status[app_name] = {}
|
||||
if runner_name not in self.curr_status[app_name]:
|
||||
self.curr_status[app_name][runner_name] = {}
|
||||
self.curr_status[app_name][runner_name][key] = value
|
||||
|
||||
def set_progress(self, app_name, runner_name, key, curr_value, max_value):
|
||||
if app_name not in self.progress:
|
||||
self.progress[app_name] = {}
|
||||
if runner_name not in self.progress[app_name]:
|
||||
self.progress[app_name][runner_name] = {}
|
||||
self.progress[app_name][runner_name][key] = (curr_value, max_value)
|
||||
|
||||
def get_status(self):
|
||||
return self.curr_status
|
||||
|
||||
def get_progress(self):
|
||||
return self.progress
|
||||
|
||||
def add_log(self, time_str, log_type, message):
|
||||
self.log.append((time_str, log_type, message))
|
||||
|
||||
def get_log(self):
|
||||
return self.log
|
||||
|
||||
def get_running_apps(self):
|
||||
return list(self.running_app.keys())
|
||||
|
||||
def get_last_status(self):
|
||||
return self.last_status
|
||||
|
||||
def reset_status(self):
|
||||
self.last_status = {}
|
||||
self.curr_status = {}
|
||||
|
||||
status_manager = StatusManager()
|
@ -1,149 +0,0 @@
|
||||
import inspect
|
||||
|
||||
from PytorchBoot.component import Component
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
import PytorchBoot.namespace as namespace
|
||||
|
||||
|
||||
def ensure_component_subclass(cls, type_name, name):
|
||||
if not issubclass(cls, Component):
|
||||
new_cls = type(cls.__name__, (Component, cls), {
|
||||
**cls.__dict__,
|
||||
"TYPE": type_name,
|
||||
"NAME": name
|
||||
})
|
||||
new_cls.__original_class__ = cls
|
||||
else:
|
||||
new_cls = cls
|
||||
for method_name, method in inspect.getmembers(cls, predicate=inspect.isfunction):
|
||||
if getattr(method, "__isabstractmethod__", False):
|
||||
Log.error(f"Component <{name}> contains abstract method <{method_name}>.", True)
|
||||
|
||||
return cls
|
||||
|
||||
|
||||
# --- Classes --- #
|
||||
dataset_classes = {}
|
||||
dataset_comments = {}
|
||||
def dataset(dataset_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, "get_loader") or not callable(getattr(cls, "get_loader")):
|
||||
Log.error(f"dataset <{cls.__name__}> must implement a 'get_loader' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.DATASET, dataset_name)
|
||||
dataset_comments[dataset_name] = comment
|
||||
dataset_classes[dataset_name] = cls
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
module_classes = {}
|
||||
module_comments = {}
|
||||
def module(module_name, comment=None):
|
||||
def decorator(cls):
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.MODULE, module_name)
|
||||
module_comments[module_name] = comment
|
||||
module_classes[module_name] = cls
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
pipeline_classes = {}
|
||||
pipline_comments = {}
|
||||
def pipeline(pipeline_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, 'forward') or not callable(getattr(cls, 'forward')):
|
||||
Log.error(f"pipeline <{cls.__name__}> must implement a 'forward' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.PIPELINE, pipeline_name)
|
||||
pipeline_classes[pipeline_name] = cls
|
||||
pipline_comments[pipeline_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
runner_classes = {}
|
||||
runner_comments = {}
|
||||
def runner(runner_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, 'run') or not callable(getattr(cls, 'run')):
|
||||
Log.error(f"runner <{cls.__name__}> must implement a 'run' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.RUNNER, runner_name)
|
||||
runner_classes[runner_name] = cls
|
||||
runner_comments[runner_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
factory_classes = {}
|
||||
factory_comments = {}
|
||||
def factory(factory_name, comment=None):
|
||||
def decorator(cls):
|
||||
if not hasattr(cls, 'create') or not callable(getattr(cls, 'create')):
|
||||
Log.error(f"factory <{cls.__name__}> must implement a 'create' method", True)
|
||||
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.FACTORY, factory_name)
|
||||
factory_classes[factory_name] = cls
|
||||
factory_comments[factory_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
loss_classes = {}
|
||||
loss_comments = {}
|
||||
def loss_function(loss_name, comment=None):
|
||||
def decorator(cls):
|
||||
|
||||
if not hasattr(cls, 'compute') or not callable(getattr(cls, 'compute')):
|
||||
Log.error(f"loss function <{cls.__name__}> must implement a 'compute' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.LOSS_FUNCTION, loss_name)
|
||||
loss_classes[loss_name] = cls
|
||||
loss_comments[loss_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
evaluation_classes = {}
|
||||
evaluation_comments = {}
|
||||
def evaluation_method(evaluation_name, comment=None):
|
||||
def decorator(cls):
|
||||
|
||||
if not hasattr(cls, 'evaluate') or not callable(getattr(cls, 'evaluate')):
|
||||
Log.error(f"evaluation method <{cls.__name__}> must implement a 'evaluate' method", True)
|
||||
cls = ensure_component_subclass(cls, namespace.Stereotype.EVALUATION_METHOD, evaluation_name)
|
||||
evaluation_classes[evaluation_name] = cls
|
||||
evaluation_comments[evaluation_name] = comment
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
# --- Others --- #
|
||||
EXTERNAL_FRONZEN_MODULES = set()
|
||||
|
||||
def external_frozen_module(cls):
|
||||
if not hasattr(cls, 'load') or not callable(getattr(cls, 'load')):
|
||||
Log.error(f"external module <{cls.__name__}> must implement a 'load' method", True)
|
||||
EXTERNAL_FRONZEN_MODULES.add(cls)
|
||||
return cls
|
||||
|
||||
# --- Utils --- #
|
||||
|
||||
all_component_classes = {
|
||||
namespace.Stereotype.DATASET: dataset_classes,
|
||||
namespace.Stereotype.MODULE: module_classes,
|
||||
namespace.Stereotype.PIPELINE: pipeline_classes,
|
||||
namespace.Stereotype.RUNNER: runner_classes,
|
||||
namespace.Stereotype.LOSS_FUNCTION: loss_classes,
|
||||
namespace.Stereotype.EVALUATION_METHOD: evaluation_classes,
|
||||
namespace.Stereotype.FACTORY: factory_classes
|
||||
}
|
||||
|
||||
all_component_comments = {
|
||||
namespace.Stereotype.DATASET: dataset_comments,
|
||||
namespace.Stereotype.MODULE: module_comments,
|
||||
namespace.Stereotype.PIPELINE: pipline_comments,
|
||||
namespace.Stereotype.RUNNER: runner_comments,
|
||||
namespace.Stereotype.LOSS_FUNCTION: loss_comments,
|
||||
namespace.Stereotype.EVALUATION_METHOD: evaluation_comments,
|
||||
namespace.Stereotype.FACTORY: factory_comments
|
||||
}
|
||||
|
||||
def get_all_component_classes():
|
||||
return all_component_classes
|
||||
|
||||
def get_all_component_comments():
|
||||
return all_component_comments
|
||||
|
||||
def get_component_classes(component_type):
|
||||
return all_component_classes.get(component_type, None)
|
@ -1,15 +0,0 @@
|
||||
template = """from PytorchBoot.application import PytorchBootApplication
|
||||
|
||||
@PytorchBootApplication
|
||||
class Application:
|
||||
@staticmethod
|
||||
def start():
|
||||
'''
|
||||
call default or your custom runners here, code will be executed
|
||||
automatically when type "pytorch-boot run" or "ptb run" in terminal
|
||||
|
||||
example:
|
||||
Trainer("path_to_your_train_config").run()
|
||||
Evaluator("path_to_your_eval_config").run()
|
||||
'''
|
||||
"""
|
@ -1,58 +0,0 @@
|
||||
template = """
|
||||
runners:
|
||||
general:
|
||||
seed: 0
|
||||
device: cuda
|
||||
cuda_visible_devices: "0,1,2,3,4,5,6,7"
|
||||
parallel: False
|
||||
|
||||
experiment:
|
||||
name: experiment_name
|
||||
root_dir: "experiments"
|
||||
use_checkpoint: False
|
||||
epoch: -1 # -1 stands for last epoch
|
||||
max_epochs: 5000
|
||||
save_checkpoint_interval: 1
|
||||
test_first: True
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
type: adam
|
||||
lr: 0.0001
|
||||
losses: # loss type : weight
|
||||
loss_type_0: 1.0
|
||||
dataset:
|
||||
name: train_set_name
|
||||
source: train_set_source_name
|
||||
ratio: 1.0
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
test:
|
||||
frequency: 3 # test frequency
|
||||
dataset_list:
|
||||
- name: test_set_name_0
|
||||
source: train_set_source_name
|
||||
eval_list:
|
||||
- eval_func_name_0
|
||||
- eval_func_name_1
|
||||
ratio: 1.0
|
||||
batch_size: 1
|
||||
num_workers: 1
|
||||
|
||||
pipeline: pipeline_name
|
||||
|
||||
pipelines:
|
||||
pipeline_name_0:
|
||||
- module_name_0
|
||||
- module_name_1
|
||||
|
||||
datasets:
|
||||
dataset_source_name_0:
|
||||
dataset_source_name_1:
|
||||
|
||||
modules:
|
||||
module_name_0:
|
||||
module_name_1:
|
||||
|
||||
"""
|
@ -1 +0,0 @@
|
||||
<!DOCTYPE html><html><head><meta charset=utf-8><meta name=viewport content="width=device-width,initial-scale=1"><title>PyTorchBoot Project</title><link href=/static/css/app.5383ee564f9a1a656786665504aa6b98.css rel=stylesheet></head><body><div id=app></div><script type=text/javascript src=/static/js/manifest.2ae2e69a05c33dfc65f8.js></script><script type=text/javascript src=/static/js/vendor.9f7b4785a30f0533ee08.js></script><script type=text/javascript src=/static/js/app.230235873e25a72eeacb.js></script></body></html>
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Before Width: | Height: | Size: 542 KiB |
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,2 +0,0 @@
|
||||
!function(r){var n=window.webpackJsonp;window.webpackJsonp=function(e,u,c){for(var f,i,p,a=0,l=[];a<e.length;a++)i=e[a],o[i]&&l.push(o[i][0]),o[i]=0;for(f in u)Object.prototype.hasOwnProperty.call(u,f)&&(r[f]=u[f]);for(n&&n(e,u,c);l.length;)l.shift()();if(c)for(a=0;a<c.length;a++)p=t(t.s=c[a]);return p};var e={},o={2:0};function t(n){if(e[n])return e[n].exports;var o=e[n]={i:n,l:!1,exports:{}};return r[n].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=r,t.c=e,t.d=function(r,n,e){t.o(r,n)||Object.defineProperty(r,n,{configurable:!1,enumerable:!0,get:e})},t.n=function(r){var n=r&&r.__esModule?function(){return r.default}:function(){return r};return t.d(n,"a",n),n},t.o=function(r,n){return Object.prototype.hasOwnProperty.call(r,n)},t.p="/",t.oe=function(r){throw console.error(r),r}}([]);
|
||||
//# sourceMappingURL=manifest.2ae2e69a05c33dfc65f8.js.map
|
@ -1 +0,0 @@
|
||||
{"version":3,"sources":["webpack:///webpack/bootstrap def2f39c04517bb0de2d"],"names":["parentJsonpFunction","window","chunkIds","moreModules","executeModules","moduleId","chunkId","result","i","resolves","length","installedChunks","push","Object","prototype","hasOwnProperty","call","modules","shift","__webpack_require__","s","installedModules","2","exports","module","l","m","c","d","name","getter","o","defineProperty","configurable","enumerable","get","n","__esModule","object","property","p","oe","err","console","error"],"mappings":"aACA,IAAAA,EAAAC,OAAA,aACAA,OAAA,sBAAAC,EAAAC,EAAAC,GAIA,IADA,IAAAC,EAAAC,EAAAC,EAAAC,EAAA,EAAAC,KACQD,EAAAN,EAAAQ,OAAoBF,IAC5BF,EAAAJ,EAAAM,GACAG,EAAAL,IACAG,EAAAG,KAAAD,EAAAL,GAAA,IAEAK,EAAAL,GAAA,EAEA,IAAAD,KAAAF,EACAU,OAAAC,UAAAC,eAAAC,KAAAb,EAAAE,KACAY,EAAAZ,GAAAF,EAAAE,IAIA,IADAL,KAAAE,EAAAC,EAAAC,GACAK,EAAAC,QACAD,EAAAS,OAAAT,GAEA,GAAAL,EACA,IAAAI,EAAA,EAAYA,EAAAJ,EAAAM,OAA2BF,IACvCD,EAAAY,IAAAC,EAAAhB,EAAAI,IAGA,OAAAD,GAIA,IAAAc,KAGAV,GACAW,EAAA,GAIA,SAAAH,EAAAd,GAGA,GAAAgB,EAAAhB,GACA,OAAAgB,EAAAhB,GAAAkB,QAGA,IAAAC,EAAAH,EAAAhB,IACAG,EAAAH,EACAoB,GAAA,EACAF,YAUA,OANAN,EAAAZ,GAAAW,KAAAQ,EAAAD,QAAAC,IAAAD,QAAAJ,GAGAK,EAAAC,GAAA,EAGAD,EAAAD,QAKAJ,EAAAO,EAAAT,EAGAE,EAAAQ,EAAAN,EAGAF,EAAAS,EAAA,SAAAL,EAAAM,EAAAC,GACAX,EAAAY,EAAAR,EAAAM,IACAhB,OAAAmB,eAAAT,EAAAM,GACAI,cAAA,EACAC,YAAA,EACAC,IAAAL,KAMAX,EAAAiB,EAAA,SAAAZ,GACA,IAAAM,EAAAN,KAAAa,WACA,WAA2B,OAAAb,EAAA,SAC3B,WAAiC,OAAAA,GAEjC,OADAL,EAAAS,EAAAE,EAAA,IAAAA,GACAA,GAIAX,EAAAY,EAAA,SAAAO,EAAAC,GAAsD,OAAA1B,OAAAC,UAAAC,eAAAC,KAAAsB,EAAAC,IAGtDpB,EAAAqB,EAAA,IAGArB,EAAAsB,GAAA,SAAAC,GAA8D,MAApBC,QAAAC,MAAAF,GAAoBA","file":"static/js/manifest.2ae2e69a05c33dfc65f8.js","sourcesContent":[" \t// install a JSONP callback for chunk loading\n \tvar parentJsonpFunction = window[\"webpackJsonp\"];\n \twindow[\"webpackJsonp\"] = function webpackJsonpCallback(chunkIds, moreModules, executeModules) {\n \t\t// add \"moreModules\" to the modules object,\n \t\t// then flag all \"chunkIds\" as loaded and fire callback\n \t\tvar moduleId, chunkId, i = 0, resolves = [], result;\n \t\tfor(;i < chunkIds.length; i++) {\n \t\t\tchunkId = chunkIds[i];\n \t\t\tif(installedChunks[chunkId]) {\n \t\t\t\tresolves.push(installedChunks[chunkId][0]);\n \t\t\t}\n \t\t\tinstalledChunks[chunkId] = 0;\n \t\t}\n \t\tfor(moduleId in moreModules) {\n \t\t\tif(Object.prototype.hasOwnProperty.call(moreModules, moduleId)) {\n \t\t\t\tmodules[moduleId] = moreModules[moduleId];\n \t\t\t}\n \t\t}\n \t\tif(parentJsonpFunction) parentJsonpFunction(chunkIds, moreModules, executeModules);\n \t\twhile(resolves.length) {\n \t\t\tresolves.shift()();\n \t\t}\n \t\tif(executeModules) {\n \t\t\tfor(i=0; i < executeModules.length; i++) {\n \t\t\t\tresult = __webpack_require__(__webpack_require__.s = executeModules[i]);\n \t\t\t}\n \t\t}\n \t\treturn result;\n \t};\n\n \t// The module cache\n \tvar installedModules = {};\n\n \t// objects to store loaded and loading chunks\n \tvar installedChunks = {\n \t\t2: 0\n \t};\n\n \t// The require function\n \tfunction __webpack_require__(moduleId) {\n\n \t\t// Check if module is in cache\n \t\tif(installedModules[moduleId]) {\n \t\t\treturn installedModules[moduleId].exports;\n \t\t}\n \t\t// Create a new module (and put it into the cache)\n \t\tvar module = installedModules[moduleId] = {\n \t\t\ti: moduleId,\n \t\t\tl: false,\n \t\t\texports: {}\n \t\t};\n\n \t\t// Execute the module function\n \t\tmodules[moduleId].call(module.exports, module, module.exports, __webpack_require__);\n\n \t\t// Flag the module as loaded\n \t\tmodule.l = true;\n\n \t\t// Return the exports of the module\n \t\treturn module.exports;\n \t}\n\n\n \t// expose the modules object (__webpack_modules__)\n \t__webpack_require__.m = modules;\n\n \t// expose the module cache\n \t__webpack_require__.c = installedModules;\n\n \t// define getter function for harmony exports\n \t__webpack_require__.d = function(exports, name, getter) {\n \t\tif(!__webpack_require__.o(exports, name)) {\n \t\t\tObject.defineProperty(exports, name, {\n \t\t\t\tconfigurable: false,\n \t\t\t\tenumerable: true,\n \t\t\t\tget: getter\n \t\t\t});\n \t\t}\n \t};\n\n \t// getDefaultExport function for compatibility with non-harmony modules\n \t__webpack_require__.n = function(module) {\n \t\tvar getter = module && module.__esModule ?\n \t\t\tfunction getDefault() { return module['default']; } :\n \t\t\tfunction getModuleExports() { return module; };\n \t\t__webpack_require__.d(getter, 'a', getter);\n \t\treturn getter;\n \t};\n\n \t// Object.prototype.hasOwnProperty.call\n \t__webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); };\n\n \t// __webpack_public_path__\n \t__webpack_require__.p = \"/\";\n\n \t// on error function for async loading\n \t__webpack_require__.oe = function(err) { console.error(err); throw err; };\n\n\n\n// WEBPACK FOOTER //\n// webpack/bootstrap def2f39c04517bb0de2d"],"sourceRoot":""}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,233 +0,0 @@
|
||||
import os
|
||||
import threading
|
||||
import socket
|
||||
import logging
|
||||
import psutil
|
||||
import GPUtil
|
||||
import platform
|
||||
|
||||
from flask import Flask, jsonify, request, send_from_directory
|
||||
from flask_cors import CORS
|
||||
from tensorboard import program
|
||||
from PytorchBoot.utils.project_util import ProjectUtil
|
||||
from PytorchBoot.stereotype import get_all_component_classes, get_all_component_comments
|
||||
from PytorchBoot.application import application_class
|
||||
from PytorchBoot.status import status_manager
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.utils.timer_util import Timer
|
||||
|
||||
app = Flask(__name__, static_folder="../client")
|
||||
app.logger.setLevel("WARNING")
|
||||
logging.getLogger("werkzeug").disabled = True
|
||||
CORS(app)
|
||||
root_path = os.getcwd()
|
||||
ProjectUtil.scan_project(root_path)
|
||||
configs = ProjectUtil.scan_configs(root_path)
|
||||
running_tensorboard = {}
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def serve_index():
|
||||
return send_from_directory(app.static_folder, "index.html")
|
||||
|
||||
|
||||
@app.route("/<path:path>")
|
||||
def serve_file(path):
|
||||
return send_from_directory(app.static_folder, path)
|
||||
|
||||
|
||||
@app.route("/test", methods=["POST"])
|
||||
def hello_world():
|
||||
return jsonify(message="Hello, World!")
|
||||
|
||||
|
||||
@app.route("/project/structure", methods=["POST"])
|
||||
def project_structure():
|
||||
component_info = {}
|
||||
for st, cls_dict in get_all_component_classes().items():
|
||||
component_info[st] = {k: v.__name__ for k, v in cls_dict.items()}
|
||||
comment_info = get_all_component_comments()
|
||||
app_info = {}
|
||||
for app_name, app_cls in application_class.items():
|
||||
app_info[app_name] = app_cls.__name__
|
||||
|
||||
return jsonify(
|
||||
components=component_info,
|
||||
comments=comment_info,
|
||||
applications=app_info,
|
||||
configs=configs,
|
||||
root_path=root_path,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/project/run_app", methods=["POST"])
|
||||
def run_application():
|
||||
data = request.json
|
||||
app_name = data.get("app_name")
|
||||
app_cls = application_class.get(app_name)
|
||||
|
||||
if app_cls is None:
|
||||
Log.error(
|
||||
f"No class annotated with @PytorchBootApplication found with the name '{app_name}'.",
|
||||
True,
|
||||
)
|
||||
return jsonify(
|
||||
{
|
||||
"message": f"No application found with the name '{app_name}'",
|
||||
"status": "error",
|
||||
}
|
||||
)
|
||||
|
||||
if not hasattr(app_cls, "start"):
|
||||
Log.error(
|
||||
"The class annotated with @PytorchBootApplication should have a 'start' method.",
|
||||
True,
|
||||
)
|
||||
return jsonify(
|
||||
{"message": "The class should have a 'start' method", "status": "error"}
|
||||
)
|
||||
|
||||
def run_in_background():
|
||||
Log.info(f"Application '{app_cls.__name__}' started.")
|
||||
timer = Timer("Application")
|
||||
timer.start()
|
||||
status_manager.run_app(app_name, app_cls)
|
||||
app_cls.start()
|
||||
status_manager.end_app(app_name)
|
||||
timer.stop()
|
||||
Log.info(timer.get_elasped_time_str(Timer.HOURS))
|
||||
Log.success("Application finished.")
|
||||
|
||||
threading.Thread(target=run_in_background).start()
|
||||
|
||||
return jsonify(
|
||||
{"message": f"Application '{app_name}' is running now.", "status": "success"}
|
||||
)
|
||||
|
||||
|
||||
@app.route("/project/get_status", methods=["POST"])
|
||||
def get_status():
|
||||
cpu_info = {
|
||||
"model": platform.processor(),
|
||||
"usage_percent": psutil.cpu_percent(interval=1),
|
||||
}
|
||||
virtual_memory = psutil.virtual_memory()
|
||||
memory_info = {
|
||||
"used": round(virtual_memory.used / (1024**3), 3),
|
||||
"total": round(virtual_memory.total / (1024**3), 3),
|
||||
}
|
||||
|
||||
gpus = GPUtil.getGPUs()
|
||||
gpu_info = []
|
||||
for gpu in gpus:
|
||||
gpu_info.append(
|
||||
{
|
||||
"name": gpu.name,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
}
|
||||
)
|
||||
|
||||
return jsonify(
|
||||
curr_status=status_manager.get_status(),
|
||||
last_status=status_manager.get_last_status(),
|
||||
logs=status_manager.get_log(),
|
||||
progress=status_manager.get_progress(),
|
||||
running_apps=status_manager.get_running_apps(),
|
||||
cpu=cpu_info,
|
||||
memory=memory_info,
|
||||
gpus=gpu_info,
|
||||
)
|
||||
|
||||
|
||||
@app.route("/project/set_status", methods=["POST"])
|
||||
def set_status():
|
||||
status = request.json.get("status")
|
||||
progress = request.json.get("progress")
|
||||
if status:
|
||||
status_manager.set_status(
|
||||
app_name=status["app_name"],
|
||||
runner_name=status["runner_name"],
|
||||
key=status["key"],
|
||||
value=status["value"],
|
||||
)
|
||||
if progress:
|
||||
status_manager.set_progress(
|
||||
app_name=progress["app_name"],
|
||||
runner_name=progress["runner_name"],
|
||||
key=progress["key"],
|
||||
curr_value=progress["curr_value"],
|
||||
max_value=progress["max_value"],
|
||||
)
|
||||
return jsonify({"status": "success"})
|
||||
|
||||
@app.route("/project/add_log", methods=["POST"])
|
||||
def add_log():
|
||||
log = request.json.get("log")
|
||||
Log.log(log["message"], log["log_type"])
|
||||
return jsonify({"status": "success"})
|
||||
|
||||
def find_free_port(start_port):
|
||||
"""Find a free port starting from start_port."""
|
||||
port = start_port
|
||||
while True:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
result = sock.connect_ex(("localhost", port))
|
||||
if result != 0:
|
||||
return port
|
||||
port += 1
|
||||
|
||||
|
||||
def start_tensorboard(log_dir, port):
|
||||
"""Starts TensorBoard in a separate thread."""
|
||||
tb = program.TensorBoard()
|
||||
tb.configure(argv=[None, "--logdir", log_dir, "--port", str(port)])
|
||||
tb.launch()
|
||||
|
||||
|
||||
@app.route("/tensorboard/run", methods=["POST"])
|
||||
def run_tensorboard():
|
||||
data = request.json
|
||||
log_dir = data.get("log_dir")
|
||||
if log_dir in running_tensorboard:
|
||||
return jsonify(
|
||||
{
|
||||
"message": f"TensorBoard ({running_tensorboard[log_dir]}) is already running for <{log_dir}>",
|
||||
"url": running_tensorboard[log_dir],
|
||||
"status": "warning",
|
||||
}
|
||||
)
|
||||
|
||||
if not os.path.isdir(log_dir):
|
||||
return jsonify({"message": "Log directory does not exist", "status": "error"})
|
||||
|
||||
port = find_free_port(10000)
|
||||
|
||||
try:
|
||||
tb_thread = threading.Thread(target=start_tensorboard, args=(log_dir, port))
|
||||
tb_thread.start()
|
||||
except Exception as e:
|
||||
return jsonify(
|
||||
{"message": f"Error starting TensorBoard: {str(e)}", "status": "error"}
|
||||
)
|
||||
|
||||
url = f"http://localhost:{port}"
|
||||
running_tensorboard[log_dir] = url
|
||||
return jsonify(
|
||||
{"url": url, "message": f"TensorBoard is running at {url}", "status": "success"}
|
||||
)
|
||||
|
||||
|
||||
@app.route("/tensorboard/dirs", methods=["POST"])
|
||||
def get_tensorboard_dirs():
|
||||
tensorboard_dirs = []
|
||||
for root, dirs, _ in os.walk(root_path):
|
||||
for dir_name in dirs:
|
||||
if dir_name == "tensorboard":
|
||||
tensorboard_dirs.append(os.path.join(root, dir_name))
|
||||
return jsonify({"tensorboard_dirs": tensorboard_dirs})
|
||||
|
||||
|
||||
@app.route("/tensorboard/running_tensorboards", methods=["POST"])
|
||||
def get_running_tensorboards():
|
||||
return jsonify(running_tensorboards=running_tensorboard)
|
@ -1,3 +0,0 @@
|
||||
from PytorchBoot.utils.log_util import Log
|
||||
from PytorchBoot.utils.tensorboard_util import TensorboardWriter
|
||||
from PytorchBoot.utils.timer_util import Timer
|
@ -1,81 +0,0 @@
|
||||
import time
|
||||
import PytorchBoot.namespace as namespace
|
||||
from PytorchBoot.status import status_manager
|
||||
|
||||
class Log:
|
||||
MAX_TITLE_LENGTH:int = 7
|
||||
TYPE_COLOR_MAP = {
|
||||
namespace.LogType.INFO: "\033[94m",
|
||||
namespace.LogType.ERROR: "\033[91m",
|
||||
namespace.LogType.WARNING: "\033[93m",
|
||||
namespace.LogType.SUCCESS: "\033[92m",
|
||||
namespace.LogType.DEBUG: "\033[95m",
|
||||
namespace.LogType.TERMINATE: "\033[96m"
|
||||
}
|
||||
def get_time():
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
def blue(message):
|
||||
# blue
|
||||
print(f"\033[94m{message}\033[0m")
|
||||
def red(message):
|
||||
# red
|
||||
print(f"\033[91m{message}\033[0m")
|
||||
def yellow(message):
|
||||
# yellow
|
||||
print(f"\033[93m{message}\033[0m")
|
||||
def green(message):
|
||||
# green
|
||||
print(f"\033[92m{message}\033[0m")
|
||||
|
||||
def log(message, log_type: str):
|
||||
time_str = Log.get_time()
|
||||
space = ""
|
||||
if len(log_type) < Log.MAX_TITLE_LENGTH:
|
||||
space = " " * (Log.MAX_TITLE_LENGTH - len(log_type))
|
||||
|
||||
print (f"\033[1m\033[4m({time_str})\033[0m \033[1m{Log.TYPE_COLOR_MAP[log_type]}[{log_type.capitalize()}]\033[0m{space} {Log.TYPE_COLOR_MAP[log_type]}{message}\033[0m")
|
||||
status_manager.add_log(time_str, log_type, message)
|
||||
|
||||
def bold(message):
|
||||
print(f"\033[1m{message}\033[0m")
|
||||
def underline(message):
|
||||
print(f"\033[4m{message}\033[0m")
|
||||
|
||||
def info(message):
|
||||
Log.log(message, namespace.LogType.INFO)
|
||||
|
||||
def error(message, terminate=False):
|
||||
Log.log(message, namespace.LogType.ERROR)
|
||||
if terminate:
|
||||
Log.terminate("Application Terminated.")
|
||||
|
||||
|
||||
def warning(message):
|
||||
Log.log(message, namespace.LogType.WARNING)
|
||||
def success(message):
|
||||
Log.log(message, namespace.LogType.SUCCESS)
|
||||
|
||||
def debug(message):
|
||||
Log.log(message, namespace.LogType.DEBUG)
|
||||
|
||||
def terminate(message):
|
||||
Log.log(message, namespace.LogType.TERMINATE)
|
||||
exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
Log.info("This is a info message")
|
||||
Log.error("This is a error message")
|
||||
Log.warning("This is a warning message")
|
||||
Log.success("This is a success message")
|
||||
Log.debug("This is a debug message")
|
||||
Log.blue("This is a blue message")
|
||||
Log.red("This is a red message")
|
||||
Log.yellow("This is a yellow message")
|
||||
Log.green("This is a green message")
|
||||
|
||||
Log.bold("This is a bold message")
|
||||
Log.underline("This is a underline message")
|
||||
Log.error("This is a terminate message", True)
|
||||
|
||||
|
@ -1,50 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import yaml
|
||||
import importlib
|
||||
|
||||
class ProjectUtil:
|
||||
@staticmethod
|
||||
def scan_project(root_path):
|
||||
sys.path.append(root_path)
|
||||
if not os.path.exists(root_path) or not os.path.isdir(root_path):
|
||||
raise ValueError(f"The provided root_path '{root_path}' is not a valid directory.")
|
||||
|
||||
parent_dir = os.path.dirname(root_path)
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
def import_all_modules(path, package_name):
|
||||
for root, dirs, files in os.walk(path):
|
||||
relative_path = os.path.relpath(root, root_path)
|
||||
if relative_path == '.':
|
||||
module_package = package_name
|
||||
else:
|
||||
module_package = f"{package_name}.{relative_path.replace(os.sep, '.')}"
|
||||
for file in files:
|
||||
if file.endswith(".py") and file != "__init__.py":
|
||||
module_name = file[:-3]
|
||||
full_module_name = f"{module_package}.{module_name}"
|
||||
if full_module_name not in sys.modules:
|
||||
importlib.import_module(full_module_name)
|
||||
dirs[:] = [d for d in dirs if not d.startswith('.')]
|
||||
|
||||
package_name = os.path.basename(root_path)
|
||||
import_all_modules(root_path, package_name)
|
||||
|
||||
@staticmethod
|
||||
def scan_configs(root_path):
|
||||
configs = {}
|
||||
for root, dirs, files in os.walk(root_path):
|
||||
for file in files:
|
||||
if file.endswith(('.yaml', '.yml')):
|
||||
if file.startswith('__'):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
content = yaml.safe_load(f)
|
||||
configs[os.path.splitext(file)[0]] = content
|
||||
except yaml.YAMLError as e:
|
||||
print(f"Error reading {file_path}: {e}")
|
||||
|
||||
return configs
|
@ -1,44 +0,0 @@
|
||||
import torch
|
||||
import PytorchBoot.namespace as namespace
|
||||
|
||||
class TensorboardWriter:
|
||||
@staticmethod
|
||||
def write_tensorboard(writer, panel, data_dict, step, simple_scalar = False):
|
||||
|
||||
if simple_scalar:
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, data_dict, step)
|
||||
|
||||
if namespace.TensorBoard.SCALAR in data_dict:
|
||||
scalar_data_dict = data_dict[namespace.TensorBoard.SCALAR]
|
||||
TensorboardWriter.write_scalar_tensorboard(writer, panel, scalar_data_dict, step)
|
||||
if namespace.TensorBoard.IMAGE in data_dict:
|
||||
image_data_dict = data_dict[namespace.TensorBoard.IMAGE]
|
||||
TensorboardWriter.write_image_tensorboard(writer, panel, image_data_dict, step)
|
||||
if namespace.TensorBoard.POINT in data_dict:
|
||||
point_data_dict = data_dict[namespace.TensorBoard.POINT]
|
||||
TensorboardWriter.write_points_tensorboard(writer, panel, point_data_dict, step)
|
||||
|
||||
@staticmethod
|
||||
def write_scalar_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if isinstance(value, dict):
|
||||
writer.add_scalars(f'{panel}/{key}', value, step)
|
||||
else:
|
||||
writer.add_scalar(f'{panel}/{key}', value, step)
|
||||
|
||||
@staticmethod
|
||||
def write_image_tensorboard(writer, panel, data_dict, step):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def write_points_tensorboard(writer, panel, data_dict, step):
|
||||
for key, value in data_dict.items():
|
||||
if value.shape[-1] == 3:
|
||||
colors = torch.zeros_like(value)
|
||||
vertices = torch.cat([value, colors], dim=-1)
|
||||
elif value.shape[-1] == 6:
|
||||
vertices = value
|
||||
else:
|
||||
raise ValueError(f'Unexpected value shape: {value.shape}')
|
||||
faces = None
|
||||
writer.add_mesh(f'{panel}/{key}', vertices=vertices, faces=faces, global_step=step)
|
@ -1,32 +0,0 @@
|
||||
import time
|
||||
|
||||
class Timer:
|
||||
MILLI_SECONDS = "milliseconds"
|
||||
SECONDS = "seconds"
|
||||
MINUTES = "minutes"
|
||||
HOURS = "hours"
|
||||
def __init__(self, name=None):
|
||||
self.start_time = None
|
||||
self.end_time = None
|
||||
self.name = name
|
||||
|
||||
def start(self):
|
||||
self.start_time = time.time()
|
||||
|
||||
def stop(self):
|
||||
self.end_time = time.time()
|
||||
|
||||
def elapsed_time(self):
|
||||
return int(self.end_time - self.start_time)
|
||||
|
||||
def get_elasped_time_str(self, format):
|
||||
if format == Timer.SECONDS:
|
||||
return f"Elapsed time in <{self.name}>: {self.elapsed_time()} seconds"
|
||||
elif format == Timer.MINUTES:
|
||||
return f"Elapsed time in <{self.name}>: {self.elapsed_time() // 60} minutes, {self.elapsed_time() % 60} seconds"
|
||||
elif format == Timer.HOURS:
|
||||
return f"Elapsed time in <{self.name}>: {self.elapsed_time() // 3600} hours, {(self.elapsed_time() % 3600)//60} minutes, {self.elapsed_time() % 60} seconds"
|
||||
elif format == Timer.MILLI_SECONDS:
|
||||
return f"Elapsed time in <{self.name}>: {(self.end_time - self.start_time) * 1000} milliseconds"
|
||||
else:
|
||||
return f"Invalid format: {format}"
|
@ -1,3 +0,0 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: pytorch-boot
|
||||
Version: 0.1
|
@ -1,46 +0,0 @@
|
||||
setup.py
|
||||
PytorchBoot/__init__.py
|
||||
PytorchBoot/application.py
|
||||
PytorchBoot/boot.py
|
||||
PytorchBoot/component.py
|
||||
PytorchBoot/config.py
|
||||
PytorchBoot/namespace.py
|
||||
PytorchBoot/status.py
|
||||
PytorchBoot/stereotype.py
|
||||
PytorchBoot/dataset/__init__.py
|
||||
PytorchBoot/dataset/base_dataset.py
|
||||
PytorchBoot/factory/__init__.py
|
||||
PytorchBoot/factory/component_factory.py
|
||||
PytorchBoot/factory/optimizer_factory.py
|
||||
PytorchBoot/runners/__init__.py
|
||||
PytorchBoot/runners/evaluator.py
|
||||
PytorchBoot/runners/predictor.py
|
||||
PytorchBoot/runners/runner.py
|
||||
PytorchBoot/runners/trainer.py
|
||||
PytorchBoot/runners/web_runner.py
|
||||
PytorchBoot/templates/application.py
|
||||
PytorchBoot/templates/config.py
|
||||
PytorchBoot/ui/client/index.html
|
||||
PytorchBoot/ui/client/static/css/app.5383ee564f9a1a656786665504aa6b98.css
|
||||
PytorchBoot/ui/client/static/css/app.5383ee564f9a1a656786665504aa6b98.css.map
|
||||
PytorchBoot/ui/client/static/fonts/ionicons.143146f.woff2
|
||||
PytorchBoot/ui/client/static/fonts/ionicons.99ac330.woff
|
||||
PytorchBoot/ui/client/static/fonts/ionicons.d535a25.ttf
|
||||
PytorchBoot/ui/client/static/img/ionicons.a2c4a26.svg
|
||||
PytorchBoot/ui/client/static/js/app.230235873e25a72eeacb.js
|
||||
PytorchBoot/ui/client/static/js/app.230235873e25a72eeacb.js.map
|
||||
PytorchBoot/ui/client/static/js/manifest.2ae2e69a05c33dfc65f8.js
|
||||
PytorchBoot/ui/client/static/js/manifest.2ae2e69a05c33dfc65f8.js.map
|
||||
PytorchBoot/ui/client/static/js/vendor.9f7b4785a30f0533ee08.js
|
||||
PytorchBoot/ui/client/static/js/vendor.9f7b4785a30f0533ee08.js.map
|
||||
PytorchBoot/ui/server/app.py
|
||||
PytorchBoot/utils/__init__.py
|
||||
PytorchBoot/utils/log_util.py
|
||||
PytorchBoot/utils/project_util.py
|
||||
PytorchBoot/utils/tensorboard_util.py
|
||||
PytorchBoot/utils/timer_util.py
|
||||
pytorch_boot.egg-info/PKG-INFO
|
||||
pytorch_boot.egg-info/SOURCES.txt
|
||||
pytorch_boot.egg-info/dependency_links.txt
|
||||
pytorch_boot.egg-info/entry_points.txt
|
||||
pytorch_boot.egg-info/top_level.txt
|
@ -1 +0,0 @@
|
||||
|
@ -1,3 +0,0 @@
|
||||
[console_scripts]
|
||||
ptb = PytorchBoot.boot:main
|
||||
pytorch-boot = PytorchBoot.boot:main
|
@ -1 +0,0 @@
|
||||
PytorchBoot
|
Loading…
x
Reference in New Issue
Block a user