humalab 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of humalab might be problematic. Click here for more details.
- humalab/__init__.py +11 -0
- humalab/assets/__init__.py +2 -2
- humalab/assets/files/resource_file.py +29 -3
- humalab/assets/files/urdf_file.py +14 -10
- humalab/assets/resource_operator.py +91 -0
- humalab/constants.py +39 -5
- humalab/dists/bernoulli.py +16 -0
- humalab/dists/categorical.py +4 -0
- humalab/dists/discrete.py +22 -0
- humalab/dists/gaussian.py +22 -0
- humalab/dists/log_uniform.py +22 -0
- humalab/dists/truncated_gaussian.py +36 -0
- humalab/dists/uniform.py +22 -0
- humalab/episode.py +196 -0
- humalab/humalab.py +116 -153
- humalab/humalab_api_client.py +760 -62
- humalab/humalab_config.py +0 -13
- humalab/humalab_test.py +46 -29
- humalab/metrics/__init__.py +5 -5
- humalab/metrics/code.py +28 -0
- humalab/metrics/metric.py +41 -108
- humalab/metrics/scenario_stats.py +95 -0
- humalab/metrics/summary.py +24 -18
- humalab/run.py +180 -115
- humalab/scenarios/__init__.py +4 -0
- humalab/scenarios/scenario.py +372 -0
- humalab/scenarios/scenario_operator.py +82 -0
- humalab/{scenario_test.py → scenarios/scenario_test.py} +150 -269
- humalab/utils.py +37 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/METADATA +1 -1
- humalab-0.0.6.dist-info/RECORD +39 -0
- humalab/assets/resource_manager.py +0 -57
- humalab/metrics/dist_metric.py +0 -22
- humalab/scenario.py +0 -225
- humalab-0.0.4.dist-info/RECORD +0 -34
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/WHEEL +0 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/entry_points.txt +0 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,372 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from threading import RLock
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from omegaconf import OmegaConf, DictConfig, ListConfig
|
|
6
|
+
import yaml
|
|
7
|
+
from humalab.dists.bernoulli import Bernoulli
|
|
8
|
+
from humalab.dists.categorical import Categorical
|
|
9
|
+
from humalab.dists.uniform import Uniform
|
|
10
|
+
from humalab.dists.discrete import Discrete
|
|
11
|
+
from humalab.dists.log_uniform import LogUniform
|
|
12
|
+
from humalab.dists.gaussian import Gaussian
|
|
13
|
+
from humalab.dists.truncated_gaussian import TruncatedGaussian
|
|
14
|
+
from functools import partial
|
|
15
|
+
from humalab.constants import GraphType, MetricDimType
|
|
16
|
+
import copy
|
|
17
|
+
import uuid
|
|
18
|
+
|
|
19
|
+
DISTRIBUTION_MAP = {
|
|
20
|
+
# 0D distributions
|
|
21
|
+
"uniform": Uniform,
|
|
22
|
+
"bernoulli": Bernoulli,
|
|
23
|
+
"categorical": Categorical,
|
|
24
|
+
"discrete": Discrete,
|
|
25
|
+
"log_uniform": LogUniform,
|
|
26
|
+
"gaussian": Gaussian,
|
|
27
|
+
"truncated_gaussian": TruncatedGaussian,
|
|
28
|
+
|
|
29
|
+
# 1D distributions
|
|
30
|
+
"uniform_1d": Uniform,
|
|
31
|
+
"bernoulli_1d": Bernoulli,
|
|
32
|
+
"categorical_1d": Categorical,
|
|
33
|
+
"discrete_1d": Discrete,
|
|
34
|
+
"log_uniform_1d": LogUniform,
|
|
35
|
+
"gaussian_1d": Gaussian,
|
|
36
|
+
"truncated_gaussian_1d": TruncatedGaussian,
|
|
37
|
+
|
|
38
|
+
# 2D distributions
|
|
39
|
+
"uniform_2d": Uniform,
|
|
40
|
+
# "bernoulli_2d": Bernoulli,
|
|
41
|
+
# "categorical_2d": Categorical,
|
|
42
|
+
# "discrete_2d": Discrete,
|
|
43
|
+
# "log_uniform_2d": LogUniform,
|
|
44
|
+
"gaussian_2d": Gaussian,
|
|
45
|
+
"truncated_gaussian_2d": TruncatedGaussian,
|
|
46
|
+
|
|
47
|
+
# 3D distributions
|
|
48
|
+
"uniform_3d": Uniform,
|
|
49
|
+
# "bernoulli_3d": Bernoulli,
|
|
50
|
+
# "categorical_3d": Categorical,
|
|
51
|
+
# "discrete_3d": Discrete,
|
|
52
|
+
# "log_uniform_3d": LogUniform,
|
|
53
|
+
"gaussian_3d": Gaussian,
|
|
54
|
+
"truncated_gaussian_3d": TruncatedGaussian,
|
|
55
|
+
|
|
56
|
+
# 4D distributions
|
|
57
|
+
# "uniform_4d": Uniform,
|
|
58
|
+
# "bernoulli_4d": Bernoulli,
|
|
59
|
+
# "categorical_4d": Categorical,
|
|
60
|
+
# "discrete_4d": Discrete,
|
|
61
|
+
# "log_uniform_4d": LogUniform,
|
|
62
|
+
# "gaussian_4d": Gaussian,
|
|
63
|
+
# "truncated_gaussian_4d": TruncatedGaussian,
|
|
64
|
+
|
|
65
|
+
# nD distributions
|
|
66
|
+
# "uniform_nd": Uniform,
|
|
67
|
+
# "bernoulli_nd": Bernoulli,
|
|
68
|
+
# "categorical_nd": Categorical,
|
|
69
|
+
# "discrete_nd": Discrete,
|
|
70
|
+
# "log_uniform_nd": LogUniform,
|
|
71
|
+
# "gaussian_nd": Gaussian,
|
|
72
|
+
# "truncated_gaussian_nd": TruncatedGaussian,
|
|
73
|
+
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
DISTRIBUTION_DIMENSION_MAP = {
|
|
77
|
+
# 0D distributions
|
|
78
|
+
"uniform": 0,
|
|
79
|
+
"bernoulli": 0,
|
|
80
|
+
"categorical": 0,
|
|
81
|
+
"discrete": 0,
|
|
82
|
+
"log_uniform": 0,
|
|
83
|
+
"gaussian": 0,
|
|
84
|
+
"truncated_gaussian": 0,
|
|
85
|
+
|
|
86
|
+
# 1D distributions
|
|
87
|
+
"uniform_1d": 1,
|
|
88
|
+
"bernoulli_1d": 1,
|
|
89
|
+
"categorical_1d": 1,
|
|
90
|
+
"discrete_1d": 1,
|
|
91
|
+
"log_uniform_1d": 1,
|
|
92
|
+
"gaussian_1d": 1,
|
|
93
|
+
"truncated_gaussian_1d": 1,
|
|
94
|
+
|
|
95
|
+
# 2D distributions
|
|
96
|
+
"uniform_2d": 2,
|
|
97
|
+
"gaussian_2d": 2,
|
|
98
|
+
"truncated_gaussian_2d": 2,
|
|
99
|
+
|
|
100
|
+
# 3D distributions
|
|
101
|
+
"uniform_3d": 3,
|
|
102
|
+
"gaussian_3d": 3,
|
|
103
|
+
"truncated_gaussian_3d": 3,
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
DISTRIBUTION_PARAM_NUM_MAP = {
|
|
107
|
+
# 0D distributions
|
|
108
|
+
"uniform": 2,
|
|
109
|
+
"bernoulli": 1,
|
|
110
|
+
"categorical": 2,
|
|
111
|
+
"discrete": 3,
|
|
112
|
+
"log_uniform": 2,
|
|
113
|
+
"gaussian": 2,
|
|
114
|
+
"truncated_gaussian": 4,
|
|
115
|
+
|
|
116
|
+
# 1D distributions
|
|
117
|
+
"uniform_1d": 2,
|
|
118
|
+
"bernoulli_1d": 1,
|
|
119
|
+
"categorical_1d": 2,
|
|
120
|
+
"discrete_1d": 3,
|
|
121
|
+
"log_uniform_1d": 2,
|
|
122
|
+
"gaussian_1d": 2,
|
|
123
|
+
"truncated_gaussian_1d": 4,
|
|
124
|
+
|
|
125
|
+
# 2D distributions
|
|
126
|
+
"uniform_2d": 2,
|
|
127
|
+
"gaussian_2d": 2,
|
|
128
|
+
"truncated_gaussian_2d": 4,
|
|
129
|
+
|
|
130
|
+
# 3D distributions
|
|
131
|
+
"uniform_3d": 2,
|
|
132
|
+
"gaussian_3d": 2,
|
|
133
|
+
"truncated_gaussian_3d": 4,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
SCENARIO_STATS_DIM_TYPE_MAP = {
|
|
137
|
+
# 0D distributions
|
|
138
|
+
"uniform": MetricDimType.ONE_D,
|
|
139
|
+
"bernoulli": MetricDimType.ONE_D,
|
|
140
|
+
"categorical": MetricDimType.ONE_D,
|
|
141
|
+
"discrete": MetricDimType.ONE_D,
|
|
142
|
+
"log_uniform": MetricDimType.ONE_D,
|
|
143
|
+
"gaussian": MetricDimType.ONE_D,
|
|
144
|
+
"truncated_gaussian": MetricDimType.ONE_D,
|
|
145
|
+
|
|
146
|
+
# 1D distributions
|
|
147
|
+
"uniform_1d": MetricDimType.ONE_D,
|
|
148
|
+
"bernoulli_1d": MetricDimType.ONE_D,
|
|
149
|
+
"categorical_1d": MetricDimType.ONE_D,
|
|
150
|
+
"discrete_1d": MetricDimType.ONE_D,
|
|
151
|
+
"log_uniform_1d": MetricDimType.ONE_D,
|
|
152
|
+
"gaussian_1d": MetricDimType.ONE_D,
|
|
153
|
+
"truncated_gaussian_1d": MetricDimType.ONE_D,
|
|
154
|
+
|
|
155
|
+
# 2D distributions
|
|
156
|
+
"uniform_2d": MetricDimType.TWO_D,
|
|
157
|
+
"gaussian_2d": MetricDimType.TWO_D,
|
|
158
|
+
"truncated_gaussian_2d": MetricDimType.TWO_D,
|
|
159
|
+
|
|
160
|
+
# 3D distributions
|
|
161
|
+
"uniform_3d": MetricDimType.THREE_D,
|
|
162
|
+
"gaussian_3d": MetricDimType.THREE_D,
|
|
163
|
+
"truncated_gaussian_3d": MetricDimType.THREE_D,
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
DISTRIBUTION_GRAPH_TYPE = {
|
|
167
|
+
# 0D distributions
|
|
168
|
+
"uniform": GraphType.HISTOGRAM,
|
|
169
|
+
"bernoulli": GraphType.BAR,
|
|
170
|
+
"categorical": GraphType.BAR,
|
|
171
|
+
"discrete": GraphType.BAR,
|
|
172
|
+
"log_uniform": GraphType.HISTOGRAM,
|
|
173
|
+
"gaussian": GraphType.GAUSSIAN,
|
|
174
|
+
"truncated_gaussian": GraphType.GAUSSIAN,
|
|
175
|
+
|
|
176
|
+
# 1D distributions
|
|
177
|
+
"uniform_1d": GraphType.HISTOGRAM,
|
|
178
|
+
"bernoulli_1d": GraphType.BAR,
|
|
179
|
+
"categorical_1d": GraphType.BAR,
|
|
180
|
+
"discrete_1d": GraphType.BAR,
|
|
181
|
+
"log_uniform_1d": GraphType.HISTOGRAM,
|
|
182
|
+
"gaussian_1d": GraphType.GAUSSIAN,
|
|
183
|
+
"truncated_gaussian_1d": GraphType.GAUSSIAN,
|
|
184
|
+
|
|
185
|
+
# 2D distributions
|
|
186
|
+
"uniform_2d": GraphType.SCATTER,
|
|
187
|
+
"gaussian_2d": GraphType.HEATMAP,
|
|
188
|
+
"truncated_gaussian_2d": GraphType.HEATMAP,
|
|
189
|
+
|
|
190
|
+
# 3D distributions
|
|
191
|
+
"uniform_3d": GraphType.THREE_D_MAP,
|
|
192
|
+
"gaussian_3d": GraphType.THREE_D_MAP,
|
|
193
|
+
"truncated_gaussian_3d": GraphType.THREE_D_MAP,
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
class Scenario:
|
|
197
|
+
dist_cache = {}
|
|
198
|
+
def __init__(self) -> None:
|
|
199
|
+
self._generator = np.random.default_rng()
|
|
200
|
+
self._scenario_template = OmegaConf.create()
|
|
201
|
+
self._cur_scenario = OmegaConf.create()
|
|
202
|
+
self._scenario_id = None
|
|
203
|
+
|
|
204
|
+
self._episode_vals = {}
|
|
205
|
+
self._lock = RLock()
|
|
206
|
+
|
|
207
|
+
def init(self,
|
|
208
|
+
scenario: str | list | dict | None = None,
|
|
209
|
+
seed: int | None=None,
|
|
210
|
+
scenario_id: str | None=None,
|
|
211
|
+
# num_env: int | None = None
|
|
212
|
+
) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Initialize the scenario with the given parameters.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
episode_id: The ID of the current episode.
|
|
218
|
+
scenario: The scenario configuration (YAML string, list, or dict).
|
|
219
|
+
seed: Optional seed for random number generation.
|
|
220
|
+
scenario_id: Optional scenario ID. If None, a new UUID is generated.
|
|
221
|
+
# num_env: Optional number of parallel environments.
|
|
222
|
+
"""
|
|
223
|
+
self._num_env = None # num_env
|
|
224
|
+
|
|
225
|
+
# Parse scenario id
|
|
226
|
+
scenario_version = 1
|
|
227
|
+
if scenario_id is not None:
|
|
228
|
+
scenario_arr = scenario_id.split(":")
|
|
229
|
+
if len(scenario_arr) < 1:
|
|
230
|
+
raise ValueError("Invalid scenario_id format. Expected 'scenario_id' or 'scenario_name:version'.")
|
|
231
|
+
scenario_id = scenario_arr[0]
|
|
232
|
+
scenario_version = int(scenario_arr[1]) if len(scenario_arr) > 1 else None
|
|
233
|
+
self._scenario_id = scenario_id or str(uuid.uuid4())
|
|
234
|
+
self._scenario_version = scenario_version
|
|
235
|
+
|
|
236
|
+
self._generator = np.random.default_rng(seed)
|
|
237
|
+
self._configure()
|
|
238
|
+
scenario = scenario or {}
|
|
239
|
+
|
|
240
|
+
self._scenario_template = OmegaConf.create(scenario)
|
|
241
|
+
|
|
242
|
+
def _validate_distribution_params(self, dist_name: str, *args: tuple) -> None:
|
|
243
|
+
dimensions = DISTRIBUTION_DIMENSION_MAP[dist_name]
|
|
244
|
+
if not DISTRIBUTION_MAP[dist_name].validate(dimensions, *args):
|
|
245
|
+
raise ValueError(f"Invalid parameters for distribution {dist_name} with dimensions {dimensions}: {args}")
|
|
246
|
+
|
|
247
|
+
def _get_final_size(self, size: int | tuple[int, ...] | None) -> int | tuple[int, ...] | None:
|
|
248
|
+
n = self._num_env
|
|
249
|
+
if size is None:
|
|
250
|
+
return n
|
|
251
|
+
if n is None:
|
|
252
|
+
return size
|
|
253
|
+
if isinstance(size, int):
|
|
254
|
+
return (n, size)
|
|
255
|
+
return (n, *size)
|
|
256
|
+
|
|
257
|
+
def _get_node_path(self, root: dict | list, node: str) -> str:
|
|
258
|
+
if isinstance(root, list):
|
|
259
|
+
root = {str(i): v for i, v in enumerate(root)}
|
|
260
|
+
|
|
261
|
+
for key, value in root.items():
|
|
262
|
+
if value == node:
|
|
263
|
+
return str(key)
|
|
264
|
+
if isinstance(value, dict):
|
|
265
|
+
sub_path = self._get_node_path(value, node)
|
|
266
|
+
if sub_path:
|
|
267
|
+
return f"{key}.{sub_path}"
|
|
268
|
+
elif isinstance(value, list):
|
|
269
|
+
for idx, item in enumerate(value):
|
|
270
|
+
if item == node:
|
|
271
|
+
return f"{key}[{idx}]"
|
|
272
|
+
if isinstance(item, (dict, list)):
|
|
273
|
+
sub_path = self._get_node_path(item, node)
|
|
274
|
+
if sub_path:
|
|
275
|
+
return f"{key}[{idx}].{sub_path}"
|
|
276
|
+
return ""
|
|
277
|
+
|
|
278
|
+
@staticmethod
|
|
279
|
+
def _convert_to_python(obj) -> Any:
|
|
280
|
+
if not isinstance(obj, (np.ndarray, np.generic)):
|
|
281
|
+
return obj
|
|
282
|
+
|
|
283
|
+
# NumPy scalar (np.generic) or 0-D ndarray
|
|
284
|
+
if isinstance(obj, np.generic) or (isinstance(obj, np.ndarray) and obj.ndim == 0):
|
|
285
|
+
return obj.item()
|
|
286
|
+
|
|
287
|
+
# Regular ndarray (1-D or higher)
|
|
288
|
+
if isinstance(obj, np.ndarray):
|
|
289
|
+
return obj.tolist()
|
|
290
|
+
|
|
291
|
+
return obj
|
|
292
|
+
|
|
293
|
+
def _configure(self) -> None:
|
|
294
|
+
self._clear_resolvers()
|
|
295
|
+
def distribution_resolver(dist_name: str, *args, _node_, _root_, _parent_, **kwargs):
|
|
296
|
+
if len(args) > DISTRIBUTION_PARAM_NUM_MAP[dist_name]:
|
|
297
|
+
args = args[:DISTRIBUTION_PARAM_NUM_MAP[dist_name]]
|
|
298
|
+
print(f"Warning: Too many parameters for {dist_name}, expected {DISTRIBUTION_PARAM_NUM_MAP[dist_name]}, got {len(args)}. Extra parameters will be ignored.")
|
|
299
|
+
|
|
300
|
+
self._validate_distribution_params(dist_name, *args)
|
|
301
|
+
# print("_node_: ", _node_, type(_node_))
|
|
302
|
+
# print("_root_: ", _root_, type(_root_))
|
|
303
|
+
# print("_parent_: ", _parent_, type(_parent_))
|
|
304
|
+
# print("Args: ", args, len(args))
|
|
305
|
+
# print("Kwargs: ", kwargs)
|
|
306
|
+
|
|
307
|
+
root_yaml = yaml.safe_load(OmegaConf.to_yaml(_root_))
|
|
308
|
+
key_path = self._get_node_path(root_yaml, str(_node_))
|
|
309
|
+
|
|
310
|
+
shape = None
|
|
311
|
+
|
|
312
|
+
if DISTRIBUTION_DIMENSION_MAP[dist_name] == -1:
|
|
313
|
+
shape = args[DISTRIBUTION_PARAM_NUM_MAP[dist_name] - 1]
|
|
314
|
+
args = args[:-1]
|
|
315
|
+
else:
|
|
316
|
+
shape = DISTRIBUTION_DIMENSION_MAP[dist_name] if DISTRIBUTION_DIMENSION_MAP[dist_name] > 0 else None
|
|
317
|
+
shape = self._get_final_size(shape)
|
|
318
|
+
|
|
319
|
+
key = str(_node_)
|
|
320
|
+
if key not in Scenario.dist_cache:
|
|
321
|
+
Scenario.dist_cache[key] = DISTRIBUTION_MAP[dist_name].create(self._generator, *args, size=shape, **kwargs)
|
|
322
|
+
ret_val = Scenario.dist_cache[key].sample()
|
|
323
|
+
ret_val = Scenario._convert_to_python(ret_val)
|
|
324
|
+
|
|
325
|
+
if isinstance(ret_val, list):
|
|
326
|
+
ret_val = ListConfig(ret_val)
|
|
327
|
+
|
|
328
|
+
self._episode_vals[key_path] = {
|
|
329
|
+
"value": ret_val,
|
|
330
|
+
"distribution": dist_name,
|
|
331
|
+
"graph_type": DISTRIBUTION_GRAPH_TYPE[dist_name],
|
|
332
|
+
"metric_dim_type": SCENARIO_STATS_DIM_TYPE_MAP[dist_name],
|
|
333
|
+
}
|
|
334
|
+
return ret_val
|
|
335
|
+
|
|
336
|
+
for dist_name in DISTRIBUTION_MAP.keys():
|
|
337
|
+
OmegaConf.register_new_resolver(dist_name, partial(distribution_resolver, dist_name))
|
|
338
|
+
|
|
339
|
+
def _clear_resolvers(self) -> None:
|
|
340
|
+
self.dist_cache.clear()
|
|
341
|
+
OmegaConf.clear_resolvers()
|
|
342
|
+
|
|
343
|
+
def resolve(self) -> tuple[DictConfig | ListConfig, dict]:
|
|
344
|
+
"""Resolve the scenario configuration, sampling all distributions.
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
tuple[DictConfig | ListConfig, dict]: The resolved scenario and episode values.
|
|
348
|
+
"""
|
|
349
|
+
with self._lock:
|
|
350
|
+
cur_scenario = copy.deepcopy(self._scenario_template)
|
|
351
|
+
self._episode_vals = {}
|
|
352
|
+
OmegaConf.resolve(cur_scenario)
|
|
353
|
+
episode_vals = copy.deepcopy(self._episode_vals)
|
|
354
|
+
return cur_scenario, episode_vals
|
|
355
|
+
|
|
356
|
+
@property
|
|
357
|
+
def template(self) -> Any:
|
|
358
|
+
"""The template scenario configuration.
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
Any: The template scenario as an OmegaConf object.
|
|
362
|
+
"""
|
|
363
|
+
return self._scenario_template
|
|
364
|
+
|
|
365
|
+
@property
|
|
366
|
+
def yaml(self) -> str:
|
|
367
|
+
"""The current scenario configuration as a YAML string.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
str: The current scenario as a YAML string.
|
|
371
|
+
"""
|
|
372
|
+
return OmegaConf.to_yaml(self._scenario_template)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
from humalab.humalab_api_client import HumaLabApiClient
|
|
5
|
+
from humalab.scenarios.scenario import Scenario
|
|
6
|
+
from humalab.constants import DEFAULT_PROJECT
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class ScenarioMetadata:
|
|
10
|
+
id: str
|
|
11
|
+
version: int
|
|
12
|
+
project: str
|
|
13
|
+
name: str
|
|
14
|
+
description: str | None
|
|
15
|
+
created_at: str
|
|
16
|
+
updated_at: str
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def list_scenarios(project: str = DEFAULT_PROJECT,
|
|
20
|
+
limit: int = 20,
|
|
21
|
+
offset: int = 0,
|
|
22
|
+
include_inactive: bool = False,
|
|
23
|
+
search: Optional[str] = None,
|
|
24
|
+
status_filter: Optional[str] = None,
|
|
25
|
+
|
|
26
|
+
base_url: str | None = None,
|
|
27
|
+
api_key: str | None = None,
|
|
28
|
+
timeout: float | None = None,
|
|
29
|
+
) -> list[dict]:
|
|
30
|
+
"""
|
|
31
|
+
List all scenarios for a given project.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
project: The project name to list scenarios from.
|
|
35
|
+
base_url: The base URL of the HumaLab API.
|
|
36
|
+
api_key: The API key for authentication.
|
|
37
|
+
timeout: The timeout for API requests.
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A list of scenario metadata dictionaries.
|
|
41
|
+
"""
|
|
42
|
+
api_client = HumaLabApiClient(base_url=base_url,
|
|
43
|
+
api_key=api_key,
|
|
44
|
+
timeout=timeout)
|
|
45
|
+
resp = api_client.get_scenarios(project_name=project,
|
|
46
|
+
limit=limit,
|
|
47
|
+
offset=offset,
|
|
48
|
+
include_inactive=include_inactive,
|
|
49
|
+
search=search,
|
|
50
|
+
status_filter=status_filter)
|
|
51
|
+
ret_list = []
|
|
52
|
+
for scenario in resp.get("scenarios", []):
|
|
53
|
+
scenario["project"] = project
|
|
54
|
+
ret_list.append(ScenarioMetadata(id=scenario["uuid"],
|
|
55
|
+
version=scenario["version"],
|
|
56
|
+
project=project,
|
|
57
|
+
name=scenario["name"],
|
|
58
|
+
description=scenario.get("description"),
|
|
59
|
+
created_at=scenario.get("created_at"),
|
|
60
|
+
updated_at=scenario.get("updated_at")))
|
|
61
|
+
return ret_list
|
|
62
|
+
|
|
63
|
+
def get_scenario(scenario_id: str,
|
|
64
|
+
version: int | None = None,
|
|
65
|
+
project: str = DEFAULT_PROJECT,
|
|
66
|
+
seed: int | None=None,
|
|
67
|
+
|
|
68
|
+
base_url: str | None = None,
|
|
69
|
+
api_key: str | None = None,
|
|
70
|
+
timeout: float | None = None,) -> Scenario:
|
|
71
|
+
api_client = HumaLabApiClient(base_url=base_url,
|
|
72
|
+
api_key=api_key,
|
|
73
|
+
timeout=timeout)
|
|
74
|
+
scenario_resp = api_client.get_scenario(
|
|
75
|
+
project_name=project,
|
|
76
|
+
uuid=scenario_id, version=version)
|
|
77
|
+
scenario = Scenario()
|
|
78
|
+
|
|
79
|
+
scenario.init(scenario=scenario_resp["yaml_content"],
|
|
80
|
+
seed=seed,
|
|
81
|
+
scenario_id=f"{scenario_id}:{version}" if version is not None else scenario_id)
|
|
82
|
+
return scenario
|