sgptools 1.2.0__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sgptools/utils/gpflow.py CHANGED
@@ -13,238 +13,399 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import gpflow
16
+ from gpflow import set_trainable
16
17
  from gpflow.utilities.traversal import print_summary
17
-
18
18
  import tensorflow as tf
19
19
  import tensorflow_probability as tfp
20
-
21
20
  import numpy as np
22
- import matplotlib.pyplot as plt
23
-
24
21
  from .misc import get_inducing_pts
22
+ from typing import Union, List, Optional, Tuple, Any, Callable
23
+ from tensorflow.keras import optimizers
25
24
 
26
25
 
27
- def plot_loss(losses, save_file=None):
28
- """Helper function to plot the training loss
29
-
30
- Args:
31
- losses (list): list of loss values
32
- save_file (str): If passed, the loss plot will be saved to the `save_file`
26
+ def get_model_params(
27
+ X_train: np.ndarray,
28
+ y_train: np.ndarray,
29
+ max_steps: int = 1500,
30
+ verbose: bool = True,
31
+ lengthscales: Union[float, List[float]] = 1.0,
32
+ variance: float = 1.0,
33
+ noise_variance: float = 0.1,
34
+ kernel: Optional[gpflow.kernels.Kernel] = None,
35
+ return_model: bool = False,
36
+ train_inducing_pts: bool = False,
37
+ num_inducing_pts: int = 500,
38
+ **kwargs: Any
39
+ ) -> Union[Tuple[np.ndarray, float, gpflow.kernels.Kernel], Tuple[
40
+ np.ndarray, float, gpflow.kernels.Kernel, Union[gpflow.models.GPR,
41
+ gpflow.models.SGPR]]]:
33
42
  """
