diff --git a/docs/training.md b/docs/training.md index 7d7b0a8256e0c7896bf1372f817ab4e263b4f45e..2cdf75d0803e6ea12e5ce6343d8471579e90bf8c 100644 --- a/docs/training.md +++ b/docs/training.md @@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o ### Data pre-processing The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model. +The script is part of the [UD tools repository](https://github.com/UniversalDependencies/tools/). + ```bash perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu ``` diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py new file mode 100644 index 0000000000000000000000000000000000000000..5082f9dc7709bb8a2b3565c077e0f24279d27f0e --- /dev/null +++ b/scripts/train_iwpt21.py @@ -0,0 +1,134 @@ +"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data. + +For possible requirements, see train_eud.py comments. +""" + +import os +import pathlib +from typing import List + +from absl import app +from absl import flags + +from scripts import utils + +LANG2TREEBANK = { + "ar": ["Arabic-PADT"], + "bg": ["Bulgarian-BTB"], + "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"], + "nl": ["Dutch-Alpino", "Dutch-LassySmall"], + "en": ["English-EWT", "English-PUD", "UD_English-GUM"], + "et": ["Estonian-EDT", "Estonian-EWT"], + "fi": ["Finnish-TDT", "Finnish-PUD"], + "fr": ["French-Sequoia", "French-FQB"], + "it": ["Italian-ISDT"], + "lv": ["Latvian-LVTB"], + "lt": ["Lithuanian-ALKSNIS"], + "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"], + "ru": ["Russian-SynTagRus"], + "sk": ["Slovak-SNK"], + "sv": ["Swedish-Talbanken", "Swedish-PUD"], + "ta": ["Tamil-TTB"], + "uk": ["Ukrainian-IU"], +} + +FLAGS = flags.FLAGS +flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()), + help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.") +flags.DEFINE_string(name="data_dir", default="", + help="Path to 'iwpt2020stdata' directory.") +flags.DEFINE_string(name="serialization_dir", default="/tmp/", + help="Model serialization dir.") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") + + +def path_to_str(path: pathlib.Path) -> str: + return str(path.resolve()) + + +def merge_files(files: List[str], output: pathlib.Path): + if not output.exists(): + os.system(f"cat {' '.join(files)} > {output}") + + +def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) + + +def run(_): + languages = FLAGS.lang + for lang in languages: + assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}." + assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict." + data_dir = pathlib.Path(FLAGS.data_dir) + assert data_dir.is_dir(), f"'{data_dir}' is not a directory!" + + treebanks = LANG2TREEBANK[lang] + train_paths = [] + dev_paths = [] + + # TODO Uncomment when IWPT'21 Shared Task ends. + # During shared task duration test data is not available. + test_paths = [] + for treebank in treebanks: + treebank_dir = data_dir / f"UD_{treebank}" + assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists." + for treebank_file in treebank_dir.iterdir(): + name = treebank_file.name + if "conllu" in name and "fixed" not in name: + output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') + if "train" in name: + collapse_nodes(data_dir, treebank_file, output) + train_paths.append(output) + elif "dev" in name: + collapse_nodes(data_dir, treebank_file, output) + dev_paths.append(output) + # elif "test" in name: + # collapse_nodes(data_dir, treebank_file, output) + # test_paths.append(output) + + lang_data_dir = pathlib.Path(data_dir / lang) + lang_data_dir.mkdir(exist_ok=True) + + train_path = lang_data_dir / "train.conllu" + dev_path = lang_data_dir / "dev.conllu" + # TODO Uncomment + # test_path = lang_data_dir / "test.conllu" + + merge_files(train_paths, output=train_path) + merge_files(dev_paths, output=dev_path) + # TODO Uncomment + # merge_files(test_paths, output=test_path) + + serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang + serialization_dir.mkdir(exist_ok=True, parents=True) + + command = f"""combo --mode train + --training_data {train_path} + --validation_data {dev_path} + --targets feats,upostag,xpostag,head,deprel,lemma,deps + --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} + --serialization_dir {serialization_dir} + --cuda_device {FLAGS.cuda_device} + --word_batch_size 2500 + --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --notensorboard + """ + + # Datasets without XPOS + if lang in {"fr"}: + command = command + " --targets deprel,head,upostag,lemma,feats" + + utils.execute_command("".join(command.splitlines())) + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main()