diff --git a/combo/modules/archival.py b/combo/modules/archival.py index 496753f711f60fce9555d724fe72c7e49c7a8435..076f11969fe414a13b68db24e015bb9b688c4d2f 100644 --- a/combo/modules/archival.py +++ b/combo/modules/archival.py @@ -13,6 +13,7 @@ from io import BytesIO from tempfile import TemporaryDirectory from combo.config import resolve +from combo.data.tokenizers import Tokenizer from combo.data.dataset_loaders import DataLoader from combo.data.dataset_readers import DatasetReader from combo.modules.model import Model @@ -117,7 +118,8 @@ def extracted_archive(resolved_archive_file, cleanup=True): def load_archive(url_or_filename: Union[PathLike, str], cache_dir: Union[PathLike, str] = None, - cuda_device: int = -1) -> Archive: + cuda_device: int = -1, + overriden_tokenizer: Optional[Tokenizer] = None) -> Archive: rarchive_file = cached_path.cached_path( url_or_filename, @@ -136,28 +138,11 @@ def load_archive(url_or_filename: Union[PathLike, str], if config.get("model_name"): pass_down_parameters = {"model_name": config.get("model_name")} - - if 'data_loader' in config: - try: - data_loader = resolve(config['data_loader'], - pass_down_parameters=pass_down_parameters) - except Exception as e: - logger.warning(f'Error while loading Training Data Loader: {str(e)}. Setting Data Loader to None', - prefix=PREFIX) - if 'validation_data_loader' in config: - try: - validation_data_loader = resolve(config['validation_data_loader'], - pass_down_parameters=pass_down_parameters) - except Exception as e: - logger.warning(f'Error while loading Validation Data Loader: {str(e)}. Setting Data Loader to None', - prefix=PREFIX) if 'dataset_reader' in config: - try: - dataset_reader = resolve(config['dataset_reader'], - pass_down_parameters=pass_down_parameters) - except Exception as e: - logger.warning(f'Error while loading Dataset Reader: {str(e)}. Setting Dataset Reader to None', - prefix=PREFIX) + if overriden_tokenizer: + config['dataset_reader']['parameters']['tokenizer'] = overriden_tokenizer.serialize() + dataset_reader = resolve(config['dataset_reader'], + pass_down_parameters=pass_down_parameters) return Archive(model=model, config=config, diff --git a/combo/predict.py b/combo/predict.py index 3feacb6071d9231df02014aa3fe6f5b133b5cb07..6dab46c36e372de5cd49c7df08e90768a6f01281 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,7 +1,7 @@ import logging import os import sys -from typing import List, Union, Dict, Any +from typing import List, Union, Dict, Optional, Any import numpy as np import torch @@ -11,7 +11,7 @@ from combo import data from combo.common import util from combo.config import Registry from combo.config.from_parameters import register_arguments -from combo.data import Instance, conllu2sentence, sentence2conllu +from combo.data import Instance, conllu2sentence, sentence2conllu, Tokenizer from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict @@ -27,7 +27,6 @@ from combo.ner_modules.data.NerTokenizer import NerTokenizer from pathlib import Path from combo.ner_modules.utils.utils import move_tensors_to_device - logger = logging.getLogger(__name__) @@ -56,7 +55,6 @@ class COMBO(PredictorModule): if ner_model is not None: self._load_ner_model(ner_model) - def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]], **kwargs): """Depending on the input uses (or ignores) tokenizer. When model isn't only text-based only List[data.Sentence] is possible input. @@ -95,10 +93,11 @@ class COMBO(PredictorModule): sentence = self.dataset_reader.tokenizer.tokenize(sentence, **kwargs) elif isinstance(sentence, list): if isinstance(sentence[0], str): - sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(sentence)]] + sentence = [[Token(idx=idx + 1, text=t) for idx, t in enumerate(sentence)]] elif isinstance(sentence[0], list): if isinstance(sentence[0][0], str): - sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(subsentence)] for subsentence in sentence] + sentence = [[Token(idx=idx + 1, text=t) for idx, t in enumerate(subsentence)] for subsentence in + sentence] elif not isinstance(sentence[0][0], Token): raise ValueError("Passed sentence must be a list (or list of lists) of strings or Token classes") elif not isinstance(sentence[0], Token) and not isinstance(sentence[0], data.Sentence): @@ -176,7 +175,7 @@ class COMBO(PredictorModule): # TODO: tokenize EVERYTHING, even if a list is passed? if isinstance(sentence, str): tokens = [sentence] - #tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])] + # tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])] elif isinstance(sentence, list): tokens = sentence else: @@ -295,13 +294,14 @@ class COMBO(PredictorModule): tree.tokens.extend(empty_tokens) return tree, predictions["sentence_embedding"], embeddings, \ - deprel_tree_distribution, deprel_label_distribution + deprel_tree_distribution, deprel_label_distribution @classmethod def from_pretrained(cls, path: str, batch_size: int = 1024, cuda_device: int = -1, + tokenizer: Optional[Tokenizer] = None, ner_model: str = None): if os.path.exists(path): @@ -314,9 +314,10 @@ class COMBO(PredictorModule): logger.error(e) raise e - archive = load_archive(model_path, cuda_device=cuda_device) + archive = load_archive(model_path, cuda_device=cuda_device, overriden_tokenizer=tokenizer) model = archive.model - dataset_reader = archive.dataset_reader or default_ud_dataset_reader(archive.config.get("model_name")) + dataset_reader = archive.dataset_reader or default_ud_dataset_reader(archive.config.get("model_name"), + tokenizer=tokenizer) return cls(model, dataset_reader, batch_size, ner_model=ner_model) def _load_ner_model(self, @@ -331,7 +332,6 @@ class COMBO(PredictorModule): self.ner_tokenizer = NerTokenizer.load_from_disc(folder_path=Path(ner_model), load_lambo_tokenizer=False) - def _predict_ner_tags(self, result: List[data.Sentence]) -> List[data.Sentence]: """Enriches predictions with NER tags.""" @@ -343,4 +343,3 @@ class COMBO(PredictorModule): for token, pred in zip(sentence.tokens, self.ner_tokenizer.decode(preds)[0]): token.ner_tag = pred return result - diff --git a/pyproject.toml b/pyproject.toml index 02722ffd399e2cccc6190b95df32e5d0268441ea..4a30280e103cadbcef9e7b027af8d7b5ff71798a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] [project] name = "combo" -version = "3.2.3" +version = "3.3.0" authors = [ {name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"} ]