{ "cells": [ { "cell_type": "markdown", "id": "ff9142be-4752-491b-bf9d-456763a2d7a5", "metadata": {}, "source": [ "# NER model training" ] }, { "cell_type": "markdown", "id": "3d1ba70b-c57d-421c-9298-65b7a2f06ef7", "metadata": {}, "source": [ "Specify model directory. There will be 4 files saved there: \n", "- confg.json - json config file specifying model architecture\n", "- char_to_id.json - mapping between characters and ids\n", "- label_to_id.json - mapping between ner tag and ids\n", "- best_model.ckpt - model weights" ] }, { "cell_type": "markdown", "id": "a3009202-3e5d-4f1c-9b7a-016a3d8694d3", "metadata": {}, "source": [ "To demonstrate how to do it we will use very small subset of NER dataset for Polish." ] }, { "cell_type": "code", "execution_count": 1, "id": "69f208e7-f717-4452-a60c-5730e7e48323", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "serialization_directory = Path(\"./models/notebook_example\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "fc9c342a-4e41-4c9b-b0ae-e6c460d17404", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\lpsze\\anaconda3\\envs\\combo_ner_integration\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from combo.ner_modules.utils.utils import fix_common_warnings\n", "fix_common_warnings()" ] }, { "cell_type": "code", "execution_count": 3, "id": "8972f473-19e4-49c3-8365-5c59ec5a811d", "metadata": {}, "outputs": [], "source": [ "import torch \n", "config = {\n", " \"data\": {\n", " \"path_data\" : r\".\\example_data\",\n", " \"use_char_level_embeddings\": True,\n", " \"use_start_end_token\": True,\n", " \"tokenize_entities\": True,\n", " \"batch_size\": 32,\n", " \"encoding\": \"utf-8\",\n", " \"num_workers\": 1\n", " },\n", "\n", " \"model\": {\n", " \"bert_embedder\": {\n", " \"pretrained_model_name\": \"allegro/herbert-base-cased\",\n", " \"pretrained_model_type\": \"AutoModel\",\n", " \"projection_dimension\": None,\n", " \"freeze_bert\": True,\n", " \"token_pooling\": True,\n", " \"pooling_strategy\": \"max\"\n", " },\n", " \"char_embedder\": {\"type\" : \"combo\",\n", " \"char_embedding_dim\": 64\n", " },\n", " \"classifier\": {\"type\" : \"vanilla\",\n", " \"to_tag_space\" : \"linear\"},\n", " \"dropout\": 0\n", " },\n", " \"loss\": \"ce\",\n", " \"learning_rate\": 0.0007585775750,\n", " \"callbacks\": {\"FixedProgressBar\": True},\n", " \"trainer\": {\n", " \"devices\" : [0],\n", " \"max_epochs\": 2,\n", " \"accelerator\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n", " \"log_every_n_steps\": 10}\n", "}" ] }, { "cell_type": "markdown", "id": "41e3ec1a-661e-44a3-bb61-74b2477771a9", "metadata": {}, "source": [ "# Training using config file" ] }, { "cell_type": "markdown", "id": "a36614a6-8e76-4006-8c0b-3d704189a721", "metadata": {}, "source": [ "## create vocabularies" ] }, { "cell_type": "code", "execution_count": 4, "id": "9f402941-42ea-47ed-a2e1-757aaa48d1f6", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.data.utils import create_tag2id, create_char2id\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 5, "id": "aea6de16-4670-4e78-a3f7-0b2f667d8d68", "metadata": {}, "outputs": [], "source": [ "char_to_id = create_char2id(file_path=Path(config[\"data\"][\"path_data\"]) / \"train.txt\", )\n", "label_to_id = create_tag2id(file_path=Path(config[\"data\"][\"path_data\"]) / \"train.txt\",\n", " encoding=config[\"data\"][\"encoding\"],\n", " include_special_tokens=config[\"data\"][\"use_start_end_token\"])" ] }, { "cell_type": "markdown", "id": "d71e974e-2f89-4a6b-aa77-31908ad28a6d", "metadata": {}, "source": [ "## create tokenizer" ] }, { "cell_type": "code", "execution_count": 6, "id": "83529f2e-34bb-4d42-9378-cec4e2e84c83", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.utils.constructors import construct_tokenizer_from_config" ] }, { "cell_type": "code", "execution_count": 7, "id": "efb8125c-7b0a-4701-8d67-ae238874b8f4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using model LAMBO-UD_Polish-PDB\n" ] } ], "source": [ "tokenizer = construct_tokenizer_from_config(config=config,\n", " char_to_id_map=char_to_id,\n", " label_to_id_map=label_to_id)" ] }, { "cell_type": "markdown", "id": "4eba05f3-0847-45f6-8d8e-98222b0aee8e", "metadata": {}, "source": [ "## create pytorch lightning datamodule" ] }, { "cell_type": "code", "execution_count": 8, "id": "8f643dd5-279b-4e29-b6db-abf65e282741", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.utils.constructors import construct_data_module_from_config\n", "data_module = construct_data_module_from_config(config=config,\n", " tokenizer=tokenizer)" ] }, { "cell_type": "markdown", "id": "e9df229a-ef28-4453-b188-ce62221523b6", "metadata": {}, "source": [ "## create loss" ] }, { "cell_type": "code", "execution_count": 9, "id": "fa9235de-fddd-4343-a0d6-fef35e180b3f", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.utils.constructors import construct_loss_from_config" ] }, { "cell_type": "code", "execution_count": 10, "id": "0183ecab-56a1-40e5-ab6f-52c2652cf2b9", "metadata": {}, "outputs": [], "source": [ "loss = construct_loss_from_config(config=config,\n", " label_to_id=label_to_id)" ] }, { "cell_type": "markdown", "id": "a9cebfb2-e570-4fa1-a100-c1cd07eeab55", "metadata": {}, "source": [ "## saving data to serialization directory" ] }, { "cell_type": "code", "execution_count": 11, "id": "d80fff2a-089e-4fa0-b14b-0c6b2037b88e", "metadata": {}, "outputs": [], "source": [ "serialization_directory.mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 12, "id": "ff58efd1-9341-4ac3-b381-512784f1f3fe", "metadata": {}, "outputs": [], "source": [ "import json\n", "with open(serialization_directory / \"char_to_id.json\", \"w+\") as f:\n", " json.dump(char_to_id, f)\n", "\n", "with open(serialization_directory / \"label_to_id.json\", \"w+\") as f:\n", " json.dump(label_to_id, f)\n", "\n", "with open(serialization_directory / \"config.json\", \"w+\") as f:\n", " json.dump(config, f)" ] }, { "cell_type": "markdown", "id": "8f44ea8a-010f-40df-bc68-5403631b0c01", "metadata": {}, "source": [ "## creating model instance" ] }, { "cell_type": "code", "execution_count": 13, "id": "aea14060-0844-4178-954e-407cab7f00a5", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.NerModel import NerModel " ] }, { "cell_type": "code", "execution_count": 14, "id": "5ba60867-3187-4b39-a79a-5b4e4077944b", "metadata": {}, "outputs": [], "source": [ "model = NerModel(loss_fn=loss,\n", " char_to_id_map=char_to_id,\n", " label_to_id_map=label_to_id,\n", " config=config)" ] }, { "cell_type": "markdown", "id": "4e7d6b32-b256-4a50-bc7e-f0ad00c067e9", "metadata": {}, "source": [ "## training" ] }, { "cell_type": "code", "execution_count": 15, "id": "d93dc81c-e148-4f5d-b8f4-b6f8a80ccf5b", "metadata": {}, "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from combo.ner_modules.utils.constructors import construct_callbacks_from_config" ] }, { "cell_type": "code", "execution_count": 16, "id": "6f91e688-d197-41b5-b929-0925af024846", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------------\n", "0 | bert_embedder | BertEmbedder | 124 M \n", "1 | char_embedder | ComboCharEmbedder | 546 K \n", "2 | classifier | VanillaClassifier | 65.8 K\n", "3 | dropout | Dropout | 0 \n", "4 | loss_fn | CrossEntropyLoss | 0 \n", "----------------------------------------------------\n", "612 K Trainable params\n", "124 M Non-trainable params\n", "125 M Total params\n", "500.223 Total estimated model params size (MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|█████████████████████████████████████████████| 20/20 [00:19<00:00, 1.03it/s, v_num=11, train_loss=1.200]\n", "Epoch 1: 100%|█| 20/20 [00:19<00:00, 1.03it/s, v_num=11, train_loss=0.630, validation_loss=0.480, validation_precision\n", "Epoch 1: 100%|█| 20/20 [00:23<00:00, 1.17s/it, v_num=11, train_loss=0.630, validation_loss=0.251, validation_precision" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=2` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: 100%|█| 20/20 [00:28<00:00, 1.43s/it, v_num=11, train_loss=0.630, validation_loss=0.251, validation_precision\n" ] } ], "source": [ "params = config[\"trainer\"]\n", "params[\"callbacks\"] = construct_callbacks_from_config(config.get(\"callbacks\", {}))\n", "params[\"default_root_dir\"] = serialization_directory\n", "trainer = pl.Trainer(**params)\n", "\n", "# start training\n", "trainer.fit(model,\n", " datamodule=data_module)" ] }, { "cell_type": "markdown", "id": "74a69251-7df3-474d-96de-59caaf41467b", "metadata": {}, "source": [ "## Evaluate on test data" ] }, { "cell_type": "code", "execution_count": 18, "id": "b659848a-53ad-4367-87db-4d9f6f69fbb6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Restoring states from the checkpoint path at models\\notebook_example\\lightning_logs\\version_11\\checkpoints\\epoch=2-step=40.ckpt\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "Loaded model weights from the checkpoint at models\\notebook_example\\lightning_logs\\version_11\\checkpoints\\epoch=2-step=40.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing: 0it [00:00, ?it/s] precision recall f1-score support\n", "\n", " nam_adj_country 1.00 0.00 0.00 3\n", " nam_liv_person 1.00 0.00 0.00 3\n", " nam_loc_gpe_city 1.00 0.00 0.00 1\n", " nam_loc_gpe_country 1.00 0.00 0.00 8\n", " nam_org_company 1.00 0.00 0.00 10\n", " nam_org_nation 1.00 0.00 0.00 3\n", "nam_org_organization 1.00 0.00 0.00 3\n", " nam_oth_currency 1.00 0.00 0.00 2\n", " nam_oth_tech 1.00 0.00 0.00 5\n", " nam_pro_brand 1.00 0.00 0.00 2\n", " nam_pro_software 1.00 0.00 0.00 37\n", "\n", " micro avg 1.00 0.00 0.00 77\n", " macro avg 1.00 0.00 0.00 77\n", " weighted avg 1.00 0.00 0.00 77\n", "\n" ] }, { "data": { "text/html": [ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> epoch </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 3.0 </span>│\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_f1 </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0 </span>│\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_precision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 1.0 </span>│\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_recall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0 </span>│\n", "└───────────────────────────┴───────────────────────────┘\n", "</pre>\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│\u001b[36m \u001b[0m\u001b[36m epoch \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.0 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_f1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "results = trainer.test(verbose=True,\n", " ckpt_path='best',\n", " datamodule=data_module)" ] }, { "cell_type": "markdown", "id": "033aac6e-1747-4b39-9956-1cafebb04b6b", "metadata": {}, "source": [ "# Training using as little config file as possible" ] }, { "cell_type": "markdown", "id": "6099ba5d-5702-45b9-a7af-d79945b06d33", "metadata": {}, "source": [ "## create vocabularies" ] }, { "cell_type": "code", "execution_count": 19, "id": "a7ff832e-2fc8-4fba-95b4-416c169779b0", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.data.utils import create_tag2id, create_char2id\n", "from pathlib import Path" ] }, { "cell_type": "code", "execution_count": 20, "id": "3c86f465-998e-4949-8b11-547ed83c3deb", "metadata": {}, "outputs": [], "source": [ "training_data_path = Path(r\".\\example_data\\train.txt\") \n", "char_to_id = create_char2id(file_path=training_data_path)\n", "label_to_id = create_tag2id(file_path=training_data_path,\n", " encoding=\"utf-8\",\n", " include_special_tokens=True)" ] }, { "cell_type": "markdown", "id": "79e1507c-e0e6-42e9-ac30-c899edafaa59", "metadata": {}, "source": [ "## create tokenizer" ] }, { "cell_type": "code", "execution_count": 21, "id": "a6d3cfed-266d-487c-a2d8-82a154bcc5b6", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.data.NerTokenizer import NerTokenizer" ] }, { "cell_type": "code", "execution_count": 22, "id": "aedfd9ae-52e7-447e-a70a-7cf3f1d8ca79", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using model LAMBO-UD_Polish-PDB\n" ] } ], "source": [ "tokenizer = NerTokenizer(pretrained_model_type=\"AutoModel\",\n", " pretrained_model_name=\"allegro/herbert-base-cased\",\n", " char_to_id_map=char_to_id,\n", " label_to_id_map=label_to_id,\n", " use_char_level_embeddings=True,\n", " use_start_end_token=True,\n", " tokenize_entities=True)" ] }, { "cell_type": "markdown", "id": "7a021776-2a19-4c9c-9dd1-b113fac8cdd4", "metadata": {}, "source": [ "## create pytorch lightning datamodule" ] }, { "cell_type": "code", "execution_count": 23, "id": "4e5374a1-d83f-4cb6-8bd4-ac95ebb75886", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.NerDataModule import NerDataModule" ] }, { "cell_type": "code", "execution_count": 24, "id": "8793da2d-39a7-49e0-85c1-56e9761725be", "metadata": {}, "outputs": [], "source": [ "data_path = Path(r\".\\example_data\") \n", "data_module = NerDataModule(path_data=data_path,\n", " tokenizer=tokenizer,\n", " batch_size=32,\n", " encoding=\"utf-8\",\n", " num_workers=1)" ] }, { "cell_type": "markdown", "id": "20a27e33-103e-406d-b621-a47546483389", "metadata": {}, "source": [ "## create losss function" ] }, { "cell_type": "code", "execution_count": 25, "id": "556d2b85-2e16-4488-90c8-a3cbc6452c61", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": 26, "id": "9accfeae-9f16-4b0b-8113-1e2352ebebd8", "metadata": {}, "outputs": [], "source": [ "loss = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "markdown", "id": "1a8a90aa-100b-4ed8-b5c5-5f02ddf770c5", "metadata": {}, "source": [ "## create model instance" ] }, { "cell_type": "code", "execution_count": 27, "id": "df74c687-d1d7-4805-ae13-456fb117e1f8", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.NerModel import NerModel" ] }, { "cell_type": "markdown", "id": "d57fa876-b467-404b-ab1d-095a9961310a", "metadata": {}, "source": [ "Minimal config should contain information about model architecture, learning rate and whether to use start and end tokens as well as whether to use character level embeddings" ] }, { "cell_type": "code", "execution_count": 28, "id": "a9941c27-8bb5-4e8d-b9a5-f9a66cae4dad", "metadata": {}, "outputs": [], "source": [ "config = {\n", " \"data\": {\n", " \"use_char_level_embeddings\": True,\n", " \"use_start_end_token\": True},\n", " \n", " \"model\": {\n", " \"bert_embedder\": {\n", " \"pretrained_model_name\": \"allegro/herbert-base-cased\",\n", " \"pretrained_model_type\": \"AutoModel\",\n", " \"projection_dimension\": None,\n", " \"freeze_bert\": True,\n", " \"token_pooling\": True,\n", " \"pooling_strategy\": \"max\"\n", " },\n", " \"char_embedder\": {\"type\" : \"combo\",\n", " \"char_embedding_dim\": 64\n", " },\n", " \"classifier\": {\"type\" : \"vanilla\",\n", " \"to_tag_space\" : \"linear\"},\n", " \"dropout\": 0\n", " },\n", " \"learning_rate\": 0.0007585775750}" ] }, { "cell_type": "code", "execution_count": 29, "id": "5eb269ee-7b5d-48a9-b0f1-4695741625a1", "metadata": {}, "outputs": [], "source": [ "model = NerModel(loss_fn=loss,\n", " char_to_id_map=char_to_id,\n", " label_to_id_map=label_to_id,\n", " config=config)" ] }, { "cell_type": "markdown", "id": "1a5da2da-2945-4489-98b1-9e5d472ab06d", "metadata": {}, "source": [ "## train" ] }, { "cell_type": "code", "execution_count": 30, "id": "65c97e08-9740-4943-8498-02747e3d53d0", "metadata": {}, "outputs": [], "source": [ "from combo.ner_modules.callbacks.FixedProgressBar import FixedProgressBar\n", "import pytorch_lightning as pl" ] }, { "cell_type": "code", "execution_count": 31, "id": "02c4c729-e763-4f25-94b8-4166a0e1c231", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "\n", " | Name | Type | Params\n", "----------------------------------------------------\n", "0 | bert_embedder | BertEmbedder | 124 M \n", "1 | char_embedder | ComboCharEmbedder | 546 K \n", "2 | classifier | VanillaClassifier | 65.8 K\n", "3 | dropout | Dropout | 0 \n", "4 | loss_fn | CrossEntropyLoss | 0 \n", "----------------------------------------------------\n", "612 K Trainable params\n", "124 M Non-trainable params\n", "125 M Total params\n", "500.223 Total estimated model params size (MB)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " " ] }, { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\lpsze\\anaconda3\\envs\\combo_ner_integration\\lib\\site-packages\\pytorch_lightning\\loops\\fit_loop.py:281: PossibleUserWarning: The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n", " rank_zero_warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0: 100%|██████████████████████████████████████████████| 20/20 [00:19<00:00, 1.05it/s, v_num=0, train_loss=1.020]\n", "Epoch 1: 100%|█| 20/20 [00:19<00:00, 1.03it/s, v_num=0, train_loss=0.128, validation_loss=0.448, validation_precision=\n", "Epoch 1: 100%|█| 20/20 [00:23<00:00, 1.17s/it, v_num=0, train_loss=0.128, validation_loss=0.256, validation_precision=" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=2` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 1: 100%|█| 20/20 [00:28<00:00, 1.42s/it, v_num=0, train_loss=0.128, validation_loss=0.256, validation_precision=\n" ] } ], "source": [ "callbacks = [FixedProgressBar()]\n", "trainer = pl.Trainer(devices = [0],\n", " accelerator=\"cuda\",\n", " max_epochs=2,\n", " callbacks=callbacks)\n", "trainer.fit(model,\n", " datamodule=data_module)" ] }, { "cell_type": "markdown", "id": "5367be2f-83dc-41fe-9b73-7051abb7b5d9", "metadata": {}, "source": [ "## evaluate on test data" ] }, { "cell_type": "code", "execution_count": 32, "id": "5ffc539c-d4b1-4ba1-8876-de65cafa1cf9", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Restoring states from the checkpoint path at L:\\combo-lightning\\docs\\ner_docs\\lightning_logs\\version_0\\checkpoints\\epoch=2-step=40.ckpt\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "Loaded model weights from the checkpoint at L:\\combo-lightning\\docs\\ner_docs\\lightning_logs\\version_0\\checkpoints\\epoch=2-step=40.ckpt\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Testing: 0it [00:00, ?it/s] precision recall f1-score support\n", "\n", " nam_adj_country 1.00 0.00 0.00 3\n", " nam_liv_person 1.00 0.00 0.00 3\n", " nam_loc_gpe_city 1.00 0.00 0.00 1\n", " nam_loc_gpe_country 1.00 0.00 0.00 8\n", " nam_org_company 1.00 0.00 0.00 10\n", " nam_org_nation 1.00 0.00 0.00 3\n", "nam_org_organization 1.00 0.00 0.00 3\n", " nam_oth_currency 1.00 0.00 0.00 2\n", " nam_oth_tech 1.00 0.00 0.00 5\n", " nam_pro_brand 1.00 0.00 0.00 2\n", " nam_pro_software 1.00 0.00 0.00 37\n", "\n", " micro avg 1.00 0.00 0.00 77\n", " macro avg 1.00 0.00 0.00 77\n", " weighted avg 1.00 0.00 0.00 77\n", "\n" ] }, { "data": { "text/html": [ "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃<span style=\"font-weight: bold\"> Test metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> epoch </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 3.0 </span>│\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_f1 </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0 </span>│\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_precision </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 1.0 </span>│\n", "│<span style=\"color: #008080; text-decoration-color: #008080\"> test_recall </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 0.0 </span>│\n", "└───────────────────────────┴───────────────────────────┘\n", "</pre>\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│\u001b[36m \u001b[0m\u001b[36m epoch \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.0 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_f1 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_precision \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 1.0 \u001b[0m\u001b[35m \u001b[0m│\n", "│\u001b[36m \u001b[0m\u001b[36m test_recall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "reults = trainer.test(verbose=True,\n", " ckpt_path='best',\n", " datamodule=data_module)" ] } ], "metadata": { "kernelspec": { "display_name": "combo_ner_integration", "language": "python", "name": "combo_ner_integration" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }