humalab 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,375 @@
1
+ from typing import Any
2
+ from threading import RLock
3
+
4
+ import numpy as np
5
+ from omegaconf import OmegaConf, DictConfig, ListConfig
6
+ import yaml
7
+ from humalab.dists.bernoulli import Bernoulli
8
+ from humalab.dists.categorical import Categorical
9
+ from humalab.dists.uniform import Uniform
10
+ from humalab.dists.discrete import Discrete
11
+ from humalab.dists.log_uniform import LogUniform
12
+ from humalab.dists.gaussian import Gaussian
13
+ from humalab.dists.truncated_gaussian import TruncatedGaussian
14
+ from functools import partial
15
+ from humalab.constants import GraphType, MetricDimType
16
+ import copy
17
+ import uuid
18
+
19
+ DISTRIBUTION_MAP = {
20
+ # 0D distributions
21
+ "uniform": Uniform,
22
+ "bernoulli": Bernoulli,
23
+ "categorical": Categorical,
24
+ "discrete": Discrete,
25
+ "log_uniform": LogUniform,
26
+ "gaussian": Gaussian,
27
+ "truncated_gaussian": TruncatedGaussian,
28
+
29
+ # 1D distributions
30
+ "uniform_1d": Uniform,
31
+ "bernoulli_1d": Bernoulli,
32
+ "categorical_1d": Categorical,
33
+ "discrete_1d": Discrete,
34
+ "log_uniform_1d": LogUniform,
35
+ "gaussian_1d": Gaussian,
36
+ "truncated_gaussian_1d": TruncatedGaussian,
37
+
38
+ # 2D distributions
39
+ "uniform_2d": Uniform,
40
+ # "bernoulli_2d": Bernoulli,
41
+ # "categorical_2d": Categorical,
42
+ # "discrete_2d": Discrete,
43
+ # "log_uniform_2d": LogUniform,
44
+ "gaussian_2d": Gaussian,
45
+ "truncated_gaussian_2d": TruncatedGaussian,
46
+
47
+ # 3D distributions
48
+ "uniform_3d": Uniform,
49
+ # "bernoulli_3d": Bernoulli,
50
+ # "categorical_3d": Categorical,
51
+ # "discrete_3d": Discrete,
52
+ # "log_uniform_3d": LogUniform,
53
+ "gaussian_3d": Gaussian,
54
+ "truncated_gaussian_3d": TruncatedGaussian,
55
+
56
+ # 4D distributions
57
+ # "uniform_4d": Uniform,
58
+ # "bernoulli_4d": Bernoulli,
59
+ # "categorical_4d": Categorical,
60
+ # "discrete_4d": Discrete,
61
+ # "log_uniform_4d": LogUniform,
62
+ # "gaussian_4d": Gaussian,
63
+ # "truncated_gaussian_4d": TruncatedGaussian,
64
+
65
+ # nD distributions
66
+ # "uniform_nd": Uniform,
67
+ # "bernoulli_nd": Bernoulli,
68
+ # "categorical_nd": Categorical,
69
+ # "discrete_nd": Discrete,
70
+ # "log_uniform_nd": LogUniform,
71
+ # "gaussian_nd": Gaussian,
72
+ # "truncated_gaussian_nd": TruncatedGaussian,
73
+
74
+ }
75
+
76
+ DISTRIBUTION_DIMENSION_MAP = {
77
+ # 0D distributions
78
+ "uniform": 0,
79
+ "bernoulli": 0,
80
+ "categorical": 0,
81
+ "discrete": 0,
82
+ "log_uniform": 0,
83
+ "gaussian": 0,
84
+ "truncated_gaussian": 0,
85
+
86
+ # 1D distributions
87
+ "uniform_1d": 1,
88
+ "bernoulli_1d": 1,
89
+ "categorical_1d": 1,
90
+ "discrete_1d": 1,
91
+ "log_uniform_1d": 1,
92
+ "gaussian_1d": 1,
93
+ "truncated_gaussian_1d": 1,
94
+
95
+ # 2D distributions
96
+ "uniform_2d": 2,
97
+ "gaussian_2d": 2,
98
+ "truncated_gaussian_2d": 2,
99
+
100
+ # 3D distributions
101
+ "uniform_3d": 3,
102
+ "gaussian_3d": 3,
103
+ "truncated_gaussian_3d": 3,
104
+ }
105
+
106
+ DISTRIBUTION_PARAM_NUM_MAP = {
107
+ # 0D distributions
108
+ "uniform": 2,
109
+ "bernoulli": 1,
110
+ "categorical": 2,
111
+ "discrete": 3,
112
+ "log_uniform": 2,
113
+ "gaussian": 2,
114
+ "truncated_gaussian": 4,
115
+
116
+ # 1D distributions
117
+ "uniform_1d": 2,
118
+ "bernoulli_1d": 1,
119
+ "categorical_1d": 2,
120
+ "discrete_1d": 3,
121
+ "log_uniform_1d": 2,
122
+ "gaussian_1d": 2,
123
+ "truncated_gaussian_1d": 4,
124
+
125
+ # 2D distributions
126
+ "uniform_2d": 2,
127
+ "gaussian_2d": 2,
128
+ "truncated_gaussian_2d": 4,
129
+
130
+ # 3D distributions
131
+ "uniform_3d": 2,
132
+ "gaussian_3d": 2,
133
+ "truncated_gaussian_3d": 4,
134
+ }
135
+
136
+ SCENARIO_STATS_DIM_TYPE_MAP = {
137
+ # 0D distributions
138
+ "uniform": MetricDimType.ONE_D,
139
+ "bernoulli": MetricDimType.ONE_D,
140
+ "categorical": MetricDimType.ONE_D,
141
+ "discrete": MetricDimType.ONE_D,
142
+ "log_uniform": MetricDimType.ONE_D,
143
+ "gaussian": MetricDimType.ONE_D,
144
+ "truncated_gaussian": MetricDimType.ONE_D,
145
+
146
+ # 1D distributions
147
+ "uniform_1d": MetricDimType.ONE_D,
148
+ "bernoulli_1d": MetricDimType.ONE_D,
149
+ "categorical_1d": MetricDimType.ONE_D,
150
+ "discrete_1d": MetricDimType.ONE_D,
151
+ "log_uniform_1d": MetricDimType.ONE_D,
152
+ "gaussian_1d": MetricDimType.ONE_D,
153
+ "truncated_gaussian_1d": MetricDimType.ONE_D,
154
+
155
+ # 2D distributions
156
+ "uniform_2d": MetricDimType.TWO_D,
157
+ "gaussian_2d": MetricDimType.TWO_D,
158
+ "truncated_gaussian_2d": MetricDimType.TWO_D,
159
+
160
+ # 3D distributions
161
+ "uniform_3d": MetricDimType.THREE_D,
162
+ "gaussian_3d": MetricDimType.THREE_D,
163
+ "truncated_gaussian_3d": MetricDimType.THREE_D,
164
+ }
165
+
166
+ class Scenario:
167
+ """Manages scenario configurations with probabilistic distributions.
168
+
169
+ A Scenario encapsulates a configuration template that can contain distribution
170
+ resolvers (e.g., ${uniform:0,1}). When resolved, these distributions are sampled
171
+ to produce concrete scenario instances. Each resolution creates a new episode
172
+ with different sampled values.
173
+
174
+ Supported distributions include uniform, gaussian, bernoulli, categorical,
175
+ discrete, log_uniform, and truncated_gaussian, with support for 0D-3D variants.
176
+
177
+ Attributes:
178
+ template (DictConfig | ListConfig): The template scenario configuration.
179
+ yaml (str): The current scenario configuration as a YAML string.
180
+ """
181
+ dist_cache = {}
182
+ def __init__(self) -> None:
183
+ self._generator = np.random.default_rng()
184
+ self._scenario_template = OmegaConf.create()
185
+ self._cur_scenario = OmegaConf.create()
186
+ self._scenario_id = None
187
+ self._seed = None
188
+
189
+ self._episode_vals = {}
190
+ self._lock = RLock()
191
+
192
+ def init(self,
193
+ scenario: str | list | dict | None = None,
194
+ seed: int | None=None,
195
+ scenario_id: str | None=None,
196
+ # num_env: int | None = None
197
+ ) -> None:
198
+ """
199
+ Initialize the scenario with the given parameters.
200
+
201
+ Args:
202
+ scenario (str | list | dict | None): The scenario configuration. Can be a YAML
203
+ string, list, or dict. If None, an empty configuration is used.
204
+ seed (int | None): Optional seed for random number generation. If None, uses
205
+ a non-deterministic seed.
206
+ scenario_id (str | None): Optional scenario ID in the format 'id' or 'id:version'.
207
+ If None, a new UUID is generated.
208
+ """
209
+ self._num_env = None # num_env
210
+ self._seed = seed
211
+
212
+ # Parse scenario id
213
+ scenario_version = 1
214
+ if scenario_id is not None:
215
+ scenario_arr = scenario_id.split(":")
216
+ if len(scenario_arr) < 1:
217
+ raise ValueError("Invalid scenario_id format. Expected 'scenario_id' or 'scenario_name:version'.")
218
+ scenario_id = scenario_arr[0]
219
+ scenario_version = int(scenario_arr[1]) if len(scenario_arr) > 1 else None
220
+ self._scenario_id = scenario_id or str(uuid.uuid4())
221
+ self._scenario_version = scenario_version
222
+
223
+ self._generator = np.random.default_rng(seed)
224
+ self._configure()
225
+ scenario = scenario or {}
226
+
227
+ self._scenario_template = OmegaConf.create(scenario)
228
+
229
+ def _validate_distribution_params(self, dist_name: str, *args: tuple) -> None:
230
+ dimensions = DISTRIBUTION_DIMENSION_MAP[dist_name]
231
+ if not DISTRIBUTION_MAP[dist_name].validate(dimensions, *args):
232
+ raise ValueError(f"Invalid parameters for distribution {dist_name} with dimensions {dimensions}: {args}")
233
+
234
+ def _get_final_size(self, size: int | tuple[int, ...] | None) -> int | tuple[int, ...] | None:
235
+ n = self._num_env
236
+ if size is None:
237
+ return n
238
+ if n is None:
239
+ return size
240
+ if isinstance(size, int):
241
+ return (n, size)
242
+ return (n, *size)
243
+
244
+ def _get_node_path(self, root: dict | list, node: str) -> str:
245
+ if isinstance(root, list):
246
+ root = {str(i): v for i, v in enumerate(root)}
247
+
248
+ for key, value in root.items():
249
+ if value == node:
250
+ return str(key)
251
+ if isinstance(value, dict):
252
+ sub_path = self._get_node_path(value, node)
253
+ if sub_path:
254
+ return f"{key}.{sub_path}"
255
+ elif isinstance(value, list):
256
+ for idx, item in enumerate(value):
257
+ if item == node:
258
+ return f"{key}[{idx}]"
259
+ if isinstance(item, (dict, list)):
260
+ sub_path = self._get_node_path(item, node)
261
+ if sub_path:
262
+ return f"{key}[{idx}].{sub_path}"
263
+ return ""
264
+
265
+ @staticmethod
266
+ def _convert_to_python(obj) -> Any:
267
+ if not isinstance(obj, (np.ndarray, np.generic)):
268
+ return obj
269
+
270
+ # NumPy scalar (np.generic) or 0-D ndarray
271
+ if isinstance(obj, np.generic) or (isinstance(obj, np.ndarray) and obj.ndim == 0):
272
+ return obj.item()
273
+
274
+ # Regular ndarray (1-D or higher)
275
+ if isinstance(obj, np.ndarray):
276
+ return obj.tolist()
277
+
278
+ return obj
279
+
280
+ def _configure(self) -> None:
281
+ self._clear_resolvers()
282
+ def distribution_resolver(dist_name: str, *args, _node_, _root_, _parent_, **kwargs):
283
+ if len(args) > DISTRIBUTION_PARAM_NUM_MAP[dist_name]:
284
+ args = args[:DISTRIBUTION_PARAM_NUM_MAP[dist_name]]
285
+ print(f"Warning: Too many parameters for {dist_name}, expected {DISTRIBUTION_PARAM_NUM_MAP[dist_name]}, got {len(args)}. Extra parameters will be ignored.")
286
+
287
+ self._validate_distribution_params(dist_name, *args)
288
+ # print("_node_: ", _node_, type(_node_))
289
+ # print("_root_: ", _root_, type(_root_))
290
+ # print("_parent_: ", _parent_, type(_parent_))
291
+ # print("Args: ", args, len(args))
292
+ # print("Kwargs: ", kwargs)
293
+
294
+ root_yaml = yaml.safe_load(OmegaConf.to_yaml(_root_))
295
+ key_path = self._get_node_path(root_yaml, str(_node_))
296
+
297
+ shape = None
298
+
299
+ if DISTRIBUTION_DIMENSION_MAP[dist_name] == -1:
300
+ shape = args[DISTRIBUTION_PARAM_NUM_MAP[dist_name] - 1]
301
+ args = args[:-1]
302
+ else:
303
+ shape = DISTRIBUTION_DIMENSION_MAP[dist_name] if DISTRIBUTION_DIMENSION_MAP[dist_name] > 0 else None
304
+ shape = self._get_final_size(shape)
305
+
306
+ key = str(_node_)
307
+ if key not in Scenario.dist_cache:
308
+ Scenario.dist_cache[key] = DISTRIBUTION_MAP[dist_name].create(self._generator, *args, size=shape, **kwargs)
309
+ ret_val = Scenario.dist_cache[key].sample()
310
+ ret_val = Scenario._convert_to_python(ret_val)
311
+
312
+ if isinstance(ret_val, list):
313
+ ret_val = ListConfig(ret_val)
314
+
315
+ self._episode_vals[key_path] = {
316
+ "value": ret_val,
317
+ "distribution": dist_name,
318
+ }
319
+ return ret_val
320
+
321
+ for dist_name in DISTRIBUTION_MAP.keys():
322
+ OmegaConf.register_new_resolver(dist_name, partial(distribution_resolver, dist_name))
323
+
324
+ def _clear_resolvers(self) -> None:
325
+ self.dist_cache.clear()
326
+ OmegaConf.clear_resolvers()
327
+
328
+ def resolve(self) -> tuple[DictConfig | ListConfig, dict]:
329
+ """Resolve the scenario configuration, sampling all distributions.
330
+
331
+ Returns:
332
+ tuple[DictConfig | ListConfig, dict]: The resolved scenario and episode values.
333
+ """
334
+ with self._lock:
335
+ cur_scenario = copy.deepcopy(self._scenario_template)
336
+ self._episode_vals = {}
337
+ OmegaConf.resolve(cur_scenario)
338
+ episode_vals = copy.deepcopy(self._episode_vals)
339
+ return cur_scenario, episode_vals
340
+
341
+ @property
342
+ def scenario_id(self) -> str | None:
343
+ """The scenario ID.
344
+
345
+ Returns:
346
+ str | None: The scenario ID, or None if not set.
347
+ """
348
+ return self._scenario_id
349
+
350
+ @property
351
+ def seed(self) -> int | None:
352
+ """The random seed for the scenario.
353
+
354
+ Returns:
355
+ int | None: The random seed, or None if not set.
356
+ """
357
+ return self._seed
358
+
359
+ @property
360
+ def template(self) -> Any:
361
+ """The template scenario configuration.
362
+
363
+ Returns:
364
+ Any: The template scenario as an OmegaConf object.
365
+ """
366
+ return self._scenario_template
367
+
368
+ @property
369
+ def yaml(self) -> str:
370
+ """The current scenario configuration as a YAML string.
371
+
372
+ Returns:
373
+ str: The current scenario as a YAML string.
374
+ """
375
+ return OmegaConf.to_yaml(self._scenario_template)
@@ -0,0 +1,114 @@
1
+ """Operations for managing and retrieving scenarios."""
2
+
3
+ from typing import Optional
4
+ from dataclasses import dataclass
5
+
6
+ from humalab.humalab_api_client import HumaLabApiClient
7
+ from humalab.scenarios.scenario import Scenario
8
+ from humalab.constants import DEFAULT_PROJECT
9
+
10
+ @dataclass
11
+ class ScenarioMetadata:
12
+ """Metadata for a scenario stored in HumaLab.
13
+
14
+ Attributes:
15
+ id (str): Unique identifier for the scenario.
16
+ version (int): Version number of the scenario.
17
+ project (str): Project name the scenario belongs to.
18
+ name (str): Human-readable scenario name.
19
+ description (str | None): Optional scenario description.
20
+ created_at (str): ISO timestamp when scenario was created.
21
+ updated_at (str): ISO timestamp when scenario was last updated.
22
+ """
23
+ id: str
24
+ version: int
25
+ project: str
26
+ name: str
27
+ description: str | None
28
+ created_at: str
29
+ updated_at: str
30
+
31
+
32
+ def list_scenarios(project: str = DEFAULT_PROJECT,
33
+ limit: int = 20,
34
+ offset: int = 0,
35
+ include_inactive: bool = False,
36
+ search: Optional[str] = None,
37
+ status_filter: Optional[str] = None,
38
+
39
+ base_url: str | None = None,
40
+ api_key: str | None = None,
41
+ timeout: float | None = None,
42
+ ) -> list[ScenarioMetadata]:
43
+ """
44
+ List all scenarios for a given project.
45
+
46
+ Args:
47
+ project (str): The project name to list scenarios from. Defaults to DEFAULT_PROJECT.
48
+ limit (int): Maximum number of scenarios to return. Defaults to 20.
49
+ offset (int): Number of scenarios to skip for pagination. Defaults to 0.
50
+ include_inactive (bool): Whether to include inactive scenarios. Defaults to False.
51
+ search (Optional[str]): Search query to filter scenarios by name or description. Defaults to None.
52
+ status_filter (Optional[str]): Filter scenarios by status. Defaults to None.
53
+ base_url (str | None): The base URL of the HumaLab API. If None, uses configured value.
54
+ api_key (str | None): The API key for authentication. If None, uses configured value.
55
+ timeout (float | None): The timeout for API requests in seconds. If None, uses configured value.
56
+
57
+ Returns:
58
+ list[ScenarioMetadata]: A list of scenario metadata objects.
59
+ """
60
+ api_client = HumaLabApiClient(base_url=base_url,
61
+ api_key=api_key,
62
+ timeout=timeout)
63
+ resp = api_client.get_scenarios(project_name=project,
64
+ limit=limit,
65
+ offset=offset,
66
+ include_inactive=include_inactive,
67
+ search=search,
68
+ status_filter=status_filter)
69
+ ret_list = []
70
+ for scenario in resp.get("scenarios", []):
71
+ scenario["project"] = project
72
+ ret_list.append(ScenarioMetadata(id=scenario["uuid"],
73
+ version=scenario["version"],
74
+ project=project,
75
+ name=scenario["name"],
76
+ description=scenario.get("description"),
77
+ created_at=scenario.get("created_at"),
78
+ updated_at=scenario.get("updated_at")))
79
+ return ret_list
80
+
81
+ def get_scenario(scenario_id: str,
82
+ version: int | None = None,
83
+ project: str = DEFAULT_PROJECT,
84
+ seed: int | None=None,
85
+
86
+ base_url: str | None = None,
87
+ api_key: str | None = None,
88
+ timeout: float | None = None,) -> Scenario:
89
+ """Retrieve and initialize a scenario from HumaLab.
90
+
91
+ Args:
92
+ scenario_id (str): The unique identifier of the scenario.
93
+ version (int | None): Optional specific version to retrieve.
94
+ project (str): The project name. Defaults to DEFAULT_PROJECT.
95
+ seed (int | None): Optional seed for scenario randomization.
96
+ base_url (str | None): Optional API host override.
97
+ api_key (str | None): Optional API key override.
98
+ timeout (float | None): Optional timeout override.
99
+
100
+ Returns:
101
+ Scenario: The initialized scenario instance.
102
+ """
103
+ api_client = HumaLabApiClient(base_url=base_url,
104
+ api_key=api_key,
105
+ timeout=timeout)
106
+ scenario_resp = api_client.get_scenario(
107
+ project_name=project,
108
+ uuid=scenario_id, version=version)
109
+ scenario = Scenario()
110
+
111
+ scenario.init(scenario=scenario_resp["yaml_content"],
112
+ seed=seed,
113
+ scenario_id=f"{scenario_id}:{version}" if version is not None else scenario_id)
114
+ return scenario