libinephany 0.16.2__tar.gz → 0.16.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. libinephany-0.16.4/CODE_VERSION.cfg +1 -0
  2. {libinephany-0.16.2/libinephany.egg-info → libinephany-0.16.4}/PKG-INFO +2 -1
  3. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observation_utils.py +167 -13
  4. libinephany-0.16.4/libinephany/observations/observers/global_observers/__init__.py +58 -0
  5. libinephany-0.16.4/libinephany/observations/observers/global_observers/base_classes.py +183 -0
  6. libinephany-0.16.4/libinephany/observations/observers/global_observers/constants.py +33 -0
  7. libinephany-0.16.4/libinephany/observations/observers/global_observers/gradient_observers.py +112 -0
  8. libinephany-0.16.4/libinephany/observations/observers/global_observers/hyperparameter_observers.py +286 -0
  9. libinephany-0.16.4/libinephany/observations/observers/global_observers/loss_observers.py +464 -0
  10. libinephany-0.16.4/libinephany/observations/observers/global_observers/model_observers.py +327 -0
  11. libinephany-0.16.4/libinephany/observations/observers/global_observers/progress_observers.py +142 -0
  12. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/observer_containers.py +3 -1
  13. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/constants.py +1 -0
  14. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/enums.py +18 -0
  15. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/samplers.py +36 -1
  16. {libinephany-0.16.2 → libinephany-0.16.4/libinephany.egg-info}/PKG-INFO +2 -1
  17. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/SOURCES.txt +8 -1
  18. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/requires.txt +1 -0
  19. {libinephany-0.16.2 → libinephany-0.16.4}/pyproject.toml +1 -0
  20. libinephany-0.16.2/CODE_VERSION.cfg +0 -1
  21. libinephany-0.16.2/libinephany/observations/observers/global_observers.py +0 -991
  22. {libinephany-0.16.2 → libinephany-0.16.4}/LICENSE +0 -0
  23. {libinephany-0.16.2 → libinephany-0.16.4}/MANIFEST.in +0 -0
  24. {libinephany-0.16.2 → libinephany-0.16.4}/README.md +0 -0
  25. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/__init__.py +0 -0
  26. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/aws/__init__.py +0 -0
  27. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/aws/s3_functions.py +0 -0
  28. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/__init__.py +0 -0
  29. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observer_pipeline.py +0 -0
  30. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/__init__.py +0 -0
  31. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/base_observers.py +0 -0
  32. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/local_observers.py +0 -0
  33. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/pipeline_coordinator.py +0 -0
  34. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/post_processors/__init__.py +0 -0
  35. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/post_processors/postprocessors.py +0 -0
  36. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/statistic_manager.py +0 -0
  37. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/statistic_trackers.py +0 -0
  38. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/__init__.py +0 -0
  39. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/__init__.py +0 -0
  40. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
  41. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/observer_config.py +0 -0
  42. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
  43. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/__init__.py +0 -0
  44. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
  45. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/inner_task_profile.py +0 -0
  46. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
  47. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
  48. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
  49. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
  50. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/states/__init__.py +0 -0
  51. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/states/hyperparameter_states.py +0 -0
  52. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/__init__.py +0 -0
  53. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/agent_utils.py +0 -0
  54. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/asyncio_worker.py +0 -0
  55. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/backend_statuses.py +0 -0
  56. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/directory_utils.py +0 -0
  57. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/dropout_utils.py +0 -0
  58. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/error_severities.py +0 -0
  59. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/exceptions.py +0 -0
  60. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/import_utils.py +0 -0
  61. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/optim_utils.py +0 -0
  62. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/random_seeds.py +0 -0
  63. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/standardizers.py +0 -0
  64. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/torch_distributed_utils.py +0 -0
  65. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/torch_utils.py +0 -0
  66. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/transforms.py +0 -0
  67. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/typing.py +0 -0
  68. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/web_apps/__init__.py +0 -0
  69. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/web_apps/error_logger.py +0 -0
  70. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/web_apps/web_app_utils.py +0 -0
  71. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/dependency_links.txt +0 -0
  72. {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/top_level.txt +0 -0
  73. {libinephany-0.16.2 → libinephany-0.16.4}/setup.cfg +0 -0
@@ -0,0 +1 @@
1
+ 0.16.4
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.16.2
3
+ Version: 0.16.4
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -18,6 +18,7 @@ Requires-Dist: pydantic<3.0.0,>=2.5.0
18
18
  Requires-Dist: loguru<0.8.0,>=0.7.0
19
19
  Requires-Dist: requests<3.0.0,>=2.28.0
20
20
  Requires-Dist: numpy<2.0.0,>=1.24.0
21
+ Requires-Dist: scipy<2.0.0,>=1.10.0
21
22
  Requires-Dist: slack-sdk<4.0.0,>=3.20.0
22
23
  Requires-Dist: boto3<2.0.0,>=1.26.0
23
24
  Requires-Dist: fastapi<0.116.0,>=0.100.0
@@ -13,6 +13,7 @@ import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
15
  import torch.optim as optim
16
+ from scipy.stats import norm
16
17
 
17
18
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
18
19
  from libinephany.utils import optim_utils
@@ -24,6 +25,9 @@ from libinephany.utils import optim_utils
24
25
  # ======================================================================================================================
25
26
 
26
27
  EXP_AVERAGE = "exp_avg"
28
+ MIN_DECAY_FACTOR = 1e-10
29
+
30
+ MIN_TOTAL_WEIGHT = 1e-15 # Minimum total weight threshold for numerical stability
27
31
 
28
32
  # ======================================================================================================================
29
33
  #
@@ -60,20 +64,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
60
64
  :param values: List of values to average via EWA.
61
65
  :return: EWA of the given values.
62
66
  """
63
-
64
- # Check for NaN and infinite values in input
65
- valid_values = [float(val) for val in values if not math.isnan(float(val)) and not math.isinf(float(val))]
66
-
67
- if not valid_values:
68
- raise ValueError("Cannot compute exponential weighted average on empty list")
69
-
70
- if len(valid_values) == 1:
71
- return valid_values[0]
72
-
73
- exp_weighted_average = pd.Series(valid_values).ewm(alpha=0.1).mean().iloc[-1]
67
+ exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
74
68
  assert isinstance(exp_weighted_average, float)
75
- assert not math.isnan(exp_weighted_average)
76
- assert not math.isinf(exp_weighted_average)
77
69
  return exp_weighted_average
78
70
 
79
71
 
@@ -280,3 +272,165 @@ def concatenate_lists(lists: list[list[Any]]) -> list[Any]:
280
272
  """
281
273
 
282
274
  return list(chain(*lists))
275
+
276
+
277
+ def compute_cdf_weighted_mean_and_std(
278
+ time_series: list[tuple[float, float]], decay_factor: float
279
+ ) -> tuple[float, float]:
280
+ """
281
+ Compute the CDF-weighted standard deviation using the same exponential decay weights
282
+ as the mean calculation, with numerical integration.
283
+
284
+ :param time_series: List of (time, value) pairs
285
+ :param decay_factor: Decay factor b in the exponential weight formula b in [1.25, 2.5, 5, 10, 20]
286
+ :return: Tuple of (weighted mean, weighted standard deviation)
287
+ """
288
+
289
+ if len(time_series) == 0:
290
+ return 0.0, 0.0
291
+
292
+ if len(time_series) == 1:
293
+ return time_series[0][1], 0.0
294
+
295
+ sorted_series = sorted(time_series, key=lambda x: x[0])
296
+
297
+ # Handle the special case when decay_factor = 1.0
298
+ if abs(decay_factor - 1.0) < MIN_DECAY_FACTOR:
299
+ # When decay_factor = 1.0, w(t) = 1 for all t
300
+ # So the result is just the arithmetic mean
301
+ values = [v for _, v in sorted_series]
302
+ mean = float(np.mean(values))
303
+ std = float(np.std(values))
304
+ return mean, std
305
+
306
+ log_decay_factor = math.log(decay_factor)
307
+
308
+ total_weight = 0.0 # ∫ w(t) dt - total weight across all time intervals
309
+ total_weighted_value = 0.0 # ∫ w(t) y(t) dt - total weighted value
310
+ total_weighted_squared = 0.0 # ∫ w(t) y(t)² dt - total weighted squared value
311
+
312
+ for time_series_index in range(len(sorted_series) - 1):
313
+ start_time_point = sorted_series[time_series_index][0]
314
+ end_time_point = sorted_series[time_series_index + 1][0]
315
+ start_value = sorted_series[time_series_index][1]
316
+ end_value = sorted_series[time_series_index + 1][1]
317
+
318
+ time_interval = end_time_point - start_time_point
319
+ assert time_interval > 0, "Time interval must be positive"
320
+
321
+ interval_value = _weighted_interval_expectation(
322
+ start_time_point=start_time_point,
323
+ start_value=start_value,
324
+ end_time_point=end_time_point,
325
+ end_value=end_value,
326
+ log_decay_factor=log_decay_factor,
327
+ )
328
+ interval_squared_value = _weighted_interval_expectation(
329
+ start_time_point=start_time_point,
330
+ start_value=start_value**2,
331
+ end_time_point=end_time_point,
332
+ end_value=end_value**2,
333
+ log_decay_factor=log_decay_factor,
334
+ )
335
+
336
+ total_weighted_value += interval_value
337
+ total_weighted_squared += interval_squared_value
338
+
339
+ total_weight = (1 / log_decay_factor) * (
340
+ math.exp(log_decay_factor * sorted_series[-1][0]) - math.exp(log_decay_factor * sorted_series[0][0])
341
+ )
342
+ # Check if total weight is too small (numerical stability)
343
+ if total_weight < MIN_TOTAL_WEIGHT:
344
+ values = [v for _, v in sorted_series]
345
+ mean = float(np.mean(values))
346
+ std = float(np.std(values))
347
+ return mean, std
348
+
349
+ # Calculate weighted mean: μ = ∫ w(t) y(t) dt / ∫ w(t) dt
350
+ # This gives us the expected value under the weight distribution
351
+ weighted_mean = float(total_weighted_value / total_weight)
352
+
353
+ # Calculate weighted variance: Var = ∫ w(t) y(t)² dt / ∫ w(t) dt - μ²
354
+ # This follows from the definition: Var(X) = E[X²] - (E[X])²
355
+ # where E[X] = ∫ w(t) y(t) dt / ∫ w(t) dt and E[X²] = ∫ w(t) y(t)² dt / ∫ w(t) dt
356
+ weighted_variance = float(total_weighted_squared / total_weight - weighted_mean**2)
357
+
358
+ # Calculate weighted standard deviation: σ = √Var
359
+ # This is the square root of the variance, representing the spread of values
360
+ weighted_std = float(math.sqrt(max(0, weighted_variance)))
361
+
362
+ return weighted_mean, weighted_std
363
+
364
+
365
+ def _weighted_interval_expectation(
366
+ start_time_point: float,
367
+ start_value: float,
368
+ end_time_point: float,
369
+ end_value: float,
370
+ log_decay_factor: float,
371
+ ) -> float:
372
+ """
373
+ Computes the weighted interval expectation from Appendix E of the LHOPT paper.
374
+
375
+ :param start_time_point: the start time value of the interval.
376
+ :param start_value: the value at start_time_point.
377
+ :param end_time_point: the end time value of the interval.
378
+ :param end_value: the value at end_time_point.
379
+ :param log_decay_factor: the logarithm of the decay factor used to weight the expectation.
380
+ :return: the exponentially-weighted expectation of the linear interpolation between the start and end points.
381
+ """
382
+
383
+ interval_gradient = (end_value - start_value) / (end_time_point - start_time_point)
384
+ start_exp_time = math.exp(log_decay_factor * start_time_point)
385
+ end_exp_time = math.exp(log_decay_factor * end_time_point)
386
+ return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time) - (
387
+ 1 / log_decay_factor**2
388
+ ) * interval_gradient * (end_exp_time - start_exp_time)
389
+
390
+
391
+ def compute_cdf_feature(
392
+ current_value: float,
393
+ time_series: list[tuple[float, float]],
394
+ decay_factor: float,
395
+ current_time: float,
396
+ time_window: int,
397
+ ) -> float:
398
+ """
399
+
400
+ This function computes a CDF feature that represents the cumulative probability
401
+ of the current value given the historical distribution, weighted by time decay.
402
+ Uses scipy.stats.norm.cdf with loc (mean) and scale (std) computed from CDF utilities.
403
+
404
+ The mean and std formula from the OpenAI paper:
405
+ https://arxiv.org/pdf/2305.18290.pdf
406
+
407
+
408
+ :param current_value: Current value to compute CDF feature for
409
+ :param time_series: List of (time, value) pairs for CDF calculation. time_series will be updated in-place each time this function is called.
410
+ :param decay_factor: Decay factor for CDF calculation (0 < factor < 1)
411
+ :param current_time: Current time step
412
+ :param time_window: Maximum number of time steps to keep in time series
413
+ :return: CDF feature value (cumulative probability from normal distribution)
414
+ """
415
+ # Add current observation to time series
416
+ time_series.append((current_time, current_value))
417
+
418
+ # Keep only the last time_window observations
419
+ if len(time_series) > time_window:
420
+ time_series[:] = time_series[-time_window:]
421
+
422
+ # If we don't have enough data, return 0.0
423
+ if len(time_series) < 2:
424
+ return 0.0
425
+
426
+ # Compute CDF-weighted mean (loc) and standard deviation (scale)
427
+ cdf_mean, cdf_std = compute_cdf_weighted_mean_and_std(time_series, decay_factor)
428
+
429
+ # Compute CDF feature using scipy.stats.norm.cdf
430
+ if cdf_std > 0:
431
+ # Use norm.cdf with loc=cdf_mean and scale=cdf_std
432
+ cdf_feature = norm.cdf(current_value, loc=cdf_mean, scale=cdf_std)
433
+ return cdf_feature
434
+ else:
435
+ # If the standard deviation is 0, return 0.0
436
+ return 0.0
@@ -0,0 +1,58 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # GLOBAL OBSERVERS PACKAGE
4
+ #
5
+ # This package contains all global observer classes used for collecting observations
6
+ # across the entire training process, not specific to individual agents.
7
+ #
8
+ # ======================================================================================================================
9
+
10
+
11
+ from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients
12
+ from .hyperparameter_observers import InitialHyperparameters, ModelFamilyOneHot, OptimizerTypeOneHot
13
+ from .loss_observers import (
14
+ LHOPTLossRatio,
15
+ LHOPTTrainingLoss,
16
+ LHOPTValidationLoss,
17
+ LossRatio,
18
+ PercentileOfLossAtEachCheckpoint,
19
+ TrainingLoss,
20
+ TrainingScore,
21
+ ValidationLoss,
22
+ ValidationScore,
23
+ )
24
+ from .model_observers import (
25
+ GlobalActivations,
26
+ GlobalLAMBTrustRatio,
27
+ GlobalParameters,
28
+ GlobalParameterUpdates,
29
+ NumberOfLayers,
30
+ NumberOfParameters,
31
+ )
32
+ from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, TrainingProgress
33
+
34
+ __all__ = [
35
+ InitialHyperparameters.__name__,
36
+ OptimizerTypeOneHot.__name__,
37
+ ModelFamilyOneHot.__name__,
38
+ TrainingLoss.__name__,
39
+ ValidationLoss.__name__,
40
+ LossRatio.__name__,
41
+ TrainingScore.__name__,
42
+ ValidationScore.__name__,
43
+ GlobalFirstOrderGradients.__name__,
44
+ GlobalSecondOrderGradients.__name__,
45
+ GlobalActivations.__name__,
46
+ GlobalParameterUpdates.__name__,
47
+ GlobalParameters.__name__,
48
+ GlobalLAMBTrustRatio.__name__,
49
+ NumberOfParameters.__name__,
50
+ NumberOfLayers.__name__,
51
+ TrainingProgress.__name__,
52
+ EpochsCompleted.__name__,
53
+ ProgressAtEachCheckpoint.__name__,
54
+ LHOPTTrainingLoss.__name__,
55
+ LHOPTValidationLoss.__name__,
56
+ LHOPTLossRatio.__name__,
57
+ PercentileOfLossAtEachCheckpoint.__name__,
58
+ ]
@@ -0,0 +1,183 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # BASE CLASSES
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from abc import ABC, abstractmethod
8
+ from typing import Any
9
+
10
+ from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
11
+ from libinephany.observations.observers.base_observers import GlobalObserver
12
+ from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
13
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
14
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
15
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
16
+
17
+
18
+ class LHOPTOuterStepBaseObserver(GlobalObserver, ABC):
19
+ """
20
+ Base class for LHOPT outer step observers to eliminate duplicate code.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
26
+ time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
27
+ **kwargs,
28
+ ) -> None:
29
+ """
30
+ :param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
31
+ :param time_window: Number of time steps to consider for CDF calculation
32
+ :param kwargs: Other observation keyword arguments.
33
+ """
34
+ super().__init__(**kwargs)
35
+ self.decay_factor = max(0.0, decay_factor)
36
+ self.time_window = max(1, time_window)
37
+
38
+ # Store time series data for CDF calculation
39
+ self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
40
+ self._current_time: float = 0.0
41
+
42
+ @property
43
+ def can_standardize(self) -> bool:
44
+ """
45
+ This observer has its own CDF calculation, no need to standardize.
46
+ :return: Whether the observation can be standardized.
47
+ """
48
+ return False
49
+
50
+ def _get_observation_format(self) -> StatisticStorageTypes:
51
+ """
52
+ :return: Format the observation returns data in. Must be one of the StatisticStorageTypes
53
+ enumeration class.
54
+ """
55
+ return StatisticStorageTypes.VECTOR
56
+
57
+ def _compute_cdf_feature(self, value: float) -> float:
58
+ """
59
+ Compute CDF feature for the given value.
60
+ training loss will be added to the time series after this call.
61
+ :param value: The value to compute CDF feature for
62
+ :return: CDF feature value
63
+ """
64
+ return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
65
+
66
+ def _update_time(self) -> None:
67
+ """Update the current time counter."""
68
+ self._current_time += 1.0
69
+
70
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
71
+ """
72
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
73
+ needed.
74
+ """
75
+ return {}
76
+
77
+ def reset(self) -> None:
78
+ """Reset the observer by clearing the time series."""
79
+ self._time_series = []
80
+ self._current_time = 0.0
81
+
82
+ @abstractmethod
83
+ def _observe(
84
+ self,
85
+ observation_inputs: ObservationInputs,
86
+ hyperparameter_states: HyperparameterStates,
87
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
88
+ action_taken: float | int | None,
89
+ ) -> float | int | list[int | float] | TensorStatistics:
90
+ """
91
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
92
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
93
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
94
+ :param action_taken: Action taken by the agent this class instance is assigned to.
95
+ """
96
+ raise NotImplementedError
97
+
98
+
99
+ class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
100
+ """
101
+ Base class for checkpoint-based observers to eliminate duplicate code.
102
+ """
103
+
104
+ def __init__(self, checkpoint_interval: int = LHOPT_CONSTANTS["DEFAULT_CHECKPOINT_INTERVAL"], **kwargs) -> None:
105
+ """
106
+ :param checkpoint_interval: How often to create checkpoints (in outer model steps).
107
+ :param kwargs: Miscellaneous keyword arguments.
108
+ """
109
+ super().__init__(**kwargs)
110
+ self.checkpoint_interval = checkpoint_interval
111
+ self._history: list[float] = []
112
+ self.last_value: float | None = None
113
+
114
+ @property
115
+ def can_standardize(self) -> bool:
116
+ """
117
+ This observer has its own CDF calculation, no need to standardize.
118
+ :return: Whether the observation can be standardized.
119
+ """
120
+ return False
121
+
122
+ def _get_observation_format(self) -> StatisticStorageTypes:
123
+ """
124
+ :return: Format the observation returns data in.
125
+ """
126
+ return StatisticStorageTypes.FLOAT
127
+
128
+ def _update_history(self, value: float) -> None:
129
+ """
130
+ Update the history with a new value and maintain sliding window.
131
+
132
+ :param value: The new value to add to history
133
+ """
134
+ self._history.append(value)
135
+
136
+ # Keep only the last checkpoint_interval values for sliding window
137
+ if len(self._history) > self.checkpoint_interval:
138
+ self._history = self._history[-self.checkpoint_interval :]
139
+
140
+ def _should_create_checkpoint(self) -> bool:
141
+ """
142
+ Check if we should create a checkpoint.
143
+
144
+ :return: True if checkpoint should be created, False otherwise
145
+ """
146
+ return len(self._history) >= self.checkpoint_interval
147
+
148
+ def _cold_start(self, value: float) -> None:
149
+ """
150
+ Handle cold start by setting the last value if not already set.
151
+
152
+ :param value: The value to set as last value if cold start
153
+ """
154
+ if self.last_value is None:
155
+ self.last_value = value
156
+
157
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
158
+ """
159
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
160
+ needed.
161
+ """
162
+ return {}
163
+
164
+ def reset(self) -> None:
165
+ """Reset the observer by clearing history."""
166
+ self._history = []
167
+ self.last_value = None
168
+
169
+ @abstractmethod
170
+ def _observe(
171
+ self,
172
+ observation_inputs: ObservationInputs,
173
+ hyperparameter_states: HyperparameterStates,
174
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
175
+ action_taken: float | int | None,
176
+ ) -> float | int | list[int | float] | TensorStatistics:
177
+ """
178
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
179
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
180
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
181
+ :param action_taken: Action taken by the agent this class instance is assigned to.
182
+ """
183
+ raise NotImplementedError
@@ -0,0 +1,33 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # CONSTANTS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from typing import TypedDict
8
+
9
+
10
+ class LHOPTConstants(TypedDict):
11
+ IS_NAN: float
12
+ NOT_NAN: float
13
+ IS_INF: float
14
+ NOT_INF: float
15
+ TANH_BOUND: float
16
+ DEFAULT_DECAY_FACTOR: float
17
+ DEFAULT_TIME_WINDOW: int
18
+ DEFAULT_CHECKPOINT_INTERVAL: int
19
+ DEFAULT_PERCENTILE: float
20
+
21
+
22
+ # Create the constants instance
23
+ LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
24
+ IS_NAN=1.0,
25
+ NOT_NAN=0.0,
26
+ IS_INF=1.0,
27
+ NOT_INF=0.0,
28
+ TANH_BOUND=10.0,
29
+ DEFAULT_DECAY_FACTOR=1.25,
30
+ DEFAULT_TIME_WINDOW=32,
31
+ DEFAULT_CHECKPOINT_INTERVAL=100,
32
+ DEFAULT_PERCENTILE=0.6,
33
+ )
@@ -0,0 +1,112 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # GRADIENT OBSERVERS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from typing import Any
8
+
9
+ from libinephany.observations import observation_utils, statistic_trackers
10
+ from libinephany.observations.observation_utils import StatisticStorageTypes
11
+ from libinephany.observations.observers.base_observers import GlobalObserver
12
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
13
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
14
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
15
+
16
+
17
+ class GlobalFirstOrderGradients(GlobalObserver):
18
+
19
+ def _get_observation_format(self) -> StatisticStorageTypes:
20
+ """
21
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
22
+ enumeration class.
23
+ """
24
+
25
+ return StatisticStorageTypes.TENSOR_STATISTICS
26
+
27
+ def _observe(
28
+ self,
29
+ observation_inputs: ObservationInputs,
30
+ hyperparameter_states: HyperparameterStates,
31
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
32
+ action_taken: float | int | None,
33
+ ) -> float | int | list[int | float] | TensorStatistics:
34
+ """
35
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
36
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
37
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
38
+ names to floats or TensorStatistic models.
39
+ :param action_taken: Action taken by the agent this class instance is assigned to.
40
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
41
+ """
42
+
43
+ statistics = tracked_statistics[statistic_trackers.FirstOrderGradients.__name__]
44
+
45
+ return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
46
+
47
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
48
+ """
49
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
50
+ needed.
51
+ """
52
+
53
+ return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
54
+
55
+
56
+ class GlobalSecondOrderGradients(GlobalObserver):
57
+
58
+ def __init__(
59
+ self,
60
+ *,
61
+ compute_hessian_diagonal: bool = False,
62
+ **kwargs,
63
+ ) -> None:
64
+ """
65
+ :param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
66
+ or use the squared first order gradients as approximations in the same way Adam does.
67
+ :param kwargs: Miscellaneous keyword arguments.
68
+ """
69
+
70
+ super().__init__(**kwargs)
71
+
72
+ self.compute_hessian_diagonal = compute_hessian_diagonal
73
+
74
+ def _get_observation_format(self) -> StatisticStorageTypes:
75
+ """
76
+ :return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
77
+ enumeration class.
78
+ """
79
+
80
+ return StatisticStorageTypes.TENSOR_STATISTICS
81
+
82
+ def _observe(
83
+ self,
84
+ observation_inputs: ObservationInputs,
85
+ hyperparameter_states: HyperparameterStates,
86
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
87
+ action_taken: float | int | None,
88
+ ) -> float | int | list[int | float] | TensorStatistics:
89
+ """
90
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
91
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
92
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
93
+ names to floats or TensorStatistic models.
94
+ :param action_taken: Action taken by the agent this class instance is assigned to.
95
+ :return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
96
+ """
97
+
98
+ statistics = tracked_statistics[statistic_trackers.SecondOrderGradients.__name__]
99
+
100
+ return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
101
+
102
+ def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
103
+ """
104
+ :return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
105
+ needed.
106
+ """
107
+
108
+ return {
109
+ statistic_trackers.SecondOrderGradients.__name__: dict(
110
+ skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
111
+ )
112
+ }