An error occurred while loading the file. Please try again.
-
adding scripts to train new UD with instruction, fixing numpy problem on Mac(M1), fixing link to download models
46c68f26
train.py 8.30 KiB
"""Script to train Dependency Parsing models based on UD 2.x data."""
import pathlib
from absl import app
from absl import flags
from scripts import utils
# # ls -1 | xargs -i echo "\"{}\","
# UD 2.7
TREEBANKS = [
"UD_Afrikaans-AfriBooms",
# "UD_Albanian-TSA", No training data
# "UD_Amharic-ATT", No training data
"UD_Arabic-NYUAD",
"UD_Arabic-PADT",
"UD_Arabic-PUD",
"UD_Armenian-ArmTDP",
# "UD_Assyrian-AS", No training data
# "UD_Bambara-CRB", No training data
"UD_Basque-BDT",
"UD_Belarusian-HSE",
# "UD_Breton-KEB", No training data
"UD_Bulgarian-BTB",
"UD_Buryat-BDT",
"UD_Cantonese-HK",
"UD_Catalan-AnCora",
"UD_Chinese-CFL",
"UD_Chinese-GSD",
"UD_Chinese-GSDSimp",
"UD_Chinese-HK",
"UD_Chinese-PUD",
"UD_Chukchi-HSE",
"UD_Classical_Chinese-Kyoto",
"UD_Coptic-Scriptorium",
"UD_Croatian-SET",
"UD_Czech-CAC",
"UD_Czech-CLTT",
"UD_Czech-FicTree",
"UD_Czech-PDT",
"UD_Czech-PUD",
"UD_Danish-DDT",
"UD_Dutch-Alpino",
"UD_English-EWT",
# "UD_Erzya-JR", No training data
"UD_Estonian-EWT",
"UD_Faroese-FarPaHC",
"UD_Faroese-OFT",
"UD_Finnish-FTB",
"UD_Finnish-OOD",
"UD_Finnish-PUD",
"UD_Finnish-TDT",
"UD_French-FQB",
"UD_French-FTB",
"UD_French-GSD",
"UD_French-ParTUT",
"UD_French-PUD",
"UD_French-Sequoia",
"UD_French-Spoken",
"UD_Galician-CTG",
"UD_Galician-TreeGal",
"UD_German-GSD",
"UD_German-HDT",
"UD_German-LIT",
"UD_German-PUD",
"UD_Gothic-PROIEL",
"UD_Greek-GDT",
"UD_Hebrew-HTB",
"UD_Hindi_English-HIENCS",
"UD_Hindi-HDTB",
"UD_Hindi-PUD",
"UD_Hungarian-Szeged",
"UD_Icelandic-IcePaHC",
"UD_Icelandic-PUD",
"UD_Indonesian-CSUI",
"UD_Indonesian-GSD",
"UD_Indonesian-PUD",
"UD_Irish-IDT",
"UD_Italian-ISDT",
"UD_Italian-ParTUT",
"UD_Italian-PoSTWITA",
"UD_Italian-PUD",
"UD_Italian-TWITTIRO",
"UD_Italian-VIT",
# "UD_Japanese-BCCWJ", No public data
"UD_Japanese-GSD",
"UD_Japanese-Modern",
"UD_Japanese-PUD",
"UD_Karelian-KKPP",
"UD_Kazakh-KTB",
"UD_Khunsari-AHA",
"UD_Komi_Permyak-UH",
"UD_Komi_Zyrian-IKDP",
"UD_Komi_Zyrian-Lattice",
"UD_Korean-GSD",
"UD_Korean-Kaist",
"UD_Korean-PUD",
"UD_Kurmanji-MG",
"UD_Latin-ITTB",
"UD_Latin-LLCT",
"UD_Latin-Perseus",
"UD_Latin-PROIEL",
"UD_Latvian-LVTB",
"UD_Lithuanian-ALKSNIS",
"UD_Lithuanian-HSE",
"UD_Maltese-MUDT",
# "UD_Manx-Cadhan", No training data
"UD_Marathi-UFAL",
"UD_Mbya_Guarani-Dooley",
"UD_Mbya_Guarani-Thomas",
"UD_Moksha-JR",
"UD_Munduruku-TuDeT",
"UD_Naija-NSC",
"UD_Nayini-AHA",
"UD_North_Sami-Giella",
"UD_Norwegian-Bokmaal",
"UD_Norwegian-Nynorsk",
"UD_Norwegian-NynorskLIA",
"UD_Old_Church_Slavonic-PROIEL",
"UD_Old_French-SRCMF",
"UD_Old_Russian-RNC",
"UD_Old_Russian-TOROT",
"UD_Old_Turkish-Tonqq",
"UD_Persian-PerDT",
"UD_Persian-Seraji",
"UD_Polish-LFG",
"UD_Polish-PDB",
"UD_Polish-PUD",
"UD_Portuguese-Bosque",
"UD_Portuguese-GSD",
"UD_Portuguese-PUD",
"UD_Romanian-Nonstandard",
"UD_Romanian-RRT",
"UD_Romanian-SiMoNERo",
"UD_Russian-GSD",
"UD_Russian-PUD",
"UD_Russian-SynTagRus",
"UD_Russian-Taiga",
# "UD_Sanskrit-UFAL", No training data
"UD_Scottish_Gaelic-ARCOSG",
"UD_Serbian-SET",
"UD_Skolt_Sami-Giellagas",
"UD_Slovak-SNK",
"UD_Slovenian-SSJ",
"UD_Slovenian-SST",
"UD_Soi-AHA",
"UD_South_Levantine_Arabic-MADAR",
"UD_Spanish-AnCora",
"UD_Spanish-GSD",
"UD_Spanish-PUD",
"UD_Swedish-LinES",
# "UD_Tagalog-TRG", No training data
# "UD_Tamil-MWTT", No training data
"UD_Telugu-MTG",
# "UD_Thai-PUD", No training data
"UD_Turkish-BOUN",
"UD_Turkish-GB",
"UD_Turkish_German-SAGT",
"UD_Turkish-IMST",
"UD_Turkish-PUD",
"UD_Ukrainian-IU",
# "UD_Upper_Sorbian-UFAL", No validation data
"UD_Urdu-UDTB",
"UD_Uyghur-UDT",
"UD_Vietnamese-VTB",
# "UD_Welsh-CCG", No validation data
# "UD_Yoruba-YTB", No training data
]
FLAGS = flags.FLAGS
flags.DEFINE_list(name="treebanks", default=TREEBANKS,
help=f"Treebanks to train. Possible values: {TREEBANKS}.")
flags.DEFINE_string(name="data_dir", default="",
help="Path to UD data directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
help="Model serialization directory.")
flags.DEFINE_string(name="embeddings_dir", default="",
help="Path to embeddings directory (with languages as subdirectories).")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def run(_):
treebanks_dir = pathlib.Path(FLAGS.data_dir)
for treebank in FLAGS.treebanks:
assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
treebank_dir = treebanks_dir / treebank
treebank_parts = treebank[3:].split("-")
language = treebank_parts[0]
files = list(treebank_dir.iterdir())
training_file = [f for f in files if "train" in f.name and ".conllu" in f.name]
assert len(training_file) == 1, f"Couldn't find training file."
training_file_path = training_file[0]
valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
assert len(valid_file) == 1, f"Couldn't find validation file."
valid_file_path = valid_file[0]
embeddings_dir = FLAGS.embeddings_dir
embeddings_file = None
if embeddings_dir:
embeddings_dir = pathlib.Path(embeddings_dir) / language
embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name]
assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
embeddings_file = embeddings_file[0]
language = training_file_path.name.split("_")[0]
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / treebank
serialization_dir.mkdir(exist_ok=True, parents=True)
command = f"""time combo --mode train
--cuda_device {FLAGS.cuda_device}
--training_data_path {training_file_path}
--validation_data_path {valid_file_path}
{f"--pretrained_tokens {embeddings_file}" if embeddings_dir
else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
--serialization_dir {serialization_dir}
--config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
--notensorboard
"""
# Datasets without XPOS
if treebank in {'UD_Danish-DDT', 'UD_Western_Armenian-ArmTDP', 'UD_Basque-BDT', 'UD_Hungarian-Szeged', 'UD_Russian-Taiga', 'UD_Portuguese-Bosque', 'UD_Norwegian-NynorskLIA', 'UD_Turkish-Penn', 'UD_French-GSD', 'UD_Armenian-ArmTDP'}:
command = command + " --targets deprel,head,upostag,lemma,feats"
# Datasets without FEATS
if treebank in {'UD_Galician-CTG', 'UD_Italian-ISDT', 'UD_Korean-Kaist', 'UD_Korean-GSD'}:
command = command + " --targets deprel,head,upostag,xpostag,lemma"
# Datasets without LEMMA
if treebank in {'UD_Old_French-SRCMF'}:
command = command + " --targets deprel,head,upostag,xpostag,feats"
# Datasets without XPOS and LEMMA
if treebank in {}:
command = command + " --targets deprel,head,upostag,feats"
# Datasets without LEMMA and FEATS
if treebank in {'UD_English-ESL', 'UD_Maltese-MUDT', 'UD_Swedish_Sign_Language-SSLC'}:
command = command + " --targets deprel,head,upostag,xpostag"
# Datasets without XPOS and FEATS
if treebank in {}:
command = command + " --targets deprel,head,upostag,lemma"
# Datasets without XPOS, FEATS and LEMMA
if treebank in {}:
command = command + " --targets deprel,head,upostag"
# Reduce word_batch_size
word_batch_size = 2500
if treebank in {"UD_German-HDT", "UD_Marathi-UFAL"}:
word_batch_size = 1000
elif treebank in {"UD_Telugu-MTG"}:
word_batch_size = 500
command = command + f" --word_batch_size {word_batch_size}"
utils.execute_command(command)
def main():
app.run(run)
if __name__ == "__main__":
main()