Skip to content
Snippets Groups Projects
Select Git revision
  • 010ff5c1c4917cedeb41f82d9a28eab848ce8b05
  • master default protected
  • vertical_relations
  • lu_without_semantic_frames
  • hierarchy
  • additional-unification-filters
  • v0.1.1
  • v0.1.0
  • v0.0.9
  • v0.0.8
  • v0.0.7
  • v0.0.6
  • v0.0.5
  • v0.0.4
  • v0.0.3
  • v0.0.2
  • v0.0.1
17 results

views.py

Blame
  • predict_iwpt21.py 1.99 KiB
    import pathlib
    
    from absl import app
    from absl import flags
    
    from scripts import utils
    
    CODE2LANG = {
        "ar": "Arabic",
        "bg": "Bulgarian",
        "cs": "Czech",
        "nl": "Dutch",
        "en": "English",
        "et": "Estonian",
        "fi": "Finnish",
        "fr": "French",
        "it": "Italian",
        "lv": "Latvian",
        "lt": "Lithuanian",
        "pl": "Polish",
        "ru": "Russian",
        "sk": "Slovak",
        "sv": "Swedish",
        "ta": "Tamil",
        "uk": "Ukrainian",
    }
    
    FLAGS = flags.FLAGS
    flags.DEFINE_string(name="data_dir", default="",
                        help="Path to IWPT'21 data directory.")
    flags.DEFINE_string(name="models_dir", default="/tmp/",
                        help="Model serialization dir.")
    flags.DEFINE_integer(name="cuda_device", default=-1,
                         help="Cuda device id (-1 for cpu).")
    flags.DEFINE_boolean(name="expect_prefix", default=True,
                         help="Whether to expect allennlp prefix.")
    
    
    def run(_):
        models_dir = pathlib.Path(FLAGS.models_dir)
        for model_dir in models_dir.iterdir():
            lang = model_dir.name
            if lang not in CODE2LANG:
                print("Skipping unknown directory: ", lang)
                continue
    
            if FLAGS.expect_prefix:
                model_dir = list(model_dir.iterdir())
                assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
                model_dir = model_dir[0]
    
            data_dir = pathlib.Path(FLAGS.data_dir)
            files = list(data_dir.iterdir())
            test_file = [f for f in files if f"{lang}.conllu" == f.name]
            assert len(test_file) == 1, f"Couldn't find test file."
            test_file = test_file[0]
    
            output_pred = data_dir / f'{lang}_pred.conllu'
            command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
            --input_file {test_file}
            --output_file {output_pred}
            --cuda_device {FLAGS.cuda_device}
            --silent
            """
            utils.execute_command(command)
    
    
    def main():
        app.run(run)
    
    
    if __name__ == "__main__":
        main()