Continuous Quantum Generative Adversarial Network (QGAN)¶
- class qugen.main.generator.continuous_qgan_model_handler.ContinuousQGANModelHandler¶
Parameters:
-
build(model_name: str, data_set: str, n_qubits: int =
2
, circuit_depth: int =1
, random_seed: int =42
, transformation: str ='pit'
, save_artifacts=True
, slower_progress_update=False
) BaseModelHandler ¶ Build the continuous qgan model. This defines the architecture of the model, including the circuit ansatz, data transformation and whether the artifacts are saved.
- Args:
model_name (int): The name which will be used to save the data to disk. data_set: The name of the data set which gets as part of the model name n_qubits (int, optional): Number of qubits. Defaults to 2. circuit_depth (int, optional): Number of repetitions of qml.StronglyEntanglingLayers. Defaults to 1. random_seed (int, optional): Random seed for reproducibility. Defaults to 42. transformation (str, optional): Type of normalization, either “minmax” or “pit”. Defaults to “pit”. save_artifacts (bool, optional): Whether to save the artifacts to disk. Defaults to True. slower_progress_update (bool, optional): Controls how often the progress bar is updated. If set to True, update every 10 seconds at most, otherwise use tqdm defaults. Defaults to False.
- Returns:
BaseModelHandler: Return the built model handler. It is not strictly necessary to overwrite the existing variable with this since all changes are made in place.
-
predict(n_samples: int =
32
) array ¶ Generate samples from the trained model and perform the inverse of the data transformation which was used to transform the training data to be able to compute the KL-divergence in the original space.
- Args:
n_samples (int, optional): Number of samples to generate. Defaults to 32.
- Returns:
np.array: Array of samples of shape (n_samples, sample_dimension).
-
predict_transform(n_samples: int =
32
) array ¶ Generate samples from the trained model in the transformed space (the n-dimensional unit cube).
- Args:
n_samples (int, optional): Number of samples to generate. Defaults to 32.
- Returns:
np.array: Array of samples of shape (n_samples, sample_dimension).
- reload(model_name: str, epoch: int) BaseModelHandler ¶
Reload the model from the artifacts including the parameters for the generator and the discriminator, the metadata and the data transformation file (reverse lookup table or original min and max of the training data).
- Args:
model_name (str): The name of the model to reload. epoch (int): The epoch to reload.
- Returns:
BaseModelHandler: The reloaded model, but changes have been made in place as well.
-
sample(n_samples: int =
32
)¶ Generate samples from the trained model.
- Args:
n_samples (int, optional): Number of samples to generate. Defaults to 32.
- Returns:
np.array: Array of samples of shape (n_samples, sample_dimension).
-
save(file_path: Path, overwrite: bool =
True
) BaseModelHandler ¶ Save the generator and discriminator weights to disk.
- Args:
file_path (Path): The paths where the pickled tuple of generator and discriminator weights will be placed. overwrite (bool, optional): Whether to overwrite the file if it already exists. Defaults to True.
- Returns:
BaseModelHandler: The model, unchanged.
-
train(train_dataset_original_space: array, n_epochs: int, initial_learning_rate_generator: float, initial_learning_rate_discriminator: float, batch_size=
None
) BaseModelHandler ¶ Train the continuous QGAN.
- Args:
train_dataset_original_space (np.array): The training data in the original space. n_epochs (int): Technically, we are not passing the number of passes through the training data, but the number of iterations of the training loop. initial_learning_rate_generator (float, optional): Learning rate for the quantum generator. initial_learning_rate_discriminator (float, optional): Learning rate for the classical discriminator. batch_size (int, optional): Batch size. Defaults to None, and the whole training data is used in each iteration.
- Raises:
ValueError: Raises ValueError if the training dataset has dimension (number of columns) not equal to 2 or 3.
- Returns:
BaseModelHandler: The trained model.
-
build(model_name: str, data_set: str, n_qubits: int =