Skip to content
Snippets Groups Projects
Commit 02429b61 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Add simple predictor creation in code from pretrained model.

parent 7210dc10
Branches
Tags
No related merge requests found
...@@ -83,7 +83,7 @@ class FeedForwardPredictor(Predictor): ...@@ -83,7 +83,7 @@ class FeedForwardPredictor(Predictor):
pred = pred.reshape(-1, CLASSES) pred = pred.reshape(-1, CLASSES)
true = true.reshape(-1) true = true.reshape(-1)
mask = mask.reshape(-1) mask = mask.reshape(-1)
loss = utils.masked_cross_entropy(pred, true, mask) * mask loss = utils.masked_cross_entropy(pred, true, mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions return loss.sum() / valid_positions
......
...@@ -74,12 +74,13 @@ class LemmatizerModel(base.Predictor): ...@@ -74,12 +74,13 @@ class LemmatizerModel(base.Predictor):
valid_positions = mask.sum() valid_positions = mask.sum()
mask = mask.reshape(-1) mask = mask.reshape(-1)
true = true.reshape(-1) true = true.reshape(-1)
loss = utils.masked_cross_entropy(pred, true, mask) * mask loss = utils.masked_cross_entropy(pred, true, mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions return loss.sum() / valid_positions
@classmethod @classmethod
def from_vocab(cls, def from_vocab(cls,
vocab: data.Vocabulary, vocab: data.Vocabulary,
char_vocab_namespace: str, char_vocab_namespace: str,
lemma_vocab_namespace: str, lemma_vocab_namespace: str,
......
...@@ -63,11 +63,11 @@ class MorphologicalFeatures(base.Predictor): ...@@ -63,11 +63,11 @@ class MorphologicalFeatures(base.Predictor):
if loss is None: if loss is None:
loss = loss_func(pred[:, cat_indices], loss = loss_func(pred[:, cat_indices],
true[:, cat_indices].argmax(dim=1), true[:, cat_indices].argmax(dim=1),
mask) * mask mask)
else: else:
loss += loss_func(pred[:, cat_indices], loss += loss_func(pred[:, cat_indices],
true[:, cat_indices].argmax(dim=1), true[:, cat_indices].argmax(dim=1),
mask) * mask mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions return loss.sum() / valid_positions
......
...@@ -93,7 +93,7 @@ class HeadPredictionModel(base.Predictor): ...@@ -93,7 +93,7 @@ class HeadPredictionModel(base.Predictor):
pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH) pred_i = pred[:, i + 1, :].reshape(BATCH_SIZE, SENTENCE_LENGTH)
true_i = true[:, i].reshape(-1) true_i = true[:, i].reshape(-1)
mask_i = mask[:, i] mask_i = mask[:, i]
cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i) * mask_i cross_entropy_loss = utils.masked_cross_entropy(pred_i, true_i, mask_i)
result.append(cross_entropy_loss) result.append(cross_entropy_loss)
cycle_loss = self._cycle_loss(pred) cycle_loss = self._cycle_loss(pred)
loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1) loss = torch.stack(result).transpose(1, 0) * sample_weights.unsqueeze(-1)
...@@ -162,7 +162,7 @@ class DependencyRelationModel(base.Predictor): ...@@ -162,7 +162,7 @@ class DependencyRelationModel(base.Predictor):
pred = pred.reshape(-1, DEPENDENCY_RELATIONS) pred = pred.reshape(-1, DEPENDENCY_RELATIONS)
true = true.reshape(-1) true = true.reshape(-1)
mask = mask.reshape(-1) mask = mask.reshape(-1)
loss = utils.masked_cross_entropy(pred, true, mask) * mask loss = utils.masked_cross_entropy(pred, true, mask)
loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1) loss = loss.reshape(BATCH_SIZE, -1) * sample_weights.unsqueeze(-1)
return loss.sum() / valid_positions return loss.sum() / valid_positions
......
...@@ -5,4 +5,4 @@ import torch.nn.functional as F ...@@ -5,4 +5,4 @@ import torch.nn.functional as F
def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor: def masked_cross_entropy(pred: torch.Tensor, true: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
mask = mask.float().unsqueeze(-1) mask = mask.float().unsqueeze(-1)
pred = pred + (mask + 1e-45).log() pred = pred + (mask + 1e-45).log()
return F.cross_entropy(pred, true, reduction='none') return F.cross_entropy(pred, true, reduction='none') * mask
...@@ -57,6 +57,9 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -57,6 +57,9 @@ class SemanticMultitaskPredictor(predictor.Predictor):
logger.info('Took {} ms'.format((end_time - start_time) * 1000.0)) logger.info('Took {} ms'.format((end_time - start_time) * 1000.0))
return result return result
def predict_string(self, sentence: str):
return self.predict_json({'sentence': sentence})
@overrides @overrides
def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: def predict_json(self, inputs: common.JsonDict) -> common.JsonDict:
start_time = time.time() start_time = time.time()
...@@ -139,3 +142,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): ...@@ -139,3 +142,13 @@ class SemanticMultitaskPredictor(predictor.Predictor):
def with_spacy_tokenizer(cls, model: models.Model, def with_spacy_tokenizer(cls, model: models.Model,
dataset_reader: allen_data.DatasetReader): dataset_reader: allen_data.DatasetReader):
return cls(model, dataset_reader, tokenizers.SpacyTokenizer()) return cls(model, dataset_reader, tokenizers.SpacyTokenizer())
@classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer()):
util.import_module_and_submodules('combo.commands')
util.import_module_and_submodules('combo.models')
util.import_module_and_submodules('combo.training')
model = models.Model.from_archive(path)
dataset_reader = allen_data.DatasetReader.from_params(
models.load_archive(path).config['dataset_reader'])
return cls(model, dataset_reader, tokenizer)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment