gplite 2.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gplite/ActiveLearning/README.md +158 -0
- gplite/ActiveLearning/__init__.py +3 -0
- gplite/ActiveLearning/active_learning.py +399 -0
- gplite/ActiveLearning/selection_functions.py +184 -0
- gplite/GaussianProcess/README.md +133 -0
- gplite/GaussianProcess/__init__.py +3 -0
- gplite/GaussianProcess/gaussian_process.py +366 -0
- gplite/Kernels/README.md +176 -0
- gplite/Kernels/__init__.py +11 -0
- gplite/Kernels/_base.py +345 -0
- gplite/Kernels/_composite.py +427 -0
- gplite/Kernels/constant.py +181 -0
- gplite/Kernels/matern.py +381 -0
- gplite/Kernels/periodic.py +364 -0
- gplite/Kernels/rbf.py +320 -0
- gplite/Optimization/README.md +119 -0
- gplite/Optimization/__init__.py +0 -0
- gplite/Optimization/active_learning/__init__.py +0 -0
- gplite/Optimization/active_learning/loss_functions.py +58 -0
- gplite/Optimization/active_learning/optimization.py +189 -0
- gplite/Optimization/gaussian_process/__init__.py +0 -0
- gplite/Optimization/gaussian_process/loss_functions.py +108 -0
- gplite/Optimization/gaussian_process/optimization.py +193 -0
- gplite/__init__.py +17 -0
- gplite/_utils/__init__.py +0 -0
- gplite/_utils/_computation.py +129 -0
- gplite/_utils/_constants.py +10 -0
- gplite/_utils/_data.py +134 -0
- gplite/_utils/_errors.py +7 -0
- gplite/_utils/_types.py +16 -0
- gplite/_utils/_validation.py +360 -0
- gplite/py.typed +0 -0
- gplite-2.1.2.dist-info/METADATA +822 -0
- gplite-2.1.2.dist-info/RECORD +36 -0
- gplite-2.1.2.dist-info/WHEEL +4 -0
- gplite-2.1.2.dist-info/licenses/LICENSE +674 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
# ActiveLearning Module
|
|
2
|
+
|
|
3
|
+
Intelligent data sampling for efficient model training.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
Pool-Based Active learning reduces the amount of labeled data needed by strategically selecting the most informative points from an unlabeled pool.
|
|
8
|
+
|
|
9
|
+
### The Active Learning Loop
|
|
10
|
+
|
|
11
|
+
```
|
|
12
|
+
1. Train GP on current labeled set
|
|
13
|
+
2. Compute acquisition scores for unlabeled points
|
|
14
|
+
3. Select and label highest-scoring point(s)
|
|
15
|
+
4. Repeat until stopping criterion met
|
|
16
|
+
```
|
|
17
|
+
|
|
18
|
+
## Usage
|
|
19
|
+
|
|
20
|
+
### Basic Active Learning
|
|
21
|
+
|
|
22
|
+
```python
|
|
23
|
+
from gplite import ActiveLearner, PeriodicKernel
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
# full dataset
|
|
27
|
+
X_full = np.linspace(0, 2*np.pi, 200).reshape(-1, 1)
|
|
28
|
+
y_full = np.sin(X_full).ravel()
|
|
29
|
+
|
|
30
|
+
# create active learner
|
|
31
|
+
kernel = PeriodicKernel(length_scale=1.0, period=2*np.pi)
|
|
32
|
+
learner = ActiveLearner(
|
|
33
|
+
kernel=kernel,
|
|
34
|
+
x_full=X_full,
|
|
35
|
+
y_full=y_full,
|
|
36
|
+
rmse_threshold=0.05, # stop when RMSE < 0.05
|
|
37
|
+
max_points=50, # or when 50 points used
|
|
38
|
+
optimize_interval=10 # re-optimize every 10 iterations
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# run active learning
|
|
42
|
+
learner.learn(
|
|
43
|
+
learning_strategy="uncertainty",
|
|
44
|
+
batch_size=1,
|
|
45
|
+
final_optimization_method="rmse",
|
|
46
|
+
update=True, # print progress
|
|
47
|
+
update_interval=10 # print progress every 10 iterations
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# access trained model
|
|
51
|
+
predictions = learner.gp.predict(X_full)
|
|
52
|
+
print(f"Points used: {len(learner.y_train)}")
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Selection Strategies
|
|
56
|
+
|
|
57
|
+
### Uncertainty Sampling (`"uncertainty"`)
|
|
58
|
+
|
|
59
|
+
Selects points where the model is most uncertain (highest predictive variance).
|
|
60
|
+
|
|
61
|
+
```python
|
|
62
|
+
learner.learn(learning_strategy="uncertainty")
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
**Best for:** Exploration, when you want broad coverage of the input space.
|
|
66
|
+
|
|
67
|
+
### Maximum Absolute Error (`"mae"`)
|
|
68
|
+
|
|
69
|
+
Selects points where the model makes the largest errors.
|
|
70
|
+
|
|
71
|
+
```python
|
|
72
|
+
learner.learn(learning_strategy="mae")
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
**Best for:** Exploitation, when you want to fix specific problem areas.
|
|
76
|
+
|
|
77
|
+
### Expected Improvement — Maximize (`"ei_max"`)
|
|
78
|
+
|
|
79
|
+
Selects points with the highest expected improvement over the current best (maximum) observed value. Balances exploitation (high predicted mean) with exploration (high uncertainty).
|
|
80
|
+
|
|
81
|
+
```
|
|
82
|
+
EI(x) = (μ(x) - f_best) * Φ(Z) + σ(x) * φ(Z)
|
|
83
|
+
where Z = (μ(x) - f_best) / σ(x), f_best = max(y_train)
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
```python
|
|
87
|
+
learner.learn(learning_strategy="ei_max")
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
**Best for:** Bayesian optimization when searching for the maximum of a function.
|
|
91
|
+
|
|
92
|
+
### Expected Improvement — Minimize (`"ei_min"`)
|
|
93
|
+
|
|
94
|
+
Selects points with the highest expected improvement below the current best (minimum) observed value.
|
|
95
|
+
|
|
96
|
+
```
|
|
97
|
+
EI(x) = (f_best - μ(x)) * Φ(Z) + σ(x) * φ(Z)
|
|
98
|
+
where Z = (f_best - μ(x)) / σ(x), f_best = min(y_train)
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
```python
|
|
102
|
+
learner.learn(learning_strategy="ei_min")
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
**Best for:** Bayesian optimization when searching for the minimum of a function (e.g., energy minimization).
|
|
106
|
+
|
|
107
|
+
### Random (`"random"`)
|
|
108
|
+
|
|
109
|
+
Baseline strategy with uniform random selection.
|
|
110
|
+
|
|
111
|
+
```python
|
|
112
|
+
learner.learn(learning_strategy="random")
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
**Best for:** Comparison baseline, or when domain knowledge suggests uniform sampling.
|
|
116
|
+
|
|
117
|
+
## Class Reference
|
|
118
|
+
|
|
119
|
+
### `ActiveLearner(kernel, x_full, y_full, ...)`
|
|
120
|
+
|
|
121
|
+
**Parameters:**
|
|
122
|
+
- `kernel` (Kernel): Kernel for the internal GP model
|
|
123
|
+
- `x_full` (array): Complete pool of input features
|
|
124
|
+
- `y_full` (array): Complete pool of target values
|
|
125
|
+
- `max_points` (int, optional): Maximum training points to use. Defaults to full dataset if not passed
|
|
126
|
+
- `rmse_threshold` (float): Target RMSE for stopping. Default: `0.5`
|
|
127
|
+
- `optimize_interval` (int, optional): Iterations between hyperparameter optimization. Defaults to 1 (optimize each iteration)
|
|
128
|
+
|
|
129
|
+
**Methods:**
|
|
130
|
+
- `learn(learning_strategy, batch_size=1, ...)`: Run the active learning loop
|
|
131
|
+
- `select_next_point(selection_function, n_points=1)`: Select next point(s) to add
|
|
132
|
+
|
|
133
|
+
**Attributes:**
|
|
134
|
+
- `gp`: The underlying GaussianProcess model
|
|
135
|
+
- `x_train`, `y_train`: Current training data
|
|
136
|
+
- `x_full`, `y_full`: Complete data pool
|
|
137
|
+
- `remaining_indices`: Indices of unlabeled points
|
|
138
|
+
|
|
139
|
+
## Stopping Criteria
|
|
140
|
+
|
|
141
|
+
The learning loop stops when any of these conditions is met:
|
|
142
|
+
|
|
143
|
+
1. **RMSE threshold reached**: Model achieves target accuracy
|
|
144
|
+
2. **Max points reached**: Budget exhausted
|
|
145
|
+
3. **Pool exhausted**: All points have been labeled
|
|
146
|
+
|
|
147
|
+
## Batch Active Learning
|
|
148
|
+
|
|
149
|
+
Select multiple points per iteration:
|
|
150
|
+
|
|
151
|
+
```python
|
|
152
|
+
learner.learn(
|
|
153
|
+
learning_strategy="uncertainty",
|
|
154
|
+
batch_size=5 # add 5 points per iteration
|
|
155
|
+
)
|
|
156
|
+
```
|
|
157
|
+
|
|
158
|
+
Useful when labeling has high fixed cost but low marginal cost.
|
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Active learning implementation for intelligent data sampling with Gaussian
|
|
3
|
+
Process models.
|
|
4
|
+
|
|
5
|
+
Active learning aims to achieve high model accuracy with minimal labeled data
|
|
6
|
+
by strategically selecting the most informative points from an unlabeled pool.
|
|
7
|
+
|
|
8
|
+
Common acquisition strategies:
|
|
9
|
+
- Uncertainty Sampling: Select points where σ²(x) is highest, exploring
|
|
10
|
+
regions where the model is least confident.
|
|
11
|
+
- Expected Improvement: Select points that maximize the expected improvement
|
|
12
|
+
over the current best observation (separate variants for maximization and
|
|
13
|
+
minimization objectives).
|
|
14
|
+
- Error-based: Select points with highest |y - μ(x)|, focusing on regions
|
|
15
|
+
where the model performs worst.
|
|
16
|
+
- Random: Baseline strategy with uniform random selection.
|
|
17
|
+
|
|
18
|
+
The learning loop:
|
|
19
|
+
1. Train GP on current labeled set
|
|
20
|
+
2. Compute acquisition scores for unlabeled points
|
|
21
|
+
3. Select and label highest-scoring point(s)
|
|
22
|
+
4. Repeat until stopping criterion (RMSE threshold, budget, etc.)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import csv
|
|
26
|
+
import warnings
|
|
27
|
+
from collections.abc import Callable
|
|
28
|
+
from pathlib import Path
|
|
29
|
+
|
|
30
|
+
import numpy as np
|
|
31
|
+
from gplite._utils._computation import compute_rmse_across_dataset
|
|
32
|
+
from gplite._utils._errors import ValidationError
|
|
33
|
+
from gplite._utils._types import Arrf64, Arri64, f64
|
|
34
|
+
from gplite._utils._validation import (
|
|
35
|
+
validate_input_and_target_data,
|
|
36
|
+
validate_numeric_value,
|
|
37
|
+
)
|
|
38
|
+
from gplite.ActiveLearning.selection_functions import (
|
|
39
|
+
expected_improvement_max,
|
|
40
|
+
expected_improvement_min,
|
|
41
|
+
max_absolute_error,
|
|
42
|
+
max_uncertainty,
|
|
43
|
+
random_selection,
|
|
44
|
+
)
|
|
45
|
+
from gplite.GaussianProcess.gaussian_process import GaussianProcess
|
|
46
|
+
from gplite.Kernels._base import Kernel
|
|
47
|
+
from gplite.Optimization.active_learning.optimization import (
|
|
48
|
+
optimize_hyperparameters,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
LEARNING_STRATEGIES: dict[str, Callable] = {
|
|
52
|
+
"random": random_selection,
|
|
53
|
+
"random_choice": random_selection,
|
|
54
|
+
"uncertainty": max_uncertainty,
|
|
55
|
+
"max_uncertainty": max_uncertainty,
|
|
56
|
+
"mae": max_absolute_error,
|
|
57
|
+
"max_absolute_error": max_absolute_error,
|
|
58
|
+
"ei_max": expected_improvement_max,
|
|
59
|
+
"expected_improvement_max": expected_improvement_max,
|
|
60
|
+
"ei_min": expected_improvement_min,
|
|
61
|
+
"expected_improvement_min": expected_improvement_min,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class ActiveLearner:
|
|
66
|
+
"""
|
|
67
|
+
Active learning system that intelligently selects training points to
|
|
68
|
+
minimize required data while maintaining model accuracy.
|
|
69
|
+
|
|
70
|
+
Uses a Gaussian Process model to identify the most informative points
|
|
71
|
+
from a pool of unlabeled data based on various selection strategies.
|
|
72
|
+
|
|
73
|
+
Attributes:
|
|
74
|
+
- gp (GaussianProcess): The underlying Gaussian Process model.
|
|
75
|
+
- x_full (Arrf64): Complete pool of input features.
|
|
76
|
+
- y_full (Arrf64): Complete pool of target values.
|
|
77
|
+
- x_train (Arrf64): Current training input features.
|
|
78
|
+
- y_train (Arrf64): Current training target values.
|
|
79
|
+
- remaining_indices (Arri64): Indices of points not yet in training set.
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(
|
|
83
|
+
self,
|
|
84
|
+
kernel: Kernel,
|
|
85
|
+
x_full: Arrf64,
|
|
86
|
+
y_full: Arrf64,
|
|
87
|
+
max_points: int | None = None,
|
|
88
|
+
rmse_threshold: f64 = np.float64(0.5),
|
|
89
|
+
optimize_interval: int | None = 1,
|
|
90
|
+
) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Initializes an active learner with the given kernel and data pool.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
- kernel (Kernel): Kernel instance for the underlying GP model.
|
|
96
|
+
- x_full (Arrf64): Full dataset input features of shape (n, d).
|
|
97
|
+
- y_full (Arrf64): Full dataset target values of shape (n,).
|
|
98
|
+
- max_points (int | None): Maximum training points to use. Defaults
|
|
99
|
+
to full dataset size.
|
|
100
|
+
- rmse_threshold (f64): RMSE target for stopping criterion. Defaults
|
|
101
|
+
to 0.5.
|
|
102
|
+
- optimize_interval (int | None): Iterations between hyperparameter
|
|
103
|
+
optimization. None disables.
|
|
104
|
+
|
|
105
|
+
Raises:
|
|
106
|
+
ValidationError: If kernel is invalid or data arrays are
|
|
107
|
+
incompatible.
|
|
108
|
+
"""
|
|
109
|
+
if not isinstance(kernel, Kernel):
|
|
110
|
+
err_msg = "Error: 'kernel' must be a valid Kernel instance"
|
|
111
|
+
raise ValidationError(err_msg)
|
|
112
|
+
|
|
113
|
+
self.x_full, self.y_full = validate_input_and_target_data(
|
|
114
|
+
x_full, y_full
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
self.kernel = kernel
|
|
118
|
+
|
|
119
|
+
self.gp = GaussianProcess(self.kernel)
|
|
120
|
+
|
|
121
|
+
self.rmse_threshold = validate_numeric_value(
|
|
122
|
+
rmse_threshold, "Active Learner RMSE Threshold", False
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
if max_points:
|
|
126
|
+
self.max_points = int(
|
|
127
|
+
validate_numeric_value(
|
|
128
|
+
max_points, "Active Learner Max Points", False
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
else:
|
|
133
|
+
self.max_points = len(self.y_full)
|
|
134
|
+
|
|
135
|
+
if optimize_interval:
|
|
136
|
+
self.optimize_interval = int(
|
|
137
|
+
validate_numeric_value(
|
|
138
|
+
optimize_interval, "Active Learner Optimize Interval", False
|
|
139
|
+
)
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
self.optimize_interval = None
|
|
143
|
+
|
|
144
|
+
# initialize training sets and pool of points that remain
|
|
145
|
+
# available to be picked
|
|
146
|
+
self.x_train = np.array([])
|
|
147
|
+
self.y_train = np.array([])
|
|
148
|
+
self.remaining_indices = np.array([])
|
|
149
|
+
|
|
150
|
+
self._initialize_training_data()
|
|
151
|
+
|
|
152
|
+
def _initialize_training_data(self) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Initializes the training set with three strategically selected points:
|
|
155
|
+
first, middle, and last from the dataset. Sets up the remaining
|
|
156
|
+
indices pool for active selection.
|
|
157
|
+
|
|
158
|
+
Warns:
|
|
159
|
+
UserWarning: If dataset has fewer than 3 samples.
|
|
160
|
+
"""
|
|
161
|
+
num_samples = self.x_full.shape[0]
|
|
162
|
+
|
|
163
|
+
if num_samples < 3:
|
|
164
|
+
warning_msg = (
|
|
165
|
+
"Warning: Active Learning data has < 3 samples. Using full "
|
|
166
|
+
"dataset for training."
|
|
167
|
+
)
|
|
168
|
+
warnings.warn(warning_msg)
|
|
169
|
+
|
|
170
|
+
self.x_train = self.x_full
|
|
171
|
+
self.y_train = self.y_full
|
|
172
|
+
|
|
173
|
+
return
|
|
174
|
+
|
|
175
|
+
# first, middle, and last points of the dataset
|
|
176
|
+
initial_indices = [0, num_samples // 2, num_samples - 1]
|
|
177
|
+
|
|
178
|
+
self.x_train = self.x_full[initial_indices]
|
|
179
|
+
self.y_train = self.y_full[initial_indices]
|
|
180
|
+
|
|
181
|
+
# remove indices from training pool
|
|
182
|
+
self.remaining_indices = np.setdiff1d(
|
|
183
|
+
np.arange(num_samples), initial_indices
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def select_next_point(
|
|
187
|
+
self, selection_function: Callable, n_points: int = 1
|
|
188
|
+
) -> Arri64:
|
|
189
|
+
"""
|
|
190
|
+
Selects the next point(s) to add to the training set using the given
|
|
191
|
+
selection strategy.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
- selection_function (Callable): Function that takes learner and
|
|
195
|
+
n_points and returns indices.
|
|
196
|
+
- n_points (int): Number of points to select. Defaults to 1.
|
|
197
|
+
|
|
198
|
+
Returns:
|
|
199
|
+
Arri64: Indices of selected points from the full dataset.
|
|
200
|
+
"""
|
|
201
|
+
return selection_function(self, n_points)
|
|
202
|
+
|
|
203
|
+
def _update_log(self, rmse, log_file: Path) -> None:
|
|
204
|
+
"""
|
|
205
|
+
Private method to update the log file of the learning loop.
|
|
206
|
+
"""
|
|
207
|
+
with log_file.open("a", encoding="utf-8", newline="") as f:
|
|
208
|
+
writer = csv.writer(f)
|
|
209
|
+
writer.writerow([self.x_train.shape[0], rmse])
|
|
210
|
+
|
|
211
|
+
def learn(
|
|
212
|
+
self,
|
|
213
|
+
learning_strategy: str,
|
|
214
|
+
batch_size: int = 1,
|
|
215
|
+
final_optimization_method: str = "rmse",
|
|
216
|
+
update: bool = False,
|
|
217
|
+
log: bool = False,
|
|
218
|
+
update_interval: int = 10,
|
|
219
|
+
log_update_interval: int = 5,
|
|
220
|
+
) -> None:
|
|
221
|
+
"""
|
|
222
|
+
Executes the active learning loop, iteratively selecting and adding
|
|
223
|
+
points until a stopping criterion is met.
|
|
224
|
+
|
|
225
|
+
Stopping criteria include: reaching RMSE threshold, exhausting all
|
|
226
|
+
points, or reaching max_points limit.
|
|
227
|
+
|
|
228
|
+
Args:
|
|
229
|
+
- learning_strategy (str): Point selection strategy. Options:
|
|
230
|
+
'random', 'uncertainty', 'mae'.
|
|
231
|
+
- batch_size (int): Points to add per iteration. Defaults to 1.
|
|
232
|
+
- final_optimization_method (str): Objective for final optimization.
|
|
233
|
+
Defaults to 'rmse'.
|
|
234
|
+
- update (bool): Whether to print progress updates. Defaults to
|
|
235
|
+
False.
|
|
236
|
+
- log (bool): Whether to log status updates. Works like the
|
|
237
|
+
"update" parameter, but updates are put into a log
|
|
238
|
+
file rather than stdout. Defaults to False.
|
|
239
|
+
- update_interval (int): Iterations between progress updates.
|
|
240
|
+
Defaults to 10.
|
|
241
|
+
- log_update_interval (int): Iterations between log progress updates.
|
|
242
|
+
Defaults to 5.
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
ValidationError: If learning_strategy is not recognized.
|
|
246
|
+
|
|
247
|
+
Warns:
|
|
248
|
+
UserWarning: If learning stops early due to errors.
|
|
249
|
+
"""
|
|
250
|
+
batch_size = int(
|
|
251
|
+
validate_numeric_value(
|
|
252
|
+
batch_size, "Number of Learning Points", allow_nonpositive=False
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
update_interval = int(
|
|
256
|
+
validate_numeric_value(
|
|
257
|
+
update_interval, "Update Interval", allow_nonpositive=False
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
log_update_interval = int(
|
|
262
|
+
validate_numeric_value(
|
|
263
|
+
log_update_interval,
|
|
264
|
+
"Log Update Interval",
|
|
265
|
+
allow_nonpositive=False,
|
|
266
|
+
)
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
if learning_strategy not in LEARNING_STRATEGIES:
|
|
270
|
+
err_msg = (
|
|
271
|
+
f"Error: {learning_strategy} is not a valid learning strategy. "
|
|
272
|
+
f"Valid strategies include: {list(LEARNING_STRATEGIES.keys())}"
|
|
273
|
+
)
|
|
274
|
+
raise ValidationError(err_msg)
|
|
275
|
+
|
|
276
|
+
log_file = None
|
|
277
|
+
if log:
|
|
278
|
+
log_file = Path(
|
|
279
|
+
f"./active_learning_{learning_strategy}.csv"
|
|
280
|
+
).resolve()
|
|
281
|
+
|
|
282
|
+
with log_file.open("w", encoding="utf-8") as f:
|
|
283
|
+
writer = csv.writer(f)
|
|
284
|
+
writer.writerow(["num_points_used", "rmse"])
|
|
285
|
+
|
|
286
|
+
for iteration in range(self.max_points):
|
|
287
|
+
should_optimize = (
|
|
288
|
+
self.optimize_interval is not None
|
|
289
|
+
and iteration % self.optimize_interval == 0
|
|
290
|
+
)
|
|
291
|
+
should_update = update and (iteration + 1) % update_interval == 0
|
|
292
|
+
should_log = log and (iteration + 1) % log_update_interval == 0
|
|
293
|
+
|
|
294
|
+
# step 1: fit model to current training data, optimize with lml
|
|
295
|
+
self.gp.fit(self.x_train, self.y_train, optimize=should_optimize)
|
|
296
|
+
|
|
297
|
+
# step 2: compute rmse and check if the threshold has been reached
|
|
298
|
+
current_rmse = compute_rmse_across_dataset(
|
|
299
|
+
self.gp, self.x_full, self.y_full
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if should_update:
|
|
303
|
+
print(f"Iteration {iteration + 1}: RMSE: {current_rmse}")
|
|
304
|
+
|
|
305
|
+
if should_log and log_file is not None:
|
|
306
|
+
self._update_log(current_rmse, log_file)
|
|
307
|
+
|
|
308
|
+
if current_rmse <= self.rmse_threshold:
|
|
309
|
+
optimize_hyperparameters(self, final_optimization_method)
|
|
310
|
+
final_rmse = compute_rmse_across_dataset(
|
|
311
|
+
self.gp, self.x_full, self.y_full
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
if update:
|
|
315
|
+
# make sure the update printing and final printing have a
|
|
316
|
+
# new line between them
|
|
317
|
+
print()
|
|
318
|
+
|
|
319
|
+
if log and log_file is not None:
|
|
320
|
+
self._update_log(final_rmse, log_file)
|
|
321
|
+
|
|
322
|
+
print(
|
|
323
|
+
"\033[4mRMSE threshold reached\033[0m:",
|
|
324
|
+
f"\nFinal RMSE: {final_rmse:.4f}",
|
|
325
|
+
f"\nPoints used: {len(self.y_train)}",
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
break
|
|
329
|
+
|
|
330
|
+
if len(self.remaining_indices) == 0:
|
|
331
|
+
optimize_hyperparameters(self, final_optimization_method)
|
|
332
|
+
final_rmse = compute_rmse_across_dataset(
|
|
333
|
+
self.gp, self.x_full, self.y_full
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
if update:
|
|
337
|
+
# make sure the update printing and final printing have a
|
|
338
|
+
# new line between them
|
|
339
|
+
print()
|
|
340
|
+
|
|
341
|
+
if log and log_file is not None:
|
|
342
|
+
self._update_log(final_rmse, log_file)
|
|
343
|
+
|
|
344
|
+
print(
|
|
345
|
+
"\033[4mAll points used\033[0m:",
|
|
346
|
+
f"\nFinal RMSE: {final_rmse:.4f}",
|
|
347
|
+
f"\nPoints used: {len(self.y_train)}",
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
break
|
|
351
|
+
|
|
352
|
+
remaining_budget = self.max_points - len(self.y_train)
|
|
353
|
+
if remaining_budget <= 0:
|
|
354
|
+
optimize_hyperparameters(self, final_optimization_method)
|
|
355
|
+
final_rmse = compute_rmse_across_dataset(
|
|
356
|
+
self.gp, self.x_full, self.y_full
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
if update:
|
|
360
|
+
# make sure the update printing and final printing have a
|
|
361
|
+
# new line between them
|
|
362
|
+
print()
|
|
363
|
+
|
|
364
|
+
if log and log_file is not None:
|
|
365
|
+
self._update_log(final_rmse, log_file)
|
|
366
|
+
|
|
367
|
+
print(
|
|
368
|
+
"\033[4mMax points reached\033[0m:",
|
|
369
|
+
f"\nFinal RMSE: {final_rmse:.4f}",
|
|
370
|
+
f"\nPoints used: {len(self.y_train)}",
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
break
|
|
374
|
+
|
|
375
|
+
try:
|
|
376
|
+
points_to_add = min(batch_size, remaining_budget)
|
|
377
|
+
selection_function = LEARNING_STRATEGIES[learning_strategy]
|
|
378
|
+
selected_indices = self.select_next_point(
|
|
379
|
+
selection_function, points_to_add
|
|
380
|
+
)
|
|
381
|
+
self.x_train = np.vstack(
|
|
382
|
+
[self.x_train, self.x_full[selected_indices]]
|
|
383
|
+
)
|
|
384
|
+
self.y_train = np.append(
|
|
385
|
+
self.y_train, self.y_full[selected_indices]
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
self.remaining_indices = np.setdiff1d(
|
|
389
|
+
self.remaining_indices, selected_indices
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
except ValueError as exc:
|
|
393
|
+
# usually due to running out of points
|
|
394
|
+
warning_msg = f"Warning: Learning stopped early: {exc!s}"
|
|
395
|
+
warnings.warn(warning_msg)
|
|
396
|
+
|
|
397
|
+
break
|
|
398
|
+
|
|
399
|
+
return
|