From bc8bdc77329199b0c1d20dcdd39049c7d301bfda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maja=20Jab=C5=82o=C5=84ska?= <majajjablonska@gmail.com> Date: Thu, 6 Apr 2023 17:07:03 +0200 Subject: [PATCH] Add get_slices_if_not_provided to data/dataset.py --- combo/data/dataset.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 62ca30f..de4805a 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -1,5 +1,6 @@ import logging +from combo import data logger = logging.getLogger(__name__) @@ -7,5 +8,30 @@ logger = logging.getLogger(__name__) class DatasetReader: pass + class UniversalDependenciesDatasetReader(DatasetReader): - pass \ No newline at end of file + pass + + +def get_slices_if_not_provided(vocab: data.Vocabulary): + if hasattr(vocab, "slices"): + return vocab.slices + + if "feats_labels" in vocab.get_namespaces(): + idx2token = vocab.get_index_to_token_vocabulary("feats_labels") + for _, v in dict(idx2token).items(): + if v not in ["_", "__PAD__"]: + empty_value = v.split("=")[0] + "=None" + vocab.add_token_to_namespace(empty_value, "feats_labels") + + slices = {} + for idx, name in vocab.get_index_to_token_vocabulary("feats_labels").items(): + # There are 2 types features: with (Case=Acc) or without assigment (None). + # Here we group their indices by name (before assigment sign). + name = name.split("=")[0] + if name in slices: + slices[name].append(idx) + else: + slices[name] = [idx] + vocab.slices = slices + return vocab.slices -- GitLab