discontinuum 1.0.4__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
discontinuum/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '1.0.4'
21
- __version_tuple__ = version_tuple = (1, 0, 4)
20
+ __version__ = version = '1.0.6'
21
+ __version_tuple__ = version_tuple = (1, 0, 6)
@@ -49,10 +49,11 @@ class MarginalGPyTorch(BaseModel):
49
49
  target: Dataset,
50
50
  target_unc: Dataset = None,
51
51
  iterations: int = 100,
52
- optimizer: str = "adam",
52
+ optimizer: str = "adamw",
53
53
  learning_rate: float = None,
54
54
  early_stopping: bool = False,
55
- early_stopping_patience: int = 100,
55
+ patience: int = 30,
56
+ gradient_noise: bool = False,
56
57
  ):
57
58
  """Fit the model to data.
58
59
 
@@ -67,13 +68,15 @@ class MarginalGPyTorch(BaseModel):
67
68
  iterations : int, optional
68
69
  Number of iterations for optimization. The default is 100.
69
70
  optimizer : str, optional
70
- Optimization method. The default is "adam".
71
+ Optimization method. Supported: "adam", "adamw". The default is "adamw".
71
72
  learning_rate : float, optional
72
73
  Learning rate for optimization. If None, uses adaptive defaults.
73
74
  early_stopping : bool, optional
74
75
  Whether to use early stopping. The default is False.
75
- early_stopping_patience : int, optional
76
- Number of iterations to wait without improvement before stopping. The default is 100.
76
+ patience : int, optional
77
+ Number of iterations to wait without improvement before stopping. The default is 60.
78
+ gradient_noise : bool, optional
79
+ Whether to inject Gaussian noise into gradients each step (std = 0.1 × current learning rate). The default is False.
77
80
  """
78
81
  self.is_fitted = True
79
82
  # setup data manager (self.dm)
@@ -95,31 +98,38 @@ class MarginalGPyTorch(BaseModel):
95
98
  self.model.train()
96
99
  self.likelihood.train()
97
100
 
98
- # Adaptive learning rate selection for faster convergence
99
101
  if learning_rate is None:
100
- if optimizer == "adam":
101
- learning_rate = 0.1 # More aggressive default for faster convergence
102
- elif optimizer == "lbfgs":
103
- learning_rate = 1.0 # L-BFGS doesn't use learning rate the same way
102
+ # More conservative starting LR
103
+ learning_rate = 0.05
104
104
 
105
- # Use the specified optimizer with stabilization
106
- if optimizer != "adam":
107
- raise NotImplementedError(f"Only 'adam' optimizer is supported. Got '{optimizer}'.")
108
- optimizer = torch.optim.Adam(
109
- self.model.parameters(),
110
- lr=learning_rate,
111
- betas=(0.9, 0.999), # Slightly more conservative momentum
112
- eps=1e-8, # Numerical stability
113
- weight_decay=1e-4 # Small L2 regularization
114
- )
115
- # More responsive learning rate scheduler for faster adaptation
105
+ if optimizer == "adamw":
106
+ optimizer_obj = torch.optim.AdamW(
107
+ self.model.parameters(),
108
+ lr=learning_rate,
109
+ betas=(0.9, 0.999),
110
+ eps=1e-8,
111
+ weight_decay=1e-2 # Stronger regularization for AdamW
112
+ )
113
+ elif optimizer == "adam":
114
+ optimizer_obj = torch.optim.Adam(
115
+ self.model.parameters(),
116
+ lr=learning_rate,
117
+ betas=(0.9, 0.999),
118
+ eps=1e-8,
119
+ weight_decay=1e-4 # Lighter regularization for Adam
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"Only 'adam' and 'adamw' optimizers are supported. Got '{optimizer}'.")
123
+
124
+ # Use ReduceLROnPlateau for more stable learning rate adaptation
116
125
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
117
- optimizer,
118
- mode='min',
119
- factor=0.6, # Reduce LR by 40% when loss plateaus (more aggressive)
120
- patience=40, # Reduce sooner for faster adaptation
121
- min_lr=1e-5, # Higher minimum learning rate
122
- threshold=1e-4 # Less sensitive to plateaus
126
+ optimizer_obj,
127
+ mode='min',
128
+ factor=0.5, # Reduce LR by half
129
+ patience=max(2, patience),
130
+ threshold=5e-1, # Aggressive plateau detection
131
+ #threshold_mode='rel', # Use relative threshold
132
+ min_lr=1e-5
123
133
  )
124
134
 
125
135
  # "Loss" for GPs - the marginal log likelihood
@@ -133,98 +143,97 @@ class MarginalGPyTorch(BaseModel):
133
143
  min_lr_for_early_stop = 2e-5 # Stop if patience is exceeded and LR is below this
134
144
 
135
145
  for i in pbar:
136
- if optimizer.__class__.__name__ == "LBFGS":
137
- # L-BFGS requires a closure function
138
- def closure():
139
- optimizer.zero_grad()
140
- output = self.model(train_x)
141
- with gpytorch.settings.cholesky_jitter(jitter):
142
- loss = -mll(output, train_y).sum()
143
- loss.backward()
144
- return loss
145
-
146
- loss = optimizer.step(closure)
147
- pbar.set_postfix(loss=loss.item())
148
- else:
149
- # Adam optimizer with stability features
150
- optimizer.zero_grad()
151
- output = self.model(train_x)
152
-
153
- # Attempt loss calculation with dynamic jitter
154
- try:
155
- with gpytorch.settings.cholesky_jitter(jitter):
156
- loss = -mll(output, train_y)
157
- except Exception as e:
158
- # Increase jitter if numerical issues occur
159
- jitter = min(jitter * 10, 1e-2)
160
- current_lr = optimizer.param_groups[0]['lr']
161
- pbar.set_postfix_str(
162
- f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
163
- )
164
- continue
165
-
166
- # Check for NaN loss
167
- if torch.isnan(loss) or torch.isinf(loss):
168
- current_lr = optimizer.param_groups[0]['lr']
169
- pbar.set_postfix_str(
170
- f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
171
- )
172
- continue
173
-
174
- loss.backward()
175
-
176
- # Gradient clipping for stability
177
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
178
-
179
- # Check for NaN gradients
180
- has_nan_grad = False
146
+ # Adam/AdamW optimizer with stability features
147
+ optimizer_obj.zero_grad()
148
+ output = self.model(train_x)
149
+
150
+ # Attempt loss calculation with dynamic jitter
151
+ try:
152
+ with gpytorch.settings.cholesky_jitter(jitter):
153
+ loss = -mll(output, train_y)
154
+ except Exception as e:
155
+ # Increase jitter if numerical issues occur
156
+ jitter = min(jitter * 10, 1e-2)
157
+ current_lr = optimizer_obj.param_groups[0]['lr']
158
+ pbar.set_postfix_str(
159
+ f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
160
+ )
161
+ continue
162
+
163
+ # Check for NaN loss
164
+ if torch.isnan(loss) or torch.isinf(loss):
165
+ current_lr = optimizer_obj.param_groups[0]['lr']
166
+ pbar.set_postfix_str(
167
+ f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
168
+ )
169
+ continue
170
+
171
+ loss.backward()
172
+
173
+ # Get current learning rate before gradient noise injection
174
+ current_lr = optimizer_obj.param_groups[0]['lr']
175
+
176
+ # Gradient noise injection (if enabled)
177
+ if gradient_noise:
178
+ gradient_noise_scale = 0.1
179
+ adaptive_noise = gradient_noise_scale * current_lr
181
180
  for param in self.model.parameters():
182
- if param.grad is not None and torch.isnan(param.grad).any():
183
- has_nan_grad = True
184
- break
185
-
186
- if has_nan_grad:
187
- # Don't update scheduler on NaN gradients - this prevents rapid LR decay
188
- # The scheduler should only respond to actual optimization progress
189
- current_lr = optimizer.param_groups[0]['lr']
190
-
191
- # Update best loss tracking (loss is still valid, just gradients are NaN)
192
- if loss.item() < best_loss:
193
- best_loss = loss.item()
194
- patience_counter = 0
195
- else:
196
- patience_counter += 1
197
-
198
- # Display comprehensive info even with NaN gradients
199
- pbar.set_postfix_str(
200
- f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
201
- )
202
- continue
203
-
204
- optimizer.step()
205
-
206
- # Update learning rate scheduler for Adam
207
- scheduler.step(loss)
208
- current_lr = optimizer.param_groups[0]['lr']
209
-
210
- # Early stopping check (more aggressive)
181
+ if param.grad is not None:
182
+ noise = torch.normal(mean=0.0, std=adaptive_noise, size=param.grad.shape, device=param.grad.device)
183
+ param.grad.add_(noise)
184
+
185
+ # Gradient clipping for stability
186
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
187
+
188
+ # Check for NaN gradients
189
+ has_nan_grad = False
190
+ for param in self.model.parameters():
191
+ if param.grad is not None and torch.isnan(param.grad).any():
192
+ has_nan_grad = True
193
+ break
194
+
195
+ if has_nan_grad:
196
+ # Don't update scheduler on NaN gradients - this prevents rapid LR decay
197
+ # The scheduler should only respond to actual optimization progress
198
+ current_lr = optimizer_obj.param_groups[0]['lr']
199
+
200
+ # Update best loss tracking (loss is still valid, just gradients are NaN)
211
201
  if loss.item() < best_loss:
212
202
  best_loss = loss.item()
213
203
  patience_counter = 0
214
204
  else:
215
205
  patience_counter += 1
216
206
 
217
- # Display progress with comprehensive metadata
207
+ # Display comprehensive info even with NaN gradients, skip normal progress update
208
+ pbar.set_postfix_str(
209
+ f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
210
+ )
211
+ continue
212
+
213
+ optimizer_obj.step()
214
+
215
+ # Update learning rate scheduler for Adam/AdamW
216
+ scheduler.step(loss.item())
217
+ current_lr = optimizer_obj.param_groups[0]['lr']
218
+
219
+ # Early stopping check (more aggressive)
220
+ if loss.item() < best_loss:
221
+ best_loss = loss.item()
222
+ patience_counter = 0
223
+ else:
224
+ patience_counter += 1
225
+
226
+ # Only update progress bar if not skipped above
227
+ if not has_nan_grad:
218
228
  progress_info = f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f}'
219
229
  if early_stopping:
220
- progress_info += f' patience={patience_counter}/25'
230
+ progress_info += f' patience={patience_counter}/{patience}'
221
231
  pbar.set_postfix_str(progress_info)
222
232
 
223
- # More aggressive early stopping: patience=25 and require LR to be low
224
- if early_stopping and patience_counter >= 25 and current_lr <= min_lr_for_early_stop:
225
- print(f"\nEarly stopping triggered after {i+1} iterations (patience exceeded and LR low)")
226
- print(f"Best loss: {best_loss:.6f}")
227
- break
233
+ if early_stopping and patience_counter >= patience and current_lr <= min_lr_for_early_stop:
234
+ print(f"\nEarly stopping triggered after {i+1} iterations")
235
+ print(f"Best loss: {best_loss:.6f}")
236
+ break
228
237
 
229
238
  @is_fitted
