Skip to content
Snippets Groups Projects
Commit a17f6c4b authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add prediction script for IWPT'21.

parent ee3d3dcf
No related merge requests found
import pathlib
from absl import app
from absl import flags
from scripts import utils
CODE2LANG = {
"ar": "Arabic",
"bg": "Bulgarian",
"cs": "Czech",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"it": "Italian",
"lv": "Latvian",
"lt": "Lithuanian",
"pl": "Polish",
"ru": "Russian",
"sk": "Slovak",
"sv": "Swedish",
"ta": "Tamil",
"uk": "Ukrainian",
}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
help="Path to IWPT'21 data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
def run(_):
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
lang = model_dir.name
if lang not in CODE2LANG:
print("Skipping unknown directory: ", lang)
continue
if FLAGS.expect_prefix:
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
data_dir = pathlib.Path(FLAGS.data_dir)
files = list(data_dir.iterdir())
test_file = [f for f in files if f"{lang}.conllu" == f.name]
assert len(test_file) == 1, f"Couldn't find test file."
test_file = test_file[0]
output_pred = data_dir / f'{lang}_pred.conllu'
command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--silent
"""
utils.execute_command(command)
def main():
app.run(run)
if __name__ == "__main__":
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment