diff --git a/combo/models/base.py b/combo/models/base.py index b824d3f2f0199fce0e72490ae02b33f1355888ba..f6c1affb363f4c96bdc52fc4fc2331713373d169 100644 --- a/combo/models/base.py +++ b/combo/models/base.py @@ -83,7 +83,7 @@ class FeedForwardPredictor(Predictor): pred = pred.reshape(-1, CLASSES) true = true.reshape(-1) mask = mask.reshape(-1) - loss = utils.masked_cross_entropy(pred, true, mask) * mask + loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions diff --git a/combo/models/lemma.py b/combo/models/lemma.py index 8ab4030245ceff6931a7bd13049ce92cc2a0ccb6..0bea1c6c4289966dff636856ce3d2f96c9c8ee96 100644 --- a/combo/models/lemma.py +++ b/combo/models/lemma.py @@ -74,12 +74,13 @@ class LemmatizerModel(base.Predictor): valid_positions = mask.sum() mask = mask.reshape(-1) true = true.reshape(-1) - loss = utils.masked_cross_entropy(pred, true, mask) * mask + loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions @classmethod def from_vocab(cls, + vocab: data.Vocabulary, char_vocab_namespace: str, lemma_vocab_namespace: str, diff --git a/combo/models/morpho.py b/combo/models/morpho.py index 238d4284ce97d5508ccb6ef663a25c317efc77bf..ee7cc24f5fa76970b71014d445d29cdf7df77c2d 100644 --- a/combo/models/morpho.py +++ b/combo/models/morpho.py @@ -63,11 +63,11 @@ class MorphologicalFeatures(base.Predictor): if loss is None: loss = loss_func(pred[:, cat_indices], true[:, cat_indices].argmax(dim=1), - mask) * mask + mask) else: loss += loss_func(pred[:, cat_indices], true[:, cat_indices].argmax(dim=1), - mask) * mask + mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions diff --git a/combo/models/parser.py b/combo/models/parser.py index 63f4c8af6641c9aca49ca683bbe257bd3a0b1b6f..91d77ce22dd6bf7495cd8c365d17870eda90db7a 100644 --- a/combo/models/parser.py +++ b/combo/models/parser.py @@ -93,7 +93,7 @@ class HeadPredictionModel(base.Predictor): pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH) true_i = true[:, i].reshape(-1) mask_i = mask[:, i] - cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i) * mask_i + cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i) result.append(cross_entropy_loss) cycle_loss = self._cycle_loss(pred) loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1) @@ -162,7 +162,7 @@ class DependencyRelationModel(base.Predictor): pred = pred.reshape(-1, DEPENDENCY_RELATIONS) true = true.reshape(-1) mask = mask.reshape(-1) - loss = utils.masked_cross_entropy(pred, true, mask) * mask + loss = utils.masked_cross_entropy(pred, true, mask) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) return loss.sum() / valid_positions diff --git a/combo/models/utils.py b/combo/models/utils.py index f97913303415552f6c594ac857b563cab52b6c3b..d4e29760ea2e9b89e2800fcc22292eb32eaf2d92 100644 --- a/combo/models/utils.py +++ b/combo/models/utils.py @@ -5,4 +5,4 @@ import torch.nn.functional as F def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: mask = mask.float().unsqueeze(-1) pred = pred + (mask + 1e-45).log() - return F.cross_entropy(pred, true, reduction='none') + return F.cross_entropy(pred, true, reduction='none') * mask diff --git a/combo/predict.py b/combo/predict.py index 9e4f8326702724efcdcdefb629d567353606b976..6514ba29424a14778da13abd5044828f02e3fed8 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -57,6 +57,9 @@ class SemanticMultitaskPredictor(predictor.Predictor): logger.info('Took {} ms'.format((end_time - start_time) * 1000.0)) return result + def predict_string(self, sentence: str): + return self.predict_json({'sentence': sentence}) + @overrides def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: start_time = time.time() @@ -139,3 +142,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): def with_spacy_tokenizer(cls, model: models.Model, dataset_reader: allen_data.DatasetReader): return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) + + @classmethod + def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer()): + util.import_module_and_submodules('combo.commands') + util.import_module_and_submodules('combo.models') + util.import_module_and_submodules('combo.training') + model = models.Model.from_archive(path) + dataset_reader = allen_data.DatasetReader.from_params( + models.load_archive(path).config['dataset_reader']) + return cls(model, dataset_reader, tokenizer)