libinephany 0.16.1__py3-none-any.whl → 0.16.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,7 @@ import numpy as np
13
13
  import pandas as pd
14
14
  import torch
15
15
  import torch.optim as optim
16
+ from scipy.stats import norm
16
17
 
17
18
  from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
18
19
  from libinephany.utils import optim_utils
@@ -24,6 +25,9 @@ from libinephany.utils import optim_utils
24
25
  # ======================================================================================================================
25
26
 
26
27
  EXP_AVERAGE = "exp_avg"
28
+ MIN_DECAY_FACTOR = 1e-10
29
+
30
+ MIN_TOTAL_WEIGHT = 1e-15 # Minimum total weight threshold for numerical stability
27
31
 
28
32
  # ======================================================================================================================
29
33
  #
@@ -280,3 +284,165 @@ def concatenate_lists(lists: list[list[Any]]) -> list[Any]:
280
284
  """
281
285
 
282
286
  return list(chain(*lists))
287
+
288
+
289
+ def compute_cdf_weighted_mean_and_std(
290
+ time_series: list[tuple[float, float]], decay_factor: float
291
+ ) -> tuple[float, float]:
292
+ """
293
+ Compute the CDF-weighted standard deviation using the same exponential decay weights
294
+ as the mean calculation, with numerical integration.
295
+
296
+ :param time_series: List of (time, value) pairs
297
+ :param decay_factor: Decay factor b in the exponential weight formula b in [1.25, 2.5, 5, 10, 20]
298
+ :return: Tuple of (weighted mean, weighted standard deviation)
299
+ """
300
+
301
+ if len(time_series) == 0:
302
+ return 0.0, 0.0
303
+
304
+ if len(time_series) == 1:
305
+ return time_series[0][1], 0.0
306
+
307
+ sorted_series = sorted(time_series, key=lambda x: x[0])
308
+
309
+ # Handle the special case when decay_factor = 1.0
310
+ if abs(decay_factor - 1.0) < MIN_DECAY_FACTOR:
311
+ # When decay_factor = 1.0, w(t) = 1 for all t
312
+ # So the result is just the arithmetic mean
313
+ values = [v for _, v in sorted_series]
314
+ mean = float(np.mean(values))
315
+ std = float(np.std(values))
316
+ return mean, std
317
+
318
+ log_decay_factor = math.log(decay_factor)
319
+
320
+ total_weight = 0.0 # ∫ w(t) dt - total weight across all time intervals
321
+ total_weighted_value = 0.0 # ∫ w(t) y(t) dt - total weighted value
322
+ total_weighted_squared = 0.0 # ∫ w(t) y(t)² dt - total weighted squared value
323
+
324
+ for time_series_index in range(len(sorted_series) - 1):
325
+ start_time_point = sorted_series[time_series_index][0]
326
+ end_time_point = sorted_series[time_series_index + 1][0]
327
+ start_value = sorted_series[time_series_index][1]
328
+ end_value = sorted_series[time_series_index + 1][1]
329
+
330
+ time_interval = end_time_point - start_time_point
331
+ assert time_interval > 0, "Time interval must be positive"
332
+
333
+ interval_value = _weighted_interval_expectation(
334
+ start_time_point=start_time_point,
335
+ start_value=start_value,
336
+ end_time_point=end_time_point,
337
+ end_value=end_value,
338
+ log_decay_factor=log_decay_factor,
339
+ )
340
+ interval_squared_value = _weighted_interval_expectation(
341
+ start_time_point=start_time_point,
342
+ start_value=start_value**2,
343
+ end_time_point=end_time_point,
344
+ end_value=end_value**2,
345
+ log_decay_factor=log_decay_factor,
346
+ )
347
+
348
+ total_weighted_value += interval_value
349
+ total_weighted_squared += interval_squared_value
350
+
351
+ total_weight = (1 / log_decay_factor) * (
352
+ math.exp(log_decay_factor * sorted_series[-1][0]) - math.exp(log_decay_factor * sorted_series[0][0])
353
+ )
354
+ # Check if total weight is too small (numerical stability)
355
+ if total_weight < MIN_TOTAL_WEIGHT:
356
+ values = [v for _, v in sorted_series]
357
+ mean = float(np.mean(values))
358
+ std = float(np.std(values))
359
+ return mean, std
360
+
361
+ # Calculate weighted mean: μ = ∫ w(t) y(t) dt / ∫ w(t) dt
362
+ # This gives us the expected value under the weight distribution
363
+ weighted_mean = float(total_weighted_value / total_weight)
364
+
365
+ # Calculate weighted variance: Var = ∫ w(t) y(t)² dt / ∫ w(t) dt - μ²
366
+ # This follows from the definition: Var(X) = E[X²] - (E[X])²
367
+ # where E[X] = ∫ w(t) y(t) dt / ∫ w(t) dt and E[X²] = ∫ w(t) y(t)² dt / ∫ w(t) dt
368
+ weighted_variance = float(total_weighted_squared / total_weight - weighted_mean**2)
369
+
370
+ # Calculate weighted standard deviation: σ = √Var
371
+ # This is the square root of the variance, representing the spread of values
372
+ weighted_std = float(math.sqrt(max(0, weighted_variance)))
373
+
374
+ return weighted_mean, weighted_std
375
+
376
+
377
+ def _weighted_interval_expectation(
378
+ start_time_point: float,
379
+ start_value: float,
380
+ end_time_point: float,
381
+ end_value: float,
382
+ log_decay_factor: float,
383
+ ) -> float:
384
+ """
385
+ Computes the weighted interval expectation from Appendix E of the LHOPT paper.
386
+
387
+ :param start_time_point: the start time value of the interval.
388
+ :param start_value: the value at start_time_point.
389
+ :param end_time_point: the end time value of the interval.
390
+ :param end_value: the value at end_time_point.
391
+ :param log_decay_factor: the logarithm of the decay factor used to weight the expectation.
392
+ :return: the exponentially-weighted expectation of the linear interpolation between the start and end points.
393
+ """
394
+
395
+ interval_gradient = (end_value - start_value) / (end_time_point - start_time_point)
396
+ start_exp_time = math.exp(log_decay_factor * start_time_point)
397
+ end_exp_time = math.exp(log_decay_factor * end_time_point)
398
+ return (1 / log_decay_factor) * (end_value * end_exp_time - start_value * start_exp_time) + (
399
+ 1 / log_decay_factor**2
400
+ ) * interval_gradient * (end_exp_time - start_exp_time)
401
+
402
+
403
+ def compute_cdf_feature(
404
+ current_value: float,
405
+ time_series: list[tuple[float, float]],
406
+ decay_factor: float,
407
+ current_time: float,
408
+ time_window: int,
409
+ ) -> float:
410
+ """
411
+
412
+ This function computes a CDF feature that represents the cumulative probability
413
+ of the current value given the historical distribution, weighted by time decay.
414
+ Uses scipy.stats.norm.cdf with loc (mean) and scale (std) computed from CDF utilities.
415
+
416
+ The mean and std formula from the OpenAI paper:
417
+ https://arxiv.org/pdf/2305.18290.pdf
418
+
419
+
420
+ :param current_value: Current value to compute CDF feature for
421
+ :param time_series: List of (time, value) pairs for CDF calculation. time_series will be updated in-place each time this function is called.
422
+ :param decay_factor: Decay factor for CDF calculation (0 < factor < 1)
423
+ :param current_time: Current time step
424
+ :param time_window: Maximum number of time steps to keep in time series
425
+ :return: CDF feature value (cumulative probability from normal distribution)
426
+ """
427
+ # Add current observation to time series
428
+ time_series.append((current_time, current_value))
429
+
430
+ # Keep only the last time_window observations
431
+ if len(time_series) > time_window:
432
+ time_series[:] = time_series[-time_window:]
433
+
434
+ # If we don't have enough data, return 0.0
435
+ if len(time_series) < 2:
436
+ return 0.0
437
+
438
+ # Compute CDF-weighted mean (loc) and standard deviation (scale)
439
+ cdf_mean, cdf_std = compute_cdf_weighted_mean_and_std(time_series, decay_factor)
440
+
441
+ # Compute CDF feature using scipy.stats.norm.cdf
442
+ if cdf_std > 0:
443
+ # Use norm.cdf with loc=cdf_mean and scale=cdf_std
444
+ cdf_feature = norm.cdf(current_value, loc=cdf_mean, scale=cdf_std)
445
+ return cdf_feature
446
+ else:
447
+ # If the standard deviation is 0, return 0.0
448
+ return 0.0
@@ -80,3 +80,4 @@ PREFIXES_TO_HPARAMS = {
80
80
  AGENT_PREFIX_SGD_MOMENTUM: SGD_MOMENTUM,
81
81
  AGENT_GRADIENT_ACCUMULATION: GRADIENT_ACCUMULATION,
82
82
  }
83
+ HPARAMS_TO_PREFIXES = {hparam: prefix for prefix, hparam in PREFIXES_TO_HPARAMS.items()}
@@ -78,6 +78,24 @@ class AgentTypes(EnumWithIndices):
78
78
  Tokens = TOKENS
79
79
  Samples = SAMPLES
80
80
 
81
+ @classmethod
82
+ def get_possible_active_agents(cls) -> list["AgentTypes"]:
83
+ """
84
+ :return: List of active agents.
85
+ """
86
+
87
+ return [
88
+ cls.LearningRateAgent,
89
+ cls.WeightDecayAgent,
90
+ cls.DropoutAgent,
91
+ cls.GradientClippingAgent,
92
+ cls.AdamBetaOneAgent,
93
+ cls.AdamBetaTwoAgent,
94
+ cls.AdamEpsAgent,
95
+ cls.SGDMomentumAgent,
96
+ cls.GradientAccumulationAgent,
97
+ ]
98
+
81
99
 
82
100
  class ModelFamilies(EnumWithIndices):
83
101
 
@@ -68,6 +68,13 @@ class Sampler:
68
68
 
69
69
  raise NotImplementedError
70
70
 
71
+ @classmethod
72
+ def get_subclasses(cls):
73
+ """Recursively gets subclasses of the Sampler class."""
74
+ for subclass in cls.__subclasses__():
75
+ yield from subclass.get_subclasses()
76
+ yield subclass
77
+
71
78
 
72
79
  class LogUniformSampler(Sampler):
73
80
 
@@ -228,6 +235,34 @@ class DiscreteValueSampler(Sampler):
228
235
  ).astype(self.sample_dtype)
229
236
 
230
237
 
238
+ class DiscreteValueListSampler(DiscreteValueSampler):
239
+
240
+ def __init__(
241
+ self,
242
+ length: int,
243
+ discrete_values: list[float | int | str],
244
+ sample_dtype: type[np.generic | float | int | str] = np.float64,
245
+ **kwargs,
246
+ ) -> None:
247
+ """
248
+ :param length: Length of list to sample.
249
+ :param discrete_values: List of discrete values to sample from.
250
+ :param kwargs: Miscellaneous keyword arguments.
251
+ """
252
+
253
+ super().__init__(discrete_values=discrete_values, sample_dtype=sample_dtype)
254
+ self.list_length = length
255
+
256
+ def sample(self, number_of_samples: int = 1, **kwargs) -> list[np.ndarray | list[Any]]:
257
+ """
258
+ :param number_of_samples: Number of samples to make.
259
+ :param kwargs: Miscellaneous keyword arguments.
260
+ :return: Array of sampled values.
261
+ """
262
+
263
+ return [super().sample(number_of_samples=self.list_length) for _ in range(number_of_samples)]
264
+
265
+
231
266
  class RoundRobinDiscreteValueSampler(Sampler):
232
267
 
233
268
  def __init__(
@@ -287,7 +322,7 @@ def build_sampler(sampler_name: str, lower_bound: float | int, upper_bound: floa
287
322
  :return: Constructed sampler.
288
323
  """
