libinephany 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,7 @@ from libinephany.utils import optim_utils
25
25
  # ======================================================================================================================
26
26
 
27
27
  EXP_AVERAGE = "exp_avg"
28
+ MOMENTUM_BUFFER = "momentum_buffer"
28
29
  MIN_DECAY_FACTOR = 1e-10
29
30
 
30
31
  MIN_TOTAL_WEIGHT = 1e-15 # Minimum total weight threshold for numerical stability
@@ -64,10 +65,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
64
65
  :param values: List of values to average via EWA.
65
66
  :return: EWA of the given values.
66
67
  """
67
-
68
68
  exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
69
69
  assert isinstance(exp_weighted_average, float)
70
-
71
70
  return exp_weighted_average
72
71
 
73
72
 
@@ -232,6 +231,24 @@ def form_update_tensor(
232
231
  raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
233
232
 
234
233
 
234
+ def form_momentum_tensor(
235
+ optimizer: optim.Optimizer, parameters: list[torch.Tensor], parameter_group: dict[str, Any]
236
+ ) -> None | torch.Tensor:
237
+ """
238
+ :param optimizer: Optimizer to form the momentum tensor from.
239
+ :param parameters: Parameters to create the momentum tensor from.
240
+ :param parameter_group: Parameter group within the optimizer the given parameters came from.
241
+ """
242
+ if type(optimizer) in optim_utils.ADAM_OPTIMISERS:
243
+ momentum_list = [optimizer.state[p][EXP_AVERAGE].view(-1) for p in parameters if tensor_on_local_rank(p)]
244
+ return torch.cat(momentum_list) if momentum_list else None
245
+ elif type(optimizer) in optim_utils.SGD_OPTIMISERS:
246
+ momentum_list = [optimizer.state[p][MOMENTUM_BUFFER].view(-1) for p in parameters if tensor_on_local_rank(p)]
247
+ return torch.cat(momentum_list) if momentum_list else None
248
+ else:
249
+ raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
250
+
251
+
235
252
  def null_standardizer(value_to_standardize: float, **kwargs) -> float:
236
253
  """
237
254
  :param value_to_standardize: Value to mock the standardization of.
@@ -8,7 +8,15 @@
8
8
  # ======================================================================================================================
9
9
 
10
10
 
