diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 62ca30f2af2c3380f311c5f0fd2accca2d225ab9..de4805aee7ba1373e6fcb2a892605eb5a95ef553 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -1,5 +1,6 @@ import logging +from combo import data logger = logging.getLogger(__name__) @@ -7,5 +8,30 @@ logger = logging.getLogger(__name__) class DatasetReader: pass + class UniversalDependenciesDatasetReader(DatasetReader): - pass \ No newline at end of file + pass + + +def get_slices_if_not_provided(vocab: data.Vocabulary): + if hasattr(vocab, "slices"): + return vocab.slices + + if "feats_labels" in vocab.get_namespaces(): + idx2token = vocab.get_index_to_token_vocabulary("feats_labels") + for _, v in dict(idx2token).items(): + if v not in ["_", "__PAD__"]: + empty_value = v.split("=")[0] + "=None" + vocab.add_token_to_namespace(empty_value, "feats_labels") + + slices = {} + for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items(): + # There are 2 types features: with (Case=Acc) or without assigment (None). + # Here we group their indices by name (before assigment sign). + name = name.split("=")[0] + if name in slices: + slices[name].append(idx) + else: + slices[name] = [idx] + vocab.slices = slices + return vocab.slices