libinephany 0.13.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- libinephany/__init__.py +0 -0
- libinephany/aws/__init__.py +0 -0
- libinephany/aws/s3_functions.py +57 -0
- libinephany/observations/__init__.py +0 -0
- libinephany/observations/observation_utils.py +243 -0
- libinephany/observations/observer_pipeline.py +307 -0
- libinephany/observations/observers/__init__.py +0 -0
- libinephany/observations/observers/base_observers.py +418 -0
- libinephany/observations/observers/global_observers.py +988 -0
- libinephany/observations/observers/local_observers.py +982 -0
- libinephany/observations/observers/observer_containers.py +270 -0
- libinephany/observations/pipeline_coordinator.py +198 -0
- libinephany/observations/post_processors/__init__.py +0 -0
- libinephany/observations/post_processors/postprocessors.py +153 -0
- libinephany/observations/statistic_manager.py +217 -0
- libinephany/observations/statistic_trackers.py +876 -0
- libinephany/pydantic_models/__init__.py +0 -0
- libinephany/pydantic_models/configs/__init__.py +0 -0
- libinephany/pydantic_models/configs/hyperparameter_configs.py +387 -0
- libinephany/pydantic_models/configs/observer_config.py +43 -0
- libinephany/pydantic_models/configs/outer_model_config.py +32 -0
- libinephany/pydantic_models/schemas/__init__.py +0 -0
- libinephany/pydantic_models/schemas/agent_info.py +65 -0
- libinephany/pydantic_models/schemas/inner_task_profile.py +284 -0
- libinephany/pydantic_models/schemas/observation_models.py +61 -0
- libinephany/pydantic_models/schemas/request_schemas.py +45 -0
- libinephany/pydantic_models/schemas/response_schemas.py +50 -0
- libinephany/pydantic_models/schemas/tensor_statistics.py +254 -0
- libinephany/pydantic_models/states/__init__.py +0 -0
- libinephany/pydantic_models/states/hyperparameter_states.py +808 -0
- libinephany/utils/__init__.py +0 -0
- libinephany/utils/agent_utils.py +82 -0
- libinephany/utils/asyncio_worker.py +87 -0
- libinephany/utils/backend_statuses.py +20 -0
- libinephany/utils/constants.py +76 -0
- libinephany/utils/directory_utils.py +41 -0
- libinephany/utils/dropout_utils.py +92 -0
- libinephany/utils/enums.py +90 -0
- libinephany/utils/error_severities.py +46 -0
- libinephany/utils/exceptions.py +60 -0
- libinephany/utils/import_utils.py +43 -0
- libinephany/utils/optim_utils.py +239 -0
- libinephany/utils/random_seeds.py +55 -0
- libinephany/utils/samplers.py +341 -0
- libinephany/utils/standardizers.py +217 -0
- libinephany/utils/torch_distributed_utils.py +85 -0
- libinephany/utils/torch_utils.py +77 -0
- libinephany/utils/transforms.py +104 -0
- libinephany/utils/typing.py +15 -0
- libinephany/web_apps/__init__.py +0 -0
- libinephany/web_apps/error_logger.py +421 -0
- libinephany/web_apps/web_app_utils.py +123 -0
- libinephany-0.13.1.dist-info/METADATA +282 -0
- libinephany-0.13.1.dist-info/RECORD +57 -0
- libinephany-0.13.1.dist-info/WHEEL +5 -0
- libinephany-0.13.1.dist-info/licenses/LICENSE +11 -0
- libinephany-0.13.1.dist-info/top_level.txt +1 -0
libinephany/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,57 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# IMPORTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
import boto3
|
8
|
+
from loguru import logger
|
9
|
+
|
10
|
+
# ======================================================================================================================
|
11
|
+
#
|
12
|
+
# CONSTANTS
|
13
|
+
#
|
14
|
+
# ======================================================================================================================
|
15
|
+
|
16
|
+
S3_URI_SCHEME = "s3://"
|
17
|
+
S3_BUCKET_SEPARATOR = "/"
|
18
|
+
S3_BOTO_CLIENT = "s3"
|
19
|
+
|
20
|
+
# ======================================================================================================================
|
21
|
+
#
|
22
|
+
# FUNCTIONS
|
23
|
+
#
|
24
|
+
# ======================================================================================================================
|
25
|
+
|
26
|
+
|
27
|
+
def parse_s3_url(s3_url: str) -> tuple[str, str]:
|
28
|
+
"""
|
29
|
+
:param s3_url: S3 URL to parse and extract the bucket and blob names from.
|
30
|
+
:return: Tuple of:
|
31
|
+
- Bucket name.
|
32
|
+
- Blob name.
|
33
|
+
"""
|
34
|
+
|
35
|
+
if s3_url.startswith(S3_URI_SCHEME):
|
36
|
+
s3_url = s3_url.replace(S3_URI_SCHEME, "")
|
37
|
+
|
38
|
+
parts = s3_url.split(S3_BUCKET_SEPARATOR, 1)
|
39
|
+
bucket = parts[0]
|
40
|
+
key = parts[1] if len(parts) > 1 else ""
|
41
|
+
|
42
|
+
return bucket, key
|
43
|
+
|
44
|
+
|
45
|
+
def download_s3_file(s3_url: str, local_path: str) -> None:
|
46
|
+
"""
|
47
|
+
:param s3_url: S3 URL to download data from.
|
48
|
+
:param local_path: Path to save the contents of the download to.
|
49
|
+
"""
|
50
|
+
|
51
|
+
bucket, key = parse_s3_url(s3_url)
|
52
|
+
|
53
|
+
s3 = boto3.client(S3_BOTO_CLIENT)
|
54
|
+
|
55
|
+
logger.info(f"Downloading {s3_url} to {local_path}...")
|
56
|
+
s3.download_file(bucket, key, local_path)
|
57
|
+
logger.success("Done.")
|
File without changes
|
@@ -0,0 +1,243 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# IMPORTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from enum import Enum
|
8
|
+
from itertools import chain
|
9
|
+
from typing import Any, Callable
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import torch
|
14
|
+
import torch.optim as optim
|
15
|
+
|
16
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
17
|
+
from libinephany.utils import optim_utils
|
18
|
+
|
19
|
+
# ======================================================================================================================
|
20
|
+
#
|
21
|
+
# CONSTANTS
|
22
|
+
#
|
23
|
+
# ======================================================================================================================
|
24
|
+
|
25
|
+
EXP_AVERAGE = "exp_avg"
|
26
|
+
|
27
|
+
# ======================================================================================================================
|
28
|
+
#
|
29
|
+
# CLASSES
|
30
|
+
#
|
31
|
+
# ======================================================================================================================
|
32
|
+
|
33
|
+
|
34
|
+
class StatisticsCallStage(Enum):
|
35
|
+
|
36
|
+
ON_BATCH_END = "on_batch_end"
|
37
|
+
ON_OPTIMIZER_STEP = "on_optimizer_step"
|
38
|
+
ON_TRAIN_END = "on_train_end"
|
39
|
+
|
40
|
+
FORWARD_HOOK = "forward_hook"
|
41
|
+
|
42
|
+
|
43
|
+
class StatisticStorageTypes(Enum):
|
44
|
+
|
45
|
+
TENSOR_STATISTICS = TensorStatistics.__name__
|
46
|
+
FLOAT = float.__name__
|
47
|
+
VECTOR = "vector"
|
48
|
+
|
49
|
+
|
50
|
+
# ======================================================================================================================
|
51
|
+
#
|
52
|
+
# FUNCTIONS
|
53
|
+
#
|
54
|
+
# ======================================================================================================================
|
55
|
+
|
56
|
+
|
57
|
+
def get_exponential_weighted_average(values: list[float]) -> float:
|
58
|
+
"""
|
59
|
+
:param values: List of values to average via EWA.
|
60
|
+
:return: EWA of the given values.
|
61
|
+
"""
|
62
|
+
|
63
|
+
exp_weighted_average = pd.Series(values).ewm(alpha=0.1).mean().iloc[-1]
|
64
|
+
assert isinstance(exp_weighted_average, float)
|
65
|
+
|
66
|
+
return exp_weighted_average
|
67
|
+
|
68
|
+
|
69
|
+
def apply_averaging_function_to_tensor_statistics(
|
70
|
+
tensor_statistics: list[TensorStatistics], averaging_function: Callable[[list[float]], float]
|
71
|
+
) -> TensorStatistics:
|
72
|
+
"""
|
73
|
+
:param tensor_statistics: List of statistics models to average over.
|
74
|
+
:param averaging_function: Function to average the values with.
|
75
|
+
:return: TensorStatistics containing the average over all given tensor statistics.
|
76
|
+
"""
|
77
|
+
|
78
|
+
fields = TensorStatistics.model_fields.keys()
|
79
|
+
averaged_metrics = {
|
80
|
+
field: averaging_function([getattr(statistics, field) for statistics in tensor_statistics]) for field in fields
|
81
|
+
}
|
82
|
+
|
83
|
+
return TensorStatistics(**averaged_metrics)
|
84
|
+
|
85
|
+
|
86
|
+
def apply_averaging_function_to_dictionary_of_tensor_statistics(
|
87
|
+
data: dict[str, list[TensorStatistics]], averaging_function: Callable[[list[float]], float]
|
88
|
+
) -> dict[str, TensorStatistics]:
|
89
|
+
"""
|
90
|
+
:param data: Dictionary mapping parameter group names to list of TensorStatistics from that parameter group.
|
91
|
+
:param averaging_function: Function to average the values with.
|
92
|
+
:return: Dictionary mapping parameter group names to TensorStatistics averaged over all statistics in the given
|
93
|
+
TensorStatistics models.
|
94
|
+
"""
|
95
|
+
|
96
|
+
return {
|
97
|
+
group: apply_averaging_function_to_tensor_statistics(
|
98
|
+
tensor_statistics=metrics, averaging_function=averaging_function
|
99
|
+
)
|
100
|
+
for group, metrics in data.items()
|
101
|
+
}
|
102
|
+
|
103
|
+
|
104
|
+
def apply_averaging_function_to_dictionary_of_metric_lists(
|
105
|
+
data: dict[str, list[float]], averaging_function: Callable[[list[float]], float]
|
106
|
+
) -> dict[str, float]:
|
107
|
+
"""
|
108
|
+
:param data: Dictionary mapping parameter group names to list of metrics from that parameter group.
|
109
|
+
:param averaging_function: Function to average the values with.
|
110
|
+
:return: Dictionary mapping parameter group names to averages over all metrics from each parameter group.
|
111
|
+
"""
|
112
|
+
|
113
|
+
return {group: averaging_function(metrics) for group, metrics in data.items()}
|
114
|
+
|
115
|
+
|
116
|
+
def average_tensor_statistics(tensor_statistics: list[TensorStatistics]) -> TensorStatistics:
|
117
|
+
"""
|
118
|
+
:param tensor_statistics: List of TensorStatistics models to average into one model.
|
119
|
+
:return: Averages over all given tensor statistics models.
|
120
|
+
"""
|
121
|
+
|
122
|
+
averaged = {
|
123
|
+
field: sum([getattr(statistics_model, field) for statistics_model in tensor_statistics])
|
124
|
+
for field in TensorStatistics.model_fields.keys()
|
125
|
+
}
|
126
|
+
averaged = {field: total / len(tensor_statistics) for field, total in averaged.items()}
|
127
|
+
|
128
|
+
return TensorStatistics(**averaged)
|
129
|
+
|
130
|
+
|
131
|
+
def create_one_hot_observation(vector_length: int, one_hot_index: int | None) -> list[int | float]:
|
132
|
+
"""
|
133
|
+
:param vector_length: Length of the one-hot vector.
|
134
|
+
:param one_hot_index: Index of the vector whose element should be set to 1.0, leaving all others as 0.0.
|
135
|
+
:return: Constructed one-hot vector in a list.
|
136
|
+
"""
|
137
|
+
|
138
|
+
if one_hot_index is not None and one_hot_index < 0:
|
139
|
+
raise ValueError("One hot indices must be greater than 0.")
|
140
|
+
|
141
|
+
one_hot = np.zeros(vector_length, dtype=np.int8)
|
142
|
+
|
143
|
+
if one_hot_index is not None:
|
144
|
+
one_hot[one_hot_index] = 1
|
145
|
+
|
146
|
+
as_list = one_hot.tolist()
|
147
|
+
|
148
|
+
assert isinstance(as_list, list), "One-hot vector must be a list."
|
149
|
+
|
150
|
+
return as_list
|
151
|
+
|
152
|
+
|
153
|
+
def create_one_hot_depth_encoding(agent_controlled_modules: list[str], parameter_group_name: str) -> list[int | float]:
|
154
|
+
"""
|
155
|
+
:param agent_controlled_modules: Ordered list of parameter group names in the inner model.
|
156
|
+
:param parameter_group_name: Name of the parameter group to create a depth one-hot vector for.
|
157
|
+
:return: Constructed one-hot depth encoding in a list.
|
158
|
+
|
159
|
+
:note: GANNO encodes depths to one-hot vectors of length 3 regardless of the size of the model.
|
160
|
+
"""
|
161
|
+
|
162
|
+
module_index = agent_controlled_modules.index(parameter_group_name)
|
163
|
+
number_of_modules = len(agent_controlled_modules)
|
164
|
+
|
165
|
+
one_hot_index = min(2, (module_index * 3) // number_of_modules)
|
166
|
+
|
167
|
+
return create_one_hot_observation(vector_length=3, one_hot_index=one_hot_index)
|
168
|
+
|
169
|
+
|
170
|
+
def tensor_on_local_rank(tensor: torch.Tensor | None) -> bool:
|
171
|
+
"""
|
172
|
+
:param tensor: Tensor to check whether it is owned by the local rank, partially or entirely.
|
173
|
+
:return: Whether the tensor is owned by the local rank.
|
174
|
+
"""
|
175
|
+
|
176
|
+
return tensor is not None and tensor.grad is not None and tensor.numel() > 0
|
177
|
+
|
178
|
+
|
179
|
+
def form_update_tensor(
|
180
|
+
optimizer: optim.Optimizer, parameters: list[torch.Tensor], parameter_group: dict[str, Any]
|
181
|
+
) -> None | torch.Tensor:
|
182
|
+
"""
|
183
|
+
:param optimizer: Optimizer to form the update tensor from.
|
184
|
+
:param parameters: Parameters to create the update tensor from.
|
185
|
+
:param parameter_group: Parameter group within the optimizer the given parameters came from.
|
186
|
+
:return: None or the formed update tensor.
|
187
|
+
"""
|
188
|
+
|
189
|
+
if type(optimizer) in optim_utils.ADAM_OPTIMISERS:
|
190
|
+
updates_list = [optimizer.state[p][EXP_AVERAGE].view(-1) for p in parameters if tensor_on_local_rank(p)]
|
191
|
+
return torch.cat(updates_list) if updates_list else None
|
192
|
+
|
193
|
+
elif type(optimizer) in optim_utils.SGD_OPTIMISERS:
|
194
|
+
return optim_utils.compute_sgd_optimizer_update_stats(
|
195
|
+
optimizer=optimizer, parameter_group=parameter_group, parameters=parameters
|
196
|
+
)
|
197
|
+
|
198
|
+
else:
|
199
|
+
raise NotImplementedError(f"Optimizer {type(optimizer).__name__} is not supported!")
|
200
|
+
|
201
|
+
|
202
|
+
def null_standardizer(value_to_standardize: float, **kwargs) -> float:
|
203
|
+
"""
|
204
|
+
:param value_to_standardize: Value to mock the standardization of.
|
205
|
+
:return: Given value to standardize.
|
206
|
+
"""
|
207
|
+
|
208
|
+
return value_to_standardize
|
209
|
+
|
210
|
+
|
211
|
+
def create_sinusoidal_depth_encoding(
|
212
|
+
agent_controlled_modules: list[str], parameter_group_name: str, dimensionality: int
|
213
|
+
) -> list[int | float]:
|
214
|
+
"""
|
215
|
+
:param agent_controlled_modules: Ordered list of parameter group names in the inner model.
|
216
|
+
:param parameter_group_name: Name of the parameter group to create a depth encoding for.
|
217
|
+
:param dimensionality: Length of the depth vector.
|
218
|
+
:return: Sinusoidal depth encoding.
|
219
|
+
"""
|
220
|
+
|
221
|
+
assert dimensionality % 2 == 0, "Dimensionality of a sinusoidal depth encoding must be even."
|
222
|
+
|
223
|
+
depth = agent_controlled_modules.index(parameter_group_name)
|
224
|
+
|
225
|
+
positions = np.arange(dimensionality // 2)
|
226
|
+
frequencies = 1 / (10000 ** (2 * positions / dimensionality))
|
227
|
+
|
228
|
+
encoding = np.zeros(dimensionality)
|
229
|
+
encoding[0::2] = np.sin(depth * frequencies)
|
230
|
+
encoding[1::2] = np.cos(depth * frequencies)
|
231
|
+
|
232
|
+
vector = encoding.tolist()
|
233
|
+
|
234
|
+
return vector
|
235
|
+
|
236
|
+
|
237
|
+
def concatenate_lists(lists: list[list[Any]]) -> list[Any]:
|
238
|
+
"""
|
239
|
+
:param lists: Lists to concatenate.
|
240
|
+
:return: Concatenated lists.
|
241
|
+
"""
|
242
|
+
|
243
|
+
return list(chain(*lists))
|
@@ -0,0 +1,307 @@
|
|
1
|
+
# ======================================================================================================================
|
2
|
+
#
|
3
|
+
# IMPORTS
|
4
|
+
#
|
5
|
+
# ======================================================================================================================
|
6
|
+
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
import gymnasium as gym
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from libinephany.observations.observers.observer_containers import GlobalObserverContainer, LocalObserverContainer
|
13
|
+
from libinephany.observations.post_processors import postprocessors
|
14
|
+
from libinephany.observations.post_processors.postprocessors import ObservationPostProcessor
|
15
|
+
from libinephany.pydantic_models.configs.observer_config import AgentObserverConfig, ObserverConfig
|
16
|
+
from libinephany.pydantic_models.schemas.observation_models import ObservationInputs
|
17
|
+
from libinephany.pydantic_models.schemas.tensor_statistics import TensorStatistics
|
18
|
+
from libinephany.pydantic_models.states.hyperparameter_states import HyperparameterStates
|
19
|
+
from libinephany.utils.standardizers import Standardizer
|
20
|
+
from libinephany.utils.typing import ObservationInformation
|
21
|
+
|
22
|
+
# ======================================================================================================================
|
23
|
+
#
|
24
|
+
# CLASSES
|
25
|
+
#
|
26
|
+
# ======================================================================================================================
|
27
|
+
|
28
|
+
|
29
|
+
class ObserverPipeline:
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
observer_config: ObserverConfig,
|
34
|
+
agent_config: AgentObserverConfig,
|
35
|
+
standardizer: Standardizer | None,
|
36
|
+
agent_id_to_modules: dict[str, str | None],
|
37
|
+
):
|
38
|
+
"""
|
39
|
+
:param observer_config: ObserverConfig that contains various parameters applicable to all agent observers as
|
40
|
+
well as the observer configs for specific agents.
|
41
|
+
:param agent_config: AgentObserverConfig storing configuration of the agent's actions and observation space for
|
42
|
+
this type of agent.
|
43
|
+
:param standardizer: None or the standardizer to apply to the returned observations.
|
44
|
+
:param agent_id_to_modules: Dictionary mapping agent IDs to None or the name of the parameter group that agent
|
45
|
+
is modulating.
|
46
|
+
"""
|
47
|
+
|
48
|
+
self.observer_config = observer_config
|
49
|
+
self.agent_config = agent_config
|
50
|
+
self.standardizer = standardizer
|
51
|
+
self.agent_id_to_modules = agent_id_to_modules
|
52
|
+
|
53
|
+
self.global_observers = GlobalObserverContainer(
|
54
|
+
global_config=observer_config,
|
55
|
+
agent_config=agent_config,
|
56
|
+
standardizer=standardizer,
|
57
|
+
)
|
58
|
+
|
59
|
+
self.local_observers: dict[str, LocalObserverContainer] = self._build_local_observers(
|
60
|
+
agent_id_to_modules=agent_id_to_modules,
|
61
|
+
)
|
62
|
+
|
63
|
+
self.post_processors: list[ObservationPostProcessor] = self._build_post_processors()
|
64
|
+
|
65
|
+
self.clipping_threshold = self.observer_config.observation_clipping_threshold
|
66
|
+
self.invalid_threshold = self.observer_config.invalid_observation_threshold
|
67
|
+
|
68
|
+
@property
|
69
|
+
def observation_size(self) -> int:
|
70
|
+
"""
|
71
|
+
:return: Total size of the observation vectors produced by this pipeline.
|
72
|
+
"""
|
73
|
+
|
74
|
+
globals_size = self.global_observers.total_observer_size
|
75
|
+
locals_size = list(self.local_observers.values())[0].total_observer_size
|
76
|
+
|
77
|
+
total_size = globals_size + locals_size
|
78
|
+
|
79
|
+
if self.agent_config.prepend_invalid_indicator:
|
80
|
+
total_size += 1
|
81
|
+
|
82
|
+
return total_size
|
83
|
+
|
84
|
+
def _build_local_observers(
|
85
|
+
self,
|
86
|
+
agent_id_to_modules: dict[str, str | None],
|
87
|
+
) -> dict[str, LocalObserverContainer]:
|
88
|
+
"""
|
89
|
+
:param agent_id_to_modules: Dictionary mapping agent IDs to None or the name of the parameter group that agent
|
90
|
+
is modulating.
|
91
|
+
:return: Dictionary mapping agent IDs to LocalObserverContainers for each agent.
|
92
|
+
"""
|
93
|
+
|
94
|
+
observers = {}
|
95
|
+
|
96
|
+
for agent_id, parameter_group_name in agent_id_to_modules.items():
|
97
|
+
observers[agent_id] = LocalObserverContainer(
|
98
|
+
global_config=self.observer_config,
|
99
|
+
agent_config=self.agent_config,
|
100
|
+
standardizer=self.standardizer,
|
101
|
+
agent_id=agent_id,
|
102
|
+
parameter_group_name=parameter_group_name,
|
103
|
+
)
|
104
|
+
|
105
|
+
return observers
|
106
|
+
|
107
|
+
def _build_post_processors(self) -> list[ObservationPostProcessor]:
|
108
|
+
"""
|
109
|
+
:return: List of post-processors to apply to the global and local observations.
|
110
|
+
"""
|
111
|
+
|
112
|
+
if self.agent_config.postprocessors is None:
|
113
|
+
return []
|
114
|
+
|
115
|
+
post_processors = []
|
116
|
+
|
117
|
+
for observer_name, post_processor_kwargs in self.agent_config.postprocessors.items():
|
118
|
+
try:
|
119
|
+
post_processor_type: type[ObservationPostProcessor] = getattr(postprocessors, observer_name)
|
120
|
+
|
121
|
+
except AttributeError as e:
|
122
|
+
raise AttributeError(f"The class {observer_name} does not exist within {postprocessors}!") from e
|
123
|
+
|
124
|
+
if post_processor_kwargs is None:
|
125
|
+
post_processor_kwargs = {}
|
126
|
+
|
127
|
+
post_processor = post_processor_type(
|
128
|
+
observer_config=self.observer_config,
|
129
|
+
**post_processor_kwargs,
|
130
|
+
)
|
131
|
+
|
132
|
+
post_processors.append(post_processor)
|
133
|
+
|
134
|
+
return post_processors
|
135
|
+
|
136
|
+
@staticmethod
|
137
|
+
def merge_globals_to_locals(
|
138
|
+
global_obs: list[float | int],
|
139
|
+
local_obs: dict[str, list[float | int]],
|
140
|
+
) -> dict[str, list[float | int]]:
|
141
|
+
"""
|
142
|
+
:param global_obs: Global observations to post-process.
|
143
|
+
:param local_obs: Dictionary mapping agent IDs to local observations of that agent.
|
144
|
+
:return: Tuple of clipped global and local observations.
|
145
|
+
:return: Dictionary mapping agent ID to that agent's completed observation vector.
|
146
|
+
"""
|
147
|
+
|
148
|
+
return {agent_id: global_obs + agent_obs for agent_id, agent_obs in local_obs.items()}
|
149
|
+
|
150
|
+
def get_required_trackers(self) -> list[dict[str, dict[str, Any] | None]]:
|
151
|
+
"""
|
152
|
+
:return: List of trackers required by each of the stored observers.
|
153
|
+
"""
|
154
|
+
|
155
|
+
required_trackers = self.global_observers.get_required_trackers()
|
156
|
+
|
157
|
+
for agent_container in self.local_observers.values():
|
158
|
+
required_trackers += agent_container.get_required_trackers()
|
159
|
+
|
160
|
+
return required_trackers
|
161
|
+
|
162
|
+
def clip_observation_vector(self, observation_vector: list[float | int]) -> tuple[bool, list[float | int]]:
|
163
|
+
"""
|
164
|
+
:param observation_vector: Observations to clip.
|
165
|
+
:return: Tuple indicating whether an observation was clipped and a list of the post-processed observations.
|
166
|
+
"""
|
167
|
+
|
168
|
+
invalid_encountered = False
|
169
|
+
post_processed = []
|
170
|
+
|
171
|
+
for observation in observation_vector:
|
172
|
+
if abs(observation) >= self.clipping_threshold:
|
173
|
+
if abs(observation) >= self.invalid_threshold:
|
174
|
+
invalid_encountered = True
|
175
|
+
|
176
|
+
observation = -self.clipping_threshold if observation < 0 else self.clipping_threshold
|
177
|
+
|
178
|
+
elif np.isnan(observation):
|
179
|
+
observation = self.observer_config.invalid_observation_replacement_value
|
180
|
+
invalid_encountered = True
|
181
|
+
|
182
|
+
post_processed.append(observation)
|
183
|
+
|
184
|
+
return invalid_encountered, post_processed
|
185
|
+
|
186
|
+
def clip_observations(
|
187
|
+
self,
|
188
|
+
globals_to_clip: list[float | int],
|
189
|
+
locals_to_clip: dict[str, list[float | int]],
|
190
|
+
) -> tuple[list[float | int], dict[str, list[float | int]], bool]:
|
191
|
+
"""
|
192
|
+
:param globals_to_clip: Global observations to post-process.
|
193
|
+
:param locals_to_clip: Dictionary mapping agent IDs to local observations of that agent.
|
194
|
+
:return: Tuple of clipped global and local observations and whether any observation clipping occurred.
|
195
|
+
"""
|
196
|
+
|
197
|
+
invalid_encountered, globals_to_clip = self.clip_observation_vector(observation_vector=globals_to_clip)
|
198
|
+
post_processed_locals = {}
|
199
|
+
|
200
|
+
for agent_id, agent_observations in locals_to_clip.items():
|
201
|
+
agent_invalid_encountered, post_processed = self.clip_observation_vector(
|
202
|
+
observation_vector=agent_observations
|
203
|
+
)
|
204
|
+
|
205
|
+
post_processed_locals[agent_id] = post_processed
|
206
|
+
invalid_encountered = invalid_encountered if invalid_encountered else agent_invalid_encountered
|
207
|
+
|
208
|
+
if self.agent_config.prepend_invalid_indicator:
|
209
|
+
globals_to_clip = [int(invalid_encountered)] + globals_to_clip
|
210
|
+
|
211
|
+
return globals_to_clip, post_processed_locals, invalid_encountered
|
212
|
+
|
213
|
+
def observe(
|
214
|
+
self,
|
215
|
+
observation_inputs: ObservationInputs,
|
216
|
+
hyperparameter_states: HyperparameterStates,
|
217
|
+
tracked_statistics: dict[str, dict[str, float | TensorStatistics]],
|
218
|
+
actions_taken: dict[str, float | int | None],
|
219
|
+
) -> tuple[dict[str, list[float | int]], bool]:
|
220
|
+
"""
|
221
|
+
:param observation_inputs: Observation input metrics not calculated with statistic trackers.
|
222
|
+
:param hyperparameter_states: HyperparameterStates that manages the hyperparameters.
|
223
|
+
:param tracked_statistics: Dictionary mapping statistic tracker class names to dictionaries mapping module
|
224
|
+
names to floats or TensorStatistic models.
|
225
|
+
:param actions_taken: Dictionary mapping agent IDs to actions taken by that agent.
|
226
|
+
:return: Tuple of a dictionary mapping agent ID to that agent's completed observation vector and a boolean
|
227
|
+
indicating whether an observation clip occurred.
|
228
|
+
"""
|
229
|
+
|
230
|
+
global_obs = self.global_observers.observe(
|
231
|
+
observation_inputs=observation_inputs,
|
232
|
+
hyperparameter_states=hyperparameter_states,
|
233
|
+
tracked_statistics=tracked_statistics,
|
234
|
+
action_taken=None,
|
235
|
+
)
|
236
|
+
|
237
|
+
local_obs = {
|
238
|
+
agent_id: agent_observers.observe(
|
239
|
+
observation_inputs=observation_inputs,
|
240
|
+
hyperparameter_states=hyperparameter_states,
|
241
|
+
tracked_statistics=tracked_statistics,
|
242
|
+
action_taken=actions_taken[agent_id] if agent_id in actions_taken else None,
|
243
|
+
)
|
244
|
+
for agent_id, agent_observers in self.local_observers.items()
|
245
|
+
}
|
246
|
+
|
247
|
+
for post_processor in self.post_processors:
|
248
|
+
global_obs, local_obs = post_processor.postprocess(
|
249
|
+
global_observations=global_obs, local_observations=local_obs
|
250
|
+
)
|
251
|
+
|
252
|
+
global_obs, local_obs, obs_clipped = self.clip_observations(
|
253
|
+
globals_to_clip=global_obs, locals_to_clip=local_obs
|
254
|
+
)
|
255
|
+
|
256
|
+
return self.merge_globals_to_locals(global_obs=global_obs, local_obs=local_obs), obs_clipped
|
257
|
+
|
258
|
+
def inform(self) -> tuple[ObservationInformation, dict[str, ObservationInformation]]:
|
259
|
+
"""
|
260
|
+
:return: Dictionary of observation info to add to the agent info.
|
261
|
+
"""
|
262
|
+
|
263
|
+
global_information = self.global_observers.inform()
|
264
|
+
|
265
|
+
agent_information = {agent_id: observer.inform() for agent_id, observer in self.local_observers.items()}
|
266
|
+
|
267
|
+
return global_information, agent_information
|
268
|
+
|
269
|
+
def get_observation_spaces(self) -> dict[str, gym.spaces.Box]:
|
270
|
+
"""
|
271
|
+
:return: Dictionary mapping agent IDs to their observation spaces.
|
272
|
+
"""
|
273
|
+
|
274
|
+
return {
|
275
|
+
agent_id: gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.observation_size,), dtype=np.float32)
|
276
|
+
for agent_id in self.local_observers
|
277
|
+
}
|
278
|
+
|
279
|
+
def reset(self) -> None:
|
280
|
+
"""
|
281
|
+
Resets all global and local observers.
|
282
|
+
"""
|
283
|
+
|
284
|
+
self.global_observers.reset()
|
285
|
+
|
286
|
+
for local_observers in self.local_observers.values():
|
287
|
+
local_observers.reset()
|
288
|
+
|
289
|
+
def train(self) -> None:
|
290
|
+
"""
|
291
|
+
Sets all observer containers into training mode.
|
292
|
+
"""
|
293
|
+
|
294
|
+
self.global_observers.train()
|
295
|
+
|
296
|
+
for local_observers in self.local_observers.values():
|
297
|
+
local_observers.train()
|
298
|
+
|
299
|
+
def infer(self) -> None:
|
300
|
+
"""
|
301
|
+
Sets all observer containers into inference mode.
|
302
|
+
"""
|
303
|
+
|
304
|
+
self.global_observers.infer()
|
305
|
+
|
306
|
+
for local_observers in self.local_observers.values():
|
307
|
+
local_observers.infer()
|
File without changes
|