humalab 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
humalab/episode.py ADDED
@@ -0,0 +1,306 @@
1
+ from humalab.constants import RESERVED_NAMES, ArtifactType
2
+ from humalab.humalab_api_client import HumaLabApiClient, EpisodeStatus
3
+ from humalab.metrics.code import Code
4
+ from humalab.metrics.summary import Summary
5
+ from humalab.metrics.metric import Metrics
6
+ from omegaconf import DictConfig, ListConfig, OmegaConf
7
+ from typing import Any
8
+ import pickle
9
+ import traceback
10
+
11
+ from humalab.utils import is_standard_type
12
+
13
+
14
+ class Episode:
15
+ """Represents a single episode within a run.
16
+
17
+ An Episode is a context manager that tracks a single execution instance of a
18
+ scenario. It provides access to scenario configuration values, supports metric
19
+ logging, and manages episode lifecycle with various completion statuses.
20
+
21
+ Episodes can be finished with different statuses:
22
+ - SUCCESS: Episode completed successfully
23
+ - FAILED: Episode failed
24
+ - CANCELED: Episode was discarded/canceled
25
+ - ERRORED: Episode encountered an error
26
+
27
+ Use as a context manager to automatically handle episode lifecycle:
28
+ with episode:
29
+ # Your code here
30
+ pass
31
+
32
+ Attributes:
33
+ run_id (str): The unique identifier of the parent run.
34
+ episode_id (str): The unique identifier for this episode.
35
+ scenario (DictConfig | ListConfig): The resolved scenario configuration.
36
+ status (EpisodeStatus): The current status of the episode.
37
+ episode_vals (dict): The sampled values from scenario distributions.
38
+ is_finished (bool): Whether the episode has been finalized.
39
+ """
40
+ def __init__(self,
41
+ run_id: str,
42
+ episode_id: str,
43
+ scenario_conf: DictConfig | ListConfig,
44
+ episode_vals: dict | None = None,
45
+
46
+ base_url: str | None = None,
47
+ api_key: str | None = None,
48
+ timeout: float | None = None,):
49
+ self._run_id = run_id
50
+ self._episode_id = episode_id
51
+ self._episode_status = EpisodeStatus.RUNNING
52
+ self._scenario_conf = scenario_conf
53
+ self._logs = {}
54
+ self._episode_vals = episode_vals or {}
55
+ self._is_finished = False
56
+
57
+ self._api_client = HumaLabApiClient(base_url=base_url,
58
+ api_key=api_key,
59
+ timeout=timeout)
60
+
61
+ @property
62
+ def run_id(self) -> str:
63
+ """The unique identifier of the parent run.
64
+
65
+ Returns:
66
+ str: The run ID.
67
+ """
68
+ return self._run_id
69
+
70
+ @property
71
+ def episode_id(self) -> str:
72
+ """The unique identifier for this episode.
73
+
74
+ Returns:
75
+ str: The episode ID.
76
+ """
77
+ return self._episode_id
78
+
79
+ @property
80
+ def scenario(self) -> DictConfig | ListConfig:
81
+ """The resolved scenario configuration for this episode.
82
+
83
+ Returns:
84
+ DictConfig | ListConfig: The scenario configuration.
85
+ """
86
+ return self._scenario_conf
87
+
88
+ @property
89
+ def status(self) -> EpisodeStatus:
90
+ """The current status of the episode.
91
+
92
+ Returns:
93
+ EpisodeStatus: The episode status.
94
+ """
95
+ return self._episode_status
96
+
97
+ @property
98
+ def episode_vals(self) -> dict:
99
+ """The sampled values from scenario distributions.
100
+
101
+ Returns:
102
+ dict: Dictionary mapping scenario variable names to their sampled values.
103
+ """
104
+ return self._episode_vals
105
+
106
+ @property
107
+ def is_finished(self) -> bool:
108
+ """Whether the episode has been finalized.
109
+
110
+ Returns:
111
+ bool: True if the episode is finished, False otherwise.
112
+ """
113
+ return self._is_finished
114
+
115
+ def __enter__(self):
116
+ """Enter the episode context."""
117
+ return self
118
+
119
+ def __exit__(self, exception_type, exception_value, exception_traceback):
120
+ """Exit the episode context and finalize the episode."""
121
+ if self._is_finished:
122
+ return
123
+ if exception_type is not None:
124
+ err_msg = "".join(traceback.format_exception(exception_type, exception_value, exception_traceback))
125
+ self.finish(status=EpisodeStatus.ERRORED, err_msg=err_msg)
126
+ else:
127
+ self.finish(status=EpisodeStatus.SUCCESS)
128
+
129
+ def __getattr__(self, name: Any) -> Any:
130
+ """Access scenario configuration values as attributes.
131
+
132
+ Allows accessing scenario configuration using dot notation (e.g., episode.my_param).
133
+
134
+ Args:
135
+ name (Any): The attribute/key name from scenario configuration.
136
+
137
+ Returns:
138
+ Any: The value from scenario configuration.
139
+
140
+ Raises:
141
+ AttributeError: If the attribute is not in scenario configuration.
142
+ """
143
+ if name in self._scenario_conf:
144
+ return self._scenario_conf[name]
145
+ raise AttributeError(f"'Scenario' object has no attribute '{name}'")
146
+
147
+ def __getitem__(self, key: Any) -> Any:
148
+ """Access scenario configuration values using subscript notation.
149
+
150
+ Allows accessing scenario configuration using bracket notation (e.g., episode['my_param']).
151
+
152
+ Args:
153
+ key (Any): The key name from scenario configuration.
154
+
155
+ Returns:
156
+ Any: The value from scenario configuration.
157
+
158
+ Raises:
159
+ KeyError: If the key is not in scenario configuration.
160
+ """
161
+ if key in self._scenario_conf:
162
+ return self._scenario_conf[key]
163
+ raise KeyError(f"'Scenario' object has no key '{key}'")
164
+
165
+ def add_metric(self, name: str, metric: Metrics) -> None:
166
+ """Add a metric to track for this episode.
167
+
168
+ Args:
169
+ name (str): The name of the metric.
170
+ metric (Metrics): The metric instance to add.
171
+
172
+ Raises:
173
+ ValueError: If the name is already used.
174
+ """
175
+ if name in self._logs:
176
+ raise ValueError(f"{name} is a reserved name and is not allowed.")
177
+ self._logs[name] = metric
178
+
179
+ def log_code(self, key: str, code_content: str) -> None:
180
+ """Log code content as an artifact.
181
+
182
+ Args:
183
+ key (str): The key for the code artifact.
184
+ code_content (str): The code content to log.
185
+ """
186
+ if key in RESERVED_NAMES:
187
+ raise ValueError(f"{key} is a reserved name and is not allowed.")
188
+ self._logs[key] = Code(
189
+ run_id=self._run_id,
190
+ key=key,
191
+ code_content=code_content,
192
+ episode_id=self._episode_id
193
+ )
194
+
195
+ def log(self, data: dict, x: dict | None = None, replace: bool = False) -> None:
196
+ """Log data points or values for the episode.
197
+
198
+ Args:
199
+ data (dict): Dictionary of key-value pairs to log.
200
+ x (dict | None): Optional dictionary of x-axis values for each key.
201
+ replace (bool): Whether to replace existing values. Defaults to False.
202
+
203
+ Raises:
204
+ ValueError: If a key is reserved or logging fails.
205
+ """
206
+ for key, value in data.items():
207
+ if key in RESERVED_NAMES:
208
+ raise ValueError(f"{key} is a reserved name and is not allowed.")
209
+ if key not in self._logs:
210
+ self._logs[key] = value
211
+ else:
212
+ cur_val = self._logs[key]
213
+ if isinstance(cur_val, Metrics):
214
+ cur_x = x.get(key) if x is not None else None
215
+ cur_val.log(value, x=cur_x, replace=replace)
216
+ else:
217
+ if replace:
218
+ self._logs[key] = value
219
+ else:
220
+ raise ValueError(f"Cannot log value for key '{key}' as there is already a value logged.")
221
+
222
+ @property
223
+ def yaml(self) -> str:
224
+ """The current scenario configuration as a YAML string.
225
+
226
+ Returns:
227
+ str: The current scenario as a YAML string.
228
+ """
229
+ return OmegaConf.to_yaml(self._scenario_conf)
230
+
231
+ def discard(self) -> None:
232
+ """Mark the episode as discarded/canceled."""
233
+ self._finish(EpisodeStatus.CANCELED)
234
+
235
+ def success(self) -> None:
236
+ """Mark the episode as successfully completed."""
237
+ self._finish(EpisodeStatus.SUCCESS)
238
+
239
+ def fail(self) -> None:
240
+ """Mark the episode as failed."""
241
+ self._finish(EpisodeStatus.FAILED)
242
+
243
+ def finish(self, status: EpisodeStatus, err_msg: str | None = None) -> None:
244
+ """Finish the episode with a specific status.
245
+
246
+ Args:
247
+ status (EpisodeStatus): The final status of the episode.
248
+ err_msg (str | None): Optional error message if the episode errored.
249
+ """
250
+ if self._is_finished:
251
+ return
252
+ self._is_finished = True
253
+ self._episode_status = status
254
+
255
+ self._api_client.upload_code(
256
+ artifact_key="scenario",
257
+ run_id=self._run_id,
258
+ episode_id=self._episode_id,
259
+ code_content=self.yaml
260
+ )
261
+
262
+ # TODO: submit final metrics
263
+ for key, value in self._logs.items():
264
+ if isinstance(value, Summary):
265
+ metric_val = value.finalize()
266
+ pickled = pickle.dumps(metric_val["value"])
267
+ self._api_client.upload_python(
268
+ artifact_key=key,
269
+ run_id=self._run_id,
270
+ episode_id=self._episode_id,
271
+ pickled_bytes=pickled
272
+ )
273
+ elif isinstance(value, Metrics):
274
+ metric_val = value.finalize()
275
+ pickled = pickle.dumps(metric_val)
276
+ self._api_client.upload_metrics(
277
+ artifact_key=key,
278
+ run_id=self._run_id,
279
+ episode_id=self._episode_id,
280
+ pickled_bytes=pickled,
281
+ graph_type=value.graph_type.value,
282
+ )
283
+ elif isinstance(value, Code):
284
+ self._api_client.upload_code(
285
+ artifact_key=value.key,
286
+ run_id=value.run_id,
287
+ episode_id=value.episode_id,
288
+ code_content=value.code_content
289
+ )
290
+ else:
291
+ if not is_standard_type(value):
292
+ raise ValueError(f"Value for key '{key}' is not a standard type.")
293
+ pickled = pickle.dumps(value)
294
+ self._api_client.upload_python(
295
+ artifact_key=key,
296
+ run_id=self._run_id,
297
+ episode_id=self._episode_id,
298
+ pickled_bytes=pickled
299
+ )
300
+
301
+ self._api_client.update_episode(
302
+ run_id=self._run_id,
303
+ episode_id=self._episode_id,
304
+ status=status,
305
+ err_msg=err_msg
306
+ )
humalab/humalab.py ADDED
@@ -0,0 +1,219 @@
1
+ from contextlib import contextmanager
2
+ import sys
3
+ import traceback
4
+
5
+ from omegaconf import OmegaConf
6
+
7
+ from humalab.constants import DEFAULT_PROJECT
8
+ from humalab.run import Run
9
+ from humalab.humalab_config import HumalabConfig
10
+ from humalab.humalab_api_client import HumaLabApiClient, RunStatus, EpisodeStatus
11
+ import requests
12
+
13
+ import uuid
14
+
15
+ from collections.abc import Generator
16
+
17
+ from humalab.scenarios.scenario import Scenario
18
+
19
+ _cur_run: Run | None = None
20
+
21
+ def _pull_scenario(client: HumaLabApiClient,
22
+ project: str,
23
+ seed: int | None = None,
24
+ scenario: str | list | dict | Scenario | None = None,
25
+ scenario_id: str | None = None,) -> Scenario:
26
+ """Pull a scenario from the server if scenario_id is provided.
27
+
28
+ Args:
29
+ client (HumaLabApiClient): API client instance.
30
+ project (str): Project name.
31
+ scenario (str | list | dict | None): Local scenario configuration.
32
+ scenario_id (str | None): ID of scenario to pull from server.
33
+
34
+ Returns:
35
+ str | list | dict | None: The scenario configuration.
36
+ """
37
+ if scenario_id is not None:
38
+ scenario_arr = scenario_id.split(":")
39
+ if len(scenario_arr) < 1:
40
+ raise ValueError("Invalid scenario_id format. Expected 'scenario_id' or 'scenario_name:version'.")
41
+ scenario_real_id = scenario_arr[0]
42
+ scenario_version = int(scenario_arr[1]) if len(scenario_arr) > 1 else None
43
+
44
+ scenario_response = client.get_scenario(
45
+ project_name=project,
46
+ uuid=scenario_real_id,
47
+ version=scenario_version)
48
+ final_scenario = scenario_response["yaml_content"]
49
+ else:
50
+ final_scenario = scenario
51
+
52
+ if isinstance(final_scenario, Scenario):
53
+ scenario_inst = final_scenario
54
+ else:
55
+ scenario_inst = Scenario()
56
+ scenario_inst.init(scenario=final_scenario,
57
+ seed=seed,
58
+ scenario_id=scenario_id,
59
+ #num_env=num_env,
60
+ )
61
+ return scenario_inst
62
+
63
+ @contextmanager
64
+ def init(project: str | None = None,
65
+ name: str | None = None,
66
+ description: str | None = None,
67
+ id: str | None = None,
68
+ tags: list[str] | None = None,
69
+ scenario: str | list | dict | Scenario | None = None,
70
+ scenario_id: str | None = None,
71
+ seed: int | None=None,
72
+ auto_create_scenario: bool = False,
73
+ # num_env: int | None = None,
74
+
75
+ base_url: str | None = None,
76
+ api_key: str | None = None,
77
+ timeout: float | None = None,
78
+ ) -> Generator[Run, None, None]:
79
+ """
80
+ Initialize a new HumaLab run.
81
+
82
+ Args:
83
+ project: The project name under which to create the run.
84
+ name: The name of the run.
85
+ description: A description of the run.
86
+ id: The unique identifier for the run. If None, a new UUID will be generated.
87
+ tags: A list of tags to associate with the run.
88
+ scenario: The scenario configuration as a string, list, or dict.
89
+ scenario_id: The unique identifier of a pre-defined scenario to use.
90
+ base_url: The base URL of the HumaLab server.
91
+ api_key: The API key for authentication.
92
+ seed: An optional seed for scenario randomization.
93
+ timeout: The timeout for API requests.
94
+ auto_create_scenario: Whether to automatically create the scenario if it does not exist.
95
+ # num_env: The number of parallel environments to run. (Not supported yet.)
96
+ """
97
+ global _cur_run
98
+ run = None
99
+ try:
100
+ project = project or DEFAULT_PROJECT
101
+ name = name or ""
102
+ description = description or ""
103
+ id = id or str(uuid.uuid4())
104
+
105
+ api_client = HumaLabApiClient(base_url=base_url,
106
+ api_key=api_key,
107
+ timeout=timeout)
108
+ scenario_inst = _pull_scenario(client=api_client,
109
+ project=project,
110
+ seed=seed,
111
+ scenario=scenario,
112
+ scenario_id=scenario_id)
113
+
114
+ project_resp = api_client.create_project(name=project)
115
+
116
+ if scenario_id is None and scenario is not None and auto_create_scenario:
117
+ scenario_response = api_client.create_scenario(
118
+ project_name=project_resp['name'],
119
+ name=f"{name} scenario",
120
+ description="Auto-created scenario",
121
+ yaml_content=OmegaConf.to_yaml(scenario_inst.template),
122
+ )
123
+ scenario_id = scenario_response['uuid']
124
+ try:
125
+ run_response = api_client.get_run(run_id=id)
126
+ api_client.update_run(
127
+ run_id=run_response['run_id'],
128
+ name=name,
129
+ description=description,
130
+ tags=tags,
131
+ status=RunStatus.RUNNING,
132
+ )
133
+
134
+ except requests.HTTPError as e:
135
+ if e.response.status_code == 404:
136
+ # If not found then create a new run,
137
+ # so ignore not found error.
138
+ run_response = None
139
+ else:
140
+ # Otherwise re-raise the exception.
141
+ raise
142
+
143
+ if run_response is None:
144
+ run_response = api_client.create_run(name=name,
145
+ project_name=project_resp['name'],
146
+ description=description,
147
+ tags=tags)
148
+ id = run_response['run_id']
149
+ api_client.update_run(
150
+ run_id=id,
151
+ description=description,
152
+ )
153
+
154
+ run = Run(
155
+ project=project_resp['name'],
156
+ name=run_response["name"],
157
+ description=run_response.get("description"),
158
+ id=run_response['run_id'],
159
+ tags=run_response.get("tags"),
160
+ scenario=scenario_inst,
161
+
162
+ base_url=base_url,
163
+ api_key=api_key,
164
+ timeout=timeout
165
+ )
166
+
167
+ _cur_run = run
168
+ yield run
169
+ except Exception as e:
170
+ if _cur_run:
171
+ exc_type, exc_value, exc_traceback = sys.exc_info()
172
+ formatted_traceback = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback))
173
+ finish(status=RunStatus.ERRORED,
174
+ err_msg=formatted_traceback)
175
+ raise
176
+ else:
177
+ if _cur_run:
178
+ print("Finishing run...")
179
+ finish(status=RunStatus.FINISHED)
180
+
181
+ def discard() -> None:
182
+ """Discard the current run by finishing it with CANCELED status."""
183
+ finish(status=RunStatus.CANCELED)
184
+
185
+ def finish(status: RunStatus = RunStatus.FINISHED,
186
+ err_msg: str | None = None) -> None:
187
+ """Finish the current run.
188
+
189
+ Args:
190
+ status (RunStatus): The final status of the run. Defaults to FINISHED.
191
+ err_msg (str | None): Optional error message if the run errored.
192
+ """
193
+ global _cur_run
194
+ if _cur_run:
195
+ _cur_run.finish(status=status, err_msg=err_msg)
196
+ _cur_run = None
197
+
198
+ def login(api_key: str | None = None,
199
+ relogin: bool | None = None,
200
+ host: str | None = None,
201
+ force: bool | None = None,
202
+ timeout: float | None = None) -> bool:
203
+ """Configure HumaLab authentication and connection settings.
204
+
205
+ Args:
206
+ api_key (str | None): API key for authentication.
207
+ relogin (bool | None): Unused parameter (for compatibility).
208
+ host (str | None): API host URL.
209
+ force (bool | None): Unused parameter (for compatibility).
210
+ timeout (float | None): Request timeout in seconds.
211
+
212
+ Returns:
213
+ bool: Always returns True.
214
+ """
215
+ humalab_config = HumalabConfig()
216
+ humalab_config.api_key = api_key or humalab_config.api_key
217
+ humalab_config.base_url = host or humalab_config.base_url
218
+ humalab_config.timeout = timeout or humalab_config.timeout
219
+ return True