230
239
  def predict(self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: discontinuum
3
- Version: 1.0.4
3
+ Version: 1.0.6
4
4
  Summary: Estimate discontinuous timeseries from continuous covariates.
5
5
  Maintainer-email: Timothy Hodson <thodson@usgs.gov>
6
6
  License: License
@@ -1,17 +1,17 @@
1
1
  discontinuum/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- discontinuum/_version.py,sha256=rXTOeD0YpRo_kJ2LqUiMnTKEFf43sO_PBvJHDh0SQUA,511
2
+ discontinuum/_version.py,sha256=B7xX94ww8E2YsJCo2PT7LI1Lp5224NjscDIhXgKzj3U,511
3
3
  discontinuum/data_manager.py,sha256=LiZoPR0nnu7YAUfh5L1ZDRfaS3dgfVIELXIHkzUKyBg,4416
4
4
  discontinuum/pipeline.py,sha256=1avuZnFai-b3HmihcpZ8M3WFNQ8lXAFSNTrnfl2NrY0,10074
5
5
  discontinuum/plot.py,sha256=eZQS6-Ydq8FFcEukPtNuDVB-weV6lHyWMyJ1hqTkVrU,2969
6
6
  discontinuum/utils.py,sha256=07hIHQk_oDlkjz7tasgBjqqPOC6D0iNcy0eu-88aNbM,1540
7
7
  discontinuum/engines/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  discontinuum/engines/base.py,sha256=OlHd4ssIQoWvYHKoVqk5fKAVBcKsIIkR4ul9iNBvaYg,2396
9
- discontinuum/engines/gpytorch.py,sha256=36TxE_qfRUjuOB16eXmyrxPlicKzXkdQ7xnfqL2ucy0,14539
9
+ discontinuum/engines/gpytorch.py,sha256=05x7Ha0g2vywM_moL18fMFDGeh0CF3vJpF-mDImrIx8,14387
10
10
  discontinuum/engines/pymc.py,sha256=phbtE-3UCSVcP1MhbXwAHIWDZWDr56wK9U7aRt-w-2o,5961
11
11
  discontinuum/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  discontinuum/providers/base.py,sha256=Yn2EHS1b4fYl09-m2MYuf2P9VRUXAP-WDpSoZrCbRvY,720
13
13
  discontinuum/tests/test_pipeline.py,sha256=_FhkGxbFIxNb35lGaIdZk7Zjgs6CkxEF3gFUX3PE8EU,918
14
- discontinuum-1.0.4.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
14
+ discontinuum-1.0.6.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
15
15
  loadest_gp/__init__.py,sha256=YISfvbc7Zy2y0BOxS1A2KzqxyoNJTz0EnLMnRW6iVT8,740
16
16
  loadest_gp/plot.py,sha256=x2PK7vBCc44dX9lu5YV-rvw1u4pvXSLdcrTSvYLiHMA,2595
17
17
  loadest_gp/utils.py,sha256=m5QaqR_0JiuRXPfryH8nI5lODp8PqvQla5C05WDN3LY,2772
@@ -22,14 +22,14 @@ loadest_gp/models/pymc.py,sha256=ShP-XNoEwNAf62yCvTuS6Q8iAiB9NQk0dVS69WgkPsE,344
22
22
  loadest_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  loadest_gp/providers/usgs.py,sha256=LJ5uh0g3nrZ-8I4poGwCdVqcXcpyZkroixwwt99vBcI,10885
24
24
  rating_gp/pipeline.py,sha256=1HgxN6DD3ZL5lhUb3DK2in2IXiml7W4Ja272GBMTc08,1884
25
- rating_gp/plot.py,sha256=CJphwqWWAfIY22j5Oz5DRwj7TcQCRyIQvM79_3KEdlc,9635
25
+ rating_gp/plot.py,sha256=_XaeNLYXEcJxg7B4UCxyYccSNzNow0e4dV1z93_THaQ,10899
26
26
  rating_gp/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rating_gp/models/base.py,sha256=e2Kq644I88YLHWPNA0qyRgitF5wimdLW4618vKX-o_s,1474
28
- rating_gp/models/gpytorch.py,sha256=4SqOdWIvI93kDq9S4cDPHXX25EHNjT_hKwZijhAR4C0,7121
28
+ rating_gp/models/gpytorch.py,sha256=bNFJFT13DyVqlKhMRD3W0r6-Y72E3S2fw9E-0houyoM,7068
29
29
  rating_gp/models/kernels.py,sha256=3xg2mhY3aEgjI3r5vyAll9MA4c3M5UKqRi3FApNhJJQ,11579
30
30
  rating_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  rating_gp/providers/usgs.py,sha256=KmKYN3c8Mi-ly2l6X80WT3taEhqCPXeEcRNi9HvbJmY,8134
32
- discontinuum-1.0.4.dist-info/METADATA,sha256=A6T6BQocZmIox600f7nU5Tb9r7x5YthC5ba1WRET2XM,6302
33
- discontinuum-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- discontinuum-1.0.4.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
35
- discontinuum-1.0.4.dist-info/RECORD,,
32
+ discontinuum-1.0.6.dist-info/METADATA,sha256=3h3AhrQZ3eNviDGAsbcKEd3yd_cYv8XaaIxMlCQdK0s,6302
33
+ discontinuum-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
+ discontinuum-1.0.6.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
35
+ discontinuum-1.0.6.dist-info/RECORD,,
@@ -5,9 +5,9 @@ from discontinuum.engines.gpytorch import MarginalGPyTorch, NoOpMean
5
5
 
6
6
  from gpytorch.kernels import (
7
7
  MaternKernel,
8
- RBFKernel,
9
8
  RQKernel,
10
9
  ScaleKernel,
10
+ PeriodicKernel,
11
11
  )
12
12
  from gpytorch.priors import (
13
13
  GammaPrior,
@@ -15,10 +15,9 @@ from gpytorch.priors import (
15
15
  NormalPrior,
16
16
  )
17
17
 
18
- from linear_operator.operators import MatmulLinearOperator
19
18
  from rating_gp.models.base import RatingDataMixin, ModelConfig
20
19
  from rating_gp.plot import RatingPlotMixin
21
- from rating_gp.models.kernels import StageTimeKernel, SigmoidKernel, LogWarp, TanhWarp
20
+ from rating_gp.models.kernels import SigmoidKernel
22
21
 
23
22
 
24
23
  class PowerLawTransform(torch.nn.Module):
@@ -94,53 +93,55 @@ class ExactGPModel(gpytorch.models.ExactGP):
94
93
 
95
94
  self.powerlaw = PowerLawTransform()
96
95
 
97
- # self.mean_module = gpytorch.means.ConstantMean()
98
- # self.mean_module = gpytorch.means.LinearMean(input_size=1)
99
96
  self.mean_module = NoOpMean()
100
97
 
101
- #self.warp_stage_dim = TanhWarp()
102
- #self.warp_stage_dim = LogWarp()
103
-
104
- # self.covar_module = (
105
- # (self.cov_stage() * self.cov_stagetime())
106
- # + self.cov_residual()
107
- # )
108
-
109
- # Stage * time kernel with large time length
110
- # + stage * time kernel only at low stage with smaller time length.
111
- # Note that stage gets transformed to q, so the kernel is actually
112
- # q * time
113
- b_min = np.quantile(train_y, 0.30)
114
- b_max = np.quantile(train_y, 0.90)
98
+ # Use stage (not y) for sigmoid kernel constraint
99
+ stage = train_x[:, self.stage_dim[0]]#.cpu().numpy()
100
+ b_min = np.quantile(stage, 0.10)
101
+ b_max = np.quantile(stage, 0.90)
115
102
  self.covar_module = (
116
- (self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
117
- * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=1)))
118
- + (self.cov_stage(ls_prior=GammaPrior(concentration=5, rate=1))
119
- * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=5))
120
- * SigmoidKernel(
121
- active_dims=self.stage_dim,
122
- # a_prior=NormalPrior(loc=20, scale=1),
123
- # b_prior=NormalPrior(loc=0.5, scale=0.2),
124
- b_constraint=gpytorch.constraints.Interval(
125
- b_min,
126
- b_max,
127
- ),
128
- )
129
- )
103
+ # core time kernel
104
+ (
105
+ self.cov_time(
106
+ #ls_prior=GammaPrior(concentration=2, rate=1),
107
+ ls_prior=GammaPrior(concentration=3, rate=1),
108
+ eta_prior=HalfNormalPrior(scale=0.3),
109
+ )
110
+ *
111
+ self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=2))
112
+ #self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
113
+ )
114
+ # gated shift component
115
+ + (
116
+ self.cov_time(
117
+ ls_prior=GammaPrior(concentration=2, rate=5),
118
+ eta_prior=HalfNormalPrior(scale=1),
119
+ )
120
+ * SigmoidKernel(
121
+ active_dims=self.stage_dim,
122
+ # b_prior=NormalPrior(loc=0.7, scale=0.001),
123
+ b_constraint=gpytorch.constraints.Interval(
124
+ b_min,
125
+ b_max,
126
+ ),
127
+ )
128
+ )
129
+ # additive periodic component for seasonal effects
130
+ + self.cov_periodic()
130
131
  )
