libinephany 0.16.2__tar.gz → 0.16.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany-0.16.4/CODE_VERSION.cfg +1 -0
- {libinephany-0.16.2/libinephany.egg-info → libinephany-0.16.4}/PKG-INFO +2 -1
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observation_utils.py +167 -13
- libinephany-0.16.4/libinephany/observations/observers/global_observers/__init__.py +58 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/base_classes.py +183 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/constants.py +33 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/gradient_observers.py +112 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/hyperparameter_observers.py +286 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/loss_observers.py +464 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/model_observers.py +327 -0
- libinephany-0.16.4/libinephany/observations/observers/global_observers/progress_observers.py +142 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/observer_containers.py +3 -1
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/constants.py +1 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/enums.py +18 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/samplers.py +36 -1
- {libinephany-0.16.2 → libinephany-0.16.4/libinephany.egg-info}/PKG-INFO +2 -1
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/SOURCES.txt +8 -1
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/requires.txt +1 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/pyproject.toml +1 -0
- libinephany-0.16.2/CODE_VERSION.cfg +0 -1
- libinephany-0.16.2/libinephany/observations/observers/global_observers.py +0 -991
- {libinephany-0.16.2 → libinephany-0.16.4}/LICENSE +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/MANIFEST.in +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/README.md +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/aws/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/aws/s3_functions.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observer_pipeline.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/base_observers.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/observers/local_observers.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/pipeline_coordinator.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/post_processors/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/post_processors/postprocessors.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/statistic_manager.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/observations/statistic_trackers.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/hyperparameter_configs.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/observer_config.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/configs/outer_model_config.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/agent_info.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/inner_task_profile.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/observation_models.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/request_schemas.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/response_schemas.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/schemas/tensor_statistics.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/states/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/pydantic_models/states/hyperparameter_states.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/agent_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/asyncio_worker.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/backend_statuses.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/directory_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/dropout_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/error_severities.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/exceptions.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/import_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/optim_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/random_seeds.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/standardizers.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/torch_distributed_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/torch_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/transforms.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/utils/typing.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/web_apps/__init__.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/web_apps/error_logger.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany/web_apps/web_app_utils.py +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/dependency_links.txt +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/libinephany.egg-info/top_level.txt +0 -0
- {libinephany-0.16.2 → libinephany-0.16.4}/setup.cfg +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
0.16.4
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: libinephany
|
3
|
-
Version: 0.16.
|
3
|
+
Version: 0.16.4
|
4
4
|
Summary: Inephany library containing code commonly used by multiple subpackages.
|
5
5
|
Author-email: Inephany <info@inephany.com>
|
6
6
|
License: Apache 2.0
|
@@ -18,6 +18,7 @@ Requires-Dist: pydantic<3.0.0,>=2.5.0
|
|
18
18
|
Requires-Dist: loguru<0.8.0,>=0.7.0
|
19
19
|
Requires-Dist: requests<3.0.0,>=2.28.0
|
20
20
|
Requires-Dist: numpy<2.0.0,>=1.24.0
|
21
|
+
Requires-Dist: scipy<2.0.0,>=1.10.0
|
21
22
|
Requires-Dist: slack-sdk<4.0.0,>=3.20.0
|
22
23
|
Requires-Dist: boto3<2.0.0,>=1.26.0
|
23
24
|
Requires-Dist: fastapi<0.116.0,>=0.100.0
|
@@ -13,6 +13,7 @@ import numpy as np
|
|
13
13
|
import pandas as pd
|
14
14
|
import torch
|
15
15
|
import torch.optim as optim
|
16
|
+
from scipy.stats import norm
|
16
17
|
|
17
18
|
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
18
19
|
from libinephany.utils import optim_utils
|
@@ -24,6 +25,9 @@ from libinephany.utils import optim_utils
|
|
24
25
|
# ======================================================================================================================
|
25
26
|
|
26
27
|
EXP_AVERAGE = "exp_avg"
|
28
|
+
MIN_DECAY_FACTOR = 1e-10
|
29
|
+
|
30
|
+
MIN_TOTAL_WEIGHT = 1e-15 # Minimum total weight threshold for numerical stability
|
27
31
|
|
28
32
|
# ======================================================================================================================
|
29
33
|
#
|
@@ -60,20 +64,8 @@ def get_exponential_weighted_average(values: list[int | float]) -> float:
|
|
60
64
|
:param values: List of values to average via EWA.
|
61
65
|
:return: EWA of the given values.
|
62
66
|
"""
|
63
|
-
|
64
|
-
# Check for NaN and infinite values in input
|
65
|
-
valid_values = [float(val) for val in values if not math.isnan(float(val)) and not math.isinf(float(val))]
|
66
|
-
|
67
|
-
if not valid_values:
|
68
|
-
raise ValueError("Cannot compute exponential weighted average on empty list")
|
69
|
-
|
70
|
-
if len(valid_values) == 1:
|
71
|
-
return valid_values[0]
|
72
|
-
|
73
|
-
exp_weighted_average = pd.Series(valid_values).ewm(alpha=0.1).mean().iloc[-1]
|
67
|
+
exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
|
74
68
|
assert isinstance(exp_weighted_average, float)
|
75
|
-
assert not math.isnan(exp_weighted_average)
|
76
|
-
assert not math.isinf(exp_weighted_average)
|
77
69
|
return exp_weighted_average
|
78
70
|
|
79
71
|
|
@@ -280,3 +272,165 @@ def concatenate_lists(lists: list[list[Any]]) -> list[Any]:
|
|
280
272
|
"""
|
281
273
|
|
282
274
|
return list(chain(*lists))
|
275
|
+
|
276
|
+
|
277
|
+
def compute_cdf_weighted_mean_and_std(
|
278
|
+
time_series: list[tuple[float, float]], decay_factor: float
|
279
|
+
) -> tuple[float, float]:
|
280
|
+
"""
|
281
|
+
Compute the CDF-weighted standard deviation using the same exponential decay weights
|
282
|
+
as the mean calculation, with numerical integration.
|
283
|
+
|
284
|
+
:param time_series: List of (time, value) pairs
|
285
|
+
:param decay_factor: Decay factor b in the exponential weight formula b in [1.25, 2.5, 5, 10, 20]
|
286
|
+
:return: Tuple of (weighted mean, weighted standard deviation)
|
287
|
+
"""
|
288
|
+
|
289
|
+
if len(time_series) == 0:
|
290
|
+
return 0.0, 0.0
|
291
|
+
|
292
|
+
if len(time_series) == 1:
|
293
|
+
return time_series[0][1], 0.0
|
294
|
+
|
295
|
+
sorted_series = sorted(time_series, key=lambda x: x[0])
|
296
|
+
|
297
|
+
# Handle the special case when decay_factor = 1.0
|
298
|
+
if abs(decay_factor - 1.0) < MIN_DECAY_FACTOR:
|
299
|
+
# When decay_factor = 1.0, w(t) = 1 for all t
|
300
|
+
# So the result is just the arithmetic mean
|
301
|
+
values = [v for _, v in sorted_series]
|
302
|
+
mean = float(np.mean(values))
|
303
|
+
std = float(np.std(values))
|
304
|
+
return mean, std
|
305
|
+
|
306
|
+
log_decay_factor = math.log(decay_factor)
|
307
|
+
|
308
|
+
total_weight = 0.0 # ∫ w(t) dt - total weight across all time intervals
|
309
|
+
total_weighted_value = 0.0 # ∫ w(t) y(t) dt - total weighted value
|
310
|
+
total_weighted_squared = 0.0 # ∫ w(t) y(t)² dt - total weighted squared value
|
311
|
+
|
312
|
+
for time_series_index in range(len(sorted_series) - 1):
|
313
|
+
start_time_point = sorted_series[time_series_index][0]
|
314
|
+
end_time_point = sorted_series[time_series_index + 1][0]
|
315
|
+
start_value = sorted_series[time_series_index][1]
|
316
|
+
end_value = sorted_series[time_series_index + 1][1]
|
317
|
+
|
318
|
+
time_interval = end_time_point - start_time_point
|
319
|
+
assert time_interval > 0, "Time interval must be positive"
|
320
|
+
|
321
|
+
interval_value = _weighted_interval_expectation(
|
322
|
+
start_time_point=start_time_point,
|
323
|
+
start_value=start_value,
|
324
|
+
end_time_point=end_time_point,
|
325
|
+
end_value=end_value,
|
326
|
+
log_decay_factor=log_decay_factor,
|
327
|
+
)
|
328
|
+
interval_squared_value = _weighted_interval_expectation(
|
329
|
+
start_time_point=start_time_point,
|
330
|
+
start_value=start_value**2,
|
331
|
+
end_time_point=end_time_point,
|
332
|
+
end_value=end_value**2,
|
333
|
+
log_decay_factor=log_decay_factor,
|
334
|
+
)
|
335
|
+
|
336
|
+
total_weighted_value += interval_value
|
337
|
+
total_weighted_squared += interval_squared_value
|
338
|
+
|
339
|
+
total_weight = (1 / log_decay_factor) * (
|
340
|
+
math.exp(log_decay_factor * sorted_series[-1][0]) - math.exp(log_decay_factor * sorted_series[0][0])
|
341
|
+
)
|
342
|
+
# Check if total weight is too small (numerical stability)
|
343
|
+
if total_weight < MIN_TOTAL_WEIGHT:
|
344
|
+
values = [v for _, v in sorted_series]
|
345
|
+
mean = float(np.mean(values))
|
346
|
+
std = float(np.std(values))
|
347
|
+
return mean, std
|
348
|
+
|
349
|
+
# Calculate weighted mean: μ = ∫ w(t) y(t) dt / ∫ w(t) dt
|
350
|
+
# This gives us the expected value under the weight distribution
|
351
|
+
weighted_mean = float(total_weighted_value / total_weight)
|
352
|
+
|
353
|
+
# Calculate weighted variance: Var = ∫ w(t) y(t)² dt / ∫ w(t) dt - μ²
|
354
|
+
# This follows from the definition: Var(X) = E[X²] - (E[X])²
|
355
|
+
# where E[X] = ∫ w(t) y(t) dt / ∫ w(t) dt and E[X²] = ∫ w(t) y(t)² dt / ∫ w(t) dt
|
356
|
+
weighted_variance = float(total_weighted_squared / total_weight - weighted_mean**2)
|
357
|
+
|
358
|
+
# Calculate weighted standard deviation: σ = √Var
|
359
|
+
# This is the square root of the variance, representing the spread of values
|
360
|
+
weighted_std = float(math.sqrt(max(0, weighted_variance)))
|
361
|
+
|
362
|
+
return weighted_mean, weighted_std
|
363
|
+
|
364
|
+
|
365
|
+
def _weighted_interval_expectation(
|
366
|
+
start_time_point: float,
|
367
|
+
start_value: float,
|
368
|
+
end_time_point: float,
|
369
|
+
end_value: float,
|
370
|
+
log_decay_factor: float,
|
371
|
+
) -> float:
|
372
|
+
"""
|
373
|
+
Computes the weighted interval expectation from Appendix E of the LHOPT paper.
|
374
|
+
|
375
|
+
:param start_time_point: the start time value of the interval.
|
376
|
+
:param start_value: the value at start_time_point.
|
377
|
+
:param end_time_point: the end time value of the interval.
|
378
|
+
:param end_value: the value at end_time_point.
|
379
|
+
:param log_decay_factor: the logarithm of the decay factor used to weight the expectation.
|
380
|
+
:return: the exponentially-weighted expectation of the linear interpolation between the start and end points.
|
381
|
+
"""
|
382
|
+
|
383
|
+
interval_gradient = (end_value - start_value) / (end_time_point - start_time_point)
|
384
|
+
start_exp_time = math.exp(log_decay_factor * start_time_point)
|
385
|
+
end_exp_time = math.exp(log_decay_factor * end_time_point)
|
386
|
+
return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time) - (
|
387
|
+
1 / log_decay_factor**2
|
388
|
+
) * interval_gradient * (end_exp_time - start_exp_time)
|
389
|
+
|
390
|
+
|
391
|
+
def compute_cdf_feature(
|
392
|
+
current_value: float,
|
393
|
+
time_series: list[tuple[float, float]],
|
394
|
+
decay_factor: float,
|
395
|
+
current_time: float,
|
396
|
+
time_window: int,
|
397
|
+
) -> float:
|
398
|
+
"""
|
399
|
+
|
400
|
+
This function computes a CDF feature that represents the cumulative probability
|
401
|
+
of the current value given the historical distribution, weighted by time decay.
|
402
|
+
Uses scipy.stats.norm.cdf with loc (mean) and scale (std) computed from CDF utilities.
|
403
|
+
|
404
|
+
The mean and std formula from the OpenAI paper:
|
405
|
+
https://arxiv.org/pdf/2305.18290.pdf
|
406
|
+
|
407
|
+
|
408
|
+
:param current_value: Current value to compute CDF feature for
|
409
|
+
:param time_series: List of (time, value) pairs for CDF calculation. time_series will be updated in-place each time this function is called.
|
410
|
+
:param decay_factor: Decay factor for CDF calculation (0 < factor < 1)
|
411
|
+
:param current_time: Current time step
|
412
|
+
:param time_window: Maximum number of time steps to keep in time series
|
413
|
+
:return: CDF feature value (cumulative probability from normal distribution)
|
414
|
+
"""
|
415
|
+
# Add current observation to time series
|
416
|
+
time_series.append((current_time, current_value))
|
417
|
+
|
418
|
+
# Keep only the last time_window observations
|
419
|
+
if len(time_series) > time_window:
|
420
|
+
time_series[:] = time_series[-time_window:]
|
421
|
+
|
422
|
+
# If we don't have enough data, return 0.0
|
423
|
+
if len(time_series) < 2:
|
424
|
+
return 0.0
|
425
|
+
|
426
|
+
# Compute CDF-weighted mean (loc) and standard deviation (scale)
|
427
|
+
cdf_mean, cdf_std = compute_cdf_weighted_mean_and_std(time_series, decay_factor)
|
428
|
+
|
429
|
+
# Compute CDF feature using scipy.stats.norm.cdf
|
430
|
+
if cdf_std > 0:
|
431
|
+
# Use norm.cdf with loc=cdf_mean and scale=cdf_std
|
432
|
+
cdf_feature = norm.cdf(current_value, loc=cdf_mean, scale=cdf_std)
|
433
|
+
return cdf_feature
|
434
|
+
else:
|
435
|
+
# If the standard deviation is 0, return 0.0
|
436
|
+
return 0.0
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# GLOBAL OBSERVERS PACKAGE
|
4
|
+
#
|
5
|
+
# This package contains all global observer classes used for collecting observations
|
6
|
+
# across the entire training process, not specific to individual agents.
|
7
|
+
#
|
8
|
+
# ======================================================================================================================
|
9
|
+
|
10
|
+
|
11
|
+
from .gradient_observers import GlobalFirstOrderGradients, GlobalSecondOrderGradients
|
12
|
+
from .hyperparameter_observers import InitialHyperparameters, ModelFamilyOneHot, OptimizerTypeOneHot
|
13
|
+
from .loss_observers import (
|
14
|
+
LHOPTLossRatio,
|
15
|
+
LHOPTTrainingLoss,
|
16
|
+
LHOPTValidationLoss,
|
17
|
+
LossRatio,
|
18
|
+
PercentileOfLossAtEachCheckpoint,
|
19
|
+
TrainingLoss,
|
20
|
+
TrainingScore,
|
21
|
+
ValidationLoss,
|
22
|
+
ValidationScore,
|
23
|
+
)
|
24
|
+
from .model_observers import (
|
25
|
+
GlobalActivations,
|
26
|
+
GlobalLAMBTrustRatio,
|
27
|
+
GlobalParameters,
|
28
|
+
GlobalParameterUpdates,
|
29
|
+
NumberOfLayers,
|
30
|
+
NumberOfParameters,
|
31
|
+
)
|
32
|
+
from .progress_observers import EpochsCompleted, ProgressAtEachCheckpoint, TrainingProgress
|
33
|
+
|
34
|
+
__all__ = [
|
35
|
+
InitialHyperparameters.__name__,
|
36
|
+
OptimizerTypeOneHot.__name__,
|
37
|
+
ModelFamilyOneHot.__name__,
|
38
|
+
TrainingLoss.__name__,
|
39
|
+
ValidationLoss.__name__,
|
40
|
+
LossRatio.__name__,
|
41
|
+
TrainingScore.__name__,
|
42
|
+
ValidationScore.__name__,
|
43
|
+
GlobalFirstOrderGradients.__name__,
|
44
|
+
GlobalSecondOrderGradients.__name__,
|
45
|
+
GlobalActivations.__name__,
|
46
|
+
GlobalParameterUpdates.__name__,
|
47
|
+
GlobalParameters.__name__,
|
48
|
+
GlobalLAMBTrustRatio.__name__,
|
49
|
+
NumberOfParameters.__name__,
|
50
|
+
NumberOfLayers.__name__,
|
51
|
+
TrainingProgress.__name__,
|
52
|
+
EpochsCompleted.__name__,
|
53
|
+
ProgressAtEachCheckpoint.__name__,
|
54
|
+
LHOPTTrainingLoss.__name__,
|
55
|
+
LHOPTValidationLoss.__name__,
|
56
|
+
LHOPTLossRatio.__name__,
|
57
|
+
PercentileOfLossAtEachCheckpoint.__name__,
|
58
|
+
]
|
@@ -0,0 +1,183 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# BASE CLASSES
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes, compute_cdf_feature
|
11
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
12
|
+
from libinephany.observations.observers.global_observers.constants import LHOPT_CONSTANTS
|
13
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
14
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
15
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
16
|
+
|
17
|
+
|
18
|
+
class LHOPTOuterStepBaseObserver(GlobalObserver, ABC):
|
19
|
+
"""
|
20
|
+
Base class for LHOPT outer step observers to eliminate duplicate code.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
decay_factor: float = LHOPT_CONSTANTS["DEFAULT_DECAY_FACTOR"],
|
26
|
+
time_window: int = LHOPT_CONSTANTS["DEFAULT_TIME_WINDOW"],
|
27
|
+
**kwargs,
|
28
|
+
) -> None:
|
29
|
+
"""
|
30
|
+
:param decay_factor: Decay factor for CDF calculation in [1, 2.5, 5, 10, 20]
|
31
|
+
:param time_window: Number of time steps to consider for CDF calculation
|
32
|
+
:param kwargs: Other observation keyword arguments.
|
33
|
+
"""
|
34
|
+
super().__init__(**kwargs)
|
35
|
+
self.decay_factor = max(0.0, decay_factor)
|
36
|
+
self.time_window = max(1, time_window)
|
37
|
+
|
38
|
+
# Store time series data for CDF calculation
|
39
|
+
self._time_series: list[tuple[float, float]] = [] # (time, value) pairs
|
40
|
+
self._current_time: float = 0.0
|
41
|
+
|
42
|
+
@property
|
43
|
+
def can_standardize(self) -> bool:
|
44
|
+
"""
|
45
|
+
This observer has its own CDF calculation, no need to standardize.
|
46
|
+
:return: Whether the observation can be standardized.
|
47
|
+
"""
|
48
|
+
return False
|
49
|
+
|
50
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
51
|
+
"""
|
52
|
+
:return: Format the observation returns data in. Must be one of the StatisticStorageTypes
|
53
|
+
enumeration class.
|
54
|
+
"""
|
55
|
+
return StatisticStorageTypes.VECTOR
|
56
|
+
|
57
|
+
def _compute_cdf_feature(self, value: float) -> float:
|
58
|
+
"""
|
59
|
+
Compute CDF feature for the given value.
|
60
|
+
training loss will be added to the time series after this call.
|
61
|
+
:param value: The value to compute CDF feature for
|
62
|
+
:return: CDF feature value
|
63
|
+
"""
|
64
|
+
return compute_cdf_feature(value, self._time_series, self.decay_factor, self._current_time, self.time_window)
|
65
|
+
|
66
|
+
def _update_time(self) -> None:
|
67
|
+
"""Update the current time counter."""
|
68
|
+
self._current_time += 1.0
|
69
|
+
|
70
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
71
|
+
"""
|
72
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
73
|
+
needed.
|
74
|
+
"""
|
75
|
+
return {}
|
76
|
+
|
77
|
+
def reset(self) -> None:
|
78
|
+
"""Reset the observer by clearing the time series."""
|
79
|
+
self._time_series = []
|
80
|
+
self._current_time = 0.0
|
81
|
+
|
82
|
+
@abstractmethod
|
83
|
+
def _observe(
|
84
|
+
self,
|
85
|
+
observation_inputs: ObservationInputs,
|
86
|
+
hyperparameter_states: HyperparameterStates,
|
87
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
88
|
+
action_taken: float | int | None,
|
89
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
90
|
+
"""
|
91
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
92
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
93
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
94
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
95
|
+
"""
|
96
|
+
raise NotImplementedError
|
97
|
+
|
98
|
+
|
99
|
+
class LHOPTCheckpointBaseObserver(GlobalObserver, ABC):
|
100
|
+
"""
|
101
|
+
Base class for checkpoint-based observers to eliminate duplicate code.
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(self, checkpoint_interval: int = LHOPT_CONSTANTS["DEFAULT_CHECKPOINT_INTERVAL"], **kwargs) -> None:
|
105
|
+
"""
|
106
|
+
:param checkpoint_interval: How often to create checkpoints (in outer model steps).
|
107
|
+
:param kwargs: Miscellaneous keyword arguments.
|
108
|
+
"""
|
109
|
+
super().__init__(**kwargs)
|
110
|
+
self.checkpoint_interval = checkpoint_interval
|
111
|
+
self._history: list[float] = []
|
112
|
+
self.last_value: float | None = None
|
113
|
+
|
114
|
+
@property
|
115
|
+
def can_standardize(self) -> bool:
|
116
|
+
"""
|
117
|
+
This observer has its own CDF calculation, no need to standardize.
|
118
|
+
:return: Whether the observation can be standardized.
|
119
|
+
"""
|
120
|
+
return False
|
121
|
+
|
122
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
123
|
+
"""
|
124
|
+
:return: Format the observation returns data in.
|
125
|
+
"""
|
126
|
+
return StatisticStorageTypes.FLOAT
|
127
|
+
|
128
|
+
def _update_history(self, value: float) -> None:
|
129
|
+
"""
|
130
|
+
Update the history with a new value and maintain sliding window.
|
131
|
+
|
132
|
+
:param value: The new value to add to history
|
133
|
+
"""
|
134
|
+
self._history.append(value)
|
135
|
+
|
136
|
+
# Keep only the last checkpoint_interval values for sliding window
|
137
|
+
if len(self._history) > self.checkpoint_interval:
|
138
|
+
self._history = self._history[-self.checkpoint_interval :]
|
139
|
+
|
140
|
+
def _should_create_checkpoint(self) -> bool:
|
141
|
+
"""
|
142
|
+
Check if we should create a checkpoint.
|
143
|
+
|
144
|
+
:return: True if checkpoint should be created, False otherwise
|
145
|
+
"""
|
146
|
+
return len(self._history) >= self.checkpoint_interval
|
147
|
+
|
148
|
+
def _cold_start(self, value: float) -> None:
|
149
|
+
"""
|
150
|
+
Handle cold start by setting the last value if not already set.
|
151
|
+
|
152
|
+
:param value: The value to set as last value if cold start
|
153
|
+
"""
|
154
|
+
if self.last_value is None:
|
155
|
+
self.last_value = value
|
156
|
+
|
157
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
158
|
+
"""
|
159
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
160
|
+
needed.
|
161
|
+
"""
|
162
|
+
return {}
|
163
|
+
|
164
|
+
def reset(self) -> None:
|
165
|
+
"""Reset the observer by clearing history."""
|
166
|
+
self._history = []
|
167
|
+
self.last_value = None
|
168
|
+
|
169
|
+
@abstractmethod
|
170
|
+
def _observe(
|
171
|
+
self,
|
172
|
+
observation_inputs: ObservationInputs,
|
173
|
+
hyperparameter_states: HyperparameterStates,
|
174
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
175
|
+
action_taken: float | int | None,
|
176
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
177
|
+
"""
|
178
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
179
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
180
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
181
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
182
|
+
"""
|
183
|
+
raise NotImplementedError
|
@@ -0,0 +1,33 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# CONSTANTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from typing import TypedDict
|
8
|
+
|
9
|
+
|
10
|
+
class LHOPTConstants(TypedDict):
|
11
|
+
IS_NAN: float
|
12
|
+
NOT_NAN: float
|
13
|
+
IS_INF: float
|
14
|
+
NOT_INF: float
|
15
|
+
TANH_BOUND: float
|
16
|
+
DEFAULT_DECAY_FACTOR: float
|
17
|
+
DEFAULT_TIME_WINDOW: int
|
18
|
+
DEFAULT_CHECKPOINT_INTERVAL: int
|
19
|
+
DEFAULT_PERCENTILE: float
|
20
|
+
|
21
|
+
|
22
|
+
# Create the constants instance
|
23
|
+
LHOPT_CONSTANTS: LHOPTConstants = LHOPTConstants(
|
24
|
+
IS_NAN=1.0,
|
25
|
+
NOT_NAN=0.0,
|
26
|
+
IS_INF=1.0,
|
27
|
+
NOT_INF=0.0,
|
28
|
+
TANH_BOUND=10.0,
|
29
|
+
DEFAULT_DECAY_FACTOR=1.25,
|
30
|
+
DEFAULT_TIME_WINDOW=32,
|
31
|
+
DEFAULT_CHECKPOINT_INTERVAL=100,
|
32
|
+
DEFAULT_PERCENTILE=0.6,
|
33
|
+
)
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# GRADIENT OBSERVERS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from libinephany.observations import observation_utils, statistic_trackers
|
10
|
+
from libinephany.observations.observation_utils import StatisticStorageTypes
|
11
|
+
from libinephany.observations.observers.base_observers import GlobalObserver
|
12
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
13
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
14
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
15
|
+
|
16
|
+
|
17
|
+
class GlobalFirstOrderGradients(GlobalObserver):
|
18
|
+
|
19
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
20
|
+
"""
|
21
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
22
|
+
enumeration class.
|
23
|
+
"""
|
24
|
+
|
25
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
26
|
+
|
27
|
+
def _observe(
|
28
|
+
self,
|
29
|
+
observation_inputs: ObservationInputs,
|
30
|
+
hyperparameter_states: HyperparameterStates,
|
31
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
32
|
+
action_taken: float | int | None,
|
33
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
34
|
+
"""
|
35
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
36
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
37
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
38
|
+
names to floats or TensorStatistic models.
|
39
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
40
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
41
|
+
"""
|
42
|
+
|
43
|
+
statistics = tracked_statistics[statistic_trackers.FirstOrderGradients.__name__]
|
44
|
+
|
45
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
46
|
+
|
47
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
48
|
+
"""
|
49
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
50
|
+
needed.
|
51
|
+
"""
|
52
|
+
|
53
|
+
return {statistic_trackers.FirstOrderGradients.__name__: dict(skip_statistics=self.skip_statistics)}
|
54
|
+
|
55
|
+
|
56
|
+
class GlobalSecondOrderGradients(GlobalObserver):
|
57
|
+
|
58
|
+
def __init__(
|
59
|
+
self,
|
60
|
+
*,
|
61
|
+
compute_hessian_diagonal: bool = False,
|
62
|
+
**kwargs,
|
63
|
+
) -> None:
|
64
|
+
"""
|
65
|
+
:param compute_hessian_diagonal: Whether to compute the Hessian diagonal to determine second order gradients
|
66
|
+
or use the squared first order gradients as approximations in the same way Adam does.
|
67
|
+
:param kwargs: Miscellaneous keyword arguments.
|
68
|
+
"""
|
69
|
+
|
70
|
+
super().__init__(**kwargs)
|
71
|
+
|
72
|
+
self.compute_hessian_diagonal = compute_hessian_diagonal
|
73
|
+
|
74
|
+
def _get_observation_format(self) -> StatisticStorageTypes:
|
75
|
+
"""
|
76
|
+
:return: Format the observation returns data in. Must be one of the enum attributes in the StatisticStorageTypes
|
77
|
+
enumeration class.
|
78
|
+
"""
|
79
|
+
|
80
|
+
return StatisticStorageTypes.TENSOR_STATISTICS
|
81
|
+
|
82
|
+
def _observe(
|
83
|
+
self,
|
84
|
+
observation_inputs: ObservationInputs,
|
85
|
+
hyperparameter_states: HyperparameterStates,
|
86
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
87
|
+
action_taken: float | int | None,
|
88
|
+
) -> float | int | list[int | float] | TensorStatistics:
|
89
|
+
"""
|
90
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
91
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
92
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
93
|
+
names to floats or TensorStatistic models.
|
94
|
+
:param action_taken: Action taken by the agent this class instance is assigned to.
|
95
|
+
:return: Single float/int, list of floats/ints or TensorStatistics model to add to the observation vector.
|
96
|
+
"""
|
97
|
+
|
98
|
+
statistics = tracked_statistics[statistic_trackers.SecondOrderGradients.__name__]
|
99
|
+
|
100
|
+
return observation_utils.average_tensor_statistics(tensor_statistics=list(statistics.values())) # type: ignore
|
101
|
+
|
102
|
+
def get_required_trackers(self) -> dict[str, dict[str, Any] | None]:
|
103
|
+
"""
|
104
|
+
:return: Dictionary mapping statistic tracker class names to kwargs for the class or None if no kwargs are
|
105
|
+
needed.
|
106
|
+
"""
|
107
|
+
|
108
|
+
return {
|
109
|
+
statistic_trackers.SecondOrderGradients.__name__: dict(
|
110
|
+
skip_statistics=self.skip_statistics, compute_hessian_diagonal=self.compute_hessian_diagonal
|
111
|
+
)
|
112
|
+
}
|