libinephany 0.13.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. libinephany/__init__.py +0 -0
  2. libinephany/aws/__init__.py +0 -0
  3. libinephany/aws/s3_functions.py +57 -0
  4. libinephany/observations/__init__.py +0 -0
  5. libinephany/observations/observation_utils.py +243 -0
  6. libinephany/observations/observer_pipeline.py +307 -0
  7. libinephany/observations/observers/__init__.py +0 -0
  8. libinephany/observations/observers/base_observers.py +418 -0
  9. libinephany/observations/observers/global_observers.py +988 -0
  10. libinephany/observations/observers/local_observers.py +982 -0
  11. libinephany/observations/observers/observer_containers.py +270 -0
  12. libinephany/observations/pipeline_coordinator.py +198 -0
  13. libinephany/observations/post_processors/__init__.py +0 -0
  14. libinephany/observations/post_processors/postprocessors.py +153 -0
  15. libinephany/observations/statistic_manager.py +217 -0
  16. libinephany/observations/statistic_trackers.py +876 -0
  17. libinephany/pydantic_models/__init__.py +0 -0
  18. libinephany/pydantic_models/configs/__init__.py +0 -0
  19. libinephany/pydantic_models/configs/hyperparameter_configs.py +387 -0
  20. libinephany/pydantic_models/configs/observer_config.py +43 -0
  21. libinephany/pydantic_models/configs/outer_model_config.py +32 -0
  22. libinephany/pydantic_models/schemas/__init__.py +0 -0
  23. libinephany/pydantic_models/schemas/agent_info.py +65 -0
  24. libinephany/pydantic_models/schemas/inner_task_profile.py +284 -0
  25. libinephany/pydantic_models/schemas/observation_models.py +61 -0
  26. libinephany/pydantic_models/schemas/request_schemas.py +45 -0
  27. libinephany/pydantic_models/schemas/response_schemas.py +50 -0
  28. libinephany/pydantic_models/schemas/tensor_statistics.py +254 -0
  29. libinephany/pydantic_models/states/__init__.py +0 -0
  30. libinephany/pydantic_models/states/hyperparameter_states.py +808 -0
  31. libinephany/utils/__init__.py +0 -0
  32. libinephany/utils/agent_utils.py +82 -0
  33. libinephany/utils/asyncio_worker.py +87 -0
  34. libinephany/utils/backend_statuses.py +20 -0
  35. libinephany/utils/constants.py +76 -0
  36. libinephany/utils/directory_utils.py +41 -0
  37. libinephany/utils/dropout_utils.py +92 -0
  38. libinephany/utils/enums.py +90 -0
  39. libinephany/utils/error_severities.py +46 -0
  40. libinephany/utils/exceptions.py +60 -0
  41. libinephany/utils/import_utils.py +43 -0
  42. libinephany/utils/optim_utils.py +239 -0
  43. libinephany/utils/random_seeds.py +55 -0
  44. libinephany/utils/samplers.py +341 -0
  45. libinephany/utils/standardizers.py +217 -0
  46. libinephany/utils/torch_distributed_utils.py +85 -0
  47. libinephany/utils/torch_utils.py +77 -0
  48. libinephany/utils/transforms.py +104 -0
  49. libinephany/utils/typing.py +15 -0
  50. libinephany/web_apps/__init__.py +0 -0
  51. libinephany/web_apps/error_logger.py +421 -0
  52. libinephany/web_apps/web_app_utils.py +123 -0
  53. libinephany-0.13.1.dist-info/METADATA +282 -0
  54. libinephany-0.13.1.dist-info/RECORD +57 -0
  55. libinephany-0.13.1.dist-info/WHEEL +5 -0
  56. libinephany-0.13.1.dist-info/licenses/LICENSE +11 -0
  57. libinephany-0.13.1.dist-info/top_level.txt +1 -0