131
132
 
132
133
 
133
134
  def forward(self, x):
134
- self.powerlaw.b.data.clamp_(1.5, 2.5)
135
+ self.powerlaw.b.data.clamp_(1.2, 2.5)
135
136
  #x = x.clone()
136
137
  #q = self.powerlaw(x[:, self.stage_dim])
137
138
  #x_t[:, self.stage_dim] = self.warp_stage_dim(x_t[:, self.stage_dim])
138
139
  x_t = x.clone()
139
140
  x_t[:, self.stage_dim] = self.powerlaw(x_t[:, self.stage_dim])
140
141
  q = x_t[:, self.stage_dim]
141
-
142
142
  mean_x = self.mean_module(q)
143
- covar_x = self.covar_module(x_t)
143
+ #covar_x = self.covar_module(x_t)
144
+ covar_x = self.covar_module(x)
144
145
  return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
145
146
 
146
147
  def cov_stage(self, ls_prior=None):
@@ -155,61 +156,57 @@ class ExactGPModel(gpytorch.models.ExactGP):
155
156
  outputscale_prior=eta,
156
157
  )
157
158
 
158
- def cov_time(self, ls_prior=None):
159
- eta = HalfNormalPrior(scale=1)
159
+ def cov_time(self, ls_prior=None, eta_prior=None):
160
+ if eta_prior is None:
161
+ eta_prior = HalfNormalPrior(scale=1)
160
162
 