11
- from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients, LHOPTGradientVarianceFraction
11
+ from .gradient_observers import (
12
+ CosineSimilarityObserverOfGradientAndMomentum,
13
+ CosineSimilarityObserverOfGradientAndUpdate,
14
+ CosineSimilarityOfGradientAndParameter,
15
+ GlobalFirstOrderGradients,
16
+ GlobalSecondOrderGradients,
17
+ LHOPTGradientVarianceFraction,
18
+ LHOPTMomentumGradientRatio,
19
+ )
12
20
  from .hyperparameter_observers import (
13
21
  InitialHyperparameters,
14
22
  LHOPTHyperparameterRatio,
@@ -31,8 +39,11 @@ from .model_observers import (
31
39
  GlobalLAMBTrustRatio,
32
40
  GlobalParameters,
33
41
  GlobalParameterUpdates,
42
+ LHOPTAverageParameterUpdateMagnitudeObserver,
43
+ LHOPTGlobalLAMBTrustRatio,
34
44
  LogRatioOfPreviousAndCurrentParamNormEnvStepObserver,
35
45
  LogRatioOfUpdateAndPreviousParamNormEnvStepObserver,
46
+ LogRatioOfUpdateAndPreviousParamNormInnerStepObserver,
36
47
  NumberOfLayers,
37
48
  NumberOfParameters,
38
49
  )
@@ -51,14 +62,17 @@ __all__ = [
51
62
  GlobalFirstOrderGradients.__name__,
52
63
  GlobalSecondOrderGradients.__name__,
53
64
  LHOPTGradientVarianceFraction.__name__,
65
+ LHOPTMomentumGradientRatio.__name__,
54
66
  GlobalActivations.__name__,
55
67
  GlobalParameterUpdates.__name__,
56
68
  GlobalParameters.__name__,
57
69
  GlobalLAMBTrustRatio.__name__,
58
70
  NumberOfParameters.__name__,
59
71
  NumberOfLayers.__name__,
72
+ LHOPTAverageParameterUpdateMagnitudeObserver.__name__,
60
73
  LogRatioOfPreviousAndCurrentParamNormEnvStepObserver.__name__,
61
74
  LogRatioOfUpdateAndPreviousParamNormEnvStepObserver.__name__,
75
+ LogRatioOfUpdateAndPreviousParamNormInnerStepObserver.__name__,
62
76
  TrainingProgress.__name__,
63
77
  EpochsCompleted.__name__,
64
78
  ProgressAtEachCheckpoint.__name__,
@@ -66,4 +80,8 @@ __all__ = [
66
80
  LHOPTValidationLoss.__name__,
67
81
  LHOPTLossRatio.__name__,
68
82
  PercentileOfLossAtEachCheckpoint.__name__,
83
+ LHOPTGlobalLAMBTrustRatio.__name__,
84
+ CosineSimilarityObserverOfGradientAndMomentum.__name__,
85
+ CosineSimilarityObserverOfGradientAndUpdate.__name__,
86
+ CosineSimilarityOfGradientAndParameter.__name__,
69
87
  ]
@@ -20,6 +20,7 @@ class LHOPTConstants(TypedDict):
20
20
  ZERO_DIVISION_TOLERANCE: float
21
21
  DEFAULT_SAMPLE_FREQUENCY: int
22
22
  DEFAULT_VARIANCE_THRESHOLD: float
23
+ DEFAULT_ENV_STEP_SAMPLE_FREQUENCY: int
23
24
 
24
25
 
25
26
  # Create the constants instance
@@ -36,4 +37,5 @@ LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
36
37
  ZERO_DIVISION_TOLERANCE=1e-8,
37
38
  DEFAULT_SAMPLE_FREQUENCY=4,
38
39
  DEFAULT_VARIANCE_THRESHOLD=1e-6,
40
+ DEFAULT_ENV_STEP_SAMPLE_FREQUENCY=10,
39
41
  )
@@ -4,6 +4,7 @@
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
+ import math
7
8
  from typing import Any
8
9
 
9
10
  from libinephany.observations import observation_utils, statistic_trackers
@@ -189,5 +190,322 @@ class LHOPTGradientVarianceFraction(LHOPTBaseObserver):
189
190
  """
190
191
 
191
192
  return {
192
- statistic_trackers.GradientVarianceFraction.__name__: dict(variance_threshold=self.variance_threshold),
193
+ statistic_trackers.GradientVarianceFraction.__name__: dict(
194
+ variance_threshold=self.variance_threshold, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
195
+ ),
196
+ }
197
+
198
+
199
+ class LHOPTMomentumGradientRatio(LHOPTBaseObserver):
200
+ """
201
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
202
+ https://arxiv.org/abs/2305.18291.
203
+
204
+ It returns two-dimensional observations: [raw_value, cdf_feature] for momentum gradient ratio values.
205
+ """
206
+
207
+ def _get_observation_format(self) -> StatisticStorageTypes:
208
+ """
209
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
210
+ enumeration class.
211
+ """
212
+
213
+ return StatisticStorageTypes.VECTOR
214
+
215
+ @property
216
+ def vector_length(self) -> int:
217
+ """
218
+ :return: Length of the vector returned by this observation if it returns a vector.
219
+ """
220
+ return 2 # [raw_value, cdf_feature]
221
+
222
+ def _observe(
223
+ self,
224
+ observation_inputs: ObservationInputs,
225
+ hyperparameter_states: HyperparameterStates,
226
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
227
+ action_taken: float | int | None,
228
+ ) -> float | int | list[int | float] | TensorStatistics:
229
+ """
230
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
231
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
232
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
233
+ names to floats or TensorStatistic models.
234
+ :param action_taken: Action taken by the agent this class instance is assigned to.
235
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
236
+ """
237
+
238
+ statistics = tracked_statistics[statistic_trackers.MomentumGradientRatioStatistics.__name__]
239
+
240
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
241
+
242
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
243
+ self._update_time()
244
+
245
+ return [raw_value, cdf_feature] # type: ignore[list-item]
246
+
247
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
248
+ """
249
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
250
+ needed.
251
+ """
252
+
253
+ return {
254
+ statistic_trackers.MomentumGradientRatioStatistics.__name__: dict(
255
+ skip_statistics=self.skip_statistics,
256
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
257
+ ),
258
+ }
259
+
260
+
261
+ class CosineSimilarityObserverOfGradientAndMomentum(LHOPTBaseObserver):
262
+ """
263
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
264
+ https://arxiv.org/abs/2305.18291.
265
+
266
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and momentum values.
267
+ """
268
+
269
+ def __init__(
270
+ self,
271
+ *,
272
+ skip_statistics: list[str] | None = None,
273
+ **kwargs,
274
+ ) -> None:
275
+ """
276
+ :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
277
+ or use the squared first order gradients as approximations in the same way Adam does.
278
+ :param kwargs: Miscellaneous keyword arguments.
279
+ """
280
+
281
+ super().__init__(**kwargs)
282
+
283
+ self.skip_statistics = skip_statistics
284
+
285
+ def _get_observation_format(self) -> StatisticStorageTypes:
286
+ """
287
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
288
+ enumeration class.
289
+ """
290
+
291
+ return StatisticStorageTypes.VECTOR
292
+
293
+ @property
294
+ def vector_length(self) -> int:
295
+ """
296
+ :return: Length of the vector returned by this observation if it returns a vector.
297
+ """
298
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
299
+
300
+ def _observe(
301
+ self,
302
+ observation_inputs: ObservationInputs,
303
+ hyperparameter_states: HyperparameterStates,
304
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
305
+ action_taken: float | int | None,
306
+ ) -> float | int | list[int | float] | TensorStatistics:
307
+ """
308
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
309
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
310
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
311
+ names to floats or TensorStatistic models.
312
+ :param action_taken: Action taken by the agent this class instance is assigned to.
313
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
314
+ """
315
+
316
+ statistics = tracked_statistics[
317
+ statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__
318
+ ]
319
+
320
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
321
+
322
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
323
+ self._update_time()
324
+
325
+ # Handle edge cases for logit calculation
326
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
327
+ logit_of_cdf_feature = 0.0
328
+ else:
329
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
330
+
331
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
332
+
333
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
334
+ """
335
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
336
+ needed.
337
+ """
338
+
339
+ return {
340
+ statistic_trackers.CosineSimilarityObserverOfGradientAndMomentumStatistics.__name__: dict(
341
+ skip_statistics=self.skip_statistics,
342
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
343
+ )
344
+ }
345
+
346
+
347
+ class CosineSimilarityObserverOfGradientAndUpdate(LHOPTBaseObserver):
348
+ """
349
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
350
+ https://arxiv.org/abs/2305.18291.
351
+
352
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and update values.
353
+ """
354
+
355
+ def __init__(
356
+ self,
357
+ *,
358
+ skip_statistics: list[str] | None = None,
359
+ **kwargs,
360
+ ) -> None:
361
+ """
362
+ :param skip_statistics: List of statistics to skip.
363
+ :param kwargs: Miscellaneous keyword arguments.
364
+ """
365
+
366
+ super().__init__(**kwargs)
367
+
368
+ self.skip_statistics = skip_statistics
369
+
370
+ def _get_observation_format(self) -> StatisticStorageTypes:
371
+ """
372
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
373
+ enumeration class.
374
+ """
375
+
376
+ return StatisticStorageTypes.VECTOR
377
+
378
+ @property
379
+ def vector_length(self) -> int:
380
+ """
381
+ :return: Length of the vector returned by this observation if it returns a vector.
382
+ """
383
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
384
+
385
+ def _observe(
386
+ self,
387
+ observation_inputs: ObservationInputs,
388
+ hyperparameter_states: HyperparameterStates,
389
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
390
+ action_taken: float | int | None,
391
+ ) -> float | int | list[int | float] | TensorStatistics:
392
+ """
393
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
394
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
395
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
396
+ names to floats or TensorStatistic models.
397
+ :param action_taken: Action taken by the agent this class instance is assigned to.
398
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
399
+ """
400
+
401
+ statistics = tracked_statistics[
402
+ statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__
403
+ ]
404
+
405
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
406
+
407
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
408
+ self._update_time()
409
+
410
+ # Handle edge cases for logit calculation
411
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
412
+ logit_of_cdf_feature = 0.0
413
+ else:
414
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
415
+
416
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
417
+
418
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
419
+ """
420
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
421
+ needed.
422
+ """
423
+
424
+ return {
425
+ statistic_trackers.CosineSimilarityObserverOfGradientAndUpdateStatistics.__name__: dict(
426
+ skip_statistics=self.skip_statistics,
427
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
428
+ )
429
+ }
430
+
431
+
432
+ class CosineSimilarityOfGradientAndParameter(LHOPTBaseObserver):
433
+ """
434
+ This is a global observer from the OpenAI paper "Learning to Optimize with Reinforcement Learning"
435
+ https://arxiv.org/abs/2305.18291.
436
+
437
+ It returns two-dimensional observations: [raw_value, cdf_feature] for cosine similarity of gradient and parameter values.
438
+ """
439
+
440
+ def __init__(
441
+ self,
442
+ *,
443
+ skip_statistics: list[str] | None = None,
444
+ **kwargs,
445
+ ) -> None:
446
+ """
447
+ :param skip_statistics: List of statistics to skip.
448
+ :param kwargs: Miscellaneous keyword arguments.
449
+ """
450
+ super().__init__(**kwargs)
451
+
452
+ self.skip_statistics = skip_statistics
453
+
454
+ def _get_observation_format(self) -> StatisticStorageTypes:
455
+ """
456
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
457
+ enumeration class.
458
+ """
459
+
460
+ return StatisticStorageTypes.VECTOR
461
+
462
+ @property
463
+ def vector_length(self) -> int:
464
+ """
465
+ :return: Length of the vector returned by this observation if it returns a vector.
466
+ """
467
+ return 3 # [raw_value, cdf_feature, logit_of_cdf_feature]
468
+
469
+ def _observe(
470
+ self,
471
+ observation_inputs: ObservationInputs,
472
+ hyperparameter_states: HyperparameterStates,
473
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
474
+ action_taken: float | int | None,
475
+ ) -> float | int | list[int | float] | TensorStatistics:
476
+ """
477
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
478
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
479
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
480
+ names to floats or TensorStatistic models.
481
+ :param action_taken: Action taken by the agent this class instance is assigned to.
482
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
483
+ """
484
+
485
+ statistics = tracked_statistics[statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__]
486
+
487
+ raw_value = list(statistics.values())[0] # type: ignore[list-item]
488
+
489
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
490
+ self._update_time()
491
+
492
+ # Handle edge cases for logit calculation
493
+ if cdf_feature <= 0.0 or cdf_feature >= 1.0:
494
+ logit_of_cdf_feature = 0.0
495
+ else:
496
+ logit_of_cdf_feature = math.log(cdf_feature / (1 - cdf_feature))
497
+
498
+ return [raw_value, cdf_feature, logit_of_cdf_feature] # type: ignore[list-item]
499
+
500
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
501
+ """
502
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
503
+ needed.
504
+ """
505
+
506
+ return {
507
+ statistic_trackers.CosineSimilarityOfGradientAndParameterStatistics.__name__: dict(
508
+ skip_statistics=self.skip_statistics,
509
+ sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"],
510
+ )
193
511
  }
@@ -385,7 +385,9 @@ class LogRatioOfPreviousAndCurrentParamNormEnvStepObserver(LHOPTBaseObserver):
385
385
  """
386
386
 
387
387
  return {
388
- statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics),
388
+ statistic_trackers.ParameterStatistics.__name__: dict(
389
+ skip_statistics=self.skip_statistics,
390
+ ),
389
391
  }
