libinephany 0.19.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -43,15 +43,15 @@ class Observer(ABC):
43
43
  standardizer: Standardizer | None,
44
44
  observer_config: ObserverConfig,
45
45
  should_standardize: bool = True,
46
- skip_statistics: list[str] | None = None,
46
+ include_statistics: list[str] | None = None,
47
47
  **kwargs,
48
48
  ) -> None:
49
49
  """
50
50
  :param standardizer: None or the standardizer to apply to the returned observations.
51
51
  :param global_config: ObserverConfig that can be used to inform various observation calculations.
52
52
  :param should_standardize: Whether standardization should be applied to returned values.
53
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
54
- fields in the model to not include in returned observations.
53
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
54
+ fields in the model to include in returned observations.
55
55
  :param kwargs: Miscellaneous keyword arguments.
56
56
  """
57
57
 
@@ -63,7 +63,10 @@ class Observer(ABC):
63
63
  self.standardize = standardizer if standardizer is not None else observation_utils.null_standardizer
64
64
  self.should_standardize = should_standardize and self.can_standardize
65
65
 
66
- self.skip_statistics = TensorStatistics.filter_skip_statistics(skip_statistics=skip_statistics)
66
+ self.include_statistics: list[str] | None = None
67
+
68
+ if include_statistics is not None:
69
+ self.include_statistics = TensorStatistics.filter_include_statistics(include_statistics=include_statistics)
67
70
 
68
71
  @final
69
72
  @property
@@ -102,7 +105,10 @@ class Observer(ABC):
102
105
  observation_format = self.observation_format
103
106
 
104
107
  if observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
105
- return len([field for field in TensorStatistics.model_fields.keys() if field not in self.skip_statistics])
108
+ if self.include_statistics is None:
109
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
110
+
111
+ return len([field for field in TensorStatistics.model_fields.keys() if field in self.include_statistics])
106
112
 
107
113
  elif observation_format is StatisticStorageTypes.FLOAT:
108
114
  return 1
@@ -231,10 +237,13 @@ class Observer(ABC):
231
237
  self._cached_observation = deepcopy(observations)
232
238
 
233
239
  if self.observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
240
+ if self.include_statistics is None:
241
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
242
+
234
243
  if return_dict:
235
244
  observations_dict = observations.as_observation_dict() # type: ignore
236
245
 
237
- observations = observations.to_list(skip_statistics=self.skip_statistics) # type: ignore
246
+ observations = observations.to_list(include_statistics=self.include_statistics) # type: ignore
238
247
 
239
248
  observations = [observations] if not isinstance(observations, list) else observations # type: ignore
240
249
 
@@ -256,7 +265,7 @@ class Observer(ABC):
256
265
  def inform(self) -> float | int | dict[str, float] | None:
257
266
  """
258
267
  :return: The cached observation. If the observation format is TensorStatistics then it is converted to a
259
- dictionary with the statistics specified in skip_statistics excluded.
268
+ dictionary with the statistics specified in include_statistics included.
260
269
  """
261
270
 
262
271
  if not self.can_inform:
@@ -269,7 +278,10 @@ class Observer(ABC):
269
278
  )
270
279
 
271
280
  if self.observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
272
- observation = self._cached_observation.model_dump(exclude=set(self.skip_statistics)) # type: ignore
281
+ if self.include_statistics is None:
282
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
283
+
284
+ observation = self._cached_observation.model_dump(include=set(self.include_statistics)) # type: ignore
273
285
 
274
286
  else:
275
287
  observation = self._cached_observation
@@ -53,7 +53,7 @@ class GlobalFirstOrderGradients(GlobalObserver):
53
53
  needed.
54
54
  """
55
55
 
56
- return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
56
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(include_statistics=self.include_statistics)}
57
57
 
58
58
 
59
59
  class GlobalSecondOrderGradients(GlobalObserver):
@@ -110,7 +110,7 @@ class GlobalSecondOrderGradients(GlobalObserver):
110
110
 
111
111
  return {
112
112
  statistic_trackers.SecondOrderGradients.__name__: dict(
113
- skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
113
+ include_statistics=self.include_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
114
114
  )
115
115
  }
116
116
 
@@ -252,7 +252,7 @@ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
252
252
 
253
253
  return {
254
254
  statistic_trackers.MomentumGradientRatioStatistics.__name__: dict(
255
- skip_statistics=self.skip_statistics,
255
+ include_statistics=self.include_statistics,
256
256
  sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
257
257
  ),
258
258
  }
@@ -269,18 +269,17 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
269
269
  def __init__(
270
270
  self,
271
271
  *,
272
- skip_statistics: list[str] | None = None,
272
+ include_statistics: list[str] | None = None,
273
273
  **kwargs,
274
274
  ) -> None:
275
275
  """
