diff --git a/ud_script/README.md b/ud_script/README.md new file mode 100644 index 0000000000000000000000000000000000000000..08bb35812f739228382f075b9d5cc22bc85154d1 --- /dev/null +++ b/ud_script/README.md @@ -0,0 +1,30 @@ +## Description + +The script [train_ud_version.py](/combo/ud_script/train_ud_version.py) allows for training multiple combo models on specific UD treebank version. To run the script, three parameters are required: +- `output_directory` - path to the location where training results will be saved +- `treebank_id` +- `treebank_version` + +To find the `treebank_id` and `treebank_version`, visit https://universaldependencies.org/#download. The `treebank_version` is indicated at the beginning of the UD version, while the `treebank_id` is the value at the end of the link used to download it. See the attached image where both values for UD2.11 are highlighted in yellow. + + + +The script will automatically download and extract the UD data into the folder `output_directory`/ud_treebanks-`treebank_version`. Then, it creates a subfolder `output_directory/results` containing: +- `serialization_directories` - folder with training results +- `completed_training.txt` - a text file with the names of UD treebanks on which training was successfully completed +- `skipped_training.csv` - a csv file with two columns, the first containing names of UD treebanks, the second listing reasons why training failed. Possible reasons include: + - Dev or test or train file missing - it is expected that there is a .conllu file in the UD directory that contains train, dev, and test in its name. Otherwise, this error is thrown. + - Training file less than 1000 bytes - if the training file has less than 1000 bytes, training is skipped. + - Training file corrupted - number of columns is less than 10. + - Specify transformer model for language code: <lang_code> - No BERT model was assigned to the specified language. To address this, modify the `LANG2TRANSFORMER` variable in the file [constants](/combo/ud_script/constants.py). + - Command ... returned non-zero exit status 1 - An error was thrown during the training process. You need to examine logs from this particular training to understand what happened. + +If script was interrupted at some point, you can rerun it with the same command. Based on values in completed_training reruned script will ommit training on UD treebanks that already have model. + +Some of the models need adjusted value of word_batch_size, default value will be used unless you specify <UD trebank> <word_batch_size> pair in `UD_2_BATCH_SIZE` constant in [constants](/combo/ud_script/constants.py). + +## Example usage +Terminal command: +``` +python train_ud_version.py --treebank_id 1-5287 --treebank_version 2.13 --output_directory C:\Users\abc\Desktop +``` diff --git a/ud_script/constants.py b/ud_script/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..c92d352e68e120041be011b58586b9a19d667c90 --- /dev/null +++ b/ud_script/constants.py @@ -0,0 +1,300 @@ +LINKS = [ + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.af.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sq.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.als.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.am.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ar.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.an.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.as.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ast.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.az.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ba.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bar.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.be.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bpy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bs.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.br.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.my.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ca.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ceb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bcl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ce.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.co.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cs.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.da.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.dv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.arz.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eml.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.myv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.eo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.et.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hif.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ka.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.de.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gom.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.el.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ht.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.he.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mrj.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.is.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.io.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ilo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.id.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ia.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ga.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.it.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ja.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.jv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pam.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.kk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.km.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ky.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ko.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ku.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ckb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.la.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.li.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lmo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nds.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.lb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mai.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ms.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ml.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mzn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mhr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.min.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.xmf.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mwl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.mn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nah.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nap.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ne.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.new.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.frr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nso.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.no.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.nn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.oc.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.or.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.os.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pfl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ps.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pms.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.qu.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ro.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.rm.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ru.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sah.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sc.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sco.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.gd.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sh.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.scn.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sd.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.si.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.so.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.azb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.es.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.su.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sw.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.sv.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tl.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tg.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ta.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tt.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.te.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.th.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.bo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tr.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.tk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uk.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.hsb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ur.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ug.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.uz.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vec.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.wa.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.war.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.cy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.vls.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fy.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.pnb.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yi.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.yo.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.diq.300.vec.gz", + "https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.zea.300.vec.gz", +] + +CODE_2_LANG = { + "af": "Afrikaans", + "aii": "Assyrian", + "ajp": "South_Levantine_Arabic", + "akk": "Akkadian", + "am": "Amharic", + "apu": "Apurina", + "aqz": "Akuntsu", + "ar": "Arabic", + "be": "Belarusian", + "bg": "Bulgarian", + "bho": "Bhojpuri", + "bm": "Bambara", + "br": "Breton", + "bxr": "Buryat", + "ca": "Catalan", + "ckt": "Chukchi", + "cop": "Coptic", + "cs": "Czech", + "cu": "Old_Church_Slavonic", + "cy": "Welsh", + "da": "Danish", + "de": "German", + "el": "Greek", + "en": "English", + "es": "Spanish", + "et": "Estonian", + "eu": "Basque", + "fa": "Persian", + "fi": "Finnish", + "fo": "Faroese", + "fr": "French", + "fro": "Old_French", + "ga": "Irish", + "gd": "Scottish_Gaelic", + "gl": "Galician", + "got": "Gothic", + "grc": "Ancient_Greek", + "gsw": "Swiss_German", + "gun": "Mbya_Guarani", + "gv": "Manx", + "he": "Hebrew", + "hi": "Hindi", + "hr": "Croatian", + "hsb": "Upper_Sorbian", + "hu": "Hungarian", + "hy": "Armenian", + "id": "Indonesian", + "is": "Icelandic", + "it": "Italian", + "ja": "Japanese", + "kfm": "Khunsari", + "kk": "Kazakh", + "kmr": "Kurmanji", + "ko": "Korean", + "koi": "Komi_Permyak", + "kpv": "Komi_Zyrian", + "krl": "Karelian", + "la": "Latin", + "lt": "Lithuanian", + "lv": "Latvian", + "lzh": "Classical_Chinese", + "mdf": "Moksha", + "mr": "Marathi", + "mt": "Maltese", + "myu": "Munduruku", + "myv": "Erzya", + "nl": "Dutch", + "no": "Norwegian", + "nyq": "Nayini", + "olo": "Livvi", + "orv": "Old_Russian", + "otk": "Old_Turkish", + "pcm": "Naija", + "pl": "Polish", + "pt": "Portuguese", + "qhe": "Hindi_English", + "qtd": "Turkish_German", + "ro": "Romanian", + "ru": "Russian", + "sa": "Sanskrit", + "sk": "Slovak", + "sl": "Slovenian", + "sme": "North_Sami", + "sms": "Skolt_Sami", + "soj": "Soi", + "sq": "Albanian", + "sr": "Serbian", + "sv": "Swedish", + "swl": "Swedish_Sign_Language", + "ta": "Tamil", + "te": "Telugu", + "th": "Thai", + "tl": "Tagalog", + "tpn": "Tupinamba", + "tr": "Turkish", + "ug": "Uyghur", + "uk": "Ukrainian", + "ur": "Urdu", + "vi": "Vietnamese", + "wbp": "Warlpiri", + "wo": "Wolof", + "yo": "Yoruba", + "yue": "Cantonese", + "zh": "Chinese", +} + +UD_2_BATCH_SIZE = { + "UD_German-HDT" : 1000, + "UD_Marathi-UFAL" : 1000, + "UD_Telugu-MTG" : 500 +} + +LANG2TRANSFORMER = { + "en": "bert-base-cased", + "pl": "allegro/herbert-large-cased", + "zh": "bert-base-chinese", + "fi": "TurkuNLP/bert-base-finnish-cased-v1", + "ko": "kykim/bert-kor-base", + "de": "dbmdz/bert-base-german-cased", + "ar": "aubmindlab/bert-base-arabertv2", + "eu": "ixa-ehu/berteus-base-cased", + "tr": "dbmdz/bert-base-turkish-cased", + "bg": "xlm-roberta-large", + "nl": "xlm-roberta-large", + "fr": "camembert-base", + "it": "xlm-roberta-large", + "ru": "xlm-roberta-large", + "sv": "xlm-roberta-large", + "uk": "xlm-roberta-large", + "ta": "xlm-roberta-large", + "sk": "xlm-roberta-large", + "lt": "xlm-roberta-large", + "lv": "xlm-roberta-large", + "cs": "xlm-roberta-large", + "et": "xlm-roberta-large", +} + +LANG_UD2FASTTEXT = {} \ No newline at end of file diff --git a/ud_script/imgs/instruction_ud_version.png b/ud_script/imgs/instruction_ud_version.png new file mode 100644 index 0000000000000000000000000000000000000000..cf10a79920e3febcb1dc9b2c88a30d92e1dc0ca5 Binary files /dev/null and b/ud_script/imgs/instruction_ud_version.png differ diff --git a/ud_script/train_ud_version.py b/ud_script/train_ud_version.py new file mode 100644 index 0000000000000000000000000000000000000000..09009f6eaa414ec58fae9fe0f26281a0e0de9b03 --- /dev/null +++ b/ud_script/train_ud_version.py @@ -0,0 +1,89 @@ +from string import Template +from pathlib import Path +from utils import ( + download_file, + extract_tgz, + build_list_of_uds_to_train, + download_fasttext, + write_to_csv, + train_one_model) +import pandas as pd +import argparse + +# Argument parsing +parser = argparse.ArgumentParser() +parser.add_argument("--treebank_id", type=str, default="", help="ID of the treebank") +parser.add_argument("--treebank_version", type=str, default="", help="Version of the treebank") +parser.add_argument("--output_dir", type=str, default="", help="Base directory for output") +args = parser.parse_args() + +url = Template("https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/$treebank_id/ud-treebanks-v$version.tgz") + +def main(): + treebank_id = args.treebank_id + treebank_version = args.treebank_version + output_dir = args.output_dir + ud_output_dir = str(Path(output_dir) / f'ud-treebanks-v{treebank_version}.tgz') + + # make output directory if it doesn't exist + Path(output_dir).mkdir(parents=True, exist_ok=True) + (Path(output_dir) / "results").mkdir(parents=True, exist_ok=True) + + ud_url = url.substitute(treebank_id=treebank_id, + version=treebank_version) + # Download UD treebank + if not Path(ud_output_dir).exists(): + print("Downloading UD treebank") + download_file(url=ud_url, + local_filename=ud_output_dir) + else: + print("UD treebank already downloaded") + + # Extract UD treebank + if not (Path(output_dir) / f'ud-treebanks-v{treebank_version}').exists(): + print("Extracting UD treebank") + extract_tgz(tgz_path=ud_output_dir, + extract_path=output_dir) + else : + print("UD treebank already extracted") + + print("Checking UD data") + ud_folder_directory = output_dir + f'/ud-treebanks-v{treebank_version}' + training_queue = build_list_of_uds_to_train(UD_dir=ud_folder_directory) + + # Read UDs with completed training + completed_training_path = Path(output_dir) / "results" / "completed_training.txt" + completed_training_path.touch(exist_ok=True) + with open(completed_training_path, 'r', encoding='utf-8') as file: + completed_training = [line.strip() for line in file] + + # Train on transformer models + for ud_treebank in training_queue.keys(): + if ud_treebank not in completed_training: + print(f"Training {ud_treebank}") + try: + # Start training + train_one_model(treebank = ud_treebank, + treebanks_dir = ud_folder_directory, + output_dir = output_dir, + targets = training_queue[ud_treebank], + use_fasttext = False) + + # Save info that training was completed + with open(completed_training_path, 'a') as file: file.write(ud_treebank+"\n") + + # Delete this info from skipped_training.csv if there is any + df = pd.read_csv(Path(output_dir) / "results" / "skipped_training.csv") + + filtered_df = df[~df.iloc[:, 0].str.contains(ud_treebank)] + filtered_df.to_csv(Path(output_dir) / "results" / "skipped_training.csv", + index=False) + except Exception as e: + error = str(e) + write_to_csv(treebank=ud_treebank, + reason=error, + csv_file_path=Path(output_dir) / "results" / "skipped_training.csv") + + +if __name__ == "__main__": + main() diff --git a/ud_script/utils.py b/ud_script/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d209ee8da4ea5cb1aff54a2f92c646bc8378b9 --- /dev/null +++ b/ud_script/utils.py @@ -0,0 +1,309 @@ +import requests +from tqdm import tqdm +import tarfile +from pathlib import Path +import csv +import shutil +from constants import ( + LINKS, + CODE_2_LANG, + UD_2_BATCH_SIZE, + LANG2TRANSFORMER, + LANG_UD2FASTTEXT) +import subprocess +import os + + +def download_file(url : str, + local_filename : str, + show_progress : bool=True): + """ + Downloads a file from the given URL and saves it locally. + + Args: + url (str): The URL of the file to download. + local_filename (str): The path and filename to save the downloaded file locally. + show_progress (bool, optional): Whether to show a progress bar during the download. + Defaults to True. + """ + # Send a GET request to the URL with stream=True to stream the download + with requests.get(url, stream=True) as response: + # Raise an error for bad responses + response.raise_for_status() + + # Get the total size of the file from the Content-Length header + total_size_in_bytes = int(response.headers.get('content-length', 0)) + + # Open a local file with write-binary mode + with open(local_filename, 'wb') as file: + if show_progress: + # Use tqdm for a progress bar if show_progress is True + with tqdm( + desc=local_filename, + total=total_size_in_bytes, + unit='iB', + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for chunk in response.iter_content(chunk_size=1024): + size = file.write(chunk) + progress_bar.update(size) + else: + # If not showing progress, just write the file without a progress bar + for chunk in response.iter_content(chunk_size=1024): + file.write(chunk) + + print(f"File downloaded and saved as {local_filename}") + + +def download_fasttext(output_dir: str): + """ + Downloads FastText vectors for different languages. + + Args: + output_dir (str): The directory where the FastText vectors will be downloaded. + """ + output_dir = Path(output_dir) + for link in LINKS: + lang_code = link.split(".")[-4] + + if lang_code not in CODE_2_LANG: + print(f"Unknown code {lang_code}.") + continue + + output_file = output_dir / CODE_2_LANG[lang_code] + output_file.mkdir(exist_ok=True, parents=True) + if (output_file / 'vectors.vec.gz').exists(): + continue + + download_file(url=link, + local_filename=str(output_file / 'vectors.vec.gz'), + show_progress=False) + + +def extract_tgz(tgz_path : str, + extract_path : str ="."): + """ + Extracts the contents of a .tgz file to the specified extract path. + + Args: + tgz_path (str): The path to the .tgz file. + extract_path (str, optional): The path where the contents will be extracted. Defaults to the current directory. + """ + # Open the .tgz file + with tarfile.open(tgz_path, "r:gz") as tar: + # Get the total number of items within the archive for the progress bar + total = len(tar.getmembers()) + + # Setup tqdm progress bar + with tqdm(total=total, unit='files', desc="Extracting") as progress_bar: + # Extract each member individually + for member in tar.getmembers(): + tar.extract(member, path=extract_path) + # Update the progress bar per file extracted + progress_bar.update(1) + print(f"Files have been extracted to: {extract_path}") + + +def write_to_csv(treebank : str, + reason : str, + csv_file_path : Path): + # Open the CSV file and append the directory name and reason + with open(csv_file_path, mode='a', newline='') as file: + writer = csv.writer(file) + writer.writerow([treebank, reason]) + + +def build_list_of_uds_to_train(UD_dir: str): + """ + Builds a dictionary of UD treebanks to train with their respective targets training commands. It ommits UDs that + have training file of size less than 1000 bytes, corrupted files (less than 10 columns) or missing files. + All of reasons are saved in skipped_training.csv + + Args: + UD_dir (str): The directory path containing the UD treebanks. + + Returns: + dict: A dictionary mapping UD treebank names to their training commands. + """ + UD_dir = Path(UD_dir) + csv_file_path = UD_dir.parent / "results" / "skipped_training.csv" + UD_to_targets = {} + + # Create file with skipped UD treebanks + with open(csv_file_path, mode='w', newline='') as file: + writer = csv.writer(file) + writer.writerow(["Directory Name", "Reason"]) + + # Iterate through each item in the directory + for item in tqdm(UD_dir.iterdir()): + # Check if the item is a directory + if item.is_dir(): + + skipped = False + # Get files in directory (ignoring directories) + onlyfiles = [f for f in item.iterdir() if f.is_file()] + + # Check for the presence of specific files + has_dev = any(["dev" in f.name and f.suffix == ".conllu" for f in onlyfiles]) + has_test = any(["test" in f.name and f.suffix == ".conllu" for f in onlyfiles]) + has_train = any(["train" in f.name and f.suffix == ".conllu" for f in onlyfiles]) + + if not (has_train and has_dev and has_test): + write_to_csv(treebank=item.name, + reason="Dev or test or train file missing", + csv_file_path=csv_file_path) + continue + + # If training file are too small - skip training + for f in item.iterdir(): + tmp_path = Path() + if f.suffix == ".conllu" and ("dev" in f.name or "test" in f.name or "train" in f.name): + tmp_path = f + if "train" in f.name: + train_file = f + if tmp_path and tmp_path.is_file() and tmp_path.stat().st_size < 1000: + write_to_csv(treebank=item.name, + reason="Training file less than 1000 bytes", + csv_file_path=csv_file_path) + continue + + lemmas, upos, xpos, feats, head, deprel = [], [], [], [], [], [] + + # Checking what data is inside training file + with train_file.open('r', encoding='utf-8') as rf: + sentences_processed = 0 + for line in rf: + # omit lines that are comments + if line.startswith("#"): + continue + + # always read only 10 sentences + if line == "\n": + if sentences_processed == 10: + break + else: + sentences_processed += 1 + continue + + words = line.split("\t") + if len(words) == 10: + lemmas.append(words[2]) + upos.append(words[3]) + xpos.append(words[4]) + feats.append(words[5]) + head.append(words[6]) + deprel.append(words[7]) + else: + write_to_csv(treebank=item.name, + reason="Training file corrupted - number of columns is less than 10.", + csv_file_path=csv_file_path) + skipped = True + break + + + if skipped: + continue + + # Create targets command by adding + command = " --targets head,deprel" + if set(lemmas) != "-": + command += ",lemma" + if set(upos) != "-": + command += ",upostag" + if set(xpos) != "-": + command += ",xpostag" + if set(feats) != "-": + command += ",feats" + + # Save training command + UD_to_targets[item.name] = command + + return UD_to_targets + + +def execute_command(command: str, output_file=None): + """ + Executes a command in the system shell. + + Args: + command (str): The command to be executed. + output_file (str, optional): The file to redirect the command output to. Defaults to None. + """ + env = os.environ.copy() + command = [c for c in command.split() if c.strip()] + if output_file: + with open(output_file, "w") as f: + subprocess.run(command, check=True, stdout=f, env=env) + else: + subprocess.run(command, check=True, env=env) + + +def train_one_model(treebank : str, + treebanks_dir : str, + output_dir : str, + targets : str, + use_fasttext : bool = True): + # Create LANG_2_CODE dict + LANG_2_CODE = {value: key for key, value in CODE_2_LANG.items()} + + treebank_dir = Path(treebanks_dir) / treebank + treebank_parts = treebank[3:].split("-") + language = treebank_parts[0] + if language in LANG_UD2FASTTEXT: + language = LANG_UD2FASTTEXT[language] + + # Find training and dev files + files = list(treebank_dir.iterdir()) + + training_file = [f for f in files if "train" in f.name and ".conllu" in f.name] + assert len(training_file) == 1, f"Multiple training file." + training_file_path = training_file[0] + + valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name] + assert len(valid_file) == 1, f"Multiple validation file." + valid_file_path = valid_file[0] + + # Find fasttexts + if use_fasttext: + embeddings_dir = Path(output_dir) / "fasttext" / language + if not embeddings_dir.exists(): + raise FileNotFoundError(f"Fasttext for {language} does not exist. If you want to use fasttext for different language, specify mapping in LANG_UD2FASTTEXT.") + embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec" in f.name] + assert len(embeddings_file) == 1, f"Couldn't find embeddings file." + embeddings_file = embeddings_file[0] + else: + # Throw error if transformer model is not specified + if not LANG_2_CODE[language] in LANG2TRANSFORMER.keys(): + raise ValueError(f"Specify transformer model for language code: {LANG_2_CODE[language]}.") + + # Create serialization directory + serialization_dir = Path(output_dir) / "results" / "serialization_directories" / treebank + if serialization_dir.exists() and serialization_dir.is_dir(): + # Remove the directory and all its contents + shutil.rmtree(serialization_dir) + serialization_dir.mkdir(exist_ok=True, parents=True) + + # Create command to run training + training_data_path = f"--training_data_path {training_file_path} " + validation_data_path = f"--validation_data_path {valid_file_path} " + if use_fasttext: + embeddings = f"--pretrained_tokens {embeddings_dir} " + else: + transformer_name = LANG2TRANSFORMER[LANG_2_CODE[language]] + embeddings = f"--pretrained_transformer_name {transformer_name} " + serialization_dir = f"--serialization_dir {serialization_dir} " + tokenizer_language = f"--tokenizer_language={language} " + + command = "combo --mode train " + training_data_path + validation_data_path + embeddings + serialization_dir + tokenizer_language + + command += targets + + if treebank in UD_2_BATCH_SIZE: + command += f" --word_batch_size {UD_2_BATCH_SIZE[treebank]}" + + # Start training + execute_command(command) + + +