From 25ccba608b668a2f96ecc1940013a6a783c90b9f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl>
Date: Fri, 2 Feb 2024 23:14:19 +0100
Subject: [PATCH] Fixed multiword prediction + bug that made the code write
 empty predictions

---
 combo/data/tokenizers/lambo_tokenizer.py | 30 ++++++++++++++++++++----
 1 file changed, 26 insertions(+), 4 deletions(-)

diff --git a/combo/data/tokenizers/lambo_tokenizer.py b/combo/data/tokenizers/lambo_tokenizer.py
index e0beb83..e88f098 100644
--- a/combo/data/tokenizers/lambo_tokenizer.py
+++ b/combo/data/tokenizers/lambo_tokenizer.py
@@ -28,7 +28,7 @@ def _sentence_tokens(token: Token,
                      split_subwords: Optional[bool] = None) -> List[Token]:
     if split_subwords and len(token.subwords) > 0:
         subword_idxs = [next(_token_idx()) for _ in range(len(token.subwords))]
-        multiword = (token.text, (subword_idxs[0], subword_idxs[1]))
+        multiword = (token.text, (subword_idxs[0], subword_idxs[-1]))
         tokens = [Token(idx=s_idx, text=subword, multiword=multiword) for (s_idx, subword)
                   in zip(subword_idxs, token.subwords)]
         return tokens
@@ -74,12 +74,14 @@ class LamboTokenizer(Tokenizer):
             for turn in document.turns:
                 sentence_tokens = []
                 for sentence in turn.sentences:
+                    _reset_idx()
                     for token in sentence.tokens:
                         sentence_tokens.extend(_sentence_tokens(token, split_subwords))
                 tokens.append(sentence_tokens)
         elif split_level.upper() == "SENTENCE":
             for turn in document.turns:
                 for sentence in turn.sentences:
+                    _reset_idx()
                     sentence_tokens = []
                     for token in sentence.tokens:
                         sentence_tokens.extend(_sentence_tokens(token, split_subwords))
@@ -87,6 +89,7 @@ class LamboTokenizer(Tokenizer):
         else:
             for turn in document.turns:
                 for sentence in turn.sentences:
+                    _reset_idx()
                     for token in sentence.tokens:
                         tokens.extend(_sentence_tokens(token, split_subwords))
             tokens = [tokens]
@@ -116,13 +119,32 @@ class LamboTokenizer(Tokenizer):
             if turns:
                 sentence_tokens = []
             for sentence in turn.sentences:
+                _reset_idx()
                 if not turns:
                     sentence_tokens = []
                 for token in sentence.tokens:
                     if len(token.subwords) > 0 and split_subwords:
-                        sentence_tokens.extend([s for s in token.subwords])
-                    else:
-                        sentence_tokens.append(token.text)
+                        # @TODO this is a very dirty fix for Lambo model's shortcomings
+                        # I noticed that for longer words with multiwords it tends to remove the last letter in the last multiword
+                        # so this is a quick workaround to fix it
+
+                        # check if subwords in token.subwords are consistent with token.text
+                        if "".join(token.subwords) != token.text:
+                            fixed_subwords = []
+                            text_it = 0
+                            for i, subword in enumerate(token.subwords):
+                                if token.text[text_it:text_it + len(subword)] == subword:
+                                    if i == len(token.subwords) - 1 and (text_it + len(subword) < len(token.text)):
+                                        subword = token.text[text_it:]
+                                    fixed_subwords.append(subword)
+                                    text_it += len(subword)
+                                else:
+                                    fixed_subwords.append(token.text[text_it:text_it + len(subword)])
+                                    text_it += len(subword)
+                            token.subwords = fixed_subwords
+                        # sentence_tokens.extend(_sentence_tokens(token, split_subwords))
+                    # else:
+                    sentence_tokens.extend(_sentence_tokens(token, split_subwords))
                 if not turns:
                     sentences.append(sentence_tokens)
             if turns:
-- 
GitLab