390
392
 
391
393
  def reset(self) -> None:
@@ -456,8 +458,12 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
456
458
  """
457
459
 
458
460
  return {
459
- statistic_trackers.ParameterUpdateStatistics.__name__: dict(skip_statistics=self.skip_statistics),
460
- statistic_trackers.ParameterStatistics.__name__: dict(skip_statistics=self.skip_statistics),
461
+ statistic_trackers.ParameterUpdateStatistics.__name__: dict(
462
+ skip_statistics=self.skip_statistics,
463
+ ),
464
+ statistic_trackers.ParameterStatistics.__name__: dict(
465
+ skip_statistics=self.skip_statistics,
466
+ ),
461
467
  }
462
468
 
463
469
  def reset(self) -> None:
@@ -467,3 +473,213 @@ class LogRatioOfUpdateAndPreviousParamNormEnvStepObserver(LHOPTBaseObserver):
467
473
 
468
474
  super().reset()
469
475
  self._previous_param_norm = None
476
+
477
+
478
+ class LHOPTAverageParameterUpdateMagnitudeObserver(LHOPTBaseObserver):
479
+
480
+ def _get_observation_format(self) -> StatisticStorageTypes:
481
+ """
482
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
483
+ enumeration class.
484
+ """
485
+
486
+ return StatisticStorageTypes.VECTOR
487
+
488
+ @property
489
+ def can_standardize(self) -> bool:
490
+ """
491
+ :return: Whether the observation can be standardized.
492
+ """
493
+
494
+ return False
495
+
496
+ @property
497
+ def vector_length(self) -> int:
498
+ """
499
+ :return: Length of the vector returned by this observation if it returns a vector.
500
+ """
501
+ return 2 # [raw_feature, cdf_feature]
502
+
503
+ def _observe(
504
+ self,
505
+ observation_inputs: ObservationInputs,
506
+ hyperparameter_states: HyperparameterStates,
507
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
508
+ action_taken: float | int | None,
509
+ ) -> float | int | list[int | float] | TensorStatistics:
510
+ """
511
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
512
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
513
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
514
+ names to floats or TensorStatistic models.
515
+ :param action_taken: Action taken by the agent this class instance is assigned to.
516
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
517
+ """
518
+
519
+ statistics = tracked_statistics[statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__]
520
+
521
+ raw_feature = list(statistics.values())[0] # type: ignore[list-item]
522
+
523
+ cdf_feature = self._compute_cdf_feature(raw_feature) # type: ignore[arg-type]
524
+ self._update_time()
525
+
526
+ return [raw_feature, cdf_feature] # type: ignore[list-item]
527
+
528
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
529
+ """
530
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
531
+ needed.
532
+ """
533
+
534
+ return {
535
+ statistic_trackers.AverageParameterUpdateMagnitudeStatistics.__name__: dict(
536
+ skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
537
+ )
538
+ }
539
+
540
+
541
+ class LogRatioOfUpdateAndPreviousParamNormInnerStepObserver(LHOPTBaseObserver):
542
+ def __init__(self, **kwargs):
543
+ """
544
+ This observer is used to compute the log ratio of the update and previous parameter norm for the inner step. The sample frequency of the statistics needs to be set to 4 (according to the OpenAI paper).
545
+
546
+ """
547
+ super().__init__(**kwargs)
548
+ self._previous_param_norm = None
549
+
550
+ @property
551
+ def vector_length(self) -> int:
552
+ """
553
+ :return: Length of the vector returned by this observation if it returns a vector.
554
+ """
555
+ return 2 # [tanh_feature, cdf_feature]
556
+
557
+ def _observe(
558
+ self,
559
+ observation_inputs: ObservationInputs,
560
+ hyperparameter_states: HyperparameterStates,
561
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
562
+ action_taken: float | int | None,
563
+ ) -> float | int | list[int | float] | TensorStatistics:
564
+ """
565
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
566
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
567
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
568
+ names to floats or TensorStatistics models.
569
+ :param action_taken: Action taken by the agent this class instance is assigned to.
570
+ :return: List containing [raw_log_ratio, cdf_feature].
571
+ """
572
+
573
+ update_statistics = tracked_statistics[statistic_trackers.InnerStepParameterUpdateStatistics.__name__]
574
+ param_statistics = tracked_statistics[statistic_trackers.InnerStepParameterStatistics.__name__]
575
+ update_norm = observation_utils.average_tensor_statistics(
576
+ tensor_statistics=[stats for stats in update_statistics.values() if isinstance(stats, TensorStatistics)]
577
+ ).norm_
578
+
579
+ current_param_norm = observation_utils.average_tensor_statistics(
580
+ tensor_statistics=[stats for stats in param_statistics.values() if isinstance(stats, TensorStatistics)]
581
+ ).norm_
582
+
583
+ if self._previous_param_norm is None:
584
+ self._previous_param_norm = current_param_norm
585
+ self._compute_cdf_feature(0.0) # default value since we can't compute log ratio yet
586
+ self._update_time()
587
+ return [0.0, 0.0]
588
+ log_ratio = self._compute_log_ratio(update_norm, self._previous_param_norm)
589
+ tanh_feature = math.tanh(max(-LHOPT_CONSTANTS["TANH_BOUND"], min(LHOPT_CONSTANTS["TANH_BOUND"], log_ratio)))
590
+ cdf_feature = self._compute_cdf_feature(log_ratio)
591
+ self._update_time()
592
+ self._previous_param_norm = current_param_norm
593
+
594
+ return [tanh_feature, cdf_feature]
595
+
596
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
597
+ """
598
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
599
+ needed.
600
+ """
601
+
602
+ return {
603
+ statistic_trackers.InnerStepParameterUpdateStatistics.__name__: dict(
604
+ skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
605
+ ),
606
+ statistic_trackers.InnerStepParameterStatistics.__name__: dict(
607
+ skip_statistics=self.skip_statistics, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
608
+ ),
609
+ }
610
+
611
+ def reset(self) -> None:
612
+ """
613
+ Reset the observer by clearing the previous parameter norm and time series.
614
+ """
615
+
616
+ super().reset()
617
+ self._previous_param_norm = None
618
+
619
+
620
+ class LHOPTGlobalLAMBTrustRatio(LHOPTBaseObserver):
621
+
622
+ def __init__(
623
+ self,
624
+ *,
625
+ use_log_transform: bool = False,
626
+ **kwargs,
627
+ ) -> None:
628
+ """
629
+ :param use_log_transform: Whether to transform the LAMB trust ratio by taking ln(1 + R).
630
+ :param kwargs: Other observation keyword arguments.
631
+ """
632
+
633
+ super().__init__(**kwargs)
634
+
635
+ self.use_log_transform = use_log_transform
636
+
637
+ def _get_observation_format(self) -> StatisticStorageTypes:
638
+ """
639
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
640
+ enumeration class.
641
+ """
642
+
643
+ return StatisticStorageTypes.VECTOR
644
+
645
+ @property
646
+ def vector_length(self) -> int:
647
+ """
648
+ :return: Length of the vector returned by this observation if it returns a vector.
649
+ """
650
+ return 2 # [raw_value, cdf_feature]
651
+
652
+ def _observe(
653
+ self,
654
+ observation_inputs: ObservationInputs,
655
+ hyperparameter_states: HyperparameterStates,
656
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
657
+ action_taken: float | int | None,
658
+ ) -> float | int | list[int | float] | TensorStatistics:
659
+ """
660
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
661
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
662
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
663
+ names to floats or TensorStatistic models.
664
+ :param action_taken: Action taken by the agent this class instance is assigned to.
665
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
666
+ """
667
+
668
+ statistics = tracked_statistics[statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__]
669
+
670
+ raw_value = sum(statistics.values()) / len(statistics) # type: ignore[arg-type]
671
+ cdf_feature = self._compute_cdf_feature(raw_value) # type: ignore[arg-type]
672
+ self._update_time()
673
+ return [raw_value, cdf_feature] # type: ignore[list-item]
674
+
675
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
676
+ """
677
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
678
+ needed.
679
+ """
680
+
681
+ return {
682
+ statistic_trackers.LHOPTLAMBTrustRatioStatistics.__name__: dict(
683
+ use_log_transform=self.use_log_transform, sample_frequency=LHOPT_CONSTANTS["DEFAULT_SAMPLE_FREQUENCY"]
684
+ )
685
+ }