Skip to content
Snippets Groups Projects
Commit 2e2d1673 authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

Merge remote-tracking branch 'origin/combo-lambo' into develop

# Conflicts:
#	combo/main.py
parents 9ebeb1d7 e6eb721f
Branches
No related merge requests found
...@@ -86,8 +86,10 @@ flags.DEFINE_integer(name="batch_size", default=1, ...@@ -86,8 +86,10 @@ flags.DEFINE_integer(name="batch_size", default=1,
flags.DEFINE_boolean(name="silent", default=True, flags.DEFINE_boolean(name="silent", default=True,
help="Silent prediction to file (without printing to console).") help="Silent prediction to file (without printing to console).")
flags.DEFINE_enum(name="predictor_name", default="combo-spacy", flags.DEFINE_enum(name="predictor_name", default="combo-spacy",
enum_values=["combo", "combo-spacy"], enum_values=["combo", "combo-spacy", "combo-lambo"],
help="Use predictor with whitespace or spacy tokenizer.") help="Use predictor with whitespace, spacy or LAMBO tokenizer.")
#flags.DEFINE_string(name="lambo_model_name", default="en",
# help="LAMBO model name (if LAMBO used for segmentation).")
flags.DEFINE_string(name="save_relation_distribution_path", default=None, flags.DEFINE_string(name="save_relation_distribution_path", default=None,
help="Save relation distribution to file.") help="Save relation distribution to file.")
...@@ -179,7 +181,7 @@ def _get_predictor() -> predictors.Predictor: ...@@ -179,7 +181,7 @@ def _get_predictor() -> predictors.Predictor:
) )
return predictors.Predictor.from_archive( return predictors.Predictor.from_archive(
archive, FLAGS.predictor_name archive, FLAGS.predictor_name#, extra_args={"lambo_model_name" : FLAGS.lambo_model_name}
) )
......
...@@ -12,13 +12,14 @@ from overrides import overrides ...@@ -12,13 +12,14 @@ from overrides import overrides
from combo import data from combo import data
from combo.data import sentence2conllu, tokens2conllu, conllu2sentence from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
from combo.utils import download, graph from combo.utils import download, graph, lambo
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@predictor.Predictor.register("combo") @predictor.Predictor.register("combo")
@predictor.Predictor.register("combo-spacy", constructor="with_spacy_tokenizer") @predictor.Predictor.register("combo-spacy", constructor="with_spacy_tokenizer")
@predictor.Predictor.register("combo-lambo", constructor="with_lambo_tokenizer")
class COMBO(predictor.Predictor): class COMBO(predictor.Predictor):
def __init__(self, def __init__(self,
...@@ -58,7 +59,11 @@ class COMBO(predictor.Predictor): ...@@ -58,7 +59,11 @@ class COMBO(predictor.Predictor):
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str): if isinstance(sentence, str):
return self.predict_json({"sentence": sentence}) if isinstance(self._tokenizer,lambo.LamboTokenizer):
segmented = self._tokenizer.segment(sentence)
return self.predict(segmented)
else:
return self.predict_json({"sentence": sentence})
elif isinstance(sentence, list): elif isinstance(sentence, list):
if len(sentence) == 0: if len(sentence) == 0:
return [] return []
...@@ -239,6 +244,11 @@ class COMBO(predictor.Predictor): ...@@ -239,6 +244,11 @@ class COMBO(predictor.Predictor):
def with_spacy_tokenizer(cls, model: models.Model, def with_spacy_tokenizer(cls, model: models.Model,
dataset_reader: allen_data.DatasetReader): dataset_reader: allen_data.DatasetReader):
return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
@classmethod
def with_lambo_tokenizer(cls, model: models.Model,
dataset_reader: allen_data.DatasetReader, lambo_model_name : str = 'en'):
return cls(model, dataset_reader, lambo.LamboTokenizer(lambo_model_name))
@classmethod @classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
......
from typing import List
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data.tokenizers.token_class import Token
from lambo.segmenter.lambo import Lambo
class LamboTokenizer(Tokenizer):
def __init__(self, model: str) -> None:
self.lambo=Lambo.get(model)
# Simple tokenisation: ignoring sentence split
def tokenize(self, text: str) -> List[Token]:
result=[]
document = self.lambo.segment(text)
for turn in document.turns:
for sentence in turn.sentences:
for token in sentence.tokens:
result.append(Token(token.text))
return result
# Full segmentation: divide into sentences and tokens
def segment(self, text: str) -> List[List[str]]:
result = []
document = self.lambo.segment(text)
for turn in document.turns:
for sentence in turn.sentences:
resultS=[]
for token in sentence.tokens:
resultS.append(token.text)
result.append(resultS)
return result
\ No newline at end of file
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment