diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 8b077c3cdded0115505ad95c703357ef6f8057a1..737a7b816a070c00ded66168e2705e56156ac0da 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -36,7 +36,7 @@ FLAGS = flags.FLAGS flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()), help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.") flags.DEFINE_string(name="data_dir", default="", - help="Path to 'iwpt2020stdata' directory.") + help="Path to IWPT'21 data directory.") flags.DEFINE_string(name="serialization_dir", default="/tmp/", help="Model serialization dir.") flags.DEFINE_integer(name="cuda_device", default=-1, @@ -68,9 +68,11 @@ def run(_): assert data_dir.is_dir(), f"'{data_dir}' is not a directory!" treebanks = LANG2TREEBANK[lang] + full_language = treebanks[0].split("-")[0] train_paths = [] dev_paths = [] - + train_raw_paths = [] + dev_raw_paths = [] # TODO Uncomment when IWPT'21 Shared Task ends. # During shared task duration test data is not available. test_paths = [] @@ -90,19 +92,33 @@ def run(_): # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) # test_paths.append(output) + if ".txt" in name: + if "train" in name: + train_raw_paths.append(path_to_str(treebank_file)) + elif "dev" in name: + dev_raw_paths.append(path_to_str(treebank_file)) - lang_data_dir = pathlib.Path(data_dir / lang) + merged_dataset_name = "IWPT" + lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}") lang_data_dir.mkdir(exist_ok=True) - train_path = lang_data_dir / "train.conllu" - dev_path = lang_data_dir / "dev.conllu" - # TODO Uncomment - # test_path = lang_data_dir / "test.conllu" + suffix = f"{lang}_{merged_dataset_name}-ud".lower() + train_path = lang_data_dir / f"{suffix}-train.conllu" + dev_path = lang_data_dir / f"{suffix}-dev.conllu" + test_path = lang_data_dir / f"{suffix}-test.conllu" + train_raw_path = lang_data_dir / f"{suffix}-train.txt" + dev_raw_path = lang_data_dir / f"{suffix}-dev.txt" + test_raw_path = lang_data_dir / f"{suffix}-test.txt" merge_files(train_paths, output=train_path) merge_files(dev_paths, output=dev_path) - # TODO Uncomment - # merge_files(test_paths, output=test_path) + # TODO Change to test_paths instead of dev_paths after IWPT'21 + merge_files(dev_paths, output=test_path) + + merge_files(train_raw_paths, output=train_raw_path) + merge_files(dev_raw_paths, output=dev_raw_path) + # TODO Change to test_raw_paths instead of dev_paths after IWPT'21 + merge_files(dev_raw_paths, output=test_raw_path) serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang serialization_dir.mkdir(exist_ok=True, parents=True) diff --git a/scripts/utils.py b/scripts/utils.py index f1d03feb973a69f90662574b0803eccd37b36dba..bbbe2fec838c2aadef2c98aa2f30f5686b08f218 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -21,6 +21,7 @@ LANG2TRANSFORMER = { "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/", "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", + "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/", "cs": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-cs-cased/", "et": "/tmp/lustre_shared/mklimasz/transformers/etwiki-bert/", # "uk": http://dl.turkunlp.org/wikibert/wikibert-base-uk-cased/