diff --git a/combo/predictors/predictor_model.py b/combo/predictors/predictor_model.py index 70025092004af25c289597cbc88cf4f2253c1fa0..38c9c668bbf00d379a42a98d74b533f985e7ef43 100644 --- a/combo/predictors/predictor_model.py +++ b/combo/predictors/predictor_model.py @@ -24,6 +24,7 @@ from combo.data.dataset_readers.dataset_reader import DatasetReader from combo.data.instance import JsonDict, Instance from combo.modules.model import Model from combo.nn import utils +from combo.nn.utils import move_to_device logger = logging.getLogger(__name__) @@ -135,7 +136,7 @@ class PredictorModule(pl.LightningModule, FromParameters): dataset = Batch(instances) dataset.index_instances(self._model.vocab) - dataset_tensor_dict = util.move_to_device(dataset.as_tensor_dict(), self.cuda_device) + dataset_tensor_dict = move_to_device(dataset.as_tensor_dict(), self.cuda_device) # To bypass "RuntimeError: cudnn RNN backward can only be called in training mode" with backends.cudnn.flags(enabled=False): outputs = self._model.make_output_human_readable( @@ -331,96 +332,3 @@ class PredictorModule(pl.LightningModule, FromParameters): for json_dict in json_dicts: instances.append(self._json_to_instance(json_dict)) return instances - # - # @classmethod - # def from_path( - # cls, - # archive_path: Union[str, Path], - # predictor_name: str = None, - # cuda_device: int = -1, - # dataset_reader_to_load: str = "validation", - # frozen: bool = True, - # import_plugins: bool = True, - # overrides: Union[str, Dict[str, Any]] = "", - # ) -> "Predictor": - # """ - # Instantiate a `Predictor` from an archive path. - # - # If you need more detailed configuration options, such as overrides, - # please use `from_archive`. - # - # # Parameters - # - # archive_path : `Union[str, Path]` - # The path to the archive. - # predictor_name : `str`, optional (default=`None`) - # Name that the predictor is registered as, or None to use the - # predictor associated with the model. - # cuda_device : `int`, optional (default=`-1`) - # If `cuda_device` is >= 0, the model will be loaded onto the - # corresponding GPU. Otherwise it will be loaded onto the CPU. - # dataset_reader_to_load : `str`, optional (default=`"validation"`) - # Which dataset reader to load from the archive, either "train" or - # "validation". - # frozen : `bool`, optional (default=`True`) - # If we should call `model.eval()` when building the predictor. - # import_plugins : `bool`, optional (default=`True`) - # If `True`, we attempt to import plugins before loading the predictor. - # This comes with additional overhead, but means you don't need to explicitly - # import the modules that your predictor depends on as long as those modules - # can be found by `allennlp.common.plugins.import_plugins()`. - # overrides : `Union[str, Dict[str, Any]]`, optional (default = `""`) - # JSON overrides to apply to the unarchived `Params` object. - # - # # Returns - # - # `Predictor` - # A Predictor instance. - # """ - # if import_plugins: - # plugins.import_plugins() - # return Predictor.from_archive( - # load_archive(archive_path, cuda_device=cuda_device, overrides=overrides), - # predictor_name, - # dataset_reader_to_load=dataset_reader_to_load, - # frozen=frozen, - # ) - # - # @classmethod - # def from_archive( - # cls, - # archive: Archive, - # predictor_name: str = None, - # dataset_reader_to_load: str = "validation", - # frozen: bool = True, - # ) -> "Predictor": - # """ - # Instantiate a `Predictor` from an [`Archive`](../models/archival.md); - # that is, from the result of training a model. Optionally specify which `Predictor` - # subclass; otherwise, we try to find a corresponding predictor in `DEFAULT_PREDICTORS`, or if - # one is not found, the base class (i.e. `Predictor`) will be used. Optionally specify - # which [`DatasetReader`](../data/dataset_readers/dataset_reader.md) should be loaded; - # otherwise, the validation one will be used if it exists followed by the training dataset reader. - # Optionally specify if the loaded model should be frozen, meaning `model.eval()` will be called. - # """ - # # Duplicate the config so that the config inside the archive doesn't get consumed - # config = archive.config.duplicate() - # - # if not predictor_name: - # model_type = config.get("model").get("type") - # model_class, _ = Model.resolve_class_name(model_type) - # predictor_name = model_class.default_predictor - # predictor_class: Type[Predictor] = ( - # Predictor.by_name(predictor_name) if predictor_name is not None else cls # type: ignore - # ) - # - # if dataset_reader_to_load == "validation": - # dataset_reader = archive.validation_dataset_reader - # else: - # dataset_reader = archive.dataset_reader - # - # model = archive.model - # if frozen: - # model.eval() - # - # return predictor_class(model, dataset_reader)