humalab 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of humalab might be problematic. Click here for more details.

Files changed (39) hide show
  1. humalab/__init__.py +11 -0
  2. humalab/assets/__init__.py +2 -2
  3. humalab/assets/files/resource_file.py +29 -3
  4. humalab/assets/files/urdf_file.py +14 -10
  5. humalab/assets/resource_operator.py +91 -0
  6. humalab/constants.py +39 -5
  7. humalab/dists/bernoulli.py +16 -0
  8. humalab/dists/categorical.py +4 -0
  9. humalab/dists/discrete.py +22 -0
  10. humalab/dists/gaussian.py +22 -0
  11. humalab/dists/log_uniform.py +22 -0
  12. humalab/dists/truncated_gaussian.py +36 -0
  13. humalab/dists/uniform.py +22 -0
  14. humalab/episode.py +196 -0
  15. humalab/humalab.py +116 -153
  16. humalab/humalab_api_client.py +760 -62
  17. humalab/humalab_config.py +0 -13
  18. humalab/humalab_test.py +46 -29
  19. humalab/metrics/__init__.py +5 -5
  20. humalab/metrics/code.py +28 -0
  21. humalab/metrics/metric.py +41 -108
  22. humalab/metrics/scenario_stats.py +95 -0
  23. humalab/metrics/summary.py +24 -18
  24. humalab/run.py +180 -115
  25. humalab/scenarios/__init__.py +4 -0
  26. humalab/scenarios/scenario.py +372 -0
  27. humalab/scenarios/scenario_operator.py +82 -0
  28. humalab/{scenario_test.py → scenarios/scenario_test.py} +150 -269
  29. humalab/utils.py +37 -0
  30. {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/METADATA +1 -1
  31. humalab-0.0.6.dist-info/RECORD +39 -0
  32. humalab/assets/resource_manager.py +0 -57
  33. humalab/metrics/dist_metric.py +0 -22
  34. humalab/scenario.py +0 -225
  35. humalab-0.0.4.dist-info/RECORD +0 -34
  36. {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/WHEEL +0 -0
  37. {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/entry_points.txt +0 -0
  38. {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/licenses/LICENSE +0 -0
  39. {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/top_level.txt +0 -0
humalab/humalab_config.py CHANGED
@@ -6,13 +6,11 @@ class HumalabConfig:
6
6
  def __init__(self):
7
7
  self._config = {
8
8
  "workspace_path": "",
9
- "entity": "",
10
9
  "base_url": "",
11
10
  "api_key": "",
12
11
  "timeout": 30.0,
13
12
  }
14
13
  self._workspace_path = ""
15
- self._entity = ""
16
14
  self._base_url = ""
17
15
  self._api_key = ""
18
16
  self._timeout = 30.0
@@ -27,7 +25,6 @@ class HumalabConfig:
27
25
  with open(config_path, "r") as f:
28
26
  self._config = yaml.safe_load(f)
29
27
  self._workspace_path = os.path.expanduser(self._config["workspace_path"]) if self._config and "workspace_path" in self._config else home_path
30
- self._entity = self._config["entity"] if self._config and "entity" in self._config else ""
31
28
  self._base_url = self._config["base_url"] if self._config and "base_url" in self._config else ""
32
29
  self._api_key = self._config["api_key"] if self._config and "api_key" in self._config else ""
33
30
  self._timeout = self._config["timeout"] if self._config and "timeout" in self._config else 30.0
@@ -45,16 +42,6 @@ class HumalabConfig:
45
42
  self._config["workspace_path"] = path
46
43
  self._save()
47
44
 
48
- @property
49
- def entity(self) -> str:
50
- return str(self._entity)
51
-
52
- @entity.setter
53
- def entity(self, entity: str) -> None:
54
- self._entity = entity
55
- self._config["entity"] = entity
56
- self._save()
57
-
58
45
  @property
59
46
  def base_url(self) -> str:
60
47
  return str(self._base_url)
humalab/humalab_test.py CHANGED
@@ -2,12 +2,13 @@ import unittest
2
2
  from unittest.mock import patch, MagicMock, Mock
3
3
  import uuid
4
4
 
5
+ from humalab.constants import DEFAULT_PROJECT
5
6
  from humalab import humalab
6
7
  from humalab.run import Run
7
- from humalab.scenario import Scenario
8
+ from humalab.scenarios.scenario import Scenario
8
9
  from humalab.humalab_config import HumalabConfig
9
10
  from humalab.humalab_api_client import HumaLabApiClient
10
- from humalab.constants import EpisodeStatus
11
+ from humalab.humalab_api_client import EpisodeStatus, RunStatus
11
12
 
12
13
 
13
14
  class HumalabTest(unittest.TestCase):
@@ -30,9 +31,10 @@ class HumalabTest(unittest.TestCase):
30
31
  # Pre-condition
31
32
  client = Mock()
32
33
  scenario = {"key": "value"}
34
+ project = "test_project"
33
35
 
34
36
  # In-test
35
- result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
37
+ result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=None)
36
38
 
37
39
  # Post-condition
38
40
  self.assertEqual(result, scenario)
@@ -42,32 +44,34 @@ class HumalabTest(unittest.TestCase):
42
44
  """Test that _pull_scenario fetches from API when scenario_id is provided."""
43
45
  # Pre-condition
44
46
  client = Mock()
47
+ project = "test_project"
45
48
  scenario_id = "test-scenario-id"
46
49
  yaml_content = "scenario: test"
47
50
  client.get_scenario.return_value = {"yaml_content": yaml_content}
48
51
 
49
52
  # In-test
50
- result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
53
+ result = humalab._pull_scenario(client=client, project=project, scenario=None, scenario_id=scenario_id)
51
54
 
52
55
  # Post-condition
53
56
  self.assertEqual(result, yaml_content)
54
- client.get_scenario.assert_called_once_with(uuid=scenario_id)
57
+ client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
55
58
 
56
59
  def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
57
60
  """Test that _pull_scenario uses scenario_id even when scenario is provided."""
58
61
  # Pre-condition
59
62
  client = Mock()
63
+ project = "test_project"
60
64
  scenario = {"key": "value"}
61
65
  scenario_id = "test-scenario-id"
62
66
  yaml_content = "scenario: from_api"
63
67
  client.get_scenario.return_value = {"yaml_content": yaml_content}
64
68
 
65
69
  # In-test
66
- result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
70
+ result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=scenario_id)
67
71
 
68
72
  # Post-condition
69
73
  self.assertEqual(result, yaml_content)
70
- client.get_scenario.assert_called_once_with(uuid=scenario_id)
74
+ client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
71
75
 
72
76
  # Tests for init context manager
73
77
 
@@ -78,7 +82,6 @@ class HumalabTest(unittest.TestCase):
78
82
  def test_init_should_create_run_with_provided_parameters(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
79
83
  """Test that init() creates a Run with provided parameters."""
80
84
  # Pre-condition
81
- entity = "test_entity"
82
85
  project = "test_project"
83
86
  name = "test_name"
84
87
  description = "test_description"
@@ -87,13 +90,14 @@ class HumalabTest(unittest.TestCase):
87
90
  scenario_data = {"key": "value"}
88
91
 
89
92
  mock_config = Mock()
90
- mock_config.entity = "default_entity"
91
93
  mock_config.base_url = "http://localhost:8000"
92
94
  mock_config.api_key = "test_key"
93
95
  mock_config.timeout = 30.0
94
96
  mock_config_class.return_value = mock_config
95
97
 
96
98
  mock_api_client = Mock()
99
+ mock_api_client.create_project.return_value = {"name": project}
100
+ mock_api_client.get_run.return_value = {"run_id": run_id, "name": name, "description": description, "tags": tags}
97
101
  mock_api_client_class.return_value = mock_api_client
98
102
 
99
103
  mock_scenario_inst = Mock()
@@ -104,7 +108,6 @@ class HumalabTest(unittest.TestCase):
104
108
 
105
109
  # In-test
106
110
  with humalab.init(
107
- entity=entity,
108
111
  project=project,
109
112
  name=name,
110
113
  description=description,
@@ -116,7 +119,6 @@ class HumalabTest(unittest.TestCase):
116
119
  self.assertEqual(run, mock_run_inst)
117
120
  mock_run_class.assert_called_once()
118
121
  call_kwargs = mock_run_class.call_args.kwargs
119
- self.assertEqual(call_kwargs['entity'], entity)
120
122
  self.assertEqual(call_kwargs['project'], project)
121
123
  self.assertEqual(call_kwargs['name'], name)
122
124
  self.assertEqual(call_kwargs['description'], description)
@@ -135,13 +137,14 @@ class HumalabTest(unittest.TestCase):
135
137
  """Test that init() uses config defaults when parameters are not provided."""
136
138
  # Pre-condition
137
139
  mock_config = Mock()
138
- mock_config.entity = "config_entity"
139
140
  mock_config.base_url = "http://config:8000"
140
141
  mock_config.api_key = "config_key"
141
142
  mock_config.timeout = 60.0
142
143
  mock_config_class.return_value = mock_config
143
144
 
144
145
  mock_api_client = Mock()
146
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
147
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
145
148
  mock_api_client_class.return_value = mock_api_client
146
149
 
147
150
  mock_scenario_inst = Mock()
@@ -154,8 +157,7 @@ class HumalabTest(unittest.TestCase):
154
157
  with humalab.init() as run:
155
158
  # Post-condition
156
159
  call_kwargs = mock_run_class.call_args.kwargs
157
- self.assertEqual(call_kwargs['entity'], "config_entity")
158
- self.assertEqual(call_kwargs['project'], "default")
160
+ self.assertEqual(call_kwargs['project'], DEFAULT_PROJECT)
159
161
  self.assertEqual(call_kwargs['name'], "")
160
162
  self.assertEqual(call_kwargs['description'], "")
161
163
  self.assertIsNotNone(call_kwargs['id']) # UUID generated
@@ -171,13 +173,23 @@ class HumalabTest(unittest.TestCase):
171
173
  """Test that init() generates a UUID when id is not provided."""
172
174
  # Pre-condition
173
175
  mock_config = Mock()
174
- mock_config.entity = "test_entity"
175
176
  mock_config.base_url = "http://localhost:8000"
176
177
  mock_config.api_key = "test_key"
177
178
  mock_config.timeout = 30.0
178
179
  mock_config_class.return_value = mock_config
179
180
 
181
+ # Mock HTTP 404 error for get_run (run doesn't exist yet)
182
+ import requests
183
+ http_error = requests.HTTPError()
184
+ http_error.response = Mock()
185
+ http_error.response.status_code = 404
186
+
180
187
  mock_api_client = Mock()
188
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
189
+ mock_api_client.get_run.side_effect = http_error
190
+ # Mock create_run to return a valid UUID
191
+ generated_uuid = str(uuid.uuid4())
192
+ mock_api_client.create_run.return_value = {"run_id": generated_uuid, "name": "", "description": "", "tags": None}
181
193
  mock_api_client_class.return_value = mock_api_client
182
194
 
183
195
  mock_scenario_inst = Mock()
@@ -207,13 +219,14 @@ class HumalabTest(unittest.TestCase):
207
219
  scenario_data = {"key": "value"}
208
220
 
209
221
  mock_config = Mock()
210
- mock_config.entity = "test_entity"
211
222
  mock_config.base_url = "http://localhost:8000"
212
223
  mock_config.api_key = "test_key"
213
224
  mock_config.timeout = 30.0
214
225
  mock_config_class.return_value = mock_config
215
226
 
216
227
  mock_api_client = Mock()
228
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
229
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
217
230
  mock_api_client_class.return_value = mock_api_client
218
231
 
219
232
  mock_scenario_inst = Mock()
@@ -243,13 +256,14 @@ class HumalabTest(unittest.TestCase):
243
256
  yaml_content = "scenario: from_api"
244
257
 
245
258
  mock_config = Mock()
246
- mock_config.entity = "test_entity"
247
259
  mock_config.base_url = "http://localhost:8000"
248
260
  mock_config.api_key = "test_key"
249
261
  mock_config.timeout = 30.0
250
262
  mock_config_class.return_value = mock_config
251
263
 
252
264
  mock_api_client = Mock()
265
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
266
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
253
267
  mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
254
268
  mock_api_client_class.return_value = mock_api_client
255
269
 
@@ -262,7 +276,7 @@ class HumalabTest(unittest.TestCase):
262
276
  # In-test
263
277
  with humalab.init(scenario_id=scenario_id) as run:
264
278
  # Post-condition
265
- mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
279
+ mock_api_client.get_scenario.assert_called_once_with(project_name='default', uuid=scenario_id, version=None)
266
280
  mock_scenario_inst.init.assert_called_once()
267
281
  call_kwargs = mock_scenario_inst.init.call_args.kwargs
268
282
  self.assertEqual(call_kwargs['scenario'], yaml_content)
@@ -277,13 +291,14 @@ class HumalabTest(unittest.TestCase):
277
291
  """Test that init() sets the global _cur_run variable."""
278
292
  # Pre-condition
279
293
  mock_config = Mock()
280
- mock_config.entity = "test_entity"
281
294
  mock_config.base_url = "http://localhost:8000"
282
295
  mock_config.api_key = "test_key"
283
296
  mock_config.timeout = 30.0
284
297
  mock_config_class.return_value = mock_config
285
298
 
286
299
  mock_api_client = Mock()
300
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
301
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
287
302
  mock_api_client_class.return_value = mock_api_client
288
303
 
289
304
  mock_scenario_inst = Mock()
@@ -308,13 +323,14 @@ class HumalabTest(unittest.TestCase):
308
323
  """Test that init() calls finish even when exception occurs in context."""
309
324
  # Pre-condition
310
325
  mock_config = Mock()
311
- mock_config.entity = "test_entity"
312
326
  mock_config.base_url = "http://localhost:8000"
313
327
  mock_config.api_key = "test_key"
314
328
  mock_config.timeout = 30.0
315
329
  mock_config_class.return_value = mock_config
316
330
 
317
331
  mock_api_client = Mock()
332
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
333
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
318
334
  mock_api_client_class.return_value = mock_api_client
319
335
 
320
336
  mock_scenario_inst = Mock()
@@ -343,13 +359,14 @@ class HumalabTest(unittest.TestCase):
343
359
  timeout = 120.0
344
360
 
345
361
  mock_config = Mock()
346
- mock_config.entity = "test_entity"
347
362
  mock_config.base_url = "http://localhost:8000"
348
363
  mock_config.api_key = "default_key"
349
364
  mock_config.timeout = 30.0
350
365
  mock_config_class.return_value = mock_config
351
366
 
352
367
  mock_api_client = Mock()
368
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
369
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
353
370
  mock_api_client_class.return_value = mock_api_client
354
371
 
355
372
  mock_scenario_inst = Mock()
@@ -381,33 +398,33 @@ class HumalabTest(unittest.TestCase):
381
398
  humalab.finish()
382
399
 
383
400
  # Post-condition
384
- mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=None)
401
+ mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=None)
385
402
 
386
403
  def test_finish_should_call_finish_on_current_run_with_custom_status(self):
387
404
  """Test that finish() calls finish on the current run with custom status."""
388
405
  # Pre-condition
389
406
  mock_run = Mock()
390
407
  humalab._cur_run = mock_run
391
- status = EpisodeStatus.FAILED
408
+ status = RunStatus.ERRORED
392
409
 
393
410
  # In-test
394
411
  humalab.finish(status=status)
395
412
 
396
413
  # Post-condition
397
- mock_run.finish.assert_called_once_with(status=status, quiet=None)
414
+ mock_run.finish.assert_called_once_with(status=status, err_msg=None)
398
415
 
399
- def test_finish_should_call_finish_on_current_run_with_quiet_parameter(self):
400
- """Test that finish() calls finish on the current run with quiet parameter."""
416
+ def test_finish_should_call_finish_on_current_run_with_err_msg_parameter(self):
417
+ """Test that finish() calls finish on the current run with err_msg parameter."""
401
418
  # Pre-condition
402
419
  mock_run = Mock()
403
420
  humalab._cur_run = mock_run
404
- quiet = True
421
+ err_msg = "Test error message"
405
422
 
406
423
  # In-test
407
- humalab.finish(quiet=quiet)
424
+ humalab.finish(err_msg=err_msg)
408
425
 
409
426
  # Post-condition
410
- mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=quiet)
427
+ mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=err_msg)
411
428
 
412
429
  def test_finish_should_do_nothing_when_no_current_run(self):
413
430
  """Test that finish() does nothing when _cur_run is None."""
@@ -1,11 +1,11 @@
1
- from .metric import MetricGranularity, MetricType, Metrics
2
- from .dist_metric import DistributionMetric
1
+ from .metric import Metrics
2
+ from .code import Code
3
+ from .scenario_stats import ScenarioStats
3
4
  from .summary import Summary
4
5
 
5
6
  __all__ = [
6
- "MetricGranularity",
7
- "MetricType",
7
+ "Code",
8
8
  "Metrics",
9
- "DistributionMetric",
9
+ "ScenarioStats",
10
10
  "Summary",
11
11
  ]
@@ -0,0 +1,28 @@
1
+ class Code:
2
+ """Class for logging code artifacts."""
3
+ def __init__(self,
4
+ run_id: str,
5
+ key: str,
6
+ code_content: str,
7
+ episode_id: str | None = None) -> None:
8
+ super().__init__()
9
+ self._run_id = run_id
10
+ self._key = key
11
+ self._code_content = code_content
12
+ self._episode_id = episode_id
13
+
14
+ @property
15
+ def run_id(self) -> str:
16
+ return self._run_id
17
+
18
+ @property
19
+ def key(self) -> str:
20
+ return self._key
21
+
22
+ @property
23
+ def code_content(self) -> str:
24
+ return self._code_content
25
+
26
+ @property
27
+ def episode_id(self) -> str | None:
28
+ return self._episode_id
humalab/metrics/metric.py CHANGED
@@ -1,129 +1,62 @@
1
- from enum import Enum
2
- from typing import Any
3
- from humalab.constants import EpisodeStatus
4
-
5
-
6
- class MetricType(Enum):
7
- DEFAULT = "default"
8
- STREAM = "stream"
9
- DISTRIBUTION = "distribution"
10
- SUMMARY = "summary"
11
-
12
-
13
- class MetricGranularity(Enum):
14
- STEP = "step"
15
- EPISODE = "episode"
16
- RUN = "run"
1
+ from typing import Any
2
+ from humalab.constants import MetricDimType, GraphType
17
3
 
18
4
 
19
5
  class Metrics:
20
- def __init__(self,
21
- name: str,
22
- metric_type: MetricType,
23
- episode_id: str,
24
- run_id: str,
25
- granularity: MetricGranularity = MetricGranularity.STEP) -> None:
6
+ def __init__(self,
7
+ metric_dim_type: MetricDimType= MetricDimType.ONE_D,
8
+ graph_type: GraphType=GraphType.LINE) -> None:
26
9
  """
27
10
  Base class for different types of metrics.
28
-
29
- Args:
30
- name (str): The name of the metric.
31
- metric_type (MetricType): The type of the metric.
32
- episode_id (str): The ID of the episode.
33
- run_id (str): The ID of the run.
34
- granularity (MetricGranularity): The granularity of the metric.
35
11
  """
36
- self._name = name
37
- self._metric_type = metric_type
38
- self._granularity = granularity
39
12
  self._values = []
40
13
  self._x_values = []
41
- self._episode_id = episode_id
42
- self._run_id = run_id
43
- self._last_step = -1
44
-
45
- def reset(self,
46
- episode_id: str | None = None) -> None:
47
- """Reset the metric for a new episode or run.
14
+ self._step = -1
15
+ self._metric_dim_type = metric_dim_type
16
+ self._graph_type = graph_type
48
17
 
49
- Args:
50
- episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
51
- """
52
- if self._granularity != MetricGranularity.RUN:
53
- self._submit()
54
- self._values = []
55
- self._x_values = []
56
- self._last_step = -1
57
- self._episode_id = episode_id
58
-
59
18
  @property
60
- def name(self) -> str:
61
- """The name of the metric.
62
-
63
- Returns:
64
- str: The name of the metric.
65
- """
66
- return self._name
19
+ def metric_dim_type(self) -> MetricDimType:
20
+ return self._metric_dim_type
67
21
 
68
22
  @property
69
- def metric_type(self) -> MetricType:
70
- """The type of the metric.
71
-
72
- Returns:
73
- MetricType: The type of the metric.
74
- """
75
- return self._metric_type
23
+ def graph_type(self) -> GraphType:
24
+ return self._graph_type
76
25
 
77
- @property
78
- def granularity(self) -> MetricGranularity:
79
- """The granularity of the metric.
80
-
81
- Returns:
82
- MetricGranularity: The granularity of the metric.
83
- """
84
- return self._granularity
85
-
86
- def log(self, data: Any, step: int | None = None, replace: bool = True) -> None:
26
+ def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
87
27
  """Log a new data point for the metric. The behavior depends on the granularity.
88
28
 
89
29
  Args:
90
30
  data (Any): The data point to log.
91
- step (int | None): The step number for STEP granularity. Must be provided if granularity is STEP.
92
- replace (bool): Whether to replace the last logged value if logging at the same step/episode/run.
31
+ x (Any | None): The x-axis value associated with the data point.
32
+ if None, the current step is used.
33
+ replace (bool): Whether to replace the last logged value.
93
34
  """
94
- if self._granularity == MetricGranularity.STEP:
95
- if step is None:
96
- raise ValueError("step Must be provided!")
97
- if step == self._last_step:
98
- if replace:
99
- self._values[-1] = data
100
- else:
101
- raise ValueError("Cannot log the data at the same step.")
35
+ if replace:
36
+ self._values[-1] = data
37
+ if x is not None:
38
+ self._x_values[-1] = x
39
+ else:
40
+ self._values.append(data)
41
+ if x is not None:
42
+ self._x_values.append(x)
102
43
  else:
103
- self._values.append(data)
104
- self._x_values.append(step)
105
- elif self._granularity == MetricGranularity.EPISODE:
106
- if len(self._x_values) > 0 and not replace:
107
- raise ValueError("Cannot log the data at the same episode.")
108
- self._values = [data]
109
- self._x_values = [self._episode_id]
110
- else: # MetricGranularity.RUN
111
- if len(self._values) > 0 and not replace:
112
- raise ValueError("Cannot log the data at the same run.")
113
- self._values = [data]
114
- self._x_values = [self._run_id]
115
-
116
- def _submit(self) -> None:
117
- if not self._values:
118
- # If there is no data to submit, then return.
119
- return
120
- # TODO: Implement commit logic
121
-
122
- # Clear data after the submission.
44
+ self._x_values.append(self._step)
45
+ self._step += 1
46
+
47
+ def finalize(self) -> dict:
48
+ """Finalize the logged data for processing."""
49
+ ret_result = self._finalize()
50
+
51
+ return ret_result
52
+
53
+ def _finalize(self) -> dict:
54
+ """Process the logged data before submission. To be implemented by subclasses."""
55
+ ret_val = {
56
+ "values": self._values,
57
+ "x_values": self._x_values
58
+ }
123
59
  self._values = []
124
60
  self._x_values = []
125
-
126
- def finish(self) -> None:
127
- """Finish the metric logging and submit the final data."""
128
- self.reset()
129
- self._submit()
61
+ self._step = -1
62
+ return ret_val
@@ -0,0 +1,95 @@
1
+ from humalab.metrics.metric import Metrics
2
+ from humalab.constants import ArtifactType, GraphType, MetricDimType
3
+ from humalab.humalab_api_client import EpisodeStatus
4
+ from typing import Any
5
+
6
+
7
+ SCENARIO_STATS_NEED_FLATTEN = {
8
+ "uniform_1d",
9
+ "bernoulli_1d",
10
+ "categorical_1d",
11
+ "discrete_1d",
12
+ "log_uniform_1d",
13
+ "gaussian_1d",
14
+ "truncated_gaussian_1d"
15
+ }
16
+
17
+
18
+ class ScenarioStats(Metrics):
19
+ """Metric to track scenario statistics such as total reward, length, and success.
20
+
21
+ Attributes:
22
+ """
23
+
24
+ def __init__(self,
25
+ name: str,
26
+ distribution_type: str,
27
+ metric_dim_type: MetricDimType,
28
+ graph_type: GraphType,
29
+ ) -> None:
30
+ super().__init__(
31
+ metric_dim_type=metric_dim_type,
32
+ graph_type=graph_type
33
+ )
34
+ self._name = name
35
+ self._distribution_type = distribution_type
36
+ self._artifact_type = ArtifactType.SCENARIO_STATS
37
+ self._values = {}
38
+ self._results = {}
39
+
40
+ @property
41
+ def name(self) -> str:
42
+ return self._name
43
+
44
+ @property
45
+ def distribution_type(self) -> str:
46
+ return self._distribution_type
47
+
48
+ @property
49
+ def artifact_type(self) -> ArtifactType:
50
+ return self._artifact_type
51
+
52
+ def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
53
+ if x in self._values:
54
+ if replace:
55
+ if self._distribution_type in SCENARIO_STATS_NEED_FLATTEN:
56
+ data = data[0]
57
+ self._values[x] = data
58
+ else:
59
+ raise ValueError(f"Data for episode_id {x} already exists. Use replace=True to overwrite.")
60
+ else:
61
+ if self._distribution_type in SCENARIO_STATS_NEED_FLATTEN:
62
+ data = data[0]
63
+ self._values[x] = data
64
+
65
+ def log_status(self,
66
+ episode_id: str,
67
+ episode_status: EpisodeStatus,
68
+ replace: bool = False) -> None:
69
+ """Log a new data point for the metric. The behavior depends on the granularity.
70
+
71
+ Args:
72
+ data (Any): The data point to log.
73
+ x (Any | None): The x-axis value associated with the data point.
74
+ if None, the current step is used.
75
+ replace (bool): Whether to replace the last logged value.
76
+ """
77
+ if episode_id in self._results:
78
+ if replace:
79
+ self._results[episode_id] = episode_status
80
+ else:
81
+ raise ValueError(f"Data for episode_id {episode_id} already exists. Use replace=True to overwrite.")
82
+ else:
83
+ self._results[episode_id] = episode_status
84
+
85
+ def _finalize(self) -> dict:
86
+ ret_val = {
87
+ "values": self._values,
88
+ "results": self._results,
89
+ "distribution_type": self._distribution_type,
90
+ }
91
+ self._values = {}
92
+ self._results = {}
93
+ return ret_val
94
+
95
+