Skip to content
Snippets Groups Projects
Select Git revision
  • f726d29ab4be698af76eb57f94ca8cdd6112eab0
  • main default protected
  • change_data_model
  • feature/add_auth_asr_service
  • fix/incorrect_import
  • feature/change_registry_clarin
  • feature/add_base_asr_service
  • feature/add_poetry
  • feature/add_word_ids
  • feature/add_sziszapangma
10 results

gold_transcript_task.py

Blame
  • train-batch.py NaN GiB
    import argparse
    import logging
    from typing import List, Dict, Union
    import yaml
    
    from poldeepner2.utils.train_utils import add_train_args
    from train import train_model
    
    
    def load_configurations(path: str) -> Dict[str, Union[str, List[str]]]:
        params = []
    
        with open(path) as file:
            try:
                params = yaml.safe_load(file)
            except yaml.YAMLError as exception:
                raise ValueError(exception)
    
        return params
    
    
    def generate_cmds_list(params: Dict[str, Union[str, List[str]]]) -> List[List[str]]:
        lists = [[]]
        for key, values in params.items():
            if key.endswith("/list"):
                key = key[:-5]
                for ll in lists:
                    for value in values:
                        ll.extend([f"--{key}", value])
            elif isinstance(values, list):
                new_lists = []
                for value in values:
                    for ll in lists:
                        new_lists.append(ll + [f"--{key}", str(value)])
                lists = new_lists
            else:
                for ll in lists:
                    ll.append(f"--{key}")
                    if values is not None:
                        ll.append(str(values))
        print(lists)
        return lists
    
    
    def parse_args():
        parser = argparse.ArgumentParser(
            description='Batch training for given space of parameters')
        parser.add_argument('--config', required=True, metavar='PATH', help='path to an YAML config file')
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        logging.basicConfig(level=logging.INFO)
        args = parse_args()
    
        params = load_configurations(args.config)
        cmds = generate_cmds_list(params)
        logging.info(f"Number of configurations: {len(cmds)}")
    
        parser = argparse.ArgumentParser()
        parser = add_train_args(parser)
    
        for cmd in cmds:
            args = parser.parse_args(cmd)
            print(args)
            train_model(args)