diff --git a/.idea/combolightning.iml b/.idea/combolightning.iml index 4154233b93a52a4823d13c2abe3d93af8a29772f..332f5233511b9e9c4e7140693f371f2c7a6382c6 100644 --- a/.idea/combolightning.iml +++ b/.idea/combolightning.iml @@ -2,7 +2,7 @@ <module type="PYTHON_MODULE" version="4"> <component name="NewModuleRootManager"> <content url="file://$MODULE_DIR$" /> - <orderEntry type="jdk" jdkName="Python 3.9 (combolightning)" jdkType="Python SDK" /> + <orderEntry type="jdk" jdkName="combo" jdkType="Python SDK" /> <orderEntry type="sourceFolder" forTests="false" /> </component> </module> \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index fcbd09d95c297f4fe5aa663eb215aa696a5cf4b3..d1d59e67aa31cc624fcf2a8623e0114ab43ee103 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,4 +1,4 @@ <?xml version="1.0" encoding="UTF-8"?> <project version="4"> - <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.9 (combolightning)" project-jdk-type="Python SDK" /> + <component name="ProjectRootManager" version="2" project-jdk-name="combo" project-jdk-type="Python SDK" /> </project> \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..35eb1ddfbbc029bcab630581847471d7f238ec53 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="VcsDirectoryMappings"> + <mapping directory="" vcs="Git" /> + </component> +</project> \ No newline at end of file diff --git a/combo/checks.py b/combo/checks.py deleted file mode 100644 index 7736e96c0465fbd6188e84e308e86762053bcb0f..0000000000000000000000000000000000000000 --- a/combo/checks.py +++ /dev/null @@ -1,4 +0,0 @@ -class ConfigurationError(Exception): - def __init__(self, message: str): - super().__init__() - self.message = message diff --git a/combo/commands/__init__.py b/combo/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..83ff4f4b48489406a73f3a902f0ab28eb62572b1 --- /dev/null +++ b/combo/commands/__init__.py @@ -0,0 +1 @@ +from .train import FinetuningTrainModel \ No newline at end of file diff --git a/combo/commands/train.py b/combo/commands/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bb80729a4875e48574ac94d43b36660559fe39 --- /dev/null +++ b/combo/commands/train.py @@ -0,0 +1,10 @@ +from pytorch_lightning import Trainer + + +class FinetuningTrainModel(Trainer): + """ + Class made only for finetuning, + the only difference is saving vocab from concatenated + (archive and current) datasets + """ + pass \ No newline at end of file diff --git a/combo/data/__init__.py b/combo/data/__init__.py index 91426e3950386ecf787e18321e8c7765bc33a187..2d69cca8bb04168159e0c2b3557dabc4b7e18ab2 100644 --- a/combo/data/__init__.py +++ b/combo/data/__init__.py @@ -1,3 +1,4 @@ from .samplers import TokenCountBatchSampler +from .token import Token from .token_indexers import * from .api import * diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 0c34d3fbcf31850546ed752bea2af65efb817a58..62ca30f2af2c3380f311c5f0fd2accca2d225ab9 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -1,317 +1,11 @@ -import copy import logging -import pathlib -from typing import Union, List, Dict, Iterable, Optional, Any, Tuple -import conllu -import torch -from allennlp import data as allen_data -from allennlp.common import checks, util -from allennlp.data import fields as allen_fields, vocabulary -from conllu import parser -from dataclasses import dataclass -from overrides import overrides - -from combo.data import fields logger = logging.getLogger(__name__) -@allen_data.DatasetReader.register("conllu") -class UniversalDependenciesDatasetReader(allen_data.DatasetReader): - - def __init__( - self, - token_indexers: Dict[str, allen_data.TokenIndexer] = None, - lemma_indexers: Dict[str, allen_data.TokenIndexer] = None, - features: List[str] = None, - targets: List[str] = None, - use_sem: bool = False, - **kwargs, - ) -> None: - super().__init__(**kwargs) - if features is None: - features = ["token", "char"] - if targets is None: - targets = ["head", "deprel", "upostag", "xpostag", "lemma", "feats"] - - if "token" not in features and "char" not in features: - raise checks.ConfigurationError("There must be at least one ('char' or 'token') text-based feature!") - - if "deps" in targets and not ("head" in targets and "deprel" in targets): - raise checks.ConfigurationError("Add 'head' and 'deprel' to targets when using 'deps'!") - - intersection = set(features).intersection(set(targets)) - if len(intersection) != 0: - raise checks.ConfigurationError( - "Features and targets cannot share elements! " - "Remove {} from either features or targets.".format(intersection) - ) - self.use_sem = use_sem - - # *.conllu readers config - fields = list(parser.DEFAULT_FIELDS) - fields[1] = "token" # use 'token' instead of 'form' - field_parsers = parser.DEFAULT_FIELD_PARSERS - # Do not make it nullable - field_parsers.pop("xpostag", None) - # Ignore parsing misc - field_parsers.pop("misc", None) - if self.use_sem: - fields = list(fields) - fields.append("semrel") - field_parsers["semrel"] = lambda line, i: line[i] - self.field_parsers = field_parsers - self.fields = tuple(fields) - - self._token_indexers = token_indexers - self._lemma_indexers = lemma_indexers - self._targets = targets - self._features = features - self.generate_labels = True - # Filter out not required token indexers to avoid - # Mismatched token keys ConfigurationError - for indexer_name in list(self._token_indexers.keys()): - if indexer_name not in self._features: - del self._token_indexers[indexer_name] - - @overrides - def _read(self, file_path: str) -> Iterable[allen_data.Instance]: - file_path = [file_path] if len(file_path.split(",")) == 0 else file_path.split(",") - - for conllu_file in file_path: - file = pathlib.Path(conllu_file) - assert conllu_file and file.exists(), f"File with path '{conllu_file}' does not exists!" - with file.open("r", encoding="utf-8") as f: - for annotation in conllu.parse_incr(f, fields=self.fields, field_parsers=self.field_parsers): - yield self.text_to_instance(annotation) - - @overrides - def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: - fields_: Dict[str, allen_data.Field] = {} - tree_tokens = [t for t in tree if isinstance(t["id"], int)] - tokens = [_Token(t["token"], - pos_=t.get("upostag"), - tag_=t.get("xpostag"), - lemma_=t.get("lemma"), - feats_=t.get("feats")) - for t in tree_tokens] - - # features - text_field = allen_fields.TextField(tokens, self._token_indexers) - fields_["sentence"] = text_field - - # targets - if self.generate_labels: - for target_name in self._targets: - if target_name != "sent": - target_values = [t[target_name] for t in tree_tokens] - if target_name == "lemma": - target_values = [allen_data.Token(v) for v in target_values] - fields_[target_name] = allen_fields.TextField(target_values, self._lemma_indexers) - elif target_name == "feats": - target_values = self._feat_values(tree_tokens) - fields_[target_name] = fields.SequenceMultiLabelField(target_values, - self._feats_indexer, - self._feats_as_tensor_wrapper, - text_field, - label_namespace="feats_labels") - elif target_name == "head": - target_values = [0 if v == "_" else int(v) for v in target_values] - fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, - label_namespace=target_name + "_labels") - elif target_name == "deps": - # Graphs require adding ROOT (AdjacencyField uses sequence length from TextField). - text_field_deps = allen_fields.TextField([_Token("ROOT")] + copy.deepcopy(tokens), - self._token_indexers) - enhanced_heads: List[Tuple[int, int]] = [] - enhanced_deprels: List[str] = [] - for idx, t in enumerate(tree_tokens): - t_deps = t["deps"] - if t_deps and t_deps != "_": - for rel, head in t_deps: - # EmoryNLP skips the first edge, if there are two edges between the same - # nodes. Thanks to that one is in a tree and another in a graph. - # This snippet follows that approach. - if enhanced_heads and enhanced_heads[-1] == (idx, head): - enhanced_heads.pop() - enhanced_deprels.pop() - enhanced_heads.append((idx, head)) - enhanced_deprels.append(rel) - fields_["enhanced_heads"] = allen_fields.AdjacencyField( - indices=enhanced_heads, - sequence_field=text_field_deps, - label_namespace="enhanced_heads_labels", - padding_value=0, - ) - fields_["enhanced_deprels"] = allen_fields.AdjacencyField( - indices=enhanced_heads, - sequence_field=text_field_deps, - labels=enhanced_deprels, - # Label namespace matches regular tree parsing. - label_namespace="enhanced_deprel_labels", - padding_value=0, - ) - else: - fields_[target_name] = allen_fields.SequenceLabelField(target_values, text_field, - label_namespace=target_name + "_labels") - - # Restore feats fields to string representation - # parser.serialize_field doesn't handle key without value - for token in tree.tokens: - if "feats" in token: - feats = token["feats"] - if feats: - feats_values = [] - for k, v in feats.items(): - feats_values.append('='.join((k, v)) if v else k) - field = "|".join(feats_values) - else: - field = "_" - token["feats"] = field - - # metadata - fields_["metadata"] = allen_fields.MetadataField({"input": tree, - "field_names": self.fields, - "tokens": tokens}) - - return allen_data.Instance(fields_) - - @staticmethod - def _feat_values(tree: List[Dict[str, Any]]): - features = [] - for token in tree: - token_features = [] - if token["feats"] is not None: - for feat, value in token["feats"].items(): - if feat in ["_", "__ROOT__"]: - pass - else: - # Handle case where feature is binary (doesn't have associated value) - if value: - token_features.append(feat + "=" + value) - else: - token_features.append(feat) - features.append(token_features) - return features - - @staticmethod - def _feats_as_tensor_wrapper(field: fields.SequenceMultiLabelField): - def as_tensor(padding_lengths): - desired_num_tokens = padding_lengths["num_tokens"] - assert len(field._indexed_multi_labels) > 0 - classes_count = len(field._indexed_multi_labels[0]) - default_value = [0.0] * classes_count - padded_tags = util.pad_sequence_to_length(field._indexed_multi_labels, desired_num_tokens, - lambda: default_value) - tensor = torch.tensor(padded_tags, dtype=torch.long) - return tensor - - return as_tensor - - @staticmethod - def _feats_indexer(vocab: allen_data.Vocabulary): - label_namespace = "feats_labels" - vocab_size = vocab.get_vocab_size(label_namespace) - slices = get_slices_if_not_provided(vocab) - - def _m_from_n_ones_encoding(multi_label: List[str], sentence_length: int) -> List[int]: - one_hot_encoding = [0] * vocab_size - for cat, cat_indices in slices.items(): - if cat not in ["__PAD__", "_"]: - label_from_cat = [label for label in multi_label if cat == label.split("=")[0]] - if label_from_cat: - label_from_cat = label_from_cat[0] - index = vocab.get_token_index(label_from_cat, label_namespace) - else: - # Get Cat=None index - index = vocab.get_token_index(cat + "=None", label_namespace) - one_hot_encoding[index] = 1 - return one_hot_encoding - - return _m_from_n_ones_encoding - - -@allen_data.Vocabulary.register("from_instances_extended", constructor="from_instances_extended") -class Vocabulary(allen_data.Vocabulary): - - @classmethod - def from_instances_extended( - cls, - instances: Iterable[allen_data.Instance], - min_count: Dict[str, int] = None, - max_vocab_size: Union[int, Dict[str, int]] = None, - non_padded_namespaces: Iterable[str] = vocabulary.DEFAULT_NON_PADDED_NAMESPACES, - pretrained_files: Optional[Dict[str, str]] = None, - only_include_pretrained_words: bool = False, - min_pretrained_embeddings: Dict[str, int] = None, - padding_token: Optional[str] = vocabulary.DEFAULT_PADDING_TOKEN, - oov_token: Optional[str] = vocabulary.DEFAULT_OOV_TOKEN, - ) -> "Vocabulary": - """ - Extension to manually fill gaps in missing 'feats_labels'. - """ - # Load manually tokens from pretrained file (using different strategy - # - only words add all embedding file, without checking if were seen - # in any dataset. - tokens_to_add = None - if pretrained_files and "tokens" in pretrained_files: - pretrained_set = set(vocabulary._read_pretrained_tokens(pretrained_files["tokens"])) - tokens_to_add = {"tokens": list(pretrained_set)} - pretrained_files = None - - vocab = super().from_instances( - instances=instances, - min_count=min_count, - max_vocab_size=max_vocab_size, - non_padded_namespaces=non_padded_namespaces, - pretrained_files=pretrained_files, - only_include_pretrained_words=only_include_pretrained_words, - tokens_to_add=tokens_to_add, - min_pretrained_embeddings=min_pretrained_embeddings, - padding_token=padding_token, - oov_token=oov_token - ) - # Extending vocab with features that does not show up explicitly. - # To know all features we need to read full dataset first. - # Adding auxiliary '=None' feature for each category is needed - # to perform classification. - get_slices_if_not_provided(vocab) - return vocab - - -def get_slices_if_not_provided(vocab: allen_data.Vocabulary): - if hasattr(vocab, "slices"): - return vocab.slices - - if "feats_labels" in vocab.get_namespaces(): - idx2token = vocab.get_index_to_token_vocabulary("feats_labels") - for _, v in dict(idx2token).items(): - if v not in ["_", "__PAD__"]: - empty_value = v.split("=")[0] + "=None" - vocab.add_token_to_namespace(empty_value, "feats_labels") - - slices = {} - for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items(): - # There are 2 types features: with (Case=Acc) or without assigment (None). - # Here we group their indices by name (before assigment sign). - name = name.split("=")[0] - if name in slices: - slices[name].append(idx) - else: - slices[name] = [idx] - vocab.slices = slices - return vocab.slices - - -@dataclass(init=False, repr=False) -class _Token(allen_data.Token): - __slots__ = allen_data.Token.__slots__ + ['feats_'] - - feats_: Optional[str] +class DatasetReader: + pass - def __init__(self, text: str = None, idx: int = None, idx_end: int = None, lemma_: str = None, pos_: str = None, - tag_: str = None, dep_: str = None, ent_type_: str = None, text_id: int = None, type_id: int = None, - feats_: str = None) -> None: - super().__init__(text, idx, idx_end, lemma_, pos_, tag_, dep_, ent_type_, text_id, type_id) - self.feats_ = feats_ +class UniversalDependenciesDatasetReader(DatasetReader): + pass \ No newline at end of file diff --git a/combo/data/fields/__init__.py b/combo/data/fields/__init__.py index 4f8d3a265f506aab43e0cda78ec931cd5c44c92e..57de34489861fbde17483b3aa9097f92f28c5dc8 100644 --- a/combo/data/fields/__init__.py +++ b/combo/data/fields/__init__.py @@ -1 +1,2 @@ +from .base_field import Field from .sequence_multilabel_field import SequenceMultiLabelField diff --git a/combo/data/fields/base_field.py b/combo/data/fields/base_field.py new file mode 100644 index 0000000000000000000000000000000000000000..83ea563e6e7cb424cadbe23428045059f2cb04a5 --- /dev/null +++ b/combo/data/fields/base_field.py @@ -0,0 +1,5 @@ +from abc import ABCMeta + + +class Field(metaclass=ABCMeta): + pass diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py index c31f78e447c94d1f4ba0c91713299b0623b02dac..b7a6bbe361f17c2c8168e9415bba249107e6c82e 100644 --- a/combo/data/fields/sequence_multilabel_field.py +++ b/combo/data/fields/sequence_multilabel_field.py @@ -1,18 +1,16 @@ """Sequence multilabel field implementation.""" import logging -import textwrap from typing import Set, List, Callable, Iterator, Union, Dict import torch -from allennlp import data -from allennlp.common import checks -from allennlp.data import fields from overrides import overrides +from combo.data.fields import Field + logger = logging.getLogger(__name__) -class SequenceMultiLabelField(data.Field[torch.Tensor]): +class SequenceMultiLabelField(Field): """ A `SequenceMultiLabelField` is an extension of the :class:`MultiLabelField` that allows for multiple labels while keeping sequence dimension. @@ -93,43 +91,23 @@ class SequenceMultiLabelField(data.Field[torch.Tensor]): @overrides def count_vocab_items(self, counter: Dict[str, Dict[str, int]]): - if self._indexed_multi_labels is None: - for multi_label in self.multi_labels: - for label in multi_label: - counter[self._label_namespace][label] += 1 # type: ignore + pass @overrides def index(self, vocab: data.Vocabulary): - indexer = self.multi_label_indexer(vocab) - - indexed = [] - for multi_label in self.multi_labels: - indexed.append(indexer(multi_label, len(self.multi_labels))) - self._indexed_multi_labels = indexed + pass @overrides def get_padding_lengths(self) -> Dict[str, int]: - return {"num_tokens": self.sequence_field.sequence_length()} + pass @overrides def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: - return self.as_tensor_wrapper(padding_lengths) + pass @overrides def empty_field(self) -> "SequenceMultiLabelField": - empty_list: List[List[str]] = [[]] - sequence_label_field = SequenceMultiLabelField(empty_list, lambda x: lambda y: y, - lambda x: lambda y: y, - self.sequence_field.empty_field()) - sequence_label_field._indexed_labels = empty_list - return sequence_label_field + pass def __str__(self) -> str: - length = self.sequence_field.sequence_length() - formatted_labels = "".join( - "\t\t" + labels + "\n" for labels in textwrap.wrap(repr(self.multi_labels), 100) - ) - return ( - f"SequenceMultiLabelField of length {length} with " - f"labels:\n {formatted_labels} \t\tin namespace: '{self._label_namespace}'." - ) + pass diff --git a/combo/data/samplers/__init__.py b/combo/data/samplers/__init__.py index ab003fe4341da38983553324b0454681a715eda1..1db58a9b3221ba0442c0d0b3321be208ab2c9daa 100644 --- a/combo/data/samplers/__init__.py +++ b/combo/data/samplers/__init__.py @@ -1 +1,2 @@ +from .base_sampler import Sampler from .samplers import TokenCountBatchSampler diff --git a/combo/data/samplers/base_sampler.py b/combo/data/samplers/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5cd4017069e289f4661211e2d8b0aa38a653f3 --- /dev/null +++ b/combo/data/samplers/base_sampler.py @@ -0,0 +1,5 @@ +from abc import ABCMeta + + +class Sampler(metaclass=ABCMeta): + pass diff --git a/combo/data/samplers/samplers.py b/combo/data/samplers/samplers.py index 5db74d93d4fccc5de3fa72e033f152c86fb981bf..ee32a5e54ef9a58e047182594294103c501003cf 100644 --- a/combo/data/samplers/samplers.py +++ b/combo/data/samplers/samplers.py @@ -2,11 +2,10 @@ from typing import List import numpy as np -from allennlp import data as allen_data +from combo.data.samplers import Sampler -@allen_data.BatchSampler.register("token_count") -class TokenCountBatchSampler(allen_data.BatchSampler): +class TokenCountBatchSampler(Sampler): def __init__(self, dataset, word_batch_size: int = 2500, shuffle_dataset: bool = True): self._index = 0 diff --git a/combo/data/token.py b/combo/data/token.py new file mode 100644 index 0000000000000000000000000000000000000000..048fab0382b85251da06d69a09b4057120f5bee8 --- /dev/null +++ b/combo/data/token.py @@ -0,0 +1,2 @@ +class Token: + pass diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py index 550a80ba6947f51cc5e0ef9a11b095e8488805cb..d17954046afd71b2ab756c1079f820157fcd4258 100644 --- a/combo/data/token_indexers/__init__.py +++ b/combo/data/token_indexers/__init__.py @@ -1,3 +1,4 @@ +from .base_indexer import TokenIndexer from .pretrained_transformer_mismatched_indexer import PretrainedTransformerMismatchedIndexer from .token_characters_indexer import TokenCharactersIndexer from .token_features_indexer import TokenFeatsIndexer diff --git a/combo/data/token_indexers/base_indexer.py b/combo/data/token_indexers/base_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb48c0c047365b26cf3c7343da1523fc5077d1f --- /dev/null +++ b/combo/data/token_indexers/base_indexer.py @@ -0,0 +1,9 @@ +from abc import ABCMeta + + +class TokenIndexer(metaclass=ABCMeta): + pass + + +class PretrainedTransformerMismatchedIndexer(TokenIndexer): + pass diff --git a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py index fc29896a2ecbb5408c9367f2d8ad5b0c1d4a5d4c..9aa3616cfd44bf7dc5baabfb6e3c98417e0070fb 100644 --- a/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py +++ b/combo/data/token_indexers/pretrained_transformer_mismatched_indexer.py @@ -1,117 +1,13 @@ -from typing import Optional, Dict, Any, List, Tuple +from combo.data import TokenIndexer -from allennlp import data -from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary -from overrides import overrides +class PretrainedTransformerMismatchedIndexer(TokenIndexer): + pass -@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed") -class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer): - def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None: - # The matched version v.s. mismatched - super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs) - self._matched_indexer = PretrainedTransformerIndexer( - model_name, - namespace=namespace, - max_length=max_length, - tokenizer_kwargs=tokenizer_kwargs, - **kwargs, - ) - self._allennlp_tokenizer = self._matched_indexer._allennlp_tokenizer - self._tokenizer = self._matched_indexer._tokenizer - self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens - self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens +class PretrainedTransformerIndexer(TokenIndexer): + pass - @overrides - def tokens_to_indices(self, - tokens, - vocabulary: vocabulary ) -> IndexedTokenList: - """ - Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the - maximal input of a model. - """ - self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary) - wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize( - [t.ensure_text() for t in tokens]) - - if len(wordpieces) > self._tokenizer.max_len_single_sentence: - raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\ - " ".join([str(x) for x in tokens[:10]]) + " ... \n" + \ - f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \ - f"Current input: {len(wordpieces)}") - - offsets = [x if x is not None else (-1, -1) for x in offsets] - - output: IndexedTokenList = { - "token_ids": [t.text_id for t in wordpieces], - "mask": [True] * len(tokens), # for original tokens (i.e. word-level) - "type_ids": [t.type_id for t in wordpieces], - "offsets": offsets, - "wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level) - } - - return self._matched_indexer._postprocess_output(output) - - -class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer): - - def __init__( - self, - model_name: str, - namespace: str = "tags", - max_length: int = None, - tokenizer_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> None: - super().__init__(model_name, namespace, max_length, tokenizer_kwargs, **kwargs) - self._namespace = namespace - self._allennlp_tokenizer = PretrainedTransformerTokenizer( - model_name, tokenizer_kwargs=tokenizer_kwargs - ) - self._tokenizer = self._allennlp_tokenizer.tokenizer - self._added_to_vocabulary = False - - self._num_added_start_tokens = len(self._allennlp_tokenizer.single_sequence_start_tokens) - self._num_added_end_tokens = len(self._allennlp_tokenizer.single_sequence_end_tokens) - - self._max_length = max_length - if self._max_length is not None: - num_added_tokens = len(self._allennlp_tokenizer.tokenize("a")) - 1 - self._effective_max_length = ( # we need to take into account special tokens - self._max_length - num_added_tokens - ) - if self._effective_max_length <= 0: - raise ValueError( - "max_length needs to be greater than the number of special tokens inserted." - ) - - -class PretrainedTransformerTokenizer(tokenizers.PretrainedTransformerTokenizer): - - def _intra_word_tokenize( - self, string_tokens: List[str] - ) -> Tuple[List[data.Token], List[Optional[Tuple[int, int]]]]: - tokens: List[data.Token] = [] - offsets: List[Optional[Tuple[int, int]]] = [] - for token_string in string_tokens: - wordpieces = self.tokenizer.encode_plus( - token_string, - add_special_tokens=False, - return_tensors=None, - return_offsets_mapping=False, - return_attention_mask=False, - ) - wp_ids = wordpieces["input_ids"] - - if len(wp_ids) > 0: - offsets.append((len(tokens), len(tokens) + len(wp_ids) - 1)) - tokens.extend( - data.Token(text=wp_text, text_id=wp_id) - for wp_id, wp_text in zip(wp_ids, self.tokenizer.convert_ids_to_tokens(wp_ids)) - ) - else: - offsets.append(None) - return tokens, offsets +class PretrainedTransformerTokenizer(TokenIndexer): + pass diff --git a/combo/data/token_indexers/token_characters_indexer.py b/combo/data/token_indexers/token_characters_indexer.py index ea7a3eaab32a79758573fe6f778c59060d9bce6c..6f5dbf035ce518783aa799e7330e51c0c92a64ba 100644 --- a/combo/data/token_indexers/token_characters_indexer.py +++ b/combo/data/token_indexers/token_characters_indexer.py @@ -1,62 +1,6 @@ -"""Custom character token indexer.""" -import itertools -from typing import List, Dict +from combo.data import TokenIndexer -import torch -from allennlp import data -from allennlp.common import util -from allennlp.data import tokenizers -from allennlp.data.token_indexers import token_characters_indexer -from overrides import overrides - -@data.TokenIndexer.register("characters_const_padding") -class TokenCharactersIndexer(token_characters_indexer.TokenCharactersIndexer): +class TokenCharactersIndexer(TokenIndexer): """Wrapper around allennlp token indexer with const padding.""" - - def __init__(self, - namespace: str = "token_characters", - character_tokenizer: tokenizers.CharacterTokenizer = tokenizers.CharacterTokenizer(), - start_tokens: List[str] = None, - end_tokens: List[str] = None, - min_padding_length: int = 0, - token_min_padding_length: int = 0): - super().__init__(namespace, character_tokenizer, start_tokens, end_tokens, min_padding_length, - token_min_padding_length) - - @overrides - def get_padding_lengths(self, indexed_tokens: data.IndexedTokenList) -> Dict[str, int]: - padding_lengths = {"token_characters": len(indexed_tokens["token_characters"]), - "num_token_characters": self._min_padding_length} - return padding_lengths - - @overrides - def as_padded_tensor_dict( - self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int] - ) -> Dict[str, torch.Tensor]: - # Pad the tokens. - padded_tokens = util.pad_sequence_to_length( - tokens["token_characters"], - padding_lengths["token_characters"], - default_value=lambda: [], - ) - - # Pad the characters within the tokens. - desired_token_length = padding_lengths["num_token_characters"] - longest_token: List[int] = max(tokens["token_characters"], key=len, default=[]) # type: ignore - padding_value = 0 - if desired_token_length > len(longest_token): - # Since we want to pad to greater than the longest token, we add a - # "dummy token" so we can take advantage of the fast implementation of itertools.zip_longest. - padded_tokens.append([padding_value] * desired_token_length) - # pad the list of lists to the longest sublist, appending 0's - padded_tokens = list(zip(*itertools.zip_longest(*padded_tokens, fillvalue=padding_value))) - if desired_token_length > len(longest_token): - # Removes the "dummy token". - padded_tokens.pop() - # Truncates all the tokens to the desired length, and return the result. - return { - "token_characters": torch.LongTensor( - [list(token[:desired_token_length]) for token in padded_tokens] - ) - } + pass diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py index 7c591243ec19b0f27c3344cd748d17e3b9aa50f6..b6267a4377bcbe665a2716c8bf7697454b2f738f 100644 --- a/combo/data/token_indexers/token_features_indexer.py +++ b/combo/data/token_indexers/token_features_indexer.py @@ -1,75 +1,7 @@ """Features indexer.""" -import collections -from typing import List, Dict -import torch -from allennlp import data -from allennlp.common import util -from overrides import overrides +from combo.data import TokenIndexer -@data.TokenIndexer.register("feats_indexer") -class TokenFeatsIndexer(data.TokenIndexer): - - def __init__( - self, - namespace: str = "feats", - feature_name: str = "feats_", - token_min_padding_length: int = 0, - ) -> None: - super().__init__(token_min_padding_length) - self.namespace = namespace - self._feature_name = feature_name - - @overrides - def count_vocab_items(self, token: data.Token, counter: Dict[str, Dict[str, int]]): - feats = self._feat_values(token) - for feat in feats: - counter[self.namespace][feat] += 1 - - @overrides - def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList: - indices: List[List[int]] = [] - vocab_size = vocabulary.get_vocab_size(self.namespace) - for token in tokens: - token_indices = [] - feats = self._feat_values(token) - for feat in feats: - token_indices.append(vocabulary.get_token_index(feat, self.namespace)) - indices.append(util.pad_sequence_to_length(token_indices, vocab_size)) - return {"tokens": indices} - - @overrides - def get_empty_token_list(self) -> data.IndexedTokenList: - return {"tokens": [[]]} - - def _feat_values(self, token): - feats = getattr(token, self._feature_name) - if feats is None: - feats = collections.OrderedDict() - features = [] - for feat, value in feats.items(): - if feat in ["_", "__ROOT__"]: - pass - else: - # Handle case where feature is binary (doesn't have associated value) - if value: - features.append(feat + "=" + value) - else: - features.append(feat) - return features - - @overrides - def as_padded_tensor_dict( - self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int] - ) -> Dict[str, torch.Tensor]: - tensor_dict = {} - for key, val in tokens.items(): - vocab_size = len(val[0]) - tensor = torch.tensor(util.pad_sequence_to_length(val, - padding_lengths[key], - default_value=lambda: [0] * vocab_size, - ) - ) - tensor_dict[key] = tensor - return tensor_dict +class TokenFeatsIndexer(TokenIndexer): + pass diff --git a/combo/main.py b/combo/main.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4d945ea9624dbe64da0d96ed4d188a5dbb80b8 --- /dev/null +++ b/combo/main.py @@ -0,0 +1,150 @@ +import logging +import os +import pathlib +import tempfile +from typing import Dict + +import torch +from absl import app +from absl import flags + +from combo import models +from combo.models.base import Predictor +from combo.utils import checks + +logger = logging.getLogger(__name__) +_FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"] +_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"] + +FLAGS = flags.FLAGS +flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"], + help="Specify COMBO mode: train or predict") + +# Common flags +flags.DEFINE_integer(name="cuda_device", default=-1, + help="Cuda device id (default -1 cpu)") +flags.DEFINE_string(name="output_file", default="output.log", + help="Predictions result file.") + +# Training flags +flags.DEFINE_list(name="training_data_path", default=[], + help="Training data path(s)") +flags.DEFINE_alias(name="training_data", original_name="training_data_path") +flags.DEFINE_list(name="validation_data_path", default="", + help="Validation data path(s)") +flags.DEFINE_alias(name="validation_data", original_name="validation_data_path") +flags.DEFINE_string(name="pretrained_tokens", default="", + help="Pretrained tokens embeddings path") +flags.DEFINE_integer(name="embedding_dim", default=300, + help="Embeddings dim") +flags.DEFINE_integer(name="num_epochs", default=400, + help="Epochs num") +flags.DEFINE_integer(name="word_batch_size", default=2500, + help="Minimum words in batch") +flags.DEFINE_string(name="pretrained_transformer_name", default="", + help="Pretrained transformer model name (see transformers from HuggingFace library for list of " + "available models) for transformers based embeddings.") +flags.DEFINE_list(name="features", default=["token", "char"], + help=f"Features used to train model (required 'token' and 'char'). Possible values: {_FEATURES}.") +flags.DEFINE_list(name="targets", default=["deprel", "feats", "head", "lemma", "upostag", "xpostag"], + help=f"Targets of the model (required `deprel` and `head`). Possible values: {_TARGETS}.") +flags.DEFINE_string(name="serialization_dir", default=None, + help="Model serialization directory (default - system temp dir).") +flags.DEFINE_boolean(name="tensorboard", default=False, + help="When provided model will log tensorboard metrics.") + +# Finetune after training flags +flags.DEFINE_list(name="finetuning_training_data_path", default="", + help="Training data path(s)") +flags.DEFINE_list(name="finetuning_validation_data_path", default="", + help="Validation data path(s)") +flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.template.jsonnet"), + help="Config file path.") + +# Test after training flags +flags.DEFINE_string(name="test_path", default=None, + help="Test path file.") + +# Experimental +flags.DEFINE_boolean(name="use_pure_config", default=False, + help="Ignore ext flags (experimental).") + +# Prediction flags +flags.DEFINE_string(name="model_path", default=None, + help="Pretrained model path.") +flags.DEFINE_string(name="input_file", default=None, + help="File to predict path") +flags.DEFINE_boolean(name="conllu_format", default=True, + help="Prediction based on conllu format (instead of raw text).") +flags.DEFINE_integer(name="batch_size", default=1, + help="Prediction batch size.") +flags.DEFINE_boolean(name="silent", default=True, + help="Silent prediction to file (without printing to console).") +flags.DEFINE_enum(name="predictor_name", default="combo-spacy", + enum_values=["combo", "combo-spacy"], + help="Use predictor with whitespace or spacy tokenizer.") + + +def run(_): + pass + + +def _get_predictor() -> Predictor: + # Check for GPU + # allen_checks.check_for_gpu(FLAGS.cuda_device) + checks.file_exists(FLAGS.model_path) + # load model from archive + # archive = models.load_archive( + # FLAGS.model_path, + # cuda_device=FLAGS.cuda_device, + # ) + # return predictors.Predictor.from_archive( + # archive, FLAGS.predictor_name + # ) + return Predictor() + + +def _get_ext_vars(finetuning: bool = False) -> Dict: + if FLAGS.use_pure_config: + return {} + return { + "training_data_path": ( + ",".join(FLAGS.training_data_path if not finetuning else FLAGS.finetuning_training_data_path)), + "validation_data_path": ( + ",".join(FLAGS.validation_data_path if not finetuning else FLAGS.finetuning_validation_data_path)), + "pretrained_tokens": FLAGS.pretrained_tokens, + "pretrained_transformer_name": FLAGS.pretrained_transformer_name, + "features": " ".join(FLAGS.features), + "targets": " ".join(FLAGS.targets), + "type": "finetuning" if finetuning else "default", + "embedding_dim": str(FLAGS.embedding_dim), + "cuda_device": str(FLAGS.cuda_device), + "num_epochs": str(FLAGS.num_epochs), + "word_batch_size": str(FLAGS.word_batch_size), + "use_tensorboard": str(FLAGS.tensorboard), + } + + +def main(): + """Parse flags.""" + flags.register_validator( + "features", + lambda values: all( + value in _FEATURES for value in values), + message="Flags --features contains unknown value(s)." + ) + flags.register_validator( + "mode", + lambda value: value is not None, + message="Flag --mode must be set with either `predict` or `train` value") + flags.register_validator( + "targets", + lambda values: all( + value in _TARGETS for value in values), + message="Flag --targets contains unknown value(s)." + ) + app.run(run) + + +if __name__ == "__main__": + main() diff --git a/combo/models/__init__.py b/combo/models/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9ad341a12380b13d980905099b864b71266cd95e 100644 --- a/combo/models/__init__.py +++ b/combo/models/__init__.py @@ -0,0 +1,8 @@ +from .base import FeedForwardPredictor +from .graph_parser import GraphDependencyRelationModel +from .parser import DependencyRelationModel +from .embeddings import CharacterBasedWordEmbeddings +from .encoder import ComboEncoder +from .lemma import LemmatizerModel +from .model import ComboModel +from .morpho import MorphologicalFeatures diff --git a/combo/models/base.py b/combo/models/base.py index 11df601bc919523987a56bc77ae6407d5a2db2d9..016ef85da3c5642e5dca3268f0ee8bba1f17c465 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -4,7 +4,11 @@ import torch import torch.nn as nn import utils import combo.models.combo_nn as combo_nn -import combo.checks as checks +import combo.utils.checks as checks + + +class Model: + pass class Predictor(nn.Module): diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..e5ebff89530ffb9131cd0a154a55cd0d5287fa55 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -0,0 +1,24 @@ +class Embedding: + pass + +class TokenEmbedder: + pass + +class CharacterBasedWordEmbeddings(TokenEmbedder): + pass + + +class ProjectedWordEmbedder(TokenEmbedder): + pass + + +class PretrainedTransformerMismatchedEmbedder(TokenEmbedder): + pass + + +class TransformersWordEmbedder(PretrainedTransformerMismatchedEmbedder): + pass + + +class FeatsTokenEmbedder(TokenEmbedder): + pass \ No newline at end of file diff --git a/combo/models/encoder.py b/combo/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f25f4a1fd6b86b2183533787a0cb3b539884f648 --- /dev/null +++ b/combo/models/encoder.py @@ -0,0 +1,10 @@ +class Encoder: + pass + + +class StackedBidirectionalLstm(Encoder): + pass + + +class ComboEncoder(Encoder): + pass diff --git a/combo/models/graph_parser.py b/combo/models/graph_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..82eefde3185ba776c861504d1554dc4fb7cda58d --- /dev/null +++ b/combo/models/graph_parser.py @@ -0,0 +1,9 @@ +from combo.models.base import Predictor + + +class GraphHeadPredictionModel(Predictor): + pass + + +class GraphDependencyRelationModel(Predictor): + pass diff --git a/combo/models/lemma.py b/combo/models/lemma.py new file mode 100644 index 0000000000000000000000000000000000000000..77254f6d3dd9ce4828c7ef54d947f17bf5c46d8d --- /dev/null +++ b/combo/models/lemma.py @@ -0,0 +1,5 @@ +from combo.models.base import Predictor + + +class LemmatizerModel(Predictor): + pass diff --git a/combo/models/model.py b/combo/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..134a6068a2824f9582e23a3dd41cb9ab1bf4ce5c --- /dev/null +++ b/combo/models/model.py @@ -0,0 +1,5 @@ +from combo.models.base import Model + + +class ComboModel(Model): + pass \ No newline at end of file diff --git a/combo/models/morpho.py b/combo/models/morpho.py new file mode 100644 index 0000000000000000000000000000000000000000..4d65686aa2b71e4bca2f9bfcb087b2c4ab032394 --- /dev/null +++ b/combo/models/morpho.py @@ -0,0 +1,5 @@ +from combo.models.base import Predictor + + +class MorphologicalFeatures(Predictor): + pass diff --git a/combo/models/parser.py b/combo/models/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..46e3766cf52c1a87111c13b914ecf95bf9d540c2 --- /dev/null +++ b/combo/models/parser.py @@ -0,0 +1,9 @@ +from combo.models.base import Predictor + + +class HeadPredictionModel(Predictor): + pass + + +class DependencyRelationModel(Predictor): + pass diff --git a/combo/predict.py b/combo/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf979763e349e4524b6e2ee7f8b8729275a2273 --- /dev/null +++ b/combo/predict.py @@ -0,0 +1,15 @@ +import logging +import os +import sys +from typing import List, Union, Dict, Any + +from combo import data +from combo.data import sentence2conllu, tokens2conllu, conllu2sentence +from combo.models.base import Predictor +from combo.utils import download, graph + +logger = logging.getLogger(__name__) + + +class COMBO(Predictor): + pass diff --git a/combo/training/__init__.py b/combo/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94da4e693c4ef85a0f753b2fa73ddc10eb1a52b7 --- /dev/null +++ b/combo/training/__init__.py @@ -0,0 +1,3 @@ +from .checkpointer import FinishingTrainingCheckpointer +from .scheduler import Scheduler +from .trainer import GradientDescentTrainer diff --git a/combo/training/checkpointer.py b/combo/training/checkpointer.py new file mode 100644 index 0000000000000000000000000000000000000000..9823f148ae0e5b1581762e1aab824ff5cc4abf34 --- /dev/null +++ b/combo/training/checkpointer.py @@ -0,0 +1,6 @@ +class Checkpointer: + pass + + +class FinishingTrainingCheckpointer: + pass diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..eff71b90acb652b986c9e797371e50a8fdfd737b --- /dev/null +++ b/combo/training/scheduler.py @@ -0,0 +1,2 @@ +class Scheduler: + pass diff --git a/combo/training/tensorboard_writer.py b/combo/training/tensorboard_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..f56966c1bb175d29589fcb2973b3f13760651d90 --- /dev/null +++ b/combo/training/tensorboard_writer.py @@ -0,0 +1,2 @@ +class NullTensorboardWriter: + pass \ No newline at end of file diff --git a/combo/training/trainer.py b/combo/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb8ff0c3d33f593a8cf80cdd7720d41a8189cc1 --- /dev/null +++ b/combo/training/trainer.py @@ -0,0 +1,13 @@ +from pytorch_lightning import Trainer + + +class Callback: + pass + + +class TransferPatienceEpochCallback: + pass + + +class GradientDescentTrainer(Trainer): + pass diff --git a/combo/utils/__init__.py b/combo/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/combo/utils/checks.py b/combo/utils/checks.py new file mode 100644 index 0000000000000000000000000000000000000000..ea76ac1a49f53219471163f8c83126305f0cf138 --- /dev/null +++ b/combo/utils/checks.py @@ -0,0 +1,15 @@ +import torch + + +class ConfigurationError(Exception): + def __init__(self, message: str): + super().__init__() + self.message = message + + +def file_exists(*paths): + pass + + +def check_size_match(size_1: torch.Size, size_2: torch.Size, tensor_1_name: str, tensor_2_name: str): + pass diff --git a/combo/utils/download.py b/combo/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7ce6f951147e48f7ac793ce7b4e59817da9c89 --- /dev/null +++ b/combo/utils/download.py @@ -0,0 +1,78 @@ +import errno +import logging +import os + +import requests +import tqdm +import urllib3 +from requests import adapters, exceptions + +logger = logging.getLogger(__name__) + +DATA_TO_PATH = { + "enhanced" : "iwpt_2020", + "iwpt2021" : "iwpt_2021", + "ud25" : "ud_25", + "ud27" : "ud_27", + "ud29" : "ud_29"} +_URL = "http://s3.clarin-pl.eu/dspace/combo/{data}/{model}.tar.gz" +_HOME_DIR = os.getenv("HOME", os.curdir) +_CACHE_DIR = os.getenv("COMBO_DIR", os.path.join(_HOME_DIR, ".combo")) + + +def download_file(model_name, force=False): + _make_cache_dir() + data = model_name.split("-")[-1] + url = _URL.format(model=model_name, data=DATA_TO_PATH[data]) + local_filename = url.split("/")[-1] + location = os.path.join(_CACHE_DIR, local_filename) + if os.path.exists(location) and not force: + logger.debug("Using cached model.") + return location + chunk_size = 1024 + logger.info(url) + try: + with _requests_retry_session(retries=2).get(url, stream=True) as r: + pbar = tqdm.tqdm(unit="B", total=int(r.headers.get("content-length")), + unit_divisor=chunk_size, unit_scale=True) + with open(location, "wb") as f: + with pbar: + for chunk in r.iter_content(chunk_size): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + except exceptions.RetryError: + raise ConnectionError(f"Couldn't find or download model {model_name}.tar.gz. " + "Check if model name is correct or try again later!") + + return location + + +def _make_cache_dir(): + try: + os.makedirs(_CACHE_DIR) + logger.info(f"Making cache dir {_CACHE_DIR}") + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def _requests_retry_session( + retries=3, + backoff_factor=0.3, + status_forcelist=(404, 500, 502, 504), + session=None, +): + """Source: https://www.peterbe.com/plog/best-practice-with-retries-with-requests""" + session = session or requests.Session() + retry = urllib3.Retry( + total=retries, + read=retries, + connect=retries, + backoff_factor=backoff_factor, + status_forcelist=status_forcelist, + ) + adapter = adapters.HTTPAdapter(max_retries=retry) + session.mount("http://", adapter) + session.mount("https://", adapter) + return session diff --git a/combo/utils/graph.py b/combo/utils/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..f61a68e5b835da0c2ce3dac438425c602b084240 --- /dev/null +++ b/combo/utils/graph.py @@ -0,0 +1,149 @@ +"""Based on https://github.com/emorynlp/iwpt-shared-task-2020.""" + +import numpy as np + +_ACL_REL_CL = "acl:relcl" + + +def graph_and_tree_merge(tree_arc_scores, + tree_rel_scores, + graph_arc_scores, + graph_rel_scores, + label2idx, + idx2label, + graph_label2idx, + graph_idx2label, + tokens): + graph_arc_scores = np.copy(graph_arc_scores) + # Exclude self-loops, in-place operation. + np.fill_diagonal(graph_arc_scores, 0) + # Connection to root will be handled by tree. + graph_arc_scores[:, 0] = False + # The same with labels. + root_idx = graph_label2idx["root"] + graph_rel_scores[:, :, root_idx] = -float('inf') + graph_rel_pred = graph_rel_scores.argmax(-1) + + # Add tree edges to graph + tree_heads = [0] + tree_arc_scores + graph = [[] for _ in range(len(tree_heads))] + labeled_graph = [[] for _ in range(len(tree_heads))] + for d, h in enumerate(tree_heads): + if not d: + continue + label = idx2label[tree_rel_scores[d - 1]] + # graph_label = graph_idx2label[graph_rel_pred[d - 1][h - 1]] + # if ">" in graph_label and label in graph_label: + # print("Using graph label instead of tree.") + # label = graph_label + if label != _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + # Debug only + # Extract graph edges + graph_edges = np.argwhere(graph_arc_scores) + + # Add graph edges which aren't creating a cycle + for (d, h) in graph_edges: + if not d or not h or d in graph[h]: + continue + try: + path = next(_dfs(graph, d, h)) + except StopIteration: + # There is not path from d to h + label = graph_idx2label[graph_rel_pred[d][h]] + if label != _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + # Add 'acl:relcl' without checking for cycles. + for d, h in enumerate(tree_heads): + if not d: + continue + label = idx2label[tree_rel_scores[d - 1]] + if label == _ACL_REL_CL: + graph[h].append(d) + labeled_graph[h].append((d, label)) + + assert len(labeled_graph[0]) == 1 + d = graph[0][0] + graph[d].append(0) + labeled_graph[d].append((0, "root")) + + parse_graph = [[] for _ in range(len(tree_heads))] + for h in range(len(tree_heads)): + for d, label in labeled_graph[h]: + parse_graph[d].append((h, label)) + parse_graph[d] = sorted(parse_graph[d]) + + for i, g in enumerate(parse_graph): + heads = np.array([x[0] for x in g]) + rels = np.array([x[1] for x in g]) + indices = rels.argsort() + heads = heads[indices].tolist() + rels = rels[indices].tolist() + deps = '|'.join(f'{h}:{r}' for h, r in zip(heads, rels)) + tokens[i - 1]["deps"] = deps + return + + +def _dfs(graph, start, end): + fringe = [(start, [])] + while fringe: + state, path = fringe.pop() + if path and state == end: + yield path + continue + for next_state in graph[state]: + if next_state in path: + continue + fringe.append((next_state, path + [next_state])) + + +def restore_collapse_edges(tree_tokens): + # https://gist.github.com/hankcs/776e7d95c19e5ff5da8469fe4e9ab050 + empty_tokens = [] + for token in tree_tokens: + deps = token["deps"].split("|") + for i, d in enumerate(deps): + if ">" in d: + # {head}:{empty_node_relation}>{current_node_relation} + # should map to + # For new, empty node: + # {head}:{empty_node_relation} + # For current node: + # {new_empty_node_id}:{current_node_relation} + # TODO consider where to put new_empty_node_id (currently at the end) + head, relation = d.split(':', 1) + ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 1}" + empty_node_relation, current_node_relation = relation.split(">", 1) + # Edge case, double > + if ">" in current_node_relation: + second_empty_node_relation, current_node_relation = current_node_relation.split(">") + deps[i] = f"{ehead}:{current_node_relation}" + second_ehead = f"{len(tree_tokens)}.{len(empty_tokens) + 2}" + empty_tokens.append( + { + "id": ehead, + "deps": f"{second_ehead}:{empty_node_relation}" + } + ) + empty_tokens.append( + { + "id": second_ehead, + "deps": f"{head}:{second_empty_node_relation}" + } + ) + + else: + deps[i] = f"{ehead}:{current_node_relation}" + empty_tokens.append( + { + "id": ehead, + "deps": f"{head}:{empty_node_relation}" + } + ) + deps = sorted([d.split(":", 1) for d in deps], key=lambda x: float(x[0])) + token["deps"] = "|".join([f"{k}:{v}" for k, v in deps]) + return empty_tokens diff --git a/combo/utils/metrics.py b/combo/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9181040e5516eeee8d84fa253b7c75a0cad82581 --- /dev/null +++ b/combo/utils/metrics.py @@ -0,0 +1,17 @@ +class Metric: + pass + +class LemmaAccuracy(Metric): + pass + + +class SequenceBoolAccuracy(Metric): + pass + + +class AttachmentScores(Metric): + pass + + +class SemanticMetrics(Metric): + pass diff --git a/requirements.txt b/requirements.txt index 61f05905d44038bd0391ccf2a06ca0d9cf7f096e..718d4173d2be86ef9ba72b9f70609bff8a7965a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ +absl-py~=1.4.0 conllu~=4.4.1 dependency-injector~=4.41.0 overrides~=7.3.1 torch~=1.13.1 torchtext~=0.14.1 numpy~=1.24.1 -pytorch-lightning~=1.9.0 \ No newline at end of file +pytorch-lightning~=1.9.0 +requests~=2.28.2 +tqdm~=4.64.1 +urllib3~=1.26.14 \ No newline at end of file