Skip to content
Snippets Groups Projects
Commit 6884e6eb authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

added option for passing kweargs when predicting + fixed default_model to be...

added option for passing kweargs when predicting + fixed default_model to be consistent wiht config.template.json
parent e007a791
Branches
Tags
1 merge request!47Fixed multiword prediction + bug that made the code write empty predictions
...@@ -55,33 +55,61 @@ def default_ud_dataset_reader(pretrained_transformer_name: str, ...@@ -55,33 +55,61 @@ def default_ud_dataset_reader(pretrained_transformer_name: str,
targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"], targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
token_indexers={ token_indexers={
"char": default_character_indexer(), "char": default_character_indexer(),
"feats": TokenFeatsIndexer(), # "feats": TokenFeatsIndexer(),
"lemma": default_character_indexer(), # "lemma": default_character_indexer(),
"token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name), "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name),
"upostag": SingleIdTokenIndexer( # "upostag": SingleIdTokenIndexer(
feature_name="pos_", # feature_name="pos_",
namespace="upostag" # namespace="upostag"
), # ),
"xpostag": SingleIdTokenIndexer( # "xpostag": SingleIdTokenIndexer(
feature_name="tag_", # feature_name="tag_",
namespace="xpostag" # namespace="xpostag"
) # )
}, },
use_sem=False, use_sem=False,
tokenizer=tokenizer tokenizer=tokenizer
) )
def default_data_loader(dataset_reader: DatasetReader, def default_data_loader(
dataset_reader: DatasetReader,
file_path: str, file_path: str,
batch_size: int = 16, batch_size: int = 1,
batches_per_epoch: int = 4) -> SimpleDataLoader: batches_per_epoch: int = 64) -> SimpleDataLoader:
return SimpleDataLoader.from_dataset_reader(dataset_reader, # tokenizer = tokenizer or LamboTokenizer()
data_path=file_path, # reader = UniversalDependenciesDatasetReader(
batch_size=batch_size, # features=["token", "char"],
batches_per_epoch=batches_per_epoch, # lemma_indexers={
shuffle=True, # "char": default_character_indexer("lemma_characters")
collate_fn=lambda instances: Batch(instances).as_tensor_dict()) # },
# targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
# token_indexers={
# "char": default_character_indexer(),
# # "feats": TokenFeatsIndexer(),
# # "lemma": default_character_indexer(),
# "token": PretrainedTransformerFixedMismatchedIndexer(pretrained_transformer_name),
# # "upostag": SingleIdTokenIndexer(
# # feature_name="pos_",
# # namespace="upostag"
# # ),
# # "xpostag": SingleIdTokenIndexer(
# # feature_name="tag_",
# # namespace="xpostag"
# # )
# },
# use_sem=False,
# tokenizer=tokenizer
# )
return SimpleDataLoader.from_dataset_reader(
dataset_reader,
data_path=file_path,
batch_size=batch_size,
batches_per_epoch=batches_per_epoch,
shuffle=True,
quiet=False,
collate_fn=lambda instances: Batch(instances).as_tensor_dict())
def default_vocabulary(data_loader: DataLoader) -> Vocabulary: def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
......
...@@ -42,7 +42,7 @@ class COMBO(PredictorModule): ...@@ -42,7 +42,7 @@ class COMBO(PredictorModule):
self.without_sentence_embedding = False self.without_sentence_embedding = False
self.line_to_conllu = line_to_conllu self.line_to_conllu = line_to_conllu
def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]], **kwargs):
"""Depending on the input uses (or ignores) tokenizer. """Depending on the input uses (or ignores) tokenizer.
When model isn't only text-based only List[data.Sentence] is possible input. When model isn't only text-based only List[data.Sentence] is possible input.
...@@ -55,7 +55,7 @@ class COMBO(PredictorModule): ...@@ -55,7 +55,7 @@ class COMBO(PredictorModule):
:return: Sentence or List[Sentence] depending on the input :return: Sentence or List[Sentence] depending on the input
""" """
try: try:
return self.predict(sentence) return self.predict(sentence, **kwargs)
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
logger.error('Exiting.') logger.error('Exiting.')
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment