Select Git revision
FindGlib.cmake
train-batch.py NaN GiB
import argparse
import logging
from typing import List, Dict, Union
import yaml
from poldeepner2.utils.train_utils import add_train_args
from train import train_model
def load_configurations(path: str) -> Dict[str, Union[str, List[str]]]:
params = []
with open(path) as file:
try:
params = yaml.safe_load(file)
except yaml.YAMLError as exception:
raise ValueError(exception)
return params
def generate_cmds_list(params: Dict[str, Union[str, List[str]]]) -> List[List[str]]:
lists = [[]]
for key, values in params.items():
if key.endswith("/list"):
key = key[:-5]
for ll in lists:
for value in values:
ll.extend([f"--{key}", value])
elif isinstance(values, list):
new_lists = []
for value in values:
for ll in lists:
new_lists.append(ll + [f"--{key}", str(value)])
lists = new_lists
else:
for ll in lists:
ll.append(f"--{key}")
if values is not None:
ll.append(str(values))
print(lists)
return lists
def parse_args():
parser = argparse.ArgumentParser(
description='Batch training for given space of parameters')
parser.add_argument('--config', required=True, metavar='PATH', help='path to an YAML config file')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
params = load_configurations(args.config)
cmds = generate_cmds_list(params)
logging.info(f"Number of configurations: {len(cmds)}")
parser = argparse.ArgumentParser()
parser = add_train_args(parser)
for cmd in cmds:
args = parser.parse_args(cmd)
print(args)
train_model(args)