Skip to content
Snippets Groups Projects
Select Git revision
  • c082b65055c8c728cf4279a7c19e9d0409c3000c
  • main default protected
  • ud_training_script
  • fix_seed
  • merged-with-ner
  • multiword_fix_transformer
  • transformer_encoder
  • combo3
  • save_deprel_matrix_to_npz
  • master protected
  • combo-lambo
  • lambo-sent-attributes
  • adding_lambo
  • develop
  • update_allenlp2
  • develop_tmp
  • tokens_truncation
  • LR_test
  • eud_iwpt
  • iob
  • eud_iwpt_shared_task_bert_finetuning
  • 3.3.1
  • list
  • 3.2.1
  • 3.0.3
  • 3.0.1
  • 3.0.0
  • v1.0.6
  • v1.0.5
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.1
  • v1.0.0
34 results

predict_iwpt21.py

Blame
  • predict_iwpt21.py 1.99 KiB
    import pathlib
    
    from absl import app
    from absl import flags
    
    from scripts import utils
    
    CODE2LANG = {
        "ar": "Arabic",
        "bg": "Bulgarian",
        "cs": "Czech",
        "nl": "Dutch",
        "en": "English",
        "et": "Estonian",
        "fi": "Finnish",
        "fr": "French",
        "it": "Italian",
        "lv": "Latvian",
        "lt": "Lithuanian",
        "pl": "Polish",
        "ru": "Russian",
        "sk": "Slovak",
        "sv": "Swedish",
        "ta": "Tamil",
        "uk": "Ukrainian",
    }
    
    FLAGS = flags.FLAGS
    flags.DEFINE_string(name="data_dir", default="",
                        help="Path to IWPT'21 data directory.")
    flags.DEFINE_string(name="models_dir", default="/tmp/",
                        help="Model serialization dir.")
    flags.DEFINE_integer(name="cuda_device", default=-1,
                         help="Cuda device id (-1 for cpu).")
    flags.DEFINE_boolean(name="expect_prefix", default=True,
                         help="Whether to expect allennlp prefix.")
    
    
    def run(_):
        models_dir = pathlib.Path(FLAGS.models_dir)
        for model_dir in models_dir.iterdir():
            lang = model_dir.name
            if lang not in CODE2LANG:
                print("Skipping unknown directory: ", lang)
                continue
    
            if FLAGS.expect_prefix:
                model_dir = list(model_dir.iterdir())
                assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
                model_dir = model_dir[0]
    
            data_dir = pathlib.Path(FLAGS.data_dir)
            files = list(data_dir.iterdir())
            test_file = [f for f in files if f"{lang}.conllu" == f.name]
            assert len(test_file) == 1, f"Couldn't find test file."
            test_file = test_file[0]
    
            output_pred = data_dir / f'{lang}_pred.conllu'
            command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
            --input_file {test_file}
            --output_file {output_pred}
            --cuda_device {FLAGS.cuda_device}
            --silent
            """
            utils.execute_command(command)
    
    
    def main():
        app.run(run)
    
    
    if __name__ == "__main__":
        main()