Skip to content
Snippets Groups Projects
Commit 8bffc5a2 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Enable tokenizer overriding in from_parameters

parent c90a5b44
No related merge requests found
Pipeline #17349 passed with stage
in 48 seconds
......@@ -13,6 +13,7 @@ from io import BytesIO
from tempfile import TemporaryDirectory
from combo.config import resolve
from combo.data.tokenizers import Tokenizer
from combo.data.dataset_loaders import DataLoader
from combo.data.dataset_readers import DatasetReader
from combo.modules.model import Model
......@@ -117,7 +118,8 @@ def extracted_archive(resolved_archive_file, cleanup=True):
def load_archive(url_or_filename: Union[PathLike, str],
cache_dir: Union[PathLike, str] = None,
cuda_device: int = -1) -> Archive:
cuda_device: int = -1,
overriden_tokenizer: Optional[Tokenizer] = None) -> Archive:
rarchive_file = cached_path.cached_path(
url_or_filename,
......@@ -136,28 +138,11 @@ def load_archive(url_or_filename: Union[PathLike, str],
if config.get("model_name"):
pass_down_parameters = {"model_name": config.get("model_name")}
if 'data_loader' in config:
try:
data_loader = resolve(config['data_loader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Training Data Loader: {str(e)}. Setting Data Loader to None',
prefix=PREFIX)
if 'validation_data_loader' in config:
try:
validation_data_loader = resolve(config['validation_data_loader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Validation Data Loader: {str(e)}. Setting Data Loader to None',
prefix=PREFIX)
if 'dataset_reader' in config:
try:
dataset_reader = resolve(config['dataset_reader'],
pass_down_parameters=pass_down_parameters)
except Exception as e:
logger.warning(f'Error while loading Dataset Reader: {str(e)}. Setting Dataset Reader to None',
prefix=PREFIX)
if overriden_tokenizer:
config['dataset_reader']['parameters']['tokenizer'] = overriden_tokenizer.serialize()
dataset_reader = resolve(config['dataset_reader'],
pass_down_parameters=pass_down_parameters)
return Archive(model=model,
config=config,
......
import logging
import os
import sys
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Optional, Any
import numpy as np
import torch
......@@ -11,7 +11,7 @@ from combo import data
from combo.common import util
from combo.config import Registry
from combo.config.from_parameters import register_arguments
from combo.data import Instance, conllu2sentence, sentence2conllu
from combo.data import Instance, conllu2sentence, sentence2conllu, Tokenizer
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict
......@@ -27,7 +27,6 @@ from combo.ner_modules.data.NerTokenizer import NerTokenizer
from pathlib import Path
from combo.ner_modules.utils.utils import move_tensors_to_device
logger = logging.getLogger(__name__)
......@@ -56,7 +55,6 @@ class COMBO(PredictorModule):
if ner_model is not None:
self._load_ner_model(ner_model)
def __call__(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]], **kwargs):
"""Depending on the input uses (or ignores) tokenizer.
When model isn't only text-based only List[data.Sentence] is possible input.
......@@ -95,10 +93,11 @@ class COMBO(PredictorModule):
sentence = self.dataset_reader.tokenizer.tokenize(sentence, **kwargs)
elif isinstance(sentence, list):
if isinstance(sentence[0], str):
sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(sentence)]]
sentence = [[Token(idx=idx + 1, text=t) for idx, t in enumerate(sentence)]]
elif isinstance(sentence[0], list):
if isinstance(sentence[0][0], str):
sentence = [[Token(idx=idx+1, text=t) for idx, t in enumerate(subsentence)] for subsentence in sentence]
sentence = [[Token(idx=idx + 1, text=t) for idx, t in enumerate(subsentence)] for subsentence in
sentence]
elif not isinstance(sentence[0][0], Token):
raise ValueError("Passed sentence must be a list (or list of lists) of strings or Token classes")
elif not isinstance(sentence[0], Token) and not isinstance(sentence[0], data.Sentence):
......@@ -176,7 +175,7 @@ class COMBO(PredictorModule):
# TODO: tokenize EVERYTHING, even if a list is passed?
if isinstance(sentence, str):
tokens = [sentence]
#tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])]
# tokens = [t.text for t in self.tokenizer.tokenize(json_dict["sentence"])]
elif isinstance(sentence, list):
tokens = sentence
else:
......@@ -295,13 +294,14 @@ class COMBO(PredictorModule):
tree.tokens.extend(empty_tokens)
return tree, predictions["sentence_embedding"], embeddings, \
deprel_tree_distribution, deprel_label_distribution
deprel_tree_distribution, deprel_label_distribution
@classmethod
def from_pretrained(cls,
path: str,
batch_size: int = 1024,
cuda_device: int = -1,
tokenizer: Optional[Tokenizer] = None,
ner_model: str = None):
if os.path.exists(path):
......@@ -314,9 +314,10 @@ class COMBO(PredictorModule):
logger.error(e)
raise e
archive = load_archive(model_path, cuda_device=cuda_device)
archive = load_archive(model_path, cuda_device=cuda_device, overriden_tokenizer=tokenizer)
model = archive.model
dataset_reader = archive.dataset_reader or default_ud_dataset_reader(archive.config.get("model_name"))
dataset_reader = archive.dataset_reader or default_ud_dataset_reader(archive.config.get("model_name"),
tokenizer=tokenizer)
return cls(model, dataset_reader, batch_size, ner_model=ner_model)
def _load_ner_model(self,
......@@ -331,7 +332,6 @@ class COMBO(PredictorModule):
self.ner_tokenizer = NerTokenizer.load_from_disc(folder_path=Path(ner_model),
load_lambo_tokenizer=False)
def _predict_ner_tags(self,
result: List[data.Sentence]) -> List[data.Sentence]:
"""Enriches predictions with NER tags."""
......@@ -343,4 +343,3 @@ class COMBO(PredictorModule):
for token, pred in zip(sentence.tokens, self.ner_tokenizer.decode(preds)[0]):
token.ner_tag = pred
return result
......@@ -3,7 +3,7 @@ requires = ["setuptools"]
[project]
name = "combo"
version = "3.2.3"
version = "3.3.0"
authors = [
{name = "Maja Jablonska", email = "maja.jablonska@ipipan.waw.pl"}
]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment