discontinuum 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
discontinuum/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '1.0.5'
21
- __version_tuple__ = version_tuple = (1, 0, 5)
20
+ __version__ = version = '1.0.6'
21
+ __version_tuple__ = version_tuple = (1, 0, 6)
@@ -52,7 +52,7 @@ class MarginalGPyTorch(BaseModel):
52
52
  optimizer: str = "adamw",
53
53
  learning_rate: float = None,
54
54
  early_stopping: bool = False,
55
- patience: int = 60,
55
+ patience: int = 30,
56
56
  gradient_noise: bool = False,
57
57
  ):
58
58
  """Fit the model to data.
@@ -99,10 +99,8 @@ class MarginalGPyTorch(BaseModel):
99
99
  self.likelihood.train()
100
100
 
101
101
  if learning_rate is None:
102
- if optimizer == "adam":
103
- learning_rate = 0.1 # Aggressive default for faster convergence
104
- elif optimizer == "adamw":
105
- learning_rate = 0.1
102
+ # More conservative starting LR
103
+ learning_rate = 0.05
106
104
 
107
105
  if optimizer == "adamw":
108
106
  optimizer_obj = torch.optim.AdamW(
@@ -129,7 +127,8 @@ class MarginalGPyTorch(BaseModel):
129
127
  mode='min',
130
128
  factor=0.5, # Reduce LR by half
131
129
  patience=max(2, patience),
132
- threshold=1e-4,
130
+ threshold=5e-1, # Aggressive plateau detection
131
+ #threshold_mode='rel', # Use relative threshold
133
132
  min_lr=1e-5
134
133
  )
135
134
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: discontinuum
3
- Version: 1.0.5
3
+ Version: 1.0.6
4
4
  Summary: Estimate discontinuous timeseries from continuous covariates.
5
5
  Maintainer-email: Timothy Hodson <thodson@usgs.gov>
6
6
  License: License
@@ -1,17 +1,17 @@
1
1
  discontinuum/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- discontinuum/_version.py,sha256=cafGA7j4exK_daS29O0wW2JM2p5b8FwYDLqAYJR0jWo,511
2
+ discontinuum/_version.py,sha256=B7xX94ww8E2YsJCo2PT7LI1Lp5224NjscDIhXgKzj3U,511
3
3
  discontinuum/data_manager.py,sha256=LiZoPR0nnu7YAUfh5L1ZDRfaS3dgfVIELXIHkzUKyBg,4416
4
4
  discontinuum/pipeline.py,sha256=1avuZnFai-b3HmihcpZ8M3WFNQ8lXAFSNTrnfl2NrY0,10074
5
5
  discontinuum/plot.py,sha256=eZQS6-Ydq8FFcEukPtNuDVB-weV6lHyWMyJ1hqTkVrU,2969
6
6
  discontinuum/utils.py,sha256=07hIHQk_oDlkjz7tasgBjqqPOC6D0iNcy0eu-88aNbM,1540
7
7
  discontinuum/engines/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  discontinuum/engines/base.py,sha256=OlHd4ssIQoWvYHKoVqk5fKAVBcKsIIkR4ul9iNBvaYg,2396
9
- discontinuum/engines/gpytorch.py,sha256=kRyAgCfxjKZbAJhJGViaDU_y8NO8sW4rSWRyEQlomHo,14383
9
+ discontinuum/engines/gpytorch.py,sha256=05x7Ha0g2vywM_moL18fMFDGeh0CF3vJpF-mDImrIx8,14387
10
10
  discontinuum/engines/pymc.py,sha256=phbtE-3UCSVcP1MhbXwAHIWDZWDr56wK9U7aRt-w-2o,5961
11
11
  discontinuum/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  discontinuum/providers/base.py,sha256=Yn2EHS1b4fYl09-m2MYuf2P9VRUXAP-WDpSoZrCbRvY,720
13
13
  discontinuum/tests/test_pipeline.py,sha256=_FhkGxbFIxNb35lGaIdZk7Zjgs6CkxEF3gFUX3PE8EU,918
14
- discontinuum-1.0.5.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
14
+ discontinuum-1.0.6.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
15
15
  loadest_gp/__init__.py,sha256=YISfvbc7Zy2y0BOxS1A2KzqxyoNJTz0EnLMnRW6iVT8,740
16
16
  loadest_gp/plot.py,sha256=x2PK7vBCc44dX9lu5YV-rvw1u4pvXSLdcrTSvYLiHMA,2595
17
17
  loadest_gp/utils.py,sha256=m5QaqR_0JiuRXPfryH8nI5lODp8PqvQla5C05WDN3LY,2772
@@ -22,14 +22,14 @@ loadest_gp/models/pymc.py,sha256=ShP-XNoEwNAf62yCvTuS6Q8iAiB9NQk0dVS69WgkPsE,344
22
22
  loadest_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  loadest_gp/providers/usgs.py,sha256=LJ5uh0g3nrZ-8I4poGwCdVqcXcpyZkroixwwt99vBcI,10885
24
24
  rating_gp/pipeline.py,sha256=1HgxN6DD3ZL5lhUb3DK2in2IXiml7W4Ja272GBMTc08,1884
25
- rating_gp/plot.py,sha256=CJphwqWWAfIY22j5Oz5DRwj7TcQCRyIQvM79_3KEdlc,9635
25
+ rating_gp/plot.py,sha256=_XaeNLYXEcJxg7B4UCxyYccSNzNow0e4dV1z93_THaQ,10899
26
26
  rating_gp/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rating_gp/models/base.py,sha256=e2Kq644I88YLHWPNA0qyRgitF5wimdLW4618vKX-o_s,1474
28
- rating_gp/models/gpytorch.py,sha256=eFHwtnW44GZ1zz0fLx5REbzIWwnb_x_uq-cGjDcHyWs,6907
28
+ rating_gp/models/gpytorch.py,sha256=bNFJFT13DyVqlKhMRD3W0r6-Y72E3S2fw9E-0houyoM,7068
29
29
  rating_gp/models/kernels.py,sha256=3xg2mhY3aEgjI3r5vyAll9MA4c3M5UKqRi3FApNhJJQ,11579
30
30
  rating_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  rating_gp/providers/usgs.py,sha256=KmKYN3c8Mi-ly2l6X80WT3taEhqCPXeEcRNi9HvbJmY,8134
32
- discontinuum-1.0.5.dist-info/METADATA,sha256=GDh-fscmNYYMmXXoUE5Xd1QafuKEOPjgZjlssmjqGVg,6302
33
- discontinuum-1.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- discontinuum-1.0.5.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
35
- discontinuum-1.0.5.dist-info/RECORD,,
32
+ discontinuum-1.0.6.dist-info/METADATA,sha256=3h3AhrQZ3eNviDGAsbcKEd3yd_cYv8XaaIxMlCQdK0s,6302
33
+ discontinuum-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
+ discontinuum-1.0.6.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
35
+ discontinuum-1.0.6.dist-info/RECORD,,
@@ -5,9 +5,9 @@ from discontinuum.engines.gpytorch import MarginalGPyTorch, NoOpMean
5
5
 
6
6
  from gpytorch.kernels import (
7
7
  MaternKernel,
8
- RBFKernel,
9
8
  RQKernel,
10
9
  ScaleKernel,
10
+ PeriodicKernel,
11
11
  )
12
12
  from gpytorch.priors import (
13
13
  GammaPrior,
@@ -15,10 +15,9 @@ from gpytorch.priors import (
15
15
  NormalPrior,
16
16
  )
17
17
 
18
- from linear_operator.operators import MatmulLinearOperator
19
18
  from rating_gp.models.base import RatingDataMixin, ModelConfig
20
19
  from rating_gp.plot import RatingPlotMixin
21
- from rating_gp.models.kernels import StageTimeKernel, SigmoidKernel, LogWarp, TanhWarp
20
+ from rating_gp.models.kernels import SigmoidKernel
22
21
 
23
22
 
24
23
  class PowerLawTransform(torch.nn.Module):
@@ -94,53 +93,55 @@ class ExactGPModel(gpytorch.models.ExactGP):
94
93
 
95
94
  self.powerlaw = PowerLawTransform()
96
95
 
97
- # self.mean_module = gpytorch.means.ConstantMean()
98
- # self.mean_module = gpytorch.means.LinearMean(input_size=1)
99
96
  self.mean_module = NoOpMean()
100
97
 
101
- #self.warp_stage_dim = TanhWarp()
102
- #self.warp_stage_dim = LogWarp()
103
-
104
- # self.covar_module = (
105
- # (self.cov_stage() * self.cov_stagetime())
106
- # + self.cov_residual()
107
- # )
108
-
109
- # Stage * time kernel with large time length
110
- # + stage * time kernel only at low stage with smaller time length.
111
- # Note that stage gets transformed to q, so the kernel is actually
112
- # q * time
113
- b_min = np.quantile(train_y, 0.10)
114
- b_max = np.quantile(train_y, 0.90)
98
+ # Use stage (not y) for sigmoid kernel constraint
99
+ stage = train_x[:, self.stage_dim[0]]#.cpu().numpy()
100
+ b_min = np.quantile(stage, 0.10)
101
+ b_max = np.quantile(stage, 0.90)
115
102
  self.covar_module = (
116
- (self.cov_stage(ls_prior=GammaPrior(concentration=1, rate=1))
117
- * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=1)))
118
- + (self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=1))
119
- * self.cov_time(ls_prior=GammaPrior(concentration=2, rate=5))
120
- * SigmoidKernel(
121
- active_dims=self.stage_dim,
122
- # a_prior=NormalPrior(loc=20, scale=1),
123
- # b_prior=NormalPrior(loc=0.5, scale=0.2),
124
- b_constraint=gpytorch.constraints.Interval(
125
- b_min,
126
- b_max,
127
- ),
128
- )
129
- )
103
+ # core time kernel
104
+ (
105
+ self.cov_time(
106
+ #ls_prior=GammaPrior(concentration=2, rate=1),
107
+ ls_prior=GammaPrior(concentration=3, rate=1),
108
+ eta_prior=HalfNormalPrior(scale=0.3),
109
+ )
110
+ *
111
+ self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=2))
112
+ #self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
113
+ )
114
+ # gated shift component
115
+ + (
116
+ self.cov_time(
117
+ ls_prior=GammaPrior(concentration=2, rate=5),
118
+ eta_prior=HalfNormalPrior(scale=1),
119
+ )
120
+ * SigmoidKernel(
121
+ active_dims=self.stage_dim,
122
+ # b_prior=NormalPrior(loc=0.7, scale=0.001),
123
+ b_constraint=gpytorch.constraints.Interval(
124
+ b_min,
125
+ b_max,
126
+ ),
127
+ )
128
+ )
129
+ # additive periodic component for seasonal effects
130
+ + self.cov_periodic()
130
131
  )
131
132
 
132
133
 
133
134
  def forward(self, x):
134
- self.powerlaw.b.data.clamp_(1.5, 2.5)
135
+ self.powerlaw.b.data.clamp_(1.2, 2.5)
135
136
  #x = x.clone()
136
137
  #q = self.powerlaw(x[:, self.stage_dim])
137
138
  #x_t[:, self.stage_dim] = self.warp_stage_dim(x_t[:, self.stage_dim])
138
139
  x_t = x.clone()
139
140
  x_t[:, self.stage_dim] = self.powerlaw(x_t[:, self.stage_dim])
140
141
  q = x_t[:, self.stage_dim]
141
-
142
142
  mean_x = self.mean_module(q)
143
- covar_x = self.covar_module(x_t)
143
+ #covar_x = self.covar_module(x_t)
144
+ covar_x = self.covar_module(x)
144
145
  return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
145
146
 
146
147
  def cov_stage(self, ls_prior=None):
@@ -155,56 +156,57 @@ class ExactGPModel(gpytorch.models.ExactGP):
155
156
  outputscale_prior=eta,
156
157
  )
157
158
 
158
- def cov_time(self, ls_prior=None):
159
- eta = HalfNormalPrior(scale=1)
159
+ def cov_time(self, ls_prior=None, eta_prior=None):
160
+ if eta_prior is None:
161
+ eta_prior = HalfNormalPrior(scale=1)
160
162
 
161
163
  # Base Matern kernel for long-term trends
162
- base_kernel = ScaleKernel(
164
+ return ScaleKernel(
163
165
  MaternKernel(
164
166
  active_dims=self.time_dim,
165
167
  lengthscale_prior=ls_prior,
166
- nu=1.5, # was 2.5
168
+ nu=1.5, # was 1.5 XXX
167
169
  ),
168
- outputscale_prior=eta,
170
+ outputscale_prior=eta_prior,
169
171
  )
172
+
173
+ def cov_periodic(self, ls_prior=None, eta_prior=None):
174
+ """
175
+ Smooth, time-dependent periodic kernel for seasonal effects.
176
+ """
177
+ if eta_prior is None:
178
+ eta_prior = HalfNormalPrior(scale=0.5)
179
+
180
+ if ls_prior is None:
181
+ ls_prior = GammaPrior(concentration=3, rate=1)
170
182
 
171
- # Periodic performs beter than a locally periodic kernel
172
- periodic_kernel = ScaleKernel(
173
- gpytorch.kernels.PeriodicKernel(
183
+ return ScaleKernel(
184
+ PeriodicKernel(
174
185
  active_dims=self.time_dim,
175
186
  period_length_prior=NormalPrior(loc=1.0, scale=0.1), # ~1 year
176
- lengthscale_prior=GammaPrior(concentration=2, rate=4),
187
+ # lengthscale_prior=GammaPrior(concentration=2, rate=4),
177
188
  ),
189
+ # *
190
+ # MaternKernel(
191
+ # active_dims=self.stage_dim,
192
+ # lengthscale_prior=ls_prior,
193
+ # nu=2.5, # Smoother kernel (was nu=1.5)
194
+ # ),
178
195
  outputscale_prior=HalfNormalPrior(scale=0.5),
179
196
  )
180
-
181
- return base_kernel + periodic_kernel
182
197
 
183
-
184
-
185
-
186
- def cov_stagetime(self):
198
+ def cov_base(self):
199
+ """
200
+ Smooth, time-independent base rating curve using a Matern kernel on stage.
201
+ """
202
+ # Base should capture most variation
187
203
  eta = HalfNormalPrior(scale=1)
188
- ls = GammaPrior(concentration=2, rate=1)
189
-
190
- return ScaleKernel(
191
- StageTimeKernel(
192
- active_dims=self.dims,
193
- # lengthscale_prior=ls,
194
- ),
195
- # outputscale_prior=eta,
196
- )
197
-
198
- def cov_residual(self):
199
- eta = HalfNormalPrior(scale=0.2)
200
- ls = GammaPrior(concentration=2, rate=10)
201
-
204
+ ls = GammaPrior(concentration=3, rate=1)
202
205
  return ScaleKernel(
203
206
  MaternKernel(
204
- ard_num_dims=2,
205
- nu=1.5,
206
- active_dims=self.dims,
207
+ active_dims=self.stage_dim,
207
208
  lengthscale_prior=ls,
208
209
  ),
209
210
  outputscale_prior=eta,
210
211
  )
212
+
rating_gp/plot.py CHANGED
@@ -12,6 +12,7 @@ import pandas as pd
12
12
  from discontinuum.engines.base import is_fitted
13
13
  from discontinuum.plot import BasePlotMixin
14
14
  from scipy.stats import norm
15
+ from rating_gp.models.kernels import SigmoidKernel
15
16
  import xarray as xr
16
17
  from xarray import DataArray
17
18
  from xarray.plot.utils import label_from_attrs
@@ -117,6 +118,28 @@ class RatingPlotMixin(BasePlotMixin):
117
118
  zorder=1,
118
119
  **kwargs
119
120
  )
121
+ # Plot switch point if sigmoid kernel is in model
122
+ try:
123
+ # find first SigmoidKernel in covar_module
124
+ sig_kernels = [m for m in self.model.covar_module.modules() if isinstance(m, SigmoidKernel)]
125
+ if sig_kernels:
126
+ sig = sig_kernels[0]
127
+ # b_sig is in normalized stage space: inverse-transform to original stage units
128
+ b_sig = sig.b.item()
129
+ # use scaler step directly to inverse-transform normalized stage
130
+ pipeline = self.dm.covariate_pipelines['stage']
131
+ scaler = pipeline.named_steps['scaler']
132
+ stage_switch = float(scaler.inverse_transform(b_sig))
133
+ # Draw switch point on every call; label only once
134
+ if not getattr(ax, 'switch_point_plotted', False):
135
+ ax.axvline(stage_switch, linestyle='--', color='gray', label='switch point')
136
+ ax.switch_point_plotted = True
137
+ ax.legend()
138
+ else:
139
+ # subsequent calls, draw without label
140
+ ax.axvline(stage_switch, linestyle='--', color='gray')
141
+ except Exception:
142
+ pass
120
143
 
121
144
  # self.plot_observed_rating(ax, zorder=3)
122
145