From 5d55691a198b02a39850227ea0250dacb66e1f52 Mon Sep 17 00:00:00 2001 From: Maja Jablonska <majajjablonska@gmail.com> Date: Wed, 26 Apr 2023 17:28:13 +0200 Subject: [PATCH] Add ClassificationTextfileDatasetReader --- .../classification_textfile_dataset_reader.py | 20 +++++++++++++------ combo/data/dataset_readers/dataset_reader.py | 8 ++++---- combo/data/fields/adjacency_field.py | 4 ++-- combo/data/fields/label_field.py | 2 +- combo/data/fields/list_field.py | 2 +- combo/data/fields/sequence_label_field.py | 2 +- .../data/fields/sequence_multilabel_field.py | 2 +- combo/data/instance.py | 2 +- requirements.txt | 2 ++ requirements_no_deps.txt | 1 + 10 files changed, 28 insertions(+), 17 deletions(-) create mode 100644 requirements_no_deps.txt diff --git a/combo/data/dataset_readers/classification_textfile_dataset_reader.py b/combo/data/dataset_readers/classification_textfile_dataset_reader.py index 9ec7ca9..a535a89 100644 --- a/combo/data/dataset_readers/classification_textfile_dataset_reader.py +++ b/combo/data/dataset_readers/classification_textfile_dataset_reader.py @@ -1,8 +1,8 @@ from typing import Dict, Iterable, Optional -from .dataset_reader import DatasetReader, DatasetReaderInput from overrides import overrides +from .dataset_reader import DatasetReader from .. import Instance, Tokenizer, TokenIndexer from ..fields.label_field import LabelField from ..fields.text_field import TextField @@ -10,12 +10,10 @@ from ..fields.text_field import TextField class ClassificationTextfileDatasetReader(DatasetReader): def __init__(self, - file_path: Optional[DatasetReaderInput] = None, tokenizer: Optional[Tokenizer] = None, - token_indexers: Optional[Dict[str, TokenIndexer]] = None, - separator: str = ',') -> None: - super().__init__(file_path, tokenizer, token_indexers) - self.__separator = separator + token_indexers: Optional[Dict[str, TokenIndexer]] = None) -> None: + super().__init__(tokenizer, token_indexers) + self.__separator = None @property def separator(self) -> str: @@ -27,6 +25,11 @@ class ClassificationTextfileDatasetReader(DatasetReader): @overrides def _read(self) -> Iterable[Instance]: + if self.file_path is None: + raise ValueError('File path is None') + elif self.separator is None: + raise ValueError('Separator is None') + with open(self.file_path, 'r') as lines: for line in lines: text, label = line.strip().split(self.separator) @@ -35,3 +38,8 @@ class ClassificationTextfileDatasetReader(DatasetReader): label_field = LabelField(label) fields = {'text': text_field, 'label': label_field} yield Instance(fields) + + def __call__(self, file_path: str, separator: str): + self.file_path = file_path + self.separator = separator + return self diff --git a/combo/data/dataset_readers/dataset_reader.py b/combo/data/dataset_readers/dataset_reader.py index 6a02274..0b1ffce 100644 --- a/combo/data/dataset_readers/dataset_reader.py +++ b/combo/data/dataset_readers/dataset_reader.py @@ -9,7 +9,8 @@ from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List from overrides import overrides from torch.utils.data import IterableDataset -from combo.data import Instance, Tokenizer +from combo.data.instance import Instance +from combo.data.tokenizers import Tokenizer from combo.data.token_indexers import TokenIndexer logger = logging.getLogger(__name__) @@ -25,11 +26,10 @@ class DatasetReader(IterableDataset): of `Instance`s. """ def __init__(self, - file_path: Optional[DatasetReaderInput] = None, tokenizer: Optional[Tokenizer] = None, token_indexers: Optional[Dict[str, TokenIndexer]] = None) -> None: super(DatasetReader).__init__() - self.__file_path = file_path + self.__file_path = None self.__tokenizer = tokenizer self.__token_indexers = token_indexers @@ -50,7 +50,7 @@ class DatasetReader(IterableDataset): return self.__token_indexers @overrides - def __getitem__(self, item) -> Instance: + def __getitem__(self, item, **kwargs) -> Instance: raise NotImplementedError @overrides diff --git a/combo/data/fields/adjacency_field.py b/combo/data/fields/adjacency_field.py index 492b5b9..a0ac7ab 100644 --- a/combo/data/fields/adjacency_field.py +++ b/combo/data/fields/adjacency_field.py @@ -9,8 +9,8 @@ import textwrap import torch -from combo.data import Vocabulary -from combo.data.fields import Field +from combo.data.vocabulary import Vocabulary +from combo.data.fields.field import Field from combo.data.fields.sequence_field import SequenceField from combo.utils import ConfigurationError diff --git a/combo/data/fields/label_field.py b/combo/data/fields/label_field.py index 12bad9a..3ba6097 100644 --- a/combo/data/fields/label_field.py +++ b/combo/data/fields/label_field.py @@ -9,7 +9,7 @@ import logging import torch -from combo.data import Vocabulary +from combo.data.vocabulary import Vocabulary from combo.data.fields import Field from combo.utils import ConfigurationError diff --git a/combo/data/fields/list_field.py b/combo/data/fields/list_field.py index 24e57f0..63d7cd5 100644 --- a/combo/data/fields/list_field.py +++ b/combo/data/fields/list_field.py @@ -4,7 +4,7 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/fields/list_field.py """ from typing import Dict, List, Iterator, Sequence, Any -from combo.data import Vocabulary +from combo.data.vocabulary import Vocabulary from combo.data.fields.field import DataArray, Field from combo.data.fields.sequence_field import SequenceField from combo.utils import pad_sequence_to_length diff --git a/combo/data/fields/sequence_label_field.py b/combo/data/fields/sequence_label_field.py index 4d0299d..2a33940 100644 --- a/combo/data/fields/sequence_label_field.py +++ b/combo/data/fields/sequence_label_field.py @@ -10,7 +10,7 @@ import textwrap import torch -from combo.data import Vocabulary +from combo.data.vocabulary import Vocabulary from combo.data.fields import Field from combo.data.fields.sequence_field import SequenceField from combo.utils import ConfigurationError, pad_sequence_to_length diff --git a/combo/data/fields/sequence_multilabel_field.py b/combo/data/fields/sequence_multilabel_field.py index a1e6683..a13a499 100644 --- a/combo/data/fields/sequence_multilabel_field.py +++ b/combo/data/fields/sequence_multilabel_field.py @@ -11,7 +11,7 @@ from typing import Set, List, Callable, Iterator, Union, Dict import torch from overrides import overrides -from combo.data import Vocabulary +from combo.data.vocabulary import Vocabulary from combo.data.fields import Field from combo.data.fields.sequence_field import SequenceField from combo.utils import ConfigurationError diff --git a/combo/data/instance.py b/combo/data/instance.py index b4c008a..34771ae 100644 --- a/combo/data/instance.py +++ b/combo/data/instance.py @@ -4,7 +4,7 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/instance.py """ from typing import Dict, MutableMapping, Mapping, Any -from combo.data import Vocabulary +from combo.data.vocabulary import Vocabulary from combo.data.fields import Field from combo.data.fields.field import DataArray diff --git a/requirements.txt b/requirements.txt index 119654b..f7222b4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,10 @@ absl-py~=1.4.0 base58~=2.1.1 cached-path~=1.3.3 conllu~=4.4.1 +conllutils~=1.1.4 dependency-injector~=4.41.0 dill~=0.3.6 +importlib-resources~=5.12.0 overrides~=7.3.1 torch~=2.0.0 torchtext~=0.15.1 diff --git a/requirements_no_deps.txt b/requirements_no_deps.txt new file mode 100644 index 0000000..1609845 --- /dev/null +++ b/requirements_no_deps.txt @@ -0,0 +1 @@ +lambo @ git+https://gitlab.clarin-pl.eu/syntactic-tools/lambo.git \ No newline at end of file -- GitLab