diff --git a/combo/data/vocabulary.py b/combo/data/vocabulary.py index d9f6124f40daa338a211015aad6dc1f6d02fcfda..c5aa1c2a75643bc5391a93232eac76bcfc0a2e36 100644 --- a/combo/data/vocabulary.py +++ b/combo/data/vocabulary.py @@ -141,6 +141,17 @@ class Vocabulary(FromParameters): self._oov_token) self._retained_counter: Optional[Dict[str, Dict[str, int]]] = None + self._extend( + counter, + min_count, + max_vocab_size, + non_padded_namespaces, + pretrained_files, + only_include_pretrained_words, + tokens_to_add, + min_pretrained_embeddings + ) + def _extend(self, counter: Dict[str, Dict[str, int]] = None, min_count: Dict[str, int] = None, diff --git a/combo/dataset_reader.ipynb b/combo/dataset_reader.ipynb deleted file mode 100644 index 7ece112745cf515d42901f597d0ec2d5c71a450d..0000000000000000000000000000000000000000 --- a/combo/dataset_reader.ipynb +++ /dev/null @@ -1,202 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 10, - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2023-09-24T07:02:40.432822Z", - "start_time": "2023-09-24T07:02:40.415807Z" - } - }, - "outputs": [], - "source": [ - "from combo.data.dataset_readers import UniversalDependenciesDatasetReader\n", - "from combo.data.tokenizers import CharacterTokenizer\n", - "from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, TokenFeatsIndexer, PretrainedTransformerFixedMismatchedIndexer, SingleIdTokenIndexer\n", - "from combo.data.dataset_loaders import SimpleDataLoader\n", - "from combo.data.vocabulary import FromInstancesVocabulary" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [], - "source": [ - "def default_const_character_indexer():\n", - " return TokenConstPaddingCharactersIndexer(\n", - " tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n", - " start_tokens=[\"__START__\"]),\n", - " min_padding_length=32,\n", - " namespace=\"lemma_characters\"\n", - " )\n", - "\n", - "dataset_reader = UniversalDependenciesDatasetReader(\n", - " features=[\"token\", \"char\"],\n", - " lemma_indexers={\n", - " \"char\": default_const_character_indexer()\n", - " },\n", - " targets=[\"deprel\", \"head\", \"upostag\", \"lemma\", \"feats\", \"xpostag\"],\n", - " token_indexers={\n", - " \"char\": default_const_character_indexer(),\n", - " \"feats\": TokenFeatsIndexer(),\n", - " \"lemma\": default_const_character_indexer(),\n", - " \"token\": PretrainedTransformerFixedMismatchedIndexer(\"bert-base-cased\"),\n", - " \"upostag\": SingleIdTokenIndexer(\n", - " feature_name=\"pos_\",\n", - " namespace=\"upostag\"\n", - " ),\n", - " \"xpostag\": SingleIdTokenIndexer(\n", - " feature_name=\"tag_\",\n", - " namespace=\"xpostag\"\n", - " )\n", - " },\n", - " use_sem=False\n", - ")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-24T06:53:51.498706Z", - "start_time": "2023-09-24T06:53:49.212209Z" - } - }, - "id": "abb6ce33c2e461e6" - }, - { - "cell_type": "code", - "execution_count": 3, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, - { - "data": { - "text/plain": "loading instances: 0it [00:00, ?it/s]", - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "56d2edd36b9d42429d1629cdb6031126" - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "FILE_PATH = '/Users/majajablonska/Documents/train.conllu'\n", - "data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n", - " data_path=FILE_PATH,\n", - " batch_size=4)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-24T06:53:58.585298Z", - "start_time": "2023-09-24T06:53:51.497953Z" - } - }, - "id": "3519b6753622def0" - }, - { - "cell_type": "code", - "execution_count": 7, - "outputs": [], - "source": [ - "for i in data_loader.iter_instances():\n", - " break" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-24T06:54:56.208254Z", - "start_time": "2023-09-24T06:54:56.188568Z" - } - }, - "id": "eb23ae8415cb52c2" - }, - { - "cell_type": "code", - "execution_count": 11, - "outputs": [ - { - "data": { - "text/plain": "building vocabulary: 0it [00:00, ?it/s]", - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "5eb70beb73944090a9b054a4235d9df6" - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "vocabulary = FromInstancesVocabulary.from_instances_extended(\n", - " data_loader.iter_instances(),\n", - " non_padded_namespaces=['head_labels'],\n", - " only_include_pretrained_words=True,\n", - " oov_token='_',\n", - " padding_token='__PAD__'\n", - ")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-24T07:03:52.410405Z", - "start_time": "2023-09-24T07:03:45.701901Z" - } - }, - "id": "834f448f90453d03" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false - }, - "id": "82d4c789c15866ab" - }, - { - "cell_type": "markdown", - "source": [], - "metadata": { - "collapsed": false - }, - "id": "9a4de0a90632538" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/combo/example.ipynb b/combo/example.ipynb deleted file mode 100644 index 6e6de39aeffbb624f112cb9a9f20d92e691935bf..0000000000000000000000000000000000000000 --- a/combo/example.ipynb +++ /dev/null @@ -1,613 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:28.998851Z", - "start_time": "2023-09-30T10:14:46.034419Z" - } - }, - "outputs": [], - "source": [ - "from combo.models.combo_model import ComboModel\n", - "from combo.predict import COMBO\n", - "from combo.data.vocabulary import Vocabulary\n", - "from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM\n", - "from combo.modules.text_field_embedders import BasicTextFieldEmbedder\n", - "from combo.nn.base import Linear\n", - "from combo.modules.token_embedders import CharacterBasedWordEmbedder, TransformersWordEmbedder\n", - "from combo.nn.base import FeedForwardPredictor\n", - "from combo.nn.activations import ReLUActivation, TanhActivation, LinearActivation\n", - "from combo.modules.dilated_cnn import DilatedCnnEncoder\n", - "from combo.data.tokenizers import LamboTokenizer, CharacterTokenizer\n", - "from combo.data.token_indexers import PretrainedTransformerIndexer, TokenConstPaddingCharactersIndexer, TokenFeatsIndexer, SingleIdTokenIndexer, PretrainedTransformerFixedMismatchedIndexer\n", - "from combo.data.dataset_readers import UniversalDependenciesDatasetReader\n", - "import torch\n", - "from combo.data.dataset_loaders import SimpleDataLoader\n", - "from combo.modules.parser import DependencyRelationModel, HeadPredictionModel\n", - "from combo.modules.lemma import LemmatizerModel\n", - "from combo.modules.morpho import MorphologicalFeatures\n", - "from combo.nn.regularizers import RegularizerApplicator\n", - "from combo.nn.regularizers.regularizers import L2Regularizer" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [], - "source": [ - "weights = torch.load('/Users/majajablonska/.combo/english-bert-base-ud29/weights.th',\n", - " map_location=torch.device('cpu'))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.023086Z", - "start_time": "2023-09-30T10:14:50.480597Z" - } - }, - "id": "7302b7d49ac2fc38" - }, - { - "cell_type": "code", - "execution_count": 3, - "outputs": [], - "source": [ - "vocabulary = Vocabulary.from_files(\n", - " '/Users/majajablonska/.combo/english-bert-base-ud29/vocabulary',\n", - " oov_token='_',\n", - " padding_token='__PAD__'\n", - ")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.027147Z", - "start_time": "2023-09-30T10:14:50.877437Z" - } - }, - "id": "e0ac599d12cc33df" - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [ - { - "data": { - "text/plain": "{'deprel_labels',\n 'feats_labels',\n 'lemma_characters',\n 'token_characters',\n 'upostag_labels',\n 'xpostag_labels'}" - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "vocabulary.get_namespaces()" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.077462Z", - "start_time": "2023-09-30T10:14:50.894360Z" - } - }, - "id": "a7f687419bddd9f8" - }, - { - "cell_type": "code", - "execution_count": 5, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", - "To disable this warning, you can either:\n", - "\t- Avoid using `tokenizers` before the fork if possible\n", - "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" - ] - }, - { - "data": { - "text/plain": "loading instances: 0it [00:00, ?it/s]", - "application/vnd.jupyter.widget-view+json": { - "version_major": 2, - "version_minor": 0, - "model_id": "4cac7984914e46549d14b628edff6a19" - } - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from combo.data.batch import Batch\n", - "from combo.data.dataset_loaders.data_collator import allennlp_collate\n", - "\n", - "\n", - "def default_const_character_indexer(namespace = None):\n", - " if namespace:\n", - " return TokenConstPaddingCharactersIndexer(\n", - " tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n", - " start_tokens=[\"__START__\"]),\n", - " min_padding_length=32,\n", - " namespace=namespace\n", - " )\n", - " else:\n", - " return TokenConstPaddingCharactersIndexer(\n", - " tokenizer=CharacterTokenizer(end_tokens=[\"__END__\"],\n", - " start_tokens=[\"__START__\"]),\n", - " min_padding_length=32\n", - " )\n", - "\n", - "dataset_reader = UniversalDependenciesDatasetReader(\n", - " features=[\"token\", \"char\"],\n", - " lemma_indexers={\n", - " \"char\": default_const_character_indexer(\"lemma_characters\")\n", - " },\n", - " targets=[\"deprel\", \"head\", \"upostag\", \"lemma\", \"feats\", \"xpostag\"],\n", - " token_indexers={\n", - " \"char\": default_const_character_indexer(),\n", - " \"feats\": TokenFeatsIndexer(),\n", - " \"lemma\": default_const_character_indexer(),\n", - " \"token\": PretrainedTransformerFixedMismatchedIndexer(\"bert-base-cased\"),\n", - " \"upostag\": SingleIdTokenIndexer(\n", - " feature_name=\"pos_\",\n", - " namespace=\"upostag\"\n", - " ),\n", - " \"xpostag\": SingleIdTokenIndexer(\n", - " feature_name=\"tag_\",\n", - " namespace=\"xpostag\"\n", - " )\n", - " },\n", - " use_sem=False\n", - ")\n", - "\n", - "FILE_PATH = '/Users/majajablonska/Documents/train.conllu'\n", - "data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n", - " data_path=FILE_PATH,\n", - " batch_size=4,\n", - " collate_fn=lambda instances: Batch(instances))\n", - "\n", - "# vocabulary = FromInstancesVocabulary.from_instances_extended(\n", - "# data_loader.iter_instances(),\n", - "# non_padded_namespaces=['head_labels'],\n", - "# only_include_pretrained_words=False,\n", - "# oov_token='_',\n", - "# padding_token='__PAD__'\n", - "# )" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.087017Z", - "start_time": "2023-09-30T10:14:50.909170Z" - } - }, - "id": "d74957f422f0b05b" - }, - { - "cell_type": "code", - "execution_count": 6, - "outputs": [], - "source": [ - "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n", - " stacked_bilstm=ComboStackedBidirectionalLSTM(\n", - " hidden_size=512,\n", - " input_size=164,\n", - " layer_dropout_probability=0.33,\n", - " num_layers=2,\n", - " recurrent_dropout_probability=0.33\n", - " ))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.093211Z", - "start_time": "2023-09-30T10:14:58.819470Z" - } - }, - "id": "fa724d362fd6bd23" - }, - { - "cell_type": "code", - "execution_count": 7, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using model LAMBO-UD_English-EWT\n" - ] - }, - { - "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7f82ef8fc4a0>" - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "char_words_embedder = CharacterBasedWordEmbedder(\n", - " dilated_cnn_encoder = DilatedCnnEncoder(\n", - " input_dim=64,\n", - " kernel_size=[3, 3, 3],\n", - " padding=[1, 2, 4],\n", - " stride=[1, 1, 1],\n", - " filters=[512, 256, 64],\n", - " dilation=[1, 2, 4],\n", - " activations=[ReLUActivation(), ReLUActivation(), LinearActivation()]\n", - " ),\n", - " embedding_dim=64,\n", - " vocabulary=vocabulary\n", - ")\n", - "tokenizer = LamboTokenizer()\n", - "indexer = PretrainedTransformerIndexer('bert-base-cased')\n", - "data_loader.iter_instances()" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.093908Z", - "start_time": "2023-09-30T10:14:59.488644Z" - } - }, - "id": "f8a10f9892005fca" - }, - { - "cell_type": "code", - "execution_count": 8, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']\n", - "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", - "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" - ] - } - ], - "source": [ - "model = ComboModel(\n", - " vocabulary=vocabulary,\n", - " dependency_relation=DependencyRelationModel(\n", - " vocab=vocabulary,\n", - " dependency_projection_layer=Linear(\n", - " activation=TanhActivation(),\n", - " dropout_rate=0.25,\n", - " in_features=1024,\n", - " out_features=128\n", - " ),\n", - " head_predictor=HeadPredictionModel(\n", - " cycle_loss_n=0,\n", - " dependency_projection_layer=Linear(\n", - " activation=TanhActivation(),\n", - " in_features=1024,\n", - " out_features=512\n", - " ),\n", - " head_projection_layer=Linear(\n", - " activation=TanhActivation(),\n", - " in_features=1024,\n", - " out_features=512\n", - " )\n", - " ),\n", - " head_projection_layer=Linear(\n", - " activation=TanhActivation(),\n", - " dropout_rate=0.25,\n", - " in_features=1024,\n", - " out_features=128\n", - " ),\n", - " vocab_namespace=\"deprel_labels\"\n", - " ),\n", - " lemmatizer=LemmatizerModel(\n", - " vocab=vocabulary,\n", - " activations=[ReLUActivation(), ReLUActivation(), ReLUActivation(), LinearActivation()],\n", - " char_vocab_namespace=\"token_characters\",\n", - " dilation=[1, 2, 4, 1],\n", - " embedding_dim=256, \n", - " filters=[256, 256, 256],\n", - " input_projection_layer=Linear(\n", - " activation=TanhActivation(),\n", - " dropout_rate=0.25,\n", - " in_features=1024,\n", - " out_features=32\n", - " ),\n", - " kernel_size=[3, 3, 3, 1],\n", - " lemma_vocab_namespace=\"lemma_characters\",\n", - " padding=[1, 2, 4, 0],\n", - " stride=[1, 1, 1, 1]\n", - " ),\n", - " loss_weights={\n", - " \"deprel\": 0.8,\n", - " \"feats\": 0.2,\n", - " \"head\": 0.2,\n", - " \"lemma\": 0.05,\n", - " \"semrel\": 0.05,\n", - " \"upostag\": 0.05,\n", - " \"xpostag\": 0.05\n", - " },\n", - " morphological_feat=MorphologicalFeatures(\n", - " vocab=vocabulary,\n", - " activations=[TanhActivation(), LinearActivation()],\n", - " dropout=[0.25, 0.],\n", - " hidden_dims=[128],\n", - " input_dim=1024,\n", - " num_layers=2,\n", - " vocab_namespace=\"feats_labels\"\n", - " ),\n", - " regularizer=RegularizerApplicator([\n", - " (\".*conv1d.*\", L2Regularizer(1e-6)),\n", - " (\".*forward.*\", L2Regularizer(1e-6)),\n", - " (\".*backward.*\", L2Regularizer(1e-6)),\n", - " (\".*char_embed.*\", L2Regularizer(1e-5))\n", - " ]),\n", - " seq_encoder=ComboEncoder(\n", - " layer_dropout_probability=0.33,\n", - " stacked_bilstm=ComboStackedBidirectionalLSTM(\n", - " hidden_size=512,\n", - " input_size=164,\n", - " layer_dropout_probability=0.33,\n", - " num_layers=2,\n", - " recurrent_dropout_probability=0.33\n", - " )\n", - " ),\n", - " text_field_embedder=BasicTextFieldEmbedder(\n", - " token_embedders={\n", - " \"char\": CharacterBasedWordEmbedder(\n", - " vocabulary=vocabulary,\n", - " dilated_cnn_encoder=DilatedCnnEncoder(\n", - " activations=[ReLUActivation(), ReLUActivation(), LinearActivation()],\n", - " dilation=[1, 2, 4],\n", - " filters=[512, 256, 64],\n", - " input_dim=64,\n", - " kernel_size=[3, 3, 3],\n", - " padding=[1, 2, 4],\n", - " stride=[1, 1, 1],\n", - " ),\n", - " embedding_dim=64\n", - " ),\n", - " \"token\": TransformersWordEmbedder(\"bert-base-cased\", projection_dim=100)\n", - " }\n", - " ),\n", - " upos_tagger=FeedForwardPredictor.from_vocab(\n", - " vocab=vocabulary,\n", - " activations=[TanhActivation(), LinearActivation()],\n", - " dropout=[0.25, 0.],\n", - " hidden_dims=[64],\n", - " input_dim=1024,\n", - " num_layers=2,\n", - " vocab_namespace=\"upostag_labels\"\n", - " ),\n", - " xpos_tagger=FeedForwardPredictor.from_vocab(\n", - " vocab=vocabulary,\n", - " activations=[TanhActivation(), LinearActivation()],\n", - " dropout=[0.25, 0.],\n", - " hidden_dims=[64],\n", - " input_dim=1024,\n", - " num_layers=2,\n", - " vocab_namespace=\"xpostag_labels\"\n", - " )\n", - ")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.094482Z", - "start_time": "2023-09-30T10:14:59.529195Z" - } - }, - "id": "437d12054baaffa1" - }, - { - "cell_type": "code", - "execution_count": 9, - "outputs": [], - "source": [ - "nlp = COMBO(model, dataset_reader)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:22:29.094568Z", - "start_time": "2023-09-30T10:15:01.838557Z" - } - }, - "id": "16ae311c44073668" - }, - { - "cell_type": "code", - "execution_count": 20, - "outputs": [], - "source": [ - "data_loader.index_with(vocabulary)\n", - "for i in data_loader:\n", - " break\n", - "a = nlp.predict_instance(next(data_loader.iter_instances()))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:27:52.401373Z", - "start_time": "2023-09-30T10:27:51.315384Z" - } - }, - "id": "e131e0ec75dc6927" - }, - { - "cell_type": "code", - "execution_count": 22, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ID TOKEN LEMMA UPOS HEAD DEPREL \n", - "Al kIkkIIg~Iggggggggggggggggggggr INTJ 0 amod \n", - "- Ikkrrg~ggggggggggggggggggggggr PRON 0 flat \n", - "Zaman rrIIkg@kDggggggggggggggggggggr PRON 0 flat \n", - ": gkkIrggggggggggggggggggggggggr X 0 flat \n", - "American gDIfIrIIgkkIgggggggggggggggggg PRON 0 nmod:poss \n", - "forces rI~rfIgrkkg~kggggggggggggggggr PRON 0 flat \n", - "killed gI~rkkIkrIIgIggggggggggggggggr PRON 0 flat \n", - "Shaikh rIDIIBIIkrI~gggggggggggggggggr PRON 0 aux:pass \n", - "Abdullah IBIIDIJIQrkII~gggggggggggggggr PRON 0 aux:pass \n", - "al kkkkkDD~gggggggggggggggggggggr VERB 0 aux:pass \n", - "- Ikkrrggggggggggggggggggggggggr PRON 0 flat \n", - "Ani rkgIkkrI~ggggggggggggggggggggr PRON 0 aux:pass \n", - ", kkkkrg~Igggggggggggggggggggggr PRON 0 amod \n", - "the rgIkIgr~~kgggggggggggggggggggr VERB 0 nmod:poss \n", - "preacher DgDIIrkIIkgIIrkggggggggggggggr VERB 0 aux:pass \n", - "at @kkkkrgggggggggggggggggggggggr PRON 0 xcomp \n", - "the rgIkIgrI~kgggggggggggggggggggr VERB 0 xcomp \n", - "mosque gI?IIIgkkrIgkggggggggggggggggr PRON 0 nmod:poss \n", - "in ?IkIkDrggggggggggggggggggggggr VERB 0 nmod:poss \n", - "the rgIkIgrI~kgggggggggggggggggggr VERB 0 xcomp \n", - "town grkIgIkkgggggggggggggggggggggr VERB 0 nmod:poss \n", - "of rkkIrIgggggggggggggggggggggggr VERB 0 aux:pass \n", - "Qaim frkIkIgkIggggggggggggggggggggr PRON 0 xcomp \n", - ", Ikkkrg~Igggggggggggggggggggggr VERB 0 nmod:poss \n", - "near rrkkIrIIIgIggggggggggggggggggr VERB 0 nmod:poss \n", - "the rgIkIgrI~kgggggggggggggggggggr VERB 0 xcomp \n", - "Syrian r?XrbIrIkIIggggggggggggggggggr VERB 0 amod \n", - "border gIIrkDIrgIIgIggggggggggggggggr VERB 0 aux:pass \n", - ". Ikkrkggggggggggggggggggggggggr VERB 0 aux:pass \n" - ] - } - ], - "source": [ - "a = nlp.predict_batch_instance(i)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:27:57.287841Z", - "start_time": "2023-09-30T10:27:57.267783Z" - } - }, - "id": "dfb5ee72353e8c7f" - }, - { - "cell_type": "code", - "execution_count": 14, - "outputs": [ - { - "data": { - "text/plain": "[tensor(9.0285, grad_fn=<DivBackward0>),\n tensor(8.9160, grad_fn=<DivBackward0>),\n tensor(8.9323, grad_fn=<DivBackward0>),\n tensor(8.8981, grad_fn=<DivBackward0>),\n tensor(8.9450, grad_fn=<DivBackward0>),\n tensor(8.9228, grad_fn=<DivBackward0>),\n tensor(8.9216, grad_fn=<DivBackward0>),\n tensor(8.9147, grad_fn=<DivBackward0>),\n tensor(8.9706, grad_fn=<DivBackward0>),\n tensor(8.9243, grad_fn=<DivBackward0>)]" - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "nlp.get_gradients(i)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:23:34.046212Z", - "start_time": "2023-09-30T10:23:34.028117Z" - } - }, - "id": "cefc5173154d1605" - }, - { - "cell_type": "code", - "execution_count": 18, - "outputs": [ - { - "data": { - "text/plain": "[tensor(8.8657, grad_fn=<DivBackward0>),\n tensor(8.9624, grad_fn=<DivBackward0>),\n tensor(8.9071, grad_fn=<DivBackward0>),\n tensor(8.9303, grad_fn=<DivBackward0>),\n tensor(8.9620, grad_fn=<DivBackward0>),\n tensor(8.9274, grad_fn=<DivBackward0>),\n tensor(9.0054, grad_fn=<DivBackward0>),\n tensor(8.8900, grad_fn=<DivBackward0>),\n tensor(8.9158, grad_fn=<DivBackward0>),\n tensor(8.9557, grad_fn=<DivBackward0>),\n tensor(8.9092, grad_fn=<DivBackward0>),\n tensor(8.8874, grad_fn=<DivBackward0>),\n tensor(8.9916, grad_fn=<DivBackward0>),\n tensor(9.0001, grad_fn=<DivBackward0>),\n tensor(8.8875, grad_fn=<DivBackward0>),\n tensor(9.0140, grad_fn=<DivBackward0>),\n tensor(8.8741, grad_fn=<DivBackward0>),\n tensor(8.9413, grad_fn=<DivBackward0>),\n tensor(8.8852, grad_fn=<DivBackward0>),\n tensor(8.9800, grad_fn=<DivBackward0>),\n tensor(8.8693, grad_fn=<DivBackward0>),\n tensor(8.9082, grad_fn=<DivBackward0>),\n tensor(8.9906, grad_fn=<DivBackward0>),\n tensor(8.9462, grad_fn=<DivBackward0>),\n tensor(8.9875, grad_fn=<DivBackward0>),\n tensor(8.9254, grad_fn=<DivBackward0>),\n tensor(8.9305, grad_fn=<DivBackward0>),\n tensor(8.9767, grad_fn=<DivBackward0>),\n tensor(8.9332, grad_fn=<DivBackward0>),\n tensor(8.9140, grad_fn=<DivBackward0>),\n tensor(9.0070, grad_fn=<DivBackward0>),\n tensor(9.0090, grad_fn=<DivBackward0>),\n tensor(8.8530, grad_fn=<DivBackward0>),\n tensor(8.9505, grad_fn=<DivBackward0>),\n tensor(8.9805, grad_fn=<DivBackward0>),\n tensor(9.0016, grad_fn=<DivBackward0>),\n tensor(8.9162, grad_fn=<DivBackward0>),\n tensor(8.9593, grad_fn=<DivBackward0>),\n tensor(8.9915, grad_fn=<DivBackward0>),\n tensor(8.9792, grad_fn=<DivBackward0>),\n tensor(8.9610, grad_fn=<DivBackward0>),\n tensor(8.9677, grad_fn=<DivBackward0>),\n tensor(8.9177, grad_fn=<DivBackward0>),\n tensor(8.8837, grad_fn=<DivBackward0>),\n tensor(9.0080, grad_fn=<DivBackward0>),\n tensor(8.9786, grad_fn=<DivBackward0>),\n tensor(8.9738, grad_fn=<DivBackward0>),\n tensor(8.9820, grad_fn=<DivBackward0>),\n tensor(8.8889, grad_fn=<DivBackward0>),\n tensor(8.9882, grad_fn=<DivBackward0>),\n tensor(8.9031, grad_fn=<DivBackward0>),\n tensor(8.8290, grad_fn=<DivBackward0>),\n tensor(8.9471, grad_fn=<DivBackward0>),\n tensor(8.9244, grad_fn=<DivBackward0>),\n tensor(8.8250, grad_fn=<DivBackward0>),\n tensor(8.9150, grad_fn=<DivBackward0>),\n tensor(8.9734, grad_fn=<DivBackward0>),\n tensor(9.0532, grad_fn=<DivBackward0>),\n tensor(8.9509, grad_fn=<DivBackward0>),\n tensor(8.9648, grad_fn=<DivBackward0>),\n tensor(8.9128, grad_fn=<DivBackward0>),\n tensor(8.9677, grad_fn=<DivBackward0>),\n tensor(8.8848, grad_fn=<DivBackward0>),\n tensor(8.8822, grad_fn=<DivBackward0>),\n tensor(8.8899, grad_fn=<DivBackward0>),\n tensor(8.9628, grad_fn=<DivBackward0>),\n tensor(8.8192, grad_fn=<DivBackward0>),\n tensor(8.8811, grad_fn=<DivBackward0>),\n tensor(8.9823, grad_fn=<DivBackward0>),\n tensor(8.9336, grad_fn=<DivBackward0>),\n tensor(8.9759, grad_fn=<DivBackward0>),\n tensor(8.9934, grad_fn=<DivBackward0>),\n tensor(8.9040, grad_fn=<DivBackward0>),\n tensor(8.8672, grad_fn=<DivBackward0>),\n tensor(8.9404, grad_fn=<DivBackward0>),\n tensor(8.9907, grad_fn=<DivBackward0>),\n tensor(8.9484, grad_fn=<DivBackward0>),\n tensor(8.9674, grad_fn=<DivBackward0>),\n tensor(8.9314, grad_fn=<DivBackward0>),\n tensor(8.9879, grad_fn=<DivBackward0>),\n tensor(8.8454, grad_fn=<DivBackward0>),\n tensor(8.8618, grad_fn=<DivBackward0>),\n tensor(8.9582, grad_fn=<DivBackward0>),\n tensor(8.8988, grad_fn=<DivBackward0>),\n tensor(8.9234, grad_fn=<DivBackward0>),\n tensor(8.8517, grad_fn=<DivBackward0>),\n tensor(8.8494, grad_fn=<DivBackward0>),\n tensor(8.9398, grad_fn=<DivBackward0>),\n tensor(8.8712, grad_fn=<DivBackward0>),\n tensor(8.9619, grad_fn=<DivBackward0>),\n tensor(8.8822, grad_fn=<DivBackward0>),\n tensor(8.9798, grad_fn=<DivBackward0>),\n tensor(8.9261, grad_fn=<DivBackward0>),\n tensor(8.8694, grad_fn=<DivBackward0>),\n tensor(8.8845, grad_fn=<DivBackward0>),\n tensor(8.9664, grad_fn=<DivBackward0>),\n tensor(8.9568, grad_fn=<DivBackward0>),\n tensor(8.9002, grad_fn=<DivBackward0>),\n tensor(8.9201, grad_fn=<DivBackward0>),\n tensor(8.8687, grad_fn=<DivBackward0>)]" - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a = nlp(\"__START__ Hello! My name is Dog. __END__\")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T10:27:20.316545Z", - "start_time": "2023-09-30T10:27:20.256760Z" - } - }, - "id": "3e23413c86063183" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "a.tokens[-2].xpostag" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2023-09-30T10:15:20.056860Z" - } - }, - "id": "d555d7f0223a624b" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "print(\"{:5} {:15} {:15} {:10} {:10} {:10}\".format('ID', 'TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))\n", - "for token in a[0].tokens:\n", - " print(\"{:15} {:15} {:10} {:10} {:10}\".format(token.text, token.lemma, token.upostag, token.head, token.deprel))" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2023-09-30T10:15:20.057770Z" - } - }, - "id": "a68cd3861e1ceb67" - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "start_time": "2023-09-30T10:15:20.058736Z" - } - }, - "id": "7e1a6f47f3aa54b8" - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/combo/training.ipynb b/combo/polish_model_training.ipynb similarity index 69% rename from combo/training.ipynb rename to combo/polish_model_training.ipynb index dd5ffdd49a9929d19e396144aa0b5127852d942f..dbadb6f1a146c3cfa3a7b86562c8da930834e9a5 100644 --- a/combo/training.ipynb +++ b/combo/polish_model_training.ipynb @@ -7,14 +7,14 @@ "metadata": { "collapsed": true, "ExecuteTime": { - "end_time": "2023-09-30T12:50:11.910551Z", - "start_time": "2023-09-30T12:50:05.607909Z" + "end_time": "2023-10-01T10:41:14.747107Z", + "start_time": "2023-10-01T10:41:06.896956Z" } }, "outputs": [], "source": [ "from combo.models.combo_model import ComboModel\n", - "from combo.data.vocabulary import Vocabulary\n", + "from combo.data.vocabulary import Vocabulary, FromInstancesVocabulary\n", "from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM\n", "from combo.modules.text_field_embedders import BasicTextFieldEmbedder\n", "from combo.nn.base import Linear\n", @@ -39,26 +39,6 @@ { "cell_type": "code", "execution_count": 2, - "outputs": [], - "source": [ - "vocabulary = Vocabulary.from_files(\n", - " '/Users/majajablonska/.combo/english-bert-base-ud29/vocabulary',\n", - " oov_token='_',\n", - " padding_token='__PAD__'\n", - ")" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T12:50:12.091155Z", - "start_time": "2023-09-30T12:50:11.905639Z" - } - }, - "id": "e0ac599d12cc33df" - }, - { - "cell_type": "code", - "execution_count": 3, "outputs": [ { "data": { @@ -66,7 +46,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "070af12f14f74b45bbd86e1c98a7ac37" + "model_id": "df8b72b251b8485ca22453990bdc528e" } }, "metadata": {}, @@ -88,7 +68,19 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "1651821eb8c24eb78ac87d52eae4f5ca" + "model_id": "65f2fb4930db47c8b53536b94f5fd35b" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "building vocabulary: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "0f463b87e8ce470995f42f5c2d1b9aaa" } }, "metadata": {}, @@ -98,7 +90,6 @@ "source": [ "from combo.data.batch import Batch\n", "\n", - "\n", "def default_const_character_indexer(namespace = None):\n", " if namespace:\n", " return TokenConstPaddingCharactersIndexer(\n", @@ -137,40 +128,41 @@ " use_sem=False\n", ")\n", "\n", - "FILE_PATH = '/Users/majajablonska/Documents/train.conllu'\n", + "TRAIN_FILE_PATH = '/Users/majajablonska/Downloads/PDBUD-master-85167180bcbe0565a09269257456961365cf6ff3/PDB-UD/PDB-UD/PDBUD_train.conllu'\n", + "VAL_FILE_PATH = '/Users/majajablonska/Downloads/PDBUD-master-85167180bcbe0565a09269257456961365cf6ff3/PDB-UD/PDB-UD/PDBUD_train.conllu'\n", "data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n", - " data_path=FILE_PATH,\n", + " data_path=TRAIN_FILE_PATH,\n", " batch_size=16,\n", " batches_per_epoch=4,\n", " shuffle=True,\n", " collate_fn=lambda instances: Batch(instances).as_tensor_dict())\n", "val_data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,\n", - " data_path='/Users/majajablonska/Documents/test.conllu',\n", + " data_path=VAL_FILE_PATH,\n", " batch_size=16,\n", " batches_per_epoch=4,\n", " shuffle=True,\n", " collate_fn=lambda instances: Batch(instances).as_tensor_dict())\n", "\n", - "# vocabulary = FromInstancesVocabulary.from_instances_extended(\n", - "# data_loader.iter_instances(),\n", - "# non_padded_namespaces=['head_labels'],\n", - "# only_include_pretrained_words=False,\n", - "# oov_token='_',\n", - "# padding_token='__PAD__'\n", - "# )" + "vocabulary = FromInstancesVocabulary.from_instances_extended(\n", + " data_loader.iter_instances(),\n", + " non_padded_namespaces=['head_labels'],\n", + " only_include_pretrained_words=False,\n", + " oov_token='_',\n", + " padding_token='__PAD__'\n", + ")" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:29.370955Z", - "start_time": "2023-09-30T12:50:12.086509Z" + "end_time": "2023-10-01T10:41:44.737934Z", + "start_time": "2023-10-01T10:41:14.717845Z" } }, "id": "d74957f422f0b05b" }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "outputs": [], "source": [ "seq_encoder = ComboEncoder(layer_dropout_probability=0.33,\n", @@ -185,15 +177,15 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:30.257858Z", - "start_time": "2023-09-30T12:50:29.373620Z" + "end_time": "2023-10-01T10:41:45.478247Z", + "start_time": "2023-10-01T10:41:44.735503Z" } }, "id": "fa724d362fd6bd23" }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "outputs": [ { "name": "stdout", @@ -204,9 +196,9 @@ }, { "data": { - "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7fbabd8e7820>" + "text/plain": "<generator object SimpleDataLoader.iter_instances at 0x7f7c49bd14a0>" }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -232,21 +224,21 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:30.323025Z", - "start_time": "2023-09-30T12:50:30.262420Z" + "end_time": "2023-10-01T10:41:45.524880Z", + "start_time": "2023-10-01T10:41:45.478555Z" } }, "id": "f8a10f9892005fca" }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']\n", + "Some weights of the model checkpoint at allegro/herbert-base-cased were not used when initializing BertModel: ['cls.sso.sso_relationship.bias', 'cls.sso.sso_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -351,7 +343,7 @@ " ),\n", " embedding_dim=64\n", " ),\n", - " \"token\": TransformersWordEmbedder(\"bert-base-cased\", projection_dim=100)\n", + " \"token\": TransformersWordEmbedder(\"allegro/herbert-base-cased\", projection_dim=100)\n", " }\n", " ),\n", " upos_tagger=FeedForwardPredictor.from_vocab(\n", @@ -377,8 +369,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:33.692952Z", - "start_time": "2023-09-30T12:50:30.306121Z" + "end_time": "2023-10-01T10:41:49.351137Z", + "start_time": "2023-10-01T10:41:45.520206Z" } }, "id": "437d12054baaffa1" @@ -387,43 +379,24 @@ "cell_type": "code", "execution_count": 7, "outputs": [], - "source": [ - "#nlp = COMBO(model, dataset_reader)\n", - "\n", - "nlp = TrainableCombo(model, torch.optim.SGD)" - ], - "metadata": { - "collapsed": false, - "ExecuteTime": { - "end_time": "2023-09-30T12:50:33.719046Z", - "start_time": "2023-09-30T12:50:33.508698Z" - } - }, - "id": "16ae311c44073668" - }, - { - "cell_type": "code", - "execution_count": 8, - "outputs": [], "source": [ "data_loader.index_with(vocabulary)\n", "a = 0\n", "for i in data_loader:\n", - " break\n", - "output = nlp.forward(i)" + " break" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:53.315959Z", - "start_time": "2023-09-30T12:50:33.543824Z" + "end_time": "2023-10-01T10:42:26.427350Z", + "start_time": "2023-10-01T10:41:49.390846Z" } }, "id": "e131e0ec75dc6927" }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "outputs": [ { "name": "stderr", @@ -435,45 +408,45 @@ }, { "data": { - "text/plain": "tensor(33.0798, grad_fn=<AddBackward0>)" + "text/plain": "tensor(35.0692, grad_fn=<AddBackward0>)" }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "#val_data_loader.index_with(vocabulary)\n", - "nlp.training_step(i, 0)" + "val_data_loader.index_with(vocabulary)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:54.804317Z", - "start_time": "2023-09-30T12:50:53.310005Z" + "end_time": "2023-10-01T10:43:04.756241Z", + "start_time": "2023-10-01T10:42:26.425151Z" } }, "id": "195c71fcf8170ff" }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 18, "outputs": [], "source": [ - "import pytorch_lightning as pl" + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks.early_stopping import EarlyStopping" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:54.806552Z", - "start_time": "2023-09-30T12:50:54.800111Z" + "end_time": "2023-10-01T10:59:18.843745Z", + "start_time": "2023-10-01T10:59:18.758351Z" } }, "id": "dfb5ee72353e8c7f" }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 19, "outputs": [ { "name": "stderr", @@ -482,62 +455,124 @@ "GPU available: False, used: False\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", - "HPU available: False, using: 0 HPUs\n", - "/Users/majajablonska/miniconda/envs/combo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", - " warning_cache.warn(\n" + "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ - "trainer = pl.Trainer(max_epochs=10)" + "nlp = TrainableCombo(model, torch.optim.Adam,\n", + " optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002},\n", + " validation_metrics=['EM'])\n", + "\n", + "trainer = pl.Trainer(max_epochs=10,\n", + " default_root_dir='/Users/majajablonska/Documents/Workspace/combo_training',\n", + " gradient_clip_val=5,\n", + " callbacks=[EarlyStopping(monitor='EM', mode='max', patience=1)])" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:50:55.230032Z", - "start_time": "2023-09-30T12:50:54.809818Z" + "end_time": "2023-10-01T11:01:10.767276Z", + "start_time": "2023-10-01T11:01:10.518908Z" } }, "id": "cefc5173154d1605" }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 20, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/majajablonska/miniconda/envs/combo/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:612: UserWarning: Checkpoint directory /Users/majajablonska/PycharmProjects/combo-lightning/combo/lightning_logs/version_1/checkpoints exists and is not empty.\n", - " rank_zero_warn(f\"Checkpoint directory {dirpath} exists and is not empty.\")\n", "\n", " | Name | Type | Params\n", "-------------------------------------\n", - "0 | model | ComboModel | 120 M \n", + "0 | model | ComboModel | 136 M \n", "-------------------------------------\n", - "12.0 M Trainable params\n", - "108 M Non-trainable params\n", - "120 M Total params\n", - "481.267 Total estimated model params size (MB)\n", - "`Trainer.fit` stopped: `max_epochs=10` reached.\n" + "12.1 M Trainable params\n", + "124 M Non-trainable params\n", + "136 M Total params\n", + "546.107 Total estimated model params size (MB)\n" ] + }, + { + "data": { + "text/plain": "Sanity Checking: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "329c37bd528043099a374e1371d19f5d" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "Training: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "8c397abae96b4789926ef45ecf96312e" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "Validation: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "db5fae0f7a3f402992b650822a433540" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "Validation: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "bccd02f9e4804e4ab7474cf9a2a7b4a1" + } + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "Validation: 0it [00:00, ?it/s]", + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "569b1d7bc1c140fba2a22b5b0bb6748f" + } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "trainer.fit(model=nlp, train_dataloaders=data_loader)#, val_dataloaders=val_data_loader)" + "trainer.fit(model=nlp, train_dataloaders=data_loader, val_dataloaders=val_data_loader)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:54:40.221842Z", - "start_time": "2023-09-30T12:54:40.016457Z" + "end_time": "2023-10-01T11:03:40.933833Z", + "start_time": "2023-10-01T11:01:12.101776Z" } }, "id": "e5af131bae4b1a33" }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 21, "outputs": [], "source": [ "from combo.predict import COMBO\n", @@ -546,71 +581,91 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:54:18.575330Z", - "start_time": "2023-09-30T12:54:18.411379Z" + "end_time": "2023-10-01T11:04:49.785003Z", + "start_time": "2023-10-01T11:04:49.700277Z" } }, "id": "3e23413c86063183" }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 22, "outputs": [], "source": [ - "a = predictor(\"Hello, I am Dog\")" + "a = predictor(\"Cześć, jestem psem.\")" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:54:19.712612Z", - "start_time": "2023-09-30T12:54:19.386496Z" + "end_time": "2023-10-01T11:04:51.187878Z", + "start_time": "2023-10-01T11:04:50.656049Z" } }, "id": "d555d7f0223a624b" }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 23, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "ID TOKEN LEMMA UPOS HEAD DEPREL \n", - "Hello, NOUN 2 det \n", - "I NOUN 0 root \n", - "am NOUN 2 punct \n", - "Dog NOUN 2 punct \n" + "TOKEN LEMMA UPOS HEAD DEPREL \n", + "Cześć, cześć, NOUN 2 amod \n", + "jestem byst NOUN 0 root \n", + "psem. psem NOUN 2 obj \n" ] } ], "source": [ - "print(\"{:5} {:15} {:15} {:10} {:10} {:10}\".format('ID', 'TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))\n", + "print(\"{:15} {:15} {:10} {:10} {:10}\".format('TOKEN', 'LEMMA', 'UPOS', 'HEAD', 'DEPREL'))\n", "for token in a.tokens:\n", " print(\"{:15} {:15} {:10} {:10} {:10}\".format(token.text, token.lemma, token.upostag, token.head, token.deprel))" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:54:20.307495Z", - "start_time": "2023-09-30T12:54:20.275818Z" + "end_time": "2023-10-01T11:04:52.437027Z", + "start_time": "2023-10-01T11:04:52.370216Z" } }, "id": "a68cd3861e1ceb67" }, { "cell_type": "code", - "execution_count": 15, - "outputs": [], - "source": [], + "execution_count": 24, + "outputs": [ + { + "data": { + "text/plain": "{'UPOS_ACC': 0.3262315006792021,\n 'XPOS_ACC': 0.15857582040466148,\n 'SEMREL_ACC': 0.0,\n 'LEMMA_ACC': 0.27010795738900406,\n 'FEATS_ACC': 0.05201258311289054,\n 'EM': 0.0036900369450449944,\n 'UAS': 0.12608136126403088,\n 'LAS': 0.0319224994637878,\n 'UEM': 0.006818181818181818,\n 'LEM': 0.0011363636363636363,\n 'EUAS': 0.0,\n 'ELAS': 0.0,\n 'EUEM': 0.0,\n 'ELEM': 0.0,\n 'partial_loss/upostag_loss': 4.990175724029541,\n 'partial_loss/xpostag_loss': 12.692313194274902,\n 'partial_loss/feats_loss': 23.47402000427246,\n 'partial_loss/lemma_loss': 1.5392004251480103,\n 'partial_loss/head_loss': 7.714131832122803,\n 'partial_loss/deprel_loss': 19.59114646911621,\n 'partial_loss/cycle_loss': 0.0}" + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.get_metrics()" + ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-30T12:54:08.125735Z", - "start_time": "2023-09-30T12:54:08.083618Z" + "end_time": "2023-10-01T11:04:56.344221Z", + "start_time": "2023-10-01T11:04:56.286489Z" } }, - "id": "7e1a6f47f3aa54b8" + "id": "d6578197e2403037" + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "6391caeb9e843c0b" } ], "metadata": { diff --git a/combo/training/scheduler.py b/combo/training/scheduler.py index eff71b90acb652b986c9e797371e50a8fdfd737b..8334ff1afb601c7c46a0d02ea777eef8e20b873d 100644 --- a/combo/training/scheduler.py +++ b/combo/training/scheduler.py @@ -1,2 +1,43 @@ -class Scheduler: - pass +import torch +from typing import Callable, List, Union +from overrides import overrides + + +class Scheduler(torch.optim.lr_scheduler.LambdaLR): + def __init__(self, + optimizer: torch.optim.Optimizer, + patience: int = 6, + decreases: int = 2, + threshold: float = 1e-3, + last_epoch: int = -1, + verbose: bool = False): + super().__init__(optimizer, [self._lr_lambda], last_epoch, verbose) + self.patience = patience + self.decreases = decreases + self.threshold = threshold + self.start_patience = patience + self.best_score = 0.0 + + @staticmethod + def _lr_lambda(idx: int) -> float: + return 1.0 / (1.0 + idx * 1e-4) + + def step(self, metric: float = None) -> None: + super().step() + + if metric is not None: + if metric - self.best_score > self.threshold: + self.best_score = metric if metric > self.best_score else self.best_score + self.patience = self.start_patience + else: + if self.patience <= 1: + if self.decreases == 0: + # The Trainer should trigger early stopping + self.patience = 0 + else: + self.patience = self.start_patience + self.decreases -= 1 + self.threshold /= 2 + self.base_lrs = [x / 2 for x in self.base_lrs] + else: + self.patience -= 1 diff --git a/combo/training/trainable_combo.py b/combo/training/trainable_combo.py index 554176809614bde04717872a4bdf717d03ffbcb9..286a0762164fa5e9c93fc095c17cb88b0a0f1128 100644 --- a/combo/training/trainable_combo.py +++ b/combo/training/trainable_combo.py @@ -1,22 +1,33 @@ -from typing import Optional, Type +from typing import Any, Dict, List, Optional, Type import pytorch_lightning as pl +import torch from torch import Tensor from combo.config import FromParameters from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.modules.model import Model +from combo.training import Scheduler class TrainableCombo(pl.LightningModule, FromParameters): def __init__(self, model: Model, - optimizer_type: Type, - learning_rate: float = 0.1): + optimizer_type: Type = torch.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + scheduler_type: Type = Scheduler, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + validation_metrics: List[str] = None): super().__init__() self.model = model self._optimizer_type = optimizer_type - self._lr = learning_rate + self._optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {} + + self._scheduler_type = scheduler_type + self._scheduler_kwargs = scheduler_kwargs if scheduler_kwargs else {} + + self._validation_metrics = validation_metrics if validation_metrics else [] + def forward(self, batch: TensorDict) -> TensorDict: return self.model.batch_outputs(batch, self.model.training) @@ -28,8 +39,17 @@ class TrainableCombo(pl.LightningModule, FromParameters): def validation_step(self, batch: TensorDict, batch_idx: int) -> Tensor: output = self.forward(batch) - self.log("validation_loss", output['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True) + metrics = self.model.get_metrics() + for k in metrics.keys(): + if k in self._validation_metrics: + self.log(k, metrics[k], on_epoch=True, prog_bar=True, logger=True) return output["loss"] + def lr_scheduler_step(self, scheduler: torch.optim.lr_scheduler, metric: Optional[Any]) -> None: + scheduler.step(metric=metric) + def configure_optimizers(self): - return self._optimizer_type(self.model.parameters(), lr=self._lr) + optimizer = self._optimizer_type(self.model.parameters(), **self._optimizer_kwargs) + return ([optimizer], + [{'scheduler': self._scheduler_type(optimizer, **self._scheduler_kwargs), + 'interval': 'epoch'}])