289
324
 
290
- possible_samplers = {sampler_type.__name__: sampler_type for sampler_type in Sampler.__subclasses__()}
325
+ possible_samplers = {sampler_type.__name__: sampler_type for sampler_type in Sampler.get_subclasses()}
291
326
 
292
327
  try:
293
328
  return possible_samplers[sampler_name](lower_bound=lower_bound, upper_bound=upper_bound, **kwargs) # type: ignore
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.16.1
3
+ Version: 0.16.3
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -18,6 +18,7 @@ Requires-Dist: pydantic<3.0.0,>=2.5.0
18
18
  Requires-Dist: loguru<0.8.0,>=0.7.0
19
19
  Requires-Dist: requests<3.0.0,>=2.28.0
20
20
  Requires-Dist: numpy<2.0.0,>=1.24.0
21
+ Requires-Dist: scipy<2.0.0,>=1.10.0
21
22
  Requires-Dist: slack-sdk<4.0.0,>=3.20.0
22
23
  Requires-Dist: boto3<2.0.0,>=1.26.0
23
24
  Requires-Dist: fastapi<0.116.0,>=0.100.0
@@ -2,7 +2,7 @@ libinephany/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- libinephany/observations/observation_utils.py,sha256=z3WfEf7Dvj8sS0FmbPpjolPeFWX64SVUJF5Rydf3Whs,9949
5
+ libinephany/observations/observation_utils.py,sha256=wsCxVIhtCmJpaTKq9AcYsJGc9WK5qO_RE4DK_fzBE8w,16703
6
6
  libinephany/observations/observer_pipeline.py,sha256=RvMH-TTDTu1Nk4S_KSHDkII1YuIRMSOXkPhn6g4B9ow,12815
