An error occurred while loading the file. Please try again.
-
Mateusz Klimaszewski authored4eaad34d
predict_iwpt21.py 2.83 KiB
import pathlib
from absl import app
from absl import flags
from scripts import utils
CODE2LANG = {
"ar": "Arabic",
"bg": "Bulgarian",
"cs": "Czech",
"nl": "Dutch",
"en": "English",
"et": "Estonian",
"fi": "Finnish",
"fr": "French",
"it": "Italian",
"lv": "Latvian",
"lt": "Lithuanian",
"pl": "Polish",
"ru": "Russian",
"sk": "Slovak",
"sv": "Swedish",
"ta": "Tamil",
"uk": "Ukrainian",
}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="",
help="Path to data directory.")
flags.DEFINE_string(name="models_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_string(name="tools", default="",
help="UD tools path.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
help="Batch size.")
def run(_):
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
lang = model_dir.name
if lang not in CODE2LANG:
print("Skipping unknown directory: ", lang)
continue
if FLAGS.expect_prefix:
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
data_dir = pathlib.Path(FLAGS.data_dir)
files = list(data_dir.iterdir())
test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name]
# Try to use mwt file if it exists
if test_file:
assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file."
test_file = test_file[0]
else:
test_file = [f for f in files if f"{lang}.conllu" == f.name]
assert len(test_file) == 1, f"Couldn't find test file."
test_file = test_file[0]
output_pred = data_dir / f'{lang}_pred.conllu'
command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--batch_size {FLAGS.batch_size}
--silent
"""
utils.execute_command(command)
output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu')
utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed)
output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu')
utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed)
def main():
app.run(run)
if __name__ == "__main__":
main()