Skip to content
Snippets Groups Projects
Commit cb34f943 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add matrix extraction

parent 4115a0b1
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -26,6 +26,7 @@ from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader ...@@ -26,6 +26,7 @@ from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader
from data.dataset_loaders import DataLoader from data.dataset_loaders import DataLoader
from modules.model import Model from modules.model import Model
from utils import ConfigurationError from utils import ConfigurationError
from utils.matrices import extract_combo_matrices
logging.setLoggerClass(ComboLogger) logging.setLoggerClass(ComboLogger)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -78,6 +79,8 @@ flags.DEFINE_string(name="tensorboard_name", default="combo", ...@@ -78,6 +79,8 @@ flags.DEFINE_string(name="tensorboard_name", default="combo",
help="Name of the model in TensorBoard logs.") help="Name of the model in TensorBoard logs.")
flags.DEFINE_string(name="config_path", default="", flags.DEFINE_string(name="config_path", default="",
help="Config file path.") help="Config file path.")
flags.DEFINE_boolean(name="save_matrices", default=True,
help="Save relation distribution matrices.")
# Finetune after training flags # Finetune after training flags
flags.DEFINE_string(name="finetuning_training_data_path", default="", flags.DEFINE_string(name="finetuning_training_data_path", default="",
...@@ -351,20 +354,34 @@ def run(_): ...@@ -351,20 +354,34 @@ def run(_):
logger.info("Predicting examples from file", prefix=prefix) logger.info("Predicting examples from file", prefix=prefix)
predictions = []
if FLAGS.conllu_format: if FLAGS.conllu_format:
test_trees = dataset_reader.read(FLAGS.input_file) test_trees = dataset_reader.read(FLAGS.input_file)
with open(FLAGS.output_file, "w") as file: with open(FLAGS.output_file, "w") as file:
for tree in tqdm(test_trees): for tree in tqdm(test_trees):
file.writelines(api.sentence2conllu(predictor.predict_instance(tree), prediction = predictor.predict_instance(tree)
file.writelines(api.sentence2conllu(prediction,
keep_semrel=dataset_reader.use_sem).serialize()) keep_semrel=dataset_reader.use_sem).serialize())
predictions.append(prediction)
else: else:
tokenizer = LamboTokenizer(FLAGS.tokenizer_language) tokenizer = LamboTokenizer(FLAGS.tokenizer_language)
with open(FLAGS.input_file, "r") as file: with open(FLAGS.input_file, "r") as file:
input_sentences = tokenizer.segment(file.read()) input_sentences = tokenizer.segment(file.read())
with open(FLAGS.output_file, "w") as file: with open(FLAGS.output_file, "w") as file:
for sentence in tqdm(input_sentences): for sentence in tqdm(input_sentences):
file.writelines(api.sentence2conllu(predictor.predict(' '.join(sentence)), prediction = predictor.predict(' '.join(sentence))
file.writelines(api.sentence2conllu(prediction,
keep_semrel=dataset_reader.use_sem).serialize()) keep_semrel=dataset_reader.use_sem).serialize())
predictions.append(prediction)
if FLAGS.save_matrices:
logger.info("Saving matrices", prefix=prefix)
extract_combo_matrices(predictions,
pathlib.Path(FLAGS.serialization_dir),
pathlib.Path(FLAGS.input_file),
logger,
prefix)
else: else:
msg = 'No output file for input file {input_file} specified.'.format(input_file=FLAGS.input_file) msg = 'No output file for input file {input_file} specified.'.format(input_file=FLAGS.input_file)
......
"""
Adapted from COMBO
Author: Mateusz Klimaszewski
"""
from typing import Tuple, Dict, Optional, Union, List
import numpy as np
import torch
import torch.nn.functional as F
from combo import data
from combo.models import base, utils
from combo.nn import chu_liu_edmonds
class HeadPredictionModel(base.Predictor):
"""Head prediction model."""
def __init__(self,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear,
cycle_loss_n: int = 0):
super().__init__()
self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer
self.cycle_loss_n = cycle_loss_n
def forward(self,
x: Union[torch.Tensor, List[torch.Tensor]],
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
if mask is None:
mask = x.new_ones(x.size()[-1])
head_arc_emb = self.head_projection_layer(x)
dep_arc_emb = self.dependency_projection_layer(x)
x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1))
if self.training:
pred = x.argmax(-1)
else:
pred = []
# Adding non existing in mask ROOT to lengths
lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1
for idx, length in enumerate(lengths):
probs = x[idx, :].softmax(dim=-1).cpu().numpy()
# We do not want any word to be parent of the root node (ROOT, 0).
# Also setting it to -1 instead of 0 fixes edge case where softmax made all
# but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges)
probs[:, 0] = -1
heads, _ = chu_liu_edmonds.decode_mst(probs.T, length=length, has_labels=False)
heads[0] = 0
pred.append(heads)
pred = torch.from_numpy(np.stack(pred)).to(x.device)
output = {
"prediction": pred[:, 1:],
"probability": x
}
if labels is not None:
if sample_weights is None:
sample_weights = labels.new_ones([mask.size(0)])
output["loss"], output["cycle_loss"] = self._loss(x, labels, mask, sample_weights)
return output
def _cycle_loss(self, pred: torch.Tensor):
BATCH_SIZE, _, _ = pred.size()
loss = pred.new_zeros(BATCH_SIZE)
# Index from 1: as using non __ROOT__ tokens
pred = pred.softmax(-1)[:, 1:, 1:]
x = pred
for i in range(self.cycle_loss_n):
loss += self._batch_trace(x)
# Don't multiple on last iteration
if i < self.cycle_loss_n - 1:
x = x.bmm(pred)
return loss
@staticmethod
def _batch_trace(x: torch.Tensor) -> torch.Tensor:
assert len(x.size()) == 3
BATCH_SIZE, N, M = x.size()
assert N == M
identity = x.new_tensor(torch.eye(N))
identity = identity.reshape((1, N, N))
batch_identity = identity.repeat(BATCH_SIZE, 1, 1)
return (x * batch_identity).sum((-1, -2))
def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
BATCH_SIZE, N, M = pred.size()
assert N == M
SENTENCE_LENGTH = N
valid_positions = mask.sum()
result = []
# Ignore first pred dimension as it is ROOT token prediction
for i in range(SENTENCE_LENGTH - 1):
pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH)
true_i = true[:, i].reshape(-1)
mask_i = mask[:, i]
cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i)
result.append(cross_entropy_loss)
cycle_loss = self._cycle_loss(pred)
loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean()
class DependencyRelationModel(base.Predictor):
"""Dependency relation parsing model."""
def __init__(self,
root_idx: int,
head_predictor: HeadPredictionModel,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear,
relation_prediction_layer: base.Linear):
super().__init__()
self.root_idx = root_idx
self.head_predictor = head_predictor
self.head_projection_layer = head_projection_layer
self.dependency_projection_layer = dependency_projection_layer
self.relation_prediction_layer = relation_prediction_layer
def forward(self,
x: Union[torch.Tensor, List[torch.Tensor]],
mask: Optional[torch.BoolTensor] = None,
labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]:
device = x.device
if mask is not None:
mask = mask[:, 1:]
relations_labels, head_labels = None, None
if labels is not None and labels[0] is not None:
relations_labels, head_labels = labels
if mask is None:
mask = head_labels.new_ones(head_labels.size())
head_output = self.head_predictor(x, mask, head_labels, sample_weights)
head_pred = head_output["probability"]
head_pred_soft = F.softmax(head_pred, dim=-1)
head_rel_emb = self.head_projection_layer(x)
dep_rel_emb = self.dependency_projection_layer(x)
dep_rel_pred = head_pred_soft.bmm(head_rel_emb)
dep_rel_pred = torch.cat((dep_rel_pred, dep_rel_emb), dim=-1)
relation_prediction = self.relation_prediction_layer(dep_rel_pred)
output = head_output
output["embedding"] = dep_rel_pred
if self.training:
output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"])
else:
# Mask root label whenever head is not 0.
relation_prediction_output = relation_prediction[:, 1:].clone()
mask = (head_output["prediction"] == 0)
vocab_size = relation_prediction_output.size(-1)
root_idx = torch.tensor([self.root_idx], device=device)
relation_prediction_output[mask] = (relation_prediction_output
.masked_select(mask.unsqueeze(-1))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, 10e10))
relation_prediction_output[~mask] = (relation_prediction_output
.masked_select(~(mask.unsqueeze(-1)))
.reshape(-1, vocab_size)
.index_fill(-1, root_idx, -10e10))
output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"])
if labels is not None and labels[0] is not None:
if sample_weights is None:
sample_weights = labels.new_ones([mask.size(0)])
loss = self._loss(relation_prediction[:, 1:], relations_labels, mask, sample_weights)
output["loss"] = (loss, head_output["loss"])
return output
@staticmethod
def _loss(pred: torch.Tensor,
true: torch.Tensor,
mask: torch.BoolTensor,
sample_weights: torch.Tensor) -> torch.Tensor:
valid_positions = mask.sum()
BATCH_SIZE, _, DEPENDENCY_RELATIONS = pred.size()
pred = pred.reshape(-1, DEPENDENCY_RELATIONS)
true = true.reshape(-1)
mask = mask.reshape(-1)
loss = utils.masked_cross_entropy(pred, true, mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions
@classmethod
def from_vocab(cls,
vocab: data.Vocabulary,
vocab_namespace: str,
head_predictor: HeadPredictionModel,
head_projection_layer: base.Linear,
dependency_projection_layer: base.Linear
):
"""Creates parser combining model configuration and vocabulary data."""
assert vocab_namespace in vocab.get_namespaces()
relation_prediction_layer = base.Linear(
in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(),
out_features=vocab.get_vocab_size(vocab_namespace)
)
return cls(
head_predictor=head_predictor,
head_projection_layer=head_projection_layer,
dependency_projection_layer=dependency_projection_layer,
relation_prediction_layer=relation_prediction_layer,
root_idx=vocab.get_token_index("root", vocab_namespace)
)
"""
Author: Łukasz Pszenny
"""
from typing import Optional, List
from combo.common.tqdm import Tqdm
import numpy as np
import pandas as pd
from pathlib import Path
from data import Sentence
from utils import ComboLogger
def extract_combo_matrices(predictions: List[Sentence],
serialization_dir: Path,
input_data_path: Path,
logger: Optional[ComboLogger] = None,
logging_prefix: str = ''):
OUTPUT_DIRECTORY_MATRICES = serialization_dir / "combo_output" / "dependency_tree_matrices"
OUTPUT_RELATION_LABEL_DISTRIBUTION = serialization_dir / "combo_output" / "label_distribution_matrices"
meta_sentences = []
meta_ids = []
meta_splits = []
meta_file_names = []
# For saving file names
tmp_meta_file_names = []
sentences = []
sentence_ids = []
# Create directory if it doesn't exist
OUTPUT_DIRECTORY_MATRICES.mkdir(parents=True, exist_ok=True)
OUTPUT_RELATION_LABEL_DISTRIBUTION.mkdir(parents=True, exist_ok=True)
for sentence_id, predicted_sentence in Tqdm.tqdm(enumerate(predictions)):
dependency_tree_matrix = predicted_sentence.relation_distribution
np.savetxt(fname=OUTPUT_DIRECTORY_MATRICES / f'{input_data_path.stem}-{sentence_id}.csv',
X=dependency_tree_matrix,
delimiter=',')
tmp_meta_file_names.append(f'{input_data_path.stem}-{sentence_id}.csv')
# Save relation label matrix
label_distribution_matrix = predicted_sentence.relation_label_distribution
np.savetxt(fname=OUTPUT_RELATION_LABEL_DISTRIBUTION / f'{input_data_path.stem}-{sentence_id}.csv',
X=label_distribution_matrix,
delimiter=',')
sentences.append(' '.join([t.text for t in predicted_sentence.tokens]))
sentence_ids.append(sentence_id)
logger.info(f"\nFinished processing : {input_data_path}", prefix=logging_prefix)
meta_ids += sentence_ids
meta_splits += [input_data_path.stem] * len(sentence_ids)
meta_file_names += tmp_meta_file_names
meta_data = pd.DataFrame({'split': meta_splits,
'sentence_id': meta_ids,
'full_sentence': sentences,
'file_names': meta_file_names})
logger.info("Saving metadata", prefix=logging_prefix)
meta_data.to_csv(serialization_dir / 'metadata.csv', index=False)
logger.info("Finished processing", prefix=logging_prefix)
...@@ -15,6 +15,7 @@ requests~=2.28.2 ...@@ -15,6 +15,7 @@ requests~=2.28.2
tqdm~=4.64.1 tqdm~=4.64.1
urllib3~=1.26.14 urllib3~=1.26.14
filelock~=3.9.0 filelock~=3.9.0
pandas~=2.1.3
pytest~=7.2.2 pytest~=7.2.2
transformers~=4.27.3 transformers~=4.27.3
sacremoses~=0.0.53 sacremoses~=0.0.53
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment