From 18627a235b00399e738f1a7e78c59d6377036d96 Mon Sep 17 00:00:00 2001
From: Maja Jablonska <majajjablonska@gmail.com>
Date: Wed, 9 Aug 2023 16:21:39 +0200
Subject: [PATCH] Add LAMBO as a tokenizer

---
 combo/data/__init__.py                        |   2 +-
 combo/data/dataset_loaders/__init__.py        |   0
 combo/data/dataset_loaders/dataset_loader.py  |  54 ++++++++++
 .../dataset_loaders/simple_data_loader.py     | 102 ++++++++++++++++++
 combo/data/tokenizers/__init__.py             |   1 +
 combo/data/tokenizers/lambo_tokenizer.py      |  53 +++++++++
 combo/data/tokenizers/token.py                |   5 +
 requirements.txt                              |   1 +
 tests/data/data_readers/test_conll.py         |   6 ++
 tests/data/tokenizers/test_lambo_tokenizer.py |  21 ++++
 10 files changed, 244 insertions(+), 1 deletion(-)
 create mode 100644 combo/data/dataset_loaders/__init__.py
 create mode 100644 combo/data/dataset_loaders/dataset_loader.py
 create mode 100644 combo/data/dataset_loaders/simple_data_loader.py
 create mode 100644 combo/data/tokenizers/lambo_tokenizer.py
 create mode 100644 tests/data/tokenizers/test_lambo_tokenizer.py

diff --git a/combo/data/__init__.py b/combo/data/__init__.py
index 8f6789b..f75b7d8 100644
--- a/combo/data/__init__.py
+++ b/combo/data/__init__.py
@@ -4,6 +4,6 @@ from .samplers import TokenCountBatchSampler
 from .instance import Instance
 from .token_indexers import (SingleIdTokenIndexer, TokenIndexer, TokenFeatsIndexer)
 from .tokenizers import (Tokenizer, Token, CharacterTokenizer, PretrainedTransformerTokenizer,
-                         SpacyTokenizer, WhitespaceTokenizer)
+                         SpacyTokenizer, WhitespaceTokenizer, LamboTokenizer)
 from .dataset_readers import (ConllDatasetReader, DatasetReader,
                               TextClassificationJSONReader, UniversalDependenciesDatasetReader)
