libinephany 0.17.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -218,7 +218,9 @@ class ObserverPipeline:
218
218
  tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
219
219
  actions_taken: dict[str, float | int | None],
220
220
  return_dict: bool = False,
221
- ) -> tuple[dict[str, list[float | int]], bool, dict[str, dict[str, list[float | int]]] | None]:
221
+ ) -> tuple[
222
+ dict[str, list[float | int]], bool, dict[str, dict[str, list[float | int] | dict[str, float | int]]] | None
223
+ ]:
222
224
  """
223
225
  :param observation_inputs: Observation input metrics not calculated with statistic trackers.
224
226
  :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
@@ -241,7 +243,7 @@ class ObserverPipeline:
241
243
  )
242
244
 
243
245
  local_obs: dict[str, list[float | int]] = {}
244
- obs_as_dict: dict[str, dict[str, list[float | int]]] = {}
246
+ obs_as_dict: dict[str, dict[str, list[float | int] | dict[str, float | int]]] = {}
245
247
 
246
248
  for agent_id, agent_observers in self.local_observers.items():
247
249
  local_obs[agent_id], local_obs_dict = agent_observers.observe(
@@ -204,16 +204,22 @@ class Observer(ABC):
204
204
  hyperparameter_states: HyperparameterStates,
205
205
  tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
206
206
  action_taken: float | int | None,
207
- ) -> list[float | int]:
207
+ return_dict: bool = False,
208
+ ) -> tuple[list[float | int], dict[str, float] | None]:
208
209
  """
209
210
  :param observation_inputs: Observation input metrics not calculated with statistic trackers.
210
211
  :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
211
212
  :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
212
213
  names to floats or TensorStatistic models.
213
214
  :param action_taken: Action taken by the agent this class instance is assigned to.
214
- :return: List of floats or integers to add to the agent's observation vector.
215
+ :param return_dict: Whether to return a dictionary of observations as well as the normal vector.
216
+ :return: Tuple of:
217
+ - List of floats or integers to add to the agent's observation vector.
218
+ - Dictionary of specific observation values if the storage type is TensorStatistics and None otherwise.
215
219
  """
216
220
 
221
+ observations_dict: dict[str, float] | None = None
222
+
217
223
  observations = self._observe(
218
224
  observation_inputs=observation_inputs,
219
225
  hyperparameter_states=hyperparameter_states,
@@ -225,6 +231,9 @@ class Observer(ABC):
225
231
  self._cached_observation = deepcopy(observations)
226
232
 
227
233
  if self.observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
234
+ if return_dict:
235
+ observations_dict = observations.as_observation_dict() # type: ignore
236
+
228
237
  observations = observations.to_list(skip_statistics=self.skip_statistics) # type: ignore
229
238
 
230
239
  observations = [observations] if not isinstance(observations, list) else observations # type: ignore
@@ -241,7 +250,7 @@ class Observer(ABC):
241
250
  if not self._validated_observation:
242
251
  self._validate_observation(observations=observations)
243
252
 
244
- return observations
253
+ return observations, observations_dict
245
254
 
246
255
  @final
247
256
  def inform(self) -> float | int | dict[str, float] | None:
@@ -135,7 +135,7 @@ class ObserverContainer(ABC):
135
135
  tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
136
136
  action_taken: float | int | None,
137
137
  return_dict: bool = False,
138
- ) -> tuple[list[float | int], dict[str, list[float | int]] | None]:
138
+ ) -> tuple[list[float | int], dict[str, list[float | int] | dict[str, float | int]] | None]:
139
139
  """
140
140
  :param observation_inputs: Observation input metrics not calculated with statistic trackers.
141
141
  :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
@@ -148,17 +148,21 @@ class ObserverContainer(ABC):
148
148
  """
149
149
 
150
150
  observations = []
151
- observations_dict = {}
151
+ observations_dict: dict[str, list[float | int] | dict[str, float | int]] = {}
152
152
 
153
153
  for observer in self._observers:
154
- observer_obs = observer.observe(
154
+ observer_obs, observer_obs_dict = observer.observe(
155
155
  observation_inputs=observation_inputs,
156
156
  hyperparameter_states=hyperparameter_states,
157
157
  tracked_statistics=tracked_statistics,
158
158
  action_taken=action_taken,
159
+ return_dict=return_dict,
159
160
  )
160
161
 
161
- if return_dict:
162
+ if return_dict and observer_obs_dict is not None:
163
+ observations_dict[observer.__class__.__name__] = observer_obs_dict
164
+
165
+ elif return_dict:
162
166
  observations_dict[observer.__class__.__name__] = observer_obs
163
167
 
164
168
  observations += observer_obs
@@ -134,7 +134,7 @@ class ObserverPipelineCoordinator:
134
134
 
135
135
  clipped_observations = False
136
136
  observations = {}
137
- observations_as_dict: dict[str, dict[str, list[float | int]]] = {}
137
+ observations_as_dict: dict[str, dict[str, list[float | int] | dict[str, float | int]]] = {}
138
138
 
139
139
  for pipeline in self.pipelines:
140
140
  pipeline_observations, pipeline_clipped_observations, pipeline_observations_dict = pipeline.observe(
@@ -51,7 +51,7 @@ class Observations(BaseModel):
51
51
  hit_invalid_value: bool
52
52
 
53
53
  agent_observations: dict[str, list[float | int]]
54
- observations_as_dict: dict[str, dict[str, list[float | int]]] | None = None
54
+ observations_as_dict: dict[str, dict[str, list[float | int] | dict[str, float]]] | None = None
55
55
 
56
56
  def observations_as_arrays(self, dtype: DTypeLike = np.float32) -> dict[str, np.ndarray]:
57
57
  """
@@ -7,6 +7,14 @@
7
7
  import torch
8
8
  from pydantic import BaseModel
9
9
 
10
+ # ======================================================================================================================
11
+ #
12
+ # CONSTANTS
13
+ #
14
+ # ======================================================================================================================
15
+
16
+ STRIP_SUFFIX = "_"
17
+
10
18
  # ======================================================================================================================
11
19
  #
12
20
  # CLASSES
@@ -252,3 +260,13 @@ class TensorStatistics(BaseModel):
252
260
  self.inter_quartile_range,
253
261
  ]
254
262
  )
263
+
264
+ def as_observation_dict(self) -> dict[str, float]:
265
+ """
266
+ :return: Dictionary of observation values.
267
+ """
268
+
269
+ return {
270
+ field[:-1] if field.endswith(STRIP_SUFFIX) else field: field_value
271
+ for field, field_value in self.model_dump().items()
272
+ }
@@ -43,7 +43,7 @@ AGENT_PREFIX_EPS = "adam-eps"
43
43
  AGENT_PREFIX_SGD_MOMENTUM = "sgd-momentum"
44
44
 
45
45
  AGENT_BATCH_SIZE = "batch-size"
46
- AGENT_GRADIENT_ACCUMULATION = "gradient-accumulation"
46
+ AGENT_PREFIX_GRADIENT_ACCUMULATION = "gradient-accumulation"
47
47
 
48
48
  AGENT_BANDIT_SUFFIX = "bandit-agent"
49
49
 
@@ -68,7 +68,7 @@ PREFIXES = [
68
68
  AGENT_PREFIX_BETA_TWO,
69
69
  AGENT_PREFIX_EPS,
70
70
  AGENT_PREFIX_SGD_MOMENTUM,
71
- AGENT_GRADIENT_ACCUMULATION,
71
+ AGENT_PREFIX_GRADIENT_ACCUMULATION,
72
72
  ]
73
73
  PREFIXES_TO_HPARAMS = {
74
74
  AGENT_PREFIX_LR: LEARNING_RATE,
@@ -79,6 +79,6 @@ PREFIXES_TO_HPARAMS = {
79
79
  AGENT_PREFIX_BETA_TWO: ADAM_BETA_TWO,
80
80
  AGENT_PREFIX_EPS: ADAM_EPS,
81
81
  AGENT_PREFIX_SGD_MOMENTUM: SGD_MOMENTUM,
82
- AGENT_GRADIENT_ACCUMULATION: GRADIENT_ACCUMULATION,
82
+ AGENT_PREFIX_GRADIENT_ACCUMULATION: GRADIENT_ACCUMULATION,
83
83
  }
84
84
  HPARAMS_TO_PREFIXES = {hparam: prefix for prefix, hparam in PREFIXES_TO_HPARAMS.items()}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: libinephany
3
- Version: 0.17.0
3
+ Version: 0.18.1
4
4
  Summary: Inephany library containing code commonly used by multiple subpackages.
5
5
  Author-email: Inephany <info@inephany.com>
6
6
  License: Apache 2.0
@@ -3,14 +3,14 @@ libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
4
4
  libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  libinephany/observations/observation_utils.py,sha256=ejQ-hzKq_MX7A0304KypScmzloReNrH8dmcgonoRl4Q,16266
6
- libinephany/observations/observer_pipeline.py,sha256=ySIXueu_SRONewygsye4EYxWoJvjhLx7M-MKGZpDOgo,12915
7
- libinephany/observations/pipeline_coordinator.py,sha256=mw3c5jy_BWvNigUKNjIWMpReOjxFDblzOcWtsIkcls4,7907
6
+ libinephany/observations/observer_pipeline.py,sha256=_xA4vrijhG8-9MCtGXnKAEmpd6q0nKVpJgY_qSbypIA,12979
7
+ libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPyauCM-795qOzyqwGOSdw,7932
8
8
  libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
9
9
  libinephany/observations/statistic_trackers.py,sha256=06NjCy1MI865oU0KB5f-wQE3b2RvYawOOWxNJx4rFpw,32939
10
10
  libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- libinephany/observations/observers/base_observers.py,sha256=RkG5SW0b6Ooy0_oscRHxyB_YFNP7k8fxu37jBZElxIM,15418
11
+ libinephany/observations/observers/base_observers.py,sha256=Tkk2AvQ6mRY6fdhhxHQWohA0xSHGdPqxul1C7C7Frj4,15924
12
12
  libinephany/observations/observers/local_observers.py,sha256=7azTW227-rG_QJR5a0StfE0O4Ca1boMV9nrCphnhWik,40344
13
- libinephany/observations/observers/observer_containers.py,sha256=Inz4vkqaXsyj09EJYaWWx1eOIuzmOjbiul5wRL5Gm5Y,9274
13
+ libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
14
14
  libinephany/observations/observers/global_observers/__init__.py,sha256=lQO-nq7hILu4F3ddFXcCR-ghfv0dzEA9nAYxZta7rxk,2306
15
15
  libinephany/observations/observers/global_observers/base_classes.py,sha256=CCkRx86Lll3gFzfqervP0jKdzNFKkKU7tEBh8ic1Yrc,8249
16
16
  libinephany/observations/observers/global_observers/constants.py,sha256=olXYxh353Th6hyhfk85kHxAWPeaDpuPkcz_awFvEz6c,1054
@@ -29,17 +29,17 @@ libinephany/pydantic_models/configs/outer_model_config.py,sha256=GQ0QBSC2Xht8x8X
29
29
  libinephany/pydantic_models/schemas/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  libinephany/pydantic_models/schemas/agent_info.py,sha256=me5gDxvZjP9TNK588mpUvxiiJrPDqy3Z7ZHRzryAYTs,2628
31
31
  libinephany/pydantic_models/schemas/inner_task_profile.py,sha256=Xu0tQmhGwV043tTamFiHekuE1RRXhhrUrGbtymjXo7g,11722
32
- libinephany/pydantic_models/schemas/observation_models.py,sha256=lOLlCAFemDJh9Ml3hublncC3MwkKVQcz5PWTDMIYOEU,2091
32
+ libinephany/pydantic_models/schemas/observation_models.py,sha256=MLhxqDet9Yol1D5mkQGQsQT23sm37AStRLnPc4sgcZc,2110
33
33
  libinephany/pydantic_models/schemas/request_schemas.py,sha256=VED8eAUvBofxeAx9gWU8DyCZOTVD3QsHRq-TO7kyOqk,1260
34
34
  libinephany/pydantic_models/schemas/response_schemas.py,sha256=SKFuasdjX5aH_I0vT3SwnpwhyMf9cNPB1ZpDeAGgoO8,2158
35
- libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=Z-x-Fi_Dm0pLoHI88DnJO1krY671o0zbGRzx-gXPtVY,7534
35
+ libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=ZrBDiVy8Pt-Bw5RU3ePfjyFjP-jjC4_KsMmZy-Tjb1s,8115
36
36
  libinephany/pydantic_models/states/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
37
37
  libinephany/pydantic_models/states/hyperparameter_states.py,sha256=qQqnYyKR5ucENb_Zyk0x71_zpgUtG_rOzbOcE0kTrkQ,33478
38
38
  libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
39
39
  libinephany/utils/agent_utils.py,sha256=_2w1AY5Y4mQ5hes_Rq014VhZXOtIOn-W92mZgeixv3g,2658
40
40
  libinephany/utils/asyncio_worker.py,sha256=Ew23zKIbG1zwyCudcyiObMrw4G0f3p2QXzZfM4mePqI,2751
41
41
  libinephany/utils/backend_statuses.py,sha256=ZbpBPbz0qKmeqxyGGN_ePTrQ7Wrxh7KM6W26UDbPXtQ,644
42
- libinephany/utils/constants.py,sha256=q4R_DjFEvMhV4TKN0pswEr9YIpFTGzZeNGB-fTNttdM,2309
42
+ libinephany/utils/constants.py,sha256=XAOuPowvM4FDSbfvNsubKTAqSB84AANX4CoHb7LwgEI,2330
43
43
  libinephany/utils/directory_utils.py,sha256=408unVeE_5_Hm-ZYZuxc9sdvfuU0CgYELX7EzPlPieo,1217
44
44
  libinephany/utils/dropout_utils.py,sha256=X43yCW7Dh1cC5sNnivgS5j1fn871K_RCvxCBTT0YHKg,3392
45
45
  libinephany/utils/enums.py,sha256=6_6k_1I2BwYTIfquUOsoaQT5fkhMXUWtwCxLoTYuFyU,2906
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
57
57
  libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
58
58
  libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
59
59
  libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
60
- libinephany-0.17.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
- libinephany-0.17.0.dist-info/METADATA,sha256=OJryt-xjeAYOYaDzTyv4UXSeTRGxBFipvjlDPNOLZvI,8390
62
- libinephany-0.17.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
- libinephany-0.17.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
- libinephany-0.17.0.dist-info/RECORD,,
60
+ libinephany-0.18.1.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
61
+ libinephany-0.18.1.dist-info/METADATA,sha256=d2rH-yA7F1cVk2OnTyvdaQ6_CFey1uiY4ATbqsxH9Pg,8390
62
+ libinephany-0.18.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
63
+ libinephany-0.18.1.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
64
+ libinephany-0.18.1.dist-info/RECORD,,