discontinuum 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- discontinuum/_version.py +2 -2
- discontinuum/engines/gpytorch.py +117 -107
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.5.dist-info}/METADATA +1 -1
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.5.dist-info}/RECORD +8 -8
- rating_gp/models/gpytorch.py +9 -14
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.5.dist-info}/WHEEL +0 -0
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.5.dist-info}/licenses/LICENSE.md +0 -0
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.5.dist-info}/top_level.txt +0 -0
discontinuum/_version.py
CHANGED
discontinuum/engines/gpytorch.py
CHANGED
@@ -49,10 +49,11 @@ class MarginalGPyTorch(BaseModel):
|
|
49
49
|
target: Dataset,
|
50
50
|
target_unc: Dataset = None,
|
51
51
|
iterations: int = 100,
|
52
|
-
optimizer: str = "
|
52
|
+
optimizer: str = "adamw",
|
53
53
|
learning_rate: float = None,
|
54
54
|
early_stopping: bool = False,
|
55
|
-
|
55
|
+
patience: int = 60,
|
56
|
+
gradient_noise: bool = False,
|
56
57
|
):
|
57
58
|
"""Fit the model to data.
|
58
59
|
|
@@ -67,13 +68,15 @@ class MarginalGPyTorch(BaseModel):
|
|
67
68
|
iterations : int, optional
|
68
69
|
Number of iterations for optimization. The default is 100.
|
69
70
|
optimizer : str, optional
|
70
|
-
Optimization method. The default is "
|
71
|
+
Optimization method. Supported: "adam", "adamw". The default is "adamw".
|
71
72
|
learning_rate : float, optional
|
72
73
|
Learning rate for optimization. If None, uses adaptive defaults.
|
73
74
|
early_stopping : bool, optional
|
74
75
|
Whether to use early stopping. The default is False.
|
75
|
-
|
76
|
-
Number of iterations to wait without improvement before stopping. The default is
|
76
|
+
patience : int, optional
|
77
|
+
Number of iterations to wait without improvement before stopping. The default is 60.
|
78
|
+
gradient_noise : bool, optional
|
79
|
+
Whether to inject Gaussian noise into gradients each step (std = 0.1 × current learning rate). The default is False.
|
77
80
|
"""
|
78
81
|
self.is_fitted = True
|
79
82
|
# setup data manager (self.dm)
|
@@ -95,31 +98,39 @@ class MarginalGPyTorch(BaseModel):
|
|
95
98
|
self.model.train()
|
96
99
|
self.likelihood.train()
|
97
100
|
|
98
|
-
# Adaptive learning rate selection for faster convergence
|
99
101
|
if learning_rate is None:
|
100
102
|
if optimizer == "adam":
|
101
|
-
learning_rate = 0.1 #
|
102
|
-
elif optimizer == "
|
103
|
-
learning_rate = 1
|
103
|
+
learning_rate = 0.1 # Aggressive default for faster convergence
|
104
|
+
elif optimizer == "adamw":
|
105
|
+
learning_rate = 0.1
|
104
106
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
107
|
+
if optimizer == "adamw":
|
108
|
+
optimizer_obj = torch.optim.AdamW(
|
109
|
+
self.model.parameters(),
|
110
|
+
lr=learning_rate,
|
111
|
+
betas=(0.9, 0.999),
|
112
|
+
eps=1e-8,
|
113
|
+
weight_decay=1e-2 # Stronger regularization for AdamW
|
114
|
+
)
|
115
|
+
elif optimizer == "adam":
|
116
|
+
optimizer_obj = torch.optim.Adam(
|
117
|
+
self.model.parameters(),
|
118
|
+
lr=learning_rate,
|
119
|
+
betas=(0.9, 0.999),
|
120
|
+
eps=1e-8,
|
121
|
+
weight_decay=1e-4 # Lighter regularization for Adam
|
122
|
+
)
|
123
|
+
else:
|
124
|
+
raise NotImplementedError(f"Only 'adam' and 'adamw' optimizers are supported. Got '{optimizer}'.")
|
125
|
+
|
126
|
+
# Use ReduceLROnPlateau for more stable learning rate adaptation
|
116
127
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
117
|
-
|
118
|
-
mode='min',
|
119
|
-
factor=0.
|
120
|
-
patience=
|
121
|
-
|
122
|
-
|
128
|
+
optimizer_obj,
|
129
|
+
mode='min',
|
130
|
+
factor=0.5, # Reduce LR by half
|
131
|
+
patience=max(2, patience),
|
132
|
+
threshold=1e-4,
|
133
|
+
min_lr=1e-5
|
123
134
|
)
|
124
135
|
|
125
136
|
# "Loss" for GPs - the marginal log likelihood
|
@@ -133,98 +144,97 @@ class MarginalGPyTorch(BaseModel):
|
|
133
144
|
min_lr_for_early_stop = 2e-5 # Stop if patience is exceeded and LR is below this
|
134
145
|
|
135
146
|
for i in pbar:
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
loss
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
|
171
|
-
)
|
172
|
-
continue
|
173
|
-
|
174
|
-
loss.backward()
|
175
|
-
|
176
|
-
# Gradient clipping for stability
|
177
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
178
|
-
|
179
|
-
# Check for NaN gradients
|
180
|
-
has_nan_grad = False
|
147
|
+
# Adam/AdamW optimizer with stability features
|
148
|
+
optimizer_obj.zero_grad()
|
149
|
+
output = self.model(train_x)
|
150
|
+
|
151
|
+
# Attempt loss calculation with dynamic jitter
|
152
|
+
try:
|
153
|
+
with gpytorch.settings.cholesky_jitter(jitter):
|
154
|
+
loss = -mll(output, train_y)
|
155
|
+
except Exception as e:
|
156
|
+
# Increase jitter if numerical issues occur
|
157
|
+
jitter = min(jitter * 10, 1e-2)
|
158
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
159
|
+
pbar.set_postfix_str(
|
160
|
+
f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
|
161
|
+
)
|
162
|
+
continue
|
163
|
+
|
164
|
+
# Check for NaN loss
|
165
|
+
if torch.isnan(loss) or torch.isinf(loss):
|
166
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
167
|
+
pbar.set_postfix_str(
|
168
|
+
f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
|
169
|
+
)
|
170
|
+
continue
|
171
|
+
|
172
|
+
loss.backward()
|
173
|
+
|
174
|
+
# Get current learning rate before gradient noise injection
|
175
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
176
|
+
|
177
|
+
# Gradient noise injection (if enabled)
|
178
|
+
if gradient_noise:
|
179
|
+
gradient_noise_scale = 0.1
|
180
|
+
adaptive_noise = gradient_noise_scale * current_lr
|
181
181
|
for param in self.model.parameters():
|
182
|
-
if param.grad is not None
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
continue
|
203
|
-
|
204
|
-
optimizer.step()
|
205
|
-
|
206
|
-
# Update learning rate scheduler for Adam
|
207
|
-
scheduler.step(loss)
|
208
|
-
current_lr = optimizer.param_groups[0]['lr']
|
209
|
-
|
210
|
-
# Early stopping check (more aggressive)
|
182
|
+
if param.grad is not None:
|
183
|
+
noise = torch.normal(mean=0.0, std=adaptive_noise, size=param.grad.shape, device=param.grad.device)
|
184
|
+
param.grad.add_(noise)
|
185
|
+
|
186
|
+
# Gradient clipping for stability
|
187
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
188
|
+
|
189
|
+
# Check for NaN gradients
|
190
|
+
has_nan_grad = False
|
191
|
+
for param in self.model.parameters():
|
192
|
+
if param.grad is not None and torch.isnan(param.grad).any():
|
193
|
+
has_nan_grad = True
|
194
|
+
break
|
195
|
+
|
196
|
+
if has_nan_grad:
|
197
|
+
# Don't update scheduler on NaN gradients - this prevents rapid LR decay
|
198
|
+
# The scheduler should only respond to actual optimization progress
|
199
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
200
|
+
|
201
|
+
# Update best loss tracking (loss is still valid, just gradients are NaN)
|
211
202
|
if loss.item() < best_loss:
|
212
203
|
best_loss = loss.item()
|
213
204
|
patience_counter = 0
|
214
205
|
else:
|
215
206
|
patience_counter += 1
|
216
207
|
|
217
|
-
# Display
|
208
|
+
# Display comprehensive info even with NaN gradients, skip normal progress update
|
209
|
+
pbar.set_postfix_str(
|
210
|
+
f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
|
211
|
+
)
|
212
|
+
continue
|
213
|
+
|
214
|
+
optimizer_obj.step()
|
215
|
+
|
216
|
+
# Update learning rate scheduler for Adam/AdamW
|
217
|
+
scheduler.step(loss.item())
|
218
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
219
|
+
|
220
|
+
# Early stopping check (more aggressive)
|
221
|
+
if loss.item() < best_loss:
|
222
|
+
best_loss = loss.item()
|
223
|
+
patience_counter = 0
|
224
|
+
else:
|
225
|
+
patience_counter += 1
|
226
|
+
|
227
|
+
# Only update progress bar if not skipped above
|
228
|
+
if not has_nan_grad:
|
218
229
|
progress_info = f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f}'
|
219
230
|
if early_stopping:
|
220
|
-
progress_info += f' patience={patience_counter}/
|
231
|
+
progress_info += f' patience={patience_counter}/{patience}'
|
221
232
|
pbar.set_postfix_str(progress_info)
|
222
233
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
break
|
234
|
+
if early_stopping and patience_counter >= patience and current_lr <= min_lr_for_early_stop:
|
235
|
+
print(f"\nEarly stopping triggered after {i+1} iterations")
|
236
|
+
print(f"Best loss: {best_loss:.6f}")
|
237
|
+
break
|
228
238
|
|
229
239
|
@is_fitted
|
230
240
|
def predict(self,
|
@@ -1,17 +1,17 @@
|
|
1
1
|
discontinuum/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
discontinuum/_version.py,sha256=
|
2
|
+
discontinuum/_version.py,sha256=cafGA7j4exK_daS29O0wW2JM2p5b8FwYDLqAYJR0jWo,511
|
3
3
|
discontinuum/data_manager.py,sha256=LiZoPR0nnu7YAUfh5L1ZDRfaS3dgfVIELXIHkzUKyBg,4416
|
4
4
|
discontinuum/pipeline.py,sha256=1avuZnFai-b3HmihcpZ8M3WFNQ8lXAFSNTrnfl2NrY0,10074
|
5
5
|
discontinuum/plot.py,sha256=eZQS6-Ydq8FFcEukPtNuDVB-weV6lHyWMyJ1hqTkVrU,2969
|
6
6
|
discontinuum/utils.py,sha256=07hIHQk_oDlkjz7tasgBjqqPOC6D0iNcy0eu-88aNbM,1540
|
7
7
|
discontinuum/engines/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
discontinuum/engines/base.py,sha256=OlHd4ssIQoWvYHKoVqk5fKAVBcKsIIkR4ul9iNBvaYg,2396
|
9
|
-
discontinuum/engines/gpytorch.py,sha256=
|
9
|
+
discontinuum/engines/gpytorch.py,sha256=kRyAgCfxjKZbAJhJGViaDU_y8NO8sW4rSWRyEQlomHo,14383
|
10
10
|
discontinuum/engines/pymc.py,sha256=phbtE-3UCSVcP1MhbXwAHIWDZWDr56wK9U7aRt-w-2o,5961
|
11
11
|
discontinuum/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
discontinuum/providers/base.py,sha256=Yn2EHS1b4fYl09-m2MYuf2P9VRUXAP-WDpSoZrCbRvY,720
|
13
13
|
discontinuum/tests/test_pipeline.py,sha256=_FhkGxbFIxNb35lGaIdZk7Zjgs6CkxEF3gFUX3PE8EU,918
|
14
|
-
discontinuum-1.0.
|
14
|
+
discontinuum-1.0.5.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
|
15
15
|
loadest_gp/__init__.py,sha256=YISfvbc7Zy2y0BOxS1A2KzqxyoNJTz0EnLMnRW6iVT8,740
|
16
16
|
loadest_gp/plot.py,sha256=x2PK7vBCc44dX9lu5YV-rvw1u4pvXSLdcrTSvYLiHMA,2595
|
17
17
|
loadest_gp/utils.py,sha256=m5QaqR_0JiuRXPfryH8nI5lODp8PqvQla5C05WDN3LY,2772
|
@@ -25,11 +25,11 @@ rating_gp/pipeline.py,sha256=1HgxN6DD3ZL5lhUb3DK2in2IXiml7W4Ja272GBMTc08,1884
|
|
25
25
|
rating_gp/plot.py,sha256=CJphwqWWAfIY22j5Oz5DRwj7TcQCRyIQvM79_3KEdlc,9635
|
26
26
|
rating_gp/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rating_gp/models/base.py,sha256=e2Kq644I88YLHWPNA0qyRgitF5wimdLW4618vKX-o_s,1474
|
28
|
-
rating_gp/models/gpytorch.py,sha256=
|
28
|
+
rating_gp/models/gpytorch.py,sha256=eFHwtnW44GZ1zz0fLx5REbzIWwnb_x_uq-cGjDcHyWs,6907
|
29
29
|
rating_gp/models/kernels.py,sha256=3xg2mhY3aEgjI3r5vyAll9MA4c3M5UKqRi3FApNhJJQ,11579
|
30
30
|
rating_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
rating_gp/providers/usgs.py,sha256=KmKYN3c8Mi-ly2l6X80WT3taEhqCPXeEcRNi9HvbJmY,8134
|
32
|
-
discontinuum-1.0.
|
33
|
-
discontinuum-1.0.
|
34
|
-
discontinuum-1.0.
|
35
|
-
discontinuum-1.0.
|
32
|
+
discontinuum-1.0.5.dist-info/METADATA,sha256=GDh-fscmNYYMmXXoUE5Xd1QafuKEOPjgZjlssmjqGVg,6302
|
33
|
+
discontinuum-1.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
34
|
+
discontinuum-1.0.5.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
|
35
|
+
discontinuum-1.0.5.dist-info/RECORD,,
|
rating_gp/models/gpytorch.py
CHANGED
@@ -110,13 +110,13 @@ class ExactGPModel(gpytorch.models.ExactGP):
|
|
110
110
|
# + stage * time kernel only at low stage with smaller time length.
|
111
111
|
# Note that stage gets transformed to q, so the kernel is actually
|
112
112
|
# q * time
|
113
|
-
b_min = np.quantile(train_y, 0.
|
113
|
+
b_min = np.quantile(train_y, 0.10)
|
114
114
|
b_max = np.quantile(train_y, 0.90)
|
115
115
|
self.covar_module = (
|
116
|
-
(self.cov_stage(ls_prior=GammaPrior(concentration=
|
116
|
+
(self.cov_stage(ls_prior=GammaPrior(concentration=1, rate=1))
|
117
117
|
* self.cov_time(ls_prior=GammaPrior(concentration=1, rate=1)))
|
118
|
-
+ (self.cov_stage(ls_prior=GammaPrior(concentration=
|
119
|
-
* self.cov_time(ls_prior=GammaPrior(concentration=
|
118
|
+
+ (self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=1))
|
119
|
+
* self.cov_time(ls_prior=GammaPrior(concentration=2, rate=5))
|
120
120
|
* SigmoidKernel(
|
121
121
|
active_dims=self.stage_dim,
|
122
122
|
# a_prior=NormalPrior(loc=20, scale=1),
|
@@ -167,20 +167,15 @@ class ExactGPModel(gpytorch.models.ExactGP):
|
|
167
167
|
),
|
168
168
|
outputscale_prior=eta,
|
169
169
|
)
|
170
|
-
|
171
|
-
# Periodic
|
172
|
-
# Locally periodic kernel: Periodic * Matern
|
170
|
+
|
171
|
+
# Periodic performs beter than a locally periodic kernel
|
173
172
|
periodic_kernel = ScaleKernel(
|
174
173
|
gpytorch.kernels.PeriodicKernel(
|
175
174
|
active_dims=self.time_dim,
|
176
|
-
period_length_prior=NormalPrior(loc=1.0, scale=0.
|
177
|
-
lengthscale_prior=GammaPrior(concentration=
|
178
|
-
) * MaternKernel(
|
179
|
-
active_dims=self.time_dim,
|
180
|
-
nu=2.5,
|
181
|
-
lengthscale_prior=GammaPrior(concentration=4, rate=3),
|
175
|
+
period_length_prior=NormalPrior(loc=1.0, scale=0.1), # ~1 year
|
176
|
+
lengthscale_prior=GammaPrior(concentration=2, rate=4),
|
182
177
|
),
|
183
|
-
outputscale_prior=HalfNormalPrior(scale=0.
|
178
|
+
outputscale_prior=HalfNormalPrior(scale=0.5),
|
184
179
|
)
|
185
180
|
|
186
181
|
return base_kernel + periodic_kernel
|
File without changes
|
File without changes
|
File without changes
|