34
- plt.plot(losses)
35
- plt.title('Log Likelihood')
36
- plt.xlabel('Iteration')
37
- plt.ylabel('Log Likelihood')
38
- ax = plt.gca()
39
- ax.ticklabel_format(useOffset=False)
40
-
41
- if save_file is not None:
42
- plt.savefig(save_file, bbox_inches='tight')
43
- plt.close()
44
- else:
45
- plt.show()
46
-
47
- def get_model_params(X_train, y_train,
48
- max_steps=1500,
49
- lr=1e-2,
50
- print_params=True,
51
- lengthscales=1.0,
52
- variance=1.0,
53
- noise_variance=0.1,
54
- kernel=None,
55
- return_gp=False,
56
- train_inducing_pts=False,
57
- num_inducing_pts=500,
58
- **kwargs):
59
- """Train a GP on the given training set.
60
- Trains a sparse GP if the training set is larger than 1000 samples.
43
+ Trains a Gaussian Process (GP) or Sparse Gaussian Process (SGP) model on the given training set.
44
+ A Sparse GP is used if the training set size exceeds 1500 samples.
61
45
 
62
46
  Args:
63
- X_train (ndarray): (n, d); Training set inputs
64
- y_train (ndarray): (n, 1); Training set labels
65
- max_steps (int): Maximum number of optimization steps
66
- lr (float): Optimization learning rate
67
- print_params (bool): If True, prints the optimized GP parameters
68
- lengthscales (float or list): Kernel lengthscale(s), if passed as a list,
69
- each element corresponds to each data dimension
70
- variance (float): Kernel variance
71
- noise_variance (float): Data noise variance
72
- kernel (gpflow.kernels.Kernel): gpflow kernel function
73
- return_gp (bool): If True, returns the trained GP model
74
- train_inducing_pts (bool): If True, trains the inducing points when
75
- using a sparse GP model
76
- num_inducing_pts (int): Number of inducing points to use when training
77
- a sparse GP model
47
+ X_train (np.ndarray): (n, d); Training set input features. `n` is the number of samples,
48
+ `d` is the number of input dimensions.
49
+ y_train (np.ndarray): (n, 1); Training set labels. `n` is the number of samples.
50
+ max_steps (int): Maximum number of optimization steps. Defaults to 1500.
51
+ verbose (bool): If True, prints a summary of the optimized GP parameters. Defaults to True.
52
+ lengthscales (Union[float, List[float]]): Initial kernel lengthscale(s). If a float, it's
53
+ applied uniformly to all dimensions. If a list, each element
54
+ corresponds to a data dimension. Defaults to 1.0.
55
+ variance (float): Initial kernel variance. Defaults to 1.0.
56
+ noise_variance (float): Initial data noise variance. Defaults to 0.1.
57
+ kernel (Optional[gpflow.kernels.Kernel]): A pre-defined GPflow kernel function. If None,
58
+ a `gpflow.kernels.SquaredExponential` kernel is created
59
+ with the provided `lengthscales` and `variance`. Defaults to None.
60
+ return_model (bool): If True, the trained GP/SGP model object is returned along with
61
+ loss, variance, and kernel. Defaults to False.
62
+ train_inducing_pts (bool): If True and using a Sparse GP model, the inducing points
63
+ are optimized along with other model parameters. If False,
64
+ inducing points remain fixed (default for SGP). Defaults to False.
65
+ num_inducing_pts (int): Number of inducing points to use when training a Sparse GP model.
66
+ Ignored if `len(X_train)` is less than or equal to 1500. Defaults to 500.
67
+ **kwargs: Additional keyword arguments passed to the `optimize_model` function.
78
68
 
79
69
  Returns:
80
- loss (list): Loss values obtained during training
81
- variance (float): Optimized data noise variance
82
- kernel (gpflow.kernels.Kernel): Optimized gpflow kernel function
83
- gp (gpflow.models.GPR): Optimized gpflow GP model.
84
- Returned only if ```return_gp=True```.
85
-
70
+ Union[Tuple[np.ndarray, float, gpflow.kernels.Kernel], Tuple[np.ndarray, float, gpflow.kernels.Kernel, Union[gpflow.models.GPR, gpflow.models.SGPR]]]:
71
+ - If `return_model` is False:
72
+ Tuple: (loss (np.ndarray), variance (float), kernel (gpflow.kernels.Kernel)).
73
+ `loss` is an array of loss values obtained during training.
74
+ `variance` is the optimized data noise variance.
75
+ `kernel` is the optimized GPflow kernel function.
76
+ - If `return_model` is True:
77
+ Tuple: (loss (np.ndarray), variance (float), kernel (gpflow.kernels.Kernel), gp (Union[gpflow.models.GPR, gpflow.models.SGPR])).
78
+ `gp` is the optimized GPflow GPR or SGPR model object.
79
+
80
+ Usage:
81
+ ```python
82
+ import numpy as np
83
+ # Generate some dummy data
84
+ X = np.random.rand(1000, 2) * 10
85
+ y = np.sin(X[:, 0:1]) + np.cos(X[:, 1:2]) + np.random.randn(1000, 1) * 0.1
86
+
87
+ # Train a GPR model (since 1000 samples <= 1500)
88
+ losses, noise_var, trained_kernel = get_model_params(X, y, max_steps=500, verbose=True)
89
+
90
+ # Train an SGPR model (more than 1500 samples)
91
+ X_large = np.random.rand(2000, 2) * 10
92
+ y_large = np.sin(X_large[:, 0:1]) + np.cos(X_large[:, 1:2]) + np.random.randn(2000, 1) * 0.1
93
+ losses_sgpr, noise_var_sgpr, trained_kernel_sgpr, sgpr_model = \
94
+ get_model_params(X_large, y_large, max_steps=500, num_inducing_pts=100, return_model=True)
95
+ ```
86
96
  """
87
97
  if kernel is None:
88
- kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
98
+ kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
89
99
  variance=variance)
90
100
 
101
+ model: Union[gpflow.models.GPR, gpflow.models.SGPR]
102
+ trainable_variables_list: List[tf.Variable]
103
+
91
104
  if len(X_train) <= 1500:
