discontinuum 1.0.5__py3-none-any.whl → 1.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- discontinuum/_version.py +2 -2
- discontinuum/engines/gpytorch.py +5 -6
- {discontinuum-1.0.5.dist-info → discontinuum-1.0.6.dist-info}/METADATA +1 -1
- {discontinuum-1.0.5.dist-info → discontinuum-1.0.6.dist-info}/RECORD +9 -9
- rating_gp/models/gpytorch.py +70 -68
- rating_gp/plot.py +23 -0
- {discontinuum-1.0.5.dist-info → discontinuum-1.0.6.dist-info}/WHEEL +0 -0
- {discontinuum-1.0.5.dist-info → discontinuum-1.0.6.dist-info}/licenses/LICENSE.md +0 -0
- {discontinuum-1.0.5.dist-info → discontinuum-1.0.6.dist-info}/top_level.txt +0 -0
discontinuum/_version.py
CHANGED
discontinuum/engines/gpytorch.py
CHANGED
@@ -52,7 +52,7 @@ class MarginalGPyTorch(BaseModel):
|
|
52
52
|
optimizer: str = "adamw",
|
53
53
|
learning_rate: float = None,
|
54
54
|
early_stopping: bool = False,
|
55
|
-
patience: int =
|
55
|
+
patience: int = 30,
|
56
56
|
gradient_noise: bool = False,
|
57
57
|
):
|
58
58
|
"""Fit the model to data.
|
@@ -99,10 +99,8 @@ class MarginalGPyTorch(BaseModel):
|
|
99
99
|
self.likelihood.train()
|
100
100
|
|
101
101
|
if learning_rate is None:
|
102
|
-
|
103
|
-
|
104
|
-
elif optimizer == "adamw":
|
105
|
-
learning_rate = 0.1
|
102
|
+
# More conservative starting LR
|
103
|
+
learning_rate = 0.05
|
106
104
|
|
107
105
|
if optimizer == "adamw":
|
108
106
|
optimizer_obj = torch.optim.AdamW(
|
@@ -129,7 +127,8 @@ class MarginalGPyTorch(BaseModel):
|
|
129
127
|
mode='min',
|
130
128
|
factor=0.5, # Reduce LR by half
|
131
129
|
patience=max(2, patience),
|
132
|
-
threshold=
|
130
|
+
threshold=5e-1, # Aggressive plateau detection
|
131
|
+
#threshold_mode='rel', # Use relative threshold
|
133
132
|
min_lr=1e-5
|
134
133
|
)
|
135
134
|
|
@@ -1,17 +1,17 @@
|
|
1
1
|
discontinuum/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
discontinuum/_version.py,sha256=
|
2
|
+
discontinuum/_version.py,sha256=B7xX94ww8E2YsJCo2PT7LI1Lp5224NjscDIhXgKzj3U,511
|
3
3
|
discontinuum/data_manager.py,sha256=LiZoPR0nnu7YAUfh5L1ZDRfaS3dgfVIELXIHkzUKyBg,4416
|
4
4
|
discontinuum/pipeline.py,sha256=1avuZnFai-b3HmihcpZ8M3WFNQ8lXAFSNTrnfl2NrY0,10074
|
5
5
|
discontinuum/plot.py,sha256=eZQS6-Ydq8FFcEukPtNuDVB-weV6lHyWMyJ1hqTkVrU,2969
|
6
6
|
discontinuum/utils.py,sha256=07hIHQk_oDlkjz7tasgBjqqPOC6D0iNcy0eu-88aNbM,1540
|
7
7
|
discontinuum/engines/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
discontinuum/engines/base.py,sha256=OlHd4ssIQoWvYHKoVqk5fKAVBcKsIIkR4ul9iNBvaYg,2396
|
9
|
-
discontinuum/engines/gpytorch.py,sha256=
|
9
|
+
discontinuum/engines/gpytorch.py,sha256=05x7Ha0g2vywM_moL18fMFDGeh0CF3vJpF-mDImrIx8,14387
|
10
10
|
discontinuum/engines/pymc.py,sha256=phbtE-3UCSVcP1MhbXwAHIWDZWDr56wK9U7aRt-w-2o,5961
|
11
11
|
discontinuum/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
12
|
discontinuum/providers/base.py,sha256=Yn2EHS1b4fYl09-m2MYuf2P9VRUXAP-WDpSoZrCbRvY,720
|
13
13
|
discontinuum/tests/test_pipeline.py,sha256=_FhkGxbFIxNb35lGaIdZk7Zjgs6CkxEF3gFUX3PE8EU,918
|
14
|
-
discontinuum-1.0.
|
14
|
+
discontinuum-1.0.6.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
|
15
15
|
loadest_gp/__init__.py,sha256=YISfvbc7Zy2y0BOxS1A2KzqxyoNJTz0EnLMnRW6iVT8,740
|
16
16
|
loadest_gp/plot.py,sha256=x2PK7vBCc44dX9lu5YV-rvw1u4pvXSLdcrTSvYLiHMA,2595
|
17
17
|
loadest_gp/utils.py,sha256=m5QaqR_0JiuRXPfryH8nI5lODp8PqvQla5C05WDN3LY,2772
|
@@ -22,14 +22,14 @@ loadest_gp/models/pymc.py,sha256=ShP-XNoEwNAf62yCvTuS6Q8iAiB9NQk0dVS69WgkPsE,344
|
|
22
22
|
loadest_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
23
23
|
loadest_gp/providers/usgs.py,sha256=LJ5uh0g3nrZ-8I4poGwCdVqcXcpyZkroixwwt99vBcI,10885
|
24
24
|
rating_gp/pipeline.py,sha256=1HgxN6DD3ZL5lhUb3DK2in2IXiml7W4Ja272GBMTc08,1884
|
25
|
-
rating_gp/plot.py,sha256=
|
25
|
+
rating_gp/plot.py,sha256=_XaeNLYXEcJxg7B4UCxyYccSNzNow0e4dV1z93_THaQ,10899
|
26
26
|
rating_gp/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
27
|
rating_gp/models/base.py,sha256=e2Kq644I88YLHWPNA0qyRgitF5wimdLW4618vKX-o_s,1474
|
28
|
-
rating_gp/models/gpytorch.py,sha256=
|
28
|
+
rating_gp/models/gpytorch.py,sha256=bNFJFT13DyVqlKhMRD3W0r6-Y72E3S2fw9E-0houyoM,7068
|
29
29
|
rating_gp/models/kernels.py,sha256=3xg2mhY3aEgjI3r5vyAll9MA4c3M5UKqRi3FApNhJJQ,11579
|
30
30
|
rating_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
31
31
|
rating_gp/providers/usgs.py,sha256=KmKYN3c8Mi-ly2l6X80WT3taEhqCPXeEcRNi9HvbJmY,8134
|
32
|
-
discontinuum-1.0.
|
33
|
-
discontinuum-1.0.
|
34
|
-
discontinuum-1.0.
|
35
|
-
discontinuum-1.0.
|
32
|
+
discontinuum-1.0.6.dist-info/METADATA,sha256=3h3AhrQZ3eNviDGAsbcKEd3yd_cYv8XaaIxMlCQdK0s,6302
|
33
|
+
discontinuum-1.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
34
|
+
discontinuum-1.0.6.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
|
35
|
+
discontinuum-1.0.6.dist-info/RECORD,,
|
rating_gp/models/gpytorch.py
CHANGED
@@ -5,9 +5,9 @@ from discontinuum.engines.gpytorch import MarginalGPyTorch, NoOpMean
|
|
5
5
|
|
6
6
|
from gpytorch.kernels import (
|
7
7
|
MaternKernel,
|
8
|
-
RBFKernel,
|
9
8
|
RQKernel,
|
10
9
|
ScaleKernel,
|
10
|
+
PeriodicKernel,
|
11
11
|
)
|
12
12
|
from gpytorch.priors import (
|
13
13
|
GammaPrior,
|
@@ -15,10 +15,9 @@ from gpytorch.priors import (
|
|
15
15
|
NormalPrior,
|
16
16
|
)
|
17
17
|
|
18
|
-
from linear_operator.operators import MatmulLinearOperator
|
19
18
|
from rating_gp.models.base import RatingDataMixin, ModelConfig
|
20
19
|
from rating_gp.plot import RatingPlotMixin
|
21
|
-
from rating_gp.models.kernels import
|
20
|
+
from rating_gp.models.kernels import SigmoidKernel
|
22
21
|
|
23
22
|
|
24
23
|
class PowerLawTransform(torch.nn.Module):
|
@@ -94,53 +93,55 @@ class ExactGPModel(gpytorch.models.ExactGP):
|
|
94
93
|
|
95
94
|
self.powerlaw = PowerLawTransform()
|
96
95
|
|
97
|
-
# self.mean_module = gpytorch.means.ConstantMean()
|
98
|
-
# self.mean_module = gpytorch.means.LinearMean(input_size=1)
|
99
96
|
self.mean_module = NoOpMean()
|
100
97
|
|
101
|
-
#
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
# (self.cov_stage() * self.cov_stagetime())
|
106
|
-
# + self.cov_residual()
|
107
|
-
# )
|
108
|
-
|
109
|
-
# Stage * time kernel with large time length
|
110
|
-
# + stage * time kernel only at low stage with smaller time length.
|
111
|
-
# Note that stage gets transformed to q, so the kernel is actually
|
112
|
-
# q * time
|
113
|
-
b_min = np.quantile(train_y, 0.10)
|
114
|
-
b_max = np.quantile(train_y, 0.90)
|
98
|
+
# Use stage (not y) for sigmoid kernel constraint
|
99
|
+
stage = train_x[:, self.stage_dim[0]]#.cpu().numpy()
|
100
|
+
b_min = np.quantile(stage, 0.10)
|
101
|
+
b_max = np.quantile(stage, 0.90)
|
115
102
|
self.covar_module = (
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
103
|
+
# core time kernel
|
104
|
+
(
|
105
|
+
self.cov_time(
|
106
|
+
#ls_prior=GammaPrior(concentration=2, rate=1),
|
107
|
+
ls_prior=GammaPrior(concentration=3, rate=1),
|
108
|
+
eta_prior=HalfNormalPrior(scale=0.3),
|
109
|
+
)
|
110
|
+
*
|
111
|
+
self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=2))
|
112
|
+
#self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
|
113
|
+
)
|
114
|
+
# gated shift component
|
115
|
+
+ (
|
116
|
+
self.cov_time(
|
117
|
+
ls_prior=GammaPrior(concentration=2, rate=5),
|
118
|
+
eta_prior=HalfNormalPrior(scale=1),
|
119
|
+
)
|
120
|
+
* SigmoidKernel(
|
121
|
+
active_dims=self.stage_dim,
|
122
|
+
# b_prior=NormalPrior(loc=0.7, scale=0.001),
|
123
|
+
b_constraint=gpytorch.constraints.Interval(
|
124
|
+
b_min,
|
125
|
+
b_max,
|
126
|
+
),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
# additive periodic component for seasonal effects
|
130
|
+
+ self.cov_periodic()
|
130
131
|
)
|
131
132
|
|
132
133
|
|
133
134
|
def forward(self, x):
|
134
|
-
self.powerlaw.b.data.clamp_(1.
|
135
|
+
self.powerlaw.b.data.clamp_(1.2, 2.5)
|
135
136
|
#x = x.clone()
|
136
137
|
#q = self.powerlaw(x[:, self.stage_dim])
|
137
138
|
#x_t[:, self.stage_dim] = self.warp_stage_dim(x_t[:, self.stage_dim])
|
138
139
|
x_t = x.clone()
|
139
140
|
x_t[:, self.stage_dim] = self.powerlaw(x_t[:, self.stage_dim])
|
140
141
|
q = x_t[:, self.stage_dim]
|
141
|
-
|
142
142
|
mean_x = self.mean_module(q)
|
143
|
-
covar_x = self.covar_module(x_t)
|
143
|
+
#covar_x = self.covar_module(x_t)
|
144
|
+
covar_x = self.covar_module(x)
|
144
145
|
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
145
146
|
|
146
147
|
def cov_stage(self, ls_prior=None):
|
@@ -155,56 +156,57 @@ class ExactGPModel(gpytorch.models.ExactGP):
|
|
155
156
|
outputscale_prior=eta,
|
156
157
|
)
|
157
158
|
|
158
|
-
def cov_time(self, ls_prior=None):
|
159
|
-
|
159
|
+
def cov_time(self, ls_prior=None, eta_prior=None):
|
160
|
+
if eta_prior is None:
|
161
|
+
eta_prior = HalfNormalPrior(scale=1)
|
160
162
|
|
161
163
|
# Base Matern kernel for long-term trends
|
162
|
-
|
164
|
+
return ScaleKernel(
|
163
165
|
MaternKernel(
|
164
166
|
active_dims=self.time_dim,
|
165
167
|
lengthscale_prior=ls_prior,
|
166
|
-
nu=1.5, # was
|
168
|
+
nu=1.5, # was 1.5 XXX
|
167
169
|
),
|
168
|
-
outputscale_prior=
|
170
|
+
outputscale_prior=eta_prior,
|
169
171
|
)
|
172
|
+
|
173
|
+
def cov_periodic(self, ls_prior=None, eta_prior=None):
|
174
|
+
"""
|
175
|
+
Smooth, time-dependent periodic kernel for seasonal effects.
|
176
|
+
"""
|
177
|
+
if eta_prior is None:
|
178
|
+
eta_prior = HalfNormalPrior(scale=0.5)
|
179
|
+
|
180
|
+
if ls_prior is None:
|
181
|
+
ls_prior = GammaPrior(concentration=3, rate=1)
|
170
182
|
|
171
|
-
|
172
|
-
|
173
|
-
gpytorch.kernels.PeriodicKernel(
|
183
|
+
return ScaleKernel(
|
184
|
+
PeriodicKernel(
|
174
185
|
active_dims=self.time_dim,
|
175
186
|
period_length_prior=NormalPrior(loc=1.0, scale=0.1), # ~1 year
|
176
|
-
lengthscale_prior=GammaPrior(concentration=2, rate=4),
|
187
|
+
# lengthscale_prior=GammaPrior(concentration=2, rate=4),
|
177
188
|
),
|
189
|
+
# *
|
190
|
+
# MaternKernel(
|
191
|
+
# active_dims=self.stage_dim,
|
192
|
+
# lengthscale_prior=ls_prior,
|
193
|
+
# nu=2.5, # Smoother kernel (was nu=1.5)
|
194
|
+
# ),
|
178
195
|
outputscale_prior=HalfNormalPrior(scale=0.5),
|
179
196
|
)
|
180
|
-
|
181
|
-
return base_kernel + periodic_kernel
|
182
197
|
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
198
|
+
def cov_base(self):
|
199
|
+
"""
|
200
|
+
Smooth, time-independent base rating curve using a Matern kernel on stage.
|
201
|
+
"""
|
202
|
+
# Base should capture most variation
|
187
203
|
eta = HalfNormalPrior(scale=1)
|
188
|
-
ls = GammaPrior(concentration=
|
189
|
-
|
190
|
-
return ScaleKernel(
|
191
|
-
StageTimeKernel(
|
192
|
-
active_dims=self.dims,
|
193
|
-
# lengthscale_prior=ls,
|
194
|
-
),
|
195
|
-
# outputscale_prior=eta,
|
196
|
-
)
|
197
|
-
|
198
|
-
def cov_residual(self):
|
199
|
-
eta = HalfNormalPrior(scale=0.2)
|
200
|
-
ls = GammaPrior(concentration=2, rate=10)
|
201
|
-
|
204
|
+
ls = GammaPrior(concentration=3, rate=1)
|
202
205
|
return ScaleKernel(
|
203
206
|
MaternKernel(
|
204
|
-
|
205
|
-
nu=1.5,
|
206
|
-
active_dims=self.dims,
|
207
|
+
active_dims=self.stage_dim,
|
207
208
|
lengthscale_prior=ls,
|
208
209
|
),
|
209
210
|
outputscale_prior=eta,
|
210
211
|
)
|
212
|
+
|
rating_gp/plot.py
CHANGED
@@ -12,6 +12,7 @@ import pandas as pd
|
|
12
12
|
from discontinuum.engines.base import is_fitted
|
13
13
|
from discontinuum.plot import BasePlotMixin
|
14
14
|
from scipy.stats import norm
|
15
|
+
from rating_gp.models.kernels import SigmoidKernel
|
15
16
|
import xarray as xr
|
16
17
|
from xarray import DataArray
|
17
18
|
from xarray.plot.utils import label_from_attrs
|
@@ -117,6 +118,28 @@ class RatingPlotMixin(BasePlotMixin):
|
|
117
118
|
zorder=1,
|
118
119
|
**kwargs
|
119
120
|
)
|
121
|
+
# Plot switch point if sigmoid kernel is in model
|
122
|
+
try:
|
123
|
+
# find first SigmoidKernel in covar_module
|
124
|
+
sig_kernels = [m for m in self.model.covar_module.modules() if isinstance(m, SigmoidKernel)]
|
125
|
+
if sig_kernels:
|
126
|
+
sig = sig_kernels[0]
|
127
|
+
# b_sig is in normalized stage space: inverse-transform to original stage units
|
128
|
+
b_sig = sig.b.item()
|
129
|
+
# use scaler step directly to inverse-transform normalized stage
|
130
|
+
pipeline = self.dm.covariate_pipelines['stage']
|
131
|
+
scaler = pipeline.named_steps['scaler']
|
132
|
+
stage_switch = float(scaler.inverse_transform(b_sig))
|
133
|
+
# Draw switch point on every call; label only once
|
134
|
+
if not getattr(ax, 'switch_point_plotted', False):
|
135
|
+
ax.axvline(stage_switch, linestyle='--', color='gray', label='switch point')
|
136
|
+
ax.switch_point_plotted = True
|
137
|
+
ax.legend()
|
138
|
+
else:
|
139
|
+
# subsequent calls, draw without label
|
140
|
+
ax.axvline(stage_switch, linestyle='--', color='gray')
|
141
|
+
except Exception:
|
142
|
+
pass
|
120
143
|
|
121
144
|
# self.plot_observed_rating(ax, zorder=3)
|
122
145
|
|
File without changes
|
File without changes
|
File without changes
|