Skip to content
Snippets Groups Projects
Commit c082b650 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add prediction script for IWPT'21.

parent a9ad1754
Branches
Tags
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
This commit is part of merge request !36. Comments created here will be created in the context of that merge request.
import pathlib
from absl import app
from absl import flags
from scripts import utils
CODE2LANG = {
"ar": "Arabic",
"bg": "Bulgarian",
"cs": "Czech",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"it": "Italian",
"lv": "Latvian",
"lt": "Lithuanian",
"pl": "Polish",
"ru": "Russian",
"sk": "Slovak",
"sv": "Swedish",
"ta": "Tamil",
"uk": "Ukrainian",
}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
help="Path to IWPT'21 data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
def run(_):
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
lang = model_dir.name
if lang not in CODE2LANG:
print("Skipping unknown directory: ", lang)
continue
if FLAGS.expect_prefix:
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
data_dir = pathlib.Path(FLAGS.data_dir)
files = list(data_dir.iterdir())
test_file = [f for f in files if f"{lang}.conllu" == f.name]
assert len(test_file) == 1, f"Couldn't find test file."
test_file = test_file[0]
output_pred = data_dir / f'{lang}_pred.conllu'
command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--silent
"""
utils.execute_command(command)
def main():
app.run(run)
if __name__ == "__main__":
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment