"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data. For possible requirements, see train_eud.py comments. """ import os import pathlib from typing import List from absl import app from absl import flags from scripts import utils LANG2TREEBANK = { "ar": ["Arabic-PADT"], "bg": ["Bulgarian-BTB"], "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"], "nl": ["Dutch-Alpino", "Dutch-LassySmall"], "en": ["English-EWT", "English-PUD", "English-GUM"], "et": ["Estonian-EDT", "Estonian-EWT"], "fi": ["Finnish-TDT", "Finnish-PUD"], "fr": ["French-Sequoia", "French-FQB"], "it": ["Italian-ISDT"], "lv": ["Latvian-LVTB"], "lt": ["Lithuanian-ALKSNIS"], "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"], "ru": ["Russian-SynTagRus"], "sk": ["Slovak-SNK"], "sv": ["Swedish-Talbanken", "Swedish-PUD"], "ta": ["Tamil-TTB"], "uk": ["Ukrainian-IU"], } FLAGS = flags.FLAGS flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()), help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.") flags.DEFINE_string(name="data_dir", default="", help="Path to IWPT'21 data directory.") flags.DEFINE_string(name="serialization_dir", default="/tmp/", help="Model serialization dir.") flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") def path_to_str(path: pathlib.Path) -> str: return str(path.resolve()) def merge_files(files: List[str], output: pathlib.Path): if not output.exists(): os.system(f"cat {' '.join(files)} > {output}") def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): output_path = pathlib.Path(output) if not output_path.exists(): utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " f"{path_to_str(treebank_file)}", output) def run(_): languages = FLAGS.lang for lang in languages: assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}." assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict." data_dir = pathlib.Path(FLAGS.data_dir) assert data_dir.is_dir(), f"'{data_dir}' is not a directory!" treebanks = LANG2TREEBANK[lang] full_language = treebanks[0].split("-")[0] train_paths = [] dev_paths = [] train_raw_paths = [] dev_raw_paths = [] # TODO Uncomment when IWPT'21 Shared Task ends. # During shared task duration test data is not available. test_paths = [] for treebank in treebanks: treebank_dir = data_dir / f"UD_{treebank}" assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists." for treebank_file in treebank_dir.iterdir(): name = treebank_file.name if "conllu" in name and "fixed" not in name: output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') if "train" in name: collapse_nodes(data_dir, treebank_file, output) train_paths.append(output) elif "dev" in name: collapse_nodes(data_dir, treebank_file, output) dev_paths.append(output) # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) # test_paths.append(output) if ".txt" in name: if "train" in name: train_raw_paths.append(path_to_str(treebank_file)) elif "dev" in name: dev_raw_paths.append(path_to_str(treebank_file)) merged_dataset_name = "IWPT" lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}") lang_data_dir.mkdir(exist_ok=True) suffix = f"{lang}_{merged_dataset_name}-ud".lower() train_path = lang_data_dir / f"{suffix}-train.conllu" dev_path = lang_data_dir / f"{suffix}-dev.conllu" test_path = lang_data_dir / f"{suffix}-test.conllu" train_raw_path = lang_data_dir / f"{suffix}-train.txt" dev_raw_path = lang_data_dir / f"{suffix}-dev.txt" test_raw_path = lang_data_dir / f"{suffix}-test.txt" merge_files(train_paths, output=train_path) merge_files(dev_paths, output=dev_path) # TODO Change to test_paths instead of dev_paths after IWPT'21 merge_files(dev_paths, output=test_path) merge_files(train_raw_paths, output=train_raw_path) merge_files(dev_raw_paths, output=dev_raw_path) # TODO Change to test_raw_paths instead of dev_paths after IWPT'21 merge_files(dev_raw_paths, output=test_raw_path) serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang serialization_dir.mkdir(exist_ok=True, parents=True) command = f"""combo --mode train --training_data {train_path} --validation_data {dev_path} --targets feats,upostag,xpostag,head,deprel,lemma,deps --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} --word_batch_size 2500 --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """ # Datasets without XPOS if lang in {"fr"}: command = command + " --targets deprel,head,upostag,lemma,feats" utils.execute_command("".join(command.splitlines())) def main(): app.run(run) if __name__ == "__main__": main()