discontinuum 1.0.3__tar.gz → 1.0.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. {discontinuum-1.0.3 → discontinuum-1.0.4}/PKG-INFO +4 -3
  2. {discontinuum-1.0.3 → discontinuum-1.0.4}/README.md +3 -2
  3. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/source/notebooks/loadest-gp-demo.ipynb +4 -2
  4. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/source/notebooks/rating-gp-demo.ipynb +13 -4
  5. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/_version.py +2 -2
  6. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/engines/gpytorch.py +134 -15
  7. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum.egg-info/PKG-INFO +4 -3
  8. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/models/gpytorch.py +35 -8
  9. {discontinuum-1.0.3 → discontinuum-1.0.4}/.github/workflows/deploy-docs.yml +0 -0
  10. {discontinuum-1.0.3 → discontinuum-1.0.4}/.github/workflows/python-package.yml +0 -0
  11. {discontinuum-1.0.3 → discontinuum-1.0.4}/.github/workflows/python-publish.yml +0 -0
  12. {discontinuum-1.0.3 → discontinuum-1.0.4}/.gitignore +0 -0
  13. {discontinuum-1.0.3 → discontinuum-1.0.4}/DISCLAIMER.md +0 -0
  14. {discontinuum-1.0.3 → discontinuum-1.0.4}/LICENSE.md +0 -0
  15. {discontinuum-1.0.3 → discontinuum-1.0.4}/code.json +0 -0
  16. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/Makefile +0 -0
  17. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/assets/illinois-river-nitrate.png +0 -0
  18. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/make.bat +0 -0
  19. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/source/api_reference.rst +0 -0
  20. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/source/conf.py +0 -0
  21. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/source/getting_started.md +0 -0
  22. {discontinuum-1.0.3 → discontinuum-1.0.4}/docs/source/index.md +0 -0
  23. {discontinuum-1.0.3 → discontinuum-1.0.4}/examples/nwqn-loadest-example/Dockerfile_discontinuum +0 -0
  24. {discontinuum-1.0.3 → discontinuum-1.0.4}/examples/nwqn-loadest-example/README.md +0 -0
  25. {discontinuum-1.0.3 → discontinuum-1.0.4}/examples/nwqn-loadest-example/lithops.yaml +0 -0
  26. {discontinuum-1.0.3 → discontinuum-1.0.4}/examples/nwqn-loadest-example/nwqn-loadest-example.py +0 -0
  27. {discontinuum-1.0.3 → discontinuum-1.0.4}/examples/nwqn-loadest-example/requirements.txt +0 -0
  28. {discontinuum-1.0.3 → discontinuum-1.0.4}/pyproject.toml +0 -0
  29. {discontinuum-1.0.3 → discontinuum-1.0.4}/setup.cfg +0 -0
  30. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/__init__.py +0 -0
  31. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/data_manager.py +0 -0
  32. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/engines/__init__.py +0 -0
  33. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/engines/base.py +0 -0
  34. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/engines/pymc.py +0 -0
  35. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/pipeline.py +0 -0
  36. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/plot.py +0 -0
  37. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/providers/__init__.py +0 -0
  38. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/providers/base.py +0 -0
  39. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/tests/test_pipeline.py +0 -0
  40. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum/utils.py +0 -0
  41. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum.egg-info/SOURCES.txt +0 -0
  42. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum.egg-info/dependency_links.txt +0 -0
  43. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum.egg-info/requires.txt +0 -0
  44. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/discontinuum.egg-info/top_level.txt +0 -0
  45. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/__init__.py +0 -0
  46. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/models/__init__.py +0 -0
  47. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/models/base.py +0 -0
  48. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/models/gpytorch.py +0 -0
  49. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/models/pymc.py +0 -0
  50. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/plot.py +0 -0
  51. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/providers/__init__.py +0 -0
  52. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/providers/usgs.py +0 -0
  53. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/loadest_gp/utils.py +0 -0
  54. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/models/__init__.py +0 -0
  55. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/models/base.py +0 -0
  56. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/models/kernels.py +0 -0
  57. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/pipeline.py +0 -0
  58. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/plot.py +0 -0
  59. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/providers/__init__.py +0 -0
  60. {discontinuum-1.0.3 → discontinuum-1.0.4}/src/rating_gp/providers/usgs.py +0 -0
  61. {discontinuum-1.0.3 → discontinuum-1.0.4}/tests/test_loadest_gp.py +0 -0
  62. {discontinuum-1.0.3 → discontinuum-1.0.4}/tests/test_rating_gp.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: discontinuum