diff --git a/combo/data/dataset_loaders/__init__.py b/combo/data/dataset_loaders/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/combo/data/dataset_loaders/dataset_loader.py b/combo/data/dataset_loaders/dataset_loader.py
new file mode 100644
index 0000000..27a7fda
--- /dev/null
+++ b/combo/data/dataset_loaders/dataset_loader.py
@@ -0,0 +1,54 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/data_loaders/data_loader.py
+"""
+
+from typing import Dict, Union, Iterator
+
+import torch
+
+from combo.data import Vocabulary, Instance
+
+TensorDict = Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]
+"""
+`TensorDict` is the type we use for batches.
+"""
+
+
+class DataLoader:
+    """
+    A `DataLoader` is responsible for generating batches of instances from a
+    [`DatasetReader`](/api/data/dataset_readers/dataset_reader/#datasetreader),
+    or another source of data.
+
+    This is purely an abstract base class. All concrete subclasses must provide
+    implementations of the following methods:
+
+      - [`__iter__()`](#__iter__) that creates an iterable of `TensorDict`s,
+      - [`iter_instances()`](#iter_instances) that creates an iterable of `Instance`s,
+      - [`index_with()`](#index_with) that should index the data with a vocabulary, and
+      - [`set_target_device()`](#set_target_device), which updates the device that batch
+        tensors should be put it when they are generated in `__iter__()`.
+
+    Additionally, this class should also implement `__len__()` when possible.
+
+    The default implementation is
+    [`MultiProcessDataLoader`](../multiprocess_data_loader/#multiprocessdataloader).
+    """
+
+    default_implementation = "multiprocess"
+
+    def __len__(self) -> int:
+        raise TypeError
+
+    def __iter__(self) -> Iterator[TensorDict]:
+        raise NotImplementedError
+
+    def iter_instances(self) -> Iterator[Instance]:
+        raise NotImplementedError
+
+    def index_with(self, vocab: Vocabulary) -> None:
+        raise NotImplementedError
+
+    def set_target_device(self, device: torch.device) -> None:
+        raise NotImplementedError
diff --git a/combo/data/dataset_loaders/simple_data_loader.py b/combo/data/dataset_loaders/simple_data_loader.py
new file mode 100644
index 0000000..613b514
--- /dev/null
+++ b/combo/data/dataset_loaders/simple_data_loader.py
@@ -0,0 +1,102 @@
+"""
+Adapted from AllenNLP
+https://github.com/allenai/allennlp/blob/main/allennlp/data/data_loaders/simple_data_loader.py
+"""
+
+import math
+import random
+from typing import Optional, List, Iterator
+
+
+import torch
+
+from allennlp.common.util import lazy_groups_of
+from allennlp.common.tqdm import Tqdm
+from allennlp.data.data_loaders.data_collator import DefaultDataCollator
+from allennlp.data.dataset_readers import DatasetReader
+from allennlp.data.instance import Instance
+from allennlp.data.vocabulary import Vocabulary
+import allennlp.nn.util as nn_util
+
+from combo.data.dataset_loaders.dataset_loader import DataLoader
+
+
+class SimpleDataLoader(DataLoader):
+    """
+    A very simple `DataLoader` that is mostly used for testing.
+    """
+
+    def __init__(
+        self,
+        instances: List[Instance],
+        batch_size: int,
+        *,
+        shuffle: bool = False,
+        batches_per_epoch: Optional[int] = None,
+        vocab: Optional[Vocabulary] = None,
+    ) -> None:
+        self.instances = instances
+        self.batch_size = batch_size
+        self.shuffle = shuffle
+        self.batches_per_epoch = batches_per_epoch
+        self.vocab = vocab
+        self.cuda_device: Optional[torch.device] = None
+        self._batch_generator: Optional[Iterator[TensorDict]] = None
+        self.collate_fn = DefaultDataCollator()
+
+    def __len__(self) -> int:
+        if self.batches_per_epoch is not None:
+            return self.batches_per_epoch
+        return math.ceil(len(self.instances) / self.batch_size)
+
+    def __iter__(self) -> Iterator[TensorDict]:
+        if self.batches_per_epoch is None:
+            yield from self._iter_batches()
+        else:
+            if self._batch_generator is None:
+                self._batch_generator = self._iter_batches()
+            for i in range(self.batches_per_epoch):
+                try:
+                    yield next(self._batch_generator)
+                except StopIteration:  # data_generator is exhausted
+                    self._batch_generator = self._iter_batches()  # so refresh it
+                    yield next(self._batch_generator)
+
+    def _iter_batches(self) -> Iterator[TensorDict]:
+        if self.shuffle:
+            random.shuffle(self.instances)
+        for batch in lazy_groups_of(self.iter_instances(), self.batch_size):
+            tensor_dict = self.collate_fn(batch)
+            if self.cuda_device is not None:
+                tensor_dict = nn_util.move_to_device(tensor_dict, self.cuda_device)
+            yield tensor_dict
+
+    def iter_instances(self) -> Iterator[Instance]:
+        for instance in self.instances:
+            if self.vocab is not None:
+                instance.index_fields(self.vocab)
+            yield instance
+
+    def index_with(self, vocab: Vocabulary) -> None:
+        self.vocab = vocab
+        for instance in self.instances:
+            instance.index_fields(self.vocab)
+
+    def set_target_device(self, device: torch.device) -> None:
+        self.cuda_device = device
+
+    @classmethod
+    def from_dataset_reader(
+        cls,
+        reader: DatasetReader,
+        data_path: str,
+        batch_size: int,
+        shuffle: bool = False,
+        batches_per_epoch: Optional[int] = None,
+        quiet: bool = False,
+    ) -> "SimpleDataLoader":
+        instance_iter = reader.read(data_path)
+        if not quiet:
+            instance_iter = Tqdm.tqdm(instance_iter, desc="loading instances")
+        instances = list(instance_iter)
+        return cls(instances, batch_size, shuffle=shuffle, batches_per_epoch=batches_per_epoch)
diff --git a/combo/data/tokenizers/__init__.py b/combo/data/tokenizers/__init__.py
index 1e404f0..0afccdd 100644
--- a/combo/data/tokenizers/__init__.py
+++ b/combo/data/tokenizers/__init__.py
@@ -4,3 +4,4 @@ from .pretrained_transformer_tokenizer import PretrainedTransformerTokenizer
 from .spacy_tokenizer import SpacyTokenizer
 from .sentence_splitter import SentenceSplitter, SpacySentenceSplitter
 from .whitespace_tokenizer import WhitespaceTokenizer
+from .lambo_tokenizer import LamboTokenizer
diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
new file mode 100644
index 0000000..43a5ac6
--- /dev/null
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -0,0 +1,53 @@
+from typing import List, Dict, Any
+
+from combo.data.tokenizers.token import Token
+from combo.data.tokenizers.tokenizer import Tokenizer
+from lambo.segmenter.lambo import Lambo
+
+
+class LamboTokenizer(Tokenizer):
+
+    def __init__(
+            self,
+            language: str = "English"
+    ):
+        self._language = language
+        self.__tokenizer = Lambo.get(language)
+
+    def tokenize(self, text: str) -> List[Token]:
+        """
+        Simple tokenization - ignoring the sentence splits
+        :param text:
+        :return:
+        """
+        document = self.__tokenizer.segment(text)
+        tokens = []
+
+        for turn in document.turns:
+            for sentence in turn.sentences:
+                for token in sentence.tokens:
+                    tokens.append(Token(token.text, subwords=token.subwords))
+
+        return tokens
+
+    def segment(self, text: str) -> List[List[str]]:
+        """
+        Full segmentation - segment into sentences
+        :param text:
+        :return:
+        """
+
+        document = self.__tokenizer.segment(text)
+        sentences = []
+
+        for turn in document.turns:
+            for sentence in turn.sentences:
+                sentences.append([t.text for t in sentence.tokens])
+
+        return sentences
+
+    def _to_params(self) -> Dict[str, Any]:
+        return {
+            "type": "lambo",
+            "language": self._language
+        }
diff --git a/combo/data/tokenizers/token.py b/combo/data/tokenizers/token.py
index 76c205a..d8f2590 100644
--- a/combo/data/tokenizers/token.py
+++ b/combo/data/tokenizers/token.py
@@ -28,6 +28,7 @@ class Token:
         "deprel",
         "deps",
         "misc",
+        "subwords",
         "semrel",
         "embeddings",
         "text_id"
@@ -44,6 +45,7 @@ class Token:
     deprel: Optional[str]  # dep_ ?
     deps: Optional[str]
     misc: Optional[str]
+    subwords: Optional[List[str]]
     semrel: Optional[str]
     embeddings: Dict[str, List[float]]
     text_id: Optional[int]
@@ -60,6 +62,7 @@ class Token:
                  deprel: str = None,
                  deps: str = None,
                  misc: str = None,
+                 subwords: List[str] = None,
                  semrel: str = None,
                  embeddings: Dict[str, List[float]] = None,
                  text_id: int = None) -> None:
@@ -76,6 +79,7 @@ class Token:
         self.deprel = deprel
         self.deps = deps
         self.misc = misc
+        self.subwords = subwords if subwords else []
         self.semrel = semrel
 
         if embeddings is None:
@@ -105,6 +109,7 @@ class Token:
             f"(deprel: {self.deprel}) "
             f"(deps: {self.deps}) "
             f"(misc: {self.misc}) "
+            f"(subwords: {','.join(self.subwords)})"
             f"(semrel: {self.semrel}) "
             f"(embeddings: {self.embeddings}) "
             f"(text_id: {self.text_id})"
diff --git a/requirements.txt b/requirements.txt
index f7222b4..1a928d8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -9,6 +9,7 @@ importlib-resources~=5.12.0
 overrides~=7.3.1
 torch~=2.0.0
 torchtext~=0.15.1
+lambo~=2.0.0
 numpy~=1.24.1
 pytorch-lightning~=2.0.01
 requests~=2.28.2
diff --git a/tests/data/data_readers/test_conll.py b/tests/data/data_readers/test_conll.py
index f516a24..134d0a9 100644
--- a/tests/data/data_readers/test_conll.py
+++ b/tests/data/data_readers/test_conll.py
@@ -1,6 +1,7 @@
 import unittest
 
 from combo.data import ConllDatasetReader
+from torch.utils.data import DataLoader
 
 
 class ConllDatasetReaderTest(unittest.TestCase):
@@ -9,6 +10,11 @@ class ConllDatasetReaderTest(unittest.TestCase):
         tokens = [token for token in reader('conll_test_file.txt')]
         self.assertEqual(len(tokens), 6)
 
+    def test_read_all_tokens_data_loader(self):
+        reader = ConllDatasetReader(coding_scheme='IOB2')
+        loader = DataLoader(reader('conll_test_file.txt'), batch_size=16)
+        print(next(iter(loader)))
+
     def test_tokenize_correct_tokens(self):
         reader = ConllDatasetReader(coding_scheme='IOB2')
         token = next(iter(reader('conll_test_file.txt')))
diff --git a/tests/data/tokenizers/test_lambo_tokenizer.py b/tests/data/tokenizers/test_lambo_tokenizer.py
new file mode 100644
index 0000000..7c49cc5
--- /dev/null
+++ b/tests/data/tokenizers/test_lambo_tokenizer.py
@@ -0,0 +1,21 @@
+import unittest
+
+from combo.data import LamboTokenizer
+
+
+class LamboTokenizerTest(unittest.TestCase):
+
+    def setUp(self) -> None:
+        self.lambo_tokenizer = LamboTokenizer()
+
+    def test_tokenize_sentence(self):
+        tokens = self.lambo_tokenizer.tokenize('Hello cats. I love you')
+        self.assertListEqual([t.text for t in tokens],
+                             ['Hello', 'cats', '.', 'I', 'love', 'you'])
+
+    def test_tokenize_sentence_with_multiword(self):
+        tokens = self.lambo_tokenizer.tokenize('I don\'t like apples.')
+        self.assertListEqual([t.text for t in tokens],
+                             ['I', 'don\'t', 'like', 'apples', '.'])
+        self.assertListEqual([t.subwords for t in tokens],
+                             [[], ['do', 'n\'t'], [], [], []])
-- 
GitLab