libinephany 0.14.1__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,7 +4,7 @@
4
4
  #
5
5
  # ======================================================================================================================
6
6
 
7
- from typing import Any
7
+ from typing import Any, cast
8
8
 
9
9
  from pydantic import BaseModel, ConfigDict, ValidationError, field_serializer, field_validator, model_validator
10
10
 
@@ -232,6 +232,23 @@ class BatchSizeHParamConfig(HParamConfig):
232
232
  sample_discrete_values: list[float | int] | None = None
233
233
 
234
234
 
235
+ class GradientAccumulationHParamConfig(HParamConfig):
236
+ max_hparam_value: float | int = 64
237
+ min_hparam_value: float | int = 1
238
+ hparam_dtype: type[float | int] = int
239
+ initial_value: int = 1
240
+ initial_delta: float = 0.0
241
+ scale: float = 1.0
242
+
243
+ sampler: str = "DiscreteRangeSampler"
244
+ sample_initial_values: bool = False
245
+ sample_lower_bound: int = 1
246
+ sample_upper_bound: int = 64
247
+ sample_step: int = 1
248
+ sample_discrete_values: list[float | int] | None = None
249
+ force_limit: float | int = 64
250
+
251
+
235
252
  class EpochsHParamConfig(HParamConfig):
236
253
  max_hparam_value: float | int = 16
237
254
  min_hparam_value: float | int = 1
@@ -289,6 +306,7 @@ class HParamConfigs(BaseModel):
289
306
  sgd_momentum_config: HParamConfig = SGDMomentumHParamConfig()
290
307
 
291
308
  batch_size_config: HParamConfig = BatchSizeHParamConfig()
309
+ gradient_accumulation_config: GradientAccumulationHParamConfig = GradientAccumulationHParamConfig()
292
310
  epochs_config: HParamConfig = EpochsHParamConfig()
293
311
  token_config: HParamConfig = TokensHParamConfig()
294
312
  samples_config: HParamConfig = SamplesHParamConfig()
@@ -351,6 +369,9 @@ class HParamConfigs(BaseModel):
351
369
  case AgentTypes.BatchSize:
352
370
  self.batch_size_config = hparam_config
353
371
 
372
+ case AgentTypes.GradientAccumulationAgent:
373
+ self.gradient_accumulation_config = cast(GradientAccumulationHParamConfig, hparam_config)
374
+
354
375
  case AgentTypes.Epochs:
355
376
  self.epochs_config = hparam_config
356
377
 
@@ -400,6 +421,9 @@ class HParamConfigs(BaseModel):
400
421
  case AgentTypes.BatchSize:
401
422
  return self.batch_size_config
402
423
 
424
+ case AgentTypes.GradientAccumulationAgent:
425
+ return self.gradient_accumulation_config
426
+
403
427
  case AgentTypes.Epochs:
404
428
  return self.epochs_config
405
429
 
@@ -140,6 +140,40 @@ class InnerTaskProfiles(BaseModel):
140
140
 
141
141
  return sum(self.compiled_action_sizes.values())
142
142
 
143
+ @property
144
+ def max_total_observation_size(self) -> int:
145
+ """
146
+ :return: The summed observation size of all agents with the task that has the most layers.
147
+ """
148
+
149
+ if not self.profiles:
150
+ raise ValueError(
151
+ "No profiles to calculate max total observation size. Ensure profiles have been "
152
+ "added before executing the training loop"
153
+ )
154
+
155
+ largest_task_name = max(self.profiles, key=lambda k: self.profiles[k].number_of_layers)
156
+ largest_task = self.profiles[largest_task_name]
157
+
158
+ return sum(largest_task.observation_space_sizes.values())
159
+
160
+ @property
161
+ def max_total_action_size(self) -> int:
162
+ """
163
+ :return: The summed action size of all agents with the task that has the most layers.
164
+ """
165
+
166
+ if not self.profiles:
167
+ raise ValueError(
168
+ "No profiles to calculate max total action size. Ensure profiles have been "
169
+ "added before executing the training loop"
170
+ )
171
+
172
+ largest_task_name = max(self.profiles, key=lambda k: self.profiles[k].number_of_layers)
173
+ largest_task = self.profiles[largest_task_name]
174
+
175
+ return sum(largest_task.action_space_sizes.values())
176
+
143
177
  @staticmethod
144
178
  def _compile_gym_space_sizes(spaces: dict[str, dict[str, int]]) -> dict[str, int]:
145
179
  """
@@ -20,6 +20,7 @@ from libinephany.utils.constants import (
20
20
  DROPOUT,
21
21
  EPOCHS,
22
22
  GRAD_NORM_CLIP,
23
+ GRADIENT_ACCUMULATION,
23
24
  LEARNING_RATE,
24
25
  SAMPLES,
25
26
  SGD_MOMENTUM,
@@ -60,6 +61,7 @@ class UpdateCallbacks(BaseModel):
60
61
  sgd_momentum: Callable[..., None]
61
62
 
62
63
  batch_size: Callable[..., None] | None
64
+ gradient_accumulation: Callable[..., None] | None
63
65
  epochs: Callable[..., None] | None
64
66
 
65
67
  def __getitem__(self, item: str) -> Callable[..., None] | None:
@@ -457,6 +459,7 @@ class ParameterGroupHParams(HyperparameterContainer):
457
459
  class GlobalHParams(HyperparameterContainer):
458
460
 
459
461
  batch_size: Hyperparameter
462
+ gradient_accumulation: Hyperparameter
460
463
  epochs: Hyperparameter
461
464
  tokens: Hyperparameter
462
465
  samples: Hyperparameter
@@ -550,6 +553,14 @@ class HyperparameterStates(BaseModel):
550
553
  """
551
554
  return self.global_hparams.batch_size
552
555
 
556
+ @computed_field # type: ignore[misc]
557
+ @property
558
+ def gradient_accumulation(self) -> Hyperparameter:
559
+ """
560
+ :return: The gradient accumulation steps of the inner model.
561
+ """
562
+ return self.global_hparams.gradient_accumulation
563
+
553
564
  @computed_field # type: ignore[misc]
554
565
  @property
555
566
  def epochs(self) -> Hyperparameter:
@@ -676,6 +687,7 @@ class HyperparameterStates(BaseModel):
676
687
 
677
688
  return {
678
689
  BATCH_SIZE: hparam_configs.batch_size_config,
690
+ GRADIENT_ACCUMULATION: hparam_configs.gradient_accumulation_config,
679
691
  EPOCHS: hparam_configs.epochs_config,
680
692
  TOKENS: hparam_configs.token_config,
681
693
  SAMPLES: hparam_configs.samples_config,
@@ -21,6 +21,7 @@ ADAM_BETA_TWO = "adam_beta_two"
21
21
  ADAM_EPS = "adam_eps"
22
22
  SGD_MOMENTUM = "sgd_momentum"
23
23
  BATCH_SIZE = "batch_size"
24
+ GRADIENT_ACCUMULATION = "gradient_accumulation"
24
25
  EPOCHS = "epochs"
25
26
  TOKENS = "tokens"
26
27
  SAMPLES = "samples"
@@ -41,6 +42,7 @@ AGENT_PREFIX_EPS = "adam-eps"
41
42
  AGENT_PREFIX_SGD_MOMENTUM = "sgd-momentum"
42
43
 
43
44
  AGENT_BATCH_SIZE = "batch-size"
45
+ AGENT_GRADIENT_ACCUMULATION = "gradient-accumulation"
44
46
 
45
47
  AGENT_BANDIT_SUFFIX = "bandit-agent"
46
48
 
@@ -53,6 +55,7 @@ AGENT_TYPES = [
53
55
  ADAM_BETA_TWO,
54
56
  ADAM_EPS,
55
57
  SGD_MOMENTUM,
58
+ GRADIENT_ACCUMULATION,
56
59
  ]
57
60
  SUFFIXES = [AGENT_BANDIT_SUFFIX]
58
61
  PREFIXES = [
@@ -64,6 +67,7 @@ PREFIXES = [
64
67
  AGENT_PREFIX_BETA_TWO,
65
68
  AGENT_PREFIX_EPS,
66
69
  AGENT_PREFIX_SGD_MOMENTUM,
70
+ AGENT_GRADIENT_ACCUMULATION,
67
71
  ]
68
72
  PREFIXES_TO_HPARAMS = {
69
73
  AGENT_PREFIX_LR: LEARNING_RATE,
@@ -74,4 +78,5 @@ PREFIXES_TO_HPARAMS = {
74
78
  AGENT_PREFIX_BETA_TWO: ADAM_BETA_TWO,
75
79
  AGENT_PREFIX_EPS: ADAM_EPS,
76
80
  AGENT_PREFIX_SGD_MOMENTUM: SGD_MOMENTUM,
81
+ AGENT_GRADIENT_ACCUMULATION: GRADIENT_ACCUMULATION,
77
82
  }
@@ -14,6 +14,7 @@ from libinephany.utils.constants import (
14
14
  DROPOUT,
15
15
  EPOCHS,
16
16
  GRAD_NORM_CLIP,
17
+ GRADIENT_ACCUMULATION,
17
18
  LEARNING_RATE,
18
19
  SAMPLES,
19
20
  SGD_MOMENTUM,
@@ -69,6 +70,7 @@ class AgentTypes(EnumWithIndices):
69
70
  AdamBetaTwoAgent = ADAM_BETA_TWO
70
71
  AdamEpsAgent = ADAM_EPS
71
72
  SGDMomentumAgent = SGD_MOMENTUM
73
+ GradientAccumulationAgent = GRADIENT_ACCUMULATION
72
74
 
73
75
  # Deprecated or Non-Agent
74
76
  BatchSize = BATCH_SIZE
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.14.1
3
+ Version: 0.15.0
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -16,26 +16,26 @@ libinephany/observations/post_processors/__init__.py,sha256=47DEQpj8HBSa-_TImW-5
16
16
  libinephany/observations/post_processors/postprocessors.py,sha256=43_e5UaDPr2KbAvqc_w3wLqnlm7bgRjqgCtyQ95-8cM,5913
17
17
  libinephany/pydantic_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  libinephany/pydantic_models/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- libinephany/pydantic_models/configs/hyperparameter_configs.py,sha256=5BH7aPiFtSfhTKp-Z0WTbCYuGJvUTskgFgMttxpDZb0,13654
19
+ libinephany/pydantic_models/configs/hyperparameter_configs.py,sha256=FYl8A2_9L-ohg36aZEW5kREO3tcqIyztYpW62s99tqY,14562
20
20
  libinephany/pydantic_models/configs/observer_config.py,sha256=v_ChzaVXC_rlZ7eDZPuCae1DdG7-PS3mPwC-OaWpGQo,1355
21
21
  libinephany/pydantic_models/configs/outer_model_config.py,sha256=GQ0QBSC2Xht8x8X_TEMfYM2GF_x1kErLuFrA_H6Jhs0,1209
22
22
  libinephany/pydantic_models/schemas/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
23
  libinephany/pydantic_models/schemas/agent_info.py,sha256=me5gDxvZjP9TNK588mpUvxiiJrPDqy3Z7ZHRzryAYTs,2628
24
- libinephany/pydantic_models/schemas/inner_task_profile.py,sha256=wgclLbDqDjuk-_v2f66Eb6btf3RDzz6leCZvSthQde4,10464
24
+ libinephany/pydantic_models/schemas/inner_task_profile.py,sha256=Xu0tQmhGwV043tTamFiHekuE1RRXhhrUrGbtymjXo7g,11722
25
25
  libinephany/pydantic_models/schemas/observation_models.py,sha256=YjQmrWZ0r-_LRp92jvhSD8p1grKsMVXCXoou4q15Ue8,1849
26
26
  libinephany/pydantic_models/schemas/request_schemas.py,sha256=VED8eAUvBofxeAx9gWU8DyCZOTVD3QsHRq-TO7kyOqk,1260
27
27
  libinephany/pydantic_models/schemas/response_schemas.py,sha256=SKFuasdjX5aH_I0vT3SwnpwhyMf9cNPB1ZpDeAGgoO8,2158
28
28
  libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=Z-x-Fi_Dm0pLoHI88DnJO1krY671o0zbGRzx-gXPtVY,7534
29
29
  libinephany/pydantic_models/states/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- libinephany/pydantic_models/states/hyperparameter_states.py,sha256=JWNSmQmncnQu2GTh0MeT8pvPLxttyrhsKXErucdURqQ,32223
30
+ libinephany/pydantic_models/states/hyperparameter_states.py,sha256=fwqUmRbT5WxcnMPK8DmRXkBQOtCs9n6V24BeCyFTFL8,32688
31
31
  libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
32
  libinephany/utils/agent_utils.py,sha256=_2w1AY5Y4mQ5hes_Rq014VhZXOtIOn-W92mZgeixv3g,2658
33
33
  libinephany/utils/asyncio_worker.py,sha256=Ew23zKIbG1zwyCudcyiObMrw4G0f3p2QXzZfM4mePqI,2751
34
34
  libinephany/utils/backend_statuses.py,sha256=ZbpBPbz0qKmeqxyGGN_ePTrQ7Wrxh7KM6W26UDbPXtQ,644
35
- libinephany/utils/constants.py,sha256=RH-0fZe6SL1WrrsrW4KP4k7ClQZDq8IFEdok3hEnRt4,1952
35
+ libinephany/utils/constants.py,sha256=piawYQa51vCxxAHCH3YoWOgUhTlgqgQxKMCenkoQTsc,2170
36
36
  libinephany/utils/directory_utils.py,sha256=408unVeE_5_Hm-ZYZuxc9sdvfuU0CgYELX7EzPlPieo,1217
37
37
  libinephany/utils/dropout_utils.py,sha256=X43yCW7Dh1cC5sNnivgS5j1fn871K_RCvxCBTT0YHKg,3392
38
- libinephany/utils/enums.py,sha256=YH10mUhW4kjYS0cp4XUASok9vfPl0jv9ZhS3HpZD0Zg,2339
38
+ libinephany/utils/enums.py,sha256=kEECkJO2quKAyVAqzgOzOP-d4qIENE3z_RyymSvyIB8,2420
39
39
  libinephany/utils/error_severities.py,sha256=B9oidqOVaYOe0W6P6GwjpmuDsrkyTX30v1xdiUStCFk,1427
40
40
  libinephany/utils/exceptions.py,sha256=kgwLpHOgy3kciUz_I18xnYsWRtzdonfadUtwG2uDYk8,1823
41
41
  libinephany/utils/import_utils.py,sha256=WzC6V6UIa0nCiU2MekROwG82fWBh9RuVzichtby5EvM,1495
@@ -50,8 +50,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
50
50
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
51
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
52
52
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
53
- libinephany-0.14.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
- libinephany-0.14.1.dist-info/METADATA,sha256=k82YNw3190axmX8DPGPpYHSChw5nvJojLns-JNAhXw8,8354
55
- libinephany-0.14.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
- libinephany-0.14.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
- libinephany-0.14.1.dist-info/RECORD,,
53
+ libinephany-0.15.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
54
+ libinephany-0.15.0.dist-info/METADATA,sha256=lU7SqV1ArMEAyuZ845Z1jAYxNUEYGfJ8Tl6Df6EwSpc,8354
55
+ libinephany-0.15.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
56
+ libinephany-0.15.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
57
+ libinephany-0.15.0.dist-info/RECORD,,