Base Model Handler¶
- class qugen.main.generator.base_model_handler.BaseModelHandler¶
It implements the interface for each of the models handlers (continuous QGAN/QCBM and discrete QGAN/QCBM), which includes building the models, training them, saving and reloading them, and generating samples from them.
- abstract build(*args, **kwargs) BaseModelHandler ¶
Define the architecture of the model. Weights initialization is also typically performed here.
- abstract predict(*args) array ¶
Draw samples from the model.
- abstract reload(file_path: Path) BaseModelHandler ¶
Loads the model from a set of weights.
- Parameters:
file_path (pathlib.Path): source file for the model weights
-
abstract save(file_path: Path, overwrite: bool =
True
) BaseModelHandler ¶ Saves the model weights to a file.
- Parameters:
file_path (pathlib.Path): destination file for model weights overwrite (bool): Flag indicating if any existing file at the target location should be overwritten
- abstract train(*args) BaseModelHandler ¶
Perform training of the model.