3
- Version: 1.0.3
3
+ Version: 1.0.4
4
4
  Summary: Estimate discontinuous timeseries from continuous covariates.
5
5
  Maintainer-email: Timothy Hodson <thodson@usgs.gov>
6
6
  License: License
@@ -124,11 +124,12 @@ However, LOADEST has several serious limitations
124
124
  the more flexible Weighted Regression on Time Discharge and Season (WRTDS),
125
125
  which allows the relation between target and covariate to vary through time.
126
126
  `loadest-gp` takes the WRTDS idea and reimplements it as a GP.
127
- Try it out in the [loadest-gp demo](https://code.usgs.gov/wma/uncertainty/discontinuum/-/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb).
127
+ github/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb
128
+ Try it out in the [loadest-gp demo](https://github.com/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb).
128
129
 
129
130
  ### rating-gp
130
131
  `rating-gp` is a Gaussian-process model for estimating river flow from stage time series.
131
- Try it out in the [rating-gp demo](https://code.usgs.gov/wma/uncertainty/discontinuum/-/blob/main/docs/source/notebooks/rating-gp-demo.ipynb).
132
+ Try it out in the [rating-gp demo](https://github.com/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/rating-gp-demo.ipynb).
132
133
 
133
134
  ## Engines
134
135
  Currently, the only supported engines are the marginal likelihood implementation in `pymc` and `gpytorch`.
@@ -36,11 +36,12 @@ However, LOADEST has several serious limitations
36
36
  the more flexible Weighted Regression on Time Discharge and Season (WRTDS),
37
37
  which allows the relation between target and covariate to vary through time.
38
38
  `loadest-gp` takes the WRTDS idea and reimplements it as a GP.
39
- Try it out in the [loadest-gp demo](https://code.usgs.gov/wma/uncertainty/discontinuum/-/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb).
39
+ github/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb
40
+ Try it out in the [loadest-gp demo](https://github.com/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb).
40
41
 
41
42
  ### rating-gp
42
43
  `rating-gp` is a Gaussian-process model for estimating river flow from stage time series.
43
- Try it out in the [rating-gp demo](https://code.usgs.gov/wma/uncertainty/discontinuum/-/blob/main/docs/source/notebooks/rating-gp-demo.ipynb).
44
+ Try it out in the [rating-gp demo](https://github.com/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/rating-gp-demo.ipynb).
44
45
 
45
46
  ## Engines
46
47
  Currently, the only supported engines are the marginal likelihood implementation in `pymc` and `gpytorch`.
@@ -14,7 +14,9 @@
14
14
  "# loadest-gp (prototype)\n",
15
15
  "LOAD ESTimator (LOADEST) is a software program for estimating some constituent using surrogate variables (covariates).\n",
16
16
  "However, LOADEST has several serious limitations, and it has been all but replaced by another model known as Weighted Regressions on Time, Discharge, and Season (WRTDS).\n",
17
- "`loadest-gp` essentially reimplements WRTDS as a Gaussian process."
17
+ "`loadest-gp` essentially reimplements WRTDS as a Gaussian process.\n",
18
+ "\n",
19
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb)"
18
20
  ]
19
21
  },
20
22
  {
@@ -25,7 +27,7 @@
25
27
  "outputs": [],
26
28
  "source": [
27
29
  "# install the latest version of discontinuum\n",
28
- "# !pip install discontinuum\n",
30
+ "# !pip install discontinuum[loadest_gp]\n",
29
31
  "%load_ext autoreload\n",
30
32
  "%autoreload 2"
31
33
  ]
@@ -13,7 +13,9 @@
13
13
  "source": [
14
14
  "# rating-gp (prototype)\n",
15
15
  "`rating-gp` is a prototype model that can fit rating curves (stage-discharge relationship) using a Gaussian process.\n",
16
- "This model seeks to expand the typical rating curve fitting process to include shifts in the rating curve with time such that the time evolution in the rating curve can be included in the model."
16
+ "This model seeks to expand the typical rating curve fitting process to include shifts in the rating curve with time such that the time evolution in the rating curve can be included in the model.\n",
17
+ "\n",
18
+ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/rating-gp-demo.ipynb)"
17
19
  ]
18
20
  },
19
21
  {
@@ -24,7 +26,9 @@
24
26
  "outputs": [],
25
27
  "source": [
26
28
  "%load_ext autoreload\n",
27
- "%autoreload 2"
29
+ "%autoreload 2\n",
30
+ "\n",
31
+ "#!pip install discontinuum[rating_gp]"
28
32
  ]
29
33
  },
30
34
  {
@@ -116,7 +120,10 @@
116
120
  "model = RatingGP()\n",
117
121
  "model.fit(target=training_data['discharge'],\n",
118
122
  " covariates=training_data[['stage']],\n",
119
- " target_unc=training_data['discharge_unc'])"
123
+ " target_unc=training_data['discharge_unc'],\n",
124
+ " iterations=2000,\n",
125
+ " early_stopping=True,\n",
126
+ " )"
120
127
  ]
121
128
  },
122
129
  {
@@ -242,7 +249,9 @@
242
249
  " model[site].fit(target=training_data['discharge'],\n",
243
250
  " covariates=training_data[['stage']],\n",
244
251
  " target_unc=training_data['discharge_unc'],\n",
245
- " iterations=200)"
252
+ " iterations=2000,\n",
253
+ " early_stopping=True,\n",
254
+ " )"
246
255
  ]
247
256
  },
248
257
  {
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '1.0.3'
21
- __version_tuple__ = version_tuple = (1, 0, 3)
20
+ __version__ = version = '1.0.4'
21
+ __version_tuple__ = version_tuple = (1, 0, 4)
@@ -50,6 +50,9 @@ class MarginalGPyTorch(BaseModel):
50
50
  target_unc: Dataset = None,
51
51
  iterations: int = 100,
52
52
  optimizer: str = "adam",
53
+ learning_rate: float = None,
54
+ early_stopping: bool = False,
55
+ early_stopping_patience: int = 100,
53
56
  ):
54
57
  """Fit the model to data.
55
58
 
@@ -65,6 +68,12 @@ class MarginalGPyTorch(BaseModel):
65
68
  Number of iterations for optimization. The default is 100.
66
69
  optimizer : str, optional
67
70
  Optimization method. The default is "adam".
71
+ learning_rate : float, optional
72
+ Learning rate for optimization. If None, uses adaptive defaults.
73
+ early_stopping : bool, optional
74
+ Whether to use early stopping. The default is False.
75
+ early_stopping_patience : int, optional
76
+ Number of iterations to wait without improvement before stopping. The default is 100.
68
77
  """
69
78
  self.is_fitted = True
70
79
  # setup data manager (self.dm)
@@ -86,26 +95,136 @@ class MarginalGPyTorch(BaseModel):
86
95
  self.model.train()
87
96
  self.likelihood.train()
88
97
 
89
- # Use the adam optimizer
90
- if optimizer == "adam":
91
- optimizer = torch.optim.Adam(self.model.parameters(), lr=0.05) # default previously lr=0.1
92
- else:
93
- raise NotImplementedError("Only Adam optimizer is implemented")
98
+ # Adaptive learning rate selection for faster convergence
99
+ if learning_rate is None:
100
+ if optimizer == "adam":
101
+ learning_rate = 0.1 # More aggressive default for faster convergence
102
+ elif optimizer == "lbfgs":
103
+ learning_rate = 1.0 # L-BFGS doesn't use learning rate the same way
104
+
105
+ # Use the specified optimizer with stabilization
106
+ if optimizer != "adam":
107
+ raise NotImplementedError(f"Only 'adam' optimizer is supported. Got '{optimizer}'.")
108
+ optimizer = torch.optim.Adam(
109
+ self.model.parameters(),
110
+ lr=learning_rate,
111
+ betas=(0.9, 0.999), # Slightly more conservative momentum
112
+ eps=1e-8, # Numerical stability
113
+ weight_decay=1e-4 # Small L2 regularization
114
+ )
115
+ # More responsive learning rate scheduler for faster adaptation
116
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
117
+ optimizer,
118
+ mode='min',
119
+ factor=0.6, # Reduce LR by 40% when loss plateaus (more aggressive)
120
+ patience=40, # Reduce sooner for faster adaptation
121
+ min_lr=1e-5, # Higher minimum learning rate
122
+ threshold=1e-4 # Less sensitive to plateaus
123
+ )
94
124
 
95
125
  # "Loss" for GPs - the marginal log likelihood
96
126
  mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self.model)
