diff --git a/Dockerfile b/Dockerfiles/base/Dockerfile
similarity index 50%
rename from Dockerfile
rename to Dockerfiles/base/Dockerfile
index 6c48d1eead2e5cdb41b1a58e1943330e4cdecfae..40e67b491f591a43099e7b36eccc90f717f47c40 100644
--- a/Dockerfile
+++ b/Dockerfiles/base/Dockerfile
@@ -18,11 +18,20 @@ RUN apt-get install -y python3.6 python3-pip
 RUN pip3 install pip --upgrade
 RUN pip3 install wheel
 
-# Copy repository and install requirements
-COPY . ./poldeepner2
+# Install requirements and spacy
 WORKDIR "/poldeepner2"
+ADD ./requirements.txt /poldeepner2/requirements.txt
 RUN pip3 install -r requirements.txt
+RUN python3.6 -m spacy download pl_core_news_sm
+RUN python3.6 -m nltk.downloader punkt
 
-# download the necessary using config
-# run sever
-# expose port
\ No newline at end of file
+RUN apt-get install -y wget
+RUN apt-get install -y unzip
+
+# Download and unzip roberta_base_fairseq
+RUN mkdir -p models/roberta_base_fairseq
+RUN wget https://github.com/sdadas/polish-roberta/releases/download/models/roberta_base_fairseq.zip
+RUN unzip roberta_base_fairseq.zip -d models/roberta_base_fairseq
+RUN rm roberta_base_fairseq.zip
+
+COPY . .
diff --git a/Dockerfiles/kpwr_n82_base/Dockerfile b/Dockerfiles/kpwr_n82_base/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..509714a39b1681b2b4069db24e86754ce14c4b25
--- /dev/null
+++ b/Dockerfiles/kpwr_n82_base/Dockerfile
@@ -0,0 +1,11 @@
+FROM poldeepner2
+
+# Download and unzip kpwr_n82 model
+RUN mkdir -p models/kpwr_n82_base/kpwr_n82_base
+RUN wget https://minio.clarin-pl.eu/public/models/poldeepner2/kpwr_n82_base.zip
+RUN unzip kpwr_n82_base.zip -d models/kpwr_n82_base
+RUN rm kpwr_n82_base.zip
+
+EXPOSE 8000
+
+CMD python3.6 server.py --model models/kpwr_n82_base/kpwr_n82_base --pretrained_path xlmr:models/roberta_base_fairseq
diff --git a/Dockerfiles/kpwr_n82_large/Dockerfile b/Dockerfiles/kpwr_n82_large/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..e7759a7010478dfa8deb8131158361cae6fba07b
--- /dev/null
+++ b/Dockerfiles/kpwr_n82_large/Dockerfile
@@ -0,0 +1,11 @@
+FROM poldeepner2
+
+# Download and unzip kpwr_n82 model
+RUN mkdir -p models/kpwr_n82_large/kpwr_n82_large
+RUN wget https://github.com/sdadas/polish-roberta/releases/download/models/roberta_large_fairseq.zip
+RUN unzip roberta_large_fairseq.zip -d models/roberta_large_fairseq
+RUN rm roberta_large_fairseq.zip
+
+EXPOSE 8000
+
+CMD python3.6 server.py --model models/kpwr_n82_large/kpwr_n82_large --pretrained_path xlmr:models/roberta_base_fairseq
diff --git a/Dockerfiles/merged-base/Dockerfile b/Dockerfiles/merged-base/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..0d30902faefeaca4ea9c782464f7d77a7a9e1b45
--- /dev/null
+++ b/Dockerfiles/merged-base/Dockerfile
@@ -0,0 +1,47 @@
+FROM nvidia/cuda:11.1-cudnn8-runtime-ubuntu18.04
+LABEL maintainer="Michał Marcińczuk <marcinczuk@gmail.com>"
+
+RUN apt-get clean && apt-get update
+
+# Set the locale
+RUN apt-get install locales
+RUN locale-gen en_US.UTF-8
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:en
+ENV LC_ALL en_US.UTF-8
+
+# Python 3.6
+RUN apt-get install -y software-properties-common vim
+RUN apt-get install -y python3.6 python3-pip
+
+# update pip
+RUN pip3 install pip --upgrade
+RUN pip3 install wheel
+
+# Install requirements and spacy
+WORKDIR "/poldeepner2"
+ADD ./requirements.txt /poldeepner2/requirements.txt
+RUN pip3 install -r requirements.txt
+RUN python3.6 -m spacy download pl_core_news_sm
+RUN python3.6 -m nltk.downloader punkt
+
+RUN apt-get install -y wget
+RUN apt-get install -y unzip
+
+# Download and unzip kpwr_n82 model
+RUN mkdir -p models/kpwr_n82_base/kpwr_n82_base
+RUN wget https://minio.clarin-pl.eu/public/models/poldeepner2/kpwr_n82_base.zip
+RUN unzip kpwr_n82_base.zip -d models/kpwr_n82_base
+RUN rm kpwr_n82_base.zip
+
+# Download and unzip roberta_base_fairseq
+RUN mkdir -p models/roberta_base_fairseq
+RUN wget https://github.com/sdadas/polish-roberta/releases/download/models/roberta_base_fairseq.zip
+RUN unzip roberta_base_fairseq.zip -d models/roberta_base_fairseq
+RUN rm roberta_base_fairseq.zip
+
+COPY . .
+
+EXPOSE 8000
+
+CMD python3.6 server.py --model models/kpwr_n82_base/kpwr_n82_base --pretrained_path xlmr:models/roberta_base_fairseq
diff --git a/Dockerfiles/merged-large/Dockerfile b/Dockerfiles/merged-large/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..9c9387facb29375ab2701872ba09f96612456044
--- /dev/null
+++ b/Dockerfiles/merged-large/Dockerfile
@@ -0,0 +1,47 @@
+FROM nvidia/cuda:11.1-cudnn8-runtime-ubuntu18.04
+LABEL maintainer="Michał Marcińczuk <marcinczuk@gmail.com>"
+
+RUN apt-get clean && apt-get update
+
+# Set the locale
+RUN apt-get install locales
+RUN locale-gen en_US.UTF-8
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:en
+ENV LC_ALL en_US.UTF-8
+
+# Python 3.6
+RUN apt-get install -y software-properties-common vim
+RUN apt-get install -y python3.6 python3-pip
+
+# update pip
+RUN pip3 install pip --upgrade
+RUN pip3 install wheel
+
+# Install requirements and spacy
+WORKDIR "/poldeepner2"
+ADD ./requirements.txt /poldeepner2/requirements.txt
+RUN pip3 install -r requirements.txt
+RUN python3.6 -m spacy download pl_core_news_sm
+RUN python3.6 -m nltk.downloader punkt
+
+RUN apt-get install -y wget
+RUN apt-get install -y unzip
+
+# Download and unzip kpwr_n82 model
+RUN mkdir -p models/kpwr_n82_large/kpwr_n82_large
+RUN wget https://github.com/sdadas/polish-roberta/releases/download/models/roberta_large_fairseq.zip
+RUN unzip roberta_large_fairseq.zip -d models/roberta_large_fairseq
+RUN rm roberta_large_fairseq.zip
+
+# Download and unzip roberta_base_fairseq
+RUN mkdir -p models/roberta_base_fairseq
+RUN wget https://github.com/sdadas/polish-roberta/releases/download/models/roberta_base_fairseq.zip
+RUN unzip roberta_base_fairseq.zip -d models/roberta_base_fairseq
+RUN rm roberta_base_fairseq.zip
+
+COPY . .
+
+EXPOSE 8000
+
+CMD python3.6 server.py --model models/kpwr_n82_large/kpwr_n82_large --pretrained_path xlmr:models/roberta_base_fairseq
diff --git a/Dockerfiles/nkjp_base/Dockerfile b/Dockerfiles/nkjp_base/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..c44168190ccb10b0e03c0d5f711b899685d3c144
--- /dev/null
+++ b/Dockerfiles/nkjp_base/Dockerfile
@@ -0,0 +1,11 @@
+FROM poldeepner2
+
+# Download and unzip nkjp_base model
+RUN mkdir -p models/nkjp_base/nkjp_base
+RUN wget https://minio.clarin-pl.eu/public/models/poldeepner2/nkjp_base.zip
+RUN unzip nkjp_base.zip -d models/nkjp_base
+RUN rm nkjp_base.zip
+
+EXPOSE 8000
+
+CMD python3.6 server.py --model models/nkjp_base/nkjp_base --pretrained_path xlmr:models/roberta_base_fairseq
diff --git a/Dockerfiles/nkjp_base_sq/Dockerfile b/Dockerfiles/nkjp_base_sq/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..adc8f5cdf244b98db2048fbe506b02610b9db9a2
--- /dev/null
+++ b/Dockerfiles/nkjp_base_sq/Dockerfile
@@ -0,0 +1,11 @@
+FROM poldeepner2
+
+# Download and unzip nkjp_base_sq
+RUN mkdir -p models/nkjp_base_sq/nkjp_base_sq
+RUN wget https://minio.clarin-pl.eu/public/models/poldeepner2/nkjp_base_sq.zip
+RUN unzip nkjp_base_sq.zip -d models/nkjp_base_sq
+RUN rm nkjp_base_sq.zip
+
+EXPOSE 8000
+
+CMD python3.6 server.py --model models/nkjp_base_sq/nkjp_base_sq --pretrained_path xlmr:models/roberta_base_fairseq
diff --git a/README.md b/README.md
index 93bd7dd7eb5e628b8855f93db4b04a54f245e38d..7cd3add65b7c5320ae22560be535397e66abfbb0 100644
--- a/README.md
+++ b/README.md
@@ -375,3 +375,39 @@ time python main.py  \
       --dropout 0.3 \
       --squeeze
 ```
+
+### Docker
+
+To build base image
+
+docker build -f Dockerfiles/base/Dockerfile . --tag poldeepner2
+
+To build specific models on top of base image
+
+docker build -f Dockerfiles/nkjp_base_sq/Dockerfile . --tag poldeepner2_nkjp_base_sq
+
+To run container with chosen model
+
+docker run --publish 8000:8000 poldeepner2_nkjp_base_sq
+
+### HerBERT
+```bash
+time python main.py  \
+      --data_dir=data/nkjp-nested-full-aug/  \
+      --task_name=ner \
+      --output_dir=models/nkjp_base_sq/   \
+      --max_seq_length=256   \
+      --num_train_epochs 10  \
+      --do_eval \
+      --warmup_proportion=0.0 \
+      --pretrained_path models/roberta_base_fairseq \
+      --learning_rate 6e-5 \
+      --gradient_accumulation_steps 4 \
+      --do_train \
+      --eval_on test \
+      --train_batch_size 32 \
+      --dropout 0.3 \
+      --model=Herbert \
+      --squeeze
+```
+
diff --git a/core/poldeepner.py b/core/poldeepner.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d08aec58fcb04cd7d8ff9e880e35da5da54ab91
--- /dev/null
+++ b/core/poldeepner.py
@@ -0,0 +1,116 @@
+import codecs
+import os
+import torch
+import tqdm
+from torch.utils.data.dataloader import DataLoader
+
+from core.model.xlmr_for_token_classification import XLMRForTokenClassification
+from core.utils.data_utils import InputExample, convert_examples_to_features, create_dataset, read_params, \
+    wrap_annotations, align_tokens_with_text
+from core.utils.tokenization import TokenizerSpaces
+
+
+class PolDeepNer2:
+
+    def __init__(self, model_path, pretrained_path,
+                 device="cpu", squeeze=False, max_seq_length=256, tokenizer=TokenizerSpaces()):
+        if not os.path.exists(model_path):
+            raise ValueError("Model not found on path '%s'" % model_path)
+
+        if not os.path.exists(pretrained_path):
+            raise ValueError("RoBERTa language model not found on path '%s'" % pretrained_path)
+
+        dropout, num_labels, label_list = read_params(model_path)
+        self.label_list = label_list
+        model = XLMRForTokenClassification(pretrained_path=pretrained_path,
+                                           n_labels=len(self.label_list) + 1,
+                                           dropout_p=dropout,
+                                           device=device,
+                                           hidden_size=768 if 'base' in pretrained_path else 1024)
+        state_dict = torch.load(open(os.path.join(model_path, 'model.pt'), 'rb'))
+        model.load_state_dict(state_dict)
+        model.eval()
+        model.to(device)
+        self.model = model
+        self.device = device
+        self.squeeze = squeeze
+        self.max_seq_length = max_seq_length
+        self.tokenizer = tokenizer
+
+    @staticmethod
+    def load_labels(path):
+        return [line.strip() for line in codecs.open(path, "r", "utf8").readlines() if len(line.strip()) > 0]
+
+    def process(self, sentences):
+        """
+        @param sentences -- array of array of words, [['Jan', 'z', 'Warszawy'], ['IBM', 'i', 'Apple']]
+        @param max_seq_length -- the maximum total input sequence length after WordPiece tokenization
+        @param squeeze -- boolean enabling squeezing multiple sentences into one Input Feature
+        """
+        examples = []
+        for idx, tokens in enumerate(sentences):
+            guid = str(idx)
+            text_a = ' '.join(tokens)
+            label = ["O"] * len(tokens)
+            examples.append(InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
+
+        eval_features = convert_examples_to_features(examples, self.label_list, self.max_seq_length,
+                                                     self.model.encode_word, self.squeeze)
+        eval_dataset = create_dataset(eval_features)
+        eval_dataloader = DataLoader(eval_dataset, batch_size=1)
+
+        y_pred = []
+        sum_pred = []
+        label_map = {i: label for i, label in enumerate(self.label_list, 1)}
+
+        for input_ids, label_ids, l_mask, valid_ids in eval_dataloader:
+            input_ids = input_ids.to(self.device)
+            label_ids = label_ids.to(self.device)
+            valid_ids = valid_ids.to(self.device)
+
+            with torch.no_grad():
+                logits = self.model(input_ids, labels=None, labels_mask=None, valid_mask=valid_ids)
+
+            logits = torch.argmax(logits, dim=2)
+            logits = logits.detach().cpu().numpy()
+            label_ids = label_ids.cpu().numpy()
+            for i, cur_label in enumerate(label_ids):
+                temp_1 = []
+                temp_2 = []
+                for j, m in enumerate(cur_label):
+                    if valid_ids[i][j]:
+                        temp_1.append(label_map[m])
+                        temp_2.append(label_map[logits[i][j]])
+                assert len(temp_1) == len(temp_2)
+                if self.squeeze:
+                    sum_pred.extend(temp_2)
+                else:
+                    y_pred.append(temp_2)
+        pointer = 0
+        for sentence in sentences:
+            y_pred.append(sum_pred[pointer: (pointer+len(sentence))])
+            pointer += len(sentence)
+        return y_pred
+
+    def process_text(self, text: str):
+        """
+        @texts: Array of sentences. Each sentence is a string.
+                "John lives in New York. Mary lives in Chicago"
+
+        return:[(PER, 0, 4, "John"), (LOC, 14, 22, "New York"), (PER, 24, 28, "Mary"), (LOC, 38, 45, "Chicago")]]
+        """
+        sentences = self.tokenizer.tokenize([text])
+        predictions = self.process(sentences)
+        annotations = wrap_annotations(predictions)
+        return align_tokens_with_text(text, sentences, annotations)
+
+    def process_tokenized(self, tokens: [[str]], text: str):
+        """
+        @tokens: Array of sentences. Each sentence is an array of words.
+                 [["John", "lives", "in", "New", "York"], ["Mary", "lives", "in", "Chicago"]]
+
+        return: [["B-PER", "O", "O", "B-LOC", "I-LOC"], ["B-PER", "O", "O", "B-LOC"]]
+        """
+        predictions = self.process(tokens)
+        annotations = wrap_annotations(predictions)
+        return align_tokens_with_text(text, tokens, annotations)
diff --git a/main.py b/main.py
index a279fd2198fd51baa865e0970c67311429fe567b..c91f759f647857dd162ea1d600e9ce589be1474f 100644
--- a/main.py
+++ b/main.py
@@ -13,6 +13,7 @@ from pytorch_transformers import AdamW, WarmupLinearSchedule
 from torch.utils.data import DataLoader, RandomSampler
 
 from poldeepner2.model.xlmr_for_token_classification import XLMRForTokenClassification
+from poldeepner2.model.herbert_for_token_calssification import AutoTokenizerForTokenClassification
 from poldeepner2.utils.train_utils import add_xlmr_args, evaluate_model, predict_model
 from poldeepner2.utils.data_utils import NerProcessor, create_dataset, convert_examples_to_features, save_params
 
@@ -28,16 +29,31 @@ def main():
     parser = argparse.ArgumentParser()
     parser = add_xlmr_args(parser)
     args = parser.parse_args()
-    
-    #if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
-    #    raise ValueError("Output directory (%s) already exists and is not empty." % args.output_dir)
+
+    if args.wandb:
+        import wandb
+        wandb.init(project=args.wandb,
+                   config={
+                       "epochs": args.num_train_epochs,
+                       "language_model": args.pretrained_path,
+                       "batch_size": args.train_batch_size,
+                       #"trainig_dataset": args.data_train,
+                       #"tuning_dataset": args.data_tune,
+                       "max_seq_length": args.max_seq_length,
+                       "warmup_proportion": args.warmup_proportion,
+                       "learning_rate": args.learning_rate,
+                       "gradient_accumulation_steps": args.gradient_accumulation_steps,
+                       "squeeze": args.squeeze
+                   })
+
+    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
+        raise ValueError("Output directory (%s) already exists and is not empty." % args.output_dir)
 
     if not os.path.exists(args.output_dir):
         os.makedirs(args.output_dir)
 
     logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
-                        datefmt='%m/%d/%Y %H:%M:%S',
-                        level=logging.INFO,
+                        datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO,
                         filename=os.path.join(args.output_dir, "log.txt"))
     logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
     logger = logging.getLogger(__name__)
@@ -57,8 +73,6 @@ def main():
 
     processor = NerProcessor()
     label_list = processor.get_labels(args.data_dir)
-    print(*label_list, sep="\n")
-    print(len(label_list))
     num_labels = len(label_list) + 1  # add one for IGNORE label
 
     train_examples = None
@@ -73,11 +87,26 @@ def main():
     hidden_size = 768 if 'base' in args.pretrained_path else 1024
     device = 'cuda:0' if (torch.cuda.is_available() and not args.no_cuda) else 'cpu'
     logger.info(device)
-    model = XLMRForTokenClassification(pretrained_path=args.pretrained_path,
-                                       n_labels=num_labels, hidden_size=hidden_size,
-                                       dropout_p=args.dropout, device=device)
+
+    if args.pretrained_path.startswith("automodel:"):
+        pretrained_dir = args.pretrained_path.split(':')[1]
+        model = AutoTokenizerForTokenClassification(
+            pretrained_path=pretrained_dir, n_labels=num_labels, hidden_size=hidden_size, dropout_p=args.dropout,
+            device=device)
+    else:
+        pretrained_dir = args.pretrained_path
+        if ":" in pretrained_dir:
+            pretrained_dir = pretrained_dir.split(':')[1]
+        if not os.path.exists(pretrained_dir):
+            raise ValueError("RoBERTa language model not found on path '%s'" % pretrained_dir)
+        model = XLMRForTokenClassification(
+            pretrained_path=pretrained_dir, n_labels=num_labels, hidden_size=hidden_size, dropout_p=args.dropout,
+            device=device)
 
     model.to(device)
+    if args.wandb:
+        wandb.watch(model)
+
     no_decay = ['bias', 'final_layer_norm.weight']
 
     params = list(model.named_parameters())
@@ -148,7 +177,6 @@ def main():
                 batch = tuple(t.to(device) for t in batch)
                 input_ids, label_ids, l_mask, valid_ids, = batch
                 loss = model(input_ids, label_ids, l_mask, valid_ids)
-
                 if args.gradient_accumulation_steps > 1:
                     loss = loss / args.gradient_accumulation_steps
                 
@@ -165,6 +193,10 @@ def main():
                 tr_loss += loss.item()
                 nb_tr_examples += input_ids.size(0)
                 nb_tr_steps += 1
+
+                if args.wandb:
+                    wandb.log({"loss": loss}, commit=False)
+
                 if step % 1000 == 0:
                     logger.info('Step = %d/%d; Loss = %.4f' % (step+1, steps, tr_loss / (step+1)))
 
@@ -175,6 +207,10 @@ def main():
 
             logger.info("\nTesting on validation set...")
             f1, report = evaluate_model(model, val_data, label_list, args.eval_batch_size, device)
+
+            if args.wandb:
+                wandb.log({"epoch": epoch_no, "validation_F1": f1})
+
             if f1 > best_val_f1:
                 best_val_f1 = f1
                 logger.info("\nFound better f1=%.4f on validation set. Saving model\n" % f1)
diff --git a/poldeepner2/model/herbert_for_token_calssification.py b/poldeepner2/model/herbert_for_token_calssification.py
new file mode 100644
index 0000000000000000000000000000000000000000..729c4c71055596554a7b9151dd27af5c2fe2eaa4
--- /dev/null
+++ b/poldeepner2/model/herbert_for_token_calssification.py
@@ -0,0 +1,74 @@
+from transformers import AutoModel, AutoTokenizer
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+
+class AutoTokenizerForTokenClassification(nn.Module):
+
+    def __init__(self, pretrained_path, n_labels, hidden_size=768, dropout_p=0.2, label_ignore_idx=0,
+                head_init_range=0.04, device='cuda'):
+        super().__init__()
+
+        self.n_labels = n_labels
+        
+        self.linear_1 = nn.Linear(hidden_size, hidden_size)
+        self.classification_head = nn.Linear(hidden_size, n_labels)
+        
+        self.label_ignore_idx = label_ignore_idx
+        self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
+
+        self.model = AutoModel.from_pretrained(pretrained_path)
+
+        self.dropout = nn.Dropout(dropout_p)
+        
+        self.device = device
+
+        # initializing classification head
+        self.classification_head.weight.data.normal_(mean=0.0, std=head_init_range)
+
+
+    def forward(self, inputs_ids, labels, labels_mask, valid_mask):
+        '''
+        Computes a forward pass through the sequence tagging model.
+        Args:
+            inputs_ids: tensor of size (bsz, max_seq_len). padding idx = 1
+            labels: tensor of size (bsz, max_seq_len)
+            labels_mask and valid_mask: indicate where loss gradients should be propagated and where 
+            labels should be ignored
+
+        Returns :
+            logits: unnormalized model outputs.
+            loss: Cross Entropy loss between labels and logits
+
+        '''
+        self.model.train()
+
+        transformer_out  = self.model(inputs_ids, return_dict=True)[0]
+        out_1 = F.relu(self.linear_1(transformer_out))
+        out_1 = self.dropout(out_1)
+        logits = self.classification_head(out_1)
+
+        if labels is not None:
+            loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_ignore_idx)
+            # Only keep active parts of the loss
+            if labels_mask is not None:
+                active_loss = valid_mask.view(-1) == 1
+
+                active_logits = logits.view(-1, self.n_labels)[active_loss]
+                active_labels = labels.view(-1)[active_loss]
+                loss = loss_fct(active_logits, active_labels)
+            else:
+                loss = loss_fct(
+                    logits.view(-1, self.n_labels), labels.view(-1))
+            return loss
+        else:
+            return logits
+
+
+    def encode_word(self, s):
+        """
+        takes a string and returns a list of token ids
+        """
+        tensor_ids = self.tokenizer.encode(s)
+        # remove <s> and </s> ids
+        return tensor_ids[1:-1]
diff --git a/poldeepner2/model/xlmr_for_token_classification.py b/poldeepner2/model/xlmr_for_token_classification.py
index e412c83f922ecdc3feb591d51eb500766834eddc..ff574d4d3b5a2092cab019226960b6f53198e7f8 100644
--- a/poldeepner2/model/xlmr_for_token_classification.py
+++ b/poldeepner2/model/xlmr_for_token_classification.py
@@ -44,7 +44,6 @@ class XLMRForTokenClassification(nn.Module):
         out_1 = F.relu(self.linear_1(transformer_out))
         out_1 = self.dropout(out_1)
         logits = self.classification_head(out_1)
-
         if labels is not None:
             loss_fct = nn.CrossEntropyLoss(ignore_index=self.label_ignore_idx)
             # Only keep active parts of the loss
diff --git a/poldeepner2/models.py b/poldeepner2/models.py
index 06fb0e82b87b901043cbe47afd16abd99301c7f8..300df0409eb1b3af7a0aa820a73855d7946a30d9 100644
--- a/poldeepner2/models.py
+++ b/poldeepner2/models.py
@@ -10,7 +10,7 @@ from poldeepner2.utils.tokenization import TokenizerSpaces
 from poldeepner2.utils.data_utils import InputExample, convert_examples_to_features, create_dataset, read_params, \
     wrap_annotations, align_tokens_with_text
 from poldeepner2.model.xlmr_for_token_classification import XLMRForTokenClassification
-
+from poldeepner2.model.herbert_for_token_calssification import AutoTokenizerForTokenClassification
 from torch.utils.data.dataloader import DataLoader
 
 
@@ -54,22 +54,35 @@ resources = {
 
 class PolDeepNer2:
 
-    def __init__(self, model_path, roberta_embeddings_path,
+    def __init__(self, model_path, pretrained_path,
                  device="cpu", squeeze=False, max_seq_length=256, tokenizer=TokenizerSpaces()):
+
         if not os.path.exists(model_path):
             raise ValueError("Model not found on path '%s'" % model_path)
 
-        if not os.path.exists(roberta_embeddings_path):
-            raise ValueError("RoBERTa language model not found on path '%s'" % roberta_embeddings_path)
-
         dropout, num_labels, label_list = read_params(model_path)
         self.label_list = label_list
-        model = XLMRForTokenClassification(pretrained_path=roberta_embeddings_path,
-                                           n_labels=len(self.label_list) + 1,
-                                           dropout_p=dropout,
-                                           device=device,
-                                           hidden_size=768 if 'base' in roberta_embeddings_path else 1024)
-        state_dict = torch.load(open(os.path.join(model_path, 'model.pt'), 'rb'))
+
+        hidden_size = 768 if 'base' in pretrained_path else 1024
+
+        if pretrained_path.startswith('automodel:'):
+            pretrained_dir = pretrained_path[len('automodel:'):]
+            model = AutoTokenizerForTokenClassification(
+                pretrained_path=pretrained_dir, n_labels=num_labels, hidden_size=hidden_size, dropout_p=dropout,
+                device=device)
+        else:
+            pretrained_dir = pretrained_path
+            if ":" in pretrained_dir:
+                pretrained_dir = pretrained_dir.split(":")[1]
+            if not os.path.exists(pretrained_dir):
+                raise ValueError("RoBERTa language model not found on path '%s'" % pretrained_dir)
+            model = XLMRForTokenClassification(
+                pretrained_path=pretrained_dir, n_labels=num_labels, dropout_p=dropout, device=device,
+                hidden_size=hidden_size)
+        if device == "cpu":
+            state_dict = torch.load(open(os.path.join(model_path, 'model.pt'), 'rb'), map_location='cpu')
+        else:
+            state_dict = torch.load(open(os.path.join(model_path, 'model.pt'), 'rb'))
         model.load_state_dict(state_dict)
         model.eval()
         model.to(device)
diff --git a/poldeepner2/utils/data_utils.py b/poldeepner2/utils/data_utils.py
index cecd4aa0e5e7e5f08a96acc876f0b68013fe56ea..bb994ed2c8021c6e9e76c834782b39ac0081d47b 100644
--- a/poldeepner2/utils/data_utils.py
+++ b/poldeepner2/utils/data_utils.py
@@ -311,14 +311,12 @@ def convert_examples_to_features_nosq(examples, label_list, max_seq_length, enco
         assert len(label_ids) == max_seq_length
         assert len(valid) == max_seq_length
         assert len(label_mask) == max_seq_length
-
         features.append(
             InputFeatures(input_ids=token_ids,
                           input_mask=input_mask,
                           label_id=label_ids,
                           valid_ids=valid,
                           label_mask=label_mask))
-
     return features
 
 
@@ -465,6 +463,7 @@ def read_tsv(filename, with_labels=False):
     return data
 
 
+
 def save_tsv(output_path, sentences, predictions):
     with codecs.open(output_path, "w", "utf8") as fout:
         assert len(sentences) == len(predictions)
@@ -474,6 +473,7 @@ def save_tsv(output_path, sentences, predictions):
             fout.write("\n")
 
 
+
 def get_dict_for_record(json_ann):
     token_dict = {}
     derives = 0
@@ -482,6 +482,7 @@ def get_dict_for_record(json_ann):
             if ann.strip() != '':
                 annotation = ann.split('\t')[1].split(' ')[0]
                 token = ann.split('\t')[-1]
+
                 if token in token_dict.keys():
                     token_dict[token] = ''.join([token_dict[token],'#',annotation])
                 else:
@@ -510,3 +511,44 @@ def map_json_to_iob(json_ann, iob):
     failed_to_add = len(token_dict) - successfully_added
     return out_iob, successfully_added, failed_to_add, derives
 
+
+def is_continued(annotation, next_annotations):
+    if next_annotations == ['O']:
+        return False
+    searched_ann = 'I-{0}'.format(annotation)
+    return searched_ann in next_annotations
+
+
+def iob2_to_iob(iob2_text):
+    iob2_list = []
+    iob1_list = []
+    for line in iob2_text.split('\n'):
+        split = line.split(' ')
+        iob2_list.append((split[0], split[1].split('#')))
+    for i, line in enumerate(iob2_list):
+        current_ann = []
+        if len(iob2_list) == 1:
+            for ann in line[1]:
+                split = ann.split('-')
+                if split[0] == 'B':
+                    current_ann.append('I-{0}'.format(split[1]))
+                else:
+                    current_ann.append(ann)
+            iob1_list.append((line[0], '#'.join(current_ann)))
+        elif i == len(iob2_list)-1:
+            for ann in line[1]:
+                split = ann.split('-')
+                if split[0] == 'B':
+                    current_ann.append('I-{0}'.format(split[1]))
+                else:
+                    current_ann.append(ann)
+            iob1_list.append((line[0], '#'.join(current_ann)))
+        else:
+            for ann in line[1]:
+                split = ann.split('-')
+                if split[0] == 'B' and not is_continued(split[1], iob2_list[i+1][1]):
+                    current_ann.append('I-{0}'.format(split[1]))
+                else:
+                    current_ann.append(ann)
+            iob1_list.append((line[0], '#'.join(current_ann)))
+    return '\n'.join(map(lambda x: '{} {}'.format(x[0], x[1]), iob1_list))
diff --git a/poldeepner2/utils/train_utils.py b/poldeepner2/utils/train_utils.py
index c02cf1bbbbfc169bb5115c53c8e7280715a7feba..5ffff88a042e87740bbda21b1e43ff61b5099475 100644
--- a/poldeepner2/utils/train_utils.py
+++ b/poldeepner2/utils/train_utils.py
@@ -14,7 +14,7 @@ def add_xlmr_args(parser):
                          required=True,
                          help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
      parser.add_argument("--pretrained_path", default=None, type=str, required=True,
-                         help="pretrained XLM-Roberta model path")
+                         help="pretrained XLM-Roberta model path with model name as prefix, a.e automodel:allegro/herbert-large-cased")
      parser.add_argument("--task_name",
                          default=None,
                          type=str,
@@ -110,10 +110,13 @@ def add_xlmr_args(parser):
                          help="save model for every epoch")
      parser.add_argument('--squeeze', default=False, action="store_true",
                          help="try to squeeze multiple examples into one Input Feature")
+     parser.add_argument('--wandb',
+                         type=str,
+                         help="Wandb project id. If present the training data will be logged using wandb api.")
      return parser
 
 
-def evaluate_model(model, eval_dataset, label_list, batch_size, device):
+def evaluate_model(model, eval_dataset, label_list, batch_size, device, model_name='Roberta'):
      """
      Evaluates an NER model on the eval_dataset provided.
      Returns:
@@ -142,8 +145,11 @@ def evaluate_model(model, eval_dataset, label_list, batch_size, device):
           l_mask = l_mask.to(device)
 
           with torch.no_grad():
-               logits = model(input_ids, labels=None, labels_mask=None,
-                              valid_mask=valid_ids)
+               if model_name == 'Roberta':
+                    logits = model(input_ids, labels=None, labels_mask=None,
+                                   valid_mask=valid_ids)
+               else:
+                    logits = model(input_ids, return_dict=True).logits
 
           logits = torch.argmax(logits, dim=2)
           logits = logits.detach().cpu().numpy()
diff --git a/process_poleval.py b/process_poleval.py
index 46c181dd8bb092dcb5b09cfe50a596501e510731..3f17ca8031f905b39157df782b905772f0a94a4e 100644
--- a/process_poleval.py
+++ b/process_poleval.py
@@ -7,12 +7,61 @@ import codecs
 import os
 import json
 
+from core.poldeepner import PolDeepNer2
+from core.utils.data_utils import get_poleval_dict, read_tsv, wrap_annotations
+from core.utils.file_utils import show_download_menu, check_for_data, download_missing
+
+
+def get_id(ini_file):
+    for line in codecs.open(ini_file, "r", "utf8"):
+        if 'id = ' in line:
+            return line.replace('id = ', '')
+
+
+def split_hashtags(tokens):
+    output = []
+    i = 0
+    while i < len(tokens):
+        if tokens[i] == "#" and i+1 < len(tokens) and re.fullmatch(r"([A-Z][a-z]+)([A-Z][a-z]+)+", tokens[i+1]):
+            output.append("#")
+            for m in re.findall(r"([A-Z][a-z]+)", tokens[i+1]):
+                output.append(str(m))
+            i += 2
+        else:
+            output.append(tokens[i])
+            i += 1
+    return output
+
+
+def split_leading_name(tokens):
+    if len(tokens) > 1 and re.fullmatch(r"([A-Z][a-z]+)([A-Z][a-z]+)+", tokens[0]) and tokens[1] == ":":
+        output = []
+        for m in re.findall(r"([A-Z][a-z]+)", tokens[0]):
+            output.append(str(m))
+        output.extend(tokens[1:])
+        return output
+    else:
+        return tokens
+
+
+def load_document(abs_path):
+    namext = os.path.basename(abs_path)
+    name = os.path.splitext(namext)[0]
+    path = os.path.dirname(abs_path)
+    text = codecs.open(os.path.join(path, name + ".txt"), "r", "utf8").read()
+    doc_id = get_id(os.path.join(path, name + ".ini"))
+    sentences_labels = read_tsv(os.path.join(path, name + ".iob"))
+    sentences = [sentence[0] for sentence in sentences_labels]
+    return doc_id, text, sentences
+
+
 from poldeepner2.models import PolDeepNer2
 from poldeepner2.utils import tokenization
 from poldeepner2.utils.data_utils import get_poleval_dict, wrap_annotations
 from poldeepner2.utils.preprocess import split_hashtags, split_leading_name
 
 
+
 def main(args):
     print("Loading the NER model ...")
     ner = PolDeepNer2(args.model, args.pretrained_path, device=args.device, max_seq_length=args.max_seq_length,
diff --git a/process_tsv.py b/process_tsv.py
index e4e0f36e3f5e271817a072aac430f5cfc7a2b576..109a668e90c703847930fe7f27a8bfb8abcf6def 100644
--- a/process_tsv.py
+++ b/process_tsv.py
@@ -1,11 +1,14 @@
 from __future__ import absolute_import, division, print_function
 
 import argparse
+
 import logging
 import os
 
-from poldeepner2.models import PolDeepNer2
-from poldeepner2.utils.data_utils import read_tsv, save_tsv
+from core.utils.file_utils import show_download_menu, check_for_data, download_missing
+from core import pdn2
+from core.poldeepner import PolDeepNer2
+from core.utils.data_utils import read_tsv, save_tsv
 
 
 def main(args):
@@ -23,6 +26,7 @@ def main(args):
     logging.info("done.")
 
 
+
 def parse_args():
     parser = argparse.ArgumentParser(
         description='Process a single TSV with a NER model')
diff --git a/requirements.txt b/requirements.txt
index e8c4ae231a4ea24b29a789887460ffab477fe75f..29d2846dfb6456290ff65dcb5478b3f8630421d6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -2,8 +2,11 @@ fairseq==0.9.0
 pytorch-transformers==1.2.0
 seqeval==0.0.12
 pytest~=6.0.1
-flask==1.1.2
+tqdm
+console-menu
+fastapi==0.61.1
+uvicorn==0.12.2
 pandas==1.1.1
 nltk==3.5
 spacy==2.3.2
-tqdm
\ No newline at end of file
+wandb==0.10.7 
diff --git a/server.py b/server.py
index f82986d00135ee1aab1d8f2f60985700c77c9a3b..f1a3607c395f843f9dd9dc2af93d6a206448d8d5 100644
--- a/server.py
+++ b/server.py
@@ -1,23 +1,39 @@
 from __future__ import absolute_import, division, print_function
-
+import uvicorn
 import argparse
-from flask import Flask, jsonify, request
+from fastapi import FastAPI
+from typing import Dict, List, Optional
 from poldeepner2.models import PolDeepNer2
-from poldeepner2.utils import tokenization
+from poldeepner2.utils.tokenization import TokenizerSpaces, load, names
+from pydantic import BaseModel
 
 
-class Server:
-    
-    app = Flask(__name__)
+class PredictionReq(BaseModel):
+    text: str
+    tokenization: Optional[str] = 'spacy'
 
-    @app.route('/predict', methods=['POST'])
-    def predict():
-        text = request.get_data().decode('utf-8')
-        entities = ner.process_text(text)
-        return jsonify({"text": text, "entities": [entity.dict() for entity in entities]})
 
-    def run(self, host, port, threaded, processes):
-        self.app.run(host, port, threaded, processes)
+class Prediction(BaseModel):
+    text: str
+    entities: List[List[str]]
+
+
+class Server:
+    app = FastAPI()
+    global spacyTokenizer, spacyNltk
+    spacyTokenizer = load('spacy')
+    spacyNltk = load('nltk')
+
+    @app.post('/predict', response_model=Prediction)
+    async def predict(pred_req: PredictionReq):
+        text = pred_req.text
+        sentences = text.split('\n')
+        if pred_req.tokenization == 'spacy':
+            tokens = spacyTokenizer.tokenize(sentences)
+        else:
+            tokens = spacyNltk.tokenize(sentences)
+        output = ner.process_tokenized(tokens)
+        return {"text": text, "entities": output}
 
 
 def parse_args():
@@ -28,9 +44,10 @@ def parse_args():
                         help='device type used for processing')
     parser.add_argument('--max_seq_length', required=False, default=256, metavar='N', type=int,
                         help='the maximum total input sequence length after WordPiece tokenization.')
-    parser.add_argument('--pretrained_path', required=True, metavar='PATH', help='path to a pretrained RoBERTa model')
+    parser.add_argument('--pretrained_path', required=True, metavar='PATH',
+                        help='pretrained XLM-Roberta model path with model name as prefix, a.e automodel:allegro/herbert-large-cased')
     parser.add_argument('--processes', help='number of processes', default=1)
-    parser.add_argument('--tokenization', required=False, default="spacy-ext", choices=tokenization.names,
+    parser.add_argument('--tokenization', required=False, default="spacy-ext", choices=names,
                         help='Type of tokenization, nltk or spacy')
     parser.add_argument('--squeeze', required=False, default=False, action="store_true",
                         help='try to squeeze multiple examples into one Input Feature')
@@ -39,13 +56,16 @@ def parse_args():
     return parser.parse_args()
 
 
+server = Server()
+
 if __name__ == "__main__":
     cliargs = parse_args()
     try:
         global ner
         ner = PolDeepNer2(cliargs.model, cliargs.pretrained_path, device=cliargs.device,
-                          max_seq_length=cliargs.max_seq_length, squeeze=cliargs.squeeze, )
-        server = Server()
-        server.run(host=cliargs.host, port=cliargs.port, threaded=True, processes=cliargs.processes)
+                          max_seq_length=cliargs.max_seq_length, squeeze=cliargs.squeeze, tokenizer=TokenizerSpaces())
+        
+        # threaded=True, processes=cliargs.processes
+        uvicorn.run(server.app, host=cliargs.host, port=cliargs.port, log_level="info")
     except ValueError as er:
         print("[ERROR] %s" % er)
diff --git a/tests/unit/utils/test_iob2_to_iob.py b/tests/unit/utils/test_iob2_to_iob.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c0c9bf2843460eb0fe0631d6208a49bb867b40a
--- /dev/null
+++ b/tests/unit/utils/test_iob2_to_iob.py
@@ -0,0 +1,34 @@
+import pytest
+import sys
+import pathlib
+
+sys.path.append(str(pathlib.Path(__file__).absolute().parents[3].resolve()))
+from poldeepner2.utils.data_utils import iob2_to_iob
+
+
+@pytest.mark.parametrize(
+    "iob2_input, expected_output", [
+        ('Alex B-PER\nis O\ngoing O\nto O\nLos B-LOC\nAngeles I-LOC',
+         'Alex I-PER\nis O\ngoing O\nto O\nLos I-LOC\nAngeles I-LOC'),
+        ('Alex B-PER',
+         'Alex I-PER'),
+        ('Alex B-PER\nAngeles I-PER\nis O\ngoing O\nto O\nLos B-LOC\nAngeles I-LOC',
+         'Alex I-PER\nAngeles I-PER\nis O\ngoing O\nto O\nLos I-LOC\nAngeles I-LOC'),
+        ('is O\ngoing O\nAlex B-PER\nAngeles I-PER\nis O\ngoing O\nto O\nLos B-LOC\nAngeles I-LOC',
+         'is O\ngoing O\nAlex I-PER\nAngeles I-PER\nis O\ngoing O\nto O\nLos I-LOC\nAngeles I-LOC'),
+        ('Alex B-PER\nis O\ngoing O\nAlex B-PER\nto O\nLos B-LOC\nAngeles I-LOC',
+         'Alex I-PER\nis O\ngoing O\nAlex I-PER\nto O\nLos I-LOC\nAngeles I-LOC'),
+        ('Alex B-PER\nis O\ngoing O\nAlex B-PER\nto O\nLos B-LOC\nAngeles I-LOC\nAlex B-PER',
+         'Alex I-PER\nis O\ngoing O\nAlex I-PER\nto O\nLos I-LOC\nAngeles I-LOC\nAlex I-PER'),
+        # nested
+        ('Alex B-PER#B-ORG\nis I-ORG\ngoing O\nAlex B-PER\nto O\nLos B-LOC\nAngeles I-LOC\nAlex B-PER#B-LOC',
+         'Alex I-PER#I-ORG\nis I-ORG\ngoing O\nAlex I-PER\nto O\nLos I-LOC\nAngeles I-LOC\nAlex I-PER#B-LOC'),
+        ('Alex B-PER#B-ORG\nis I-ORG\ngoing O\nAlex B-PER\nto O\nLos B-LOC#B-PER\nAngeles B-LOC#B-PER',
+         'Alex I-PER#B-ORG\nis I-ORG\ngoing O\nAlex I-PER\nto O\nLos I-LOC#I-PER\nAngeles B-LOC#I-PER'),
+        ('Alex B-PER#B-ORG\nis I-ORG\ngoing O\nAlex B-PER#B-NAV\nto I-NAV\nLos B-LOC#B-PER\nAngeles I-LOC#B-PER',
+         'Alex I-PER#I-ORG\nis I-ORG\ngoing O\nAlex I-PER#I-NAV\nto I-NAV\nLos I-LOC#I-PER\nAngeles I-LOC#B-PER')
+    ]
+)
+def test_iob2_to_iob(iob2_input, expected_output):
+    iob1 = iob2_to_iob(iob2_input)
+    assert iob1.split('\n') == expected_output.split('\n')