276
- :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
277
- or use the squared first order gradients as approximations in the same way Adam does.
276
+ :param include_statistics: List of statistics to include.
278
277
  :param kwargs: Miscellaneous keyword arguments.
279
278
  """
280
279
 
281
280
  super().__init__(**kwargs)
282
281
 
283
- self.skip_statistics = skip_statistics
282
+ self.include_statistics = include_statistics
284
283
 
285
284
  def _get_observation_format(self) -> StatisticStorageTypes:
286
285
  """
@@ -338,7 +337,7 @@ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
338
337
 
339
338
  return {
340
339
  statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__: dict(
341
- skip_statistics=self.skip_statistics,
340
+ include_statistics=self.include_statistics,
342
341
  sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
343
342
  )
344
343
  }
@@ -355,17 +354,17 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
355
354
  def __init__(
356
355
  self,
357
356
  *,
358
- skip_statistics: list[str] | None = None,
357
+ include_statistics: list[str] | None = None,
359
358
  **kwargs,
360
359
  ) -> None:
361
360
  """
362
- :param skip_statistics: List of statistics to skip.
361
+ :param include_statistics: List of statistics to include.
363
362
  :param kwargs: Miscellaneous keyword arguments.
364
363
  """
365
364
 
366
365
  super().__init__(**kwargs)
367
366
 
368
- self.skip_statistics = skip_statistics
367
+ self.include_statistics = include_statistics
369
368
 
370
369
  def _get_observation_format(self) -> StatisticStorageTypes:
371
370
  """
@@ -423,7 +422,7 @@ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
423
422
 
424
423
  return {
425
424
  statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__: dict(
426
- skip_statistics=self.skip_statistics,
425
+ include_statistics=self.include_statistics,
427
426
  sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
428
427
  )
429
428
  }
@@ -440,16 +439,16 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
440
439
  def __init__(
441
440
  self,
442
441
  *,
443
- skip_statistics: list[str] | None = None,
442
+ include_statistics: list[str] | None = None,
444
443
  **kwargs,
445
444
  ) -> None:
446
445
  """
447
- :param skip_statistics: List of statistics to skip.
446
+ :param include_statistics: List of statistics to include.
448
447
  :param kwargs: Miscellaneous keyword arguments.
449
448
  """
450
449
  super().__init__(**kwargs)
451
450
 
452
- self.skip_statistics = skip_statistics
451
+ self.include_statistics = include_statistics
453
452
 
454
453
  def _get_observation_format(self) -> StatisticStorageTypes:
455
454
  """
@@ -505,7 +504,7 @@ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
505
504
 
506
505
  return {
507
506
  statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__: dict(
508
- skip_statistics=self.skip_statistics,
507
+ include_statistics=self.include_statistics,
509
508
  sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
510
509
  )
511
510
  }
@@ -21,18 +21,16 @@ from libinephany.utils.enums import ModelFamilies
21
21
 
22
22
  class InitialHyperparameters(GlobalObserver):
23
23
 
24
- def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
24
+ def __init__(self, include_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
25
25
  """
26
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
26
+ :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
27
27
  this observation.
28
28
  :param kwargs: Miscellaneous keyword arguments.
29
29
  """
30
30
 
31
31
  super().__init__(**kwargs)
32
32
 
33
- force_skip = ["samples", "gradient_accumulation"]
34
- skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
35
- self.skip_hparams = [] if skip_hparams is None else skip_hparams
33
+ self.include_hparams = include_hparams
36
34
  self.pad_with = pad_with
37
35
 
38
36
  @property
@@ -41,9 +39,12 @@ class InitialHyperparameters(GlobalObserver):
41
39
  :return: Length of the vector returned by this observation if it returns a vector.
42
40
  """
43
41
 
42
+ if self.include_hparams is None:
43
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
44
+
44
45
  available_hparams = HyperparameterStates.get_all_hyperparameters()
45
46
 
46
- return len([hparam for hparam in available_hparams if hparam not in self.skip_hparams])
47
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
47
48
 
48
49
  @property
49
50
  def can_standardize(self) -> bool:
@@ -85,12 +86,14 @@ class InitialHyperparameters(GlobalObserver):
85
86
  :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
86
87
  """
87
88
 
88
- initial_internal_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
89
+ assert self.include_hparams is not None
90
+
91
+ initial_internal_values = hyperparameter_states.get_initial_internal_values(self.include_hparams)
89
92
  self._cached_observation = initial_internal_values
90
93
  initial_internal_values_list = [
91
94
  self.pad_with if initial_internal_value is None else initial_internal_value
92
95
  for hparam_name, initial_internal_value in initial_internal_values.items()
93
- if hparam_name not in self.skip_hparams
96
+ if hparam_name in self.include_hparams
94
97
  ]
95
98
  return initial_internal_values_list
96
99
 
@@ -179,7 +182,8 @@ class ModelFamilyOneHot(GlobalObserver):
179
182
  **kwargs,
180
183
  ) -> None:
181
184
  """
182
- :param skip_observations: List of episode boundary observations to ignore.
185
+ :param zero_vector_chance: Chance of the output vector being masked with zeros.
186
+ :param zero_vector_frequency_unit: Unit of time to sample the zero vector.
183
187
  :param kwargs: Miscellaneous keyword arguments.
184
188
  """
185
189
  super().__init__(**kwargs)
@@ -294,17 +298,16 @@ class LHOPTHyperparameterRatio(GlobalObserver):
294
298
  providing insights into how much hyperparameters have changed from their starting values.
295
299
  """
296
300
 
297
- def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
301
+ def __init__(self, include_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
298
302
  """
299
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
303
+ :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
300
304
  this observation.
301
305
  :param kwargs: Miscellaneous keyword arguments.
302
306
  """
303
307
 
304
308
  super().__init__(**kwargs)
305
309
 
306
- force_skip = ["samples", "gradient_accumulation"]
307
- self.skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
310
+ self.include_hparams = include_hparams
308
311
  self.pad_with = pad_with
309
312
 
310
313
  @property
@@ -313,9 +316,12 @@ class LHOPTHyperparameterRatio(GlobalObserver):
313
316
  :return: Length of the vector returned by this observation if it returns a vector.
314
317
  """
315
318
 
319
+ if self.include_hparams is None:
320
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
321
+
316
322
  available_hparams = HyperparameterStates.get_all_hyperparameters()
317
323
 
318
- return len([hparam for hparam in available_hparams if hparam not in self.skip_hparams])
324
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
319
325
 
320
326
  @property
321
327
  def can_standardize(self) -> bool:
@@ -357,18 +363,20 @@ class LHOPTHyperparameterRatio(GlobalObserver):
357
363
  :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
358
364
  """
359
365
 
366
+ assert self.include_hparams is not None
367
+
360
368
  # Get initial and current hyperparameter values
361
- initial_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
369
+ initial_values = hyperparameter_states.get_initial_internal_values(self.include_hparams)
362
370
  initial_values = {
363
371
  hparam_name: self.pad_with if initial_value is None else initial_value
364
372
  for hparam_name, initial_value in initial_values.items()
365
- if hparam_name not in self.skip_hparams
373
+ if hparam_name in self.include_hparams
366
374
  }
367
- current_values = hyperparameter_states.get_current_internal_values(self.skip_hparams)
375
+ current_values = hyperparameter_states.get_current_internal_values(self.include_hparams)
368
376
  current_values = {
369
377
  hparam_name: self.pad_with if current_value is None else current_value
370
378
  for hparam_name, current_value in current_values.items()
371
- if hparam_name not in self.skip_hparams
379
+ if hparam_name in self.include_hparams
372
380
  }
373
381
 
374
382
  ratios = []
@@ -59,7 +59,7 @@ class GlobalActivations(GlobalObserver):
59
59
  needed.
60
60
  """
61
61
 
62
- return {statistic_trackers.ActivationStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
62
+ return {statistic_trackers.ActivationStatistics.__name__: dict(include_statistics=self.include_statistics)}
63
63
 
64
64
 
65
65
  class GlobalParameterUpdates(GlobalObserver):
@@ -98,7 +98,7 @@ class GlobalParameterUpdates(GlobalObserver):
98
98
  needed.
99
99
  """
100
100
 
101
- return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
101
+ return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics)}
102
102
 
103
103
 
104
104
  class GlobalParameters(GlobalObserver):
@@ -137,7 +137,7 @@ class GlobalParameters(GlobalObserver):
137
137
  needed.
138
138
  """
139
139
 
140
- return {statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
140
+ return {statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics)}
141
141
 
142
142
 
143
143
  class GlobalLAMBTrustRatio(GlobalObserver):
@@ -385,9 +385,7 @@ class LogRatioOfPreviousAndCurrentParamNormEnvStepObserver(LHOPTBaseObserver):
385
385
  """
386
386
 
387
387
  return {
388
- statistic_trackers.ParameterStatistics.__name__: dict(
389
- skip_statistics=self.skip_statistics,
390
- ),
388
+ statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics),
391
389
  }
392
390
 
393
391
  def reset(self) -> None:
@@ -443,9 +441,11 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
443
441
  self._compute_cdf_feature(0.0) # default value since we can't compute log ratio yet
444
442
  self._update_time()
445
443
  return [0.0, 0.0]
444
+
446
445
  log_ratio = self._compute_log_ratio(update_norm, self._previous_param_norm)
447
446
  tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
448
447
  cdf_feature = self._compute_cdf_feature(log_ratio)
448
+
449
449
  self._update_time()
450
450
  self._previous_param_norm = current_param_norm
451
451
 
@@ -458,12 +458,8 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
458
458
  """
459
459
 
460
460
  return {
461
- statistic_trackers.ParameterUpdateStatistics.__name__: dict(
462
- skip_statistics=self.skip_statistics,
463
- ),
464
- statistic_trackers.ParameterStatistics.__name__: dict(
465
- skip_statistics=self.skip_statistics,
466
- ),
461
+ statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics),
462
+ statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics),
467
463
  }
468
464
 
469
465
  def reset(self) -> None:
@@ -533,7 +529,7 @@ class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
533
529
 
534
530
  return {
535
531
  statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__: dict(
536
- skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
532
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
537
533
  )
538
534
  }
539
535
 
@@ -570,8 +566,8 @@ class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
570
566
  :return: List containing [raw_log_ratio, cdf_feature].
571
567
  """
572
568
 
573
- update_statistics = tracked_statistics[statistic_trackers.InnerStepParameterUpdateStatistics.__name__]
574
- param_statistics = tracked_statistics[statistic_trackers.InnerStepParameterStatistics.__name__]
569
+ update_statistics = tracked_statistics[statistic_trackers.LHOPTParameterUpdateStatistics.__name__]
570
+ param_statistics = tracked_statistics[statistic_trackers.LHOPTParameterStatistics.__name__]
575
571
  update_norm = observation_utils.average_tensor_statistics(
576
572
  tensor_statistics=[stats for stats in update_statistics.values() if isinstance(stats, TensorStatistics)]
577
573
  ).norm_
@@ -600,11 +596,11 @@ class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
600
596
  """
601
597
 
602
598
  return {
603
- statistic_trackers.InnerStepParameterUpdateStatistics.__name__: dict(
604
- skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
599
+ statistic_trackers.LHOPTParameterUpdateStatistics.__name__: dict(
600
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
605
601
  ),
606
- statistic_trackers.InnerStepParameterStatistics.__name__: dict(
607
- skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
602
+ statistic_trackers.LHOPTParameterStatistics.__name__: dict(
603
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
608
604
  ),
609
605
  }
610
606
 
@@ -680,6 +676,8 @@ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
680
676
 
681
677
  return {
682
678
  statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__: dict(
683
- use_log_transform=self.use_log_transform, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
679
+ include_statistics=self.include_statistics,
680
+ use_log_transform=self.use_log_transform,
681
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
684
682
  )
685
683
  }
@@ -1,6 +1,6 @@
1
1
  # ======================================================================================================================
2
2
  #
3
- # imports
3
+ # IMPORTS
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
@@ -13,6 +13,12 @@ from libinephany.pydantic_models.schemas.observation_models import ObservationIn
13
13
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
14
14
  from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
15
15
 
16
+ # ======================================================================================================================
17
+ #
18
+ # CLASSES
19
+ #
20
+ # ======================================================================================================================
21
+
16
22
 
17
23
  class TrainingProgress(GlobalObserver):
18
24
 
@@ -68,7 +68,7 @@ class FirstOrderGradients(LocalObserver):
68
68
  needed.
69
69
  """
70
70
 
71
- return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
71
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(include_statistics=self.include_statistics)}
72
72
 
73
73
 
74
74
  class SecondOrderGradients(LocalObserver):
@@ -132,7 +132,7 @@ class SecondOrderGradients(LocalObserver):
132
132
 
133
133
  return {
134
134
  statistic_trackers.SecondOrderGradients.__name__: dict(
135
- skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
135
+ include_statistics=self.include_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
136
136
  )
137
137
  }
138
138
 
@@ -180,7 +180,7 @@ class Activations(LocalObserver):
180
180
  needed.
181
181
  """
182
182
 
183
- return {statistic_trackers.ActivationStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
183
+ return {statistic_trackers.ActivationStatistics.__name__: dict(include_statistics=self.include_statistics)}
184
184
 
185
185
 
186
186
  class ParameterUpdates(LocalObserver):
@@ -226,7 +226,7 @@ class ParameterUpdates(LocalObserver):
226
226
  needed.
227
227
  """
228
228
 
229
- return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
229
+ return {statistic_trackers.ParameterUpdateStatistics.__name__: dict(include_statistics=self.include_statistics)}
230
230
 
231
231
 
232
232
  class Parameters(LocalObserver):
@@ -272,7 +272,7 @@ class Parameters(LocalObserver):
272
272
  needed.
273
273
  """
274
274
 
275
- return {statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics)}
275
+ return {statistic_trackers.ParameterStatistics.__name__: dict(include_statistics=self.include_statistics)}
276
276
 
277
277
 
278
278
  class LAMBTrustRatio(LocalObserver):
@@ -692,16 +692,16 @@ class ModuleTypeOneHot(LocalObserver):
692
692
 
693
693
  class CurrentHyperparameters(LocalObserver):
694
694
 
695
- def __init__(self, skip_hparams: list[str] | None = None, **kwargs) -> None:
695
+ def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
696
696
  """
697
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
697
+ :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
698
698
  this observation.
699
699
  :param kwargs: Miscellaneous keyword arguments.
700
700
  """
701
701
 
702
702
  super().__init__(**kwargs)
703
703
 
704
- self.skip_hparams = skip_hparams if skip_hparams is not None else []
704
+ self.include_hparams = include_hparams
705
705
 
706
706
  @property
707
707
  def can_standardize(self) -> bool:
@@ -725,11 +725,12 @@ class CurrentHyperparameters(LocalObserver):
725
725
  :return: Length of the vector returned by this observation if it returns a vector.
726
726
  """
727
727
 
728
+ if self.include_hparams is None:
729
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
730
+
728
731
  available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
729
732
 
730
- return len(
731
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
732
- )
733
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
733
734
 
734
735
  def _get_observation_format(self) -> StatisticStorageTypes:
735
736
  """
@@ -758,7 +759,7 @@ class CurrentHyperparameters(LocalObserver):
758
759
  assert self.parameter_group_name is not None
759
760
 
760
761
  current_internal_values = hyperparameter_states[self.parameter_group_name].get_current_internal_values(
761
- skip_hparams=self.skip_hparams
762
+ include_hparams=self.include_hparams
762
763
  )
763
764
 
764
765
  self._cached_observation = current_internal_values
@@ -776,16 +777,16 @@ class CurrentHyperparameters(LocalObserver):
776
777
 
777
778
  class CurrentHyperparameterDeltas(LocalObserver):
778
779
 
779
- def __init__(self, skip_hparams: list[str] | None = None, **kwargs) -> None:
780
+ def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
780
781
  """
781
- :param skip_hparams: Names of the hyperparameters to not include in the initial deltas vector returned by
782
+ :param include_hparams: Names of the hyperparameters to include in the initial deltas vector returned by
782
783
  this observation.
783
784
  :param kwargs: Miscellaneous keyword arguments.
784
785
  """
785
786
 
786
787
  super().__init__(**kwargs)
787
788
 
788
- self.skip_hparams = skip_hparams if skip_hparams is not None else []
789
+ self.include_hparams = include_hparams
789
790
 
790
791
  @property
791
792
  def can_standardize(self) -> bool:
@@ -809,11 +810,12 @@ class CurrentHyperparameterDeltas(LocalObserver):
809
810
  :return: Length of the vector returned by this observation if it returns a vector.
810
811
  """
811
812
 
813
+ if self.include_hparams is None:
814
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
815
+
812
816
  available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
813
817
 
814
- return len(
815
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
816
- )
818
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
817
819
 
818
820
  def _get_observation_format(self) -> StatisticStorageTypes:
819
821
  """
@@ -840,9 +842,10 @@ class CurrentHyperparameterDeltas(LocalObserver):
840
842
  """
841
843
 
842
844
  assert self.parameter_group_name is not None
845
+ assert self.include_hparams is not None
843
846
 
844
847
  current_deltas = hyperparameter_states[self.parameter_group_name].get_current_deltas(
845
- skip_hparams=self.skip_hparams
848
+ include_hparams=self.include_hparams
846
849
  )
847
850
 
848
851
  self._cached_observation = current_deltas
@@ -862,16 +865,16 @@ class HyperparameterTransformTypes(LocalObserver):
862
865
 
863
866
  TRANSFORM_TYPE_TO_IDX = dict(((s, i) for i, s in enumerate(HyperparameterTransformType)))
864
867
 
865
- def __init__(self, skip_hparams: list[str] | None = None, **kwargs) -> None:
868
+ def __init__(self, include_hparams: list[str] | None = None, **kwargs) -> None:
866
869
  """
867
- :param skip_hparams: Names of the hyperparameters to not include in the transforms vector returned by
870
+ :param include_hparams: Names of the hyperparameters to include in the transforms vector returned by
868
871
  this observation.
869
872
  :param kwargs: Miscellaneous keyword arguments.
870
873
  """
871
874
 
872
875
  super().__init__(**kwargs)
873
876
 
874
- self.skip_hparams = skip_hparams if skip_hparams is not None else []
877
+ self.include_hparams = include_hparams
875
878
 
876
879
  @property
877
880
  def can_standardize(self) -> bool:
@@ -895,10 +898,13 @@ class HyperparameterTransformTypes(LocalObserver):
895
898
  :return: Length of the vector returned by this observation if it returns a vector.
896
899
  """
897
900
 
901
+ if self.include_hparams is None:
902
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
903
+
898
904
  available_hparams = HyperparameterStates.get_layerwise_hyperparameters()
899
905
 
900
906
  return len(HyperparameterTransformType) * len(
901
- [hparam for hparam in available_hparams if not any(skipped in hparam for skipped in self.skip_hparams)]
907
+ [hparam for hparam in available_hparams if hparam in self.include_hparams]
902
908
  )
903
909
 
904
910
  def _get_observation_format(self) -> StatisticStorageTypes:
@@ -926,10 +932,11 @@ class HyperparameterTransformTypes(LocalObserver):
926
932
  """
927
933
 
928
934
  assert self.parameter_group_name is not None
935
+ assert self.include_hparams is not None
929
936
 
930
937
  parameter_group_hparams = hyperparameter_states[self.parameter_group_name]
931
938
  hyperparameter_transform_types = parameter_group_hparams.get_hyperparameter_transform_types(
932
- skip_hparams=self.skip_hparams
939
+ include_hparams=self.include_hparams
933
940
  )
934
941
  hyperparameter_transform_types_onehot_list = [
935
942
  observation_utils.create_one_hot_observation(
@@ -1080,20 +1087,20 @@ class LogOfNoiseScaleObserver(LocalObserver):
1080
1087
  *,
1081
1088
  decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
1082
1089
  time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
1083
- skip_statistics: list[str] | None = None,
1090
+ include_statistics: list[str] | None = None,
1084
1091
  **kwargs,
1085
1092
  ) -> None:
1086
1093
  """
1087
1094
  :param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
1088
1095
  :param time_window: Number of time steps to consider for CDF calculation
1089
- :param skip_statistics: Whether to skip the statistics
1096
+ :param include_statistics: List of statistics to include.
1090
1097
  or use the squared first order gradients as approximations in the same way Adam does.
1091
1098
  :param kwargs: Miscellaneous keyword arguments.
1092
1099
  """
1093
1100
 
1094
1101
  super().__init__(**kwargs)
1095
1102
 
1096
- self.skip_statistics = skip_statistics
1103
+ self.include_statistics = include_statistics
1097
1104
  self.decay_factor = max(0.0, decay_factor)
1098
1105
  self.time_window = max(1, time_window)
1099
1106
 
@@ -1187,7 +1194,7 @@ class LogOfNoiseScaleObserver(LocalObserver):
1187
1194
 
1188
1195
  return {
1189
1196
  statistic_trackers.LogOfNoiseScaleStatistics.__name__: dict(
1190
- skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
1197
+ include_statistics=self.include_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
1191
1198
  )
1192
1199
  }
1193
1200