diff --git a/combo/default_model.py b/combo/default_model.py index 927966f175f6e9df47504a479a2d12b95f6239ef..96223570aec1fbfffd92b2150f5f9459b54b27ca 100644 --- a/combo/default_model.py +++ b/combo/default_model.py @@ -12,7 +12,7 @@ from combo.data.tokenizers import CharacterTokenizer from combo.data.vocabulary import Vocabulary from combo.combo_model import ComboModel from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM -from combo.modules.dilated_cnn import DilatedCnnEncoder +from combo.models.dilated_cnn import DilatedCnnEncoder from combo.modules.lemma import LemmatizerModel from combo.modules.morpho import MorphologicalFeatures from combo.modules.parser import DependencyRelationModel, HeadPredictionModel diff --git a/combo/main.py b/combo/main.py index d5fb9dc317c1931e3a3f6f4ebeb6934949b3500f..ecc455ceee366bc1fdfbf9521904db82e5d577dd 100755 --- a/combo/main.py +++ b/combo/main.py @@ -18,6 +18,7 @@ from combo.default_model import default_ud_dataset_reader, default_data_loader from combo.modules.archival import load_archive, archive from combo.predict import COMBO from combo.data import api +from combo.data import DatasetReader logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) @@ -93,14 +94,6 @@ flags.DEFINE_enum(name="predictor_name", default="combo-lambo", enum_values=["combo", "combo-spacy", "combo-lambo"], help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.") - -def get_predictor() -> COMBO: - checks.file_exists(FLAGS.model_path) - arch = load_archive(FLAGS.model_path) - dataset_reader = default_ud_dataset_reader() - return COMBO(arch.model, dataset_reader) - - def run(_): if FLAGS.mode == 'train': if not FLAGS.finetuning: @@ -211,13 +204,39 @@ def run(_): keep_semrel=dataset_reader.use_sem).serialize()) elif FLAGS.mode == 'predict': - predictor = get_predictor() - sentence = input("Sentence:") - prediction = predictor(sentence) - print("{:15} {:15} {:10} {:10} {:10}".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL')) - for token in prediction.tokens: - print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head, - token.deprel)) + prefix = 'Predicting' + logger.info('Loading the model', prefix=prefix) + model, _, _, _, dataset_reader = load_archive(FLAGS.model_path) + + if not dataset_reader: + logger.info("No dataset reader in the configuration or archive file - using a default UD dataset reader", + prefix=prefix) + dataset_reader = default_ud_dataset_reader() + + predictor = COMBO(model, dataset_reader) + + if FLAGS.input_file == '-': + print("Interactive mode.") + sentence = input("Sentence: ") + prediction = predictor(sentence) + print("{:15} {:15} {:10} {:10} {:10}".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL')) + for token in prediction.tokens: + print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head, + token.deprel)) + elif FLAGS.output_file: + checks.file_exists(FLAGS.input_file) + logger.info("Predicting examples from file", prefix=prefix) + test_trees = dataset_reader.read(FLAGS.input_file) + predictor = COMBO(model, dataset_reader) + with open(FLAGS.output_file, "w") as file: + for tree in tqdm(test_trees): + file.writelines(api.sentence2conllu(predictor.predict_instance(tree), + keep_semrel=dataset_reader.use_sem).serialize()) + + else: + msg = 'No output file for input file {input_file} specified.'.format(input_file=FLAGS.input_file) + logger.info(msg, prefix=prefix) + print(msg) def _get_ext_vars(finetuning: bool = False) -> Dict: diff --git a/combo/modules/dilated_cnn.py b/combo/modules/dilated_cnn.py deleted file mode 100644 index e694a1013362a0c06a9a6795ca65a04b864c8f32..0000000000000000000000000000000000000000 --- a/combo/modules/dilated_cnn.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -Adapted from COMBO 1.0 -Author: Mateusz Klimaszewski -""" - -from typing import List - -import torch - -from combo.config import FromParameters, Registry -from combo.config.from_parameters import register_arguments -from combo.nn.activations import Activation - - -@Registry.register('dilated_cnn') -class DilatedCnnEncoder(torch.nn.Module, FromParameters): - @register_arguments - def __init__(self, - input_dim: int, - filters: List[int], - kernel_size: List[int], - stride: List[int], - padding: List[int], - dilation: List[int], - activations: List[Activation]): - super().__init__() - conv1d_layers = [] - input_dims = [input_dim] + filters[:-1] - output_dims = filters - for idx in range(len(activations)): - conv1d_layers.append(torch.nn.Conv1d( - in_channels=input_dims[idx], - out_channels=output_dims[idx], - kernel_size=(kernel_size[idx],), - stride=(stride[idx],), - padding=padding[idx], - dilation=(dilation[idx],))) - self.conv1d_layers = torch.nn.ModuleList(conv1d_layers) - self.activations = activations - assert len(self.activations) == len(self.conv1d_layers) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - for layer, activation in zip(self.conv1d_layers, self.activations): - x = activation(layer(x)) - return x diff --git a/combo/modules/lemma.py b/combo/modules/lemma.py index 9ef9e5e28f5e7a5f67face2448fcf664a20ba40c..960f382336330d8c3ecd20dc1aa0adfe10773017 100644 --- a/combo/modules/lemma.py +++ b/combo/modules/lemma.py @@ -7,7 +7,7 @@ from overrides import overrides from combo import data from combo.config import Registry from combo.config.from_parameters import register_arguments -from combo.modules import dilated_cnn +from combo.models import dilated_cnn from combo.nn import base from combo.nn.activations import Activation from combo.nn.utils import masked_cross_entropy diff --git a/combo/modules/token_embedders/character_token_embedder.py b/combo/modules/token_embedders/character_token_embedder.py index 1967ecf9bf92214b3d3124ffa92e930cca3a2753..a6cf2f2470c05cecc72fe19eb16792c7c90c708a 100644 --- a/combo/modules/token_embedders/character_token_embedder.py +++ b/combo/modules/token_embedders/character_token_embedder.py @@ -8,7 +8,7 @@ from overrides import overrides from combo.config import Registry from combo.config.from_parameters import register_arguments from combo.data import Vocabulary -from combo.modules.dilated_cnn import DilatedCnnEncoder +from combo.models.dilated_cnn import DilatedCnnEncoder from combo.modules.token_embedders import TokenEmbedder from typing import Optional diff --git a/combo/polish_model_training.ipynb b/combo/polish_model_training.ipynb index f9787ee068f881340c23e426d86e487d6c4fa651..1d75abd32d335d498dc0b8631b2898e89e7bb92a 100644 --- a/combo/polish_model_training.ipynb +++ b/combo/polish_model_training.ipynb @@ -14,15 +14,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:28:53.129601Z", - "start_time": "2023-11-11T07:28:52.947282Z" + "end_time": "2023-11-13T07:47:15.954139Z", + "start_time": "2023-11-13T07:47:15.711912Z" } }, "id": "b28c7d8bacb08d02" }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "outputs": [], "source": [ "from combo.predict import COMBO\n", @@ -34,7 +34,7 @@ "from combo.modules.token_embedders import CharacterBasedWordEmbedder, TransformersWordEmbedder\n", "from combo.modules import FeedForwardPredictor\n", "from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation\n", - "from combo.modules.dilated_cnn import DilatedCnnEncoder\n", + "from combo.models.dilated_cnn import DilatedCnnEncoder\n", "from combo.data.tokenizers import LamboTokenizer, CharacterTokenizer\n", "from combo.data.token_indexers import PretrainedTransformerIndexer, TokenConstPaddingCharactersIndexer, TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer\n", "from combo.data.dataset_readers import UniversalDependenciesDatasetReader\n", @@ -51,15 +51,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:29:25.986145Z", - "start_time": "2023-11-11T07:29:25.671527Z" + "end_time": "2023-11-13T07:47:22.233317Z", + "start_time": "2023-11-13T07:47:15.766709Z" } }, "id": "initial_id" }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "outputs": [ { "name": "stdout", @@ -77,7 +77,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "0df1c8c352a14e9993691edf5626968f" + "model_id": "d318d4f50da14b76a14eb20cb877ee67" } }, "metadata": {}, @@ -89,7 +89,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "8ac307b142074e38afe3d55e00b3c203" + "model_id": "dbab946d82ab4d0ead64fc02796c2a9f" } }, "metadata": {}, @@ -101,7 +101,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "f60a4d838fc849c3aff01901f531d91c" + "model_id": "2f3d9306cb2b463eb080c922fe775b02" } }, "metadata": {}, @@ -169,15 +169,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:29:54.208157Z", - "start_time": "2023-11-11T07:29:25.685934Z" + "end_time": "2023-11-13T07:47:42.601537Z", + "start_time": "2023-11-13T07:47:22.243325Z" } }, "id": "d74957f422f0b05b" }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "outputs": [], "source": [ "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n", @@ -192,15 +192,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:29:55.768681Z", - "start_time": "2023-11-11T07:29:54.231728Z" + "end_time": "2023-11-13T07:47:44.068445Z", + "start_time": "2023-11-13T07:47:42.595098Z" } }, "id": "fa724d362fd6bd23" }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "outputs": [ { "name": "stdout", @@ -211,9 +211,9 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7faf9b7f6820>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fdd1e0a0c80>" }, - "execution_count": 7, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -239,15 +239,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:29:55.840836Z", - "start_time": "2023-11-11T07:29:55.773085Z" + "end_time": "2023-11-13T07:47:44.196484Z", + "start_time": "2023-11-13T07:47:44.034821Z" } }, "id": "f8a10f9892005fca" }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "outputs": [ { "name": "stderr", @@ -263,21 +263,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:29:55.859643Z", - "start_time": "2023-11-11T07:29:55.837113Z" + "end_time": "2023-11-13T07:47:44.197075Z", + "start_time": "2023-11-13T07:47:44.055240Z" } }, "id": "14413692656b68ac" }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.sso.sso_relationship.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.weight']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.sso.sso_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -411,15 +411,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:30:01.875464Z", - "start_time": "2023-11-11T07:29:55.851182Z" + "end_time": "2023-11-13T07:47:48.599708Z", + "start_time": "2023-11-13T07:47:44.063606Z" } }, "id": "437d12054baaffa1" }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "outputs": [], "source": [ "data_loader.index_with(vocabulary)\n", @@ -430,15 +430,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:30:46.834983Z", - "start_time": "2023-11-11T07:30:01.904158Z" + "end_time": "2023-11-13T07:48:26.090634Z", + "start_time": "2023-11-13T07:47:48.622684Z" } }, "id": "e131e0ec75dc6927" }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "outputs": [], "source": [ "val_data_loader.index_with(vocabulary)" @@ -446,15 +446,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:30:51.486798Z", - "start_time": "2023-11-11T07:30:46.839866Z" + "end_time": "2023-11-13T07:48:32.052740Z", + "start_time": "2023-11-13T07:48:26.077694Z" } }, "id": "195c71fcf8170ff" }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "outputs": [ { "name": "stderr", @@ -481,15 +481,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:30:51.879326Z", - "start_time": "2023-11-11T07:30:51.500543Z" + "end_time": "2023-11-13T07:48:32.321842Z", + "start_time": "2023-11-13T07:48:32.056903Z" } }, "id": "cefc5173154d1605" }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "outputs": [ { "name": "stderr", @@ -512,7 +512,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "7b4d5f7d5cec41b98aaf4c5fa3fff0d8" + "model_id": "027b704c9899478bb71021e074ad29bf" } }, "metadata": {}, @@ -534,7 +534,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "71a7264ff03d4f17828db6f8a7893e39" + "model_id": "4af02d76668645ae9213db79ae97d36f" } }, "metadata": {}, @@ -546,7 +546,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "b93d9100d1bc4a04ab706a9a3840644c" + "model_id": "55eda5299f554849aba6bd2781608ed2" } }, "metadata": {}, @@ -566,15 +566,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:32:36.809443Z", - "start_time": "2023-11-11T07:30:51.816554Z" + "end_time": "2023-11-13T07:49:35.721377Z", + "start_time": "2023-11-13T07:48:32.278875Z" } }, "id": "e5af131bae4b1a33" }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "outputs": [], "source": [ "predictor = COMBO(model, dataset_reader)" @@ -582,15 +582,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:32:37.095367Z", - "start_time": "2023-11-11T07:32:32.550627Z" + "end_time": "2023-11-13T07:49:35.728679Z", + "start_time": "2023-11-13T07:49:35.696749Z" } }, "id": "3e23413c86063183" }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "outputs": [], "source": [ "a = predictor(\"Cześć, jestem psem.\")" @@ -598,24 +598,24 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:32:37.333871Z", - "start_time": "2023-11-11T07:32:32.625348Z" + "end_time": "2023-11-13T07:49:35.972167Z", + "start_time": "2023-11-13T07:49:35.711714Z" } }, "id": "d555d7f0223a624b" }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TOKEN LEMMA UPOS HEAD DEPREL \n", - "Cześć, ????? NOUN 2 punct \n", - "jestem ????? NOUN 0 root \n", - "psem. ???? NOUN 2 punct \n" + "Cześć, ?????? NOUN 0 root \n", + "jestem ?????? NOUN 1 punct \n", + "psem. ????? NOUN 1 punct \n" ] } ], @@ -627,15 +627,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:32:38.854144Z", - "start_time": "2023-11-11T07:32:35.324424Z" + "end_time": "2023-11-13T07:49:35.973153Z", + "start_time": "2023-11-13T07:49:35.929034Z" } }, "id": "a68cd3861e1ceb67" }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 15, "outputs": [], "source": [ "from modules.archival import archive" @@ -643,21 +643,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:32:41.112617Z", - "start_time": "2023-11-11T07:32:35.502093Z" + "end_time": "2023-11-13T07:49:35.973436Z", + "start_time": "2023-11-13T07:49:35.931941Z" } }, "id": "d0f43f4493218b5" }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 16, "outputs": [ { "data": { "text/plain": "'/Users/majajablonska/Documents/combo'" }, - "execution_count": 18, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -668,19 +668,23 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-11-11T07:34:18.278208Z", - "start_time": "2023-11-11T07:32:35.783931Z" + "end_time": "2023-11-13T07:51:13.077831Z", + "start_time": "2023-11-13T07:49:35.950950Z" } }, "id": "ec92aa5bb5bb3605" }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "outputs": [], "source": [], "metadata": { - "collapsed": false + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-11-13T07:51:13.123575Z", + "start_time": "2023-11-13T07:51:13.067631Z" + } }, "id": "5ad8a827586f65e3" }