Skip to content
Snippets Groups Projects
predict_iwpt21.py 2.83 KiB
import pathlib

from absl import app
from absl import flags

from scripts import utils

CODE2LANG = {
    "ar": "Arabic",
    "bg": "Bulgarian",
    "cs": "Czech",
    "nl": "Dutch",
    "en": "English",
    "et": "Estonian",
    "fi": "Finnish",
    "fr": "French",
    "it": "Italian",
    "lv": "Latvian",
    "lt": "Lithuanian",
    "pl": "Polish",
    "ru": "Russian",
    "sk": "Slovak",
    "sv": "Swedish",
    "ta": "Tamil",
    "uk": "Ukrainian",
}

FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
                    help="Path to data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
                    help="Model serialization dir.")
flags.DEFINE_string(name="tools", default="",
                    help="UD tools path.")
flags.DEFINE_integer(name="cuda_device", default=-1,
                     help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True,
                     help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
                     help="Batch size.")


def run(_):
    models_dir = pathlib.Path(FLAGS.models_dir)
    for model_dir in models_dir.iterdir():
        lang = model_dir.name
        if lang not in CODE2LANG:
            print("Skipping unknown directory: ", lang)
            continue

        if FLAGS.expect_prefix:
            model_dir = list(model_dir.iterdir())
            assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
            model_dir = model_dir[0]

        data_dir = pathlib.Path(FLAGS.data_dir)
        files = list(data_dir.iterdir())
        test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name]
        # Try to use mwt file if it exists
        if test_file:
            assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file."
            test_file = test_file[0]
        else:
            test_file = [f for f in files if f"{lang}.conllu" == f.name]
            assert len(test_file) == 1, f"Couldn't find test file."
            test_file = test_file[0]

        output_pred = data_dir / f'{lang}_pred.conllu'
        command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
        --input_file {test_file}
        --output_file {output_pred}
        --cuda_device {FLAGS.cuda_device}
        --batch_size {FLAGS.batch_size}
        --silent
        """
        utils.execute_command(command)

        output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu')
        utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed)

        output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu')
        utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed)


def main():
    app.run(run)


if __name__ == "__main__":
    main()