File without changes
File without changes
@@ -0,0 +1,57 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # IMPORTS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ import boto3
8
+ from loguru import logger
9
+
10
+ # ======================================================================================================================
11
+ #
12
+ # CONSTANTS
13
+ #
14
+ # ======================================================================================================================
15
+
16
+ S3_URI_SCHEME = "s3://"
17
+ S3_BUCKET_SEPARATOR = "/"
18
+ S3_BOTO_CLIENT = "s3"
19
+
20
+ # ======================================================================================================================
21
+ #
22
+ # FUNCTIONS
23
+ #
24
+ # ======================================================================================================================
25
+
26
+
27
+ def parse_s3_url(s3_url: str) -> tuple[str, str]:
28
+ """
29
+ :param s3_url: S3 URL to parse and extract the bucket and blob names from.
30
+ :return: Tuple of:
31
+ - Bucket name.
32
+ - Blob name.
33
+ """
34
+
35
+ if s3_url.startswith(S3_URI_SCHEME):
36
+ s3_url = s3_url.replace(S3_URI_SCHEME, "")
37
+
38
+ parts = s3_url.split(S3_BUCKET_SEPARATOR, 1)
39
+ bucket = parts[0]
40
+ key = parts[1] if len(parts) > 1 else ""
41
+
42
+ return bucket, key
43
+
44
+
45
+ def download_s3_file(s3_url: str, local_path: str) -> None:
46
+ """
47
+ :param s3_url: S3 URL to download data from.
48
+ :param local_path: Path to save the contents of the download to.
49
+ """
50
+
51
+ bucket, key = parse_s3_url(s3_url)
52
+
53
+ s3 = boto3.client(S3_BOTO_CLIENT)
54
+
55
+ logger.info(f"Downloading {s3_url} to {local_path}...")
56
+ s3.download_file(bucket, key, local_path)
57
+ logger.success("Done.")
File without changes
@@ -0,0 +1,243 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # IMPORTS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from enum import Enum
8
+ from itertools import chain
9
+ from typing import Any, Callable
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import torch.optim as optim
15
+
16
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
17
+ from libinephany.utils import optim_utils
18
+
19
+ # ======================================================================================================================
20
+ #
21
+ # CONSTANTS
22
+ #
23
+ # ======================================================================================================================
24
+
25
+ EXP_AVERAGE = "exp_avg"
26
+
27
+ # ======================================================================================================================
28
+ #
29
+ # CLASSES
30
+ #
31
+ # ======================================================================================================================
32
+
33
+
34
+ class StatisticsCallStage(Enum):
35
+
36
+ ON_BATCH_END = "on_batch_end"
37
+ ON_OPTIMIZER_STEP = "on_optimizer_step"
38
+ ON_TRAIN_END = "on_train_end"
39
+
40
+ FORWARD_HOOK = "forward_hook"
41
+
42
+
43
+ class StatisticStorageTypes(Enum):
44
+
45
+ TENSOR_STATISTICS = TensorStatistics.__name__
46
+ FLOAT = float.__name__
47
+ VECTOR = "vector"
48
+
49
+
50
+ # ======================================================================================================================
51
+ #
52
+ # FUNCTIONS
53
+ #
54
+ # ======================================================================================================================
55
+
56
+
57
+ def get_exponential_weighted_average(values: list[float]) -> float:
58
+ """
59
+ :param values: List of values to average via EWA.
60
+ :return: EWA of the given values.
61
+ """
62
+
63
+ exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
64
+ assert isinstance(exp_weighted_average, float)
65
+
66
+ return exp_weighted_average
67
+
68
+
69
+ def apply_averaging_function_to_tensor_statistics(
70
+ tensor_statistics: list[TensorStatistics], averaging_function: Callable[[list[float]], float]
71
+ ) -> TensorStatistics:
72
+ """
73
+ :param tensor_statistics: List of statistics models to average over.
74
+ :param averaging_function: Function to average the values with.
75
+ :return: TensorStatistics containing the average over all given tensor statistics.
76
+ """
77
+
78
+ fields = TensorStatistics.model_fields.keys()
79
+ averaged_metrics = {
80
+ field: averaging_function([getattr(statistics, field) for statistics in tensor_statistics]) for field in fields
81
+ }
82
+
83
+ return TensorStatistics(**averaged_metrics)
84
+
85
+
86
+ def apply_averaging_function_to_dictionary_of_tensor_statistics(
87
+ data: dict[str, list[TensorStatistics]], averaging_function: Callable[[list[float]], float]
88
+ ) -> dict[str, TensorStatistics]:
89
+ """
90
+ :param data: Dictionary mapping parameter group names to list of TensorStatistics from that parameter group.
91
+ :param averaging_function: Function to average the values with.
92
+ :return: Dictionary mapping parameter group names to TensorStatistics averaged over all statistics in the given
93
+ TensorStatistics models.
94
+ """
95
+
96
+ return {
97
+ group: apply_averaging_function_to_tensor_statistics(
98
+ tensor_statistics=metrics, averaging_function=averaging_function
99
+ )
100
+ for group, metrics in data.items()
101
+ }
102
+
103
+
104
+ def apply_averaging_function_to_dictionary_of_metric_lists(
105
+ data: dict[str, list[float]], averaging_function: Callable[[list[float]], float]
106
+ ) -> dict[str, float]:
107
+ """
108
+ :param data: Dictionary mapping parameter group names to list of metrics from that parameter group.
109
+ :param averaging_function: Function to average the values with.
110
+ :return: Dictionary mapping parameter group names to averages over all metrics from each parameter group.
111
+ """
112
+
113
+ return {group: averaging_function(metrics) for group, metrics in data.items()}
114
+
115
+
116
+ def average_tensor_statistics(tensor_statistics: list[TensorStatistics]) -> TensorStatistics:
117
+ """
118
+ :param tensor_statistics: List of TensorStatistics models to average into one model.
119
+ :return: Averages over all given tensor statistics models.
120
+ """
121
+
122
+ averaged = {
123
+ field: sum([getattr(statistics_model, field) for statistics_model in tensor_statistics])
124
+ for field in TensorStatistics.model_fields.keys()
125
+ }
126
+ averaged = {field: total / len(tensor_statistics) for field, total in averaged.items()}
127
+
128
+ return TensorStatistics(**averaged)
129
+
130
+
131
+ def create_one_hot_observation(vector_length: int, one_hot_index: int | None) -> list[int | float]:
132
+ """
133
+ :param vector_length: Length of the one-hot vector.
134
+ :param one_hot_index: Index of the vector whose element should be set to 1.0, leaving all others as 0.0.
135
+ :return: Constructed one-hot vector in a list.
136
+ """
137
+
138
+ if one_hot_index is not None and one_hot_index < 0:
139
+ raise ValueError("One hot indices must be greater than 0.")
140
+
141
+ one_hot = np.zeros(vector_length, dtype=np.int8)
142
+
143
+ if one_hot_index is not None:
144
+ one_hot[one_hot_index] = 1
145
+
146
+ as_list = one_hot.tolist()
147
+
148
+ assert isinstance(as_list, list), "One-hot vector must be a list."
149
+
150
+ return as_list
151
+
152
+
153
+ def create_one_hot_depth_encoding(agent_controlled_modules: list[str], parameter_group_name: str) -> list[int | float]:
154
+ """
155
+ :param agent_controlled_modules: Ordered list of parameter group names in the inner model.
156
+ :param parameter_group_name: Name of the parameter group to create a depth one-hot vector for.
157
+ :return: Constructed one-hot depth encoding in a list.
158
+
159
+ :note: GANNO encodes depths to one-hot vectors of length 3 regardless of the size of the model.
160
+ """
161
+
162
+ module_index = agent_controlled_modules.index(parameter_group_name)
163
+ number_of_modules = len(agent_controlled_modules)
164
+
165
+ one_hot_index = min(2, (module_index * 3) // number_of_modules)
166
+
167
+ return create_one_hot_observation(vector_length=3, one_hot_index=one_hot_index)
168
+
169
+
170
+ def tensor_on_local_rank(tensor: torch.Tensor | None) -> bool:
171
+ """
172
+ :param tensor: Tensor to check whether it is owned by the local rank, partially or entirely.
173
+ :return: Whether the tensor is owned by the local rank.
174
+ """
175
+
176
+ return tensor is not None and tensor.grad is not None and tensor.numel() > 0
177
+
178
+
179
+ def form_update_tensor(
180
+ optimizer: optim.Optimizer, parameters: list[torch.Tensor], parameter_group: dict[str, Any]
181
+ ) -> None | torch.Tensor:
182
+ """
183
+ :param optimizer: Optimizer to form the update tensor from.
184
+ :param parameters: Parameters to create the update tensor from.
185
+ :param parameter_group: Parameter group within the optimizer the given parameters came from.
186
+ :return: None or the formed update tensor.
187
+ """
188
+
189
+ if type(optimizer) in optim_utils.ADAM_OPTIMISERS:
190
+ updates_list = [optimizer.state[p][EXP_AVERAGE].view(-1) for p in parameters if tensor_on_local_rank(p)]
191
+ return torch.cat(updates_list) if updates_list else None
192
+
193
+ elif type(optimizer) in optim_utils.SGD_OPTIMISERS:
194
+ return optim_utils.compute_sgd_optimizer_update_stats(
195
+ optimizer=optimizer, parameter_group=parameter_group, parameters=parameters
196
+ )
197
+
198
+ else:
199
+ raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
200
+
201
+
202
+ def null_standardizer(value_to_standardize: float, **kwargs) -> float:
203
+ """
204
+ :param value_to_standardize: Value to mock the standardization of.
205
+ :return: Given value to standardize.
206
+ """
207
+
208
+ return value_to_standardize
209
+
210
+
211
+ def create_sinusoidal_depth_encoding(
212
+ agent_controlled_modules: list[str], parameter_group_name: str, dimensionality: int
213
+ ) -> list[int | float]:
214
+ """
215
+ :param agent_controlled_modules: Ordered list of parameter group names in the inner model.
216
+ :param parameter_group_name: Name of the parameter group to create a depth encoding for.
217
+ :param dimensionality: Length of the depth vector.
218
+ :return: Sinusoidal depth encoding.
219
+ """
220
+
221
+ assert dimensionality % 2 == 0, "Dimensionality of a sinusoidal depth encoding must be even."
222
+
223
+ depth = agent_controlled_modules.index(parameter_group_name)
224
+
225
+ positions = np.arange(dimensionality // 2)
226
+ frequencies = 1 / (10000 ** (2 * positions / dimensionality))
227
+
228
+ encoding = np.zeros(dimensionality)
229
+ encoding[0::2] = np.sin(depth * frequencies)
230
+ encoding[1::2] = np.cos(depth * frequencies)
231
+
232
+ vector = encoding.tolist()
233
+
234
+ return vector
235
+
236
+
237
+ def concatenate_lists(lists: list[list[Any]]) -> list[Any]:
238
+ """
239
+ :param lists: Lists to concatenate.
240
+ :return: Concatenated lists.
241
+ """
242
+
243
+ return list(chain(*lists))
@@ -0,0 +1,307 @@
1
+ # ======================================================================================================================
2
+ #
3
+ # IMPORTS
4
+ #
5
+ # ======================================================================================================================
6
+
7
+ from typing import Any
8
+
9
+ import gymnasium as gym
10
+ import numpy as np
11
+
12
+ from libinephany.observations.observers.observer_containers import GlobalObserverContainer, LocalObserverContainer
13
+ from libinephany.observations.post_processors import postprocessors
14
+ from libinephany.observations.post_processors.postprocessors import ObservationPostProcessor
15
+ from libinephany.pydantic_models.configs.observer_config import AgentObserverConfig, ObserverConfig
16
+ from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
17
+ from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
18
+ from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
19
+ from libinephany.utils.standardizers import Standardizer
20
+ from libinephany.utils.typing import ObservationInformation
21
+
22
+ # ======================================================================================================================
23
+ #
24
+ # CLASSES
25
+ #
26
+ # ======================================================================================================================
27
+
28
+
29
+ class ObserverPipeline:
30
+
31
+ def __init__(
32
+ self,
33
+ observer_config: ObserverConfig,
34
+ agent_config: AgentObserverConfig,
35
+ standardizer: Standardizer | None,
36
+ agent_id_to_modules: dict[str, str | None],
37
+ ):
38
+ """
39
+ :param observer_config: ObserverConfig that contains various parameters applicable to all agent observers as
40
+ well as the observer configs for specific agents.
41
+ :param agent_config: AgentObserverConfig storing configuration of the agent's actions and observation space for
42
+ this type of agent.
43
+ :param standardizer: None or the standardizer to apply to the returned observations.
44
+ :param agent_id_to_modules: Dictionary mapping agent IDs to None or the name of the parameter group that agent
45
+ is modulating.
46
+ """
47
+
48
+ self.observer_config = observer_config
49
+ self.agent_config = agent_config
50
+ self.standardizer = standardizer
51
+ self.agent_id_to_modules = agent_id_to_modules
52
+
53
+ self.global_observers = GlobalObserverContainer(
54
+ global_config=observer_config,
55
+ agent_config=agent_config,
56
+ standardizer=standardizer,
57
+ )
58
+
59
+ self.local_observers: dict[str, LocalObserverContainer] = self._build_local_observers(
60
+ agent_id_to_modules=agent_id_to_modules,
61
+ )
62
+
63
+ self.post_processors: list[ObservationPostProcessor] = self._build_post_processors()
64
+
65
+ self.clipping_threshold = self.observer_config.observation_clipping_threshold
66
+ self.invalid_threshold = self.observer_config.invalid_observation_threshold
67
+
68
+ @property
69
+ def observation_size(self) -> int:
70
+ """
71
+ :return: Total size of the observation vectors produced by this pipeline.
72
+ """
73
+
74
+ globals_size = self.global_observers.total_observer_size
75
+ locals_size = list(self.local_observers.values())[0].total_observer_size
76
+
77
+ total_size = globals_size + locals_size
78
+
79
+ if self.agent_config.prepend_invalid_indicator:
80
+ total_size += 1
81
+
82
+ return total_size
83
+
84
+ def _build_local_observers(
85
+ self,
86
+ agent_id_to_modules: dict[str, str | None],
87
+ ) -> dict[str, LocalObserverContainer]:
88
+ """
89
+ :param agent_id_to_modules: Dictionary mapping agent IDs to None or the name of the parameter group that agent
90
+ is modulating.
91
+ :return: Dictionary mapping agent IDs to LocalObserverContainers for each agent.
92
+ """
93
+
94
+ observers = {}
95
+
96
+ for agent_id, parameter_group_name in agent_id_to_modules.items():
97
+ observers[agent_id] = LocalObserverContainer(
98
+ global_config=self.observer_config,
99
+ agent_config=self.agent_config,
100
+ standardizer=self.standardizer,
101
+ agent_id=agent_id,
102
+ parameter_group_name=parameter_group_name,
103
+ )
104
+
105
+ return observers
106
+
107
+ def _build_post_processors(self) -> list[ObservationPostProcessor]:
108
+ """
109
+ :return: List of post-processors to apply to the global and local observations.
110
+ """
111
+
112
+ if self.agent_config.postprocessors is None:
113
+ return []
114
+
115
+ post_processors = []
116
+
117
+ for observer_name, post_processor_kwargs in self.agent_config.postprocessors.items():
118
+ try:
119
+ post_processor_type: type[ObservationPostProcessor] = getattr(postprocessors, observer_name)
120
+
121
+ except AttributeError as e:
122
+ raise AttributeError(f"The class {observer_name} does not exist within {postprocessors}!") from e
123
+
124
+ if post_processor_kwargs is None:
125
+ post_processor_kwargs = {}
126
+
127
+ post_processor = post_processor_type(
128
+ observer_config=self.observer_config,
129
+ **post_processor_kwargs,
130
+ )
131
+
132
+ post_processors.append(post_processor)
133
+
134
+ return post_processors
135
+
136
+ @staticmethod
137
+ def merge_globals_to_locals(
138
+ global_obs: list[float | int],
139
+ local_obs: dict[str, list[float | int]],
140
+ ) -> dict[str, list[float | int]]:
141
+ """
142
+ :param global_obs: Global observations to post-process.
143
+ :param local_obs: Dictionary mapping agent IDs to local observations of that agent.
144
+ :return: Tuple of clipped global and local observations.
145
+ :return: Dictionary mapping agent ID to that agent's completed observation vector.
146
+ """
147
+
148
+ return {agent_id: global_obs + agent_obs for agent_id, agent_obs in local_obs.items()}
149
+
150
+ def get_required_trackers(self) -> list[dict[str, dict[str, Any] | None]]:
151
+ """
152
+ :return: List of trackers required by each of the stored observers.
153
+ """
154
+
155
+ required_trackers = self.global_observers.get_required_trackers()
156
+
157
+ for agent_container in self.local_observers.values():
158
+ required_trackers += agent_container.get_required_trackers()
159
+
160
+ return required_trackers
161
+
162
+ def clip_observation_vector(self, observation_vector: list[float | int]) -> tuple[bool, list[float | int]]:
163
+ """
164
+ :param observation_vector: Observations to clip.
165
+ :return: Tuple indicating whether an observation was clipped and a list of the post-processed observations.
166
+ """
167
+
168
+ invalid_encountered = False
169
+ post_processed = []
170
+
171
+ for observation in observation_vector:
172
+ if abs(observation) >= self.clipping_threshold:
173
+ if abs(observation) >= self.invalid_threshold:
174
+ invalid_encountered = True
175
+
176
+ observation = -self.clipping_threshold if observation < 0 else self.clipping_threshold
177
+
178
+ elif np.isnan(observation):
179
+ observation = self.observer_config.invalid_observation_replacement_value
180
+ invalid_encountered = True
181
+
182
+ post_processed.append(observation)
183
+
184
+ return invalid_encountered, post_processed
185
+
186
+ def clip_observations(
187
+ self,
188
+ globals_to_clip: list[float | int],
189
+ locals_to_clip: dict[str, list[float | int]],
190
+ ) -> tuple[list[float | int], dict[str, list[float | int]], bool]:
191
+ """
192
+ :param globals_to_clip: Global observations to post-process.
193
+ :param locals_to_clip: Dictionary mapping agent IDs to local observations of that agent.
194
+ :return: Tuple of clipped global and local observations and whether any observation clipping occurred.
195
+ """
196
+
197
+ invalid_encountered, globals_to_clip = self.clip_observation_vector(observation_vector=globals_to_clip)
198
+ post_processed_locals = {}
199
+
200
+ for agent_id, agent_observations in locals_to_clip.items():
201
+ agent_invalid_encountered, post_processed = self.clip_observation_vector(
202
+ observation_vector=agent_observations
203
+ )
204
+
205
+ post_processed_locals[agent_id] = post_processed
206
+ invalid_encountered = invalid_encountered if invalid_encountered else agent_invalid_encountered
207
+
208
+ if self.agent_config.prepend_invalid_indicator:
209
+ globals_to_clip = [int(invalid_encountered)] + globals_to_clip
210
+
211
+ return globals_to_clip, post_processed_locals, invalid_encountered
212
+
213
+ def observe(
214
+ self,
215
+ observation_inputs: ObservationInputs,
216
+ hyperparameter_states: HyperparameterStates,
217
+ tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
218
+ actions_taken: dict[str, float | int | None],
219
+ ) -> tuple[dict[str, list[float | int]], bool]:
220
+ """
221
+ :param observation_inputs: Observation input metrics not calculated with statistic trackers.
222
+ :param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
223
+ :param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
224
+ names to floats or TensorStatistic models.
225
+ :param actions_taken: Dictionary mapping agent IDs to actions taken by that agent.
226
+ :return: Tuple of a dictionary mapping agent ID to that agent's completed observation vector and a boolean
227
+ indicating whether an observation clip occurred.
228
+ """
229
+
230
+ global_obs = self.global_observers.observe(
231
+ observation_inputs=observation_inputs,
232
+ hyperparameter_states=hyperparameter_states,
233
+ tracked_statistics=tracked_statistics,
234
+ action_taken=None,
235
+ )
236
+
237
+ local_obs = {
238
+ agent_id: agent_observers.observe(
239
+ observation_inputs=observation_inputs,
240
+ hyperparameter_states=hyperparameter_states,
241
+ tracked_statistics=tracked_statistics,
242
+ action_taken=actions_taken[agent_id] if agent_id in actions_taken else None,
243
+ )
244
+ for agent_id, agent_observers in self.local_observers.items()
245
+ }
246
+
247
+ for post_processor in self.post_processors:
248
+ global_obs, local_obs = post_processor.postprocess(
249
+ global_observations=global_obs, local_observations=local_obs
250
+ )
251
+
252
+ global_obs, local_obs, obs_clipped = self.clip_observations(
253
+ globals_to_clip=global_obs, locals_to_clip=local_obs
254
+ )
255
+
256
+ return self.merge_globals_to_locals(global_obs=global_obs, local_obs=local_obs), obs_clipped
257
+
258
+ def inform(self) -> tuple[ObservationInformation, dict[str, ObservationInformation]]:
259
+ """
260
+ :return: Dictionary of observation info to add to the agent info.
261
+ """
262
+
263
+ global_information = self.global_observers.inform()
264
+
265
+ agent_information = {agent_id: observer.inform() for agent_id, observer in self.local_observers.items()}
266
+
267
+ return global_information, agent_information
268
+
269
+ def get_observation_spaces(self) -> dict[str, gym.spaces.Box]:
270
+ """
271
+ :return: Dictionary mapping agent IDs to their observation spaces.
272
+ """
273
+
274
+ return {
275
+ agent_id: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.observation_size,), dtype=np.float32)
276
+ for agent_id in self.local_observers
277
+ }
278
+
279
+ def reset(self) -> None:
280
+ """
281
+ Resets all global and local observers.
282
+ """
283
+
284
+ self.global_observers.reset()
285
+
286
+ for local_observers in self.local_observers.values():
287
+ local_observers.reset()
288
+
289
+ def train(self) -> None:
290
+ """
291
+ Sets all observer containers into training mode.
292
+ """
293
+
294
+ self.global_observers.train()
295
+
296
+ for local_observers in self.local_observers.values():
297
+ local_observers.train()
298
+
299
+ def infer(self) -> None:
300
+ """
301
+ Sets all observer containers into inference mode.
302
+ """
303
+
304
+ self.global_observers.infer()
305
+
306
+ for local_observers in self.local_observers.values():
307
+ local_observers.infer()
File without changes