From 6884e6eb10c2a0396eeea076742e792dc4e7b9db Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl>
Date: Sun, 4 Feb 2024 16:28:15 +0100
Subject: [PATCH] added option for passing kweargs when predicting + fixed
 default_model to be consistent wiht config.template.json

---
 combo/default_model.py | 66 ++++++++++++++++++++++++++++++------------
 combo/predict.py       |  4 +--
 2 files changed, 49 insertions(+), 21 deletions(-)

diff --git a/combo/default_model.py b/combo/default_model.py
index f02786b..4ec705d 100644
--- a/combo/default_model.py
+++ b/combo/default_model.py
@@ -55,33 +55,61 @@ def default_ud_dataset_reader(pretrained_transformer_name: str,
         targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
         token_indexers={
             "char": default_character_indexer(),
-            "feats": TokenFeatsIndexer(),
-            "lemma": default_character_indexer(),
+            # "feats": TokenFeatsIndexer(),
+            # "lemma": default_character_indexer(),
             "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name),
-            "upostag": SingleIdTokenIndexer(
-                feature_name="pos_",
-                namespace="upostag"
-            ),
-            "xpostag": SingleIdTokenIndexer(
-                feature_name="tag_",
-                namespace="xpostag"
-            )
+            # "upostag": SingleIdTokenIndexer(
+            #     feature_name="pos_",
+            #     namespace="upostag"
+            # ),
+            # "xpostag": SingleIdTokenIndexer(
+            #     feature_name="tag_",
+            #     namespace="xpostag"
+            # )
         },
         use_sem=False,
         tokenizer=tokenizer
     )
 
 
-def default_data_loader(dataset_reader: DatasetReader,
+def default_data_loader(
+                        dataset_reader: DatasetReader,
                         file_path: str,
-                        batch_size: int = 16,
-                        batches_per_epoch: int = 4) -> SimpleDataLoader:
-    return SimpleDataLoader.from_dataset_reader(dataset_reader,
-                                                data_path=file_path,
-                                                batch_size=batch_size,
-                                                batches_per_epoch=batches_per_epoch,
-                                                shuffle=True,
-                                                collate_fn=lambda instances: Batch(instances).as_tensor_dict())
+                        batch_size: int = 1,
+                        batches_per_epoch: int = 64) -> SimpleDataLoader:
+    # tokenizer = tokenizer or LamboTokenizer()
+    # reader = UniversalDependenciesDatasetReader(
+    #     features=["token", "char"],
+    #     lemma_indexers={
+    #         "char": default_character_indexer("lemma_characters")
+    #     },
+    #     targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
+    #     token_indexers={
+    #         "char": default_character_indexer(),
+    #         # "feats": TokenFeatsIndexer(),
+    #         # "lemma": default_character_indexer(),
+    #         "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name),
+    #         # "upostag": SingleIdTokenIndexer(
+    #         #     feature_name="pos_",
+    #         #     namespace="upostag"
+    #         # ),
+    #         # "xpostag": SingleIdTokenIndexer(
+    #         #     feature_name="tag_",
+    #         #     namespace="xpostag"
+    #         # )
+    #     },
+    #     use_sem=False,
+    #     tokenizer=tokenizer
+    # )
+
+    return SimpleDataLoader.from_dataset_reader(
+        dataset_reader,
+        data_path=file_path,
+        batch_size=batch_size,
+        batches_per_epoch=batches_per_epoch,
+        shuffle=True,
+        quiet=False,
+        collate_fn=lambda instances: Batch(instances).as_tensor_dict())
 
 
 def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
diff --git a/combo/predict.py b/combo/predict.py
index c4f507b..84e9aed 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -42,7 +42,7 @@ class COMBO(PredictorModule):
         self.without_sentence_embedding = False
         self.line_to_conllu = line_to_conllu
 
-    def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
+    def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]], **kwargs):
         """Depending on the input uses (or ignores) tokenizer.
         When model isn't only text-based only List[data.Sentence] is possible input.
 
@@ -55,7 +55,7 @@ class COMBO(PredictorModule):
         :return: Sentence or List[Sentence] depending on the input
         """
         try:
-            return self.predict(sentence)
+            return self.predict(sentence, **kwargs)
         except Exception as e:
             logger.error(e)
             logger.error('Exiting.')
-- 
GitLab