Select Git revision
evaluate_tsv_batch.py 2.79 KiB
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Union, Any
import yaml
import os
import evaluate_tsv
def load_yaml(path: str) -> Dict[str, Union[str, List[str]]]:
params = []
with open(path) as file:
try:
params = yaml.safe_load(file)
except yaml.YAMLError as exception:
raise ValueError(exception)
return params
def generate_configurations(params: Dict[str, Union[str, List[str]]]) -> List[List[str]]:
configurations = [{}]
for key, values in params.items():
if key.endswith("/list"):
# key = key[:-5]
# for ll in lists:
# for value in values:
# ll.extend([f"--{key}", value])
pass
elif isinstance(values, list):
new_configurations = []
for config in configurations:
for value in values:
new_config = config.copy()
new_config[key] = value
new_configurations.append(new_config)
configurations = new_configurations
else:
for config in configurations:
config[key] = values
return configurations
def save_metrics(metrics: Dict[str, Any], path: str):
yaml.dump(metrics, open(path, "w", encoding="utf-8"))
def configuration_to_cmd(configuration: Dict[str, str]):
args = []
for k, v in configuration.items():
args += [f"--{k}", v]
return args
def parse_args():
parser = argparse.ArgumentParser(
description='Evaluate a set of models and parameters')
parser.add_argument('--config', required=True, metavar='PATH', help='path to an YAML config file')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
params = load_yaml(args.config)
configurations = generate_configurations(params)
logging.info(f"{configurations}")
logging.info(f"Number of configurations: {len(configurations)}")
path_metrics = os.path.splitext(args.config)[0] + "_eval.yaml"
logging.info(f"Metrics will be saved to: {path_metrics}")
metrics = {}
if Path(path_metrics).exists():
metrics = load_yaml(path_metrics)
for configuration in configurations:
key = str(sorted(configuration.items()))
if key in metrics and metrics[key]["metrics"]:
logging.info(f"Skipping {key}")
continue
metrics[key] = {}
metrics[key]["params"] = configuration
cmd = configuration_to_cmd(configuration)
args = evaluate_tsv.parse_args(cmd)
metrics[key]["metrics"] = {}
for k, v in evaluate_tsv.evaluate_model(args):
metrics[key]["metrics"][k] = v.split("\n")
save_metrics(metrics, path_metrics)