humalab 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
humalab/run.py ADDED
@@ -0,0 +1,325 @@
1
+ import uuid
2
+ import traceback
3
+ import pickle
4
+ import base64
5
+
6
+ from humalab.metrics.code import Code
7
+ from humalab.metrics.summary import Summary
8
+
9
+ from humalab.constants import DEFAULT_PROJECT, RESERVED_NAMES, ArtifactType
10
+ from humalab.metrics.scenario_stats import ScenarioStats
11
+ from humalab.humalab_api_client import EpisodeStatus, HumaLabApiClient, RunStatus
12
+ from humalab.metrics.metric import Metrics
13
+ from humalab.episode import Episode
14
+ from humalab.utils import is_standard_type
15
+
16
+ from humalab.scenarios.scenario import Scenario
17
+
18
+ class Run:
19
+ """Represents a run containing multiple episodes for a scenario.
20
+
21
+ A Run is a context manager that tracks experiments or evaluations using a specific
22
+ scenario. It manages episode creation, metric logging, and code artifacts. The run
23
+ can contain multiple episodes, each representing a single execution instance.
24
+
25
+ Use as a context manager to automatically handle run lifecycle:
26
+ with Run(scenario=my_scenario) as run:
27
+ # Your code here
28
+ pass
29
+
30
+ Attributes:
31
+ project (str): The project name under which the run is created.
32
+ id (str): The unique identifier for the run.
33
+ name (str): The name of the run.
34
+ description (str): A description of the run.
35
+ tags (list[str]): A list of tags associated with the run.
36
+ scenario (Scenario): The scenario associated with the run.
37
+ """
38
+ def __init__(self,
39
+ scenario: Scenario,
40
+ project: str = DEFAULT_PROJECT,
41
+ name: str | None = None,
42
+ description: str | None = None,
43
+ id: str | None = None,
44
+ tags: list[str] | None = None,
45
+
46
+ base_url: str | None = None,
47
+ api_key: str | None = None,
48
+ timeout: float | None = None,
49
+ ) -> None:
50
+ """
51
+ Initialize a new Run instance.
52
+
53
+ Args:
54
+ project (str): The project name under which the run is created.
55
+ scenario (Scenario): The scenario instance for the run.
56
+ name (str | None): The name of the run.
57
+ description (str | None): A description of the run.
58
+ id (str | None): The unique identifier for the run. If None, a UUID is generated.
59
+ tags (list[str] | None): A list of tags associated with the run.
60
+ """
61
+ self._project = project
62
+ self._id = id or str(uuid.uuid4())
63
+ self._name = name or ""
64
+ self._description = description or ""
65
+ self._tags = tags or []
66
+
67
+ self._scenario = scenario
68
+ self._logs = {}
69
+ self._episodes = {}
70
+ self._is_finished = False
71
+
72
+ self._api_client = HumaLabApiClient(base_url=base_url,
73
+ api_key=api_key,
74
+ timeout=timeout)
75
+
76
+
77
+ @property
78
+ def project(self) -> str:
79
+ """The project name under which the run is created.
80
+
81
+ Returns:
82
+ str: The project name.
83
+ """
84
+ return self._project
85
+
86
+ @property
87
+ def id(self) -> str:
88
+ """The unique identifier for the run.
89
+
90
+ Returns:
91
+ str: The run ID.
92
+ """
93
+ return self._id
94
+
95
+ @property
96
+ def name(self) -> str:
97
+ """The name of the run.
98
+
99
+ Returns:
100
+ str: The run name.
101
+ """
102
+ return self._name
103
+
104
+ @property
105
+ def description(self) -> str:
106
+ """The description of the run.
107
+
108
+ Returns:
109
+ str: The run description.
110
+ """
111
+ return self._description
112
+
113
+ @property
114
+ def tags(self) -> list[str]:
115
+ """The tags associated with the run.
116
+
117
+ Returns:
118
+ list[str]: The list of tags.
119
+ """
120
+ return self._tags
121
+
122
+ @property
123
+ def scenario(self) -> Scenario:
124
+ """The scenario associated with the run.
125
+
126
+ Returns:
127
+ Scenario: The scenario instance.
128
+ """
129
+ return self._scenario
130
+
131
+ def __enter__(self):
132
+ """Enter the run context."""
133
+ return self
134
+
135
+ def __exit__(self, exception_type, exception_value, exception_traceback):
136
+ """Exit the run context and finalize the run."""
137
+ if self._is_finished:
138
+ return
139
+ if exception_type is not None:
140
+ err_msg = "".join(traceback.format_exception(exception_type, exception_value, exception_traceback))
141
+ self.finish(status=RunStatus.ERRORED, err_msg=err_msg)
142
+ else:
143
+ self.finish()
144
+
145
+ def create_episode(self, episode_id: str | None = None) -> Episode:
146
+ """Create a new episode for this run.
147
+
148
+ Args:
149
+ episode_id (str | None): Optional unique identifier for the episode.
150
+ If None, a UUID is generated automatically.
151
+
152
+ Returns:
153
+ Episode: The newly created episode instance.
154
+ """
155
+ episode = None
156
+ episode_id = episode_id or str(uuid.uuid4())
157
+ cur_scenario, episode_vals = self._scenario.resolve()
158
+ episode = Episode(run_id=self._id,
159
+ episode_id=episode_id,
160
+ scenario_conf=cur_scenario,
161
+ episode_vals=episode_vals)
162
+ self._handle_scenario_stats(episode, episode_vals)
163
+
164
+ return episode
165
+
166
+ def _handle_scenario_stats(self, episode: Episode, episode_vals: dict) -> None:
167
+ for metric_name, value in episode_vals.items():
168
+ if metric_name not in self._logs:
169
+ stat = ScenarioStats(name=metric_name,
170
+ distribution_type=value["distribution"])
171
+ self._logs[metric_name] = stat
172
+ self._logs[metric_name].log(data=value["value"],
173
+ x=episode.episode_id)
174
+ self._episodes[episode.episode_id] = episode
175
+
176
+ def add_metric(self, name: str, metric: Metrics) -> None:
177
+ """Add a metric to track for this run.
178
+
179
+ Args:
180
+ name (str): The name of the metric.
181
+ metric (Metrics): The metric instance to add.
182
+
183
+ Raises:
184
+ ValueError: If the name is already used.
185
+ """
186
+ if name in self._logs:
187
+ raise ValueError(f"{name} is a reserved name and is not allowed.")
188
+ self._logs[name] = metric
189
+
190
+ def log_code(self, key: str, code_content: str) -> None:
191
+ """Log code content as an artifact.
192
+
193
+ Args:
194
+ key (str): The key for the code artifact.
195
+ code_content (str): The code content to log.
196
+ """
197
+ if key in RESERVED_NAMES:
198
+ raise ValueError(f"{key} is a reserved name and is not allowed.")
199
+ self._logs[key] = Code(
200
+ run_id=self._id,
201
+ key=key,
202
+ code_content=code_content,
203
+ )
204
+
205
+
206
+ def log(self, data: dict, x: dict | None = None, replace: bool = False) -> None:
207
+ """Log data points or values for the run.
208
+
209
+ Args:
210
+ data (dict): Dictionary of key-value pairs to log.
211
+ x (dict | None): Optional dictionary of x-axis values for each key.
212
+ replace (bool): Whether to replace existing values. Defaults to False.
213
+
214
+ Raises:
215
+ ValueError: If a key is reserved or logging fails.
216
+ """
217
+ for key, value in data.items():
218
+ if key in RESERVED_NAMES:
219
+ raise ValueError(f"{key} is a reserved name and is not allowed.")
220
+ if key not in self._logs:
221
+ self._logs[key] = value
222
+ else:
223
+ cur_val = self._logs[key]
224
+ if isinstance(cur_val, Metrics):
225
+ cur_x = x.get(key) if x is not None else None
226
+ cur_val.log(value, x=cur_x, replace=replace)
227
+ else:
228
+ if replace:
229
+ self._logs[key] = value
230
+ else:
231
+ raise ValueError(f"Cannot log value for key '{key}' as there is already a value logged.")
232
+ def _finish_episodes(self,
233
+ status: RunStatus,
234
+ err_msg: str | None = None) -> None:
235
+ for episode in self._episodes.values():
236
+ if not episode.is_finished:
237
+ if status == RunStatus.FINISHED:
238
+ episode.finish(status=EpisodeStatus.SUCCESS, err_msg=err_msg)
239
+ elif status == RunStatus.ERRORED:
240
+ episode.finish(status=EpisodeStatus.ERRORED, err_msg=err_msg)
241
+ elif status == RunStatus.CANCELED:
242
+ episode.finish(status=EpisodeStatus.CANCELED, err_msg=err_msg)
243
+
244
+
245
+ def finish(self,
246
+ status: RunStatus = RunStatus.FINISHED,
247
+ err_msg: str | None = None) -> None:
248
+ """Finish the run and submit final metrics.
249
+
250
+ Args:
251
+ status (RunStatus): The final status of the run.
252
+ err_msg (str | None): An optional error message.
253
+ """
254
+ if self._is_finished:
255
+ return
256
+ self._is_finished = True
257
+ self._finish_episodes(status=status, err_msg=err_msg)
258
+
259
+ self._api_client.upload_code(
260
+ artifact_key="scenario",
261
+ run_id=self._id,
262
+ code_content=self.scenario.yaml
263
+ )
264
+
265
+ self._api_client.upload_python(
266
+ artifact_key="seed",
267
+ run_id=self._id,
268
+ pickled_bytes=pickle.dumps(self.scenario.seed)
269
+ )
270
+ # TODO: submit final metrics
271
+ for key, value in self._logs.items():
272
+ if isinstance(value, ScenarioStats):
273
+ for episode_id, episode in self._episodes.items():
274
+ episode_status = episode.status
275
+ value.log_status(
276
+ episode_id=episode_id,
277
+ episode_status=episode_status
278
+ )
279
+ metric_val = value.finalize()
280
+ pickled = pickle.dumps(metric_val)
281
+ self._api_client.upload_scenario_stats_artifact(
282
+ artifact_key=key,
283
+ run_id=self._id,
284
+ pickled_bytes=pickled,
285
+ graph_type=value.graph_type.value,
286
+ )
287
+ elif isinstance(value, Summary):
288
+ metric_val = value.finalize()
289
+ pickled = pickle.dumps(metric_val["value"])
290
+ self._api_client.upload_python(
291
+ artifact_key=key,
292
+ run_id=self._id,
293
+ pickled_bytes=pickled
294
+ )
295
+ elif isinstance(value, Metrics):
296
+ metric_val = value.finalize()
297
+ pickled = pickle.dumps(metric_val)
298
+ self._api_client.upload_metrics(
299
+ artifact_key=key,
300
+ run_id=self._id,
301
+ pickled_bytes=pickled,
302
+ graph_type=value.graph_type.value,
303
+ )
304
+ elif isinstance(value, Code):
305
+ self._api_client.upload_code(
306
+ artifact_key=value.key,
307
+ run_id=value.run_id,
308
+ episode_id=value.episode_id,
309
+ code_content=value.code_content
310
+ )
311
+ else:
312
+ if not is_standard_type(value):
313
+ raise ValueError(f"Value for key '{key}' is not a standard type.")
314
+ pickled = pickle.dumps(value)
315
+ self._api_client.upload_python(
316
+ artifact_key=key,
317
+ run_id=self._id,
318
+ pickled_bytes=pickled
319
+ )
320
+
321
+ self._api_client.update_run(
322
+ run_id=self._id,
323
+ status=status,
324
+ err_msg=err_msg
325
+ )
@@ -0,0 +1,11 @@
1
+ """Scenario management and configuration.
2
+
3
+ This module provides the Scenario class and related utilities for managing scenario
4
+ configurations with probabilistic distributions, supporting randomized scenario generation
5
+ for robotics experiments.
6
+ """
7
+
8
+ from .scenario import Scenario
9
+ from .scenario_operator import list_scenarios, get_scenario
10
+
11
+ __all__ = ["Scenario", "list_scenarios", "get_scenario"]