diff --git a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py index 0fa4827987e270b0f986a7df0f158c69c124594d..c094b3c85daadce0cba9a764f20c4ddfb7d710f9 100644 --- a/combo/data/dataset_readers/universal_dependencies_dataset_reader.py +++ b/combo/data/dataset_readers/universal_dependencies_dataset_reader.py @@ -27,6 +27,15 @@ from combo.data.vocabulary import get_slices_if_not_provided from combo.utils import checks, pad_sequence_to_length +def parser_field_to_dataset_reader_field(field: str): + if field == 'upos': + return 'upostag' + elif field == 'xpos': + return 'xpostag' + else: + return field + + @Registry.register(DatasetReader, 'conllu') class UniversalDependenciesDatasetReader(DatasetReader, ABC): def __init__( @@ -70,8 +79,9 @@ class UniversalDependenciesDatasetReader(DatasetReader, ABC): fields = list(fields) fields.append("semrel") field_parsers["semrel"] = lambda line, i: line[i] + self.field_parsers = field_parsers - self.fields = tuple(fields) + self.fields = tuple([parser_field_to_dataset_reader_field(f) for f in fields]) self.__lemma_indexers = lemma_indexers self.__targets = targets diff --git a/combo/example.ipynb b/combo/example.ipynb index 1269fefdb7031762ceba5b7938592f06260d0501..f1174e33adf44f4a3b9b5f808eb39485f8819047 100644 --- a/combo/example.ipynb +++ b/combo/example.ipynb @@ -7,8 +7,8 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2023-09-29T11:59:43.374412Z", - "start_time": "2023-09-29T11:59:39.164713Z" + "end_time": "2023-09-30T06:30:32.968386Z", + "start_time": "2023-09-30T06:30:28.670080Z" } }, "outputs": [], @@ -48,8 +48,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:43.751419Z", - "start_time": "2023-09-29T11:59:43.356220Z" + "end_time": "2023-09-30T06:30:33.433633Z", + "start_time": "2023-09-30T06:30:32.965078Z" } }, "id": "7302b7d49ac2fc38" @@ -68,8 +68,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:43.754710Z", - "start_time": "2023-09-29T11:59:43.743252Z" + "end_time": "2023-09-30T06:30:33.502604Z", + "start_time": "2023-09-30T06:30:33.457202Z" } }, "id": "e0ac599d12cc33df" @@ -93,8 +93,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:43.787672Z", - "start_time": "2023-09-29T11:59:43.755206Z" + "end_time": "2023-09-30T06:30:33.516731Z", + "start_time": "2023-09-30T06:30:33.469116Z" } }, "id": "a7f687419bddd9f8" @@ -119,7 +119,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "55a97b4050364e31a0d13ab013262d15" + "model_id": "8cce881f28a94e879336221b2fab2ac3" } }, "metadata": {}, @@ -181,8 +181,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:51.813061Z", - "start_time": "2023-09-29T11:59:43.766557Z" + "end_time": "2023-09-30T06:30:47.592831Z", + "start_time": "2023-09-30T06:30:33.490375Z" } }, "id": "d74957f422f0b05b" @@ -204,8 +204,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:52.631225Z", - "start_time": "2023-09-29T11:59:51.818391Z" + "end_time": "2023-09-30T06:30:50.225703Z", + "start_time": "2023-09-30T06:30:47.479441Z" } }, "id": "fa724d362fd6bd23" @@ -223,7 +223,7 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fd958d502e0>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7f8db1e44430>" }, "execution_count": 7, "metadata": {}, @@ -251,8 +251,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:52.665717Z", - "start_time": "2023-09-29T11:59:52.630665Z" + "end_time": "2023-09-30T06:30:50.244596Z", + "start_time": "2023-09-30T06:30:49.859096Z" } }, "id": "f8a10f9892005fca" @@ -265,7 +265,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']\n", + "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -396,8 +396,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:56.308564Z", - "start_time": "2023-09-29T11:59:52.662376Z" + "end_time": "2023-09-30T06:30:54.651620Z", + "start_time": "2023-09-30T06:30:49.865341Z" } }, "id": "437d12054baaffa1" @@ -412,15 +412,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:56.323194Z", - "start_time": "2023-09-29T11:59:56.275602Z" + "end_time": "2023-09-30T06:30:54.680188Z", + "start_time": "2023-09-30T06:30:54.660096Z" } }, "id": "16ae311c44073668" }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 16, "outputs": [], "source": [ "for i in data_loader.iter_instances():\n", @@ -430,15 +430,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T11:59:56.829311Z", - "start_time": "2023-09-29T11:59:56.317467Z" + "end_time": "2023-09-30T06:31:13.385199Z", + "start_time": "2023-09-30T06:31:12.864300Z" } }, "id": "e131e0ec75dc6927" }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "outputs": [], "source": [ "a = nlp(\"__START__ Hello! My name is Dog. __END__\")" @@ -446,41 +446,78 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T12:00:44.999132Z", - "start_time": "2023-09-29T12:00:44.612375Z" + "end_time": "2023-09-30T06:31:05.928119Z", + "start_time": "2023-09-30T06:31:05.519479Z" } }, "id": "3e23413c86063183" }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 12, "outputs": [ { "data": { - "text/plain": "__START__" + "text/plain": "'JJR'" }, - "execution_count": 24, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "a.tokens[0]" + "a.tokens[-2].xpostag" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-29T12:01:53.250349Z", - "start_time": "2023-09-29T12:01:53.215738Z" + "end_time": "2023-09-30T06:30:56.069386Z", + "start_time": "2023-09-30T06:30:56.021433Z" } }, "id": "d555d7f0223a624b" }, { "cell_type": "code", - "execution_count": null, - "outputs": [], + "execution_count": 17, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ID TOKEN LEMMA UPOS HEAD DEPREL \n", + "Al 6?KK>g3K>KKKKKKKKKKKKKKKKKKKKK PROPN 16 nmod:tmod \n", + "- ?33>K>>KKKKKKKKKKKKKKKKKKKKKKK PROPN 16 nmod:tmod \n", + "Zaman E3633?3>KKKKKKKKKKKKKKKKKKKKKK PROPN 28 nmod:tmod \n", + ": 3333K>>>KKKKKKKKKKKKKKKKKKKKKK PROPN 28 nmod:tmod \n", + "American {{{6K63g3uif3KKKKKKKKKKKKKKKKK PROPN 28 nmod:tmod \n", + "forces lK3K333gu>3>>KKKKKKKKKKKKKKKiK PROPN 29 nmod:tmod \n", + "killed uEE6K>>uuKK>KKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "Shaikh >EE3336u3K>KKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "Abdullah 6K>l63K3>ui>3>KKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "al E?KK>33>>KKKKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "- ?33>K>>KKKKKKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "Ani u6?6u3>3KKKKKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + ", ?333K3>>KKKKKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "the 3K?u36>3K>KKKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "preacher 3>K3363u3>u>3>>KKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "at 3?K6>>3>KKKKKKKKKKKKKKKKKKKKiK ADJ 29 nmod:tmod \n", + "the 3K?u36>3K>KKKKKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "mosque |3u>K3336>3K>KKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "in 3?K3>33KKKKKKKKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "the 3K?u36>3K>KKKKKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "town Ku3>3|>f3KKKKKKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "of 3?Ku>>>KKKKKKKKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "Qaim u3333u63KKKKKKKKKKKKKKKKKKKKiK SYM 29 expl \n", + ", ?333K3>>KKKKKKKKKKKKKKKKKKKKiK SYM 29 expl \n", + "near u6KuKu33K>>KKKKKKKKKKKKKKKKKiK SYM 29 expl \n", + "the 3K?u36>3K>KKKKKKKKKKKKKKKKKKiK SYM 29 expl \n", + "Syrian ui3633Ku>f3KKKKKKKKKKKKKKKKKiK ADJ 29 expl \n", + "border i33>33u33>u>>KKKKKKKKKKKKKKKiK ADJ 29 expl \n", + ". ??3>>>KKKKKKKKKKKKKKKKKKKKKKiK ADJ 0 root \n" + ] + } + ], "source": [ "print(\"{:5} {:15} {:15} {:10} {:10} {:10}\".format('ID', 'TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))\n", "for token in a.tokens:\n", @@ -489,20 +526,22 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "start_time": "2023-09-29T11:59:57.890328Z" + "end_time": "2023-09-30T06:31:15.058040Z", + "start_time": "2023-09-30T06:31:14.723267Z" } }, "id": "a68cd3861e1ceb67" }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "outputs": [], "source": [], "metadata": { "collapsed": false, "ExecuteTime": { - "start_time": "2023-09-29T11:59:57.890649Z" + "end_time": "2023-09-30T06:30:56.070718Z", + "start_time": "2023-09-30T06:30:56.038704Z" } }, "id": "7e1a6f47f3aa54b8" diff --git a/combo/predict.py b/combo/predict.py index 70b8bd36e09401cee989f401e08bf28446ee11ad..98afe53055db045775e2433400c0371a402065dc 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -13,14 +13,14 @@ from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sen from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict -from combo.predictors import PredictorModel +from combo.predictors import PredictorModule from combo.utils import download, graph logger = logging.getLogger(__name__) -@Registry.register(PredictorModel, 'combo') -class COMBO(PredictorModel): +@Registry.register(PredictorModule, 'combo') +class COMBO(PredictorModule): def __init__(self, model: models.Model, @@ -183,7 +183,6 @@ class COMBO(PredictorModel): field_value = "|".join(np.array(features)[arg_indices].tolist()) setattr(token, field_name, field_value) - #token[field_name] = field_value embeddings[token["idx"]][field_name] = predictions[f"{field_name}_token_embedding"][idx] elif field_name == "lemma": prediction = field_predictions[idx] @@ -200,7 +199,6 @@ class COMBO(PredictorModel): word_chars.append(pred_char) setattr(token, field_name, "".join(word_chars)) - #token[field_name] = "".join(word_chars) else: raise NotImplementedError(f"Unknown field name {field_name}!") diff --git a/combo/predictors/__init__.py b/combo/predictors/__init__.py index a49f08e501b26571e664f564e8248437354b0b78..967701a342d41c61f9162881de6ebf9c224a09f6 100644 --- a/combo/predictors/__init__.py +++ b/combo/predictors/__init__.py @@ -1,2 +1,2 @@ from .predictor import Predictor -from .predictor_model import PredictorModel +from .predictor_model import PredictorModule diff --git a/combo/predictors/predictor_model.py b/combo/predictors/predictor_model.py index 2247ecac3e9769c79755eab080d4d8d1b56d38c5..09604208fb7bb2b480216554f0733f8e8fd68324 100644 --- a/combo/predictors/predictor_model.py +++ b/combo/predictors/predictor_model.py @@ -15,6 +15,8 @@ from torch.utils.hooks import RemovableHandle from torch import Tensor from torch import backends +import pytorch_lightning as pl + from combo.common.util import sanitize from combo.config import FromParameters from combo.data.batch import Batch @@ -26,13 +28,14 @@ from combo.nn import util logger = logging.getLogger(__name__) -class PredictorModel(FromParameters): +class PredictorModule(pl.LightningModule, FromParameters): """ a `Predictor` is a thin wrapper around an AllenNLP model that handles JSON -> JSON predictions that can be used for serving models through the web API or making predictions in bulk. """ def __init__(self, model: Model, dataset_reader: DatasetReader, frozen: bool = True) -> None: + super().__init__() if frozen: model.eval() self._model = model