Skip to content
Snippets Groups Projects
Commit ada9e8fd authored by Michał Marcińczuk's avatar Michał Marcińczuk
Browse files

Batch evaluation.

parent 5118ae5b
1 merge request!41Dev v07
Pipeline #6345 failed with stage
in 2 minutes and 2 seconds
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import argparse import argparse
import logging
import os import os
import time import time
from typing import List, Tuple
import poldeepner2 import poldeepner2
from poldeepner2.utils.data_utils import read_tsv from poldeepner2.utils.data_utils import read_tsv
...@@ -12,23 +14,22 @@ from poldeepner2.utils.sequence_labeling import classification_report ...@@ -12,23 +14,22 @@ from poldeepner2.utils.sequence_labeling import classification_report
from poldeepner2.utils.sequences import FeatureGeneratorFactory from poldeepner2.utils.sequences import FeatureGeneratorFactory
def main(args): def evaluate_model(args) -> List[Tuple[str, str]]:
print("Loading the NER model ...") logging.info(f"Loading the NER model {args.model}...")
ner = poldeepner2.load(args.model, device=args.device) ner = poldeepner2.load(args.model, device=args.device)
for param in ["device", "max_seq_length", "sequence_generator"]: for param in ["device", "max_seq_length", "sequence_generator", "output_top_k"]:
value = args.__dict__.get(param, None) value = args.__dict__.get(param, None)
if value is not None: if value is not None:
value_default = ner.model.config.__dict__.get(param) value_default = ner.model.config.__dict__.get(param)
if str(value) != str(value_default): if str(value) != str(value_default):
print(f"Forced change of the parameter: {param} '{value_default}' => '{value}'") logging.info(f"Forced change of the parameter: {param} '{value_default}' => '{value}'")
ner.model.config.__dict__[param] = value ner.model.config.__dict__[param] = value
if args.seed is not None: if args.seed is not None:
setup_seed(args.seed) setup_seed(args.seed)
print("Processing ...") logging.info("Processing ...")
sentences_labels = read_tsv(os.path.join(args.input), True) sentences_labels = read_tsv(os.path.join(args.input), True)
sentences = [sentence[0] for sentence in sentences_labels] sentences = [sentence[0] for sentence in sentences_labels]
labels = [sentence[1] for sentence in sentences_labels] labels = [sentence[1] for sentence in sentences_labels]
...@@ -42,16 +43,25 @@ def main(args): ...@@ -42,16 +43,25 @@ def main(args):
time_processing = time.clock() - t0 time_processing = time.clock() - t0
report = classification_report(labels, predictions, digits=4) report = classification_report(labels, predictions, digits=4)
print(report) # print(report)
#
print(f"Total time : {time_processing:>8.4} second(s)") # print(f"Total time: : {time_processing:>8.4} second(s)")
print(f"Data size: : {data_size/1000000:>8.4} M characters") # print(f"Data size: : {data_size/1000000:>8.4} M characters")
print(f"Speed: : {data_size / 1000000 / (time_processing/60):>8.4} M characters/minute") # print(f"Speed: : {data_size / 1000000 / (time_processing/60):>8.4} M characters/minute")
print(f"Number of token labels : {len(ner.model.config.labels):>8} ") # print(f"Number of token labels : {len(ner.model.config.labels):>8} ")
print(f"Stats : {str(stats)}") # print(f"Stats : {str(stats)}")
return [
def parse_args(): ("Report", report),
("Total time", f"{time_processing:>8.4} second(s)"),
(f"Data size", f"{data_size/1000000:>8.4} M characters"),
(f"Speed", f"{data_size / 1000000 / (time_processing/60):>8.4} M characters/minute"),
(f"Number", f"{len(ner.model.config.labels):>8} "),
(f"Stats", f"{str(stats)}")
]
def parse_args(args=None):
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Process a single TSV with a NER model') description='Process a single TSV with a NER model')
parser.add_argument('--input', required=True, metavar='PATH', help='path to a file with a list of files') parser.add_argument('--input', required=True, metavar='PATH', help='path to a file with a list of files')
...@@ -64,12 +74,23 @@ def parse_args(): ...@@ -64,12 +74,23 @@ def parse_args():
help="method of sequence generation", default=None, required=False) help="method of sequence generation", default=None, required=False)
parser.add_argument('--seed', required=False, default=None, metavar='N', type=int, parser.add_argument('--seed', required=False, default=None, metavar='N', type=int,
help='a seed used to initialize a number generator') help='a seed used to initialize a number generator')
return parser.parse_args() parser.add_argument('--output-top-k', required=False, default=None, metavar='N', type=int,
help='output top k labels for each token')
if args:
return parser.parse_args(args)
else:
return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
cliargs = parse_args() cliargs = parse_args()
try: try:
main(cliargs) output = evaluate_model(cliargs)
for k, v in output:
if k == "Report":
print(v)
else:
print(f"{k:10} {v}")
except ValueError as er: except ValueError as er:
print("[ERROR] %s" % er) print("[ERROR] %s" % er)
import argparse
import logging
from pathlib import Path
from typing import List, Dict, Union, Any
import yaml
import os
import evaluate_tsv
def load_yaml(path: str) -> Dict[str, Union[str, List[str]]]:
params = []
with open(path) as file:
try:
params = yaml.safe_load(file)
except yaml.YAMLError as exception:
raise ValueError(exception)
return params
def generate_configurations(params: Dict[str, Union[str, List[str]]]) -> List[List[str]]:
configurations = [{}]
for key, values in params.items():
if key.endswith("/list"):
# key = key[:-5]
# for ll in lists:
# for value in values:
# ll.extend([f"--{key}", value])
pass
elif isinstance(values, list):
new_configurations = []
for config in configurations:
for value in values:
new_config = config.copy()
new_config[key] = value
new_configurations.append(new_config)
configurations = new_configurations
else:
for config in configurations:
config[key] = values
return configurations
def save_metrics(metrics: Dict[str, Any], path: str):
yaml.dump(metrics, open(path, "w", encoding="utf-8"))
def configuration_to_cmd(configuration: Dict[str, str]):
args = []
for k, v in configuration.items():
args += [f"--{k}", v]
return args
def parse_args():
parser = argparse.ArgumentParser(
description='Evaluate a set of models and parameters')
parser.add_argument('--config', required=True, metavar='PATH', help='path to an YAML config file')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
args = parse_args()
params = load_yaml(args.config)
configurations = generate_configurations(params)
logging.info(f"{configurations}")
logging.info(f"Number of configurations: {len(configurations)}")
path_metrics = os.path.splitext(args.config)[0] + "_eval.yaml"
logging.info(f"Metrics will be saved to: {path_metrics}")
metrics = {}
if Path(path_metrics).exists():
metrics = load_yaml(path_metrics)
for configuration in configurations:
key = str(sorted(configuration.items()))
if key in metrics and metrics[key]["metrics"]:
logging.info(f"Skipping {key}")
continue
metrics[key] = {}
metrics[key]["params"] = configuration
cmd = configuration_to_cmd(configuration)
args = evaluate_tsv.parse_args(cmd)
metrics[key]["metrics"] = {}
for k, v in evaluate_tsv.evaluate_model(args):
metrics[key]["metrics"][k] = v.split("\n")
save_metrics(metrics, path_metrics)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment