libinephany 0.17.0__py3-none-any.whl → 0.18.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/observations/observer_pipeline.py +4 -2
- libinephany/observations/observers/base_observers.py +12 -3
- libinephany/observations/observers/observer_containers.py +8 -4
- libinephany/observations/pipeline_coordinator.py +1 -1
- libinephany/pydantic_models/schemas/observation_models.py +1 -1
- libinephany/pydantic_models/schemas/tensor_statistics.py +18 -0
- {libinephany-0.17.0.dist-info → libinephany-0.18.0.dist-info}/METADATA +1 -1
- {libinephany-0.17.0.dist-info → libinephany-0.18.0.dist-info}/RECORD +11 -11
- {libinephany-0.17.0.dist-info → libinephany-0.18.0.dist-info}/WHEEL +0 -0
- {libinephany-0.17.0.dist-info → libinephany-0.18.0.dist-info}/licenses/LICENSE +0 -0
- {libinephany-0.17.0.dist-info → libinephany-0.18.0.dist-info}/top_level.txt +0 -0
@@ -218,7 +218,9 @@ class ObserverPipeline:
|
|
218
218
|
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
219
219
|
actions_taken: dict[str, float | int | None],
|
220
220
|
return_dict: bool = False,
|
221
|
-
) -> tuple[
|
221
|
+
) -> tuple[
|
222
|
+
dict[str, list[float | int]], bool, dict[str, dict[str, list[float | int] | dict[str, float | int]]] | None
|
223
|
+
]:
|
222
224
|
"""
|
223
225
|
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
224
226
|
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
@@ -241,7 +243,7 @@ class ObserverPipeline:
|
|
241
243
|
)
|
242
244
|
|
243
245
|
local_obs: dict[str, list[float | int]] = {}
|
244
|
-
obs_as_dict: dict[str, dict[str, list[float | int]]] = {}
|
246
|
+
obs_as_dict: dict[str, dict[str, list[float | int] | dict[str, float | int]]] = {}
|
245
247
|
|
246
248
|
for agent_id, agent_observers in self.local_observers.items():
|
247
249
|
local_obs[agent_id], local_obs_dict = agent_observers.observe(
|
@@ -204,16 +204,22 @@ class Observer(ABC):
|
|
204
204
|
hyperparameter_states: HyperparameterStates,
|
205
205
|
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
206
206
|
action_taken: float | int | None,
|
207
|
-
|
207
|
+
return_dict: bool = False,
|
208
|
+
) -> tuple[list[float | int], dict[str, float] | None]:
|
208
209
|
"""
|
209
210
|
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
210
211
|
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
211
212
|
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
212
213
|
names to floats or TensorStatistic models.
|
213
214
|
:param action_taken: Action taken by the agent this class instance is assigned to.
|
214
|
-
:
|
215
|
+
:param return_dict: Whether to return a dictionary of observations as well as the normal vector.
|
216
|
+
:return: Tuple of:
|
217
|
+
- List of floats or integers to add to the agent's observation vector.
|
218
|
+
- Dictionary of specific observation values if the storage type is TensorStatistics and None otherwise.
|
215
219
|
"""
|
216
220
|
|
221
|
+
observations_dict: dict[str, float] | None = None
|
222
|
+
|
217
223
|
observations = self._observe(
|
218
224
|
observation_inputs=observation_inputs,
|
219
225
|
hyperparameter_states=hyperparameter_states,
|
@@ -225,6 +231,9 @@ class Observer(ABC):
|
|
225
231
|
self._cached_observation = deepcopy(observations)
|
226
232
|
|
227
233
|
if self.observation_format is StatisticStorageTypes.TENSOR_STATISTICS:
|
234
|
+
if return_dict:
|
235
|
+
observations_dict = observations.as_observation_dict() # type: ignore
|
236
|
+
|
228
237
|
observations = observations.to_list(skip_statistics=self.skip_statistics) # type: ignore
|
229
238
|
|
230
239
|
observations = [observations] if not isinstance(observations, list) else observations # type: ignore
|
@@ -241,7 +250,7 @@ class Observer(ABC):
|
|
241
250
|
if not self._validated_observation:
|
242
251
|
self._validate_observation(observations=observations)
|
243
252
|
|
244
|
-
return observations
|
253
|
+
return observations, observations_dict
|
245
254
|
|
246
255
|
@final
|
247
256
|
def inform(self) -> float | int | dict[str, float] | None:
|
@@ -135,7 +135,7 @@ class ObserverContainer(ABC):
|
|
135
135
|
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
136
136
|
action_taken: float | int | None,
|
137
137
|
return_dict: bool = False,
|
138
|
-
) -> tuple[list[float | int], dict[str, list[float | int]] | None]:
|
138
|
+
) -> tuple[list[float | int], dict[str, list[float | int] | dict[str, float | int]] | None]:
|
139
139
|
"""
|
140
140
|
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
141
141
|
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
@@ -148,17 +148,21 @@ class ObserverContainer(ABC):
|
|
148
148
|
"""
|
149
149
|
|
150
150
|
observations = []
|
151
|
-
observations_dict = {}
|
151
|
+
observations_dict: dict[str, list[float | int] | dict[str, float | int]] = {}
|
152
152
|
|
153
153
|
for observer in self._observers:
|
154
|
-
observer_obs = observer.observe(
|
154
|
+
observer_obs, observer_obs_dict = observer.observe(
|
155
155
|
observation_inputs=observation_inputs,
|
156
156
|
hyperparameter_states=hyperparameter_states,
|
157
157
|
tracked_statistics=tracked_statistics,
|
158
158
|
action_taken=action_taken,
|
159
|
+
return_dict=return_dict,
|
159
160
|
)
|
160
161
|
|
161
|
-
if return_dict:
|
162
|
+
if return_dict and observer_obs_dict is not None:
|
163
|
+
observations_dict[observer.__class__.__name__] = observer_obs_dict
|
164
|
+
|
165
|
+
elif return_dict:
|
162
166
|
observations_dict[observer.__class__.__name__] = observer_obs
|
163
167
|
|
164
168
|
observations += observer_obs
|
@@ -134,7 +134,7 @@ class ObserverPipelineCoordinator:
|
|
134
134
|
|
135
135
|
clipped_observations = False
|
136
136
|
observations = {}
|
137
|
-
observations_as_dict: dict[str, dict[str, list[float | int]]] = {}
|
137
|
+
observations_as_dict: dict[str, dict[str, list[float | int] | dict[str, float | int]]] = {}
|
138
138
|
|
139
139
|
for pipeline in self.pipelines:
|
140
140
|
pipeline_observations, pipeline_clipped_observations, pipeline_observations_dict = pipeline.observe(
|
@@ -51,7 +51,7 @@ class Observations(BaseModel):
|
|
51
51
|
hit_invalid_value: bool
|
52
52
|
|
53
53
|
agent_observations: dict[str, list[float | int]]
|
54
|
-
observations_as_dict: dict[str, dict[str, list[float | int]]] | None = None
|
54
|
+
observations_as_dict: dict[str, dict[str, list[float | int] | dict[str, float]]] | None = None
|
55
55
|
|
56
56
|
def observations_as_arrays(self, dtype: DTypeLike = np.float32) -> dict[str, np.ndarray]:
|
57
57
|
"""
|
@@ -7,6 +7,14 @@
|
|
7
7
|
import torch
|
8
8
|
from pydantic import BaseModel
|
9
9
|
|
10
|
+
# ======================================================================================================================
|
11
|
+
#
|
12
|
+
# CONSTANTS
|
13
|
+
#
|
14
|
+
# ======================================================================================================================
|
15
|
+
|
16
|
+
STRIP_SUFFIX = "_"
|
17
|
+
|
10
18
|
# ======================================================================================================================
|
11
19
|
#
|
12
20
|
# CLASSES
|
@@ -252,3 +260,13 @@ class TensorStatistics(BaseModel):
|
|
252
260
|
self.inter_quartile_range,
|
253
261
|
]
|
254
262
|
)
|
263
|
+
|
264
|
+
def as_observation_dict(self) -> dict[str, float]:
|
265
|
+
"""
|
266
|
+
:return: Dictionary of observation values.
|
267
|
+
"""
|
268
|
+
|
269
|
+
return {
|
270
|
+
field[:-1] if field.endswith(STRIP_SUFFIX) else field: field_value
|
271
|
+
for field, field_value in self.model_dump().items()
|
272
|
+
}
|
@@ -3,14 +3,14 @@ libinephany/aws/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
libinephany/aws/s3_functions.py,sha256=W8u85A6tDloo4FlJvydJbVHCUq_m9i8KDGdnKzy-Xpg,1745
|
4
4
|
libinephany/observations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
libinephany/observations/observation_utils.py,sha256=ejQ-hzKq_MX7A0304KypScmzloReNrH8dmcgonoRl4Q,16266
|
6
|
-
libinephany/observations/observer_pipeline.py,sha256=
|
7
|
-
libinephany/observations/pipeline_coordinator.py,sha256=
|
6
|
+
libinephany/observations/observer_pipeline.py,sha256=_xA4vrijhG8-9MCtGXnKAEmpd6q0nKVpJgY_qSbypIA,12979
|
7
|
+
libinephany/observations/pipeline_coordinator.py,sha256=mLfaHhkXVhMp9w5jWIAL3jPyauCM-795qOzyqwGOSdw,7932
|
8
8
|
libinephany/observations/statistic_manager.py,sha256=LLg1zSxnJr2oQQepYla3qoUuRy10rsthr9jta4wEbnc,8956
|
9
9
|
libinephany/observations/statistic_trackers.py,sha256=06NjCy1MI865oU0KB5f-wQE3b2RvYawOOWxNJx4rFpw,32939
|
10
10
|
libinephany/observations/observers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
libinephany/observations/observers/base_observers.py,sha256=
|
11
|
+
libinephany/observations/observers/base_observers.py,sha256=Tkk2AvQ6mRY6fdhhxHQWohA0xSHGdPqxul1C7C7Frj4,15924
|
12
12
|
libinephany/observations/observers/local_observers.py,sha256=7azTW227-rG_QJR5a0StfE0O4Ca1boMV9nrCphnhWik,40344
|
13
|
-
libinephany/observations/observers/observer_containers.py,sha256=
|
13
|
+
libinephany/observations/observers/observer_containers.py,sha256=VNyqGgxYJ4r49Msp_kk-POgicb-_5w54twuT1qfNdxw,9562
|
14
14
|
libinephany/observations/observers/global_observers/__init__.py,sha256=lQO-nq7hILu4F3ddFXcCR-ghfv0dzEA9nAYxZta7rxk,2306
|
15
15
|
libinephany/observations/observers/global_observers/base_classes.py,sha256=CCkRx86Lll3gFzfqervP0jKdzNFKkKU7tEBh8ic1Yrc,8249
|
16
16
|
libinephany/observations/observers/global_observers/constants.py,sha256=olXYxh353Th6hyhfk85kHxAWPeaDpuPkcz_awFvEz6c,1054
|
@@ -29,10 +29,10 @@ libinephany/pydantic_models/configs/outer_model_config.py,sha256=GQ0QBSC2Xht8x8X
|
|
29
29
|
libinephany/pydantic_models/schemas/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
30
|
libinephany/pydantic_models/schemas/agent_info.py,sha256=me5gDxvZjP9TNK588mpUvxiiJrPDqy3Z7ZHRzryAYTs,2628
|
31
31
|
libinephany/pydantic_models/schemas/inner_task_profile.py,sha256=Xu0tQmhGwV043tTamFiHekuE1RRXhhrUrGbtymjXo7g,11722
|
32
|
-
libinephany/pydantic_models/schemas/observation_models.py,sha256=
|
32
|
+
libinephany/pydantic_models/schemas/observation_models.py,sha256=MLhxqDet9Yol1D5mkQGQsQT23sm37AStRLnPc4sgcZc,2110
|
33
33
|
libinephany/pydantic_models/schemas/request_schemas.py,sha256=VED8eAUvBofxeAx9gWU8DyCZOTVD3QsHRq-TO7kyOqk,1260
|
34
34
|
libinephany/pydantic_models/schemas/response_schemas.py,sha256=SKFuasdjX5aH_I0vT3SwnpwhyMf9cNPB1ZpDeAGgoO8,2158
|
35
|
-
libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=
|
35
|
+
libinephany/pydantic_models/schemas/tensor_statistics.py,sha256=ZrBDiVy8Pt-Bw5RU3ePfjyFjP-jjC4_KsMmZy-Tjb1s,8115
|
36
36
|
libinephany/pydantic_models/states/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
37
37
|
libinephany/pydantic_models/states/hyperparameter_states.py,sha256=qQqnYyKR5ucENb_Zyk0x71_zpgUtG_rOzbOcE0kTrkQ,33478
|
38
38
|
libinephany/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -57,8 +57,8 @@ libinephany/utils/typing.py,sha256=rGbaPO3MaUndsWiC_wHzReD_TOLYqb43i01pKN-j7Xs,6
|
|
57
57
|
libinephany/web_apps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
58
58
|
libinephany/web_apps/error_logger.py,sha256=gAQIaqerqP4ornXZwFF1cghjnd2mMZEt3aVrTuUCr34,16653
|
59
59
|
libinephany/web_apps/web_app_utils.py,sha256=qiq_lasPipgN1RgRudPJc342kYci8O_4RqppxmIX8NY,4095
|
60
|
-
libinephany-0.
|
61
|
-
libinephany-0.
|
62
|
-
libinephany-0.
|
63
|
-
libinephany-0.
|
64
|
-
libinephany-0.
|
60
|
+
libinephany-0.18.0.dist-info/licenses/LICENSE,sha256=pogfDoMBP07ehIOvWymuWIar8pg2YLUhqOHsJQU3wdc,9250
|
61
|
+
libinephany-0.18.0.dist-info/METADATA,sha256=WTx51KP-dxUEpT5njeRj4LgnVuw94OO9zxj49FgoUuA,8390
|
62
|
+
libinephany-0.18.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
63
|
+
libinephany-0.18.0.dist-info/top_level.txt,sha256=bYAOXQdJgIoLkO2Ui0kxe7pSYegS_e38u0dMscd7COQ,12
|
64
|
+
libinephany-0.18.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|