import pathlib from absl import app from absl import flags from scripts import utils CODE2LANG = { "ar": "Arabic", "bg": "Bulgarian", "cs": "Czech", "nl": "Dutch", "en": "English", "et": "Estonian", "fi": "Finnish", "fr": "French", "it": "Italian", "lv": "Latvian", "lt": "Lithuanian", "pl": "Polish", "ru": "Russian", "sk": "Slovak", "sv": "Swedish", "ta": "Tamil", "uk": "Ukrainian", } FLAGS = flags.FLAGS flags.DEFINE_string(name="data_dir", default="", help="Path to data directory.") flags.DEFINE_string(name="models_dir", default="/tmp/", help="Model serialization dir.") flags.DEFINE_string(name="tools", default="", help="UD tools path.") flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") flags.DEFINE_boolean(name="expect_prefix", default=True, help="Whether to expect allennlp prefix.") flags.DEFINE_integer(name="batch_size", default=32, help="Batch size.") def run(_): models_dir = pathlib.Path(FLAGS.models_dir) for model_dir in models_dir.iterdir(): lang = model_dir.name if lang not in CODE2LANG: print("Skipping unknown directory: ", lang) continue if FLAGS.expect_prefix: model_dir = list(model_dir.iterdir()) assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}" model_dir = model_dir[0] data_dir = pathlib.Path(FLAGS.data_dir) files = list(data_dir.iterdir()) test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name] # Try to use mwt file if it exists if test_file: assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file." test_file = test_file[0] else: test_file = [f for f in files if f"{lang}.conllu" == f.name] assert len(test_file) == 1, f"Couldn't find test file." test_file = test_file[0] output_pred = data_dir / f'{lang}_pred.conllu' command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} --input_file {test_file} --output_file {output_pred} --cuda_device {FLAGS.cuda_device} --batch_size {FLAGS.batch_size} --silent """ utils.execute_command(command) output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu') utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed) output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu') utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed) def main(): app.run(run) if __name__ == "__main__": main()