161
163
  # Base Matern kernel for long-term trends
162
- base_kernel = ScaleKernel(
164
+ return ScaleKernel(
163
165
  MaternKernel(
164
166
  active_dims=self.time_dim,
165
167
  lengthscale_prior=ls_prior,
166
- nu=1.5, # was 2.5
167
- ),
168
- outputscale_prior=eta,
169
- )
170
-
171
- # Periodic kernel for annual seasonality
172
- # Locally periodic kernel: Periodic * Matern
173
- periodic_kernel = ScaleKernel(
174
- gpytorch.kernels.PeriodicKernel(
175
- active_dims=self.time_dim,
176
- period_length_prior=NormalPrior(loc=1.0, scale=0.05), # ~1 year
177
- lengthscale_prior=GammaPrior(concentration=6, rate=1),
178
- ) * MaternKernel(
179
- active_dims=self.time_dim,
180
- nu=2.5,
181
- lengthscale_prior=GammaPrior(concentration=4, rate=3),
168
+ nu=1.5, # was 1.5 XXX
182
169
  ),
183
- outputscale_prior=HalfNormalPrior(scale=0.2),
170
+ outputscale_prior=eta_prior,
184
171
  )
185
-
186
- return base_kernel + periodic_kernel
187
172
 
173
+ def cov_periodic(self, ls_prior=None, eta_prior=None):
174
+ """
175
+ Smooth, time-dependent periodic kernel for seasonal effects.
176
+ """
177
+ if eta_prior is None:
178
+ eta_prior = HalfNormalPrior(scale=0.5)
188
179
 
189
-
190
-
191
- def cov_stagetime(self):
192
- eta = HalfNormalPrior(scale=1)
193
- ls = GammaPrior(concentration=2, rate=1)
180
+ if ls_prior is None:
181
+ ls_prior = GammaPrior(concentration=3, rate=1)
194
182
 
195
183
  return ScaleKernel(
196
- StageTimeKernel(
197
- active_dims=self.dims,
198
- # lengthscale_prior=ls,
184
+ PeriodicKernel(
185
+ active_dims=self.time_dim,
186
+ period_length_prior=NormalPrior(loc=1.0, scale=0.1), # ~1 year
187
+ # lengthscale_prior=GammaPrior(concentration=2, rate=4),
199
188
  ),
200
- # outputscale_prior=eta,
189
+ # *
190
+ # MaternKernel(
191
+ # active_dims=self.stage_dim,
192
+ # lengthscale_prior=ls_prior,
193
+ # nu=2.5, # Smoother kernel (was nu=1.5)
194
+ # ),
195
+ outputscale_prior=HalfNormalPrior(scale=0.5),
201
196
  )
202
-
203
- def cov_residual(self):
204
- eta = HalfNormalPrior(scale=0.2)
205
- ls = GammaPrior(concentration=2, rate=10)
206
-
197
+
198
+ def cov_base(self):
199
+ """
200
+ Smooth, time-independent base rating curve using a Matern kernel on stage.
201
+ """
202
+ # Base should capture most variation
203
+ eta = HalfNormalPrior(scale=1)
204
+ ls = GammaPrior(concentration=3, rate=1)
207
205
  return ScaleKernel(
208
206
  MaternKernel(
209
- ard_num_dims=2,
210
- nu=1.5,
211
- active_dims=self.dims,
207
+ active_dims=self.stage_dim,
212
208
  lengthscale_prior=ls,
213
209
  ),
214
210
  outputscale_prior=eta,
215
211
  )
212
+
rating_gp/plot.py CHANGED
@@ -12,6 +12,7 @@ import pandas as pd
12
12
  from discontinuum.engines.base import is_fitted
13
13
  from discontinuum.plot import BasePlotMixin
14
14
  from scipy.stats import norm
15
+ from rating_gp.models.kernels import SigmoidKernel
15
16
  import xarray as xr
16
17
  from xarray import DataArray
17
18
  from xarray.plot.utils import label_from_attrs
@@ -117,6 +118,28 @@ class RatingPlotMixin(BasePlotMixin):
117
118
  zorder=1,
118
119
  **kwargs
119
120
  )
121
+ # Plot switch point if sigmoid kernel is in model
122
+ try:
123
+ # find first SigmoidKernel in covar_module
124
+ sig_kernels = [m for m in self.model.covar_module.modules() if isinstance(m, SigmoidKernel)]
125
+ if sig_kernels:
126
+ sig = sig_kernels[0]
127
+ # b_sig is in normalized stage space: inverse-transform to original stage units
128
+ b_sig = sig.b.item()
129
+ # use scaler step directly to inverse-transform normalized stage
130
+ pipeline = self.dm.covariate_pipelines['stage']
131
+ scaler = pipeline.named_steps['scaler']
132
+ stage_switch = float(scaler.inverse_transform(b_sig))
133
+ # Draw switch point on every call; label only once
134
+ if not getattr(ax, 'switch_point_plotted', False):
135
+ ax.axvline(stage_switch, linestyle='--', color='gray', label='switch point')
136
+ ax.switch_point_plotted = True
137
+ ax.legend()
138
+ else:
139
+ # subsequent calls, draw without label
140
+ ax.axvline(stage_switch, linestyle='--', color='gray')
141
+ except Exception:
142
+ pass
120
143
 
121
144
  # self.plot_observed_rating(ax, zorder=3)
122
145