Commit 1a9d7362 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski

Add IWPT'21 shared task training script. Add information about UD tools repo.

parent 6fde9305
......@@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o
### Data pre-processing
The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model.
The script is part of the [UD tools repository](https://github.com/UniversalDependencies/tools/).
```bash
perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu
```
......
"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data.
For possible requirements, see train_eud.py comments.
"""
import os
import pathlib
from typing import List
from absl import app
from absl import flags
from scripts import utils
LANG2TREEBANK = {
"ar": ["Arabic-PADT"],
"bg": ["Bulgarian-BTB"],
"cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
"nl": ["Dutch-Alpino", "Dutch-LassySmall"],
"en": ["English-EWT", "English-PUD", "UD_English-GUM"],
"et": ["Estonian-EDT", "Estonian-EWT"],
"fi": ["Finnish-TDT", "Finnish-PUD"],
"fr": ["French-Sequoia", "French-FQB"],
"it": ["Italian-ISDT"],
"lv": ["Latvian-LVTB"],
"lt": ["Lithuanian-ALKSNIS"],
"pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
"ru": ["Russian-SynTagRus"],
"sk": ["Slovak-SNK"],
"sv": ["Swedish-Talbanken", "Swedish-PUD"],
"ta": ["Tamil-TTB"],
"uk": ["Ukrainian-IU"],
}
FLAGS = flags.FLAGS
flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
flags.DEFINE_string(name="data_dir", default="",
help="Path to 'iwpt2020stdata' directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
data_dir = pathlib.Path(FLAGS.data_dir)
assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
treebanks = LANG2TREEBANK[lang]
train_paths = []
dev_paths = []
# TODO Uncomment when IWPT'21 Shared Task ends.
# During shared task duration test data is not available.
test_paths = []
for treebank in treebanks:
treebank_dir = data_dir / f"UD_{treebank}"
assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
# elif "test" in name:
# collapse_nodes(data_dir, treebank_file, output)
# test_paths.append(output)
lang_data_dir = pathlib.Path(data_dir / lang)
lang_data_dir.mkdir(exist_ok=True)
train_path = lang_data_dir / "train.conllu"
dev_path = lang_data_dir / "dev.conllu"
# TODO Uncomment
# test_path = lang_data_dir / "test.conllu"
merge_files(train_paths, output=train_path)
merge_files(dev_paths, output=dev_path)
# TODO Uncomment
# merge_files(test_paths, output=test_path)
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
serialization_dir.mkdir(exist_ok=True, parents=True)
command = f"""combo --mode train
--training_data {train_path}
--validation_data {dev_path}
--targets feats,upostag,xpostag,head,deprel,lemma,deps
--pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--notensorboard
"""
# Datasets without XPOS
if lang in {"fr"}:
command = command + " --targets deprel,head,upostag,lemma,feats"
utils.execute_command("".join(command.splitlines()))
def main():
app.run(run)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment