Skip to content
Snippets Groups Projects
Commit baf626c4 authored by Maja Jabłońska's avatar Maja Jabłońska
Browse files

Add TextClassificationJSONReader

parent 9aed5097
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -5,4 +5,4 @@ from .instance import Instance
from .token_indexers import (SingleIdTokenIndexer, TokenIndexer, TokenFeatsIndexer)
from .tokenizers import (Tokenizer, TokenizerToken, CharacterTokenizer, PretrainedTransformerTokenizer,
SpacyTokenizer, WhitespaceTokenizer)
from .dataset_readers import DatasetReader, ClassificationTextfileDatasetReader
from .dataset_readers import DatasetReader, TextClassificationJSONReader
from .dataset_reader import DatasetReader
from .classification_textfile_dataset_reader import ClassificationTextfileDatasetReader
from .text_classification_json_reader import TextClassificationJSONReader
from typing import Dict, Iterable, Optional
from overrides import overrides
from .dataset_reader import DatasetReader
from .. import Instance, Tokenizer, TokenIndexer
from ..fields.label_field import LabelField
from ..fields.text_field import TextField
class ClassificationTextfileDatasetReader(DatasetReader):
def __init__(self,
tokenizer: Optional[Tokenizer] = None,
token_indexers: Optional[Dict[str, TokenIndexer]] = None) -> None:
super().__init__(tokenizer, token_indexers)
self.__separator = None
@property
def separator(self) -> str:
return self.__separator
@separator.setter
def separator(self, new_separator: str):
self.__separator = new_separator
@overrides
def _read(self) -> Iterable[Instance]:
if self.file_path is None:
raise ValueError('File path is None')
elif self.separator is None:
raise ValueError('Separator is None')
with open(self.file_path, 'r') as lines:
for line in lines:
text, label = line.strip().split(self.separator)
text_field = TextField(self.tokenizer.tokenize(text),
self.token_indexers)
label_field = LabelField(label)
fields = {'text': text_field, 'label': label_field}
yield Instance(fields)
def __call__(self, file_path: str, separator: str):
self.file_path = file_path
self.separator = separator
return self
......@@ -9,6 +9,7 @@ from typing import Iterable, Iterator, Optional, Union, TypeVar, Dict, List
from overrides import overrides
from torch.utils.data import IterableDataset
from combo.data import SpacyTokenizer, SingleIdTokenIndexer
from combo.data.instance import Instance
from combo.data.tokenizers import Tokenizer
from combo.data.token_indexers import TokenIndexer
......@@ -30,8 +31,8 @@ class DatasetReader(IterableDataset):
token_indexers: Optional[Dict[str, TokenIndexer]] = None) -> None:
super(DatasetReader).__init__()
self.__file_path = None
self.__tokenizer = tokenizer
self.__token_indexers = token_indexers
self.__tokenizer = tokenizer or SpacyTokenizer()
self.__token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
@property
def file_path(self) -> DatasetReaderInput:
......
"""
Partially adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/data/dataset_readers/text_classification_json.py
"""
import json
from typing import Dict, Iterable, Optional, Union, List
from overrides import overrides
from .dataset_reader import DatasetReader
from .utils import MalformedFileException
from .. import Instance, Tokenizer, TokenIndexer
from ..fields import Field, ListField
from ..fields.label_field import LabelField
from ..fields.text_field import TextField
from ...utils import ConfigurationError
def _is_sentence_segmenter(sentence_segmenter: Optional[Tokenizer]) -> bool:
split_sentences_method = getattr(sentence_segmenter, "split_sentences", None)
return callable(split_sentences_method)
class TextClassificationJSONReader(DatasetReader):
def __init__(self,
tokenizer: Optional[Tokenizer] = None,
token_indexers: Optional[Dict[str, TokenIndexer]] = None,
sentence_segmenter: Optional[Tokenizer] = None,
max_sequence_length: Optional[int] = None,
skip_label_indexing: bool = False,
text_key: str = "text",
label_key: str = "label") -> None:
if ((sentence_segmenter is not None) and
(not _is_sentence_segmenter(sentence_segmenter))):
raise ConfigurationError(f'Passed sentence segmenter has no'
f'split_sentences method!')
super().__init__(tokenizer, token_indexers)
self.__sentence_segmenter = sentence_segmenter
self.__max_sequence_length = max_sequence_length
self.__skip_label_indexing = skip_label_indexing
self.__text_key = text_key
self.__label_key = label_key
@property
def sentence_segmenter(self) -> Optional[Tokenizer]:
return self.__sentence_segmenter
@sentence_segmenter.setter
def sentence_segmenter(self, value: Optional[Tokenizer]):
if ((value is not None) and
(not _is_sentence_segmenter(value))):
raise ConfigurationError(f'Passed sentence segmenter has no'
f'split_sentences method!')
self.__sentence_segmenter = value
@property
def max_sequence_length(self) -> Optional[int]:
return self.__max_sequence_length
@max_sequence_length.setter
def max_sequence_length(self, value: Optional[int]):
self.__max_sequence_length = value
@property
def skip_label_indexing(self) -> bool:
return self.__skip_label_indexing
@skip_label_indexing.setter
def skip_label_indexing(self, value: bool):
self.__skip_label_indexing = value
@property
def text_key(self) -> str:
return self.__text_key
@text_key.setter
def text_key(self, value: str):
self.__text_key = value
@property
def label_key(self) -> str:
return self.__label_key
@label_key.setter
def label_key(self, value: str):
self.__label_key = value
@overrides
def _read(self) -> Iterable[Instance]:
if self.file_path is None:
raise ValueError('File path is None')
with open(self.file_path, "r") as data_file:
for line in data_file.readlines():
if not line:
continue
items = json.loads(line)
text = items.get(self.text_key)
if text is None:
raise MalformedFileException(f'No item with {self.text_key} (text) label')
label = items.get(self.label_key)
if label is not None:
if self.skip_label_indexing:
try:
label = int(label)
except ValueError:
raise MalformedFileException("Labels must be integers if skip_label_indexing is True.")
else:
label = str(label)
yield self.text_to_instance(text, label)
def text_to_instance(self,
text: str,
label: Optional[Union[str, int]] = None) -> Instance:
"""
:param text: the text to classify
:param label: the label for the text
:return: Instance containing the following fields:
- tokens ('TextField')
- label ('LabelField')
"""
fields: Dict[str, Field] = {}
if self.sentence_segmenter is not None:
sentences: List[Field] = []
# TODO: some subclass for sentence segmenter for tokenizers?
sentence_splits = self.sentence_segmenter.split_sentences(text)
for sentence in sentence_splits:
word_tokens = self.tokenizer.tokenize(sentence)
if self.max_sequence_length is not None:
word_tokens = self._truncate(word_tokens)
sentences.append(TextField(word_tokens))
fields["tokens"] = ListField(sentences)
else:
tokens = self.tokenizer.tokenize(text)
if self.max_sequence_length is not None:
tokens = self._truncate(tokens)
fields["tokens"] = TextField(tokens)
if label is not None:
fields["label"] = LabelField(label,
skip_indexing=self.skip_label_indexing)
return Instance(fields)
def _truncate(self, tokens):
"""
truncate a set of tokens using the provided sequence length
"""
if len(tokens) > self.max_sequence_length:
tokens = tokens[: self.max_sequence_length]
return tokens
def __call__(self, file_path: str):
self.file_path = file_path
return self
@overrides
def apply_token_indexers(self, instance: Instance) -> None:
if self.sentence_segmenter is not None:
for text_field in instance.fields["tokens"]: # type: ignore
text_field._token_indexers = self.token_indexers
else:
instance.fields["tokens"]._token_indexers = self.token_indexers # type: ignore
......@@ -6,3 +6,8 @@ def is_distributed() -> bool:
Checks if the distributed process group is available and has been initialized
"""
return dist.is_available() and dist.is_initialized()
class MalformedFileException(Exception):
def __init__(self, message):
super().__init__(message)
import unittest
from combo.data.dataset_readers import TextClassificationJSONReader
from combo.data.fields import LabelField, TextField, ListField
from combo.data.tokenizers import SpacySentenceSplitter
class TextClassificationJSONReaderTest(unittest.TestCase):
def test_read_two_tokens(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')]
self.assertEqual(len(tokens), 2)
def test_read_two_examples_fields_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('label'), LabelField)
self.assertEqual(tokens[0].fields.get('label').label, 'label1')
self.assertEqual(len(tokens[1].fields.items()), 2)
self.assertIsInstance(tokens[1].fields.get('label'), LabelField)
self.assertEqual(tokens[1].fields.get('label').label, 'label2')
def test_read_two_examples_tokens_without_sentence_splitting(self):
reader = TextClassificationJSONReader()
tokens = [token for token in reader('text_classification_json_reader.json')]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), TextField)
self.assertEqual(len(tokens[0].fields.get('tokens').tokens), 13)
self.assertEqual(len(tokens[1].fields.items()), 2)
self.assertIsInstance(tokens[1].fields.get('tokens'), TextField)
self.assertEqual(len(tokens[1].fields.get('tokens').tokens), 6)
def test_read_two_examples_tokens_with_sentence_splitting(self):
reader = TextClassificationJSONReader(sentence_segmenter=SpacySentenceSplitter())
tokens = [token for token in reader('text_classification_json_reader.json')]
self.assertEqual(len(tokens[0].fields.items()), 2)
self.assertIsInstance(tokens[0].fields.get('tokens'), ListField)
self.assertEqual(len(tokens[0].fields.get('tokens').field_list), 2)
self.assertEqual(len(tokens[1].fields.items()), 2)
self.assertIsInstance(tokens[1].fields.get('tokens'), ListField)
self.assertEqual(len(tokens[1].fields.get('tokens').field_list), 1)
{"text": "A green cat jumped the bed. Then he went to sleep.", "label": "label1"}
{"text": "A black cat jumped the bed", "label": "label2"}
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment