{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ff9142be-4752-491b-bf9d-456763a2d7a5",
   "metadata": {},
   "source": [
    "# NER model training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d1ba70b-c57d-421c-9298-65b7a2f06ef7",
   "metadata": {},
   "source": [
    "Specify model directory. There will be 4 files saved there: \n",
    "- confg.json - json config file specifying model architecture\n",
    "- char_to_id.json - mapping between characters and ids\n",
    "- label_to_id.json - mapping between ner tag and ids\n",
    "- best_model.ckpt - model weights"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a3009202-3e5d-4f1c-9b7a-016a3d8694d3",
   "metadata": {},
   "source": [
    "To demonstrate how to do it we will use very small subset of NER dataset for Polish."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "69f208e7-f717-4452-a60c-5730e7e48323",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "serialization_directory = Path(\"./models/notebook_example\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fc9c342a-4e41-4c9b-b0ae-e6c460d17404",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\lpsze\\anaconda3\\envs\\combo_ner_integration\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from combo.ner_modules.utils.utils import fix_common_warnings\n",
    "fix_common_warnings()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8972f473-19e4-49c3-8365-5c59ec5a811d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "config = {\n",
    "  \"data\": {\n",
    "    \"path_data\" : r\".\\example_data\",\n",
    "    \"use_char_level_embeddings\": True,\n",
    "    \"use_start_end_token\": True,\n",
    "    \"tokenize_entities\": True,\n",
    "    \"batch_size\": 32,\n",
    "    \"encoding\": \"utf-8\",\n",
    "    \"num_workers\": 1\n",
    "  },\n",
    "\n",
    "  \"model\": {\n",
    "    \"bert_embedder\": {\n",
    "        \"pretrained_model_name\": \"allegro/herbert-base-cased\",\n",
    "        \"pretrained_model_type\": \"AutoModel\",\n",
    "        \"projection_dimension\": None,\n",
    "        \"freeze_bert\": True,\n",
    "        \"token_pooling\": True,\n",
    "        \"pooling_strategy\": \"max\"\n",
    "                     },\n",
    "    \"char_embedder\": {\"type\" : \"combo\",\n",
    "                      \"char_embedding_dim\":  64\n",
    "                     },\n",
    "    \"classifier\": {\"type\" : \"vanilla\",\n",
    "                   \"to_tag_space\" :  \"linear\"},\n",
    "    \"dropout\": 0\n",
    "            },\n",
    "  \"loss\": \"ce\",\n",
    "  \"learning_rate\": 0.0007585775750,\n",
    "  \"callbacks\": {\"FixedProgressBar\": True},\n",
    "  \"trainer\": {\n",
    "          \"devices\" : [0],\n",
    "          \"max_epochs\": 2,\n",
    "          \"accelerator\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
    "          \"log_every_n_steps\": 10}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41e3ec1a-661e-44a3-bb61-74b2477771a9",
   "metadata": {},
   "source": [
    "# Training using config file"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a36614a6-8e76-4006-8c0b-3d704189a721",
   "metadata": {},
   "source": [
    "## create vocabularies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9f402941-42ea-47ed-a2e1-757aaa48d1f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.data.utils import create_tag2id, create_char2id\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "aea6de16-4670-4e78-a3f7-0b2f667d8d68",
   "metadata": {},
   "outputs": [],
   "source": [
    "char_to_id = create_char2id(file_path=Path(config[\"data\"][\"path_data\"]) / \"train.txt\", )\n",
    "label_to_id = create_tag2id(file_path=Path(config[\"data\"][\"path_data\"]) / \"train.txt\",\n",
    "                            encoding=config[\"data\"][\"encoding\"],\n",
    "                            include_special_tokens=config[\"data\"][\"use_start_end_token\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d71e974e-2f89-4a6b-aa77-31908ad28a6d",
   "metadata": {},
   "source": [
    "## create tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "83529f2e-34bb-4d42-9378-cec4e2e84c83",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.utils.constructors import construct_tokenizer_from_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "efb8125c-7b0a-4701-8d67-ae238874b8f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using model LAMBO-UD_Polish-PDB\n"
     ]
    }
   ],
   "source": [
    "tokenizer = construct_tokenizer_from_config(config=config,\n",
    "                                            char_to_id_map=char_to_id,\n",
    "                                            label_to_id_map=label_to_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4eba05f3-0847-45f6-8d8e-98222b0aee8e",
   "metadata": {},
   "source": [
    "## create pytorch lightning datamodule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8f643dd5-279b-4e29-b6db-abf65e282741",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.utils.constructors import construct_data_module_from_config\n",
    "data_module = construct_data_module_from_config(config=config,\n",
    "                                                tokenizer=tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9df229a-ef28-4453-b188-ce62221523b6",
   "metadata": {},
   "source": [
    "## create loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "fa9235de-fddd-4343-a0d6-fef35e180b3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.utils.constructors import construct_loss_from_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0183ecab-56a1-40e5-ab6f-52c2652cf2b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = construct_loss_from_config(config=config,\n",
    "                                  label_to_id=label_to_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9cebfb2-e570-4fa1-a100-c1cd07eeab55",
   "metadata": {},
   "source": [
    "## saving data to serialization directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d80fff2a-089e-4fa0-b14b-0c6b2037b88e",
   "metadata": {},
   "outputs": [],
   "source": [
    "serialization_directory.mkdir(parents=True, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "ff58efd1-9341-4ac3-b381-512784f1f3fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open(serialization_directory / \"char_to_id.json\", \"w+\") as f:\n",
    "    json.dump(char_to_id, f)\n",
    "\n",
    "with open(serialization_directory / \"label_to_id.json\", \"w+\") as f:\n",
    "    json.dump(label_to_id, f)\n",
    "\n",
    "with open(serialization_directory / \"config.json\", \"w+\") as f:\n",
    "    json.dump(config, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f44ea8a-010f-40df-bc68-5403631b0c01",
   "metadata": {},
   "source": [
    "## creating model instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aea14060-0844-4178-954e-407cab7f00a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.NerModel import NerModel "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5ba60867-3187-4b39-a79a-5b4e4077944b",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NerModel(loss_fn=loss,\n",
    "                 char_to_id_map=char_to_id,\n",
    "                 label_to_id_map=label_to_id,\n",
    "                 config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e7d6b32-b256-4a50-bc7e-f0ad00c067e9",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "d93dc81c-e148-4f5d-b8f4-b6f8a80ccf5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "from combo.ner_modules.utils.constructors import construct_callbacks_from_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "6f91e688-d197-41b5-b929-0925af024846",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name          | Type              | Params\n",
      "----------------------------------------------------\n",
      "0 | bert_embedder | BertEmbedder      | 124 M \n",
      "1 | char_embedder | ComboCharEmbedder | 546 K \n",
      "2 | classifier    | VanillaClassifier | 65.8 K\n",
      "3 | dropout       | Dropout           | 0     \n",
      "4 | loss_fn       | CrossEntropyLoss  | 0     \n",
      "----------------------------------------------------\n",
      "612 K     Trainable params\n",
      "124 M     Non-trainable params\n",
      "125 M     Total params\n",
      "500.223   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|█████████████████████████████████████████████| 20/20 [00:19<00:00,  1.03it/s, v_num=11, train_loss=1.200]\n",
      "Epoch 1: 100%|█| 20/20 [00:19<00:00,  1.03it/s, v_num=11, train_loss=0.630, validation_loss=0.480, validation_precision\n",
      "Epoch 1: 100%|█| 20/20 [00:23<00:00,  1.17s/it, v_num=11, train_loss=0.630, validation_loss=0.251, validation_precision"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|█| 20/20 [00:28<00:00,  1.43s/it, v_num=11, train_loss=0.630, validation_loss=0.251, validation_precision\n"
     ]
    }
   ],
   "source": [
    "params = config[\"trainer\"]\n",
    "params[\"callbacks\"] = construct_callbacks_from_config(config.get(\"callbacks\", {}))\n",
    "params[\"default_root_dir\"] = serialization_directory\n",
    "trainer = pl.Trainer(**params)\n",
    "\n",
    "# start training\n",
    "trainer.fit(model,\n",
    "            datamodule=data_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74a69251-7df3-474d-96de-59caaf41467b",
   "metadata": {},
   "source": [
    "## Evaluate on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b659848a-53ad-4367-87db-4d9f6f69fbb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Restoring states from the checkpoint path at models\\notebook_example\\lightning_logs\\version_11\\checkpoints\\epoch=2-step=40.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "Loaded model weights from the checkpoint at models\\notebook_example\\lightning_logs\\version_11\\checkpoints\\epoch=2-step=40.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing: 0it [00:00, ?it/s]                      precision    recall  f1-score   support\n",
      "\n",
      "     nam_adj_country       1.00      0.00      0.00         3\n",
      "      nam_liv_person       1.00      0.00      0.00         3\n",
      "    nam_loc_gpe_city       1.00      0.00      0.00         1\n",
      " nam_loc_gpe_country       1.00      0.00      0.00         8\n",
      "     nam_org_company       1.00      0.00      0.00        10\n",
      "      nam_org_nation       1.00      0.00      0.00         3\n",
      "nam_org_organization       1.00      0.00      0.00         3\n",
      "    nam_oth_currency       1.00      0.00      0.00         2\n",
      "        nam_oth_tech       1.00      0.00      0.00         5\n",
      "       nam_pro_brand       1.00      0.00      0.00         2\n",
      "    nam_pro_software       1.00      0.00      0.00        37\n",
      "\n",
      "           micro avg       1.00      0.00      0.00        77\n",
      "           macro avg       1.00      0.00      0.00        77\n",
      "        weighted avg       1.00      0.00      0.00        77\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">           epoch           </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            3.0            </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">          test_f1          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            0.0            </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            1.0            </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            0.0            </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m          epoch          \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           3.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m         test_f1         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           0.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           1.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           0.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(verbose=True,\n",
    "                       ckpt_path='best',\n",
    "                        datamodule=data_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "033aac6e-1747-4b39-9956-1cafebb04b6b",
   "metadata": {},
   "source": [
    "# Training using as little config file as possible"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6099ba5d-5702-45b9-a7af-d79945b06d33",
   "metadata": {},
   "source": [
    "## create vocabularies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "a7ff832e-2fc8-4fba-95b4-416c169779b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.data.utils import create_tag2id, create_char2id\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3c86f465-998e-4949-8b11-547ed83c3deb",
   "metadata": {},
   "outputs": [],
   "source": [
    "training_data_path = Path(r\".\\example_data\\train.txt\") \n",
    "char_to_id = create_char2id(file_path=training_data_path)\n",
    "label_to_id = create_tag2id(file_path=training_data_path,\n",
    "                            encoding=\"utf-8\",\n",
    "                            include_special_tokens=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79e1507c-e0e6-42e9-ac30-c899edafaa59",
   "metadata": {},
   "source": [
    "## create tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a6d3cfed-266d-487c-a2d8-82a154bcc5b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.data.NerTokenizer import NerTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "aedfd9ae-52e7-447e-a70a-7cf3f1d8ca79",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using model LAMBO-UD_Polish-PDB\n"
     ]
    }
   ],
   "source": [
    "tokenizer = NerTokenizer(pretrained_model_type=\"AutoModel\",\n",
    "                         pretrained_model_name=\"allegro/herbert-base-cased\",\n",
    "                         char_to_id_map=char_to_id,\n",
    "                         label_to_id_map=label_to_id,\n",
    "                         use_char_level_embeddings=True,\n",
    "                         use_start_end_token=True,\n",
    "                         tokenize_entities=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a021776-2a19-4c9c-9dd1-b113fac8cdd4",
   "metadata": {},
   "source": [
    "## create pytorch lightning datamodule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "4e5374a1-d83f-4cb6-8bd4-ac95ebb75886",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.NerDataModule import NerDataModule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "8793da2d-39a7-49e0-85c1-56e9761725be",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_path = Path(r\".\\example_data\") \n",
    "data_module = NerDataModule(path_data=data_path,\n",
    "                            tokenizer=tokenizer,\n",
    "                            batch_size=32,\n",
    "                            encoding=\"utf-8\",\n",
    "                            num_workers=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20a27e33-103e-406d-b621-a47546483389",
   "metadata": {},
   "source": [
    "## create losss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "556d2b85-2e16-4488-90c8-a3cbc6452c61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "9accfeae-9f16-4b0b-8113-1e2352ebebd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a8a90aa-100b-4ed8-b5c5-5f02ddf770c5",
   "metadata": {},
   "source": [
    "## create model instance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "df74c687-d1d7-4805-ae13-456fb117e1f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.NerModel import NerModel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d57fa876-b467-404b-ab1d-095a9961310a",
   "metadata": {},
   "source": [
    "Minimal config should contain information about model architecture, learning rate and whether to use start and end tokens as well as whether to use character level embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "a9941c27-8bb5-4e8d-b9a5-f9a66cae4dad",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "  \"data\": {\n",
    "    \"use_char_level_embeddings\": True,\n",
    "    \"use_start_end_token\": True},\n",
    "    \n",
    "  \"model\": {\n",
    "    \"bert_embedder\": {\n",
    "        \"pretrained_model_name\": \"allegro/herbert-base-cased\",\n",
    "        \"pretrained_model_type\": \"AutoModel\",\n",
    "        \"projection_dimension\": None,\n",
    "        \"freeze_bert\": True,\n",
    "        \"token_pooling\": True,\n",
    "        \"pooling_strategy\": \"max\"\n",
    "                     },\n",
    "    \"char_embedder\": {\"type\" : \"combo\",\n",
    "                      \"char_embedding_dim\":  64\n",
    "                     },\n",
    "    \"classifier\": {\"type\" : \"vanilla\",\n",
    "                   \"to_tag_space\" :  \"linear\"},\n",
    "    \"dropout\": 0\n",
    "            },\n",
    "    \"learning_rate\": 0.0007585775750}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "5eb269ee-7b5d-48a9-b0f1-4695741625a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = NerModel(loss_fn=loss,\n",
    "                 char_to_id_map=char_to_id,\n",
    "                 label_to_id_map=label_to_id,\n",
    "                 config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a5da2da-2945-4489-98b1-9e5d472ab06d",
   "metadata": {},
   "source": [
    "## train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "65c97e08-9740-4943-8498-02747e3d53d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from combo.ner_modules.callbacks.FixedProgressBar import FixedProgressBar\n",
    "import pytorch_lightning as pl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "02c4c729-e763-4f25-94b8-4166a0e1c231",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "\n",
      "  | Name          | Type              | Params\n",
      "----------------------------------------------------\n",
      "0 | bert_embedder | BertEmbedder      | 124 M \n",
      "1 | char_embedder | ComboCharEmbedder | 546 K \n",
      "2 | classifier    | VanillaClassifier | 65.8 K\n",
      "3 | dropout       | Dropout           | 0     \n",
      "4 | loss_fn       | CrossEntropyLoss  | 0     \n",
      "----------------------------------------------------\n",
      "612 K     Trainable params\n",
      "124 M     Non-trainable params\n",
      "125 M     Total params\n",
      "500.223   Total estimated model params size (MB)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                                                                                       "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\lpsze\\anaconda3\\envs\\combo_ner_integration\\lib\\site-packages\\pytorch_lightning\\loops\\fit_loop.py:281: PossibleUserWarning: The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "  rank_zero_warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████████████████████████████████████████| 20/20 [00:19<00:00,  1.05it/s, v_num=0, train_loss=1.020]\n",
      "Epoch 1: 100%|█| 20/20 [00:19<00:00,  1.03it/s, v_num=0, train_loss=0.128, validation_loss=0.448, validation_precision=\n",
      "Epoch 1: 100%|█| 20/20 [00:23<00:00,  1.17s/it, v_num=0, train_loss=0.128, validation_loss=0.256, validation_precision="
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|█| 20/20 [00:28<00:00,  1.42s/it, v_num=0, train_loss=0.128, validation_loss=0.256, validation_precision=\n"
     ]
    }
   ],
   "source": [
    "callbacks = [FixedProgressBar()]\n",
    "trainer = pl.Trainer(devices = [0],\n",
    "                     accelerator=\"cuda\",\n",
    "                     max_epochs=2,\n",
    "                     callbacks=callbacks)\n",
    "trainer.fit(model,\n",
    "            datamodule=data_module)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5367be2f-83dc-41fe-9b73-7051abb7b5d9",
   "metadata": {},
   "source": [
    "## evaluate on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "5ffc539c-d4b1-4ba1-8876-de65cafa1cf9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Restoring states from the checkpoint path at L:\\combo-lightning\\docs\\ner_docs\\lightning_logs\\version_0\\checkpoints\\epoch=2-step=40.ckpt\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n",
      "Loaded model weights from the checkpoint at L:\\combo-lightning\\docs\\ner_docs\\lightning_logs\\version_0\\checkpoints\\epoch=2-step=40.ckpt\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing: 0it [00:00, ?it/s]                      precision    recall  f1-score   support\n",
      "\n",
      "     nam_adj_country       1.00      0.00      0.00         3\n",
      "      nam_liv_person       1.00      0.00      0.00         3\n",
      "    nam_loc_gpe_city       1.00      0.00      0.00         1\n",
      " nam_loc_gpe_country       1.00      0.00      0.00         8\n",
      "     nam_org_company       1.00      0.00      0.00        10\n",
      "      nam_org_nation       1.00      0.00      0.00         3\n",
      "nam_org_organization       1.00      0.00      0.00         3\n",
      "    nam_oth_currency       1.00      0.00      0.00         2\n",
      "        nam_oth_tech       1.00      0.00      0.00         5\n",
      "       nam_pro_brand       1.00      0.00      0.00         2\n",
      "    nam_pro_software       1.00      0.00      0.00        37\n",
      "\n",
      "           micro avg       1.00      0.00      0.00        77\n",
      "           macro avg       1.00      0.00      0.00        77\n",
      "        weighted avg       1.00      0.00      0.00        77\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">           epoch           </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            3.0            </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">          test_f1          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            0.0            </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            1.0            </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">            0.0            </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m          epoch          \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           3.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m         test_f1         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           0.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           1.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m           0.0           \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "reults = trainer.test(verbose=True,\n",
    "                      ckpt_path='best',\n",
    "                      datamodule=data_module)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "combo_ner_integration",
   "language": "python",
   "name": "combo_ner_integration"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}