97
127
 
98
- pbar = tqdm.tqdm(range(iterations), ncols=70)
128
+ # Training loop with stability features
129
+ pbar = tqdm.tqdm(range(iterations), ncols=100) # Wider progress bar
130
+ jitter = 1e-6 # Dynamic jitter for numerical stability
131
+ best_loss = float('inf')
132
+ patience_counter = 0
133
+ min_lr_for_early_stop = 2e-5 # Stop if patience is exceeded and LR is below this
134
+
99
135
  for i in pbar:
100
- # Zero gradients from previous iteration
101
- optimizer.zero_grad()
102
- # Output from model
103
- output = self.model(train_x)
104
- # Calc loss and backprop gradients
105
- loss = -mll(output, train_y)
106
- loss.backward()
107
- pbar.set_postfix(loss=loss.item())
108
- optimizer.step()
136
+ if optimizer.__class__.__name__ == "LBFGS":
137
+ # L-BFGS requires a closure function
138
+ def closure():
139
+ optimizer.zero_grad()
140
+ output = self.model(train_x)
141
+ with gpytorch.settings.cholesky_jitter(jitter):
142
+ loss = -mll(output, train_y).sum()
143
+ loss.backward()
144
+ return loss
145
+
146
+ loss = optimizer.step(closure)
147
+ pbar.set_postfix(loss=loss.item())
148
+ else:
149
+ # Adam optimizer with stability features
150
+ optimizer.zero_grad()
151
+ output = self.model(train_x)
152
+
153
+ # Attempt loss calculation with dynamic jitter
154
+ try:
155
+ with gpytorch.settings.cholesky_jitter(jitter):
156
+ loss = -mll(output, train_y)
157
+ except Exception as e:
158
+ # Increase jitter if numerical issues occur
159
+ jitter = min(jitter * 10, 1e-2)
160
+ current_lr = optimizer.param_groups[0]['lr']
161
+ pbar.set_postfix_str(
162
+ f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
163
+ )
164
+ continue
165
+
166
+ # Check for NaN loss
167
+ if torch.isnan(loss) or torch.isinf(loss):
168
+ current_lr = optimizer.param_groups[0]['lr']
169
+ pbar.set_postfix_str(
170
+ f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
171
+ )
172
+ continue
173
+
174
+ loss.backward()
175
+
176
+ # Gradient clipping for stability
177
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
178
+
179
+ # Check for NaN gradients
180
+ has_nan_grad = False
181
+ for param in self.model.parameters():
182
+ if param.grad is not None and torch.isnan(param.grad).any():
183
+ has_nan_grad = True
184
+ break
185
+
186
+ if has_nan_grad:
187
+ # Don't update scheduler on NaN gradients - this prevents rapid LR decay
188
+ # The scheduler should only respond to actual optimization progress
189
+ current_lr = optimizer.param_groups[0]['lr']
190
+
191
+ # Update best loss tracking (loss is still valid, just gradients are NaN)
192
+ if loss.item() < best_loss:
193
+ best_loss = loss.item()
194
+ patience_counter = 0
195
+ else:
196
+ patience_counter += 1
197
+
198
+ # Display comprehensive info even with NaN gradients
199
+ pbar.set_postfix_str(
200
+ f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
201
+ )
202
+ continue
203
+
204
+ optimizer.step()
205
+
206
+ # Update learning rate scheduler for Adam
207
+ scheduler.step(loss)
208
+ current_lr = optimizer.param_groups[0]['lr']
209
+
210
+ # Early stopping check (more aggressive)
211
+ if loss.item() < best_loss:
212
+ best_loss = loss.item()
213
+ patience_counter = 0
214
+ else:
215
+ patience_counter += 1
216
+
217
+ # Display progress with comprehensive metadata
218
+ progress_info = f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f}'
219
+ if early_stopping:
220
+ progress_info += f' patience={patience_counter}/25'
221
+ pbar.set_postfix_str(progress_info)
222
+
223
+ # More aggressive early stopping: patience=25 and require LR to be low
224
+ if early_stopping and patience_counter >= 25 and current_lr <= min_lr_for_early_stop:
225
+ print(f"\nEarly stopping triggered after {i+1} iterations (patience exceeded and LR low)")
226
+ print(f"Best loss: {best_loss:.6f}")
227
+ break
109
228
 