7
7
  libinephany/observations/pipeline_coordinator.py,sha256=mw3c5jy_BWvNigUKNjIWMpReOjxFDblzOcWtsIkcls4,7907
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
@@ -32,16 +32,16 @@ libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU
32
32
  libinephany/utils/agent_utils.py,sha256=_2w1AY5Y4mQ5hes_Rq014VhZXOtIOn-W92mZgeixv3g,2658
33
33
  libinephany/utils/asyncio_worker.py,sha256=Ew23zKIbG1zwyCudcyiObMrw4G0f3p2QXzZfM4mePqI,2751
34
34
  libinephany/utils/backend_statuses.py,sha256=ZbpBPbz0qKmeqxyGGN_ePTrQ7Wrxh7KM6W26UDbPXtQ,644
35
- libinephany/utils/constants.py,sha256=piawYQa51vCxxAHCH3YoWOgUhTlgqgQxKMCenkoQTsc,2170
35
+ libinephany/utils/constants.py,sha256=Qh8iz5o1R4UDVVCB69jOQPX2SLWRCncpb_2yTHpFSbY,2259
36
36
  libinephany/utils/directory_utils.py,sha256=408unVeE_5_Hm-ZYZuxc9sdvfuU0CgYELX7EzPlPieo,1217
37
37
  libinephany/utils/dropout_utils.py,sha256=X43yCW7Dh1cC5sNnivgS5j1fn871K_RCvxCBTT0YHKg,3392