92
- gpr = gpflow.models.GPR(data=(X_train, y_train),
93
- kernel=kernel,
94
- noise_variance=noise_variance)
95
- trainable_variables=gpr.trainable_variables
105
+ model = gpflow.models.GPR(data=(X_train, y_train),
106
+ kernel=kernel,
107
+ noise_variance=noise_variance)
108
+ trainable_variables_list = model.trainable_variables
96
109
  else:
97
110
  inducing_pts = get_inducing_pts(X_train, num_inducing_pts)
98
- gpr = gpflow.models.SGPR(data=(X_train, y_train),
99
- kernel=kernel,
100
- inducing_variable=inducing_pts,
101
- noise_variance=noise_variance)
111
+ model = gpflow.models.SGPR(data=(X_train, y_train),
112
+ kernel=kernel,
113
+ inducing_variable=inducing_pts,
114
+ noise_variance=noise_variance)
102
115
  if train_inducing_pts:
103
- trainable_variables=gpr.trainable_variables
116
+ trainable_variables_list = model.trainable_variables
104
117
  else:
105
- trainable_variables=gpr.trainable_variables[1:]
118
+ # Exclude inducing points from trainable variables if not specified to be trained
119
+ # Assuming inducing_variable is the first parameter in SGPR's trainable_variables
120
+ trainable_variables_list = model.trainable_variables[1:]
106
121
 
122
+ loss_values: np.ndarray
107
123
  if max_steps > 0:
108
- loss = optimize_model(gpr, max_steps=max_steps, lr=lr,
109
- trainable_variables=trainable_variables,
110
- **kwargs)
124
+ loss_values = optimize_model(
125
+ model,
126
+ max_steps=max_steps,
127
+ trainable_variables=trainable_variables_list,
128
+ verbose=verbose,
129
+ **kwargs)
111
130
  else:
112
- loss = 0
131
+ # If no optimization steps, return an array with a single '0' loss
132
+ loss_values = np.array([0.0])
113
133
 
114
- if print_params:
115
- print_summary(gpr)
116
-
117
- if return_gp:
118
- return loss, gpr.likelihood.variance, kernel, gpr
134
+ if verbose:
135
+ print_summary(model)
136
+
137
+ if return_model:
138
+ return loss_values, model.likelihood.variance.numpy(), kernel, model
119
139
  else:
120
- return loss, gpr.likelihood.variance, kernel
140
+ return loss_values, model.likelihood.variance.numpy(), kernel
121
141
 
122
142
 
123
143
  class TraceInducingPts(gpflow.monitor.MonitorTask):
