discontinuum 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
discontinuum/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '1.0.4'
21
- __version_tuple__ = version_tuple = (1, 0, 4)
20
+ __version__ = version = '1.0.5'
21
+ __version_tuple__ = version_tuple = (1, 0, 5)
@@ -49,10 +49,11 @@ class MarginalGPyTorch(BaseModel):
49
49
  target: Dataset,
50
50
  target_unc: Dataset = None,
51
51
  iterations: int = 100,
52
- optimizer: str = "adam",
52
+ optimizer: str = "adamw",
53
53
  learning_rate: float = None,
54
54
  early_stopping: bool = False,
55
- early_stopping_patience: int = 100,
55
+ patience: int = 60,
56
+ gradient_noise: bool = False,
56
57
  ):
57
58
  """Fit the model to data.
58
59
 
@@ -67,13 +68,15 @@ class MarginalGPyTorch(BaseModel):
67
68
  iterations : int, optional
68
69
  Number of iterations for optimization. The default is 100.
69
70
  optimizer : str, optional
70
- Optimization method. The default is "adam".
71
+ Optimization method. Supported: "adam", "adamw". The default is "adamw".
71
72
  learning_rate : float, optional
72
73
  Learning rate for optimization. If None, uses adaptive defaults.
73
74
  early_stopping : bool, optional
74
75
  Whether to use early stopping. The default is False.
75
- early_stopping_patience : int, optional
76
- Number of iterations to wait without improvement before stopping. The default is 100.
76
+ patience : int, optional
77
+ Number of iterations to wait without improvement before stopping. The default is 60.
78
+ gradient_noise : bool, optional
79
+ Whether to inject Gaussian noise into gradients each step (std = 0.1 × current learning rate). The default is False.
77
80
  """
78
81
  self.is_fitted = True
79
82
  # setup data manager (self.dm)
@@ -95,31 +98,39 @@ class MarginalGPyTorch(BaseModel):
95
98
  self.model.train()
96
99
  self.likelihood.train()
97
100
 
98
- # Adaptive learning rate selection for faster convergence
99
101
  if learning_rate is None:
100
102
  if optimizer == "adam":
101
- learning_rate = 0.1 # More aggressive default for faster convergence
102
- elif optimizer == "lbfgs":
103
- learning_rate = 1.0 # L-BFGS doesn't use learning rate the same way
103
+ learning_rate = 0.1 # Aggressive default for faster convergence
104
+ elif optimizer == "adamw":
105
+ learning_rate = 0.1
104
106
 
105
- # Use the specified optimizer with stabilization
106
- if optimizer != "adam":
107
- raise NotImplementedError(f"Only 'adam' optimizer is supported. Got '{optimizer}'.")
108
- optimizer = torch.optim.Adam(
109
- self.model.parameters(),
110
- lr=learning_rate,
111
- betas=(0.9, 0.999), # Slightly more conservative momentum
112
- eps=1e-8, # Numerical stability
113
- weight_decay=1e-4 # Small L2 regularization
114
- )
115
- # More responsive learning rate scheduler for faster adaptation
107
+ if optimizer == "adamw":
108
+ optimizer_obj = torch.optim.AdamW(
109
+ self.model.parameters(),
110
+ lr=learning_rate,
111
+ betas=(0.9, 0.999),
112
+ eps=1e-8,
113
+ weight_decay=1e-2 # Stronger regularization for AdamW
114
+ )
115
+ elif optimizer == "adam":
116
+ optimizer_obj = torch.optim.Adam(
117
+ self.model.parameters(),
118
+ lr=learning_rate,
119
+ betas=(0.9, 0.999),
120
+ eps=1e-8,
121
+ weight_decay=1e-4 # Lighter regularization for Adam
122
+ )
123
+ else:
124
+ raise NotImplementedError(f"Only 'adam' and 'adamw' optimizers are supported. Got '{optimizer}'.")
125
+
126
+ # Use ReduceLROnPlateau for more stable learning rate adaptation
116
127
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
117
- optimizer,
118
- mode='min',
119
- factor=0.6, # Reduce LR by 40% when loss plateaus (more aggressive)
120
- patience=40, # Reduce sooner for faster adaptation
121
- min_lr=1e-5, # Higher minimum learning rate
122
- threshold=1e-4 # Less sensitive to plateaus
128
+ optimizer_obj,
129
+ mode='min',
130
+ factor=0.5, # Reduce LR by half
131
+ patience=max(2, patience),
132
+ threshold=1e-4,
133
+ min_lr=1e-5
123
134
  )
