Skip to content
Snippets Groups Projects
Select Git revision
  • b6d108a4f7b23dbb33b51e3a853c8b9a8953640d
  • master default protected
  • vertical_relations
  • lu_without_semantic_frames
  • hierarchy
  • additional-unification-filters
  • v0.1.1
  • v0.1.0
  • v0.0.9
  • v0.0.8
  • v0.0.7
  • v0.0.6
  • v0.0.5
  • v0.0.4
  • v0.0.3
  • v0.0.2
  • v0.0.1
17 results

models.py

Blame
  • train_iwpt21.py 5.56 KiB
    """Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data.
    
    For possible requirements, see train_eud.py comments.
    """
    
    import os
    import pathlib
    from typing import List
    
    from absl import app
    from absl import flags
    
    from scripts import utils
    
    LANG2TREEBANK = {
        "ar": ["Arabic-PADT"],
        "bg": ["Bulgarian-BTB"],
        "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"],
        "nl": ["Dutch-Alpino", "Dutch-LassySmall"],
        "en": ["English-EWT", "English-PUD", "English-GUM"],
        "et": ["Estonian-EDT", "Estonian-EWT"],
        "fi": ["Finnish-TDT", "Finnish-PUD"],
        "fr": ["French-Sequoia", "French-FQB"],
        "it": ["Italian-ISDT"],
        "lv": ["Latvian-LVTB"],
        "lt": ["Lithuanian-ALKSNIS"],
        "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"],
        "ru": ["Russian-SynTagRus"],
        "sk": ["Slovak-SNK"],
        "sv": ["Swedish-Talbanken", "Swedish-PUD"],
        "ta": ["Tamil-TTB"],
        "uk": ["Ukrainian-IU"],
    }
    
    FLAGS = flags.FLAGS
    flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
                      help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
    flags.DEFINE_string(name="data_dir", default="",
                        help="Path to IWPT'21 data directory.")
    flags.DEFINE_string(name="serialization_dir", default="/tmp/",
                        help="Model serialization dir.")
    flags.DEFINE_integer(name="cuda_device", default=-1,
                         help="Cuda device id (-1 for cpu).")
    
    
    def merge_files(files: List[str], output: pathlib.Path):
        if not output.exists():
            os.system(f"cat {' '.join(files)} > {output}")
    
    
    def run(_):
        languages = FLAGS.lang
        for lang in languages:
            assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
            assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
            data_dir = pathlib.Path(FLAGS.data_dir)
            assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
    
            treebanks = LANG2TREEBANK[lang]
            full_language = treebanks[0].split("-")[0]
            train_paths = []
            dev_paths = []
            train_raw_paths = []
            dev_raw_paths = []
            # TODO Uncomment when IWPT'21 Shared Task ends.
            # During shared task duration test data is not available.
            test_paths = []
            for treebank in treebanks:
                treebank_dir = data_dir / f"UD_{treebank}"
                assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists."
                for treebank_file in treebank_dir.iterdir():
                    name = treebank_file.name
                    if "conllu" in name and "fixed" not in name:
                        output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu')
                        if "train" in name:
                            utils.collapse_nodes(data_dir, treebank_file, output)
                            train_paths.append(output)
                        elif "dev" in name:
                            utils.collapse_nodes(data_dir, treebank_file, output)
                            dev_paths.append(output)
                        # elif "test" in name:
                        #     collapse_nodes(data_dir, treebank_file, output)
                        #     test_paths.append(output)
                    if ".txt" in name:
                        if "train" in name:
                            train_raw_paths.append(utils.path_to_str(treebank_file))
                        elif "dev" in name:
                            dev_raw_paths.append(utils.path_to_str(treebank_file))
    
            merged_dataset_name = "IWPT"
            lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}")
            lang_data_dir.mkdir(exist_ok=True)
    
            suffix = f"{lang}_{merged_dataset_name}-ud".lower()
            train_path = lang_data_dir / f"{suffix}-train.conllu"
            dev_path = lang_data_dir / f"{suffix}-dev.conllu"
            test_path = lang_data_dir / f"{suffix}-test.conllu"
            train_raw_path = lang_data_dir / f"{suffix}-train.txt"
            dev_raw_path = lang_data_dir / f"{suffix}-dev.txt"
            test_raw_path = lang_data_dir / f"{suffix}-test.txt"
    
            merge_files(train_paths, output=train_path)
            merge_files(dev_paths, output=dev_path)
            # TODO Change to test_paths instead of dev_paths after IWPT'21
            merge_files(dev_paths, output=test_path)
    
            merge_files(train_raw_paths, output=train_raw_path)
            merge_files(dev_raw_paths, output=dev_raw_path)
            # TODO Change to test_raw_paths instead of dev_paths after IWPT'21
            merge_files(dev_raw_paths, output=test_raw_path)
    
            serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
            serialization_dir.mkdir(exist_ok=True, parents=True)
    
            command = f"""combo --mode train
            --training_data {train_path}
            --validation_data {dev_path}
            --targets feats,upostag,xpostag,head,deprel,lemma,deps
            --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
            --serialization_dir {serialization_dir}
            --cuda_device {FLAGS.cuda_device}
            --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'}
            --notensorboard
            """
    
            # Datasets without XPOS
            if lang in {"fr", "ru"}:
                command = command + " --targets deprel,head,upostag,lemma,feats"
    
            if lang in {"ta"}:
                command = command + " --word_batch_size 500"
            else:
                command = command + " --word_batch_size 2500"
    
            utils.execute_command("".join(command.splitlines()))
    
    
    def main():
        app.run(run)
    
    
    if __name__ == "__main__":
        main()