Skip to content
Snippets Groups Projects
Commit ea4d3c46 authored by Maja Jabłońska's avatar Maja Jabłońska Committed by Martyna Wiącek
Browse files

Add GraphParser from COMBO

parent b73b5f08
1 merge request!46Merge COMBO 3.0 into master
"""
Adapted from COMBO.
Author: Mateusz Klimaszewski
"""
from typing import List, Optional, Union, Tuple, Dict
from combo import data
from combo.models import base
from combo.models.base import Predictor
import torch
import torch.nn.functional as F
class GraphHeadPredictionModel(Predictor):
pass
"""Head prediction model."""
def __init__(self,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear,
cycle_loss_n: int = 0,
graph_weighting: float = 0.2):
super().__init__()
self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer
self.cycle_loss_n = cycle_loss_n
self.graph_weighting = graph_weighting
def forward(self,
x: Union[torch.Tensor, List[torch.Tensor]],
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
mask: Optional[torch.BoolTensor] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
if mask is None:
mask = x.new_ones(x.size()[-1])
heads_labels = None
if labels is not None and labels[0] is not None:
heads_labels = labels
head_arc_emb = self.head_projection_layer(x)
dep_arc_emb = self.dependency_projection_layer(x)
x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
pred = x.sigmoid() > 0.5
output = {
"prediction": pred,
"probability": x
}
if heads_labels is not None:
if sample_weights is None:
sample_weights = heads_labels.new_ones([mask.size(0)])
output["loss"], output["cycle_loss"] = self._loss(x, heads_labels, mask, sample_weights)
return output
def _cycle_loss(self, pred: torch.Tensor):
BATCH_SIZE, _, _ = pred.size()
loss = pred.new_zeros(BATCH_SIZE)
# Index from 1: as using non __ROOT__ tokens
pred = pred.softmax(-1)[:, 1:, 1:]
x = pred
for i in range(self.cycle_loss_n):
loss += self._batch_trace(x)
# Don't multiple on last iteration
if i < self.cycle_loss_n - 1:
x = x.bmm(pred)
return loss
@staticmethod
def _batch_trace(x: torch.Tensor) -> torch.Tensor:
assert len(x.size()) == 3
BATCH_SIZE, N, M = x.size()
assert N == M
identity = x.new_tensor(torch.eye(N))
identity = identity.reshape((1, N, N))
batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
return (x * batch_identity).sum((-1, -2))
def _loss(self, pred: torch.Tensor, labels: torch.Tensor, mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
BATCH_SIZE, N, M = pred.size()
assert N == M
SENTENCE_LENGTH = N
valid_positions = mask.sum()
result = []
true = labels
# Ignore first pred dimension as it is ROOT token prediction
for i in range(SENTENCE_LENGTH - 1):
pred_i = pred[:, i + 1, 1:].reshape(-1)
true_i = true[:, i + 1, 1:].reshape(-1)
mask_i = mask[:, i]
bce_loss = F.binary_cross_entropy_with_logits(pred_i, true_i, reduction="none").mean(-1) * mask_i
result.append(bce_loss)
cycle_loss = self._cycle_loss(pred)
loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
class GraphDependencyRelationModel(Predictor):
pass
"""Dependency relation parsing model."""
def __init__(self,
head_predictor: GraphHeadPredictionModel,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear,
relation_prediction_layer: base.Linear):
super().__init__()
self.head_predictor = head_predictor
self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer
self.relation_prediction_layer = relation_prediction_layer
def forward(self,
x: Union[torch.Tensor, List[torch.Tensor]],
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
relations_labels, head_labels, enhanced_heads_labels, enhanced_deprels_labels = None, None, None, None
if labels is not None and labels[0] is not None:
relations_labels, head_labels, enhanced_heads_labels = labels
head_output = self.head_predictor(x, enhanced_heads_labels, mask, sample_weights)
head_pred = head_output["probability"]
BATCH_SIZE, LENGTH, _ = head_pred.size()
head_rel_emb = self.head_projection_layer(x)
dep_rel_emb = self.dependency_projection_layer(x)
# All possible edges combinations for each batch
# Repeat interleave to have [emb1, emb1 ... (length times) ... emb1, emb2 ... ]
head_rel_pred = head_rel_emb.repeat_interleave(LENGTH, -2)
# Regular repeat to have all combinations [deprel1, deprel2, ... deprelL, deprel1 ...]
dep_rel_pred = dep_rel_emb.repeat(1, LENGTH, 1)
# All possible edges combinations for each batch
dep_rel_pred = torch.cat((head_rel_pred, dep_rel_pred), dim=-1)
relation_prediction = self.relation_prediction_layer(dep_rel_pred).reshape(BATCH_SIZE, LENGTH, LENGTH, -1)
output = head_output
output["prediction"] = (relation_prediction.argmax(-1), head_output["prediction"])
output["rel_probability"] = relation_prediction
if labels is not None and labels[0] is not None:
if sample_weights is None:
sample_weights = labels.new_ones([mask.size(0)])
loss = self._loss(relation_prediction, relations_labels, enhanced_heads_labels, mask, sample_weights)
output["loss"] = (loss, head_output["loss"])
return output
@staticmethod
def _loss(pred: torch.Tensor,
true: torch.Tensor,
heads_true: torch.Tensor,
mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor:
correct_heads_mask = heads_true.long() == 1
true = true[correct_heads_mask]
pred = pred[correct_heads_mask]
loss = F.cross_entropy(pred, true.long())
return loss.sum() / pred.size(0)
@classmethod
def from_vocab(cls,
vocab: data.Vocabulary,
vocab_namespace: str,
head_predictor: GraphHeadPredictionModel,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear
):
"""Creates parser combining model configuration and vocabulary data."""
assert vocab_namespace in vocab.get_namespaces()
relation_prediction_layer = base.Linear(
in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
out_features=vocab.get_vocab_size(vocab_namespace)
)
return cls(
head_predictor=head_predictor,
head_projection_layer=head_projection_layer,
dependency_projection_layer=dependency_projection_layer,
relation_prediction_layer=relation_prediction_layer
)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment