Skip to content
Snippets Groups Projects
Commit 1a9d7362 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add IWPT'21 shared task training script. Add information about UD tools repo.

parent 6fde9305
Branches
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
......@@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o
### Data pre-processing
The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model.
The script is part of the [UD tools repository](https://github.com/UniversalDependencies/tools/).
```bash
perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu
```
......
"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data.
For possible requirements, see train_eud.py comments.
"""
import os
import pathlib
from typing import List
from absl import app
from absl import flags
from scripts import utils
LANG2TREEBANK = {
"ar": ["Arabic-PADT"],
"bg": ["Bulgarian-BTB"],
"cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
"nl": ["Dutch-Alpino", "Dutch-LassySmall"],
"en": ["English-EWT", "English-PUD", "UD_English-GUM"],
"et": ["Estonian-EDT", "Estonian-EWT"],
"fi": ["Finnish-TDT", "Finnish-PUD"],
"fr": ["French-Sequoia", "French-FQB"],
"it": ["Italian-ISDT"],
"lv": ["Latvian-LVTB"],
"lt": ["Lithuanian-ALKSNIS"],
"pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
"ru": ["Russian-SynTagRus"],
"sk": ["Slovak-SNK"],
"sv": ["Swedish-Talbanken", "Swedish-PUD"],
"ta": ["Tamil-TTB"],
"uk": ["Ukrainian-IU"],
}
FLAGS = flags.FLAGS
flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
flags.DEFINE_string(name="data_dir", default="",
help="Path to 'iwpt2020stdata' directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
data_dir = pathlib.Path(FLAGS.data_dir)
assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
treebanks = LANG2TREEBANK[lang]
train_paths = []
dev_paths = []
# TODO Uncomment when IWPT'21 Shared Task ends.
# During shared task duration test data is not available.
test_paths = []
for treebank in treebanks:
treebank_dir = data_dir / f"UD_{treebank}"
assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
# elif "test" in name:
# collapse_nodes(data_dir, treebank_file, output)
# test_paths.append(output)
lang_data_dir = pathlib.Path(data_dir / lang)
lang_data_dir.mkdir(exist_ok=True)
train_path = lang_data_dir / "train.conllu"
dev_path = lang_data_dir / "dev.conllu"
# TODO Uncomment
# test_path = lang_data_dir / "test.conllu"
merge_files(train_paths, output=train_path)
merge_files(dev_paths, output=dev_path)
# TODO Uncomment
# merge_files(test_paths, output=test_path)
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
serialization_dir.mkdir(exist_ok=True, parents=True)
command = f"""combo --mode train
--training_data {train_path}
--validation_data {dev_path}
--targets feats,upostag,xpostag,head,deprel,lemma,deps
--pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--notensorboard
"""
# Datasets without XPOS
if lang in {"fr"}:
command = command + " --targets deprel,head,upostag,lemma,feats"
utils.execute_command("".join(command.splitlines()))
def main():
app.run(run)
if __name__ == "__main__":
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment