diff --git a/combo/main.py b/combo/main.py index 250ac5ae23e827c795093ffbc57889b169a963c6..4faa916624c1fde4fb48110fadf12ff9292c6cea 100755 --- a/combo/main.py +++ b/combo/main.py @@ -26,6 +26,7 @@ from data import LamboTokenizer, Sentence, Vocabulary, DatasetReader from data.dataset_loaders import DataLoader from modules.model import Model from utils import ConfigurationError +from utils.matrices import extract_combo_matrices logging.setLoggerClass(ComboLogger) logger = logging.getLogger(__name__) @@ -78,6 +79,8 @@ flags.DEFINE_string(name="tensorboard_name", default="combo", help="Name of the model in TensorBoard logs.") flags.DEFINE_string(name="config_path", default="", help="Config file path.") +flags.DEFINE_boolean(name="save_matrices", default=True, + help="Save relation distribution matrices.") # Finetune after training flags flags.DEFINE_string(name="finetuning_training_data_path", default="", @@ -351,20 +354,34 @@ def run(_): logger.info("Predicting examples from file", prefix=prefix) + predictions = [] if FLAGS.conllu_format: test_trees = dataset_reader.read(FLAGS.input_file) with open(FLAGS.output_file, "w") as file: for tree in tqdm(test_trees): - file.writelines(api.sentence2conllu(predictor.predict_instance(tree), + prediction = predictor.predict_instance(tree) + file.writelines(api.sentence2conllu(prediction, keep_semrel=dataset_reader.use_sem).serialize()) + predictions.append(prediction) + else: tokenizer = LamboTokenizer(FLAGS.tokenizer_language) with open(FLAGS.input_file, "r") as file: input_sentences = tokenizer.segment(file.read()) with open(FLAGS.output_file, "w") as file: for sentence in tqdm(input_sentences): - file.writelines(api.sentence2conllu(predictor.predict(' '.join(sentence)), + prediction = predictor.predict(' '.join(sentence)) + file.writelines(api.sentence2conllu(prediction, keep_semrel=dataset_reader.use_sem).serialize()) + predictions.append(prediction) + + if FLAGS.save_matrices: + logger.info("Saving matrices", prefix=prefix) + extract_combo_matrices(predictions, + pathlib.Path(FLAGS.serialization_dir), + pathlib.Path(FLAGS.input_file), + logger, + prefix) else: msg = 'No output file for input file {input_file} specified.'.format(input_file=FLAGS.input_file) diff --git a/combo/models/parser.py b/combo/models/parser.py deleted file mode 100644 index 13943b5dc5d436107d010f2c666fe6cf4de579d1..0000000000000000000000000000000000000000 --- a/combo/models/parser.py +++ /dev/null @@ -1,222 +0,0 @@ -""" -Adapted from COMBO -Author: Mateusz Klimaszewski -""" -from typing import Tuple, Dict, Optional, Union, List - -import numpy as np -import torch -import torch.nn.functional as F - -from combo import data -from combo.models import base, utils -from combo.nn import chu_liu_edmonds - - -class HeadPredictionModel(base.Predictor): - """Head prediction model.""" - - def __init__(self, - head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear, - cycle_loss_n: int = 0): - super().__init__() - self.head_projection_layer = head_projection_layer - self.dependency_projection_layer = dependency_projection_layer - self.cycle_loss_n = cycle_loss_n - - def forward(self, - x: Union[torch.Tensor, List[torch.Tensor]], - mask: Optional[torch.BoolTensor] = None, - labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: - if mask is None: - mask = x.new_ones(x.size()[-1]) - - head_arc_emb = self.head_projection_layer(x) - dep_arc_emb = self.dependency_projection_layer(x) - x = dep_arc_emb.bmm(head_arc_emb.transpose(2, 1)) - - if self.training: - pred = x.argmax(-1) - else: - pred = [] - # Adding non existing in mask ROOT to lengths - lengths = mask.data.sum(dim=1).long().cpu().numpy() + 1 - for idx, length in enumerate(lengths): - probs = x[idx, :].softmax(dim=-1).cpu().numpy() - - # We do not want any word to be parent of the root node (ROOT, 0). - # Also setting it to -1 instead of 0 fixes edge case where softmax made all - # but ROOT prediction to EXACTLY 0.0 and it might cause in many ROOT -> word edges) - probs[:, 0] = -1 - heads, _ = chu_liu_edmonds.decode_mst(probs.T, length=length, has_labels=False) - heads[0] = 0 - pred.append(heads) - pred = torch.from_numpy(np.stack(pred)).to(x.device) - - output = { - "prediction": pred[:, 1:], - "probability": x - } - - if labels is not None: - if sample_weights is None: - sample_weights = labels.new_ones([mask.size(0)]) - output["loss"], output["cycle_loss"] = self._loss(x, labels, mask, sample_weights) - - return output - - def _cycle_loss(self, pred: torch.Tensor): - BATCH_SIZE, _, _ = pred.size() - loss = pred.new_zeros(BATCH_SIZE) - # Index from 1: as using non __ROOT__ tokens - pred = pred.softmax(-1)[:, 1:, 1:] - x = pred - for i in range(self.cycle_loss_n): - loss += self._batch_trace(x) - - # Don't multiple on last iteration - if i < self.cycle_loss_n - 1: - x = x.bmm(pred) - - return loss - - @staticmethod - def _batch_trace(x: torch.Tensor) -> torch.Tensor: - assert len(x.size()) == 3 - BATCH_SIZE, N, M = x.size() - assert N == M - identity = x.new_tensor(torch.eye(N)) - identity = identity.reshape((1, N, N)) - batch_identity = identity.repeat(BATCH_SIZE, 1, 1) - return (x * batch_identity).sum((-1, -2)) - - def _loss(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor, - sample_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - BATCH_SIZE, N, M = pred.size() - assert N == M - SENTENCE_LENGTH = N - - valid_positions = mask.sum() - - result = [] - # Ignore first pred dimension as it is ROOT token prediction - for i in range(SENTENCE_LENGTH - 1): - pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH) - true_i = true[:, i].reshape(-1) - mask_i = mask[:, i] - cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i) - result.append(cross_entropy_loss) - cycle_loss = self._cycle_loss(pred) - loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1) - return loss.sum() / valid_positions + cycle_loss.mean(), cycle_loss.mean() - - -class DependencyRelationModel(base.Predictor): - """Dependency relation parsing model.""" - - def __init__(self, - root_idx: int, - head_predictor: HeadPredictionModel, - head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear, - relation_prediction_layer: base.Linear): - super().__init__() - self.root_idx = root_idx - self.head_predictor = head_predictor - self.head_projection_layer = head_projection_layer - self.dependency_projection_layer = dependency_projection_layer - self.relation_prediction_layer = relation_prediction_layer - - def forward(self, - x: Union[torch.Tensor, List[torch.Tensor]], - mask: Optional[torch.BoolTensor] = None, - labels: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None, - sample_weights: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None) -> Dict[str, torch.Tensor]: - device = x.device - if mask is not None: - mask = mask[:, 1:] - relations_labels, head_labels = None, None - if labels is not None and labels[0] is not None: - relations_labels, head_labels = labels - if mask is None: - mask = head_labels.new_ones(head_labels.size()) - - head_output = self.head_predictor(x, mask, head_labels, sample_weights) - head_pred = head_output["probability"] - head_pred_soft = F.softmax(head_pred, dim=-1) - - head_rel_emb = self.head_projection_layer(x) - - dep_rel_emb = self.dependency_projection_layer(x) - - dep_rel_pred = head_pred_soft.bmm(head_rel_emb) - dep_rel_pred = torch.cat((dep_rel_pred, dep_rel_emb), dim=-1) - relation_prediction = self.relation_prediction_layer(dep_rel_pred) - output = head_output - output["embedding"] = dep_rel_pred - - if self.training: - output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) - else: - # Mask root label whenever head is not 0. - relation_prediction_output = relation_prediction[:, 1:].clone() - mask = (head_output["prediction"] == 0) - vocab_size = relation_prediction_output.size(-1) - root_idx = torch.tensor([self.root_idx], device=device) - relation_prediction_output[mask] = (relation_prediction_output - .masked_select(mask.unsqueeze(-1)) - .reshape(-1, vocab_size) - .index_fill(-1, root_idx, 10e10)) - relation_prediction_output[~mask] = (relation_prediction_output - .masked_select(~(mask.unsqueeze(-1))) - .reshape(-1, vocab_size) - .index_fill(-1, root_idx, -10e10)) - output["prediction"] = (relation_prediction_output.argmax(-1), head_output["prediction"]) - - if labels is not None and labels[0] is not None: - if sample_weights is None: - sample_weights = labels.new_ones([mask.size(0)]) - loss = self._loss(relation_prediction[:, 1:], relations_labels, mask, sample_weights) - output["loss"] = (loss, head_output["loss"]) - - return output - - @staticmethod - def _loss(pred: torch.Tensor, - true: torch.Tensor, - mask: torch.BoolTensor, - sample_weights: torch.Tensor) -> torch.Tensor: - - valid_positions = mask.sum() - - BATCH_SIZE, _, DEPENDENCY_RELATIONS = pred.size() - pred = pred.reshape(-1, DEPENDENCY_RELATIONS) - true = true.reshape(-1) - mask = mask.reshape(-1) - loss = utils.masked_cross_entropy(pred, true, mask) - loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) - return loss.sum() / valid_positions - - @classmethod - def from_vocab(cls, - vocab: data.Vocabulary, - vocab_namespace: str, - head_predictor: HeadPredictionModel, - head_projection_layer: base.Linear, - dependency_projection_layer: base.Linear - ): - """Creates parser combining model configuration and vocabulary data.""" - assert vocab_namespace in vocab.get_namespaces() - relation_prediction_layer = base.Linear( - in_features=head_projection_layer.get_output_dim() + dependency_projection_layer.get_output_dim(), - out_features=vocab.get_vocab_size(vocab_namespace) - ) - return cls( - head_predictor=head_predictor, - head_projection_layer=head_projection_layer, - dependency_projection_layer=dependency_projection_layer, - relation_prediction_layer=relation_prediction_layer, - root_idx=vocab.get_token_index("root", vocab_namespace) - ) diff --git a/combo/utils/matrices.py b/combo/utils/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..7a982db7ecb8a0a19db0a30ce04024706d92697c --- /dev/null +++ b/combo/utils/matrices.py @@ -0,0 +1,64 @@ +""" +Author: Åukasz Pszenny +""" +from typing import Optional, List + +from combo.common.tqdm import Tqdm +import numpy as np +import pandas as pd +from pathlib import Path + +from data import Sentence +from utils import ComboLogger + + +def extract_combo_matrices(predictions: List[Sentence], + serialization_dir: Path, + input_data_path: Path, + logger: Optional[ComboLogger] = None, + logging_prefix: str = ''): + OUTPUT_DIRECTORY_MATRICES = serialization_dir / "combo_output" / "dependency_tree_matrices" + OUTPUT_RELATION_LABEL_DISTRIBUTION = serialization_dir / "combo_output" / "label_distribution_matrices" + + meta_sentences = [] + meta_ids = [] + meta_splits = [] + meta_file_names = [] + # For saving file names + tmp_meta_file_names = [] + + sentences = [] + sentence_ids = [] + + # Create directory if it doesn't exist + OUTPUT_DIRECTORY_MATRICES.mkdir(parents=True, exist_ok=True) + OUTPUT_RELATION_LABEL_DISTRIBUTION.mkdir(parents=True, exist_ok=True) + + for sentence_id, predicted_sentence in Tqdm.tqdm(enumerate(predictions)): + dependency_tree_matrix = predicted_sentence.relation_distribution + np.savetxt(fname=OUTPUT_DIRECTORY_MATRICES / f'{input_data_path.stem}-{sentence_id}.csv', + X=dependency_tree_matrix, + delimiter=',') + tmp_meta_file_names.append(f'{input_data_path.stem}-{sentence_id}.csv') + + # Save relation label matrix + label_distribution_matrix = predicted_sentence.relation_label_distribution + np.savetxt(fname=OUTPUT_RELATION_LABEL_DISTRIBUTION / f'{input_data_path.stem}-{sentence_id}.csv', + X=label_distribution_matrix, + delimiter=',') + sentences.append(' '.join([t.text for t in predicted_sentence.tokens])) + sentence_ids.append(sentence_id) + logger.info(f"\nFinished processing : {input_data_path}", prefix=logging_prefix) + + meta_ids += sentence_ids + meta_splits += [input_data_path.stem] * len(sentence_ids) + meta_file_names += tmp_meta_file_names + + meta_data = pd.DataFrame({'split': meta_splits, + 'sentence_id': meta_ids, + 'full_sentence': sentences, + 'file_names': meta_file_names}) + + logger.info("Saving metadata", prefix=logging_prefix) + meta_data.to_csv(serialization_dir / 'metadata.csv', index=False) + logger.info("Finished processing", prefix=logging_prefix) diff --git a/requirements.txt b/requirements.txt index b25e1f617c182c60aec37f6833a379c3b2a896d1..0835382a9608b729b290c071eb07ad34c0636444 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ requests~=2.28.2 tqdm~=4.64.1 urllib3~=1.26.14 filelock~=3.9.0 +pandas~=2.1.3 pytest~=7.2.2 transformers~=4.27.3 sacremoses~=0.0.53