110
229
  @is_fitted
111
230
  def predict(self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: discontinuum
3
- Version: 1.0.3
3
+ Version: 1.0.4
4
4
  Summary: Estimate discontinuous timeseries from continuous covariates.
5
5
  Maintainer-email: Timothy Hodson <thodson@usgs.gov>
6
6
  License: License
@@ -124,11 +124,12 @@ However, LOADEST has several serious limitations
124
124
  the more flexible Weighted Regression on Time Discharge and Season (WRTDS),
125
125
  which allows the relation between target and covariate to vary through time.
126
126
  `loadest-gp` takes the WRTDS idea and reimplements it as a GP.
127
- Try it out in the [loadest-gp demo](https://code.usgs.gov/wma/uncertainty/discontinuum/-/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb).
127
+ github/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb
128
+ Try it out in the [loadest-gp demo](https://github.com/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/loadest-gp-demo.ipynb).
128
129
 
129
130
  ### rating-gp
130
131
  `rating-gp` is a Gaussian-process model for estimating river flow from stage time series.
131
- Try it out in the [rating-gp demo](https://code.usgs.gov/wma/uncertainty/discontinuum/-/blob/main/docs/source/notebooks/rating-gp-demo.ipynb).
132
+ Try it out in the [rating-gp demo](https://github.com/thodson-usgs/discontinuum/blob/main/docs/source/notebooks/rating-gp-demo.ipynb).
132
133
 
133
134
  ## Engines
134
135
  Currently, the only supported engines are the marginal likelihood implementation in `pymc` and `gpytorch`.
@@ -71,8 +71,9 @@ class RatingGPMarginalGPyTorch(
71
71
  # noise, *and* you did not specify noise. This is treated as a no-op."
72
72
  self.likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(
73
73
  noise=noise,
74
+ #learn_additional_noise=False,
74
75
  learn_additional_noise=True,
75
- noise_prior=gpytorch.priors.HalfNormalPrior(scale=0.01),
76
+ noise_prior=gpytorch.priors.HalfNormalPrior(scale=0.005),
76
77
  )
77
78
 
78
79
  model = ExactGPModel(X, y, self.likelihood)
@@ -109,17 +110,20 @@ class ExactGPModel(gpytorch.models.ExactGP):
109
110
  # + stage * time kernel only at low stage with smaller time length.
110
111
  # Note that stage gets transformed to q, so the kernel is actually
111
112
  # q * time
113
+ b_min = np.quantile(train_y, 0.30)
114
+ b_max = np.quantile(train_y, 0.90)
112
115
  self.covar_module = (
113
- (self.cov_stage()
116
+ (self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
114
117
  * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=1)))
115
- + (self.cov_stage(ls_prior=GammaPrior(concentration=1, rate=2))
116
- * self.cov_time(ls_prior=GammaPrior(concentration=2, rate=5))
118
+ + (self.cov_stage(ls_prior=GammaPrior(concentration=5, rate=1))
119
+ * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=5))
117
120
  * SigmoidKernel(
118
121
  active_dims=self.stage_dim,
119
122
  # a_prior=NormalPrior(loc=20, scale=1),
123
+ # b_prior=NormalPrior(loc=0.5, scale=0.2),
120
124
  b_constraint=gpytorch.constraints.Interval(
121
- train_y.min(),
122
- train_y.max(),
125
+ b_min,
126
+ b_max,
123
127
  ),
124
128
  )
125
129
  )
@@ -141,11 +145,12 @@ class ExactGPModel(gpytorch.models.ExactGP):
141
145
 
142
146
  def cov_stage(self, ls_prior=None):
143
147
  eta = HalfNormalPrior(scale=1)
144
-
148
+
145
149
  return ScaleKernel(
146
150
  MaternKernel(
147
151
  active_dims=self.stage_dim,
148
152
  lengthscale_prior=ls_prior,
153
+ nu=2.5, # Smoother kernel (was nu=1.5)
149
154
  ),
150
155
  outputscale_prior=eta,
151
156
  )
@@ -153,13 +158,35 @@ class ExactGPModel(gpytorch.models.ExactGP):
153
158
  def cov_time(self, ls_prior=None):
154
159
  eta = HalfNormalPrior(scale=1)
155
160
 
156
- return ScaleKernel(
161
+ # Base Matern kernel for long-term trends
162
+ base_kernel = ScaleKernel(
157
163
  MaternKernel(
158
164
  active_dims=self.time_dim,
159
165
  lengthscale_prior=ls_prior,
166
+ nu=1.5, # was 2.5
160
167
  ),
161
168
  outputscale_prior=eta,
162
169
  )
170
+
171
+ # Periodic kernel for annual seasonality
172
+ # Locally periodic kernel: Periodic * Matern
173
+ periodic_kernel = ScaleKernel(
174
+ gpytorch.kernels.PeriodicKernel(
175
+ active_dims=self.time_dim,
176
+ period_length_prior=NormalPrior(loc=1.0, scale=0.05), # ~1 year
177
+ lengthscale_prior=GammaPrior(concentration=6, rate=1),
178
+ ) * MaternKernel(
179
+ active_dims=self.time_dim,
180
+ nu=2.5,
181
+ lengthscale_prior=GammaPrior(concentration=4, rate=3),
182
+ ),
183
+ outputscale_prior=HalfNormalPrior(scale=0.2),
184
+ )
185
+
186
+ return base_kernel + periodic_kernel
187
+
188
+
189
+
163
190
 
164
191
  def cov_stagetime(self):
165
192
  eta = HalfNormalPrior(scale=1)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes