diff --git a/combo/main.py b/combo/main.py index f319ae1512631aade30901f32f8d17377eda39a1..dd0abff53d15cae7de30dc98f9df6f11fac98ad9 100644 --- a/combo/main.py +++ b/combo/main.py @@ -86,8 +86,10 @@ flags.DEFINE_integer(name="batch_size", default=1, flags.DEFINE_boolean(name="silent", default=True, help="Silent prediction to file (without printing to console).") flags.DEFINE_enum(name="predictor_name", default="combo-spacy", - enum_values=["combo", "combo-spacy"], - help="Use predictor with whitespace or spacy tokenizer.") + enum_values=["combo", "combo-spacy", "combo-lambo"], + help="Use predictor with whitespace, spacy or LAMBO tokenizer.") +#flags.DEFINE_string(name="lambo_model_name", default="en", +# help="LAMBO model name (if LAMBO used for segmentation).") flags.DEFINE_string(name="save_relation_distribution_path", default=None, help="Save relation distribution to file.") @@ -179,7 +181,7 @@ def _get_predictor() -> predictors.Predictor: ) return predictors.Predictor.from_archive( - archive, FLAGS.predictor_name + archive, FLAGS.predictor_name#, extra_args={"lambo_model_name" : FLAGS.lambo_model_name} ) diff --git a/combo/predict.py b/combo/predict.py index fe7eb1a9b3ef66756c63662dd49903a3b13e0b46..90915a89bc9ab5101bef1bac4f249f061098dba1 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -12,13 +12,14 @@ from overrides import overrides from combo import data from combo.data import sentence2conllu, tokens2conllu, conllu2sentence -from combo.utils import download, graph +from combo.utils import download, graph, lambo logger = logging.getLogger(__name__) @predictor.Predictor.register("combo") @predictor.Predictor.register("combo-spacy", constructor="with_spacy_tokenizer") +@predictor.Predictor.register("combo-lambo", constructor="with_lambo_tokenizer") class COMBO(predictor.Predictor): def __init__(self, @@ -58,7 +59,11 @@ class COMBO(predictor.Predictor): def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]): if isinstance(sentence, str): - return self.predict_json({"sentence": sentence}) + if isinstance(self._tokenizer,lambo.LamboTokenizer): + segmented = self._tokenizer.segment(sentence) + return self.predict(segmented) + else: + return self.predict_json({"sentence": sentence}) elif isinstance(sentence, list): if len(sentence) == 0: return [] @@ -239,6 +244,11 @@ class COMBO(predictor.Predictor): def with_spacy_tokenizer(cls, model: models.Model, dataset_reader: allen_data.DatasetReader): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) + + @classmethod + def with_lambo_tokenizer(cls, model: models.Model, + dataset_reader: allen_data.DatasetReader, lambo_model_name : str = 'en'): + return cls(model, dataset_reader, lambo.LamboTokenizer(lambo_model_name)) @classmethod def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), diff --git a/combo/utils/lambo.py b/combo/utils/lambo.py new file mode 100644 index 0000000000000000000000000000000000000000..92ef8e1399f95ea8caf0bbf8f0763325c6c7a355 --- /dev/null +++ b/combo/utils/lambo.py @@ -0,0 +1,32 @@ +from typing import List + +from allennlp.data.tokenizers.tokenizer import Tokenizer +from allennlp.data.tokenizers.token_class import Token +from lambo.segmenter.lambo import Lambo + +class LamboTokenizer(Tokenizer): + + def __init__(self, model: str) -> None: + self.lambo=Lambo.get(model) + + # Simple tokenisation: ignoring sentence split + def tokenize(self, text: str) -> List[Token]: + result=[] + document = self.lambo.segment(text) + for turn in document.turns: + for sentence in turn.sentences: + for token in sentence.tokens: + result.append(Token(token.text)) + return result + + # Full segmentation: divide into sentences and tokens + def segment(self, text: str) -> List[List[str]]: + result = [] + document = self.lambo.segment(text) + for turn in document.turns: + for sentence in turn.sentences: + resultS=[] + for token in sentence.tokens: + resultS.append(token.text) + result.append(resultS) + return result \ No newline at end of file