Skip to content
Snippets Groups Projects
create_performance_table.py 4.80 KiB
from conll18_ud_eval import *
from absl import app
from absl import flags
import pathlib
import csv

# this script requires having conll18_ud_eval.py file in the same directory. It is available here
# https://universaldependencies.org/conll18/

FLAGS = flags.FLAGS
flags.DEFINE_string(name="pred_dir", default=r"/home/pszenny/Desktop/IPI_PAN/evaluate_UD/predictions_UD_29/pred",
                    help="Path to directory with predictions on test sets.")
flags.DEFINE_string(name="ud_dir", default=r"/home/pszenny/Desktop/IPI_PAN/evaluate_UD/predictions_UD_29/ud_files",
                    help="Path to directory with UD datasets up to UD_treebank/files .")
flags.DEFINE_string(name="models_dir", default=r"/tmp/lustre_shared/lukasz/models_UD_2.9",
                    help="Path to directory with trained models treebank/allennlp_folder/files.")
flags.DEFINE_string(name="UD_version", default="29",
                    help="UD version number.")
flags.DEFINE_string(name="URL_download", default="http://s3.clarin-pl.eu/dspace/combo/ud_29/{model}.tar.gz",
                    help="template URL to download model with {model} where model name should be placed.")
flags.DEFINE_string(name="URL_licence",
                    default="https://github.com/UniversalDependencies/{treebank}/blob/r2.9/LICENSE.txt",
                    help="template URL to license.txt with {treebank} where treebank name should be placed.")


def evaluate_wrapper(gold_file, system_file):
    # function that overloads function from conll18_ud_eval.py
    # Load CoNLL-U files
    gold_ud = load_conllu_file(gold_file)
    system_ud = load_conllu_file(system_file)
    return evaluate(gold_ud, system_ud)


def run(_):
    path_to_folder_with_predictions = pathlib.Path(FLAGS.pred_dir)
    path_to_folder_with_ud = pathlib.Path(FLAGS.ud_dir)
    path_to_folder_with_models = pathlib.Path(FLAGS.models_dir)
    URL_download = FLAGS.URL_download
    URL_licence = FLAGS.URL_licence

    # changing model name and creating dictionary with key: treebank value: model name
    directory = list(path_to_folder_with_models.iterdir())
    treebank_model_name = {}
    for filename in directory:
        allen_folders = list(filename.iterdir())
        assert len(allen_folders) == 1, f"Multiple allen nlp serialization folders."
        allen_folder = allen_folders[0]
        language = str(filename).split("/")[-1].split("_")[1].split("-")[0]
        if "model.tar.gz" not in [str(files).split("/")[-1] for files in list(allen_folder.iterdir())]:
            continue
        if sum(language in str(s) for s in directory) != 1:
            new_name = str(filename).split("/")[-1].split("_")[1].lower() + f"-ud{FLAGS.UD_version}.tar.gz"
        else:
            new_name = language.lower() + f"-ud{FLAGS.UD_version}.tar.gz"
        model_path = allen_folder / "model.tar.gz"
        model_path.rename(pathlib.Path(allen_folder, new_name))
        treebank_model_name[filename] = new_name

    # evaluating models
    all_result = [["Treebank", "Model name", "Model link", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS",
                   "CLAS", "MLAS", "BLEX", "LICENSE"]]

    for filename in list(path_to_folder_with_predictions.iterdir()):
        path_to_predictions = path_to_folder_with_predictions / filename
        folder_with_data = str(filename).split("/")[-1].replace("predictions_test.conllu", "")
        ud_folder = path_to_folder_with_ud / folder_with_data

        ud_files = list(ud_folder.iterdir())
        test_file = [f for f in ud_files if "test" in f.name and ".conllu" in f.name]
        assert len(test_file) == 1, f"Couldn't find training file."
        test_file_path = test_file[0]

        evaluation = evaluate_wrapper(str(test_file_path), str(path_to_predictions))
        metrics_evaluation = [folder_with_data, treebank_model_name[folder_with_data],
                              URL_download.format(model=treebank_model_name[folder_with_data])]
        for metric in ["UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS",
                       "MLAS", "BLEX"]:
            metrics_evaluation.append(round(100 * evaluation[metric].precision, 2))
        metrics_evaluation.append(URL_licence.format(treebank=folder_with_data))
        all_result.append(metrics_evaluation)

    # saving google sheet performance table
    with open("google_sheet.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerows(all_result)

    # creating gitlab performance table
    performance_table_gitlab = []
    for row in all_result:
        new_row = "|" + row[0] + "|[" + row[1] + "](" + row[2] + ")|" + "|".join(row[3:]) + "|"
        performance_table_gitlab.append(new_row)

    with open('performance_git.txt', 'w') as fo:
        fo.write('\n'.join(str(i) for i in performance_table_gitlab))


def main():
    app.run(run)


if __name__ == "__main__":
    main()