124
135
 
125
136
  # "Loss" for GPs - the marginal log likelihood
@@ -133,98 +144,97 @@ class MarginalGPyTorch(BaseModel):
133
144
  min_lr_for_early_stop = 2e-5 # Stop if patience is exceeded and LR is below this
134
145
 
135
146
  for i in pbar:
136
- if optimizer.__class__.__name__ == "LBFGS":
137
- # L-BFGS requires a closure function
138
- def closure():
139
- optimizer.zero_grad()
140
- output = self.model(train_x)
141
- with gpytorch.settings.cholesky_jitter(jitter):
142
- loss = -mll(output, train_y).sum()
143
- loss.backward()
144
- return loss
145
-
146
- loss = optimizer.step(closure)
147
- pbar.set_postfix(loss=loss.item())
148
- else:
149
- # Adam optimizer with stability features
150
- optimizer.zero_grad()
151
- output = self.model(train_x)
152
-
153
- # Attempt loss calculation with dynamic jitter
154
- try:
155
- with gpytorch.settings.cholesky_jitter(jitter):
156
- loss = -mll(output, train_y)
157
- except Exception as e:
158
- # Increase jitter if numerical issues occur
159
- jitter = min(jitter * 10, 1e-2)
160
- current_lr = optimizer.param_groups[0]['lr']
161
- pbar.set_postfix_str(
162
- f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
163
- )
164
- continue
165
-
166
- # Check for NaN loss
167
- if torch.isnan(loss) or torch.isinf(loss):
168
- current_lr = optimizer.param_groups[0]['lr']
169
- pbar.set_postfix_str(
170
- f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
171
- )
172
- continue
173
-
174
- loss.backward()
175
-
176
- # Gradient clipping for stability
177
- torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
178
-
179
- # Check for NaN gradients
180
- has_nan_grad = False
147
+ # Adam/AdamW optimizer with stability features
148
+ optimizer_obj.zero_grad()
149
+ output = self.model(train_x)
150
+
151
+ # Attempt loss calculation with dynamic jitter
152
+ try:
153
+ with gpytorch.settings.cholesky_jitter(jitter):
154
+ loss = -mll(output, train_y)
155
+ except Exception as e:
156
+ # Increase jitter if numerical issues occur
157
+ jitter = min(jitter * 10, 1e-2)
158
+ current_lr = optimizer_obj.param_groups[0]['lr']
159
+ pbar.set_postfix_str(
160
+ f'lr={current_lr:.1e} jitter={jitter:.1e} | Numerical issue - increasing jitter'
161
+ )
162
+ continue
163
+
164
+ # Check for NaN loss
165
+ if torch.isnan(loss) or torch.isinf(loss):
166
+ current_lr = optimizer_obj.param_groups[0]['lr']
167
+ pbar.set_postfix_str(
168
+ f'lr={current_lr:.1e} jitter={jitter:.1e} | NaN/Inf loss detected - skipping step'
169
+ )
170
+ continue
171
+
172
+ loss.backward()
173
+
174
+ # Get current learning rate before gradient noise injection
175
+ current_lr = optimizer_obj.param_groups[0]['lr']
176
+
177
+ # Gradient noise injection (if enabled)
178
+ if gradient_noise:
179
+ gradient_noise_scale = 0.1
180
+ adaptive_noise = gradient_noise_scale * current_lr
181
181
  for param in self.model.parameters():
182
- if param.grad is not None and torch.isnan(param.grad).any():
183
- has_nan_grad = True
184
- break
185
-
186
- if has_nan_grad:
187
- # Don't update scheduler on NaN gradients - this prevents rapid LR decay
188
- # The scheduler should only respond to actual optimization progress
189
- current_lr = optimizer.param_groups[0]['lr']
190
-
191
- # Update best loss tracking (loss is still valid, just gradients are NaN)
192
- if loss.item() < best_loss:
193
- best_loss = loss.item()
194
- patience_counter = 0
195
- else:
196
- patience_counter += 1
197
-
198
- # Display comprehensive info even with NaN gradients
199
- pbar.set_postfix_str(
200
- f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
201
- )
202
- continue
203
-
204
- optimizer.step()
205
-
206
- # Update learning rate scheduler for Adam
207
- scheduler.step(loss)
208
- current_lr = optimizer.param_groups[0]['lr']
209
-
210
- # Early stopping check (more aggressive)
182
+ if param.grad is not None:
183
+ noise = torch.normal(mean=0.0, std=adaptive_noise, size=param.grad.shape, device=param.grad.device)
184
+ param.grad.add_(noise)
185
+
186
+ # Gradient clipping for stability
187
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
188
+
189
+ # Check for NaN gradients
190
+ has_nan_grad = False
191
+ for param in self.model.parameters():
192
+ if param.grad is not None and torch.isnan(param.grad).any():
193
+ has_nan_grad = True
194
+ break
195
+
196
+ if has_nan_grad:
197
+ # Don't update scheduler on NaN gradients - this prevents rapid LR decay
198
+ # The scheduler should only respond to actual optimization progress
199
+ current_lr = optimizer_obj.param_groups[0]['lr']
200
+
201
+ # Update best loss tracking (loss is still valid, just gradients are NaN)
211
202
  if loss.item() < best_loss:
212
203
  best_loss = loss.item()
213
204
  patience_counter = 0
214
205
  else:
215
206
  patience_counter += 1
216
207
 
217
- # Display progress with comprehensive metadata
208
+ # Display comprehensive info even with NaN gradients, skip normal progress update
209
+ pbar.set_postfix_str(
210
+ f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f} | NaN gradients - skipping step'
211
+ )
212
+ continue
213
+
214
+ optimizer_obj.step()
215
+
216
+ # Update learning rate scheduler for Adam/AdamW
217
+ scheduler.step(loss.item())
218
+ current_lr = optimizer_obj.param_groups[0]['lr']
219
+
220
+ # Early stopping check (more aggressive)
221
+ if loss.item() < best_loss:
222
+ best_loss = loss.item()
223
+ patience_counter = 0
224
+ else:
225
+ patience_counter += 1
226
+
227
+ # Only update progress bar if not skipped above
228
+ if not has_nan_grad:
218
229
  progress_info = f'loss={loss.item():.4f} lr={current_lr:.1e} jitter={jitter:.1e} best={best_loss:.4f}'
219
230
  if early_stopping:
220
- progress_info += f' patience={patience_counter}/25'
231
+ progress_info += f' patience={patience_counter}/{patience}'
221
232
  pbar.set_postfix_str(progress_info)
222
233
 
223
- # More aggressive early stopping: patience=25 and require LR to be low
224
- if early_stopping and patience_counter >= 25 and current_lr <= min_lr_for_early_stop:
225
- print(f"\nEarly stopping triggered after {i+1} iterations (patience exceeded and LR low)")
226
- print(f"Best loss: {best_loss:.6f}")
227
- break
234
+ if early_stopping and patience_counter >= patience and current_lr <= min_lr_for_early_stop:
235
+ print(f"\nEarly stopping triggered after {i+1} iterations")
236
+ print(f"Best loss: {best_loss:.6f}")
237
+ break
228
238
 
229
239
  @is_fitted
230
240
  def predict(self,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: discontinuum
3
- Version: 1.0.4
3
+ Version: 1.0.5
4
4
  Summary: Estimate discontinuous timeseries from continuous covariates.
5
5
  Maintainer-email: Timothy Hodson <thodson@usgs.gov>
6
6
  License: License
@@ -1,17 +1,17 @@
1
1
  discontinuum/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- discontinuum/_version.py,sha256=rXTOeD0YpRo_kJ2LqUiMnTKEFf43sO_PBvJHDh0SQUA,511
2
+ discontinuum/_version.py,sha256=cafGA7j4exK_daS29O0wW2JM2p5b8FwYDLqAYJR0jWo,511
3
3
  discontinuum/data_manager.py,sha256=LiZoPR0nnu7YAUfh5L1ZDRfaS3dgfVIELXIHkzUKyBg,4416
4
4
  discontinuum/pipeline.py,sha256=1avuZnFai-b3HmihcpZ8M3WFNQ8lXAFSNTrnfl2NrY0,10074
5
5
  discontinuum/plot.py,sha256=eZQS6-Ydq8FFcEukPtNuDVB-weV6lHyWMyJ1hqTkVrU,2969
6
6
  discontinuum/utils.py,sha256=07hIHQk_oDlkjz7tasgBjqqPOC6D0iNcy0eu-88aNbM,1540
7
7
  discontinuum/engines/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  discontinuum/engines/base.py,sha256=OlHd4ssIQoWvYHKoVqk5fKAVBcKsIIkR4ul9iNBvaYg,2396
9
- discontinuum/engines/gpytorch.py,sha256=36TxE_qfRUjuOB16eXmyrxPlicKzXkdQ7xnfqL2ucy0,14539
9
+ discontinuum/engines/gpytorch.py,sha256=kRyAgCfxjKZbAJhJGViaDU_y8NO8sW4rSWRyEQlomHo,14383
10
10
  discontinuum/engines/pymc.py,sha256=phbtE-3UCSVcP1MhbXwAHIWDZWDr56wK9U7aRt-w-2o,5961
11
11
  discontinuum/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  discontinuum/providers/base.py,sha256=Yn2EHS1b4fYl09-m2MYuf2P9VRUXAP-WDpSoZrCbRvY,720
13
13
  discontinuum/tests/test_pipeline.py,sha256=_FhkGxbFIxNb35lGaIdZk7Zjgs6CkxEF3gFUX3PE8EU,918
14
- discontinuum-1.0.4.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
14
+ discontinuum-1.0.5.dist-info/licenses/LICENSE.md,sha256=XElVHHnS2uQ15M_Z2giPH1vmeWMzdpGQ48ItkuZurVA,1650
15
15
  loadest_gp/__init__.py,sha256=YISfvbc7Zy2y0BOxS1A2KzqxyoNJTz0EnLMnRW6iVT8,740
16
16
  loadest_gp/plot.py,sha256=x2PK7vBCc44dX9lu5YV-rvw1u4pvXSLdcrTSvYLiHMA,2595
17
17
  loadest_gp/utils.py,sha256=m5QaqR_0JiuRXPfryH8nI5lODp8PqvQla5C05WDN3LY,2772
@@ -25,11 +25,11 @@ rating_gp/pipeline.py,sha256=1HgxN6DD3ZL5lhUb3DK2in2IXiml7W4Ja272GBMTc08,1884
25
25
  rating_gp/plot.py,sha256=CJphwqWWAfIY22j5Oz5DRwj7TcQCRyIQvM79_3KEdlc,9635
26
26
  rating_gp/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  rating_gp/models/base.py,sha256=e2Kq644I88YLHWPNA0qyRgitF5wimdLW4618vKX-o_s,1474
28
- rating_gp/models/gpytorch.py,sha256=4SqOdWIvI93kDq9S4cDPHXX25EHNjT_hKwZijhAR4C0,7121
28
+ rating_gp/models/gpytorch.py,sha256=eFHwtnW44GZ1zz0fLx5REbzIWwnb_x_uq-cGjDcHyWs,6907
29
29
  rating_gp/models/kernels.py,sha256=3xg2mhY3aEgjI3r5vyAll9MA4c3M5UKqRi3FApNhJJQ,11579
30
30
  rating_gp/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  rating_gp/providers/usgs.py,sha256=KmKYN3c8Mi-ly2l6X80WT3taEhqCPXeEcRNi9HvbJmY,8134
32
- discontinuum-1.0.4.dist-info/METADATA,sha256=A6T6BQocZmIox600f7nU5Tb9r7x5YthC5ba1WRET2XM,6302
33
- discontinuum-1.0.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
- discontinuum-1.0.4.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
35
- discontinuum-1.0.4.dist-info/RECORD,,
32
+ discontinuum-1.0.5.dist-info/METADATA,sha256=GDh-fscmNYYMmXXoUE5Xd1QafuKEOPjgZjlssmjqGVg,6302
33
+ discontinuum-1.0.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
34
+ discontinuum-1.0.5.dist-info/top_level.txt,sha256=mwU_PSFrZYSJrBgqIuTJTo7Pp9ODDv6XdDed7kAagXM,34
35
+ discontinuum-1.0.5.dist-info/RECORD,,
@@ -110,13 +110,13 @@ class ExactGPModel(gpytorch.models.ExactGP):
110
110
  # + stage * time kernel only at low stage with smaller time length.
111
111
  # Note that stage gets transformed to q, so the kernel is actually
112
112
  # q * time
113
- b_min = np.quantile(train_y, 0.30)
113
+ b_min = np.quantile(train_y, 0.10)
114
114
  b_max = np.quantile(train_y, 0.90)
115
115
  self.covar_module = (
116
- (self.cov_stage(ls_prior=GammaPrior(concentration=2, rate=1))
116
+ (self.cov_stage(ls_prior=GammaPrior(concentration=1, rate=1))
117
117
  * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=1)))
118
- + (self.cov_stage(ls_prior=GammaPrior(concentration=5, rate=1))
119
- * self.cov_time(ls_prior=GammaPrior(concentration=1, rate=5))
118
+ + (self.cov_stage(ls_prior=GammaPrior(concentration=3, rate=1))
119
+ * self.cov_time(ls_prior=GammaPrior(concentration=2, rate=5))
120
120
  * SigmoidKernel(
121
121
  active_dims=self.stage_dim,
122
122
  # a_prior=NormalPrior(loc=20, scale=1),
@@ -167,20 +167,15 @@ class ExactGPModel(gpytorch.models.ExactGP):
167
167
  ),
168
168
  outputscale_prior=eta,
169
169
  )
170
-
171
- # Periodic kernel for annual seasonality
172
- # Locally periodic kernel: Periodic * Matern
170
+
171
+ # Periodic performs beter than a locally periodic kernel
173
172
  periodic_kernel = ScaleKernel(
174
173
  gpytorch.kernels.PeriodicKernel(
175
174
  active_dims=self.time_dim,
176
- period_length_prior=NormalPrior(loc=1.0, scale=0.05), # ~1 year
177
- lengthscale_prior=GammaPrior(concentration=6, rate=1),
178
- ) * MaternKernel(
179
- active_dims=self.time_dim,
180
- nu=2.5,
181
- lengthscale_prior=GammaPrior(concentration=4, rate=3),
175
+ period_length_prior=NormalPrior(loc=1.0, scale=0.1), # ~1 year
176
+ lengthscale_prior=GammaPrior(concentration=2, rate=4),
182
177
  ),
183
- outputscale_prior=HalfNormalPrior(scale=0.2),
178
+ outputscale_prior=HalfNormalPrior(scale=0.5),
184
179
  )
185
180
 
186
181
  return base_kernel + periodic_kernel