Skip to content
Snippets Groups Projects
Commit 426d24f1 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Add script for training enhanced dependency parsing models based on IWPT'20 Shared Task data.

parent 23e0c9ce
Branches
Tags
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
"""Script to train Enhanced Dependency Parsing models based on IWPT'20 Shared Task data.
Might require:
conda install -c bioconda perl-list-moreutils
conda install -c bioconda perl-namespace-autoclean
conda install -c bioconda perl-moose
conda install -c dan_blanchard perl-moosex-semiaffordanceaccessor
"""
import os
import pathlib
import subprocess
from typing import List
from absl import app
from absl import flags
FLAGS = flags.FLAGS
LANG2TREEBANK = {
"ar": ["Arabic-PADT"],
"bg": ["Bulgarian-BTB"],
"cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
"nl": ["Dutch-Alpino", "Dutch-LassySmall"],
"en": ["English-EWT", "English-PUD"],
"et": ["Estonian-EDT", "Estonian-EWT"],
"fi": ["Finnish-TDT", "Finnish-PUD"],
"fr": ["French-Sequoia", "French-FQB"],
"it": ["Italian-ISDT"],
"lv": ["Latvian-LVTB"],
"lt": ["Lithuanian-ALKSNIS"],
"pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
"ru": ["Russian-SynTagRus"],
"sk": ["Slovak-SNK"],
"sv": ["Swedish-Talbanken", "Swedish-PUD"],
"ta": ["Tamil-TTB"],
"uk": ["Ukrainian-IU"],
}
LANG2TRANSFORMER = {
"en": "bert-base-cased",
"pl": "allegro/herbert-base-cased",
}
flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
flags.DEFINE_string(name="data_dir", default="",
help="Path to 'iwpt2020stdata' directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def execute_command(command, output_file=None):
command = [c for c in command.split() if c.strip()]
if output_file:
with open(output_file, "w") as f:
subprocess.run(command, check=True, stdout=f)
else:
subprocess.run(command, check=True)
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
data_dir = pathlib.Path(FLAGS.data_dir)
assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
treebanks = LANG2TREEBANK[lang]
train_paths = []
dev_paths = []
test_paths = []
for treebank in treebanks:
treebank_dir = data_dir / f"UD_{treebank}"
assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
elif "test" in name:
collapse_nodes(data_dir, treebank_file, output)
test_paths.append(output)
lang_data_dir = pathlib.Path(data_dir / lang)
lang_data_dir.mkdir(exist_ok=True)
train_path = lang_data_dir / "train.conllu"
dev_path = lang_data_dir / "dev.conllu"
test_path = lang_data_dir / "test.conllu"
merge_files(train_paths, output=train_path)
merge_files(dev_paths, output=dev_path)
merge_files(test_paths, output=test_path)
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
serialization_dir.mkdir(exist_ok=True)
execute_command("".join(f"""combo --mode train
--training_data {train_path}
--validation_data {dev_path}
--targets feats,upostag,xpostag,head,deprel,lemma,deps
--pretrained_transformer_name {LANG2TRANSFORMER[lang]}
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--tensorboard
""".splitlines()))
def main():
app.run(run)
if __name__ == "__main__":
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment