Skip to content
Snippets Groups Projects
train.py 8.30 KiB
"""Script to train Dependency Parsing models based on UD 2.x data."""
import pathlib

from absl import app
from absl import flags

from scripts import utils

# # ls -1 | xargs -i echo "\"{}\","
# UD 2.7
TREEBANKS = [
    "UD_Afrikaans-AfriBooms",
    # "UD_Albanian-TSA", No training data
    # "UD_Amharic-ATT", No training data
    "UD_Arabic-NYUAD",
    "UD_Arabic-PADT",
    "UD_Arabic-PUD",
    "UD_Armenian-ArmTDP",
    # "UD_Assyrian-AS", No training data
    # "UD_Bambara-CRB", No training data
    "UD_Basque-BDT",
    "UD_Belarusian-HSE",
    # "UD_Breton-KEB", No training data
    "UD_Bulgarian-BTB",
    "UD_Buryat-BDT",
    "UD_Cantonese-HK",
    "UD_Catalan-AnCora",
    "UD_Chinese-CFL",
    "UD_Chinese-GSD",
    "UD_Chinese-GSDSimp",
    "UD_Chinese-HK",
    "UD_Chinese-PUD",
    "UD_Chukchi-HSE",
    "UD_Classical_Chinese-Kyoto",
    "UD_Coptic-Scriptorium",
    "UD_Croatian-SET",
    "UD_Czech-CAC",
    "UD_Czech-CLTT",
    "UD_Czech-FicTree",
    "UD_Czech-PDT",
    "UD_Czech-PUD",
    "UD_Danish-DDT",
    "UD_Dutch-Alpino",
    "UD_English-EWT",
    # "UD_Erzya-JR", No training data
    "UD_Estonian-EWT",
    "UD_Faroese-FarPaHC",
    "UD_Faroese-OFT",
    "UD_Finnish-FTB",
    "UD_Finnish-OOD",
    "UD_Finnish-PUD",
    "UD_Finnish-TDT",
    "UD_French-FQB",
    "UD_French-FTB",
    "UD_French-GSD",
    "UD_French-ParTUT",
    "UD_French-PUD",
    "UD_French-Sequoia",
    "UD_French-Spoken",
    "UD_Galician-CTG",
    "UD_Galician-TreeGal",
    "UD_German-GSD",
    "UD_German-HDT",
    "UD_German-LIT",
    "UD_German-PUD",
    "UD_Gothic-PROIEL",
    "UD_Greek-GDT",
    "UD_Hebrew-HTB",
    "UD_Hindi_English-HIENCS",
    "UD_Hindi-HDTB",
    "UD_Hindi-PUD",
    "UD_Hungarian-Szeged",
    "UD_Icelandic-IcePaHC",
    "UD_Icelandic-PUD",
    "UD_Indonesian-CSUI",
    "UD_Indonesian-GSD",
    "UD_Indonesian-PUD",
    "UD_Irish-IDT",
    "UD_Italian-ISDT",
    "UD_Italian-ParTUT",
    "UD_Italian-PoSTWITA",
    "UD_Italian-PUD",
    "UD_Italian-TWITTIRO",
    "UD_Italian-VIT",
    # "UD_Japanese-BCCWJ", No public data
    "UD_Japanese-GSD",
    "UD_Japanese-Modern",
    "UD_Japanese-PUD",
    "UD_Karelian-KKPP",
    "UD_Kazakh-KTB",
    "UD_Khunsari-AHA",
    "UD_Komi_Permyak-UH",
    "UD_Komi_Zyrian-IKDP",
    "UD_Komi_Zyrian-Lattice",
    "UD_Korean-GSD",
    "UD_Korean-Kaist",
    "UD_Korean-PUD",
    "UD_Kurmanji-MG",
    "UD_Latin-ITTB",
    "UD_Latin-LLCT",
    "UD_Latin-Perseus",
    "UD_Latin-PROIEL",
    "UD_Latvian-LVTB",
    "UD_Lithuanian-ALKSNIS",
    "UD_Lithuanian-HSE",
    "UD_Maltese-MUDT",
    # "UD_Manx-Cadhan", No training data
    "UD_Marathi-UFAL",
    "UD_Mbya_Guarani-Dooley",
    "UD_Mbya_Guarani-Thomas",
    "UD_Moksha-JR",
    "UD_Munduruku-TuDeT",
    "UD_Naija-NSC",
    "UD_Nayini-AHA",
    "UD_North_Sami-Giella",
    "UD_Norwegian-Bokmaal",
    "UD_Norwegian-Nynorsk",
    "UD_Norwegian-NynorskLIA",
    "UD_Old_Church_Slavonic-PROIEL",
    "UD_Old_French-SRCMF",
    "UD_Old_Russian-RNC",
    "UD_Old_Russian-TOROT",
    "UD_Old_Turkish-Tonqq",
    "UD_Persian-PerDT",
    "UD_Persian-Seraji",
    "UD_Polish-LFG",
    "UD_Polish-PDB",
    "UD_Polish-PUD",
    "UD_Portuguese-Bosque",
    "UD_Portuguese-GSD",
    "UD_Portuguese-PUD",
    "UD_Romanian-Nonstandard",
    "UD_Romanian-RRT",
    "UD_Romanian-SiMoNERo",
    "UD_Russian-GSD",
    "UD_Russian-PUD",
    "UD_Russian-SynTagRus",
    "UD_Russian-Taiga",
    # "UD_Sanskrit-UFAL", No training data
    "UD_Scottish_Gaelic-ARCOSG",
    "UD_Serbian-SET",
    "UD_Skolt_Sami-Giellagas",
    "UD_Slovak-SNK",
    "UD_Slovenian-SSJ",
    "UD_Slovenian-SST",
    "UD_Soi-AHA",
    "UD_South_Levantine_Arabic-MADAR",
    "UD_Spanish-AnCora",
    "UD_Spanish-GSD",
    "UD_Spanish-PUD",
    "UD_Swedish-LinES",
    # "UD_Tagalog-TRG", No training data
    # "UD_Tamil-MWTT", No training data
    "UD_Telugu-MTG",
    # "UD_Thai-PUD", No training data
    "UD_Turkish-BOUN",
    "UD_Turkish-GB",
    "UD_Turkish_German-SAGT",
    "UD_Turkish-IMST",
    "UD_Turkish-PUD",
    "UD_Ukrainian-IU",
    # "UD_Upper_Sorbian-UFAL", No validation data
    "UD_Urdu-UDTB",
    "UD_Uyghur-UDT",
    "UD_Vietnamese-VTB",
    # "UD_Welsh-CCG", No validation data
    # "UD_Yoruba-YTB", No training data
]

