From 06efc0439c91efcbb463c7acbb5ba27a627bca6d Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Tue, 6 Apr 2021 13:01:19 +0200 Subject: [PATCH 01/28] Add pos, deprel and feats embeddings as an additional output. --- combo/data/api.py | 11 ++-- combo/models/base.py | 115 ++++++++++++++++++++++++++++++++++++++--- combo/models/model.py | 9 +++- combo/models/morpho.py | 10 ++-- combo/models/parser.py | 1 + combo/predict.py | 13 +++-- 6 files changed, 137 insertions(+), 22 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index 4ab7f1a..bfec5ee 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -21,12 +21,13 @@ class Token: deps: Optional[str] = None misc: Optional[str] = None semrel: Optional[str] = None + embeddings: Dict[str, List[float]] = field(default_factory=list, repr=False) @dataclass class Sentence: tokens: List[Token] = field(default_factory=list) - sentence_embedding: List[float] = field(default_factory=list) + sentence_embedding: List[float] = field(default_factory=list, repr=False) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) def to_json(self): @@ -77,14 +78,16 @@ def tokens2conllu(tokens: List[str]) -> conllu.TokenList: def conllu2sentence(conllu_sentence: conllu.TokenList, - sentence_embedding=None) -> Sentence: + sentence_embedding=None, embeddings=None) -> Sentence: + if embeddings is None: + embeddings = {} if sentence_embedding is None: sentence_embedding = [] tokens = [] - for token in conllu_sentence.tokens: + for idx, token in enumerate(conllu_sentence.tokens): tokens.append( Token( - **token + **token, embeddings=embeddings[idx] ) ) return Sentence( diff --git a/combo/models/base.py b/combo/models/base.py index a5cb5fe..234fbca 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -1,11 +1,10 @@ -from typing import Dict, Optional, List, Union +from typing import Dict, Optional, List, Union, Tuple import torch import torch.nn as nn from allennlp import common, data from allennlp import nn as allen_nn from allennlp.common import checks -from allennlp.modules import feedforward from allennlp.nn import Activation from combo.models import utils @@ -51,7 +50,7 @@ class Linear(nn.Linear, common.FromParams): class FeedForwardPredictor(Predictor): """Feedforward predictor. Should be used on top of Seq2Seq encoder.""" - def __init__(self, feedforward_network: feedforward.FeedForward): + def __init__(self, feedforward_network: "FeedForward"): super().__init__() self.feedforward_network = feedforward_network @@ -63,10 +62,11 @@ class FeedForwardPredictor(Predictor): if mask is None: mask = x.new_ones(x.size()[:-1]) - x = self.feedforward_network(x) + x, feature_maps = self.feedforward_network(x) output = { "prediction": x.argmax(-1), - "probability": x + "probability": x, + "embedding": feature_maps[-1], } if labels is not None: @@ -109,9 +109,112 @@ class FeedForwardPredictor(Predictor): f"There is not {vocab_namespace} in created vocabs, check if this field has any values to predict!" hidden_dims = hidden_dims + [vocab.get_vocab_size(vocab_namespace)] - return cls(feedforward.FeedForward( + return cls(FeedForward( input_dim=input_dim, num_layers=num_layers, hidden_dims=hidden_dims, activations=activations, dropout=dropout)) + + +class FeedForward(torch.nn.Module, common.FromParams): + """ + Modified copy of allennlp.modules.feedforward.FeedForward + + This `Module` is a feed-forward neural network, just a sequence of `Linear` layers with + activation functions in between. + + # Parameters + + input_dim : `int`, required + The dimensionality of the input. We assume the input has shape `(batch_size, input_dim)`. + num_layers : `int`, required + The number of `Linear` layers to apply to the input. + hidden_dims : `Union[int, List[int]]`, required + The output dimension of each of the `Linear` layers. If this is a single `int`, we use + it for all `Linear` layers. If it is a `List[int]`, `len(hidden_dims)` must be + `num_layers`. + activations : `Union[Activation, List[Activation]]`, required + The activation function to use after each `Linear` layer. If this is a single function, + we use it after all `Linear` layers. If it is a `List[Activation]`, + `len(activations)` must be `num_layers`. Activation must have torch.nn.Module type. + dropout : `Union[float, List[float]]`, optional (default = `0.0`) + If given, we will apply this amount of dropout after each layer. Semantics of `float` + versus `List[float]` is the same as with other parameters. + + # Examples + + ```python + FeedForward(124, 2, [64, 32], torch.nn.ReLU(), 0.2) + #> FeedForward( + #> (_activations): ModuleList( + #> (0): ReLU() + #> (1): ReLU() + #> ) + #> (_linear_layers): ModuleList( + #> (0): Linear(in_features=124, out_features=64, bias=True) + #> (1): Linear(in_features=64, out_features=32, bias=True) + #> ) + #> (_dropout): ModuleList( + #> (0): Dropout(p=0.2, inplace=False) + #> (1): Dropout(p=0.2, inplace=False) + #> ) + #> ) + ``` + """ + + def __init__( + self, + input_dim: int, + num_layers: int, + hidden_dims: Union[int, List[int]], + activations: Union[Activation, List[Activation]], + dropout: Union[float, List[float]] = 0.0, + ) -> None: + + super().__init__() + if not isinstance(hidden_dims, list): + hidden_dims = [hidden_dims] * num_layers # type: ignore + if not isinstance(activations, list): + activations = [activations] * num_layers # type: ignore + if not isinstance(dropout, list): + dropout = [dropout] * num_layers # type: ignore + if len(hidden_dims) != num_layers: + raise checks.ConfigurationError( + "len(hidden_dims) (%d) != num_layers (%d)" % (len(hidden_dims), num_layers) + ) + if len(activations) != num_layers: + raise checks.ConfigurationError( + "len(activations) (%d) != num_layers (%d)" % (len(activations), num_layers) + ) + if len(dropout) != num_layers: + raise checks.ConfigurationError( + "len(dropout) (%d) != num_layers (%d)" % (len(dropout), num_layers) + ) + self._activations = torch.nn.ModuleList(activations) + input_dims = [input_dim] + hidden_dims[:-1] + linear_layers = [] + for layer_input_dim, layer_output_dim in zip(input_dims, hidden_dims): + linear_layers.append(torch.nn.Linear(layer_input_dim, layer_output_dim)) + self._linear_layers = torch.nn.ModuleList(linear_layers) + dropout_layers = [torch.nn.Dropout(p=value) for value in dropout] + self._dropout = torch.nn.ModuleList(dropout_layers) + self._output_dim = hidden_dims[-1] + self.input_dim = input_dim + + def get_output_dim(self): + return self._output_dim + + def get_input_dim(self): + return self.input_dim + + def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + + output = inputs + feature_maps = [] + for layer, activation, dropout in zip( + self._linear_layers, self._activations, self._dropout + ): + feature_maps.append(output) + output = dropout(activation(layer(output))) + return output, feature_maps diff --git a/combo/models/model.py b/combo/models/model.py index 9866bcb..c648453 100644 --- a/combo/models/model.py +++ b/combo/models/model.py @@ -129,6 +129,11 @@ class ComboModel(allen_models.Model): "enhanced_head": enhanced_head_pred, "enhanced_deprel": enhanced_relations_pred, "sentence_embedding": torch.max(encoder_emb, dim=1)[0], + "upostag_token_embedding": upos_output["embedding"], + "xpostag_token_embedding": xpos_output["embedding"], + "semrel_token_embedding": semrel_output["embedding"], + "feats_token_embedding": morpho_output["embedding"], + "deprel_token_embedding": parser_output["embedding"], } if "rel_probability" in enhanced_parser_output: @@ -196,8 +201,8 @@ class ComboModel(allen_models.Model): if callable_model: return callable_model(*args, **kwargs) if returns_tuple: - return {"prediction": (None, None), "loss": (None, None)} - return {"prediction": None, "loss": None} + return {"prediction": (None, None), "loss": (None, None), "embedding": (None, None)} + return {"prediction": None, "loss": None, "embedding": None} @staticmethod def _clean(output): diff --git a/combo/models/morpho.py b/combo/models/morpho.py index ea3451d..b0d3079 100644 --- a/combo/models/morpho.py +++ b/combo/models/morpho.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Union import torch from allennlp import data from allennlp.common import checks -from allennlp.modules import feedforward from allennlp.nn import Activation from combo.data import dataset @@ -15,7 +14,7 @@ from combo.models import base, utils class MorphologicalFeatures(base.Predictor): """Morphological features predicting model.""" - def __init__(self, feedforward_network: feedforward.FeedForward, slices: Dict[str, List[int]]): + def __init__(self, feedforward_network: base.FeedForward, slices: Dict[str, List[int]]): super().__init__() self.feedforward_network = feedforward_network self.slices = slices @@ -28,7 +27,7 @@ class MorphologicalFeatures(base.Predictor): if mask is None: mask = x.new_ones(x.size()[:-1]) - x = self.feedforward_network(x) + x, feature_maps = self.feedforward_network(x) prediction = [] for _, cat_indices in self.slices.items(): @@ -36,7 +35,8 @@ class MorphologicalFeatures(base.Predictor): output = { "prediction": torch.stack(prediction, dim=-1), - "probability": x + "probability": x, + "embedding": feature_maps[-1], } if labels is not None: @@ -92,7 +92,7 @@ class MorphologicalFeatures(base.Predictor): slices = dataset.get_slices_if_not_provided(vocab) return cls( - feedforward_network=feedforward.FeedForward( + feedforward_network=base.FeedForward( input_dim=input_dim, num_layers=num_layers, hidden_dims=hidden_dims, diff --git a/combo/models/parser.py b/combo/models/parser.py index 511edff..b16f0ad 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -153,6 +153,7 @@ class DependencyRelationModel(base.Predictor): dep_rel_pred = torch.cat((dep_rel_pred, dep_rel_emb), dim=-1) relation_prediction = self.relation_prediction_layer(dep_rel_pred) output = head_output + output["embedding"] = dep_rel_pred if self.training: output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) diff --git a/combo/predict.py b/combo/predict.py index e528a18..f580c01 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -82,8 +82,8 @@ class COMBO(predictor.Predictor): sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree, sentence_embedding = self._predictions_as_tree(prediction, instance) - sentence = conllu2sentence(tree, sentence_embedding) + tree, sentence_embedding, embeddings = self._predictions_as_tree(prediction, instance) + sentence = conllu2sentence(tree, sentence_embedding, embeddings) sentences.append(sentence) return sentences @@ -96,8 +96,8 @@ class COMBO(predictor.Predictor): @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence: predictions = super().predict_instance(instance) - tree, sentence_embedding = self._predictions_as_tree(predictions, instance) - return conllu2sentence(tree, sentence_embedding) + tree, sentence_embedding, embeddings = self._predictions_as_tree(predictions, instance, ) + return conllu2sentence(tree, sentence_embedding, embeddings) @overrides def predict_json(self, inputs: common.JsonDict) -> data.Sentence: @@ -141,6 +141,7 @@ class COMBO(predictor.Predictor): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] + embeddings = [{} for _ in range(len(tree_tokens))] for field_name in field_names: if field_name not in predictions: continue @@ -149,6 +150,7 @@ class COMBO(predictor.Predictor): if field_name in {"xpostag", "upostag", "semrel", "deprel"}: value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels") token[field_name] = value + embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "head": token[field_name] = int(field_predictions[idx]) elif field_name == "deps": @@ -174,6 +176,7 @@ class COMBO(predictor.Predictor): field_value = "|".join(np.array(features)[arg_indices].tolist()) token[field_name] = field_value + embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "lemma": prediction = field_predictions[idx] word_chars = [] @@ -206,7 +209,7 @@ class COMBO(predictor.Predictor): empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) - return tree, predictions["sentence_embedding"] + return tree, predictions["sentence_embedding"], embeddings @classmethod def with_spacy_tokenizer(cls, model: models.Model, -- GitLab From 9c283e8c4affb8e74f929e9e510abfa4d6a7d456 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 7 Apr 2021 07:46:33 +0200 Subject: [PATCH 02/28] Fix embeddings mapping during the evaluation step. --- combo/data/api.py | 5 +++-- combo/predict.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index bfec5ee..308e9e4 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -55,6 +55,7 @@ def sentence2conllu(sentence: Sentence, keep_semrel: bool = True) -> conllu.Toke # Remove semrel to have default conllu format. if not keep_semrel: del token_dict["semrel"] + del token_dict["embeddings"] tokens.append(token_dict) # Range tokens must be tuple not list, this is conllu library requirement for t in tokens: @@ -84,10 +85,10 @@ def conllu2sentence(conllu_sentence: conllu.TokenList, if sentence_embedding is None: sentence_embedding = [] tokens = [] - for idx, token in enumerate(conllu_sentence.tokens): + for token in conllu_sentence.tokens: tokens.append( Token( - **token, embeddings=embeddings[idx] + **token, embeddings=embeddings[token["id"]] ) ) return Sentence( diff --git a/combo/predict.py b/combo/predict.py index f580c01..5945e6d 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -141,7 +141,7 @@ class COMBO(predictor.Predictor): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] - embeddings = [{} for _ in range(len(tree_tokens))] + embeddings = {t["id"]: {} for t in tree} for field_name in field_names: if field_name not in predictions: continue @@ -150,7 +150,7 @@ class COMBO(predictor.Predictor): if field_name in {"xpostag", "upostag", "semrel", "deprel"}: value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels") token[field_name] = value - embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] + embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "head": token[field_name] = int(field_predictions[idx]) elif field_name == "deps": @@ -176,7 +176,7 @@ class COMBO(predictor.Predictor): field_value = "|".join(np.array(features)[arg_indices].tolist()) token[field_name] = field_value - embeddings[idx][field_name] = predictions[f"{field_name}_token_embedding"][idx] + embeddings[token["id"]][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "lemma": prediction = field_predictions[idx] word_chars = [] -- GitLab From eb3de82c30273d432fdd4004b524a3ac9dd2e8ed Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 12 Apr 2021 14:36:13 +0200 Subject: [PATCH 03/28] Remove training comments. --- scripts/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 950ee82..b75bbed 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -41,7 +41,6 @@ TREEBANKS = [ "UD_Czech-PUD", "UD_Danish-DDT", "UD_Dutch-Alpino", - #END OF FIRST RUN "UD_English-EWT", # "UD_Erzya-JR", No training data "UD_Estonian-EWT", @@ -104,7 +103,6 @@ TREEBANKS = [ "UD_Latvian-LVTB", "UD_Lithuanian-ALKSNIS", "UD_Lithuanian-HSE", - # end batch 2 "UD_Maltese-MUDT", # "UD_Manx-Cadhan", No training data "UD_Marathi-UFAL", -- GitLab From 6fde9305db5188f81eaccd3e6fe396cb3dc11513 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 12 Apr 2021 15:41:52 +0200 Subject: [PATCH 04/28] Add default BERTs for IWPT'21 languages. --- scripts/utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/scripts/utils.py b/scripts/utils.py index 6ce5a8a..0ec8725 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -10,7 +10,20 @@ LANG2TRANSFORMER = { "de": "dbmdz/bert-base-german-cased", "ar": "aubmindlab/bert-base-arabertv2", "eu": "ixa-ehu/berteus-base-cased", - "tr": "dbmdz/bert-base-turkish-cased" + "tr": "dbmdz/bert-base-turkish-cased", + "bg": "iarfmoose/roberta-base-bulgarian", + "nl": "GroNLP/bert-base-dutch-cased", + "fr": "camembert-base", + "it": "dbmdz/bert-base-italian-cased", + "ru": "blinoff/roberta-base-russian-v0", + "sv": "KB/bert-base-swedish-cased", + # "uk": http://dl.turkunlp.org/wikibert/wikibert-base-uk-cased/ + # "ta": http://dl.turkunlp.org/wikibert/wikibert-base-ta-cased/ + # "sk": http://dl.turkunlp.org/wikibert/wikibert-base-sl-cased/ + # "lt": http://dl.turkunlp.org/wikibert/wikibert-base-lt-cased/ + # "lv": http://dl.turkunlp.org/wikibert/wikibert-base-lv-cased/ + # "et": http://dl.turkunlp.org/estonian-bert/etwiki-bert/pytorch/ + # "cs": https://github.com/kiv-air/Czert https://arxiv.org/pdf/2103.13031.pdf } -- GitLab From 1a9d7362b7df464433a36c14b9f5dec421bcdf58 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 14 Apr 2021 11:06:29 +0200 Subject: [PATCH 05/28] Add IWPT'21 shared task training script. Add information about UD tools repo. --- docs/training.md | 2 + scripts/train_iwpt21.py | 134 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 scripts/train_iwpt21.py diff --git a/docs/training.md b/docs/training.md index 7d7b0a8..2cdf75d 100644 --- a/docs/training.md +++ b/docs/training.md @@ -51,6 +51,8 @@ Enhanced Dependencies are described [here](https://universaldependencies.org/u/o ### Data pre-processing The organisers of [IWPT20 shared task](https://universaldependencies.org/iwpt20/data.html) distributed the data sets and a data pre-processing script `enhanced_collapse_empty_nodes.pl`. If you wish to train a model on IWPT20 data, apply this script to the training and validation data sets, before training the COMBO EUD model. +The script is part of the [UD tools repository](https://github.com/UniversalDependencies/tools/). + ```bash perl enhanced_collapse_empty_nodes.pl training.conllu > training.fixed.conllu ``` diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py new file mode 100644 index 0000000..5082f9d --- /dev/null +++ b/scripts/train_iwpt21.py @@ -0,0 +1,134 @@ +"""Script to train Enhanced Dependency Parsing models based on IWPT'21 Shared Task data. + +For possible requirements, see train_eud.py comments. +""" + +import os +import pathlib +from typing import List + +from absl import app +from absl import flags + +from scripts import utils + +LANG2TREEBANK = { + "ar": ["Arabic-PADT"], + "bg": ["Bulgarian-BTB"], + "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"], + "nl": ["Dutch-Alpino", "Dutch-LassySmall"], + "en": ["English-EWT", "English-PUD", "UD_English-GUM"], + "et": ["Estonian-EDT", "Estonian-EWT"], + "fi": ["Finnish-TDT", "Finnish-PUD"], + "fr": ["French-Sequoia", "French-FQB"], + "it": ["Italian-ISDT"], + "lv": ["Latvian-LVTB"], + "lt": ["Lithuanian-ALKSNIS"], + "pl": ["Polish-LFG", "Polish-PDB", "Polish-PUD"], + "ru": ["Russian-SynTagRus"], + "sk": ["Slovak-SNK"], + "sv": ["Swedish-Talbanken", "Swedish-PUD"], + "ta": ["Tamil-TTB"], + "uk": ["Ukrainian-IU"], +} + +FLAGS = flags.FLAGS +flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()), + help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.") +flags.DEFINE_string(name="data_dir", default="", + help="Path to 'iwpt2020stdata' directory.") +flags.DEFINE_string(name="serialization_dir", default="/tmp/", + help="Model serialization dir.") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") + + +def path_to_str(path: pathlib.Path) -> str: + return str(path.resolve()) + + +def merge_files(files: List[str], output: pathlib.Path): + if not output.exists(): + os.system(f"cat {' '.join(files)} > {output}") + + +def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) + + +def run(_): + languages = FLAGS.lang + for lang in languages: + assert lang in LANG2TREEBANK, f"'{lang}' must be one of {list(LANG2TREEBANK.keys())}." + assert lang in utils.LANG2TRANSFORMER, f"Transformer for '{lang}' isn't defined. See 'LANG2TRANSFORMER' dict." + data_dir = pathlib.Path(FLAGS.data_dir) + assert data_dir.is_dir(), f"'{data_dir}' is not a directory!" + + treebanks = LANG2TREEBANK[lang] + train_paths = [] + dev_paths = [] + + # TODO Uncomment when IWPT'21 Shared Task ends. + # During shared task duration test data is not available. + test_paths = [] + for treebank in treebanks: + treebank_dir = data_dir / f"UD_{treebank}" + assert treebank_dir.exists() and treebank_dir.is_dir(), f"'{treebank_dir}' directory doesn't exists." + for treebank_file in treebank_dir.iterdir(): + name = treebank_file.name + if "conllu" in name and "fixed" not in name: + output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') + if "train" in name: + collapse_nodes(data_dir, treebank_file, output) + train_paths.append(output) + elif "dev" in name: + collapse_nodes(data_dir, treebank_file, output) + dev_paths.append(output) + # elif "test" in name: + # collapse_nodes(data_dir, treebank_file, output) + # test_paths.append(output) + + lang_data_dir = pathlib.Path(data_dir / lang) + lang_data_dir.mkdir(exist_ok=True) + + train_path = lang_data_dir / "train.conllu" + dev_path = lang_data_dir / "dev.conllu" + # TODO Uncomment + # test_path = lang_data_dir / "test.conllu" + + merge_files(train_paths, output=train_path) + merge_files(dev_paths, output=dev_path) + # TODO Uncomment + # merge_files(test_paths, output=test_path) + + serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang + serialization_dir.mkdir(exist_ok=True, parents=True) + + command = f"""combo --mode train + --training_data {train_path} + --validation_data {dev_path} + --targets feats,upostag,xpostag,head,deprel,lemma,deps + --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} + --serialization_dir {serialization_dir} + --cuda_device {FLAGS.cuda_device} + --word_batch_size 2500 + --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --notensorboard + """ + + # Datasets without XPOS + if lang in {"fr"}: + command = command + " --targets deprel,head,upostag,lemma,feats" + + utils.execute_command("".join(command.splitlines())) + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() -- GitLab From 34dd4929200d52592260960ee30d8fe5f57f7c2d Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 16 Apr 2021 12:12:30 +0200 Subject: [PATCH 06/28] Fix treebank name typo. --- scripts/train_iwpt21.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 5082f9d..c6310ea 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -17,7 +17,7 @@ LANG2TREEBANK = { "bg": ["Bulgarian-BTB"], "cs": ["Czech-FicTree", "Czech-CAC", "Czech-PDT", "Czech-PUD"], "nl": ["Dutch-Alpino", "Dutch-LassySmall"], - "en": ["English-EWT", "English-PUD", "UD_English-GUM"], + "en": ["English-EWT", "English-PUD", "English-GUM"], "et": ["Estonian-EDT", "Estonian-EWT"], "fi": ["Finnish-TDT", "Finnish-PUD"], "fr": ["French-Sequoia", "French-FQB"], -- GitLab From 57c432432b1b562a04de3239dccd47ed9ce38754 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mklimasz@e-science.pl> Date: Thu, 29 Apr 2021 08:08:18 +0200 Subject: [PATCH 07/28] Change tensor constructor. --- combo/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 870659f..29b1f91 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -204,7 +204,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): default_value = [0.0] * classes_count padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens, lambda: default_value) - tensor = torch.LongTensor(padded_tags) + tensor = torch.tensor(padded_tags, dtype=torch.long) return tensor return as_tensor -- GitLab From 3123cced1da128c3296aabfde22a7940125d984e Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 16 Apr 2021 12:25:46 +0200 Subject: [PATCH 08/28] Local paths for transformers models. --- scripts/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/utils.py b/scripts/utils.py index 0ec8725..f1d03fe 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -17,9 +17,15 @@ LANG2TRANSFORMER = { "it": "dbmdz/bert-base-italian-cased", "ru": "blinoff/roberta-base-russian-v0", "sv": "KB/bert-base-swedish-cased", + "uk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-uk-cased/", + "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/", + "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", + "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", + "cs": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-cs-cased/", + "et": "/tmp/lustre_shared/mklimasz/transformers/etwiki-bert/", # "uk": http://dl.turkunlp.org/wikibert/wikibert-base-uk-cased/ # "ta": http://dl.turkunlp.org/wikibert/wikibert-base-ta-cased/ - # "sk": http://dl.turkunlp.org/wikibert/wikibert-base-sl-cased/ + # "sk": http://dl.turkunlp.org/wikibert/wikibert-base-sk-cased/ # "lt": http://dl.turkunlp.org/wikibert/wikibert-base-lt-cased/ # "lv": http://dl.turkunlp.org/wikibert/wikibert-base-lv-cased/ # "et": http://dl.turkunlp.org/estonian-bert/etwiki-bert/pytorch/ -- GitLab From a3774ab6734891874c203d013eb150b99dba196b Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 30 Apr 2021 07:48:20 +0200 Subject: [PATCH 09/28] Tree and graph merging algorithm. --- combo/predict.py | 21 +++++++++--- combo/utils/graph.py | 76 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/combo/predict.py b/combo/predict.py index 5945e6d..5f1aaed 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -201,11 +201,22 @@ class COMBO(predictor.Predictor): h = np.concatenate((h[-1:], h[:-1])) r = np.array(predictions["enhanced_deprel_prob"]) r = np.concatenate((r[-1:], r[:-1])) - graph.sdp_to_dag_deps(arc_scores=h, - rel_scores=r, - tree_tokens=tree_tokens, - root_idx=self.vocab.get_token_index("root", "deprel_labels"), - vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) + + graph.graph_and_tree_merge( + tree_arc_scores=predictions["head"], + tree_rel_scores=predictions["deprel"], + graph_arc_scores=h, + graph_rel_scores=r, + idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), + label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"), + tokens=tree_tokens + ) + + # graph.sdp_to_dag_deps(arc_scores=h, + # rel_scores=r, + # tree_tokens=tree_tokens, + # root_idx=self.vocab.get_token_index("root", "deprel_labels"), + # vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 3352625..c9ad07e 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -3,6 +3,82 @@ from typing import List import numpy as np +_ACL_REL_CL = "acl:relcl" + + +def graph_and_tree_merge(tree_arc_scores, + tree_rel_scores, + graph_arc_scores, + graph_rel_scores, + label2idx, + idx2label, + tokens): + graph_arc_scores = np.copy(graph_arc_scores) + # Exclude self-loops, in-place operation. + np.fill_diagonal(graph_arc_scores, 0) + # Connection to root will be handled by tree. + graph_arc_scores[:, 0] = False + # The same with labels. + root_idx = label2idx["root"] + graph_rel_scores[:, :, root_idx] = -float('inf') + graph_rel_pred = graph_rel_scores.argmax(-1) + + # Add tree edges to graph + tree_heads = [0] + tree_arc_scores + graph = [[] for _ in range(len(tree_heads))] + labeled_graph = [[] for _ in range(len(tree_heads))] + for d, h in enumerate(tree_heads): + if not d: + continue + label = idx2label[tree_rel_scores[d - 1]] + if label != _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + # Debug only + # Extract graph edges + graph_edges = np.argwhere(graph_arc_scores) + + # Add graph edges which aren't creating a cycle + for (d, h) in graph_edges: + if not d or not h or d in graph[h]: + continue + try: + path = next(_dfs(graph, d, h)) + except StopIteration: + # There is not path from d to h + label = idx2label[graph_rel_pred[d][h]] + if label != _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + # Add 'acl:relcl' without checking for cycles. + for d, h in enumerate(tree_heads): + if not d: + continue + label = idx2label[tree_rel_scores[d - 1]] + if label == _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + assert len(labeled_graph[0]) == 1 + d = graph[0][0] + graph[d].append(0) + labeled_graph[d].append((0, "root")) + + parse_graph = [[] for _ in range(len(tree_heads))] + for h in range(len(tree_heads)): + for d, label in labeled_graph[h]: + parse_graph[d].append((h, label)) + parse_graph[d] = sorted(parse_graph[d]) + + for i, g in enumerate(parse_graph): + heads = [x[0] for x in g] + rels = [x[1] for x in g] + deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) + tokens[i - 1]["deps"] = deps + return + def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None: # adding ROOT -- GitLab From d3fe22d2b55d954bc5ba5387a1741e322da1764c Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 30 Apr 2021 09:50:40 +0200 Subject: [PATCH 10/28] Split vocabulary for graph and tree parsing. --- combo/config.graph.template.jsonnet | 2 +- combo/data/dataset.py | 2 +- combo/predict.py | 2 ++ combo/utils/graph.py | 6 ++++-- scripts/train_iwpt21.py | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index c0c4696..c72a057 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -303,7 +303,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, enhanced_dependency_relation: if in_targets("deps") then { type: "combo_graph_dependency_parsing_from_vocab", - vocab_namespace: 'deprel_labels', + vocab_namespace: 'enhanced_deprel_labels', head_predictor: { local projection_dim = 512, cycle_loss_n: cycle_loss_n, diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 29b1f91..bdc8b20 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -149,7 +149,7 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): sequence_field=text_field_deps, labels=enhanced_deprels, # Label namespace matches regular tree parsing. - label_namespace="deprel_labels", + label_namespace="enhanced_deprel_labels", padding_value=0, ) else: diff --git a/combo/predict.py b/combo/predict.py index 5f1aaed..18a63a4 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -209,6 +209,8 @@ class COMBO(predictor.Predictor): graph_rel_scores=r, idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), label2idx=self.vocab.get_token_to_index_vocabulary("deprel_labels"), + graph_idx2label=self.vocab.get_index_to_token_vocabulary("enhanced_deprel_labels"), + graph_label2idx=self.vocab.get_token_to_index_vocabulary("enhanced_deprel_labels"), tokens=tree_tokens ) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index c9ad07e..86dd98f 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -12,6 +12,8 @@ def graph_and_tree_merge(tree_arc_scores, graph_rel_scores, label2idx, idx2label, + graph_label2idx, + graph_idx2label, tokens): graph_arc_scores = np.copy(graph_arc_scores) # Exclude self-loops, in-place operation. @@ -19,7 +21,7 @@ def graph_and_tree_merge(tree_arc_scores, # Connection to root will be handled by tree. graph_arc_scores[:, 0] = False # The same with labels. - root_idx = label2idx["root"] + root_idx = graph_label2idx["root"] graph_rel_scores[:, :, root_idx] = -float('inf') graph_rel_pred = graph_rel_scores.argmax(-1) @@ -47,7 +49,7 @@ def graph_and_tree_merge(tree_arc_scores, path = next(_dfs(graph, d, h)) except StopIteration: # There is not path from d to h - label = idx2label[graph_rel_pred[d][h]] + label = graph_idx2label[graph_rel_pred[d][h]] if label != _ACL_REL_CL: graph[h].append(d) labeled_graph[h].append((d, label)) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index c6310ea..8b077c3 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -115,7 +115,7 @@ def run(_): --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} --word_batch_size 2500 - --config_path {pathlib.Path.cwd() / 'config.graph.template.jsonnet'} + --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """ -- GitLab From d12e5ec70ae5d778076be90c30d7f28a73fc027a Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 5 May 2021 15:05:28 +0200 Subject: [PATCH 11/28] Add lv local model. Merge raw txt files. --- scripts/train_iwpt21.py | 34 +++++++++++++++++++++++++--------- scripts/utils.py | 1 + 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 8b077c3..737a7b8 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -36,7 +36,7 @@ FLAGS = flags.FLAGS flags.DEFINE_list(name="lang", default=list(LANG2TREEBANK.keys()), help=f"Language of models to train. Possible values: {LANG2TREEBANK.keys()}.") flags.DEFINE_string(name="data_dir", default="", - help="Path to 'iwpt2020stdata' directory.") + help="Path to IWPT'21 data directory.") flags.DEFINE_string(name="serialization_dir", default="/tmp/", help="Model serialization dir.") flags.DEFINE_integer(name="cuda_device", default=-1, @@ -68,9 +68,11 @@ def run(_): assert data_dir.is_dir(), f"'{data_dir}' is not a directory!" treebanks = LANG2TREEBANK[lang] + full_language = treebanks[0].split("-")[0] train_paths = [] dev_paths = [] - + train_raw_paths = [] + dev_raw_paths = [] # TODO Uncomment when IWPT'21 Shared Task ends. # During shared task duration test data is not available. test_paths = [] @@ -90,19 +92,33 @@ def run(_): # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) # test_paths.append(output) + if ".txt" in name: + if "train" in name: + train_raw_paths.append(path_to_str(treebank_file)) + elif "dev" in name: + dev_raw_paths.append(path_to_str(treebank_file)) - lang_data_dir = pathlib.Path(data_dir / lang) + merged_dataset_name = "IWPT" + lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}") lang_data_dir.mkdir(exist_ok=True) - train_path = lang_data_dir / "train.conllu" - dev_path = lang_data_dir / "dev.conllu" - # TODO Uncomment - # test_path = lang_data_dir / "test.conllu" + suffix = f"{lang}_{merged_dataset_name}-ud".lower() + train_path = lang_data_dir / f"{suffix}-train.conllu" + dev_path = lang_data_dir / f"{suffix}-dev.conllu" + test_path = lang_data_dir / f"{suffix}-test.conllu" + train_raw_path = lang_data_dir / f"{suffix}-train.txt" + dev_raw_path = lang_data_dir / f"{suffix}-dev.txt" + test_raw_path = lang_data_dir / f"{suffix}-test.txt" merge_files(train_paths, output=train_path) merge_files(dev_paths, output=dev_path) - # TODO Uncomment - # merge_files(test_paths, output=test_path) + # TODO Change to test_paths instead of dev_paths after IWPT'21 + merge_files(dev_paths, output=test_path) + + merge_files(train_raw_paths, output=train_raw_path) + merge_files(dev_raw_paths, output=dev_raw_path) + # TODO Change to test_raw_paths instead of dev_paths after IWPT'21 + merge_files(dev_raw_paths, output=test_raw_path) serialization_dir = pathlib.Path(FLAGS.serialization_dir) / lang serialization_dir.mkdir(exist_ok=True, parents=True) diff --git a/scripts/utils.py b/scripts/utils.py index f1d03fe..bbbe2fe 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -21,6 +21,7 @@ LANG2TRANSFORMER = { "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/", "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", + "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/", "cs": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-cs-cased/", "et": "/tmp/lustre_shared/mklimasz/transformers/etwiki-bert/", # "uk": http://dl.turkunlp.org/wikibert/wikibert-base-uk-cased/ -- GitLab From 942205828377bc3abfd65f6212b0e051b3f1c787 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 6 May 2021 08:12:49 +0200 Subject: [PATCH 12/28] Exclude xpos in ru models. --- scripts/train_iwpt21.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 737a7b8..17737c9 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -136,7 +136,7 @@ def run(_): """ # Datasets without XPOS - if lang in {"fr"}: + if lang in {"fr", "ru"}: command = command + " --targets deprel,head,upostag,lemma,feats" utils.execute_command("".join(command.splitlines())) -- GitLab From 0f6faf2af50399e7c9ea8fb775db6bfb0563c656 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 6 May 2021 08:14:58 +0200 Subject: [PATCH 13/28] Remove emorynlp merging algorithm. --- combo/predict.py | 5 -- combo/utils/graph.py | 71 ------------------------- tests/utils/test_graph.py | 106 -------------------------------------- 3 files changed, 182 deletions(-) delete mode 100644 tests/utils/test_graph.py diff --git a/combo/predict.py b/combo/predict.py index 18a63a4..c710224 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -214,11 +214,6 @@ class COMBO(predictor.Predictor): tokens=tree_tokens ) - # graph.sdp_to_dag_deps(arc_scores=h, - # rel_scores=r, - # tree_tokens=tree_tokens, - # root_idx=self.vocab.get_token_index("root", "deprel_labels"), - # vocab_index=self.vocab.get_index_to_token_vocabulary("deprel_labels")) empty_tokens = graph.restore_collapse_edges(tree_tokens) tree.tokens.extend(empty_tokens) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 86dd98f..3995f4f 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -1,5 +1,4 @@ """Based on https://github.com/emorynlp/iwpt-shared-task-2020.""" -from typing import List import numpy as np @@ -82,76 +81,6 @@ def graph_and_tree_merge(tree_arc_scores, return -def sdp_to_dag_deps(arc_scores, rel_scores, tree_tokens: List, root_idx=0, vocab_index=None) -> None: - # adding ROOT - tree_heads = [0] + [t["head"] for t in tree_tokens] - graph = adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, - root_idx) - for i, (t, g) in enumerate(zip(tree_heads, graph)): - if not i: - continue - rels = [vocab_index.get(x[1], "root") if vocab_index else x[1] for x in g] - heads = [x[0] for x in g] - head = tree_tokens[i - 1]["head"] - index = heads.index(head) - deprel = tree_tokens[i - 1]["deprel"] - deprel = deprel.split('>')[-1] - # TODO - Consider if there should be a condition, - # It doesn't seem to make any sense as DEPS should contain DEPREL - # (although sometimes with different/more detailed label) - # if len(heads) >= 2: - # heads.pop(index) - # rels.pop(index) - deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) - tree_tokens[i - 1]["deps"] = deps - tree_tokens[i - 1]["deprel"] = deprel - return - - -def adjust_root_score_then_add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx): - if len(arc_scores) != tree_heads: - arc_scores = arc_scores[:len(tree_heads)][:len(tree_heads)] - rel_scores = rel_scores[:len(tree_heads)][:len(tree_heads)] - # Self-loops aren't allowed, mask with 0. This is an in-place operation. - np.fill_diagonal(arc_scores, 0) - parse_preds = np.array(arc_scores) > 0 - parse_preds[:, 0] = False # set heads to False - rel_scores[:, :, root_idx] = -float('inf') - return add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds) - - -def add_secondary_arcs(arc_scores, rel_scores, tree_heads, root_idx, parse_preds): - if not isinstance(tree_heads, np.ndarray): - tree_heads = np.array(tree_heads) - dh = np.argwhere(parse_preds) - sdh = sorted([(arc_scores[x[0]][x[1]], list(x)) for x in dh], reverse=True) - graph = [[] for _ in range(len(tree_heads))] - rel_pred = np.argmax(rel_scores, axis=-1) - for d, h in enumerate(tree_heads): - if d: - graph[h].append(d) - for s, (d, h) in sdh: - if not d or not h or d in graph[h]: - continue - try: - path = next(_dfs(graph, d, h)) - except StopIteration: - # no path from d to h - graph[h].append(d) - parse_graph = [[] for _ in range(len(tree_heads))] - num_root = 0 - for h in range(len(tree_heads)): - for d in graph[h]: - rel = rel_pred[d][h] - if h == 0: - rel = root_idx - assert num_root == 0 - num_root += 1 - parse_graph[d].append((h, rel)) - parse_graph[d] = sorted(parse_graph[d]) - return parse_graph - - def _dfs(graph, start, end): fringe = [(start, [])] while fringe: diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py deleted file mode 100644 index 74e3744..0000000 --- a/tests/utils/test_graph.py +++ /dev/null @@ -1,106 +0,0 @@ -import unittest -import combo.utils.graph as graph - -import conllu -import numpy as np - - -class GraphTest(unittest.TestCase): - - def test_adding_empty_graph_with_the_same_labels(self): - tree = conllu.TokenList( - tokens=[ - {"head": 0, "deprel": "root", "form": "word1"}, - {"head": 3, "deprel": "yes", "form": "word2"}, - {"head": 1, "deprel": "yes", "form": "word3"}, - ] - ) - vocab_index = {0: "root", 1: "yes", 2: "yes", 3: "yes"} - empty_graph = np.zeros((4, 4)) - graph_labels = np.zeros((4, 4, 4)) - expected_deps = ["0:root", "3:yes", "1:yes"] - - # when - graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) - actual_deps = [t["deps"] for t in tree.tokens] - - # then - self.assertEqual(expected_deps, actual_deps) - - def test_adding_empty_graph_with_different_labels(self): - tree = conllu.TokenList( - tokens=[ - {"head": 0, "deprel": "root", "form": "word1"}, - {"head": 3, "deprel": "tree_label", "form": "word2"}, - {"head": 1, "deprel": "tree_label", "form": "word3"}, - ] - ) - vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"} - empty_graph = np.zeros((4, 4)) - graph_labels = np.zeros((4, 4, 3)) - graph_labels[2][3][2] = 10e10 - graph_labels[3][1][2] = 10e10 - expected_deps = ["0:root", "3:graph_label", "1:graph_label"] - - # when - graph.sdp_to_dag_deps(empty_graph, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) - actual_deps = [t["deps"] for t in tree.tokens] - - # then - self.assertEqual(actual_deps, expected_deps) - - def test_extending_tree_with_graph(self): - # given - tree = conllu.TokenList( - tokens=[ - {"head": 0, "deprel": "root", "form": "word1"}, - {"head": 1, "deprel": "tree_label", "form": "word2"}, - {"head": 2, "deprel": "tree_label", "form": "word3"}, - ] - ) - vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"} - arc_scores = np.array([ - [0, 0, 0, 0], - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 1, 1, 0], - ]) - graph_labels = np.zeros((4, 4, 3)) - graph_labels[3][1][2] = 10e10 - expected_deps = ["0:root", "1:tree_label", "1:graph_label|2:tree_label"] - - # when - graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) - actual_deps = [t["deps"] for t in tree.tokens] - - # then - self.assertEqual(actual_deps, expected_deps) - - def test_extending_tree_with_self_loop_edge_shouldnt_add_edge(self): - # given - tree = conllu.TokenList( - tokens=[ - {"head": 0, "deprel": "root", "form": "word1"}, - {"head": 1, "deprel": "tree_label", "form": "word2"}, - {"head": 2, "deprel": "tree_label", "form": "word3"}, - ] - ) - vocab_index = {0: "root", 1: "tree_label", 2: "graph_label"} - arc_scores = np.array([ - [0, 0, 0, 0], - [1, 0, 0, 0], - [0, 1, 0, 0], - [0, 0, 1, 1], - ]) - graph_labels = np.zeros((4, 4, 3)) - graph_labels[3][3][2] = 10e10 - expected_deps = ["0:root", "1:tree_label", "2:tree_label"] - # TODO current actual, adds self-loop - # actual_deps = ["0:root", "1:tree_label", "2:tree_label|3:graph_label"] - - # when - graph.sdp_to_dag_deps(arc_scores, graph_labels, tree.tokens, root_idx=0, vocab_index=vocab_index) - actual_deps = [t["deps"] for t in tree.tokens] - - # then - self.assertEqual(expected_deps, actual_deps) -- GitLab From 6900d4f0fd51dd9e6718d7f8091b09146ba3dff8 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 6 May 2021 10:57:09 +0200 Subject: [PATCH 14/28] Add IWPT'21 evaluation script. --- scripts/evaluate_iwpt21.py | 87 ++++++++++++++++++++++++++++++++++++++ scripts/train_iwpt21.py | 21 +++------ scripts/utils.py | 12 ++++++ 3 files changed, 104 insertions(+), 16 deletions(-) create mode 100644 scripts/evaluate_iwpt21.py diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py new file mode 100644 index 0000000..b67541f --- /dev/null +++ b/scripts/evaluate_iwpt21.py @@ -0,0 +1,87 @@ +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +CODE2LANG = { + "ar": "Arabic", + "bg": "Bulgarian", + "cs": "Czech", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "it": "Italian", + "lv": "Latvian", + "lt": "Lithuanian", + "pl": "Polish", + "ru": "Russian", + "sk": "Slovak", + "sv": "Swedish", + "ta": "Tamil", + "uk": "Ukrainian", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="data_dir", default="", + help="Path to IWPT'21 data directory.") +flags.DEFINE_string(name="models_dir", default="/tmp/", + help="Model serialization dir.") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") +flags.DEFINE_string(name="evaluate_script_path", default="iwpt21_xud_eval.py", + help="Path to 'iwpt21_xud_eval.py' eval script.") +flags.DEFINE_boolean(name="expect_prefix", default=True, + help="Whether to expect allennlp prefix.") + + +def run(_): + models_dir = pathlib.Path(FLAGS.models_dir) + for model_dir in models_dir.iterdir(): + if model_dir.name not in CODE2LANG: + print("Skipping unknown directory: ", model_dir.name) + continue + + treebank_name = f"UD_{CODE2LANG[model_dir.name]}-IWPT" + + if FLAGS.expect_prefix: + model_dir = list(model_dir.iterdir()) + assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}" + model_dir = model_dir[0] + + treebank_dir = pathlib.Path(FLAGS.data_dir) / treebank_name + files = list(treebank_dir.iterdir()) + + test_file = [f for f in files if "dev" in f.name and ".conllu" in f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] + + if not (model_dir / "results.txt").exists(): + output_pred = model_dir / 'predictions.conllu' + command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} + --input_file {test_file} + --output_file {output_pred} + --cuda_device {FLAGS.cuda_device} + --silent + """ + utils.execute_command(command) + + output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu') + utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed) + + command = f"""python {FLAGS.evaluate_script_path} -v + {test_file} + {output_collapsed} + """ + utils.execute_command(command, output_file=model_dir / "results.txt") + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 17737c9..e4705f7 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -43,22 +43,11 @@ flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") -def path_to_str(path: pathlib.Path) -> str: - return str(path.resolve()) - - def merge_files(files: List[str], output: pathlib.Path): if not output.exists(): os.system(f"cat {' '.join(files)} > {output}") -def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): - output_path = pathlib.Path(output) - if not output_path.exists(): - utils.execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " - f"{path_to_str(treebank_file)}", output) - - def run(_): languages = FLAGS.lang for lang in languages: @@ -82,21 +71,21 @@ def run(_): for treebank_file in treebank_dir.iterdir(): name = treebank_file.name if "conllu" in name and "fixed" not in name: - output = path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') + output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') if "train" in name: - collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir, treebank_file, output) train_paths.append(output) elif "dev" in name: - collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir, treebank_file, output) dev_paths.append(output) # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) # test_paths.append(output) if ".txt" in name: if "train" in name: - train_raw_paths.append(path_to_str(treebank_file)) + train_raw_paths.append(utils.path_to_str(treebank_file)) elif "dev" in name: - dev_raw_paths.append(path_to_str(treebank_file)) + dev_raw_paths.append(utils.path_to_str(treebank_file)) merged_dataset_name = "IWPT" lang_data_dir = pathlib.Path(data_dir / f"UD_{full_language}-{merged_dataset_name}") diff --git a/scripts/utils.py b/scripts/utils.py index bbbe2fe..19808ad 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -1,4 +1,5 @@ """Utils for scripts.""" +import pathlib import subprocess LANG2TRANSFORMER = { @@ -41,3 +42,14 @@ def execute_command(command, output_file=None): subprocess.run(command, check=True, stdout=f) else: subprocess.run(command, check=True) + + +def path_to_str(path: pathlib.Path) -> str: + return str(path.resolve()) + + +def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) -- GitLab From a9ad17547d73eb54a8ff07de0711aaf172523a6c Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 7 May 2021 11:22:16 +0200 Subject: [PATCH 15/28] Change batch size for tamil. --- scripts/train_iwpt21.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index e4705f7..b5838c4 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -119,7 +119,6 @@ def run(_): --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} - --word_batch_size 2500 --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """ @@ -128,6 +127,11 @@ def run(_): if lang in {"fr", "ru"}: command = command + " --targets deprel,head,upostag,lemma,feats" + if lang in {"ta"}: + command = command + " --word_batch_size 500" + else: + command = command + " --word_batch_size 2500" + utils.execute_command("".join(command.splitlines())) -- GitLab From c082b65055c8c728cf4279a7c19e9d0409c3000c Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 7 May 2021 14:45:53 +0200 Subject: [PATCH 16/28] Add prediction script for IWPT'21. --- scripts/predict_iwpt21.py | 73 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 scripts/predict_iwpt21.py diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py new file mode 100644 index 0000000..dff3594 --- /dev/null +++ b/scripts/predict_iwpt21.py @@ -0,0 +1,73 @@ +import pathlib + +from absl import app +from absl import flags + +from scripts import utils + +CODE2LANG = { + "ar": "Arabic", + "bg": "Bulgarian", + "cs": "Czech", + "nl": "Dutch", + "en": "English", + "et": "Estonian", + "fi": "Finnish", + "fr": "French", + "it": "Italian", + "lv": "Latvian", + "lt": "Lithuanian", + "pl": "Polish", + "ru": "Russian", + "sk": "Slovak", + "sv": "Swedish", + "ta": "Tamil", + "uk": "Ukrainian", +} + +FLAGS = flags.FLAGS +flags.DEFINE_string(name="data_dir", default="", + help="Path to IWPT'21 data directory.") +flags.DEFINE_string(name="models_dir", default="/tmp/", + help="Model serialization dir.") +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (-1 for cpu).") +flags.DEFINE_boolean(name="expect_prefix", default=True, + help="Whether to expect allennlp prefix.") + + +def run(_): + models_dir = pathlib.Path(FLAGS.models_dir) + for model_dir in models_dir.iterdir(): + lang = model_dir.name + if lang not in CODE2LANG: + print("Skipping unknown directory: ", lang) + continue + + if FLAGS.expect_prefix: + model_dir = list(model_dir.iterdir()) + assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}" + model_dir = model_dir[0] + + data_dir = pathlib.Path(FLAGS.data_dir) + files = list(data_dir.iterdir()) + test_file = [f for f in files if f"{lang}.conllu" == f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] + + output_pred = data_dir / f'{lang}_pred.conllu' + command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} + --input_file {test_file} + --output_file {output_pred} + --cuda_device {FLAGS.cuda_device} + --silent + """ + utils.execute_command(command) + + +def main(): + app.run(run) + + +if __name__ == "__main__": + main() -- GitLab From 2941cf0645739c590a0872fbfafb66ffbbaa3481 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Fri, 7 May 2021 16:07:20 +0200 Subject: [PATCH 17/28] Add deps sorting. --- combo/utils/graph.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 3995f4f..ed2c1ff 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -74,8 +74,11 @@ def graph_and_tree_merge(tree_arc_scores, parse_graph[d] = sorted(parse_graph[d]) for i, g in enumerate(parse_graph): - heads = [x[0] for x in g] - rels = [x[1] for x in g] + heads = np.array([x[0] for x in g]) + rels = np.array([x[1] for x in g]) + indices = rels.argsort() + heads = heads[indices].tolist() + rels = rels[indices].tolist() deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) tokens[i - 1]["deps"] = deps return -- GitLab From 00c388ed8c24d36c1dbba5c3d38ad03b8d53eade Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 10 May 2021 14:07:09 +0200 Subject: [PATCH 18/28] Fix fr and ru models training. --- scripts/train_iwpt21.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index b5838c4..39054e4 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -125,8 +125,9 @@ def run(_): # Datasets without XPOS if lang in {"fr", "ru"}: - command = command + " --targets deprel,head,upostag,lemma,feats" + command = command + " --targets deprel,head,upostag,lemma,feats,deps" + # Smaller dataset if lang in {"ta"}: command = command + " --word_batch_size 500" else: -- GitLab From ce2d7be688db2ce28be47b734238671467ab5e28 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Tue, 11 May 2021 18:16:03 +0200 Subject: [PATCH 19/28] Add conllu-quick-fix call. Change model for Tamil. --- scripts/evaluate_iwpt21.py | 2 +- scripts/predict_iwpt21.py | 22 ++++++++++++++++++---- scripts/train_iwpt21.py | 4 ++-- scripts/utils.py | 11 +++++++++-- 4 files changed, 30 insertions(+), 9 deletions(-) diff --git a/scripts/evaluate_iwpt21.py b/scripts/evaluate_iwpt21.py index b67541f..bfd24eb 100644 --- a/scripts/evaluate_iwpt21.py +++ b/scripts/evaluate_iwpt21.py @@ -70,7 +70,7 @@ def run(_): utils.execute_command(command) output_collapsed = utils.path_to_str(output_pred).replace('.conllu', '.collapsed.conllu') - utils.collapse_nodes(pathlib.Path(FLAGS.data_dir), output_pred, output_collapsed) + utils.collapse_nodes(pathlib.Path(FLAGS.data_dir) / 'tools', output_pred, output_collapsed) command = f"""python {FLAGS.evaluate_script_path} -v {test_file} diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py index dff3594..513ce3b 100644 --- a/scripts/predict_iwpt21.py +++ b/scripts/predict_iwpt21.py @@ -27,9 +27,11 @@ CODE2LANG = { FLAGS = flags.FLAGS flags.DEFINE_string(name="data_dir", default="", - help="Path to IWPT'21 data directory.") + help="Path to data directory.") flags.DEFINE_string(name="models_dir", default="/tmp/", help="Model serialization dir.") +flags.DEFINE_string(name="tools", default="", + help="UD tools path.") flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") flags.DEFINE_boolean(name="expect_prefix", default=True, @@ -51,9 +53,15 @@ def run(_): data_dir = pathlib.Path(FLAGS.data_dir) files = list(data_dir.iterdir()) - test_file = [f for f in files if f"{lang}.conllu" == f.name] - assert len(test_file) == 1, f"Couldn't find test file." - test_file = test_file[0] + test_file = [f for f in files if f"{lang}.mwt.conllu" == f.name] + # Try to use mwt file if it exists + if test_file: + assert len(test_file) == 1, f"Should be exactly one {lang}.mwt.conllu file." + test_file = test_file[0] + else: + test_file = [f for f in files if f"{lang}.conllu" == f.name] + assert len(test_file) == 1, f"Couldn't find test file." + test_file = test_file[0] output_pred = data_dir / f'{lang}_pred.conllu' command = f"""combo --mode predict --model_path {model_dir / 'model.tar.gz'} @@ -64,6 +72,12 @@ def run(_): """ utils.execute_command(command) + output_fixed = utils.path_to_str(output_pred).replace('.conllu', '.fixed.conllu') + utils.quick_fix(pathlib.Path(FLAGS.tools), output_pred, output_fixed) + + output_collapsed = output_fixed.replace('.fixed.conllu', '.collapsed.conllu') + utils.collapse_nodes(pathlib.Path(FLAGS.tools), pathlib.Path(output_fixed), output_collapsed) + def main(): app.run(run) diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index 39054e4..e1e427f 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -73,10 +73,10 @@ def run(_): if "conllu" in name and "fixed" not in name: output = utils.path_to_str(treebank_file).replace('.conllu', '.fixed.conllu') if "train" in name: - utils.collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir / 'tools', treebank_file, output) train_paths.append(output) elif "dev" in name: - utils.collapse_nodes(data_dir, treebank_file, output) + utils.collapse_nodes(data_dir / 'tools', treebank_file, output) dev_paths.append(output) # elif "test" in name: # collapse_nodes(data_dir, treebank_file, output) diff --git a/scripts/utils.py b/scripts/utils.py index 19808ad..2c66205 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -19,7 +19,7 @@ LANG2TRANSFORMER = { "ru": "blinoff/roberta-base-russian-v0", "sv": "KB/bert-base-swedish-cased", "uk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-uk-cased/", - "ta": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-ta-cased/", + "ta": "xlm-roberta-large", "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/", @@ -51,5 +51,12 @@ def path_to_str(path: pathlib.Path) -> str: def collapse_nodes(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): output_path = pathlib.Path(output) if not output_path.exists(): - execute_command(f"perl {path_to_str(data_dir / 'tools' / 'enhanced_collapse_empty_nodes.pl')} " + execute_command(f"perl {path_to_str(data_dir / 'enhanced_collapse_empty_nodes.pl')} " + f"{path_to_str(treebank_file)}", output) + + +def quick_fix(data_dir: pathlib.Path, treebank_file: pathlib.Path, output: str): + output_path = pathlib.Path(output) + if not output_path.exists(): + execute_command(f"perl {path_to_str(data_dir / 'conllu-quick-fix.pl')} " f"{path_to_str(treebank_file)}", output) -- GitLab From 4eaad34d264781eaa6f8859b63cd150ef507033d Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Wed, 12 May 2021 17:01:30 +0200 Subject: [PATCH 20/28] Fix batch predictions for DEPS. --- combo/predict.py | 9 +++++---- scripts/predict_iwpt21.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/combo/predict.py b/combo/predict.py index c710224..01a0837 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -197,14 +197,15 @@ class COMBO(predictor.Predictor): if "enhanced_head" in predictions and predictions["enhanced_head"]: # TODO off-by-one hotfix, refactor - h = np.array(predictions["enhanced_head"]) + sentence_length = len(tree_tokens) + h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length] h = np.concatenate((h[-1:], h[:-1])) - r = np.array(predictions["enhanced_deprel_prob"]) + r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :] r = np.concatenate((r[-1:], r[:-1])) graph.graph_and_tree_merge( - tree_arc_scores=predictions["head"], - tree_rel_scores=predictions["deprel"], + tree_arc_scores=predictions["head"][:sentence_length], + tree_rel_scores=predictions["deprel"][:sentence_length], graph_arc_scores=h, graph_rel_scores=r, idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py index 513ce3b..61b9cf7 100644 --- a/scripts/predict_iwpt21.py +++ b/scripts/predict_iwpt21.py @@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") flags.DEFINE_boolean(name="expect_prefix", default=True, help="Whether to expect allennlp prefix.") +flags.DEFINE_integer(name="batch_size", default=32, + help="Batch size.") def run(_): @@ -68,6 +70,7 @@ def run(_): --input_file {test_file} --output_file {output_pred} --cuda_device {FLAGS.cuda_device} + --batch_size {FLAGS.batch_size} --silent """ utils.execute_command(command) -- GitLab From 4b3e1cafeae01b20512c146a3dac9b72d4987a6a Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 20 May 2021 12:37:07 +0200 Subject: [PATCH 21/28] Handle double > cases. --- combo/utils/graph.py | 32 +++++++++++++++++++++++++------- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index ed2c1ff..b4c9632 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -114,13 +114,31 @@ def restore_collapse_edges(tree_tokens): head, relation = d.split(':', 1) ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}" empty_node_relation, current_node_relation = relation.split(">", 1) - deps[i] = f"{ehead}:{current_node_relation}" - empty_tokens.append( - { - "id": ehead, - "deps": f"{head}:{empty_node_relation}" - } - ) + # Edge case, double > + if ">" in current_node_relation: + second_empty_node_relation, current_node_relation = current_node_relation.split(">") + deps[i] = f"{ehead}:{current_node_relation}" + empty_tokens.append( + { + "id": ehead, + "deps": f"{head}:{empty_node_relation}" + } + ) + empty_tokens.append( + { + "id": f"{len(tree_tokens)}.{len(empty_tokens) + 1}", + "deps": f"{ehead}:{second_empty_node_relation}" + } + ) + + else: + deps[i] = f"{ehead}:{current_node_relation}" + empty_tokens.append( + { + "id": ehead, + "deps": f"{head}:{empty_node_relation}" + } + ) deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0])) token["deps"] = "|".join([f"{k}:{v}" for k, v in deps]) return empty_tokens -- GitLab From 7dcf0e5f5e76999b721041cf34f2539823460ad9 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 20 May 2021 12:38:13 +0200 Subject: [PATCH 22/28] Use XLM-R for low resource languages. --- scripts/utils.py | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/scripts/utils.py b/scripts/utils.py index 2c66205..925f13b 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -18,20 +18,13 @@ LANG2TRANSFORMER = { "it": "dbmdz/bert-base-italian-cased", "ru": "blinoff/roberta-base-russian-v0", "sv": "KB/bert-base-swedish-cased", - "uk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-uk-cased/", + "uk": "xlm-roberta-large", "ta": "xlm-roberta-large", - "sk": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-sk-cased/", - "lt": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lt-cased/", - "lv": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-lv-cased/", - "cs": "/tmp/lustre_shared/mklimasz/transformers/wikibert-base-cs-cased/", - "et": "/tmp/lustre_shared/mklimasz/transformers/etwiki-bert/", - # "uk": http://dl.turkunlp.org/wikibert/wikibert-base-uk-cased/ - # "ta": http://dl.turkunlp.org/wikibert/wikibert-base-ta-cased/ - # "sk": http://dl.turkunlp.org/wikibert/wikibert-base-sk-cased/ - # "lt": http://dl.turkunlp.org/wikibert/wikibert-base-lt-cased/ - # "lv": http://dl.turkunlp.org/wikibert/wikibert-base-lv-cased/ - # "et": http://dl.turkunlp.org/estonian-bert/etwiki-bert/pytorch/ - # "cs": https://github.com/kiv-air/Czert https://arxiv.org/pdf/2103.13031.pdf + "sk": "xlm-roberta-large", + "lt": "xlm-roberta-large", + "lv": "xlm-roberta-large", + "cs": "xlm-roberta-large", + "et": "xlm-roberta-large", } -- GitLab From 5545d96805d5a4f08854350d7e7dd94c6604b975 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 20 May 2021 14:02:09 +0200 Subject: [PATCH 23/28] Enable weighted average of LM embeddings. --- combo/config.graph.template.jsonnet | 3 ++- combo/models/embeddings.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index c72a057..708cfa3 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -49,7 +49,7 @@ local lemma_char_dim = 64; # Character embedding dim, int local char_dim = 64; # Word embedding projection dim, int -local projected_embedding_dim = 100; +local projected_embedding_dim = 768; # Loss weights, dict[str, int] local loss_weights = { xpostag: 0.05, @@ -202,6 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't }, token: if use_transformer then { type: "transformers_word_embeddings", + last_layer_only: false, model_name: pretrained_transformer_name, projection_dim: projected_embedding_dim, tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index d8e9d7a..d8c3d71 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -111,10 +111,12 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm projection_activation: Optional[allen_nn.Activation] = lambda x: x, projection_dropout_rate: Optional[float] = 0.0, freeze_transformer: bool = True, + last_layer_only: bool = True, tokenizer_kwargs: Optional[Dict[str, Any]] = None, transformer_kwargs: Optional[Dict[str, Any]] = None): super().__init__(model_name, train_parameters=not freeze_transformer, + last_layer_only=last_layer_only, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs) if projection_dim: -- GitLab From 7b545ee51dbf7ba61824ee630507cef2859d245b Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Mon, 24 May 2021 07:49:02 +0200 Subject: [PATCH 24/28] Change few transformer models. Make position_ids not necessary. Update to pytorch 1.7. --- combo/config.graph.template.jsonnet | 10 +++------- combo/models/embeddings.py | 2 ++ combo/utils/graph.py | 11 ++++++++--- scripts/utils.py | 12 ++++++------ setup.py | 2 +- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/combo/config.graph.template.jsonnet b/combo/config.graph.template.jsonnet index 708cfa3..a472560 100644 --- a/combo/config.graph.template.jsonnet +++ b/combo/config.graph.template.jsonnet @@ -112,10 +112,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't use_sem: if in_targets("semrel") then true else false, token_indexers: { token: if use_transformer then { - type: "pretrained_transformer_mismatched_fixed", - model_name: pretrained_transformer_name, - tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") - then {use_fast: false} else {}, + type: "pretrained_transformer_mismatched", + model_name: pretrained_transformer_name } else { # SingleIdTokenIndexer, token as single int type: "single_id", @@ -204,9 +202,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't type: "transformers_word_embeddings", last_layer_only: false, model_name: pretrained_transformer_name, - projection_dim: projected_embedding_dim, - tokenizer_kwargs: if std.startsWith(pretrained_transformer_name, "allegro/herbert") - then {use_fast: false} else {}, + projection_dim: projected_embedding_dim } else { type: "embeddings_projected", embedding_dim: embedding_dim, diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index d8c3d71..c8499ba 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -105,6 +105,8 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm Tested with Bert (but should work for other models as well). """ + authorized_missing_keys = [r"position_ids$"] + def __init__(self, model_name: str, projection_dim: int = 0, diff --git a/combo/utils/graph.py b/combo/utils/graph.py index b4c9632..f61a68e 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -32,6 +32,10 @@ def graph_and_tree_merge(tree_arc_scores, if not d: continue label = idx2label[tree_rel_scores[d - 1]] + # graph_label = graph_idx2label[graph_rel_pred[d - 1][h - 1]] + # if ">" in graph_label and label in graph_label: + # print("Using graph label instead of tree.") + # label = graph_label if label != _ACL_REL_CL: graph[h].append(d) labeled_graph[h].append((d, label)) @@ -118,16 +122,17 @@ def restore_collapse_edges(tree_tokens): if ">" in current_node_relation: second_empty_node_relation, current_node_relation = current_node_relation.split(">") deps[i] = f"{ehead}:{current_node_relation}" + second_ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 2}" empty_tokens.append( { "id": ehead, - "deps": f"{head}:{empty_node_relation}" + "deps": f"{second_ehead}:{empty_node_relation}" } ) empty_tokens.append( { - "id": f"{len(tree_tokens)}.{len(empty_tokens) + 1}", - "deps": f"{ehead}:{second_empty_node_relation}" + "id": second_ehead, + "deps": f"{head}:{second_empty_node_relation}" } ) diff --git a/scripts/utils.py b/scripts/utils.py index 925f13b..09f7591 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -4,7 +4,7 @@ import subprocess LANG2TRANSFORMER = { "en": "bert-base-cased", - "pl": "allegro/herbert-base-cased", + "pl": "allegro/herbert-large-cased", "zh": "bert-base-chinese", "fi": "TurkuNLP/bert-base-finnish-cased-v1", "ko": "kykim/bert-kor-base", @@ -12,12 +12,12 @@ LANG2TRANSFORMER = { "ar": "aubmindlab/bert-base-arabertv2", "eu": "ixa-ehu/berteus-base-cased", "tr": "dbmdz/bert-base-turkish-cased", - "bg": "iarfmoose/roberta-base-bulgarian", - "nl": "GroNLP/bert-base-dutch-cased", + "bg": "xlm-roberta-large", + "nl": "xlm-roberta-large", "fr": "camembert-base", - "it": "dbmdz/bert-base-italian-cased", - "ru": "blinoff/roberta-base-russian-v0", - "sv": "KB/bert-base-swedish-cased", + "it": "xlm-roberta-large", + "ru": "xlm-roberta-large", + "sv": "xlm-roberta-large", "uk": "xlm-roberta-large", "ta": "xlm-roberta-large", "sk": "xlm-roberta-large", diff --git a/setup.py b/setup.py index e1354b7..06717ba 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIREMENTS = [ 'scipy<1.6.0;python_version<"3.7"', # SciPy 1.6.0 works for 3.7+ 'spacy==2.3.2', 'scikit-learn<=0.23.2', - 'torch==1.6.0', + 'torch==1.7.0', 'tqdm==4.43.0', 'transformers==4.0.1', 'urllib3==1.25.11', -- GitLab From 7ddc4b0ee23a42601b27c97435cf51c5ac5068d0 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 5 Aug 2021 10:43:50 +0200 Subject: [PATCH 25/28] Add postprocessing EUD script. --- scripts/postprocessing.py | 454 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 454 insertions(+) create mode 100644 scripts/postprocessing.py diff --git a/scripts/postprocessing.py b/scripts/postprocessing.py new file mode 100644 index 0000000..2f4da16 --- /dev/null +++ b/scripts/postprocessing.py @@ -0,0 +1,454 @@ +# TODO lemma remove punctuation - ukrainian +# TODO lemma remove punctuation - russian +# TODO consider handling multiple 'case' +import sys + +import conllu + +from re import * + +rus = compile(u'^из-за$') +expand = compile('^\d+\.\d+$') + +''' +A script correcting automatically predicted enhanced dependency graphs. +Running the script: python postprocessing.py cs + +You have to modified the paths to the input CoNLL-U file and the output file. + +The last argument (e.g. cs) corresponds to the language symbol. +All language symbols: +ar (Arabic), bg (Bulgarian), cs (Czech), nl (Dutch), en (English), et (Estonian), fi (Finnish) +fr (French), it (Italian), lv (Latvian), lt (Lithuanian), pl (Polish), ru (Russian) +sk (Slovak), sv (Swedish), ta (Tamil), uk (Ukrainian) + +There are two main rules: +1) the first one add case information to the following labels: nmod, obl, acl, advcl. +The case information comes from case/mark dependent of the current token and from the morphological feature Case. +Depending on the language, not all information is added. +In some languages ('en', 'it', 'nl', 'sv') the lemma of coordinating conjunction (cc) is appendend to the conjunct label (conj). +Functions: fix_mod_deps, fix_obj_deps, fix_acl_deps, fix_advcl_deps and fix_conj_deps + +2) the second rule correct enhanced edges comming into function words labelled ref, mark, punct, root, case, det, cc, cop, aux +They should not be assinged other functions. For example, if a token, e.g. "and" is labelled cc (coordinating conjunction), +it cannot be simultaneously a subject (nsubj) and if this wrong enhanced edge exists, it should be removed from the graph. + +There is one additional rule for Estonian: +if the label is nsubj:cop or csubj:cop, the cop sublabel is removed and we have nsubj and csubj, respectively. +''' + + +def fix_nmod_deps(dep, token, sentence, relation): + """ + This function modifies enhanced edges labelled 'nmod' + """ + label: str + label, head = dep + + # All labels starting with 'relation' are checked + if not label.startswith(relation): + return dep + + # case_lemma is a (complex) preposition labelled 'case' e.g. 'po' in nmod:po:loc + # or a (complex) subordinating conjunction labelled 'mark' + case_lemma = None + case_tokens = [] + for t in sentence: + if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]: + case_tokens.append(t) + break + + if case_tokens: + fixed_tokens = [] + for t in sentence: + for c in case_tokens: + if t["deprel"] == "fixed" and t["head"] == c["id"]: + fixed_tokens.append(t) + + if fixed_tokens: + case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens)) + else: + case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens)) + + # case_val is a value of Case, e.g. 'gen' in nmod:gen and 'loc' in nmod:po:loc + case_val = None + if token['feats'] is not None: + if 'Case' in token["feats"]: + case_val = token["feats"]['Case'].lower() + + #TODO: check for other languages + if language in ['fi'] and label not in ['nmod', 'nmod:poss']: + return dep + elif language not in ['fi'] and label not in ['nmod']: + return dep + else: + label_lst = [label] + if case_lemma: + label_lst.append(case_lemma) + if case_val: + #TODO: check for other languages + if language not in ['bg', 'en', 'nl', 'sv']: + label_lst.append(case_val) + label = ":".join(label_lst) + + # print(label, sentence.metadata["sent_id"]) + return label, head + + +def fix_obl_deps(dep, token, sentence, relation): + """ + This function modifies enhanced edges labelled 'obl', 'obl:arg', 'obl:rel' + """ + label: str + label, head = dep + + if not label.startswith(relation): + return dep + + # case_lemma is a (complex) preposition labelled 'case' e.g. 'pod' in obl:pod:loc + # or a (complex) subordinating conjunction labelled 'mark' + case_lemma = None + case_tokens = [] + for t in sentence: + if t["deprel"] in ["case", "mark"] and t["head"] == token["id"]: + case_tokens.append(t) + break + + if case_tokens: + # fixed_token is the lemma of a complex preposition, e.g. 'przypadek' in obl:w_przypadku:gen + fixed_tokens = [] + for t in sentence: + for c in case_tokens: + if t["deprel"] == "fixed" and t["head"] == c["id"]: + fixed_tokens.append(t) + + if fixed_tokens: + case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens + fixed_tokens)) + else: + case_lemma = "_".join(rus.sub('изза', f["lemma"]) for f in quicksort(case_tokens)) + + # case_val is a value of Case feature, e.g. 'loc' in obl:pod:loc + case_val = None + if token['feats'] is not None: + if 'Case' in token["feats"]: + case_val = token["feats"]['Case'].lower() + + if label not in ['obl', 'obl:arg', 'obl:agent']: + return dep + else: + label_lst = [label] + if case_lemma: + label_lst.append(case_lemma) + if case_val: + # TODO: check for other languages + if language not in ['bg', 'en', 'lv', 'nl', 'sv']: + label_lst.append(case_val) + # TODO: check it for other languages + if language not in ['pl', 'sv']: + if case_val and not case_lemma: + if label == token['deprel']: + label_lst.append(case_val) + label = ":".join(label_lst) + + # print(label, sentence.metadata["sent_id"]) + return label, head + + +def fix_acl_deps(dep, acl_token, sentence, acl, lang): + """ + This function modifies enhanced edges labelled 'acl' + """ + label: str + label, head = dep + + if not label.startswith(acl): + return dep + + if label.startswith("acl:relcl"): + if lang not in ['uk']: + return dep + + case_lemma = None + case_tokens = [] + for token in sentence: + if token["deprel"] == "mark" and token["head"] == acl_token["id"]: + case_tokens.append(token) + break + + if case_tokens: + fixed_tokens = [] + for token in sentence: + if token["deprel"] == "fixed" and token["head"] == quicksort(case_tokens)[0]["id"]: + fixed_tokens.append(token) + + if fixed_tokens: + case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)]) + else: + case_lemma = quicksort(case_tokens)[0]["lemma"] + + if lang in ['uk']: + if label not in ['acl', 'acl:relcl']: + return dep + else: + label_lst = [label] + if case_lemma: + label_lst.append(case_lemma) + label = ":".join(label_lst) + else: + if label not in ['acl']: + return dep + else: + label_lst = [label] + if case_lemma: + label_lst.append(case_lemma) + label = ":".join(label_lst) + + # print(label, sentence.metadata["sent_id"]) + return label, head + +def fix_advcl_deps(dep, advcl_token, sentence, advcl): + """ + This function modifies enhanced edges labelled 'advcl' + """ + label: str + label, head = dep + + if not label.startswith(advcl): + return dep + + case_lemma = None + case_tokens = [] + # TODO: check for other languages + if language in ['bg', 'lt']: + for token in sentence: + if token["deprel"] in ["mark", "case"] and token["head"] == advcl_token["id"]: + case_tokens.append(token) + else: + for token in sentence: + if token["deprel"] == "mark" and token["head"] == advcl_token["id"]: + case_tokens.append(token) + + if case_tokens: + fixed_tokens = [] + # TODO: check for other languages + if language not in ['bg', 'nl']: + for token in sentence: + for case in quicksort(case_tokens): + if token["deprel"] == "fixed" and token["head"] == case["id"]: + fixed_tokens.append(token) + + if fixed_tokens: + case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)]) + else: + case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)]) + + if label not in ['advcl']: + return dep + else: + label_lst = [label] + if case_lemma: + label_lst.append(case_lemma) + label = ":".join(label_lst) + + # print(label, sentence.metadata["sent_id"]) + return label, head + + +def fix_conj_deps(dep, conj_token, sentence, conj): + """ + This function modifies enhanced edges labelled 'conj' which should be assined the lemma of cc as sublabel + """ + label: str + label, head = dep + + if not label.startswith(conj): + return dep + + case_lemma = None + case_tokens = [] + for token in sentence: + if token["deprel"] == "cc" and token["head"] == conj_token["id"]: + case_tokens.append(token) + + if case_tokens: + fixed_tokens = [] + for token in sentence: + for case in quicksort(case_tokens): + if token["deprel"] == "fixed" and token["head"] == case["id"]: + fixed_tokens.append(token) + + if fixed_tokens: + case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens + fixed_tokens)]) + else: + case_lemma = "_".join([t["lemma"] for t in quicksort(case_tokens)]) + + if label not in ['conj']: + return dep + else: + label_lst = [label] + if case_lemma: + label_lst.append(case_lemma) + label = ":".join(label_lst) + + # print(label, sentence.metadata["sent_id"]) + return label, head + + + +def quicksort(tokens): + if len(tokens) <= 1: + return tokens + else: + return quicksort([x for x in tokens[1:] if int(x["id"]) < int(tokens[0]["id"])]) \ + + [tokens[0]] \ + + quicksort([y for y in tokens[1:] if int(y["id"]) >= int(tokens[0]["id"])]) + + +language = sys.argv[1] +errors = 0 + +input_file = f"./token_test/{language}_pred.fixed.conllu" +output_file = f"./token_test/{language}.nofixed.conllu" +with open(input_file) as fh: + with open(output_file, "w") as oh: + for sentence in conllu.parse_incr(fh): + for token in sentence: + deps = token["deps"] + if deps: + if language not in ['fr']: + for idx, dep in enumerate(deps): + assert len(dep) == 2, dep + new_dep = fix_obl_deps(dep, token, sentence, "obl") + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + if language not in ['fr']: + for idx, dep in enumerate(deps): + assert len(dep) == 2, dep + new_dep = fix_nmod_deps(dep, token, sentence, "nmod") + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + # TODO: check for other languages + if language not in ['fr', 'lv']: + for idx, dep in enumerate(deps): + assert len(dep) == 2, dep + new_dep = fix_acl_deps(dep, token, sentence, "acl", language) + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + + # TODO: check for other languages + if language not in ['fr', 'lv']: + for idx, dep in enumerate(deps): + assert len(dep) == 2, dep + new_dep = fix_advcl_deps(dep, token, sentence, "advcl") + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + # TODO: check for other languages + if language in ['en', 'it', 'nl', 'sv']: + for idx, dep in enumerate(deps): + assert len(dep) == 2, dep + new_dep = fix_conj_deps(dep, token, sentence, "conj") + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + # TODO: check for other languages + if language in ['et']: + for idx, dep in enumerate(deps): + assert len(dep) == 2, dep + if token['deprel'] == 'nsubj:cop' and dep[0] == 'nsubj:cop': + new_dep = ('nsubj', dep[1]) + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + if token['deprel'] == 'csubj:cop' and dep[0] == 'csubj:cop': + new_dep = ('csubj', dep[1]) + token["deps"][idx] = new_dep + if new_dep[0] != dep[0]: + errors += 1 + # BELOW ARE THE RULES FOR CORRECTION OF THE FUNCTION WORDS + # labelled ref, mark, punct, root, case, det, cc, cop, aux + # They should not be assinged other functions + #TODO: to check for other languages + if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ru']: + refs = [s for s in deps if s[0] == 'ref'] + if refs: + token["deps"] = refs + #TODO: to check for other languages + if language in ['ar', 'bg', 'en', 'et', 'fi', 'it', 'lt', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']: + marks = [s for s in deps if s[0] == 'mark'] + if marks and token['deprel'] == 'mark': + token["deps"] = marks + #TODO: to check for other languages + if language in ['ar', 'bg', 'cs', 'en', 'et', 'fi', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']: + puncts = [s for s in deps if s[0] == 'punct' and s[1] == token['head']] + if puncts and token['deprel'] == 'punct': + token["deps"] = puncts + #TODO: to check for other languages + if language in ['ar', 'lt', 'pl']: + roots = [s for s in deps if s[0] == 'root'] + if roots and token['deprel'] == 'root': + token["deps"] = roots + #TODO: to check for other languages + if language in ['en', 'ar', 'bg', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr']: + cases = [s for s in deps if s[0] == 'case'] + if cases and token['deprel'] == 'case': + token["deps"] = cases + #TODO: to check for other languages + if language in ['en', 'ar', 'et', 'fi', 'it', 'lt', 'lv', 'nl', 'pl', 'sk', 'sv', 'ta', 'uk', 'fr', 'ru']: + dets = [s for s in deps if s[0] == 'det'] + if dets and token['deprel'] == 'det': + token["deps"] = dets + #TODO: to check for other languages + if language in ['et', 'fi', 'it', 'lv', 'nl', 'pl', 'sk', 'sv', 'uk', 'fr', 'ar', 'ru', 'ta']: + ccs = [s for s in deps if s[0] == 'cc'] + if ccs and token['deprel'] == 'cc': + token["deps"] = ccs + #TODO: to check for other languages + if language in ['bg', 'fi','et', 'it', 'sk', 'sv', 'uk', 'nl', 'fr', 'ru']: + cops = [s for s in deps if s[0] == 'cop'] + if cops and token['deprel'] == 'cop': + token["deps"] = cops + #TODO: to check for other languages + if language in ['bg', 'et', 'fi', 'it', 'lv', 'pl', 'sv']: + auxs = [s for s in deps if s[0] == 'aux'] + if auxs and token['deprel'] == 'aux': + token["deps"] = auxs + + #TODO: to check for other languages + if language in ['ar', 'bg', 'cs', 'et', 'fi', 'fr', 'lt', 'lv', 'pl', 'sk', 'sv', 'uk', 'ru', 'ta']: + conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']] + other = [s for s in deps if s[0] != 'conj'] + if conjs and token['deprel'] == 'conj': + token["deps"] = conjs+other + + #TODO: to check for other languages + # EXTRA rule 1 + if language in ['cs', 'et', 'fi', 'lv', 'pl', 'uk']: #ar nl ru + # not use for: lt, bg, fr, sk, ta, sv, en + deprel = [s for s in deps if s[0] == token['deprel'] and s[1] == token['head']] + other_exp = [s for s in deps if type(s[1]) == tuple] + other_noexp = [s for s in deps if s[1] != token['head'] and type(s[1]) != tuple] + if other_exp: + token["deps"] = other_exp+other_noexp + + # EXTRA rule 2 + if language in ['cs', 'lt', 'pl', 'sk', 'uk']: #ar nl ru + conjs = [s for s in deps if s[0] == 'conj' and s[1] == token['head']] + if conjs and len(deps) == 1 and len(conjs) == 1: + for t in sentence: + if t['id'] == conjs[0][1] and t['deprel'] == 'root': + conjs.append((t['deprel'], t['head'])) + token["deps"] = conjs + + if language in ['ta']: + if token['deprel'] != 'conj': + conjs = [s for s in deps if s[0] == 'conj'] + if conjs: + new_dep = [s for s in deps if s[1] == token['head']] + token["deps"] = new_dep + + oh.write(sentence.serialize()) +print(errors) -- GitLab From 5e5e461b6cc35dba693cd458bc72f265da603b0e Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 5 Aug 2021 11:04:15 +0200 Subject: [PATCH 26/28] Add citation bibtex. --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index 732262e..58cda99 100644 --- a/README.md +++ b/README.md @@ -39,3 +39,23 @@ We encourage you to use the [beginner's tutorial](https://colab.research.google. - [**Training**](docs/training.md) - [**Prediction**](docs/prediction.md) - [**Model performance**](docs/performance.md) + +## Citing + +If you use EUD in your research, please cite [COMBO: A New Module for EUD Parsing](https://aclanthology.org/2021.iwpt-1.16/) +```bibtex +@inproceedings{klimaszewski-wroblewska-2021-combo, + title = "{COMBO}: A New Module for {EUD} Parsing", + author = "Klimaszewski, Mateusz and + Wr{\'o}blewska, Alina", + booktitle = "Proceedings of the 17th International Conference on Parsing Technologies and the IWPT 2021 Shared Task on Parsing into Enhanced Universal Dependencies (IWPT 2021)", + month = aug, + year = "2021", + address = "Online", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2021.iwpt-1.16", + doi = "10.18653/v1/2021.iwpt-1.16", + pages = "158--166", + abstract = "We introduce the COMBO-based approach for EUD parsing and its implementation, which took part in the IWPT 2021 EUD shared task. The goal of this task is to parse raw texts in 17 languages into Enhanced Universal Dependencies (EUD). The proposed approach uses COMBO to predict UD trees and EUD graphs. These structures are then merged into the final EUD graphs. Some EUD edge labels are extended with case information using a single language-independent expansion rule. In the official evaluation, the solution ranked fourth, achieving an average ELAS of 83.79{\%}. The source code is available at https://gitlab.clarin-pl.eu/syntactic-tools/combo.", +} +``` -- GitLab From d297c48c1ef9567294dacd38bda8e583921d0fe3 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 16 Sep 2021 16:22:02 +0100 Subject: [PATCH 27/28] Add EMNLP announcement. --- README.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 58cda99..e22b9a8 100644 --- a/README.md +++ b/README.md @@ -42,7 +42,21 @@ We encourage you to use the [beginner's tutorial](https://colab.research.google. ## Citing -If you use EUD in your research, please cite [COMBO: A New Module for EUD Parsing](https://aclanthology.org/2021.iwpt-1.16/) +### Accepted at EMNLP'21 demo session :tada: :fire: + +If you use COMBO in your research, please cite [COMBO: State-of-the-Art Morphosyntactic Analysis](https://arxiv.org/abs/2109.05361) +```bibtex +@misc{klimaszewski2021combo, + title={COMBO: State-of-the-Art Morphosyntactic Analysis}, + author={Mateusz Klimaszewski and Alina Wróblewska}, + year={2021}, + eprint={2109.05361}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` + +If you use an EUD module in your research, please cite [COMBO: A New Module for EUD Parsing](https://aclanthology.org/2021.iwpt-1.16/) ```bibtex @inproceedings{klimaszewski-wroblewska-2021-combo, title = "{COMBO}: A New Module for {EUD} Parsing", -- GitLab From 4a27855947d788dc3df971bf5000f3771bcce088 Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 16 Sep 2021 16:38:50 +0100 Subject: [PATCH 28/28] Release 1.0.4. --- README.md | 2 +- docs/installation.md | 4 ++-- setup.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e22b9a8..76e436b 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Clone this repository and install COMBO (we suggest creating a virtualenv/conda environment with Python 3.6+, as a bundle of required packages will be installed): ```bash pip install -U pip setuptools wheel -pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.3 +pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.4 ``` Run the following commands in your Python console to make predictions with a pre-trained model: ```python diff --git a/docs/installation.md b/docs/installation.md index 6aba7f7..422bed2 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -2,7 +2,7 @@ Clone this repository and install COMBO (we suggest using virtualenv/conda with Python 3.6+): ```bash pip install -U pip setuptools wheel -pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.3 +pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.4 combo --helpfull ``` @@ -11,7 +11,7 @@ combo --helpfull python -m venv venv source venv/bin/activate pip install -U pip setuptools wheel -pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.3 +pip install --index-url https://pypi.clarin-pl.eu/simple combo==1.0.4 ``` ### Conda example: diff --git a/setup.py b/setup.py index 06717ba..876909d 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ REQUIREMENTS = [ setup( name='combo', - version='1.0.3', + version='1.0.4', author='Mateusz Klimaszewski', author_email='M.Klimaszewski@ii.pw.edu.pl', install_requires=REQUIREMENTS, -- GitLab