Skip to content
Snippets Groups Projects
Commit d711c60f authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski Committed by Mateusz Klimaszewski
Browse files

Add script for training UD dependency parsing models and extend pypi description.

parent 426d24f1
No related branches found
No related tags found
2 merge requests!9Enhanced dependency parsing develop to master,!8Enhanced dependency parsing
......@@ -5,5 +5,5 @@ from .parser import DependencyRelationModel
from .embeddings import CharacterBasedWordEmbeddings
from .encoder import ComboEncoder
from .lemma import LemmatizerModel
from .model import SemanticMultitaskModel
from .model import ComboModel
from .morpho import MorphologicalFeatures
......@@ -119,13 +119,9 @@ class GraphDependencyRelationModel(base.Predictor):
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
# if mask is not None:
# mask = mask[:, 1:]
relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
if labels is not None and labels[0] is not None:
relations_labels, head_labels, enhanced_heads_labels = labels
# if mask is None:
# mask = head_labels.new_ones(head_labels.size())
head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
head_pred = head_output["probability"]
......
......@@ -12,7 +12,7 @@ from combo.utils import metrics
@allen_models.Model.register("semantic_multitask")
class SemanticMultitaskModel(allen_models.Model):
class ComboModel(allen_models.Model):
"""Main COMBO model."""
def __init__(self,
......
"""Script to train Dependency Parsing models based on UD 2.x data."""
import pathlib
from absl import app
from absl import flags
from scripts import utils
TREEBANKS = [
"UD_Afrikaans-AfriBooms",
"UD_Arabic-NYUAD",
"UD_Arabic-PADT",
"UD_Armenian-ArmTDP",
"UD_Basque-BDT",
"UD_Belarusian-HSE",
"UD_Breton-KEB",
"UD_Bulgarian-BTB",
"UD_Catalan-AnCora",
"UD_Croatian-SET",
"UD_Czech-CAC",
"UD_Czech-CLTT",
"UD_Czech-FicTree",
"UD_Czech-PDT",
"UD_Danish-DDT",
"UD_Dutch-Alpino",
"UD_Dutch-LassySmall",
"UD_English-ESL",
"UD_English-EWT",
"UD_English-GUM",
"UD_English-LinES",
"UD_English-ParTUT",
"UD_English-Pronouns",
"UD_Estonian-EDT",
"UD_Estonian-EWT",
"UD_Finnish-FTB",
"UD_Finnish-TDT",
"UD_French-FQB",
"UD_French-FTB",
"UD_French-GSD",
"UD_French-ParTUT",
"UD_French-Sequoia",
"UD_French-Spoken",
"UD_Galician-CTG",
"UD_Galician-TreeGal",
"UD_German-GSD",
"UD_German-HDT",
"UD_German-LIT",
"UD_Greek-GDT",
"UD_Hebrew-HTB",
"UD_Hindi_English-HIENCS",
"UD_Hindi-HDTB",
"UD_Hungarian-Szeged",
"UD_Indonesian-GSD",
"UD_Irish-IDT",
"UD_Italian-ISDT",
"UD_Italian-ParTUT",
"UD_Italian-PoSTWITA",
"UD_Italian-TWITTIRO",
"UD_Italian-VIT",
"UD_Japanese-BCCWJ",
"UD_Japanese-GSD",
"UD_Japanese-Modern",
"UD_Kazakh-KTB",
"UD_Korean-GSD",
"UD_Korean-Kaist",
"UD_Latin-ITTB",
"UD_Latin-Perseus",
"UD_Latin-PROIEL",
"UD_Latvian-LVTB",
"UD_Lithuanian-ALKSNIS",
"UD_Lithuanian-HSE",
"UD_Maltese-MUDT",
"UD_Marathi-UFAL",
"UD_Persian-Seraji",
"UD_Polish-LFG",
"UD_Polish-PDB",
"UD_Portuguese-Bosque",
"UD_Portuguese-GSD",
"UD_Romanian-Nonstandard",
"UD_Romanian-RRT",
"UD_Romanian-SiMoNERo",
"UD_Russian-GSD",
"UD_Russian-SynTagRus",
"UD_Russian-Taiga",
"UD_Serbian-SET",
"UD_Slovak-SNK",
"UD_Slovenian-SSJ",
"UD_Slovenian-SST",
"UD_Spanish-AnCora",
"UD_Spanish-GSD",
"UD_Swedish-LinES",
"UD_Swedish_Sign_Language-SSLC",
"UD_Swedish-Talbanken",
"UD_Tamil-TTB",
"UD_Telugu-MTG",
"UD_Turkish-GB",
"UD_Turkish-IMST",
"UD_Ukrainian-IU",
"UD_Urdu-UDTB",
"UD_Uyghur-UDT",
"UD_Vietnamese-VTB",
]
FLAGS = flags.FLAGS
flags.DEFINE_list(name="treebanks", default=TREEBANKS,
help=f"Treebanks to train. Possible values: {TREEBANKS}.")
flags.DEFINE_string(name="data_dir", default="",
help="Path to UD data directory.")
flags.DEFINE_string(name="serialization_dir", default="/tmp/",
help="Model serialization directory.")
flags.DEFINE_string(name="embeddings_dir", default="",
help="Path to embeddings directory (with languages as subdirectories).")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
def run(_):
treebanks_dir = pathlib.Path(FLAGS.data_dir)
for treebank in FLAGS.treebanks:
assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
treebank_dir = treebanks_dir / treebank
treebank_parts = treebank.split("_")[1].split("-")
language = treebank_parts[0]
files = list(treebank_dir.iterdir())
training_file = [f for f in files if "train" in f.name and ".conllu" in f.name]
assert len(training_file) == 1, f"Couldn't find training file."
training_file_path = training_file[0]
valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
assert len(valid_file) == 1, f"Couldn't find validation file."
valid_file_path = valid_file[0]
embeddings_dir = FLAGS.embeddings_dir
embeddings_file = None
if embeddings_dir:
embeddings_dir = embeddings_dir / language
embeddings_file = [f for f in embeddings_dir.iterdir() if "vectors" in f.name and ".vec.gz" in f.name]
assert len(embeddings_file) == 1, f"Couldn't find embeddings file."
embeddings_file = embeddings_file[0]
language = training_file_path.name.split("_")[0]
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / treebank
serialization_dir.mkdir(exist_ok=True, parents=True)
command = f"""time combo --mode train
--cuda_device {FLAGS.cuda_device}
--training_data_path {training_file_path}
--validation_data_path {valid_file_path}
{f"--pretrained_tokens {embeddings_file}" if embeddings_dir
else f"--pretrained_transformer_name {utils.LANG2TRANSFORMER[language]}"}
--serialization_dir {serialization_dir}
--config_path {pathlib.Path.cwd() / 'config.template.jsonnet'}
--word_batch_size 2500
--notensorboard
"""
# no XPOS datasets
if treebank in ["UD_Hungarian-Szeged", "UD_Armenian-ArmTDP"]:
command = command + " --targets deprel,head,upostag,lemma,feats"
utils.execute_command(command)
def main():
app.run(run)
if __name__ == "__main__":
main()
......@@ -9,13 +9,13 @@ conda install -c dan_blanchard perl-moosex-semiaffordanceaccessor
import os
import pathlib
import subprocess
from typing import List
from absl import app
from absl import flags
FLAGS = flags.FLAGS
from scripts import utils
LANG2TREEBANK = {
"ar": ["Arabic-PADT"],
"bg": ["Bulgarian-BTB"],
......@@ -36,11 +36,7 @@ LANG2TREEBANK = {
"uk": ["Ukrainian-IU"],
}
LANG2TRANSFORMER = {
"en": "bert-base-cased",
"pl": "allegro/herbert-base-cased",
}
FLAGS = flags.FLAGS
flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()),
help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.")
flags.DEFINE_string(name="data_dir", default="",
......@@ -60,19 +56,10 @@ def merge_files(files: List[str], output: pathlib.Path):
os.system(f"cat {' '.join(files)} > {output}")
def execute_command(command, output_file=None):
command = [c for c in command.split() if c.strip()]
if output_file:
with open(output_file, "w") as f:
subprocess.run(command, check=True, stdout=f)
else:
subprocess.run(command, check=True)
def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str):
output_path = pathlib.Path(output)
if not output_path.exists():
execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} "
f"{path_to_str(treebank_file)}", output)
......@@ -80,6 +67,7 @@ def run(_):
languages = FLAGS.lang
for lang in languages:
assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}."
assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict."
data_dir = pathlib.Path(FLAGS.data_dir)
assert data_dir.is_dir(), f"'{data_dir}' is not a directory!"
......@@ -116,17 +104,17 @@ def run(_):
merge_files(test_paths, output=test_path)
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang
serialization_dir.mkdir(exist_ok=True)
execute_command("".join(f"""combo --mode train
serialization_dir.mkdir(exist_ok=True, parents=True)
utils.execute_command("".join(f"""combo --mode train
--training_data {train_path}
--validation_data {dev_path}
--targets feats,upostag,xpostag,head,deprel,lemma,deps
--pretrained_transformer_name {LANG2TRANSFORMER[lang]}
--pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]}
--serialization_dir {serialization_dir}
--cuda_device {FLAGS.cuda_device}
--word_batch_size 2500
--config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'}
--tensorboard
--notensorboard
""".splitlines()))
......
"""Utils for scripts."""
import subprocess
LANG2TRANSFORMER = {
"en": "bert-base-cased",
"pl": "allegro/herbert-base-cased",
}
def execute_command(command, output_file=None):
command = [c for c in command.split() if c.strip()]
if output_file:
with open(output_file, "w") as f:
subprocess.run(command, check=True, stdout=f)
else:
subprocess.run(command, check=True)
[aliases]
test=pytest
[metadata]
description-file = README.md
......@@ -20,11 +20,23 @@ REQUIREMENTS = [
setup(
name='COMBO',
version='0.0.1',
version='1.0.0b1',
install_requires=REQUIREMENTS,
packages=find_packages(exclude=['tests']),
license="GPL-3.0",
url='https://gitlab.clarin-pl.eu/syntactic-tools/combo',
keywords="nlp natural-language-processing dependency-parsing",
setup_requires=['pytest-runner', 'pytest-pylint'],
tests_require=['pytest', 'pylint'],
python_requires='>=3.6',
entry_points={'console_scripts': ['combo = combo.main:main']},
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: GNU General Public License v3 (GPLv3)',
'Topic :: Scientific/Engineering :: Artificial Intelligence'
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
]
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment