diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index 6ad25590e3f29bcde42266b8ee9cc720787b4388..d8e9d7a28a7fa36b60d108a30a3286026a327e51 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -107,18 +107,16 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm def __init__(self, model_name: str, - projection_dim: int, + projection_dim: int = 0, projection_activation: Optional[allen_nn.Activation] = lambda x: x, projection_dropout_rate: Optional[float] = 0.0, freeze_transformer: bool = True, tokenizer_kwargs: Optional[Dict[str, Any]] = None, transformer_kwargs: Optional[Dict[str, Any]] = None): - super().__init__(model_name, tokenizer_kwargs=tokenizer_kwargs, transformer_kwargs=transformer_kwargs) - self.freeze_transformer = freeze_transformer - if self.freeze_transformer: - self._matched_embedder.eval() - for param in self._matched_embedder.parameters(): - param.requires_grad = False + super().__init__(model_name, + train_parameters=not freeze_transformer, + tokenizer_kwargs=tokenizer_kwargs, + transformer_kwargs=transformer_kwargs) if projection_dim: self.projection_layer = base.Linear(in_features=super().get_output_dim(), out_features=projection_dim, @@ -148,20 +146,6 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm def get_output_dim(self): return self.output_dim - @overrides - def train(self, mode: bool): - if self.freeze_transformer: - self.projection_layer.train(mode) - else: - super().train(mode) - - @overrides - def eval(self): - if self.freeze_transformer: - self.projection_layer.eval() - else: - super().eval() - @token_embedders.TokenEmbedder.register("feats_embedding") class FeatsTokenEmbedder(token_embedders.Embedding): diff --git a/combo/models/parser.py b/combo/models/parser.py index dfb53ab8ded369b01eae4851dd1d7a9936c05bbe..511edffc2f8d17edbc3fd0702e6425a4ec645e4e 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -158,7 +158,7 @@ class DependencyRelationModel(base.Predictor): output["prediction"] = (relation_prediction.argmax(-1)[:, 1:], head_output["prediction"]) else: # Mask root label whenever head is not 0. - relation_prediction_output = relation_prediction[:, 1:] + relation_prediction_output = relation_prediction[:, 1:].clone() mask = (head_output["prediction"] == 0) vocab_size = relation_prediction_output.size(-1) root_idx = torch.tensor([self.root_idx], device=device) diff --git a/combo/predict.py b/combo/predict.py index a5c99fd883b4ee40c5e9c76af44a7c7dbad85bdc..bd9f5d4637bf410bace029604afb915ed05311c2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -24,7 +24,7 @@ class COMBO(predictor.Predictor): model: models.Model, dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), - batch_size: int = 32, + batch_size: int = 1024, line_to_conllu: bool = True) -> None: super().__init__(model, dataset_reader) self.batch_size = batch_size @@ -140,54 +140,56 @@ class COMBO(predictor.Predictor): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] - for idx, token in enumerate(tree_tokens): - for field_name in field_names: - if field_name in predictions: - if field_name in ["xpostag", "upostag", "semrel", "deprel"]: - value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + "_labels") - token[field_name] = value - elif field_name in ["head"]: - token[field_name] = int(predictions[field_name][idx]) - elif field_name == "deps": - # Handled after every other decoding - continue - - elif field_name in ["feats"]: - slices = self._model.morphological_feat.slices - features = [] - prediction = predictions[field_name][idx] - for (cat, cat_indices), pred_idx in zip(slices.items(), prediction): - if cat not in ["__PAD__", "_"]: - value = self.vocab.get_token_from_index(cat_indices[pred_idx], - field_name + "_labels") - # Exclude auxiliary values - if "=None" not in value: - features.append(value) - if len(features) == 0: - field_value = "_" - else: - lowercase_features = [f.lower() for f in features] - arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__) - field_value = "|".join(np.array(features)[arg_indices].tolist()) - - token[field_name] = field_value - elif field_name == "lemma": - prediction = predictions[field_name][idx] - word_chars = [] - for char_idx in prediction[1:-1]: - pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters") - - if pred_char == "__END__": - break - elif pred_char == "__PAD__": - continue - elif "_" in pred_char: - pred_char = "?" - - word_chars.append(pred_char) - token[field_name] = "".join(word_chars) + for field_name in field_names: + if field_name not in predictions: + continue + field_predictions = predictions[field_name] + for idx, token in enumerate(tree_tokens): + if field_name in {"xpostag", "upostag", "semrel", "deprel"}: + value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels") + token[field_name] = value + elif field_name == "head": + token[field_name] = int(field_predictions[idx]) + elif field_name == "deps": + # Handled after every other decoding + continue + + elif field_name == "feats": + slices = self._model.morphological_feat.slices + features = [] + prediction = field_predictions[idx] + for (cat, cat_indices), pred_idx in zip(slices.items(), prediction): + if cat not in ["__PAD__", "_"]: + value = self.vocab.get_token_from_index(cat_indices[pred_idx], + field_name + "_labels") + # Exclude auxiliary values + if "=None" not in value: + features.append(value) + if len(features) == 0: + field_value = "_" else: - raise NotImplementedError(f"Unknown field name {field_name}!") + lowercase_features = [f.lower() for f in features] + arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__) + field_value = "|".join(np.array(features)[arg_indices].tolist()) + + token[field_name] = field_value + elif field_name == "lemma": + prediction = field_predictions[idx] + word_chars = [] + for char_idx in prediction[1:-1]: + pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters") + + if pred_char == "__END__": + break + elif pred_char == "__PAD__": + continue + elif "_" in pred_char: + pred_char = "?" + + word_chars.append(pred_char) + token[field_name] = "".join(word_chars) + else: + raise NotImplementedError(f"Unknown field name {field_name}!") if "enhanced_head" in predictions and predictions["enhanced_head"]: # TODO off-by-one hotfix, refactor @@ -212,7 +214,7 @@ class COMBO(predictor.Predictor): @classmethod def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), - batch_size: int = 32, + batch_size: int = 1024, cuda_device: int = -1): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models") diff --git a/combo/training/trainer.py b/combo/training/trainer.py index aeb9f097369f515b3860d961f638870ec17f6786..26bd75f7fbe6917f144b820bbbb1c7e14c3c8e9d 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -230,22 +230,24 @@ class GradientDescentTrainer(training.GradientDescentTrainer): patience: int = None, validation_metric: str = "-loss", num_epochs: int = 20, - cuda_device: int = -1, + cuda_device: Optional[Union[int, torch.device]] = -1, grad_norm: float = None, grad_clipping: float = None, distributed: bool = None, world_size: int = 1, num_gradient_accumulation_steps: int = 1, - opt_level: Optional[str] = None, use_amp: bool = False, - optimizer: common.Lazy[optimizers.Optimizer] = None, + no_grad: List[str] = None, + optimizer: common.Lazy[optimizers.Optimizer] = common.Lazy(optimizers.Optimizer.default), learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None, momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None, tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None, moving_average: common.Lazy[moving_average.MovingAverage] = None, - checkpointer: common.Lazy[training.Checkpointer] = None, + checkpointer: common.Lazy[training.Checkpointer] = common.Lazy(training.Checkpointer), batch_callbacks: List[training.BatchCallback] = None, epoch_callbacks: List[training.EpochCallback] = None, + end_callbacks: List[training.EpochCallback] = None, + trainer_callbacks: List[training.TrainerCallback] = None, ) -> "training.Trainer": if tensorboard_writer is None: tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter) @@ -265,6 +267,7 @@ class GradientDescentTrainer(training.GradientDescentTrainer): world_size=world_size, num_gradient_accumulation_steps=num_gradient_accumulation_steps, use_amp=use_amp, + no_grad=no_grad, optimizer=optimizer, learning_rate_scheduler=learning_rate_scheduler, momentum_scheduler=momentum_scheduler, @@ -273,4 +276,6 @@ class GradientDescentTrainer(training.GradientDescentTrainer): checkpointer=checkpointer, batch_callbacks=batch_callbacks, epoch_callbacks=epoch_callbacks, + end_callbacks=end_callbacks, + trainer_callbacks=trainer_callbacks, ) diff --git a/combo/utils/graph.py b/combo/utils/graph.py index 651c14a7d79b7ea3c277b9466f5e050435a7a01b..3352625e6665ca1cd3196506ed5e50183fedfbb0 100644 --- a/combo/utils/graph.py +++ b/combo/utils/graph.py @@ -88,6 +88,7 @@ def _dfs(graph, start, end): def restore_collapse_edges(tree_tokens): + # https://gist.github.com/hankcs/776e7d95c19e5ff5da8469fe4e9ab050 empty_tokens = [] for token in tree_tokens: deps = token["deps"].split("|")