From 6884e6eb10c2a0396eeea076742e792dc4e7b9db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Sun, 4 Feb 2024 16:28:15 +0100 Subject: [PATCH] added option for passing kweargs when predicting + fixed default_model to be consistent wiht config.template.json --- combo/default_model.py | 66 ++++++++++++++++++++++++++++++------------ combo/predict.py | 4 +-- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/combo/default_model.py b/combo/default_model.py index f02786b..4ec705d 100644 --- a/combo/default_model.py +++ b/combo/default_model.py @@ -55,33 +55,61 @@ def default_ud_dataset_reader(pretrained_transformer_name: str, targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"], token_indexers={ "char": default_character_indexer(), - "feats": TokenFeatsIndexer(), - "lemma": default_character_indexer(), + # "feats": TokenFeatsIndexer(), + # "lemma": default_character_indexer(), "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name), - "upostag": SingleIdTokenIndexer( - feature_name="pos_", - namespace="upostag" - ), - "xpostag": SingleIdTokenIndexer( - feature_name="tag_", - namespace="xpostag" - ) + # "upostag": SingleIdTokenIndexer( + # feature_name="pos_", + # namespace="upostag" + # ), + # "xpostag": SingleIdTokenIndexer( + # feature_name="tag_", + # namespace="xpostag" + # ) }, use_sem=False, tokenizer=tokenizer ) -def default_data_loader(dataset_reader: DatasetReader, +def default_data_loader( + dataset_reader: DatasetReader, file_path: str, - batch_size: int = 16, - batches_per_epoch: int = 4) -> SimpleDataLoader: - return SimpleDataLoader.from_dataset_reader(dataset_reader, - data_path=file_path, - batch_size=batch_size, - batches_per_epoch=batches_per_epoch, - shuffle=True, - collate_fn=lambda instances: Batch(instances).as_tensor_dict()) + batch_size: int = 1, + batches_per_epoch: int = 64) -> SimpleDataLoader: + # tokenizer = tokenizer or LamboTokenizer() + # reader = UniversalDependenciesDatasetReader( + # features=["token", "char"], + # lemma_indexers={ + # "char": default_character_indexer("lemma_characters") + # }, + # targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"], + # token_indexers={ + # "char": default_character_indexer(), + # # "feats": TokenFeatsIndexer(), + # # "lemma": default_character_indexer(), + # "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name), + # # "upostag": SingleIdTokenIndexer( + # # feature_name="pos_", + # # namespace="upostag" + # # ), + # # "xpostag": SingleIdTokenIndexer( + # # feature_name="tag_", + # # namespace="xpostag" + # # ) + # }, + # use_sem=False, + # tokenizer=tokenizer + # ) + + return SimpleDataLoader.from_dataset_reader( + dataset_reader, + data_path=file_path, + batch_size=batch_size, + batches_per_epoch=batches_per_epoch, + shuffle=True, + quiet=False, + collate_fn=lambda instances: Batch(instances).as_tensor_dict()) def default_vocabulary(data_loader: DataLoader) -> Vocabulary: diff --git a/combo/predict.py b/combo/predict.py index c4f507b..84e9aed 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -42,7 +42,7 @@ class COMBO(PredictorModule): self.without_sentence_embedding = False self.line_to_conllu = line_to_conllu - def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): + def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]], **kwargs): """Depending on the input uses (or ignores) tokenizer. When model isn't only text-based only List[data.Sentence] is possible input. @@ -55,7 +55,7 @@ class COMBO(PredictorModule): :return: Sentence or List[Sentence] depending on the input """ try: - return self.predict(sentence) + return self.predict(sentence, **kwargs) except Exception as e: logger.error(e) logger.error('Exiting.') -- GitLab