gplite 2.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,158 @@
1
+ # ActiveLearning Module
2
+
3
+ Intelligent data sampling for efficient model training.
4
+
5
+ ## Overview
6
+
7
+ Pool-Based Active learning reduces the amount of labeled data needed by strategically selecting the most informative points from an unlabeled pool.
8
+
9
+ ### The Active Learning Loop
10
+
11
+ ```
12
+ 1. Train GP on current labeled set
13
+ 2. Compute acquisition scores for unlabeled points
14
+ 3. Select and label highest-scoring point(s)
15
+ 4. Repeat until stopping criterion met
16
+ ```
17
+
18
+ ## Usage
19
+
20
+ ### Basic Active Learning
21
+
22
+ ```python
23
+ from gplite import ActiveLearner, PeriodicKernel
24
+ import numpy as np
25
+
26
+ # full dataset
27
+ X_full = np.linspace(0, 2*np.pi, 200).reshape(-1, 1)
28
+ y_full = np.sin(X_full).ravel()
29
+
30
+ # create active learner
31
+ kernel = PeriodicKernel(length_scale=1.0, period=2*np.pi)
32
+ learner = ActiveLearner(
33
+ kernel=kernel,
34
+ x_full=X_full,
35
+ y_full=y_full,
36
+ rmse_threshold=0.05, # stop when RMSE < 0.05
37
+ max_points=50, # or when 50 points used
38
+ optimize_interval=10 # re-optimize every 10 iterations
39
+ )
40
+
41
+ # run active learning
42
+ learner.learn(
43
+ learning_strategy="uncertainty",
44
+ batch_size=1,
45
+ final_optimization_method="rmse",
46
+ update=True, # print progress
47
+ update_interval=10 # print progress every 10 iterations
48
+ )
49
+
50
+ # access trained model
51
+ predictions = learner.gp.predict(X_full)
52
+ print(f"Points used: {len(learner.y_train)}")
53
+ ```
54
+
55
+ ## Selection Strategies
56
+
57
+ ### Uncertainty Sampling (`"uncertainty"`)
58
+
59
+ Selects points where the model is most uncertain (highest predictive variance).
60
+
61
+ ```python
62
+ learner.learn(learning_strategy="uncertainty")
63
+ ```
64
+
65
+ **Best for:** Exploration, when you want broad coverage of the input space.
66
+
67
+ ### Maximum Absolute Error (`"mae"`)
68
+
69
+ Selects points where the model makes the largest errors.
70
+
71
+ ```python
72
+ learner.learn(learning_strategy="mae")
73
+ ```
74
+
75
+ **Best for:** Exploitation, when you want to fix specific problem areas.
76
+
77
+ ### Expected Improvement — Maximize (`"ei_max"`)
78
+
79
+ Selects points with the highest expected improvement over the current best (maximum) observed value. Balances exploitation (high predicted mean) with exploration (high uncertainty).
80
+
81
+ ```
82
+ EI(x) = (μ(x) - f_best) * Φ(Z) + σ(x) * φ(Z)
83
+ where Z = (μ(x) - f_best) / σ(x), f_best = max(y_train)
84
+ ```
85
+
86
+ ```python
87
+ learner.learn(learning_strategy="ei_max")
88
+ ```
89
+
90
+ **Best for:** Bayesian optimization when searching for the maximum of a function.
91
+
92
+ ### Expected Improvement — Minimize (`"ei_min"`)
93
+
94
+ Selects points with the highest expected improvement below the current best (minimum) observed value.
95
+
96
+ ```
97
+ EI(x) = (f_best - μ(x)) * Φ(Z) + σ(x) * φ(Z)
98
+ where Z = (f_best - μ(x)) / σ(x), f_best = min(y_train)
99
+ ```
100
+
101
+ ```python
102
+ learner.learn(learning_strategy="ei_min")
103
+ ```
104
+
105
+ **Best for:** Bayesian optimization when searching for the minimum of a function (e.g., energy minimization).
106
+
107
+ ### Random (`"random"`)
108
+
109
+ Baseline strategy with uniform random selection.
110
+
111
+ ```python
112
+ learner.learn(learning_strategy="random")
113
+ ```
114
+
115
+ **Best for:** Comparison baseline, or when domain knowledge suggests uniform sampling.
116
+
117
+ ## Class Reference
118
+
119
+ ### `ActiveLearner(kernel, x_full, y_full, ...)`
120
+
121
+ **Parameters:**
122
+ - `kernel` (Kernel): Kernel for the internal GP model
123
+ - `x_full` (array): Complete pool of input features
124
+ - `y_full` (array): Complete pool of target values
125
+ - `max_points` (int, optional): Maximum training points to use. Defaults to full dataset if not passed
126
+ - `rmse_threshold` (float): Target RMSE for stopping. Default: `0.5`
127
+ - `optimize_interval` (int, optional): Iterations between hyperparameter optimization. Defaults to 1 (optimize each iteration)
128
+
129
+ **Methods:**
130
+ - `learn(learning_strategy, batch_size=1, ...)`: Run the active learning loop
131
+ - `select_next_point(selection_function, n_points=1)`: Select next point(s) to add
132
+
133
+ **Attributes:**
134
+ - `gp`: The underlying GaussianProcess model
135
+ - `x_train`, `y_train`: Current training data
136
+ - `x_full`, `y_full`: Complete data pool
137
+ - `remaining_indices`: Indices of unlabeled points
138
+
139
+ ## Stopping Criteria
140
+
141
+ The learning loop stops when any of these conditions is met:
142
+
143
+ 1. **RMSE threshold reached**: Model achieves target accuracy
144
+ 2. **Max points reached**: Budget exhausted
145
+ 3. **Pool exhausted**: All points have been labeled
146
+
147
+ ## Batch Active Learning
148
+
149
+ Select multiple points per iteration:
150
+
151
+ ```python
152
+ learner.learn(
153
+ learning_strategy="uncertainty",
154
+ batch_size=5 # add 5 points per iteration
155
+ )
156
+ ```
157
+
158
+ Useful when labeling has high fixed cost but low marginal cost.
@@ -0,0 +1,3 @@
1
+ from gplite.ActiveLearning.active_learning import ActiveLearner
2
+
3
+ __all__ = ["ActiveLearner"]
@@ -0,0 +1,399 @@
1
+ """
2
+ Active learning implementation for intelligent data sampling with Gaussian
3
+ Process models.
4
+
5
+ Active learning aims to achieve high model accuracy with minimal labeled data
6
+ by strategically selecting the most informative points from an unlabeled pool.
7
+
8
+ Common acquisition strategies:
9
+ - Uncertainty Sampling: Select points where σ²(x) is highest, exploring
10
+ regions where the model is least confident.
11
+ - Expected Improvement: Select points that maximize the expected improvement
12
+ over the current best observation (separate variants for maximization and
13
+ minimization objectives).
14
+ - Error-based: Select points with highest |y - μ(x)|, focusing on regions
15
+ where the model performs worst.
16
+ - Random: Baseline strategy with uniform random selection.
17
+
18
+ The learning loop:
19
+ 1. Train GP on current labeled set
20
+ 2. Compute acquisition scores for unlabeled points
21
+ 3. Select and label highest-scoring point(s)
22
+ 4. Repeat until stopping criterion (RMSE threshold, budget, etc.)
23
+ """
24
+
25
+ import csv
26
+ import warnings
27
+ from collections.abc import Callable
28
+ from pathlib import Path
29
+
30
+ import numpy as np
31
+ from gplite._utils._computation import compute_rmse_across_dataset
32
+ from gplite._utils._errors import ValidationError
33
+ from gplite._utils._types import Arrf64, Arri64, f64
34
+ from gplite._utils._validation import (
35
+ validate_input_and_target_data,
36
+ validate_numeric_value,
37
+ )
38
+ from gplite.ActiveLearning.selection_functions import (
39
+ expected_improvement_max,
40
+ expected_improvement_min,
41
+ max_absolute_error,
42
+ max_uncertainty,
43
+ random_selection,
44
+ )
45
+ from gplite.GaussianProcess.gaussian_process import GaussianProcess
46
+ from gplite.Kernels._base import Kernel
47
+ from gplite.Optimization.active_learning.optimization import (
48
+ optimize_hyperparameters,
49
+ )
50
+
51
+ LEARNING_STRATEGIES: dict[str, Callable] = {
52
+ "random": random_selection,
53
+ "random_choice": random_selection,
54
+ "uncertainty": max_uncertainty,
55
+ "max_uncertainty": max_uncertainty,
56
+ "mae": max_absolute_error,
57
+ "max_absolute_error": max_absolute_error,
58
+ "ei_max": expected_improvement_max,
59
+ "expected_improvement_max": expected_improvement_max,
60
+ "ei_min": expected_improvement_min,
61
+ "expected_improvement_min": expected_improvement_min,
62
+ }
63
+
64
+
65
+ class ActiveLearner:
66
+ """
67
+ Active learning system that intelligently selects training points to
68
+ minimize required data while maintaining model accuracy.
69
+
70
+ Uses a Gaussian Process model to identify the most informative points
71
+ from a pool of unlabeled data based on various selection strategies.
72
+
73
+ Attributes:
74
+ - gp (GaussianProcess): The underlying Gaussian Process model.
75
+ - x_full (Arrf64): Complete pool of input features.
76
+ - y_full (Arrf64): Complete pool of target values.
77
+ - x_train (Arrf64): Current training input features.
78
+ - y_train (Arrf64): Current training target values.
79
+ - remaining_indices (Arri64): Indices of points not yet in training set.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ kernel: Kernel,
85
+ x_full: Arrf64,
86
+ y_full: Arrf64,
87
+ max_points: int | None = None,
88
+ rmse_threshold: f64 = np.float64(0.5),
89
+ optimize_interval: int | None = 1,
90
+ ) -> None:
91
+ """
92
+ Initializes an active learner with the given kernel and data pool.
93
+
94
+ Args:
95
+ - kernel (Kernel): Kernel instance for the underlying GP model.
96
+ - x_full (Arrf64): Full dataset input features of shape (n, d).
97
+ - y_full (Arrf64): Full dataset target values of shape (n,).
98
+ - max_points (int | None): Maximum training points to use. Defaults
99
+ to full dataset size.
100
+ - rmse_threshold (f64): RMSE target for stopping criterion. Defaults
101
+ to 0.5.
102
+ - optimize_interval (int | None): Iterations between hyperparameter
103
+ optimization. None disables.
104
+
105
+ Raises:
106
+ ValidationError: If kernel is invalid or data arrays are
107
+ incompatible.
108
+ """
109
+ if not isinstance(kernel, Kernel):
110
+ err_msg = "Error: 'kernel' must be a valid Kernel instance"
111
+ raise ValidationError(err_msg)
112
+
113
+ self.x_full, self.y_full = validate_input_and_target_data(
114
+ x_full, y_full
115
+ )
116
+
117
+ self.kernel = kernel
118
+
119
+ self.gp = GaussianProcess(self.kernel)
120
+
121
+ self.rmse_threshold = validate_numeric_value(
122
+ rmse_threshold, "Active Learner RMSE Threshold", False
123
+ )
124
+
125
+ if max_points:
126
+ self.max_points = int(
127
+ validate_numeric_value(
128
+ max_points, "Active Learner Max Points", False
129
+ )
130
+ )
131
+
132
+ else:
133
+ self.max_points = len(self.y_full)
134
+
135
+ if optimize_interval:
136
+ self.optimize_interval = int(
137
+ validate_numeric_value(
138
+ optimize_interval, "Active Learner Optimize Interval", False
139
+ )
140
+ )
141
+ else:
142
+ self.optimize_interval = None
143
+
144
+ # initialize training sets and pool of points that remain
145
+ # available to be picked
146
+ self.x_train = np.array([])
147
+ self.y_train = np.array([])
148
+ self.remaining_indices = np.array([])
149
+
150
+ self._initialize_training_data()
151
+
152
+ def _initialize_training_data(self) -> None:
153
+ """
154
+ Initializes the training set with three strategically selected points:
155
+ first, middle, and last from the dataset. Sets up the remaining
156
+ indices pool for active selection.
157
+
158
+ Warns:
159
+ UserWarning: If dataset has fewer than 3 samples.
160
+ """
161
+ num_samples = self.x_full.shape[0]
162
+
163
+ if num_samples < 3:
164
+ warning_msg = (
165
+ "Warning: Active Learning data has < 3 samples. Using full "
166
+ "dataset for training."
167
+ )
168
+ warnings.warn(warning_msg)
169
+
170
+ self.x_train = self.x_full
171
+ self.y_train = self.y_full
172
+
173
+ return
174
+
175
+ # first, middle, and last points of the dataset
176
+ initial_indices = [0, num_samples // 2, num_samples - 1]
177
+
178
+ self.x_train = self.x_full[initial_indices]
179
+ self.y_train = self.y_full[initial_indices]
180
+
181
+ # remove indices from training pool
182
+ self.remaining_indices = np.setdiff1d(
183
+ np.arange(num_samples), initial_indices
184
+ )
185
+
186
+ def select_next_point(
187
+ self, selection_function: Callable, n_points: int = 1
188
+ ) -> Arri64:
189
+ """
190
+ Selects the next point(s) to add to the training set using the given
191
+ selection strategy.
192
+
193
+ Args:
194
+ - selection_function (Callable): Function that takes learner and
195
+ n_points and returns indices.
196
+ - n_points (int): Number of points to select. Defaults to 1.
197
+
198
+ Returns:
199
+ Arri64: Indices of selected points from the full dataset.
200
+ """
201
+ return selection_function(self, n_points)
202
+
203
+ def _update_log(self, rmse, log_file: Path) -> None:
204
+ """
205
+ Private method to update the log file of the learning loop.
206
+ """
207
+ with log_file.open("a", encoding="utf-8", newline="") as f:
208
+ writer = csv.writer(f)
209
+ writer.writerow([self.x_train.shape[0], rmse])
210
+
211
+ def learn(
212
+ self,
213
+ learning_strategy: str,
214
+ batch_size: int = 1,
215
+ final_optimization_method: str = "rmse",
216
+ update: bool = False,
217
+ log: bool = False,
218
+ update_interval: int = 10,
219
+ log_update_interval: int = 5,
220
+ ) -> None:
221
+ """
222
+ Executes the active learning loop, iteratively selecting and adding
223
+ points until a stopping criterion is met.
224
+
225
+ Stopping criteria include: reaching RMSE threshold, exhausting all
226
+ points, or reaching max_points limit.
227
+
228
+ Args:
229
+ - learning_strategy (str): Point selection strategy. Options:
230
+ 'random', 'uncertainty', 'mae'.
231
+ - batch_size (int): Points to add per iteration. Defaults to 1.
232
+ - final_optimization_method (str): Objective for final optimization.
233
+ Defaults to 'rmse'.
234
+ - update (bool): Whether to print progress updates. Defaults to
235
+ False.
236
+ - log (bool): Whether to log status updates. Works like the
237
+ "update" parameter, but updates are put into a log
238
+ file rather than stdout. Defaults to False.
239
+ - update_interval (int): Iterations between progress updates.
240
+ Defaults to 10.
241
+ - log_update_interval (int): Iterations between log progress updates.
242
+ Defaults to 5.
243
+
244
+ Raises:
245
+ ValidationError: If learning_strategy is not recognized.
246
+
247
+ Warns:
248
+ UserWarning: If learning stops early due to errors.
249
+ """
250
+ batch_size = int(
251
+ validate_numeric_value(
252
+ batch_size, "Number of Learning Points", allow_nonpositive=False
253
+ )
254
+ )
255
+ update_interval = int(
256
+ validate_numeric_value(
257
+ update_interval, "Update Interval", allow_nonpositive=False
258
+ )
259
+ )
260
+
261
+ log_update_interval = int(
262
+ validate_numeric_value(
263
+ log_update_interval,
264
+ "Log Update Interval",
265
+ allow_nonpositive=False,
266
+ )
267
+ )
268
+
269
+ if learning_strategy not in LEARNING_STRATEGIES:
270
+ err_msg = (
271
+ f"Error: {learning_strategy} is not a valid learning strategy. "
272
+ f"Valid strategies include: {list(LEARNING_STRATEGIES.keys())}"
273
+ )
274
+ raise ValidationError(err_msg)
275
+
276
+ log_file = None
277
+ if log:
278
+ log_file = Path(
279
+ f"./active_learning_{learning_strategy}.csv"
280
+ ).resolve()
281
+
282
+ with log_file.open("w", encoding="utf-8") as f:
283
+ writer = csv.writer(f)
284
+ writer.writerow(["num_points_used", "rmse"])
285
+
286
+ for iteration in range(self.max_points):
287
+ should_optimize = (
288
+ self.optimize_interval is not None
289
+ and iteration % self.optimize_interval == 0
290
+ )
291
+ should_update = update and (iteration + 1) % update_interval == 0
292
+ should_log = log and (iteration + 1) % log_update_interval == 0
293
+
294
+ # step 1: fit model to current training data, optimize with lml
295
+ self.gp.fit(self.x_train, self.y_train, optimize=should_optimize)
296
+
297
+ # step 2: compute rmse and check if the threshold has been reached
298
+ current_rmse = compute_rmse_across_dataset(
299
+ self.gp, self.x_full, self.y_full
300
+ )
301
+
302
+ if should_update:
303
+ print(f"Iteration {iteration + 1}: RMSE: {current_rmse}")
304
+
305
+ if should_log and log_file is not None:
306
+ self._update_log(current_rmse, log_file)
307
+
308
+ if current_rmse <= self.rmse_threshold:
309
+ optimize_hyperparameters(self, final_optimization_method)
310
+ final_rmse = compute_rmse_across_dataset(
311
+ self.gp, self.x_full, self.y_full
312
+ )
313
+
314
+ if update:
315
+ # make sure the update printing and final printing have a
316
+ # new line between them
317
+ print()
318
+
319
+ if log and log_file is not None:
320
+ self._update_log(final_rmse, log_file)
321
+
322
+ print(
323
+ "\033[4mRMSE threshold reached\033[0m:",
324
+ f"\nFinal RMSE: {final_rmse:.4f}",
325
+ f"\nPoints used: {len(self.y_train)}",
326
+ )
327
+
328
+ break
329
+
330
+ if len(self.remaining_indices) == 0:
331
+ optimize_hyperparameters(self, final_optimization_method)
332
+ final_rmse = compute_rmse_across_dataset(
333
+ self.gp, self.x_full, self.y_full
334
+ )
335
+
336
+ if update:
337
+ # make sure the update printing and final printing have a
338
+ # new line between them
339
+ print()
340
+
341
+ if log and log_file is not None:
342
+ self._update_log(final_rmse, log_file)
343
+
344
+ print(
345
+ "\033[4mAll points used\033[0m:",
346
+ f"\nFinal RMSE: {final_rmse:.4f}",
347
+ f"\nPoints used: {len(self.y_train)}",
348
+ )
349
+
350
+ break
351
+
352
+ remaining_budget = self.max_points - len(self.y_train)
353
+ if remaining_budget <= 0:
354
+ optimize_hyperparameters(self, final_optimization_method)
355
+ final_rmse = compute_rmse_across_dataset(
356
+ self.gp, self.x_full, self.y_full
357
+ )
358
+
359
+ if update:
360
+ # make sure the update printing and final printing have a
361
+ # new line between them
362
+ print()
363
+
364
+ if log and log_file is not None:
365
+ self._update_log(final_rmse, log_file)
366
+
367
+ print(
368
+ "\033[4mMax points reached\033[0m:",
369
+ f"\nFinal RMSE: {final_rmse:.4f}",
370
+ f"\nPoints used: {len(self.y_train)}",
371
+ )
372
+
373
+ break
374
+
375
+ try:
376
+ points_to_add = min(batch_size, remaining_budget)
377
+ selection_function = LEARNING_STRATEGIES[learning_strategy]
378
+ selected_indices = self.select_next_point(
379
+ selection_function, points_to_add
380
+ )
381
+ self.x_train = np.vstack(
382
+ [self.x_train, self.x_full[selected_indices]]
383
+ )
384
+ self.y_train = np.append(
385
+ self.y_train, self.y_full[selected_indices]
386
+ )
387
+
388
+ self.remaining_indices = np.setdiff1d(
389
+ self.remaining_indices, selected_indices
390
+ )
391
+
392
+ except ValueError as exc:
393
+ # usually due to running out of points
394
+ warning_msg = f"Warning: Learning stopped early: {exc!s}"
395
+ warnings.warn(warning_msg)
396
+
397
+ break
398
+
399
+ return