Skip to content
Snippets Groups Projects
Select Git revision
  • 03c1e6697f4f2a667f4342ad5c1adc7f7de5581e
  • master default protected
  • vertical_relations
  • lu_without_semantic_frames
  • hierarchy
  • additional-unification-filters
  • v0.1.1
  • v0.1.0
  • v0.0.9
  • v0.0.8
  • v0.0.7
  • v0.0.6
  • v0.0.5
  • v0.0.4
  • v0.0.3
  • v0.0.2
  • v0.0.1
17 results

tests.py

Blame
  • evaluate_tsv_batch.py 2.79 KiB
    import argparse
    import logging
    from pathlib import Path
    from typing import List, Dict, Union, Any
    import yaml
    import os
    
    import evaluate_tsv
    
    
    def load_yaml(path: str) -> Dict[str, Union[str, List[str]]]:
        params = []
    
        with open(path) as file:
            try:
                params = yaml.safe_load(file)
            except yaml.YAMLError as exception:
                raise ValueError(exception)
        return params
    
    
    def generate_configurations(params: Dict[str, Union[str, List[str]]]) -> List[List[str]]:
        configurations = [{}]
        for key, values in params.items():
            if key.endswith("/list"):
                # key = key[:-5]
                # for ll in lists:
                #     for value in values:
                #         ll.extend([f"--{key}", value])
                pass
            elif isinstance(values, list):
                new_configurations = []
                for config in configurations:
                    for value in values:
                        new_config = config.copy()
                        new_config[key] = value
                        new_configurations.append(new_config)
                configurations = new_configurations
            else:
                for config in configurations:
                    config[key] = values
        return configurations
    
    
    def save_metrics(metrics: Dict[str, Any], path: str):
        yaml.dump(metrics, open(path, "w", encoding="utf-8"))
    
    
    def configuration_to_cmd(configuration: Dict[str, str]):
        args = []
        for k, v in configuration.items():
            args += [f"--{k}", v]
        return args
    
    
    def parse_args():
        parser = argparse.ArgumentParser(
            description='Evaluate a set of models and parameters')
        parser.add_argument('--config', required=True, metavar='PATH', help='path to an YAML config file')
        return parser.parse_args()
    
    
    if __name__ == '__main__':
        logging.basicConfig(level=logging.INFO)
        args = parse_args()
    
        params = load_yaml(args.config)
        configurations = generate_configurations(params)
        logging.info(f"{configurations}")
        logging.info(f"Number of configurations: {len(configurations)}")
    
        path_metrics = os.path.splitext(args.config)[0] + "_eval.yaml"
        logging.info(f"Metrics will be saved to: {path_metrics}")
    
        metrics = {}
        if Path(path_metrics).exists():
            metrics = load_yaml(path_metrics)
    
        for configuration in configurations:
            key = str(sorted(configuration.items()))
    
            if key in metrics and metrics[key]["metrics"]:
                logging.info(f"Skipping {key}")
                continue
    
            metrics[key] = {}
            metrics[key]["params"] = configuration
    
            cmd = configuration_to_cmd(configuration)
            args = evaluate_tsv.parse_args(cmd)
            metrics[key]["metrics"] = {}
            for k, v in evaluate_tsv.evaluate_model(args):
                metrics[key]["metrics"][k] = v.split("\n")
    
            save_metrics(metrics, path_metrics)