discontinuum 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- discontinuum/_version.py +2 -2
- discontinuum/engines/gpytorch.py +117 -108
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.6.dist-info}/METADATA +1 -1
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.6.dist-info}/RECORD +9 -9
- rating_gp/models/gpytorch.py +72 -75
- rating_gp/plot.py +23 -0
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.6.dist-info}/WHEEL +0 -0
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.6.dist-info}/licenses/LICENSE.md +0 -0
- {discontinuum-1.0.4.dist-info → discontinuum-1.0.6.dist-info}/top_level.txt +0 -0
discontinuum/_version.py
CHANGED
discontinuum/engines/gpytorch.py
CHANGED
@@ -49,10 +49,11 @@ class MarginalGPyTorch(BaseModel):
|
|
49
49
|
target: Dataset,
|
50
50
|
target_unc: Dataset = None,
|
51
51
|
iterations: int = 100,
|
52
|
-
optimizer: str = "
|
52
|
+
optimizer: str = "adamw",
|
53
53
|
learning_rate: float = None,
|
54
54
|
early_stopping: bool = False,
|
55
|
-
|
55
|
+
patience: int = 30,
|
56
|
+
gradient_noise: bool = False,
|
56
57
|
):
|
57
58
|
"""Fit the model to data.
|
58
59
|
|
@@ -67,13 +68,15 @@ class MarginalGPyTorch(BaseModel):
|
|
67
68
|
iterations : int, optional
|
68
69
|
Number of iterations for optimization. The default is 100.
|
69
70
|
optimizer : str, optional
|
70
|
-
Optimization method. The default is "
|
71
|
+
Optimization method. Supported: "adam", "adamw". The default is "adamw".
|
71
72
|
learning_rate : float, optional
|
72
73
|
Learning rate for optimization. If None, uses adaptive defaults.
|
73
74
|
early_stopping : bool, optional
|
74
75
|
Whether to use early stopping. The default is False.
|
75
|
-
|
76
|
-
Number of iterations to wait without improvement before stopping. The default is
|
76
|
+
patience : int, optional
|
77
|
+
Number of iterations to wait without improvement before stopping. The default is 60.
|
78
|
+
gradient_noise : bool, optional
|
79
|
+
Whether to inject Gaussian noise into gradients each step (std = 0.1 × current learning rate). The default is False.
|
77
80
|
"""
|
78
81
|
self.is_fitted = True
|
79
82
|
# setup data manager (self.dm)
|
@@ -95,31 +98,38 @@ class MarginalGPyTorch(BaseModel):
|
|
95
98
|
self.model.train()
|
96
99
|
self.likelihood.train()
|
97
100
|
|
98
|
-
# Adaptive learning rate selection for faster convergence
|
99
101
|
if learning_rate is None:
|
100
|
-
|
101
|
-
|
102
|
-
elif optimizer == "lbfgs":
|
103
|
-
learning_rate = 1.0 # L-BFGS doesn't use learning rate the same way
|
102
|
+
# More conservative starting LR
|
103
|
+
learning_rate = 0.05
|
104
104
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
105
|
+
if optimizer == "adamw":
|
106
|
+
optimizer_obj = torch.optim.AdamW(
|
107
|
+
self.model.parameters(),
|
108
|
+
lr=learning_rate,
|
109
|
+
betas=(0.9, 0.999),
|
110
|
+
eps=1e-8,
|
111
|
+
weight_decay=1e-2 # Stronger regularization for AdamW
|
112
|
+
)
|
113
|
+
elif optimizer == "adam":
|
114
|
+
optimizer_obj = torch.optim.Adam(
|
115
|
+
self.model.parameters(),
|
116
|
+
lr=learning_rate,
|
117
|
+
betas=(0.9, 0.999),
|
118
|
+
eps=1e-8,
|
119
|
+
weight_decay=1e-4 # Lighter regularization for Adam
|
120
|
+
)
|
121
|
+
else:
|
122
|
+
raise NotImplementedError(f"Only 'adam' and 'adamw' optimizers are supported. Got '{optimizer}'.")
|
123
|
+
|
124
|
+
# Use ReduceLROnPlateau for more stable learning rate adaptation
|
116
125
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
117
|
-
|
118
|
-
mode='min',
|
119
|
-
factor=0.
|
120
|
-
patience=
|
121
|
-
|
122
|
-
|
126
|
+
optimizer_obj,
|
127
|
+
mode='min',
|
128
|
+
factor=0.5, # Reduce LR by half
|
129
|
+
patience=max(2, patience),
|
130
|
+
threshold=5e-1, # Aggressive plateau detection
|
131
|
+
#threshold_mode='rel', # Use relative threshold
|
132
|
+
min_lr=1e-5
|
123
133
|
)
|
124
134
|
|
125
135
|
# "Loss" for GPs - the marginal log likelihood
|
@@ -133,98 +143,97 @@ class MarginalGPyTorch(BaseModel):
|
|
133
143
|
min_lr_for_early_stop = 2e-5 # Stop if patience is exceeded and LR is below this
|
134
144
|
|
135
145
|
for i in pbar:
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
loss
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
|
171
|
-
)
|
172
|
-
continue
|
173
|
-
|
174
|
-
loss.backward()
|
175
|
-
|
176
|
-
# Gradient clipping for stability
|
177
|
-
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
178
|
-
|
179
|
-
# Check for NaN gradients
|
180
|
-
has_nan_grad = False
|
146
|
+
# Adam/AdamW optimizer with stability features
|
147
|
+
optimizer_obj.zero_grad()
|
148
|
+
output = self.model(train_x)
|
149
|
+
|
150
|
+
# Attempt loss calculation with dynamic jitter
|
151
|
+
try:
|
152
|
+
with gpytorch.settings.cholesky_jitter(jitter):
|
153
|
+
loss = -mll(output, train_y)
|
154
|
+
except Exception as e:
|
155
|
+
# Increase jitter if numerical issues occur
|
156
|
+
jitter = min(jitter * 10, 1e-2)
|
157
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
158
|
+
pbar.set_postfix_str(
|
159
|
+
f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
|
160
|
+
)
|
161
|
+
continue
|
162
|
+
|
163
|
+
# Check for NaN loss
|
164
|
+
if torch.isnan(loss) or torch.isinf(loss):
|
165
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
166
|
+
pbar.set_postfix_str(
|
167
|
+
f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
|
168
|
+
)
|
169
|
+
continue
|
170
|
+
|
171
|
+
loss.backward()
|
172
|
+
|
173
|
+
# Get current learning rate before gradient noise injection
|
174
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
175
|
+
|
176
|
+
# Gradient noise injection (if enabled)
|
177
|
+
if gradient_noise:
|
178
|
+
gradient_noise_scale = 0.1
|
179
|
+
adaptive_noise = gradient_noise_scale * current_lr
|
181
180
|
for param in self.model.parameters():
|
182
|
-
if param.grad is not None
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
continue
|
203
|
-
|
204
|
-
optimizer.step()
|
205
|
-
|
206
|
-
# Update learning rate scheduler for Adam
|
207
|
-
scheduler.step(loss)
|
208
|
-
current_lr = optimizer.param_groups[0]['lr']
|
209
|
-
|
210
|
-
# Early stopping check (more aggressive)
|
181
|
+
if param.grad is not None:
|
182
|
+
noise = torch.normal(mean=0.0, std=adaptive_noise, size=param.grad.shape, device=param.grad.device)
|
183
|
+
param.grad.add_(noise)
|
184
|
+
|
185
|
+
# Gradient clipping for stability
|
186
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
187
|
+
|
188
|
+
# Check for NaN gradients
|
189
|
+
has_nan_grad = False
|
190
|
+
for param in self.model.parameters():
|
191
|
+
if param.grad is not None and torch.isnan(param.grad).any():
|
192
|
+
has_nan_grad = True
|
193
|
+
break
|
194
|
+
|
195
|
+
if has_nan_grad:
|
196
|
+
# Don't update scheduler on NaN gradients - this prevents rapid LR decay
|
197
|
+
# The scheduler should only respond to actual optimization progress
|
198
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
199
|
+
|
200
|
+
# Update best loss tracking (loss is still valid, just gradients are NaN)
|
211
201
|
if loss.item() < best_loss:
|
212
202
|
best_loss = loss.item()
|
213
203
|
patience_counter = 0
|
214
204
|
else:
|
215
205
|
patience_counter += 1
|
216
206
|
|
217
|
-
# Display
|
207
|
+
# Display comprehensive info even with NaN gradients, skip normal progress update
|
208
|
+
pbar.set_postfix_str(
|
209
|
+
f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
|
210
|
+
)
|
211
|
+
continue
|
212
|
+
|
213
|
+
optimizer_obj.step()
|
214
|
+
|
215
|
+
# Update learning rate scheduler for Adam/AdamW
|
216
|
+
scheduler.step(loss.item())
|
217
|
+
current_lr = optimizer_obj.param_groups[0]['lr']
|
218
|
+
|
219
|
+
# Early stopping check (more aggressive)
|
220
|
+
if loss.item() < best_loss:
|
221
|
+
best_loss = loss.item()
|
222
|
+
patience_counter = 0
|
223
|
+
else:
|
224
|
+
patience_counter += 1
|
225
|
+
|
226
|
+
# Only update progress bar if not skipped above
|
227
|
+
if not has_nan_grad:
|
218
228
|
progress_info = f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f}'
|
219
229
|
if early_stopping:
|
220
|
-
progress_info += f' patience={patience_counter}/
|
230
|
+
progress_info += f' patience={patience_counter}/{patience}'
|
221
231
|
pbar.set_postfix_str(progress_info)
|
222
232
|
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
break
|
233
|
+
if early_stopping and patience_counter >= patience and current_lr <= min_lr_for_early_stop:
|
234
|
+
print(f"\nEarly stopping triggered after {i+1} iterations")
|
235
|
+
print(f"Best loss: {best_loss:.6f}")
|
236
|
+
break
|
228
237
|
|
229
238
|
@is_fitted
|
230
239
|
def predict(self,
|
@@ -1,17 +1,17 @@
|
|
1
1
|
discontinuum/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
discontinuum/_version.py,sha256=
|
2
|
+
discontinuum/_version.py,sha256=B7xX94ww8E2YsJCo2PT7LI1Lp5224NjscDIhXgKzj3U,511
|
3
3
|
discontinuum/data_manager.py,sha256=LiZoPR0nnu7YAUfh5L1ZDRfaS3dgfVIELXIHkzUKyBg,4416
|
4
4
|
discontinuum/pipeline.py,sha256=1avuZnFai-b3HmihcpZ8M3WFNQ8lXAFSNTrnfl2NrY0,10074
|
5
5
|
discontinuum/plot.py,sha256=eZQS6-Ydq8FFcEukPtNuDVB-weV6lHyWMyJ1hqTkVrU,2969
|
6
6
|
discontinuum/utils.py,sha256=07hIHQk_oDlkjz7tasgBjqqPOC6D0iNcy0eu-88aNbM,1540
|
7
7
|
discontinuum/engines/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
discontinuum/engines/base.py,sha256=OlHd4ssIQoWvYHKoVqk5fKAVBcKsIIkR4ul9iNBvaYg,2396
|
9
|
-
discontinuum/engines/gpytorch.py,sha256=
|
9
|
+
discontinuum/engines/gpytorch.py,sha256=05x7Ha0g2vywM_moL18fMFDGeh0CF3vJpF-mDImrIx8,14387
|
10
10
|
discontinuum/engines/pymc.py,sha256=phbtE-3UCSVcP1MhbXwAHIWDZWDr56wK9U7aRt-w-2o,5961
|
11
11
|
discontinuum/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
discontinuum/providers/base.py,sha256=Yn2EHS1b4fYl09-m2MYuf2P9VRUXAP-WDpSoZrCbRvY,720
|
13
13
|
discontinuum/tests/test_pipeline.py,sha256=_FhkGxbFIxNb35lGaIdZk7Zjgs6CkxEF3gFUX3PE8EU,918
|
14
|
-
discontinuum-1.0.
|
14
|
+
discontinuum-1.0.6.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
|
15
15
|
loadest_gp/__init__.py,sha256=YISfvbc7Zy2y0BOxS1A2KzqxyoNJTz0EnLMnRW6iVT8,740
|
16
16
|
loadest_gp/plot.py,sha256=x2PK7vBCc44dX9lu5YV-rvw1u4pvXSLdcrTSvYLiHMA,2595
|
17
17
|
loadest_gp/utils.py,sha256=m5QaqR_0JiuRXPfryH8nI5lODp8PqvQla5C05WDN3LY,2772
|
@@ -22,14 +22,14 @@ loadest_gp/models/pymc.py,sha256=ShP-XNoEwNAf62yCvTuS6Q8iAiB9NQk0dVS69WgkPsE,344
|
|
22
22
|
loadest_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
loadest_gp/providers/usgs.py,sha256=LJ5uh0g3nrZ-8I4poGwCdVqcXcpyZkroixwwt99vBcI,10885
|
24
24
|
rating_gp/pipeline.py,sha256=1HgxN6DD3ZL5lhUb3DK2in2IXiml7W4Ja272GBMTc08,1884
|
25
|
-
rating_gp/plot.py,sha256=
|
25
|
+
rating_gp/plot.py,sha256=_XaeNLYXEcJxg7B4UCxyYccSNzNow0e4dV1z93_THaQ,10899
|
26
26
|
rating_gp/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rating_gp/models/base.py,sha256=e2Kq644I88YLHWPNA0qyRgitF5wimdLW4618vKX-o_s,1474
|
28
|
-
rating_gp/models/gpytorch.py,sha256=
|
28
|
+
rating_gp/models/gpytorch.py,sha256=bNFJFT13DyVqlKhMRD3W0r6-Y72E3S2fw9E-0houyoM,7068
|
29
29
|
rating_gp/models/kernels.py,sha256=3xg2mhY3aEgjI3r5vyAll9MA4c3M5UKqRi3FApNhJJQ,11579
|
30
30
|
rating_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
rating_gp/providers/usgs.py,sha256=KmKYN3c8Mi-ly2l6X80WT3taEhqCPXeEcRNi9HvbJmY,8134
|
32
|
-
discontinuum-1.0.
|
33
|
-
discontinuum-1.0.
|
34
|
-
discontinuum-1.0.
|
35
|
-
discontinuum-1.0.
|
32
|
+
discontinuum-1.0.6.dist-info/METADATA,sha256=3h3AhrQZ3eNviDGAsbcKEd3yd_cYv8XaaIxMlCQdK0s,6302
|
33
|
+
discontinuum-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
34
|
+
discontinuum-1.0.6.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
|
35
|
+
discontinuum-1.0.6.dist-info/RECORD,,
|
rating_gp/models/gpytorch.py
CHANGED
@@ -5,9 +5,9 @@ from discontinuum.engines.gpytorch import MarginalGPyTorch, NoOpMean
|
|
5
5
|
|
6
6
|
from gpytorch.kernels import (
|
7
7
|
MaternKernel,
|
8
|
-
RBFKernel,
|
9
8
|
RQKernel,
|
10
9
|
ScaleKernel,
|
10
|
+
PeriodicKernel,
|
11
11
|
)
|
12
12
|
from gpytorch.priors import (
|
13
13
|
GammaPrior,
|
@@ -15,10 +15,9 @@ from gpytorch.priors import (
|
|
15
15
|
NormalPrior,
|
16
16
|
)
|
17
17
|
|
18
|
-
from linear_operator.operators import MatmulLinearOperator
|
19
18
|
from rating_gp.models.base import RatingDataMixin, ModelConfig
|
20
19
|
from rating_gp.plot import RatingPlotMixin
|
21
|
-
from rating_gp.models.kernels import
|
20
|
+
from rating_gp.models.kernels import SigmoidKernel
|
22
21
|
|
23
22
|
|
24
23
|
class PowerLawTransform(torch.nn.Module):
|
@@ -94,53 +93,55 @@ class ExactGPModel(gpytorch.models.ExactGP):
|
|
94
93
|
|
95
94
|
self.powerlaw = PowerLawTransform()
|
96
95
|
|
97
|
-
# self.mean_module = gpytorch.means.ConstantMean()
|
98
|
-
# self.mean_module = gpytorch.means.LinearMean(input_size=1)
|
99
96
|
self.mean_module = NoOpMean()
|
100
97
|
|
101
|
-
#
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
# (self.cov_stage() * self.cov_stagetime())
|
106
|
-
# + self.cov_residual()
|
107
|
-
# )
|
108
|
-
|
109
|
-
# Stage * time kernel with large time length
|
110
|
-
# + stage * time kernel only at low stage with smaller time length.
|
111
|
-
# Note that stage gets transformed to q, so the kernel is actually
|
112
|
-
# q * time
|
113
|
-
b_min = np.quantile(train_y, 0.30)
|
114
|
-
b_max = np.quantile(train_y, 0.90)
|
98
|
+
# Use stage (not y) for sigmoid kernel constraint
|
99
|
+
stage = train_x[:, self.stage_dim[0]]#.cpu().numpy()
|
100
|
+
b_min = np.quantile(stage, 0.10)
|
101
|
+
b_max = np.quantile(stage, 0.90)
|
115
102
|
self.covar_module = (
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
103
|
+
# core time kernel
|
104
|
+
(
|
105
|
+
self.cov_time(
|
106
|
+
#ls_prior=GammaPrior(concentration=2, rate=1),
|
107
|
+
ls_prior=GammaPrior(concentration=3, rate=1),
|
108
|
+
eta_prior=HalfNormalPrior(scale=0.3),
|
109
|
+
)
|
110
|
+
*
|
111
|
+
self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=2))
|
112
|
+
#self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
|
113
|
+
)
|
114
|
+
# gated shift component
|
115
|
+
+ (
|
116
|
+
self.cov_time(
|
117
|
+
ls_prior=GammaPrior(concentration=2, rate=5),
|
118
|
+
eta_prior=HalfNormalPrior(scale=1),
|
119
|
+
)
|
120
|
+
* SigmoidKernel(
|
121
|
+
active_dims=self.stage_dim,
|
122
|
+
# b_prior=NormalPrior(loc=0.7, scale=0.001),
|
123
|
+
b_constraint=gpytorch.constraints.Interval(
|
124
|
+
b_min,
|
125
|
+
b_max,
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
# additive periodic component for seasonal effects
|
130
|
+
+ self.cov_periodic()
|
130
131
|
)
|
131
132
|
|
132
133
|
|
133
134
|
def forward(self, x):
|
134
|
-
self.powerlaw.b.data.clamp_(1.
|
135
|
+
self.powerlaw.b.data.clamp_(1.2, 2.5)
|
135
136
|
#x = x.clone()
|
136
137
|
#q = self.powerlaw(x[:, self.stage_dim])
|
137
138
|
#x_t[:, self.stage_dim] = self.warp_stage_dim(x_t[:, self.stage_dim])
|
138
139
|
x_t = x.clone()
|
139
140
|
x_t[:, self.stage_dim] = self.powerlaw(x_t[:, self.stage_dim])
|
140
141
|
q = x_t[:, self.stage_dim]
|
141
|
-
|
142
142
|
mean_x = self.mean_module(q)
|
143
|
-
covar_x = self.covar_module(x_t)
|
143
|
+
#covar_x = self.covar_module(x_t)
|
144
|
+
covar_x = self.covar_module(x)
|
144
145
|
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
145
146
|
|
146
147
|
def cov_stage(self, ls_prior=None):
|
@@ -155,61 +156,57 @@ class ExactGPModel(gpytorch.models.ExactGP):
|
|
155
156
|
outputscale_prior=eta,
|
156
157
|
)
|
157
158
|
|
158
|
-
def cov_time(self, ls_prior=None):
|
159
|
-
|
159
|
+
def cov_time(self, ls_prior=None, eta_prior=None):
|
160
|
+
if eta_prior is None:
|
161
|
+
eta_prior = HalfNormalPrior(scale=1)
|
160
162
|
|
161
163
|
# Base Matern kernel for long-term trends
|
162
|
-
|
164
|
+
return ScaleKernel(
|
163
165
|
MaternKernel(
|
164
166
|
active_dims=self.time_dim,
|
165
167
|
lengthscale_prior=ls_prior,
|
166
|
-
nu=1.5, # was
|
167
|
-
),
|
168
|
-
outputscale_prior=eta,
|
169
|
-
)
|
170
|
-
|
171
|
-
# Periodic kernel for annual seasonality
|
172
|
-
# Locally periodic kernel: Periodic * Matern
|
173
|
-
periodic_kernel = ScaleKernel(
|
174
|
-
gpytorch.kernels.PeriodicKernel(
|
175
|
-
active_dims=self.time_dim,
|
176
|
-
period_length_prior=NormalPrior(loc=1.0, scale=0.05), # ~1 year
|
177
|
-
lengthscale_prior=GammaPrior(concentration=6, rate=1),
|
178
|
-
) * MaternKernel(
|
179
|
-
active_dims=self.time_dim,
|
180
|
-
nu=2.5,
|
181
|
-
lengthscale_prior=GammaPrior(concentration=4, rate=3),
|
168
|
+
nu=1.5, # was 1.5 XXX
|
182
169
|
),
|
183
|
-
outputscale_prior=
|
170
|
+
outputscale_prior=eta_prior,
|
184
171
|
)
|
185
|
-
|
186
|
-
return base_kernel + periodic_kernel
|
187
172
|
|
173
|
+
def cov_periodic(self, ls_prior=None, eta_prior=None):
|
174
|
+
"""
|
175
|
+
Smooth, time-dependent periodic kernel for seasonal effects.
|
176
|
+
"""
|
177
|
+
if eta_prior is None:
|
178
|
+
eta_prior = HalfNormalPrior(scale=0.5)
|
188
179
|
|
189
|
-
|
190
|
-
|
191
|
-
def cov_stagetime(self):
|
192
|
-
eta = HalfNormalPrior(scale=1)
|
193
|
-
ls = GammaPrior(concentration=2, rate=1)
|
180
|
+
if ls_prior is None:
|
181
|
+
ls_prior = GammaPrior(concentration=3, rate=1)
|
194
182
|
|
195
183
|
return ScaleKernel(
|
196
|
-
|
197
|
-
active_dims=self.
|
198
|
-
|
184
|
+
PeriodicKernel(
|
185
|
+
active_dims=self.time_dim,
|
186
|
+
period_length_prior=NormalPrior(loc=1.0, scale=0.1), # ~1 year
|
187
|
+
# lengthscale_prior=GammaPrior(concentration=2, rate=4),
|
199
188
|
),
|
200
|
-
#
|
189
|
+
# *
|
190
|
+
# MaternKernel(
|
191
|
+
# active_dims=self.stage_dim,
|
192
|
+
# lengthscale_prior=ls_prior,
|
193
|
+
# nu=2.5, # Smoother kernel (was nu=1.5)
|
194
|
+
# ),
|
195
|
+
outputscale_prior=HalfNormalPrior(scale=0.5),
|
201
196
|
)
|
202
|
-
|
203
|
-
def
|
204
|
-
|
205
|
-
|
206
|
-
|
197
|
+
|
198
|
+
def cov_base(self):
|
199
|
+
"""
|
200
|
+
Smooth, time-independent base rating curve using a Matern kernel on stage.
|
201
|
+
"""
|
202
|
+
# Base should capture most variation
|
203
|
+
eta = HalfNormalPrior(scale=1)
|
204
|
+
ls = GammaPrior(concentration=3, rate=1)
|
207
205
|
return ScaleKernel(
|
208
206
|
MaternKernel(
|
209
|
-
|
210
|
-
nu=1.5,
|
211
|
-
active_dims=self.dims,
|
207
|
+
active_dims=self.stage_dim,
|
212
208
|
lengthscale_prior=ls,
|
213
209
|
),
|
214
210
|
outputscale_prior=eta,
|
215
211
|
)
|
212
|
+
|
rating_gp/plot.py
CHANGED
@@ -12,6 +12,7 @@ import pandas as pd
|
|
12
12
|
from discontinuum.engines.base import is_fitted
|
13
13
|
from discontinuum.plot import BasePlotMixin
|
14
14
|
from scipy.stats import norm
|
15
|
+
from rating_gp.models.kernels import SigmoidKernel
|
15
16
|
import xarray as xr
|
16
17
|
from xarray import DataArray
|
17
18
|
from xarray.plot.utils import label_from_attrs
|
@@ -117,6 +118,28 @@ class RatingPlotMixin(BasePlotMixin):
|
|
117
118
|
zorder=1,
|
118
119
|
**kwargs
|
119
120
|
)
|
121
|
+
# Plot switch point if sigmoid kernel is in model
|
122
|
+
try:
|
123
|
+
# find first SigmoidKernel in covar_module
|
124
|
+
sig_kernels = [m for m in self.model.covar_module.modules() if isinstance(m, SigmoidKernel)]
|
125
|
+
if sig_kernels:
|
126
|
+
sig = sig_kernels[0]
|
127
|
+
# b_sig is in normalized stage space: inverse-transform to original stage units
|
128
|
+
b_sig = sig.b.item()
|
129
|
+
# use scaler step directly to inverse-transform normalized stage
|
130
|
+
pipeline = self.dm.covariate_pipelines['stage']
|
131
|
+
scaler = pipeline.named_steps['scaler']
|
132
|
+
stage_switch = float(scaler.inverse_transform(b_sig))
|
133
|
+
# Draw switch point on every call; label only once
|
134
|
+
if not getattr(ax, 'switch_point_plotted', False):
|
135
|
+
ax.axvline(stage_switch, linestyle='--', color='gray', label='switch point')
|
136
|
+
ax.switch_point_plotted = True
|
137
|
+
ax.legend()
|
138
|
+
else:
|
139
|
+
# subsequent calls, draw without label
|
140
|
+
ax.axvline(stage_switch, linestyle='--', color='gray')
|
141
|
+
except Exception:
|
142
|
+
pass
|
120
143
|
|
121
144
|
# self.plot_observed_rating(ax, zorder=3)
|
122
145
|
|
File without changes
|
File without changes
|
File without changes
|