38
- libinephany/utils/enums.py,sha256=kEECkJO2quKAyVAqzgOzOP-d4qIENE3z_RyymSvyIB8,2420
38
+ libinephany/utils/enums.py,sha256=6_6k_1I2BwYTIfquUOsoaQT5fkhMXUWtwCxLoTYuFyU,2906
39
39
  libinephany/utils/error_severities.py,sha256=B9oidqOVaYOe0W6P6GwjpmuDsrkyTX30v1xdiUStCFk,1427
40
40
  libinephany/utils/exceptions.py,sha256=kgwLpHOgy3kciUz_I18xnYsWRtzdonfadUtwG2uDYk8,1823
41
41
  libinephany/utils/import_utils.py,sha256=WzC6V6UIa0nCiU2MekROwG82fWBh9RuVzichtby5EvM,1495
42
42
  libinephany/utils/optim_utils.py,sha256=-PLqsyuq4ZH3spBy_olNB3yuLwvhnLrCF0384elCmXc,8777
43
43
  libinephany/utils/random_seeds.py,sha256=eF-ErrMShu8mp9V_gXrB_iUxR-Lb-OtHypEEUQAGn2Y,1565
44
- libinephany/utils/samplers.py,sha256=uyVGAy5cm5bCyWMOuySJmzUc_vFuieO_3zydJciwdv4,12158
44
+ libinephany/utils/samplers.py,sha256=7h_el2dLJi2J97f_zpvc4BrEzoM_EJgZk1-ZjRkOhZ8,13357
45
45
  libinephany/utils/standardizers.py,sha256=pG1K_XL4OR_NjVtT6Hjbln1dk1BtQdDuSK1PQTkA17Y,8014
46
46
  libinephany/utils/torch_distributed_utils.py,sha256=UPMfhdZZwyHX_r3h55AAK4PcB-zFtjK37Z5aawAKNmE,2968
47
47
  libinephany/utils/torch_utils.py,sha256=o5TsqrXe6Id04P6SqB_avGBRZutbu6IBB61llAHQ_PY,2696
@@ -50,8 +50,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
50
50
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
52
52
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
53
- libinephany-0.16.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
- libinephany-0.16.1.dist-info/METADATA,sha256=qqXRHyzLSH1dm1SlIbn2dthXuQ-WH00OsbTvM8RmwcE,8354
55
- libinephany-0.16.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- libinephany-0.16.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
- libinephany-0.16.1.dist-info/RECORD,,
53
+ libinephany-0.16.3.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
+ libinephany-0.16.3.dist-info/METADATA,sha256=qMiO9s8TRo6kshtkrv79aGT1BYQFjO55-9th2Wm7rdk,8390
55
+ libinephany-0.16.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ libinephany-0.16.3.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
+ libinephany-0.16.3.dist-info/RECORD,,