diff --git a/scripts/train.py b/scripts/train.py index accca4a7f2ee8e2d969c21b7966e18e8628da48a..950ee82184c4dbb4b3a4d9ce31b551ab81439375 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -229,6 +229,10 @@ def run(_): "UD_Marathi-UFAL", "UD_Norwegian-Bokmaal"}: command = command + " --targets deprel,head,upostag,lemma,feats" + # Datasets without FEATS + if treebank in {"UD_Japanese-GSD", "UD_Korean-Kaist"}: + command = command + " --targets deprel,head,upostag,xpostag,lemma" + # Datasets without LEMMA and FEATS if treebank in {"UD_Maltese-MUDT"}: command = command + " --targets deprel,head,upostag,xpostag" diff --git a/scripts/train_eud.py b/scripts/train_eud.py index 4904e0bff6a9d78a7d0a56bff2ed0357992b615b..ba13a272a78c127b88ee524a733138d2adc81459 100644 --- a/scripts/train_eud.py +++ b/scripts/train_eud.py @@ -105,7 +105,8 @@ def run(_): serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang serialization_dir.mkdir(exist_ok=True, parents=True) - utils.execute_command("".join(f"""combo --mode train + + command = f"""combo --mode train --training_data {train_path} --validation_data {dev_path} --targets feats,upostag,xpostag,head,deprel,lemma,deps @@ -115,7 +116,13 @@ def run(_): --word_batch_size 2500 --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} --notensorboard - """.splitlines())) + """ + + # Datasets without XPOS + if lang in {"fr"}: + command = command + " --targets deprel,head,upostag,lemma,feats" + + utils.execute_command("".join(command.splitlines())) def main(): diff --git a/scripts/utils.py b/scripts/utils.py index 5dda2b89693fc1431b810709d9e7002d9f5f8071..ebfec3e22111996ed4ea2151e24ae03f77bd2891 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -4,6 +4,13 @@ import subprocess LANG2TRANSFORMER = { "en": "bert-base-cased", "pl": "allegro/herbert-base-cased", + "zh": "bert-base-chinese", + "fi": "TurkuNLP/bert-base-finnish-cased-v1", + "ja": "cl-tohoku/bert-base-japanese", + "ko": "kykim/bert-kor-base", + "de": "dbmdz/bert-base-german-cased", + "ar": "aubmindlab/bert-base-arabertv2", + "eu": "ixa-ehu/berteus-base-cased" }