Skip to content
Snippets Groups Projects
Commit 2e2d1673 authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

Merge remote-tracking branch 'origin/combo-lambo' into develop

# Conflicts:
#	combo/main.py
parents 9ebeb1d7 e6eb721f
No related branches found
No related tags found
No related merge requests found
......@@ -86,8 +86,10 @@ flags.DEFINE_integer(name="batch_size", default=1,
flags.DEFINE_boolean(name="silent", default=True,
help="Silent prediction to file (without printing to console).")
flags.DEFINE_enum(name="predictor_name", default="combo-spacy",
enum_values=["combo", "combo-spacy"],
help="Use predictor with whitespace or spacy tokenizer.")
enum_values=["combo", "combo-spacy", "combo-lambo"],
help="Use predictor with whitespace, spacy or LAMBO tokenizer.")
#flags.DEFINE_string(name="lambo_model_name", default="en",
# help="LAMBO model name (if LAMBO used for segmentation).")
flags.DEFINE_string(name="save_relation_distribution_path", default=None,
help="Save relation distribution to file.")
......@@ -179,7 +181,7 @@ def _get_predictor() -> predictors.Predictor:
)
return predictors.Predictor.from_archive(
archive, FLAGS.predictor_name
archive, FLAGS.predictor_name#, extra_args={"lambo_model_name" : FLAGS.lambo_model_name}
)
......
......@@ -12,13 +12,14 @@ from overrides import overrides
from combo import data
from combo.data import sentence2conllu, tokens2conllu, conllu2sentence
from combo.utils import download, graph
from combo.utils import download, graph, lambo
logger = logging.getLogger(__name__)
@predictor.Predictor.register("combo")
@predictor.Predictor.register("combo-spacy", constructor="with_spacy_tokenizer")
@predictor.Predictor.register("combo-lambo", constructor="with_lambo_tokenizer")
class COMBO(predictor.Predictor):
def __init__(self,
......@@ -58,6 +59,10 @@ class COMBO(predictor.Predictor):
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str):
if isinstance(self._tokenizer,lambo.LamboTokenizer):
segmented = self._tokenizer.segment(sentence)
return self.predict(segmented)
else:
return self.predict_json({"sentence": sentence})
elif isinstance(sentence, list):
if len(sentence) == 0:
......@@ -240,6 +245,11 @@ class COMBO(predictor.Predictor):
dataset_reader: allen_data.DatasetReader):
return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
@classmethod
def with_lambo_tokenizer(cls, model: models.Model,
dataset_reader: allen_data.DatasetReader, lambo_model_name : str = 'en'):
return cls(model, dataset_reader, lambo.LamboTokenizer(lambo_model_name))
@classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 1024,
......
from typing import List
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data.tokenizers.token_class import Token
from lambo.segmenter.lambo import Lambo
class LamboTokenizer(Tokenizer):
def __init__(self, model: str) -> None:
self.lambo=Lambo.get(model)
# Simple tokenisation: ignoring sentence split
def tokenize(self, text: str) -> List[Token]:
result=[]
document = self.lambo.segment(text)
for turn in document.turns:
for sentence in turn.sentences:
for token in sentence.tokens:
result.append(Token(token.text))
return result
# Full segmentation: divide into sentences and tokens
def segment(self, text: str) -> List[List[str]]:
result = []
document = self.lambo.segment(text)
for turn in document.turns:
for sentence in turn.sentences:
resultS=[]
for token in sentence.tokens:
resultS.append(token.text)
result.append(resultS)
return result
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment