libinephany 0.18.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,7 @@ from libinephany.utils import optim_utils
25
25
  # ======================================================================================================================
26
26
 
27
27
  EXP_AVERAGE = "exp_avg"
28
+ MOMENTUM_BUFFER = "momentum_buffer"
28
29
  MIN_DECAY_FACTOR = 1e-10
29
30
 
30
31
  MIN_TOTAL_WEIGHT = 1e-15 # Minimum total weight threshold for numerical stability
@@ -64,10 +65,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
64
65
  :param values: List of values to average via EWA.
65
66
  :return: EWA of the given values.
66
67
  """
67
-
68
68
  exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
69
69
  assert isinstance(exp_weighted_average, float)
70
-
71
70
  return exp_weighted_average
72
71
 
73
72
 
@@ -232,6 +231,24 @@ def form_update_tensor(
232
231
  raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
233
232
 
234
233
 
234
+ def form_momentum_tensor(
235
+ optimizer: optim.Optimizer, parameters: list[torch.Tensor], parameter_group: dict[str, Any]
236
+ ) -> None | torch.Tensor:
237
+ """
238
+ :param optimizer: Optimizer to form the momentum tensor from.
239
+ :param parameters: Parameters to create the momentum tensor from.
240
+ :param parameter_group: Parameter group within the optimizer the given parameters came from.
241
+ """
242
+ if type(optimizer) in optim_utils.ADAM_OPTIMISERS:
243
+ momentum_list = [optimizer.state[p][EXP_AVERAGE].view(-1) for p in parameters if tensor_on_local_rank(p)]
244
+ return torch.cat(momentum_list) if momentum_list else None
245
+ elif type(optimizer) in optim_utils.SGD_OPTIMISERS:
246
+ momentum_list = [optimizer.state[p][MOMENTUM_BUFFER].view(-1) for p in parameters if tensor_on_local_rank(p)]
247
+ return torch.cat(momentum_list) if momentum_list else None
248
+ else:
249
+ raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
250
+
251
+
235
252
  def null_standardizer(value_to_standardize: float, **kwargs) -> float:
236
253
  """
237
254
  :param value_to_standardize: Value to mock the standardization of.
@@ -43,15 +43,15 @@ class Observer(ABC):
43
43
  standardizer: Standardizer | None,
44
44
  observer_config: ObserverConfig,
45
45
  should_standardize: bool = True,
46
- skip_statistics: list[str] | None = None,
46
+ include_statistics: list[str] | None = None,
47
47
  **kwargs,
48
48
  ) -> None:
49
49
  """
50
50
  :param standardizer: None or the standardizer to apply to the returned observations.
51
51
  :param global_config: ObserverConfig that can be used to inform various observation calculations.
52
52
  :param should_standardize: Whether standardization should be applied to returned values.
53
- :param skip_statistics: If the observation uses the TensorStatistic model to return observations, names of the
54
- fields in the model to not include in returned observations.
53
+ :param include_statistics: If the observation uses the TensorStatistic model to return observations, names of the
54
+ fields in the model to include in returned observations.
55
55
  :param kwargs: Miscellaneous keyword arguments.
56
56
  """
57
57
 
@@ -63,7 +63,10 @@ class Observer(ABC):
63
63
  self.standardize = standardizer if standardizer is not None else observation_utils.null_standardizer
64
64
  self.should_standardize = should_standardize and self.can_standardize
65
65
 
66
- self.skip_statistics = TensorStatistics.filter_skip_statistics(skip_statistics=skip_statistics)
66
+ self.include_statistics: list[str] | None = None
67
+
68
+ if include_statistics is not None:
69
+ self.include_statistics = TensorStatistics.filter_include_statistics(include_statistics=include_statistics)
67
70
 
68
71
  @final
69
72
  @property
@@ -102,7 +105,10 @@ class Observer(ABC):
102
105
  observation_format = self.observation_format
103
106
 
104
107
  if observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
105
- return len([field for field in TensorStatistics.model_fields.keys() if field not in self.skip_statistics])
108
+ if self.include_statistics is None:
109
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
110
+
111
+ return len([field for field in TensorStatistics.model_fields.keys() if field in self.include_statistics])
106
112
 
107
113
  elif observation_format is StatisticStorageTypes.FLOAT:
108
114
  return 1
@@ -231,10 +237,13 @@ class Observer(ABC):
231
237
  self._cached_observation = deepcopy(observations)
232
238
 
233
239
  if self.observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
240
+ if self.include_statistics is None:
241
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
242
+
234
243
  if return_dict:
235
244
  observations_dict = observations.as_observation_dict() # type: ignore
236
245
 
237
- observations = observations.to_list(skip_statistics=self.skip_statistics) # type: ignore
246
+ observations = observations.to_list(include_statistics=self.include_statistics) # type: ignore
238
247
 
239
248
  observations = [observations] if not isinstance(observations, list) else observations # type: ignore
240
249
 
@@ -256,7 +265,7 @@ class Observer(ABC):
256
265
  def inform(self) -> float | int | dict[str, float] | None:
257
266
  """
258
267
  :return: The cached observation. If the observation format is TensorStatistics then it is converted to a
259
- dictionary with the statistics specified in skip_statistics excluded.
268
+ dictionary with the statistics specified in include_statistics included.
260
269
  """
261
270
 
262
271
  if not self.can_inform:
@@ -269,7 +278,10 @@ class Observer(ABC):
269
278
  )
270
279
 
271
280
  if self.observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
272
- observation = self._cached_observation.model_dump(exclude=set(self.skip_statistics)) # type: ignore
281
+ if self.include_statistics is None:
282
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_statistics.")
283
+
284
+ observation = self._cached_observation.model_dump(include=set(self.include_statistics)) # type: ignore
273
285
 
274
286
  else:
275
287
  observation = self._cached_observation
@@ -8,7 +8,15 @@
8
8
  # ======================================================================================================================
9
9
 
10
10
 
11
- from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients, LHOPTGradientVarianceFraction
11
+ from .gradient_observers import (
12
+ CosineSimilarityObserverOfGradientAndMomentum,
13
+ CosineSimilarityObserverOfGradientAndUpdate,
14
+ CosineSimilarityOfGradientAndParameter,
15
+ GlobalFirstOrderGradients,
16
+ GlobalSecondOrderGradients,
17
+ LHOPTGradientVarianceFraction,
18
+ LHOPTMomentumGradientRatio,
19
+ )
12
20
  from .hyperparameter_observers import (
13
21
  InitialHyperparameters,
14
22
  LHOPTHyperparameterRatio,
@@ -31,8 +39,11 @@ from .model_observers import (
31
39
  GlobalLAMBTrustRatio,
32
40
  GlobalParameters,
33
41
  GlobalParameterUpdates,
42
+ LHOPTAverageParameterUpdateMagnitudeObserver,
43
+ LHOPTGlobalLAMBTrustRatio,
34
44
  LogRatioOfPreviousAndCurrentParamNormEnvStepObserver,
35
45
  LogRatioOfUpdateAndPreviousParamNormEnvStepObserver,
46
+ LogRatioOfUpdateAndPreviousParamNormInnerStepObserver,
36
47
  NumberOfLayers,
37
48
  NumberOfParameters,
38
49
  )
@@ -51,14 +62,17 @@ __all__ = [
51
62
  GlobalFirstOrderGradients.__name__,
52
63
  GlobalSecondOrderGradients.__name__,
53
64
  LHOPTGradientVarianceFraction.__name__,
65
+ LHOPTMomentumGradientRatio.__name__,
54
66
  GlobalActivations.__name__,
55
67
  GlobalParameterUpdates.__name__,
56
68
  GlobalParameters.__name__,
57
69
  GlobalLAMBTrustRatio.__name__,
58
70
  NumberOfParameters.__name__,
59
71
  NumberOfLayers.__name__,
72
+ LHOPTAverageParameterUpdateMagnitudeObserver.__name__,
60
73
  LogRatioOfPreviousAndCurrentParamNormEnvStepObserver.__name__,
61
74
  LogRatioOfUpdateAndPreviousParamNormEnvStepObserver.__name__,
75
+ LogRatioOfUpdateAndPreviousParamNormInnerStepObserver.__name__,
62
76
  TrainingProgress.__name__,
63
77
  EpochsCompleted.__name__,
64
78
  ProgressAtEachCheckpoint.__name__,
@@ -66,4 +80,8 @@ __all__ = [
66
80
  LHOPTValidationLoss.__name__,
67
81
  LHOPTLossRatio.__name__,
68
82
  PercentileOfLossAtEachCheckpoint.__name__,
83
+ LHOPTGlobalLAMBTrustRatio.__name__,
84
+ CosineSimilarityObserverOfGradientAndMomentum.__name__,
85
+ CosineSimilarityObserverOfGradientAndUpdate.__name__,
86
+ CosineSimilarityOfGradientAndParameter.__name__,
69
87
  ]
@@ -20,6 +20,7 @@ class LHOPTConstants(TypedDict):
20
20
  ZERO_DIVISION_TOLERANCE: float
21
21
  DEFAULT_SAMPLE_FREQUENCY: int
22
22
  DEFAULT_VARIANCE_THRESHOLD: float
23
+ DEFAULT_ENV_STEP_SAMPLE_FREQUENCY: int
23
24
 
24
25
 
25
26
  # Create the constants instance
@@ -36,4 +37,5 @@ LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
36
37
  ZERO_DIVISION_TOLERANCE=1e-8,
37
38
  DEFAULT_SAMPLE_FREQUENCY=4,
38
39
  DEFAULT_VARIANCE_THRESHOLD=1e-6,
40
+ DEFAULT_ENV_STEP_SAMPLE_FREQUENCY=10,
39
41
  )
@@ -4,6 +4,7 @@
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
+ import math
7
8
  from typing import Any
8
9
 
9
10
  from libinephany.observations import observation_utils, statistic_trackers
@@ -52,7 +53,7 @@ class GlobalFirstOrderGradients(GlobalObserver):
52
53
  needed.
53
54
  """
54
55
 
55
- return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
56
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(include_statistics=self.include_statistics)}
56
57
 
57
58
 
58
59
  class GlobalSecondOrderGradients(GlobalObserver):
@@ -109,7 +110,7 @@ class GlobalSecondOrderGradients(GlobalObserver):
109
110
 
110
111
  return {
111
112
  statistic_trackers.SecondOrderGradients.__name__: dict(
112
- skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
113
+ include_statistics=self.include_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
113
114
  )
114
115
  }
115
116
 
@@ -189,5 +190,321 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
189
190
  """
190
191
 
191
192
  return {
192
- statistic_trackers.GradientVarianceFraction.__name__: dict(variance_threshold=self.variance_threshold),
193
+ statistic_trackers.GradientVarianceFraction.__name__: dict(
194
+ variance_threshold=self.variance_threshold, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
195
+ ),
196
+ }
197
+
198
+
199
+ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
200
+ """
201
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
202
+ https://arxiv.org/abs/2305.18291.
203
+
204
+ It returns two-dimensional observations: [raw_value, cdf_feature] for momentum gradient ratio values.
205
+ """
206
+
207
+ def _get_observation_format(self) -> StatisticStorageTypes:
208
+ """
209
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
210
+ enumeration class.
211
+ """
212
+
213
+ return StatisticStorageTypes.VECTOR
214
+
215
+ @property
216
+ def vector_length(self) -> int:
217
+ """
218
+ :return: Length of the vector returned by this observation if it returns a vector.
219
+ """
220
+ return 2 # [raw_value, cdf_feature]
221
+
222
+ def _observe(
223
+ self,
224
+ observation_inputs: ObservationInputs,
225
+ hyperparameter_states: HyperparameterStates,
226
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
227
+ action_taken: float | int | None,
228
+ ) -> float | int | list[int | float] | TensorStatistics:
229
+ """
230
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
231
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
232
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
233
+ names to floats or TensorStatistic models.
234
+ :param action_taken: Action taken by the agent this class instance is assigned to.
235
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
236
+ """
237
+
238
+ statistics = tracked_statistics[statistic_trackers.MomentumGradientRatioStatistics.__name__]
239
+
240
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
241
+
242
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
243
+ self._update_time()
244
+
245
+ return [raw_value, cdf_feature] # type: ignore[list-item]
246
+
247
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
248
+ """
249
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
250
+ needed.
251
+ """
252
+
253
+ return {
254
+ statistic_trackers.MomentumGradientRatioStatistics.__name__: dict(
255
+ include_statistics=self.include_statistics,
256
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
257
+ ),
258
+ }
259
+
260
+
261
+ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
262
+ """
263
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
264
+ https://arxiv.org/abs/2305.18291.
265
+
266
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and momentum values.
267
+ """
268
+
269
+ def __init__(
270
+ self,
271
+ *,
272
+ include_statistics: list[str] | None = None,
273
+ **kwargs,
274
+ ) -> None:
275
+ """
276
+ :param include_statistics: List of statistics to include.
277
+ :param kwargs: Miscellaneous keyword arguments.
278
+ """
279
+
280
+ super().__init__(**kwargs)
281
+
282
+ self.include_statistics = include_statistics
283
+
284
+ def _get_observation_format(self) -> StatisticStorageTypes:
285
+ """
286
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
287
+ enumeration class.
288
+ """
289
+
290
+ return StatisticStorageTypes.VECTOR
291
+
292
+ @property
293
+ def vector_length(self) -> int:
294
+ """
295
+ :return: Length of the vector returned by this observation if it returns a vector.
296
+ """
297
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
298
+
299
+ def _observe(
300
+ self,
301
+ observation_inputs: ObservationInputs,
302
+ hyperparameter_states: HyperparameterStates,
303
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
304
+ action_taken: float | int | None,
305
+ ) -> float | int | list[int | float] | TensorStatistics:
306
+ """
307
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
308
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
309
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
310
+ names to floats or TensorStatistic models.
311
+ :param action_taken: Action taken by the agent this class instance is assigned to.
312
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
313
+ """
314
+
315
+ statistics = tracked_statistics[
316
+ statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__
317
+ ]
318
+
319
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
320
+
321
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
322
+ self._update_time()
323
+
324
+ # Handle edge cases for logit calculation
325
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
326
+ logit_of_cdf_feature = 0.0
327
+ else:
328
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
329
+
330
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
331
+
332
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
333
+ """
334
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
335
+ needed.
336
+ """
337
+
338
+ return {
339
+ statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__: dict(
340
+ include_statistics=self.include_statistics,
341
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
342
+ )
343
+ }
344
+
345
+
346
+ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
347
+ """
348
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
349
+ https://arxiv.org/abs/2305.18291.
350
+
351
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and update values.
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ *,
357
+ include_statistics: list[str] | None = None,
358
+ **kwargs,
359
+ ) -> None:
360
+ """
361
+ :param include_statistics: List of statistics to include.
362
+ :param kwargs: Miscellaneous keyword arguments.
363
+ """
364
+
365
+ super().__init__(**kwargs)
366
+
367
+ self.include_statistics = include_statistics
368
+
369
+ def _get_observation_format(self) -> StatisticStorageTypes:
370
+ """
371
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
372
+ enumeration class.
373
+ """
374
+
375
+ return StatisticStorageTypes.VECTOR
376
+
377
+ @property
378
+ def vector_length(self) -> int:
379
+ """
380
+ :return: Length of the vector returned by this observation if it returns a vector.
381
+ """
382
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
383
+
384
+ def _observe(
385
+ self,
386
+ observation_inputs: ObservationInputs,
387
+ hyperparameter_states: HyperparameterStates,
388
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
389
+ action_taken: float | int | None,
390
+ ) -> float | int | list[int | float] | TensorStatistics:
391
+ """
392
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
393
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
394
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
395
+ names to floats or TensorStatistic models.
396
+ :param action_taken: Action taken by the agent this class instance is assigned to.
397
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
398
+ """
399
+
400
+ statistics = tracked_statistics[
401
+ statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__
402
+ ]
403
+
404
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
405
+
406
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
407
+ self._update_time()
408
+
409
+ # Handle edge cases for logit calculation
410
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
411
+ logit_of_cdf_feature = 0.0
412
+ else:
413
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
414
+
415
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
416
+
417
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
418
+ """
419
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
420
+ needed.
421
+ """
422
+
423
+ return {
424
+ statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__: dict(
425
+ include_statistics=self.include_statistics,
426
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
427
+ )
428
+ }
429
+
430
+
431
+ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
432
+ """
433
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
434
+ https://arxiv.org/abs/2305.18291.
435
+
436
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and parameter values.
437
+ """
438
+
439
+ def __init__(
440
+ self,
441
+ *,
442
+ include_statistics: list[str] | None = None,
443
+ **kwargs,
444
+ ) -> None:
445
+ """
446
+ :param include_statistics: List of statistics to include.
447
+ :param kwargs: Miscellaneous keyword arguments.
448
+ """
449
+ super().__init__(**kwargs)
450
+
451
+ self.include_statistics = include_statistics
452
+
453
+ def _get_observation_format(self) -> StatisticStorageTypes:
454
+ """
455
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
456
+ enumeration class.
457
+ """
458
+
459
+ return StatisticStorageTypes.VECTOR
460
+
461
+ @property
462
+ def vector_length(self) -> int:
463
+ """
464
+ :return: Length of the vector returned by this observation if it returns a vector.
465
+ """
466
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
467
+
468
+ def _observe(
469
+ self,
470
+ observation_inputs: ObservationInputs,
471
+ hyperparameter_states: HyperparameterStates,
472
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
473
+ action_taken: float | int | None,
474
+ ) -> float | int | list[int | float] | TensorStatistics:
475
+ """
476
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
477
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
478
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
479
+ names to floats or TensorStatistic models.
480
+ :param action_taken: Action taken by the agent this class instance is assigned to.
481
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
482
+ """
483
+
484
+ statistics = tracked_statistics[statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__]
485
+
486
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
487
+
488
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
489
+ self._update_time()
490
+
491
+ # Handle edge cases for logit calculation
492
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
493
+ logit_of_cdf_feature = 0.0
494
+ else:
495
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
496
+
497
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
498
+
499
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
500
+ """
501
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
502
+ needed.
503
+ """
504
+
505
+ return {
506
+ statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__: dict(
507
+ include_statistics=self.include_statistics,
508
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
509
+ )
193
510
  }
@@ -21,18 +21,16 @@ from libinephany.utils.enums import ModelFamilies
21
21
 
22
22
  class InitialHyperparameters(GlobalObserver):
23
23
 
24
- def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
24
+ def __init__(self, include_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
25
25
  """
26
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
26
+ :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
27
27
  this observation.
28
28
  :param kwargs: Miscellaneous keyword arguments.
29
29
  """
30
30
 
31
31
  super().__init__(**kwargs)
32
32
 
33
- force_skip = ["samples", "gradient_accumulation"]
34
- skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
35
- self.skip_hparams = [] if skip_hparams is None else skip_hparams
33
+ self.include_hparams = include_hparams
36
34
  self.pad_with = pad_with
37
35
 
38
36
  @property
@@ -41,9 +39,12 @@ class InitialHyperparameters(GlobalObserver):
41
39
  :return: Length of the vector returned by this observation if it returns a vector.
42
40
  """
43
41
 
42
+ if self.include_hparams is None:
43
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
44
+
44
45
  available_hparams = HyperparameterStates.get_all_hyperparameters()
45
46
 
46
- return len([hparam for hparam in available_hparams if hparam not in self.skip_hparams])
47
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
47
48
 
48
49
  @property
49
50
  def can_standardize(self) -> bool:
@@ -85,12 +86,14 @@ class InitialHyperparameters(GlobalObserver):
85
86
  :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
86
87
  """
87
88
 
88
- initial_internal_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
89
+ assert self.include_hparams is not None
90
+
91
+ initial_internal_values = hyperparameter_states.get_initial_internal_values(self.include_hparams)
89
92
  self._cached_observation = initial_internal_values
90
93
  initial_internal_values_list = [
91
94
  self.pad_with if initial_internal_value is None else initial_internal_value
92
95
  for hparam_name, initial_internal_value in initial_internal_values.items()
93
- if hparam_name not in self.skip_hparams
96
+ if hparam_name in self.include_hparams
94
97
  ]
95
98
  return initial_internal_values_list
96
99
 
@@ -179,7 +182,8 @@ class ModelFamilyOneHot(GlobalObserver):
179
182
  **kwargs,
180
183
  ) -> None:
181
184
  """
182
- :param skip_observations: List of episode boundary observations to ignore.
185
+ :param zero_vector_chance: Chance of the output vector being masked with zeros.
186
+ :param zero_vector_frequency_unit: Unit of time to sample the zero vector.
183
187
  :param kwargs: Miscellaneous keyword arguments.
184
188
  """
185
189
  super().__init__(**kwargs)
@@ -294,17 +298,16 @@ class LHOPTHyperparameterRatio(GlobalObserver):
294
298
  providing insights into how much hyperparameters have changed from their starting values.
295
299
  """
296
300
 
297
- def __init__(self, skip_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
301
+ def __init__(self, include_hparams: list[str] | None = None, pad_with: float = 0.0, **kwargs) -> None:
298
302
  """
299
- :param skip_hparams: Names of the hyperparameters to not include in the initial values vector returned by
303
+ :param include_hparams: Names of the hyperparameters to include in the initial values vector returned by
300
304
  this observation.
301
305
  :param kwargs: Miscellaneous keyword arguments.
302
306
  """
303
307
 
304
308
  super().__init__(**kwargs)
305
309
 
306
- force_skip = ["samples", "gradient_accumulation"]
307
- self.skip_hparams = force_skip if skip_hparams is None else skip_hparams + force_skip
310
+ self.include_hparams = include_hparams
308
311
  self.pad_with = pad_with
309
312
 
310
313
  @property
@@ -313,9 +316,12 @@ class LHOPTHyperparameterRatio(GlobalObserver):
313
316
  :return: Length of the vector returned by this observation if it returns a vector.
314
317
  """
315
318
 
319
+ if self.include_hparams is None:
320
+ raise ValueError(f"{self.__class__.__name__} must be provided with include_hparams.")
321
+
316
322
  available_hparams = HyperparameterStates.get_all_hyperparameters()
317
323
 
318
- return len([hparam for hparam in available_hparams if hparam not in self.skip_hparams])
324
+ return len([hparam for hparam in available_hparams if hparam in self.include_hparams])
319
325
 
320
326
  @property
321
327
  def can_standardize(self) -> bool:
@@ -357,18 +363,20 @@ class LHOPTHyperparameterRatio(GlobalObserver):
357
363
  :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
358
364
  """
359
365
 
366
+ assert self.include_hparams is not None
367
+
360
368
  # Get initial and current hyperparameter values
361
- initial_values = hyperparameter_states.get_initial_internal_values(self.skip_hparams)
369
+ initial_values = hyperparameter_states.get_initial_internal_values(self.include_hparams)
362
370
  initial_values = {
363
371
  hparam_name: self.pad_with if initial_value is None else initial_value
364
372
  for hparam_name, initial_value in initial_values.items()
365
- if hparam_name not in self.skip_hparams
373
+ if hparam_name in self.include_hparams
366
374
  }
367
- current_values = hyperparameter_states.get_current_internal_values(self.skip_hparams)
375
+ current_values = hyperparameter_states.get_current_internal_values(self.include_hparams)
368
376
  current_values = {
369
377
  hparam_name: self.pad_with if current_value is None else current_value
370
378
  for hparam_name, current_value in current_values.items()
371
- if hparam_name not in self.skip_hparams
379
+ if hparam_name in self.include_hparams
372
380
  }
373
381
 
374
382
  ratios = []