FLAGS = flags.FLAGS
flags.DEFINE_list(name="treebanks", default=TREEBANKS,
                  help=f"Treebanks to train. Possible values: {TREEBANKS}.")
flags.DEFINE_string(name="data_dir", default="",
                    help="Path to UD data directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
                    help="Model serialization directory.")
flags.DEFINE_string(name="embeddings_dir", default="",
                    help="Path to embeddings directory (with languages as subdirectories).")
flags.DEFINE_integer(name="cuda_device", default=-1,
                     help="Cuda device id (-1 for cpu).")


def run(_):
    treebanks_dir = pathlib.Path(FLAGS.data_dir)
    for treebank in FLAGS.treebanks:
        assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
        treebank_dir = treebanks_dir / treebank
        treebank_parts = treebank[3:].split("-")
        language = treebank_parts[0]

        files = list(treebank_dir.iterdir())

        training_file = [f for f in files if "train" in f.name and ".conllu" in f.name]
        assert len(training_file) == 1, f"Couldn't find training file."
        training_file_path = training_file[0]

        valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
        assert len(valid_file) == 1, f"Couldn't find validation file."
        valid_file_path = valid_file[0]

        embeddings_dir = FLAGS.embeddings_dir
        embeddings_file = None
        if embeddings_dir:
            embeddings_dir = pathlib.Path(embeddings_dir) / language
            embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name]
            assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
            embeddings_file = embeddings_file[0]

        language = training_file_path.name.split("_")[0]

        serialization_dir = pathlib.Path(FLAGS.serialization_dir) / treebank
        serialization_dir.mkdir(exist_ok=True, parents=True)

        command = f"""time combo --mode train
        --cuda_device {FLAGS.cuda_device}
        --training_data_path {training_file_path}
        --validation_data_path {valid_file_path}
        {f"--pretrained_tokens {embeddings_file}" if embeddings_dir
        else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
        --serialization_dir {serialization_dir}
        --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
        --notensorboard
        """

        # Datasets without XPOS
        if treebank in {'UD_Danish-DDT', 'UD_Western_Armenian-ArmTDP', 'UD_Basque-BDT', 'UD_Hungarian-Szeged', 'UD_Russian-Taiga', 'UD_Portuguese-Bosque', 'UD_Norwegian-NynorskLIA', 'UD_Turkish-Penn', 'UD_French-GSD', 'UD_Armenian-ArmTDP'}:
            command = command + " --targets deprel,head,upostag,lemma,feats"

        # Datasets without FEATS
        if treebank in {'UD_Galician-CTG', 'UD_Italian-ISDT', 'UD_Korean-Kaist', 'UD_Korean-GSD'}:
            command = command + " --targets deprel,head,upostag,xpostag,lemma"

        # Datasets without LEMMA
        if treebank in {'UD_Old_French-SRCMF'}:
            command = command + " --targets deprel,head,upostag,xpostag,feats"

        # Datasets without XPOS and LEMMA
        if treebank in {}:
            command = command + " --targets deprel,head,upostag,feats"

        # Datasets without LEMMA and FEATS
        if treebank in {'UD_English-ESL', 'UD_Maltese-MUDT', 'UD_Swedish_Sign_Language-SSLC'}:
            command = command + " --targets deprel,head,upostag,xpostag"

        # Datasets without XPOS and FEATS
        if treebank in {}:
            command = command + " --targets deprel,head,upostag,lemma"

        # Datasets without XPOS, FEATS and LEMMA
        if treebank in {}:
            command = command + " --targets deprel,head,upostag"

        # Reduce word_batch_size
        word_batch_size = 2500
        if treebank in {"UD_German-HDT", "UD_Marathi-UFAL"}:
            word_batch_size = 1000
        elif treebank in {"UD_Telugu-MTG"}:
            word_batch_size = 500
        command = command + f" --word_batch_size {word_batch_size}"

        utils.execute_command(command)


def main():
    app.run(run)


if __name__ == "__main__":
    main()