Skip to content
Snippets Groups Projects

Resolve "Flair Embeddings"

3 files
+ 40
9
Compare changes
  • Side-by-side
  • Inline

Files

@@ -2,17 +2,24 @@ from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
import torch
from flair.data import Sentence
import flair
from flair.embeddings import FlairEmbeddings, StackedEmbeddings, WordEmbeddings
class AutoTokenizerForTokenClassification(nn.Module):
def __init__(self, pretrained_path, n_labels, hidden_size=768, dropout_p=0.2, label_ignore_idx=0,
head_init_range=0.04, device='cuda'):
head_init_range=0.04, device='cuda', flair_embeddings=False, flair_embed_size=4096):
super().__init__()
self.n_labels = n_labels
self.linear_1 = nn.Linear(hidden_size, hidden_size)
self.classification_head = nn.Linear(hidden_size, n_labels)
self.flair_embeddings = flair_embeddings
if flair_embeddings:
self.linear_1 = nn.Linear(hidden_size + flair_embed_size, hidden_size)
self.classification_head = nn.Linear(hidden_size, n_labels)
else:
self.linear_1 = nn.Linear(hidden_size, hidden_size)
self.classification_head = nn.Linear(hidden_size, n_labels)
self.label_ignore_idx = label_ignore_idx
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
@@ -25,9 +32,15 @@ class AutoTokenizerForTokenClassification(nn.Module):
# initializing classification head
self.classification_head.weight.data.normal_(mean=0.0, std=head_init_range)
if flair_embeddings:
flair.device = torch.device(self.device)
self.stacked_embeddings = StackedEmbeddings([
FlairEmbeddings('pl-forward'),
FlairEmbeddings('pl-backward'),
])
def forward(self, inputs_ids, labels, labels_mask, valid_mask):
def forward(self, inputs_ids, labels, labels_mask, valid_mask, flair_embed_size=4096):
'''
Computes a forward pass through the sequence tagging model.
Args:
@@ -44,7 +57,20 @@ class AutoTokenizerForTokenClassification(nn.Module):
self.model.train()
transformer_out = self.model(inputs_ids, return_dict=True)[0]
out_1 = F.relu(self.linear_1(transformer_out))
if self.flair_embeddings:
embedded = []
for batch in inputs_ids:
flair_embeds_for_batch = []
flair_embeds_for_batch.extend(self.embed_sentence(self.tokenizer.decode(batch).replace('<pad>', ' ')))
while len(flair_embeds_for_batch) < len(batch):
flair_embeds_for_batch.append(torch.zeros(flair_embed_size))
flair_embeds_for_batch = torch.stack(flair_embeds_for_batch)
embedded.append(flair_embeds_for_batch)
embedded = torch.stack(embedded)
merged = torch.cat((transformer_out, embedded), dim=2)
out_1 = F.relu(self.linear_1(merged))
else:
out_1 = F.relu(self.linear_1(transformer_out))
out_1 = self.dropout(out_1)
logits = self.classification_head(out_1)
@@ -72,3 +98,8 @@ class AutoTokenizerForTokenClassification(nn.Module):
tensor_ids = self.tokenizer.encode(s)
# remove <s> and </s> ids
return tensor_ids[1:-1]
def embed_sentence(self, sentence):
sentence = Sentence(sentence)
self.stacked_embeddings.embed(sentence)
return [token.embedding for token in sentence]
\ No newline at end of file
Loading