Skip to content
Snippets Groups Projects
Commit 1f854b40 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add IWPT'21 shared task training script. Add information about UD tools repo.

parent 8c81af24
No related merge requests found
Pipeline #2877 passed with stage
in 4 minutes and 32 seconds
...@@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o ...@@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o
### Data pre-processing ### Data pre-processing
The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model. The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model.
The script is part of the [UD tools repository](https://github.com/UniversalDependencies/tools/).
```bash ```bash
perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu
``` ```
......
"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data.
For possible requirements, see train_eud.py comments.
"""
import os
import pathlib
from typing import List
from absl import app
from absl import flags
from scripts import utils
LANG2TREEBANK = {
"ar": ["Arabic-PADT"],
"bg": ["Bulgarian-BTB"],
"cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
"nl": ["Dutch-Alpino", "Dutch-LassySmall"],
"en": ["English-EWT", "English-PUD", "UD_English-GUM"],
"et": ["Estonian-EDT", "Estonian-EWT"],
"fi": ["Finnish-TDT", "Finnish-PUD"],
"fr": ["French-Sequoia", "French-FQB"],
"it": ["Italian-ISDT"],
"lv": ["Latvian-LVTB"],
"lt": ["Lithuanian-ALKSNIS"],
"pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
"ru": ["Russian-SynTagRus"],
"sk": ["Slovak-SNK"],
"sv": ["Swedish-Talbanken", "Swedish-PUD"],
"ta": ["Tamil-TTB"],
"uk": ["Ukrainian-IU"],
}
FLAGS = flags.FLAGS
flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
flags.DEFINE_string(name="data_dir", default="",
help="Path to 'iwpt2020stdata' directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def path_to_str(path: pathlib.Path) -> str:
return str(path.resolve())
def merge_files(files: List[str], output: pathlib.Path):
if not output.exists():
os.system(f"cat {' '.join(files)} > {output}")
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
def run(_):
languages = FLAGS.lang
for lang in languages:
assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
data_dir = pathlib.Path(FLAGS.data_dir)
assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
treebanks = LANG2TREEBANK[lang]
train_paths = []
dev_paths = []
# TODO Uncomment when IWPT'21 Shared Task ends.
# During shared task duration test data is not available.
test_paths = []
for treebank in treebanks:
treebank_dir = data_dir / f"UD_{treebank}"
assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
for treebank_file in treebank_dir.iterdir():
name = treebank_file.name
if "conllu" in name and "fixed" not in name:
output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
if "train" in name:
collapse_nodes(data_dir, treebank_file, output)
train_paths.append(output)
elif "dev" in name:
collapse_nodes(data_dir, treebank_file, output)
dev_paths.append(output)
# elif "test" in name:
# collapse_nodes(data_dir, treebank_file, output)
# test_paths.append(output)
lang_data_dir = pathlib.Path(data_dir / lang)
lang_data_dir.mkdir(exist_ok=True)
train_path = lang_data_dir / "train.conllu"
dev_path = lang_data_dir / "dev.conllu"
# TODO Uncomment
# test_path = lang_data_dir / "test.conllu"
merge_files(train_paths, output=train_path)
merge_files(dev_paths, output=dev_path)
# TODO Uncomment
# merge_files(test_paths, output=test_path)
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
serialization_dir.mkdir(exist_ok=True, parents=True)
command = f"""combo --mode train
--training_data {train_path}
--validation_data {dev_path}
--targets feats,upostag,xpostag,head,deprel,lemma,deps
--pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--notensorboard
"""
# Datasets without XPOS
if lang in {"fr"}:
command = command + " --targets deprel,head,upostag,lemma,feats"
utils.execute_command("".join(command.splitlines()))
def main():
app.run(run)
if __name__ == "__main__":
main()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment