diff --git a/combo/models/__init__.py b/combo/models/__init__.py index ba8d617510440fd6f0195e4f8ba6606ea8d68f5a..ec7a1380e1cfc80b0302806e46cca4e5fc2d3568 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -5,5 +5,5 @@ from .parser import DependencyRelationModel from .embeddings import CharacterBasedWordEmbeddings from .encoder import ComboEncoder from .lemma import LemmatizerModel -from .model import SemanticMultitaskModel +from .model import ComboModel from .morpho import MorphologicalFeatures diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py index 2dc02dc98e20ff7637d2333135857ec81cec009a..edcdc2d0785dd73d91b3e79249d196ba55ec148c 100644 --- a/combo/models/graph_parser.py +++ b/combo/models/graph_parser.py @@ -119,13 +119,9 @@ class GraphDependencyRelationModel(base.Predictor): mask: Optional[torch.BoolTensor] = None, labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: - # if mask is not None: - # mask = mask[:, 1:] relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None if labels is not None and labels[0] is not None: relations_labels, head_labels, enhanced_heads_labels = labels - # if mask is None: - # mask = head_labels.new_ones(head_labels.size()) head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights) head_pred = head_output["probability"] diff --git a/combo/models/model.py b/combo/models/model.py index ad0df0e8be53d6470581fdda9f73279ea25327ed..9d3f81725450b61295703ac2489bcce5f87b66a0 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -12,7 +12,7 @@ from combo.utils import metrics @allen_models.Model.register("semantic_multitask") -class SemanticMultitaskModel(allen_models.Model): +class ComboModel(allen_models.Model): """Main COMBO model.""" def __init__(self, diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca0fce656fde639c27d3087f6011ef8fb9ab142 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,172 @@ +"""Script to train Dependency Parsing models based on UD 2.x data.""" +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +TREEBANKS = [ + "UD_Afrikaans-AfriBooms", + "UD_Arabic-NYUAD", + "UD_Arabic-PADT", + "UD_Armenian-ArmTDP", + "UD_Basque-BDT", + "UD_Belarusian-HSE", + "UD_Breton-KEB", + "UD_Bulgarian-BTB", + "UD_Catalan-AnCora", + "UD_Croatian-SET", + "UD_Czech-CAC", + "UD_Czech-CLTT", + "UD_Czech-FicTree", + "UD_Czech-PDT", + "UD_Danish-DDT", + "UD_Dutch-Alpino", + "UD_Dutch-LassySmall", + "UD_English-ESL", + "UD_English-EWT", + "UD_English-GUM", + "UD_English-LinES", + "UD_English-ParTUT", + "UD_English-Pronouns", + "UD_Estonian-EDT", + "UD_Estonian-EWT", + "UD_Finnish-FTB", + "UD_Finnish-TDT", + "UD_French-FQB", + "UD_French-FTB", + "UD_French-GSD", + "UD_French-ParTUT", + "UD_French-Sequoia", + "UD_French-Spoken", + "UD_Galician-CTG", + "UD_Galician-TreeGal", + "UD_German-GSD", + "UD_German-HDT", + "UD_German-LIT", + "UD_Greek-GDT", + "UD_Hebrew-HTB", + "UD_Hindi_English-HIENCS", + "UD_Hindi-HDTB", + "UD_Hungarian-Szeged", + "UD_Indonesian-GSD", + "UD_Irish-IDT", + "UD_Italian-ISDT", + "UD_Italian-ParTUT", + "UD_Italian-PoSTWITA", + "UD_Italian-TWITTIRO", + "UD_Italian-VIT", + "UD_Japanese-BCCWJ", + "UD_Japanese-GSD", + "UD_Japanese-Modern", + "UD_Kazakh-KTB", + "UD_Korean-GSD", + "UD_Korean-Kaist", + "UD_Latin-ITTB", + "UD_Latin-Perseus", + "UD_Latin-PROIEL", + "UD_Latvian-LVTB", + "UD_Lithuanian-ALKSNIS", + "UD_Lithuanian-HSE", + "UD_Maltese-MUDT", + "UD_Marathi-UFAL", + "UD_Persian-Seraji", + "UD_Polish-LFG", + "UD_Polish-PDB", + "UD_Portuguese-Bosque", + "UD_Portuguese-GSD", + "UD_Romanian-Nonstandard", + "UD_Romanian-RRT", + "UD_Romanian-SiMoNERo", + "UD_Russian-GSD", + "UD_Russian-SynTagRus", + "UD_Russian-Taiga", + "UD_Serbian-SET", + "UD_Slovak-SNK", + "UD_Slovenian-SSJ", + "UD_Slovenian-SST", + "UD_Spanish-AnCora", + "UD_Spanish-GSD", + "UD_Swedish-LinES", + "UD_Swedish_Sign_Language-SSLC", + "UD_Swedish-Talbanken", + "UD_Tamil-TTB", + "UD_Telugu-MTG", + "UD_Turkish-GB", + "UD_Turkish-IMST", + "UD_Ukrainian-IU", + "UD_Urdu-UDTB", + "UD_Uyghur-UDT", + "UD_Vietnamese-VTB", +] + +FLAGS = flags.FLAGS +flags.DEFINE_list(name="treebanks", default=TREEBANKS, + help=f"Treebanks to train. Possible values: {TREEBANKS}.") +flags.DEFINE_string(name="data_dir", default="", + help="Path to UD data directory.") +flags.DEFINE_string(name="serialization_dir", default="/tmp/", + help="Model serialization directory.") +flags.DEFINE_string(name="embeddings_dir", default="", + help="Path to embeddings directory (with languages as subdirectories).") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") + + +def run(_): + treebanks_dir = pathlib.Path(FLAGS.data_dir) + for treebank in FLAGS.treebanks: + assert treebank in TREEBANKS, f"Unknown treebank {treebank}." + treebank_dir = treebanks_dir / treebank + treebank_parts = treebank.split("_")[1].split("-") + language = treebank_parts[0] + + files = list(treebank_dir.iterdir()) + + training_file = [f for f in files if "train" in f.name and ".conllu" in f.name] + assert len(training_file) == 1, f"Couldn't find training file." + training_file_path = training_file[0] + + valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name] + assert len(valid_file) == 1, f"Couldn't find validation file." + valid_file_path = valid_file[0] + + embeddings_dir = FLAGS.embeddings_dir + embeddings_file = None + if embeddings_dir: + embeddings_dir = embeddings_dir / language + embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name] + assert len(embeddings_file) == 1, f"Couldn't find embeddings file." + embeddings_file = embeddings_file[0] + + language = training_file_path.name.split("_")[0] + + serialization_dir = pathlib.Path(FLAGS.serialization_dir) / treebank + serialization_dir.mkdir(exist_ok=True, parents=True) + + command = f"""time combo --mode train + --cuda_device {FLAGS.cuda_device} + --training_data_path {training_file_path} + --validation_data_path {valid_file_path} + {f"--pretrained_tokens {embeddings_file}" if embeddings_dir + else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"} + --serialization_dir {serialization_dir} + --config_path {pathlib.Path.cwd() / 'config.template.jsonnet'} + --word_batch_size 2500 + --notensorboard + """ + + # no XPOS datasets + if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]: + command = command + " --targets deprel,head,upostag,lemma,feats" + + utils.execute_command(command) + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_eud.py b/scripts/train_eud.py index 7c50b6193be47e4518ba5aee73c83224b3957d5c..4904e0bff6a9d78a7d0a56bff2ed0357992b615b 100644 --- a/scripts/train_eud.py +++ b/scripts/train_eud.py @@ -9,13 +9,13 @@ conda install -c dan_blanchard perl-moosex-semiaffordanceaccessor import os import pathlib -import subprocess from typing import List from absl import app from absl import flags -FLAGS = flags.FLAGS +from scripts import utils + LANG2TREEBANK = { "ar": ["Arabic-PADT"], "bg": ["Bulgarian-BTB"], @@ -36,11 +36,7 @@ LANG2TREEBANK = { "uk": ["Ukrainian-IU"], } -LANG2TRANSFORMER = { - "en": "bert-base-cased", - "pl": "allegro/herbert-base-cased", -} - +FLAGS = flags.FLAGS flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()), help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.") flags.DEFINE_string(name="data_dir", default="", @@ -60,26 +56,18 @@ def merge_files(files: List[str], output: pathlib.Path): os.system(f"cat {' '.join(files)} > {output}") -def execute_command(command, output_file=None): - command = [c for c in command.split() if c.strip()] - if output_file: - with open(output_file, "w") as f: - subprocess.run(command, check=True, stdout=f) - else: - subprocess.run(command, check=True) - - def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): output_path = pathlib.Path(output) if not output_path.exists(): - execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " - f"{path_to_str(treebank_file)}", output) + utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) def run(_): languages = FLAGS.lang for lang in languages: assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}." + assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict." data_dir = pathlib.Path(FLAGS.data_dir) assert data_dir.is_dir(), f"'{data_dir}' is not a directory!" @@ -116,17 +104,17 @@ def run(_): merge_files(test_paths, output=test_path) serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang - serialization_dir.mkdir(exist_ok=True) - execute_command("".join(f"""combo --mode train + serialization_dir.mkdir(exist_ok=True, parents=True) + utils.execute_command("".join(f"""combo --mode train --training_data {train_path} --validation_data {dev_path} --targets feats,upostag,xpostag,head,deprel,lemma,deps - --pretrained_transformer_name {LANG2TRANSFORMER[lang]} + --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} --word_batch_size 2500 --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} - --tensorboard + --notensorboard """.splitlines())) diff --git a/scripts/utils.py b/scripts/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5dda2b89693fc1431b810709d9e7002d9f5f8071 --- /dev/null +++ b/scripts/utils.py @@ -0,0 +1,16 @@ +"""Utils for scripts.""" +import subprocess + +LANG2TRANSFORMER = { + "en": "bert-base-cased", + "pl": "allegro/herbert-base-cased", +} + + +def execute_command(command, output_file=None): + command = [c for c in command.split() if c.strip()] + if output_file: + with open(output_file, "w") as f: + subprocess.run(command, check=True, stdout=f) + else: + subprocess.run(command, check=True) diff --git a/setup.cfg b/setup.cfg index b7e478982ccf9ab1963c74e1084dfccb6e42c583..6876d0d7447015400e616dbd7479de01d19c2948 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,5 @@ [aliases] test=pytest + +[metadata] +description-file = README.md diff --git a/setup.py b/setup.py index 74540e259212f32d5407fea598e0b9e3560e1f6f..5c833d558fa63b9d74cb317fe4e38d875313692c 100644 --- a/setup.py +++ b/setup.py @@ -20,11 +20,23 @@ REQUIREMENTS = [ setup( name='COMBO', - version='0.0.1', + version='1.0.0b1', install_requires=REQUIREMENTS, packages=find_packages(exclude=['tests']), + license="GPL-3.0", + url='https://gitlab.clarin-pl.eu/syntactic-tools/combo', + keywords="nlp natural-language-processing dependency-parsing", setup_requires=['pytest-runner', 'pytest-pylint'], tests_require=['pytest', 'pylint'], python_requires='>=3.6', entry_points={'console_scripts': ['combo = combo.main:main']}, + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', + 'Topic :: Scientific/Engineering :: Artificial Intelligence' + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ] )