diff --git a/train-batch.py b/train-batch.py index 863870f1ee6a398c3f82c3654a6b22546a2f2ef7..4cee4787bb087729b349e253910776cae9af3701 100644 --- a/train-batch.py +++ b/train-batch.py @@ -22,7 +22,12 @@ def load_configurations(path: str) -> Dict[str, Union[str, List[str]]]: def generate_cmds_list(params: Dict[str, Union[str, List[str]]]) -> List[List[str]]: lists = [[]] for key, values in params.items(): - if isinstance(values, list): + if key.endswith("/list"): + key = key[:-5] + for ll in lists: + for value in values: + ll.extend([f"--{key}", value]) + elif isinstance(values, list): new_lists = [] for value in values: for ll in lists: @@ -33,6 +38,7 @@ def generate_cmds_list(params: Dict[str, Union[str, List[str]]]) -> List[List[st ll.append(f"--{key}") if values is not None: ll.append(str(values)) + print(lists) return lists @@ -57,7 +63,4 @@ if __name__ == '__main__': for cmd in cmds: args = parser.parse_args(cmd) print(args) - try: - train_model(args) - except KeyError as error: - print(error) + train_model(args)