124
- '''
125
- GPflow monitoring task, used to trace the inducing points
126
- states at every step during optimization.
144
+ """
145
+ A GPflow monitoring task designed to trace the state of inducing points
146
+ at every step during optimization of a Sparse Gaussian Process (SGP) model.
147
+ This is particularly useful for visualizing the movement of inducing points
148
+ during training.
127
149
 
128
- Args:
129
- model (gpflow.models.sgpr): GPflow GP/SGP model
130
- '''
131
- def __init__(self, model):
150
+ Attributes:
151
+ trace (List[np.ndarray]): A list to store the numpy arrays of inducing points
152
+ at each optimization step.
153
+ model (Union[gpflow.models.GPR, gpflow.models.SGPR]): The GPflow model being monitored.
154
+ """
155
+
156
+ def __init__(self, model: Union[gpflow.models.GPR, gpflow.models.SGPR]):
157
+ """
158
+ Initializes the TraceInducingPts monitor task.
159
+
160
+ Args:
161
+ model (Union[gpflow.models.GPR, gpflow.models.SGPR]): The GPflow GP or SGP model instance
162
+ to monitor. It is expected to have an
163
+ `inducing_variable.Z` attribute and potentially
164
+ a `transform` attribute.
165
+ """
132
166
  super().__init__()
133
- self.trace = []
167
+ self.trace: List[np.ndarray] = []
134
168
  self.model = model
135
169
 
170
+ def run(self, **kwargs: Any) -> None:
171
+ """
172
+ Executes the monitoring task. This method is called by the GPflow `Monitor`
173
+ at specified intervals. It extracts the current inducing points, applies
174
+ any associated transformations (e.g., `IPPTransform`'s fixed points expansion),
175
+ and appends them to the internal trace list.
176
+
177
+ Args:
178
+ **kwargs: Additional keyword arguments (e.g., `step`, `loss_value`)
179
+ passed by the `gpflow.monitor.Monitor` framework.
180
+
181
+ Usage:
182
+ This method is called internally by `gpflow.monitor.Monitor` and typically
183
+ not invoked directly by the user.
184
+ """
185
+ Xu = self.model.inducing_variable.Z
186
+ Xu_exp: np.ndarray
187
+ # Apply IPP fixed points transform if available, without expanding sensor model
188
+ try:
189
+ Xu_exp = self.model.transform.expand(
190
+ Xu, expand_sensor_model=False).numpy()
191
+ except AttributeError:
192
+ Xu_exp = Xu
193
+ self.trace.append(Xu_exp)
194
+
136
195
  def run(self, **kwargs):
137
196
  '''
138
197
  Method used to extract the inducing points and
139
198
  apply IPP fixed points transform if available
140
199
  '''
141
200
  Xu = self.model.inducing_variable.Z
142
- Xu_exp = self.model.transform.expand(Xu,
143
- expand_sensor_model=False).numpy()
201
+ Xu_exp = self.model.transform.expand(
202
+ Xu, expand_sensor_model=False).numpy()
144
203
  self.trace.append(Xu_exp)
145
204
 
146
- def get_trace(self):
147
- '''
148
- Returns the inducing points collected at each optimization step
205
+ def get_trace(self) -> np.ndarray:
206
+ """
207
+ Returns the collected inducing points at each optimization step.
149
208
 
150
209
  Returns:
151
- trace (ndarray): (n, m, d); Array with the inducing points.
152
- `n` is the number of optimization steps;
153
- `m` is the number of inducing points;
154
- `d` is the dimension of the inducing points.
155
- '''
210
+ np.ndarray: (num_steps, num_inducing_points, num_dimensions);
211
+ An array where:
212
+ - `num_steps` is the number of optimization steps monitored.
213
+ - `num_inducing_points` is the number of inducing points.
214
+ - `num_dimensions` is the dimensionality of the inducing points.
215
+
216
+ Usage:
217
+ ```python
218
+ # Assuming `model` is an SGPR and `opt_losses` was called with `trace_fn='traceXu'`
219
+ # trace_task = TraceInducingPts(model)
220
+ # Then retrieve trace after optimization
221
+ # inducing_points_history = trace_task.get_trace()
222
+ ```
223
+ """
156
224
  return np.array(self.trace)
157
225
 
158
226
 
159
- def optimize_model(model,
160
- max_steps=2000,
161
- kernel_grad=True,
162
- lr=1e-2,
163
- optimizer='scipy',
164
- method=None,
165
- verbose=False,
166
- trace_fn=None,
167
- convergence_criterion=True,
168
- trainable_variables=None,
169
- tol=None):
227
+ def optimize_model(model: Union[gpflow.models.GPR, gpflow.models.SGPR],
228
+ max_steps: int = 2000,
229
+ optimize_hparams: bool = True,
230
+ optimizer: str = 'scipy.L-BFGS-B',
231
+ verbose: bool = False,
232
+ trace_fn: Optional[Union[str, Callable[[Any], Any]]] = None,
233
+ convergence_criterion: bool = True,
234
+ trainable_variables: Optional[List[tf.Variable]] = None,
235
+ **kwargs: Any) -> np.ndarray:
170
236
  """
171
- Trains a GP/SGP model
237
+ Trains a GPflow GP or SGP model using either SciPy's optimizers or TensorFlow's optimizers.
172
238
 
173
239
  Args:
174
- model (gpflow.models): GPflow GP/SGP model to train.
175
- max_steps (int): Maximum number of training steps.
176
- kernel_grad (bool): If `False`, the kernel parameters will not be optimized.
177
- Ignored when `trainable_variables` are passed.
178
- lr (float): Optimization learning rate.
179
- optimizer (str): Optimizer to use for training (`scipy` or `tf`).
180
- method (str): Optimization method refer to [scipy minimize](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html#scipy.optimize.minimize)
181
- and [tf optimizers](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers) for full list
182
- verbose (bool): If `True`, the training progress will be printed when using Scipy.
183
- trace_fn (str): Function to trace metrics during training.
184
- If `None`, the loss values are returned;
185
- If `traceXu`, it the inducing points states at each optimization step are returned (increases computation time).
186
- convergence_criterion (bool): If `True` and using a tensorflow optimizer, it
187
- enables early stopping when the loss plateaus.
188
- trainable_variables (list): List of model variables to train.
189
- tol (float): Convergence tolerance to decide when to stop optimization.
240
+ model (Union[gpflow.models.GPR, gpflow.models.SGPR]): The GPflow model (GPR or SGPR) to be trained.
241
+ max_steps (int): Maximum number of training steps (iterations). Defaults to 2000.
242
+ optimize_hparams (bool): If `False`, the model hyperparameters (kernel parameters and data likelihood)
243
+ will not be optimized. This is ignored if `trainable_variables` is explicitly passed.
244
+ Defaults to True.
245
+ optimizer (str): Specifies the optimizer to use in "<backend>.<method>" format.
246
+ Supported backends: `scipy` and `tf` (TensorFlow).
247
+ - For `scipy` backend: Refer to `scipy.optimize.minimize` documentation for available
248
+ methods (e.g., 'L-BFGS-B', 'CG'). Only first-order and quasi-Newton methods
249
+ that do not require the Hessian are supported.
250
+ - For `tf` backend: Refer to `tf.keras.optimizers` for available methods
251
+ (e.g., 'Adam', 'SGD').
252
+ Defaults to 'scipy.L-BFGS-B'.
253
+ verbose (bool): If `True`, the training progress will be printed. For SciPy optimizers,
254
+ this controls `disp` option. Defaults to False.
255
+ trace_fn (Optional[Union[str, Callable[[Any], Any]]]): Specifies what to trace during training:
256
+ - `None`: Returns the loss values.
257
+ - `'traceXu'`: Traces the inducing points' states at each optimization step.
258
+ This increases computation time.
259
+ - `Callable`: A custom function that takes the traceable quantities from the optimizer
260
+ and returns the desired output.
261
+ - For `scipy` backend: Refer to `gpflow.monitor.MonitorTask`
262
+ - For `tf` backend: Refer to `trace_fn` argument of `tfp.math.minimize`
263
+ Defaults to None.
264
+ convergence_criterion (bool): If `True` and using a TensorFlow optimizer, it enables early
265
+ stopping when the loss plateaus (using `tfp.optimizer.convergence_criteria.LossNotDecreasing`).
266
+ Defaults to True.
267
+ trainable_variables (Optional[List[tf.Variable]]): A list of specific model variables to train.
268
+ If None, variables are determined based on `kernel_grad`. Defaults to None.
269
+ **kwargs: Additional keyword arguments passed to the backend optimizers.
270
+
271
+ Returns:
272
+ np.ndarray: An array of loss values (or traced quantities if `trace_fn` is specified)
273
+ recorded during the optimization process. The shape depends on `trace_fn`.
274
+
275
+ Raises:
276
+ ValueError: If an invalid optimizer format or an unsupported backend is specified.
277
+
278
+ Usage:
279
+ ```python
280
+ import gpflow
281
+ import numpy as np
282
+
283
+ # Create a dummy model (e.g., GPR for simplicity)
284
+ X = np.random.rand(100, 1)
285
+ y = X + np.random.randn(100, 1) * 0.1
286
+ kernel = gpflow.kernels.SquaredExponential()
287
+ model = gpflow.models.GPR((X, y), kernel=kernel, noise_variance=0.1)
288
+
289
+ # 1. Optimize using SciPy's L-BFGS-B (default)
290
+ losses_scipy = optimize_model(model, max_steps=500, verbose=True)
291
+
292
+ # 2. Optimize using TensorFlow's Adam optimizer
293
+ # Re-initialize model to reset parameters for new optimization
294
+ model_tf = gpflow.models.GPR((X, y), kernel=gpflow.kernels.SquaredExponential(), noise_variance=0.1)
295
+ losses_tf = optimize_model(model_tf, max_steps=1000, learning_rate=0.01, optimizer='tf.Adam', verbose=False)
296
+
297
+ # 3. Optimize SGPR and trace inducing points
298
+ X_sgpr = np.random.rand(2000, 2)
299
+ y_sgpr = np.sin(X_sgpr[:, 0:1]) + np.random.randn(2000, 1) * 0.1
300
+ inducing_points_init = get_inducing_pts(X_sgpr, 100)
301
+ sgpr_model = gpflow.models.SGPR((X_sgpr, y_sgpr), kernel=gpflow.kernels.SquaredExponential(),
302
+ inducing_variable=inducing_points_init, noise_variance=0.1)
303
+ traced_ips = optimize_model(sgpr_model, max_steps=100, optimizer='tf.Adam', trace_fn='traceXu', verbose=False)
304
+ ```
190
305
  """
191
- # Train all variables if trainable_variables are not provided
192
- # If kernel_gradient is False, disable the kernel parameter gradient updates
193
- if trainable_variables is None and kernel_grad:
194
- trainable_variables=model.trainable_variables
195
- elif trainable_variables is None and not kernel_grad:
196
- trainable_variables=model.trainable_variables[:1]
197
-
198
- if optimizer == 'scipy':
199
- if method is None:
200
- method = 'L-BFGS-B'
306
+ # Determine which variables to train
307
+ if trainable_variables is None:
308
+ # Disable hyperparameter gradients (kernel and likelihood parameters)
309
+ if not optimize_hparams:
310
+ set_trainable(model.kernel, False)
311
+ set_trainable(model.likelihood, False)
312
+ trainable_variables = model.trainable_variables
313
+
314
+ # Parse optimizer string
315
+ optimizer_parts = optimizer.split('.')
316
+ if len(optimizer_parts) != 2:
317
+ raise ValueError(
318
+ f"Invalid optimizer format! Expected <backend>.<method>; got {optimizer}"
319
+ )
320
+ backend, method = optimizer_parts
321
+
322
+ losses_output: Any # Will hold the final loss values or traced data
201
323
 
324
+ if backend == 'scipy':
325
+ # Configure SciPy monitor if tracing is requested
326
+ scipy_monitor: Optional[gpflow.monitor.Monitor] = None
327
+ trace_task_instance: Optional[TraceInducingPts] = None
202
328
  if trace_fn == 'traceXu':
203
- execute_task = TraceInducingPts(model)
204
- task_group = gpflow.monitor.MonitorTaskGroup(execute_task,
329
+ trace_task_instance = TraceInducingPts(model)
330
+ # Period=1 means run task at every step
331
+ task_group = gpflow.monitor.MonitorTaskGroup(trace_task_instance,
205
332
  period=1)
206
- trace_fn = gpflow.monitor.Monitor(task_group)
333
+ scipy_monitor = gpflow.monitor.Monitor(task_group)
207
334
 
208
335
  opt = gpflow.optimizers.Scipy()
209
- losses = opt.minimize(model.training_loss,
210
- trainable_variables,
211
- method=method,
212
- options=dict(disp=verbose, maxiter=max_steps),
213
- tol=tol,
214
- step_callback=trace_fn)
215
- if trace_fn is None:
216
- losses = losses.fun
336
+ # SciPy optimize method returns a `ScipyOptimizerResults` object
337
+ # which has `fun` attribute for the final loss. `step_callback` is used for tracing.
338
+ results = opt.minimize(
339
+ model.training_loss,
340
+ trainable_variables,
341
+ method=method,
342
+ options=dict(disp=verbose, maxiter=max_steps),
343
+ step_callback=scipy_monitor, # Pass the monitor as step_callback
344
+ **kwargs)
345
+
346
+ if trace_fn == 'traceXu' and trace_task_instance is not None:
347
+ losses_output = trace_task_instance.task_groups[0].tasks[
348
+ 0].get_trace()
217
349
  else:
218
- losses = trace_fn.task_groups[0].tasks[0].get_trace()
219
- else:
350
+ # If no tracing or non-Xu tracing, the `results.fun` contains the final loss
351
+ losses_output = np.array([results.fun
352
+ ]) # Return as an array for consistency
353
+ # Note: For SciPy, `losses.fun` is typically just the final loss, not a history.
354
+ # To get history, a custom callback capturing loss at each step would be needed.
355
+
356
+ elif backend == 'tf':
357
+ tf_trace_fn: Optional[Callable[[Any], Any]] = None
220
358
  if trace_fn is None:
221
- trace_fn = lambda x: x.loss
359
+ # Default TF trace function to capture loss history
360
+ tf_trace_fn = lambda traceable_quantities: traceable_quantities.loss
222
361
  elif trace_fn == 'traceXu':
223
- def trace_fn(traceable_quantities):
362
+
363
+ def tf_trace_fn(traceable_quantities):
224
364
  return model.inducing_variable.Z.numpy()
365
+ elif callable(trace_fn):
366
+ tf_trace_fn = trace_fn
367
+ else:
368
+ raise ValueError(
369
+ f"Invalid trace_fn for TensorFlow backend: {trace_fn}")
225
370
 
226
- if method is None:
227
- method = 'adam'
228
- opt = tf.keras.optimizers.get(method)
229
- opt.learning_rate = lr
230
- loss_fn = model.training_loss
371
+ # Get Keras optimizer instance
372
+ opt = getattr(optimizers, method)(**kwargs)
373
+
374
+ # Define the training loss function
375
+ loss_function_to_minimize = model.training_loss
376
+
377
+ # Configure convergence criterion
378
+ tf_convergence_criterion: Optional[
379
+ tfp.optimizer.convergence_criteria.ConvergenceCriterion] = None
231
380
  if convergence_criterion:
232
- convergence_criterion = tfp.optimizer.convergence_criteria.LossNotDecreasing(
233
- atol=1e-5,
234
- window_size=50,
235
- min_num_steps=int(max_steps*0.1))
236
- else:
237
- convergence_criterion = None
238
- losses = tfp.math.minimize(loss_fn,
239
- trainable_variables=trainable_variables,
240
- num_steps=max_steps,
241
- optimizer=opt,
242
- convergence_criterion=convergence_criterion,
243
- trace_fn=trace_fn)
244
- losses = losses.numpy()
245
-
246
- return losses
247
-
248
-
249
- if __name__ == "__main__":
250
- pass
381
+ tf_convergence_criterion = tfp.optimizer.convergence_criteria.LossNotDecreasing(
382
+ atol=1e-5, # Absolute tolerance for checking decrease
383
+ window_size=
384
+ 50, # Number of steps to consider for plateau detection
385
+ min_num_steps=int(
386
+ max_steps *
387
+ 0.1) # Minimum steps before early stopping is considered
388
+ )
389
+
390
+ # Run TensorFlow optimization
391
+ results_tf = tfp.math.minimize(
392
+ loss_function_to_minimize,
393
+ trainable_variables=trainable_variables,
394
+ num_steps=max_steps,
395
+ optimizer=opt,
396
+ convergence_criterion=tf_convergence_criterion,
397
+ trace_fn=tf_trace_fn)
398
+
399
+ # Fallback to just final loss if no proper trace captured
400
+ losses_output = np.array(results_tf.numpy())
401
+
402
+ else:
403
+ raise ValueError(
404
+ f"Invalid backend! Expected `scipy` or `tf`; got {backend}")
405
+
406
+ # Reset trainable variables
407
+ if not optimize_hparams:
408
+ set_trainable(model.kernel, True)
409
+ set_trainable(model.likelihood, True)
410
+
411
+ return losses_output