humalab 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of humalab might be problematic. Click here for more details.

Files changed (37) hide show
  1. humalab/__init__.py +11 -0
  2. humalab/assets/__init__.py +2 -2
  3. humalab/assets/files/resource_file.py +29 -3
  4. humalab/assets/files/urdf_file.py +14 -10
  5. humalab/assets/resource_operator.py +91 -0
  6. humalab/constants.py +39 -5
  7. humalab/dists/bernoulli.py +2 -1
  8. humalab/dists/discrete.py +2 -2
  9. humalab/dists/gaussian.py +2 -2
  10. humalab/dists/log_uniform.py +2 -2
  11. humalab/dists/truncated_gaussian.py +4 -4
  12. humalab/episode.py +181 -11
  13. humalab/humalab.py +44 -28
  14. humalab/humalab_api_client.py +301 -94
  15. humalab/humalab_test.py +46 -17
  16. humalab/metrics/__init__.py +5 -5
  17. humalab/metrics/code.py +28 -0
  18. humalab/metrics/metric.py +41 -108
  19. humalab/metrics/scenario_stats.py +95 -0
  20. humalab/metrics/summary.py +24 -18
  21. humalab/run.py +180 -103
  22. humalab/scenarios/__init__.py +4 -0
  23. humalab/{scenario.py → scenarios/scenario.py} +120 -129
  24. humalab/scenarios/scenario_operator.py +82 -0
  25. humalab/{scenario_test.py → scenarios/scenario_test.py} +150 -269
  26. humalab/utils.py +37 -0
  27. {humalab-0.0.5.dist-info → humalab-0.0.6.dist-info}/METADATA +1 -1
  28. humalab-0.0.6.dist-info/RECORD +39 -0
  29. humalab/assets/resource_manager.py +0 -58
  30. humalab/evaluators/__init__.py +0 -16
  31. humalab/humalab_main.py +0 -119
  32. humalab/metrics/dist_metric.py +0 -22
  33. humalab-0.0.5.dist-info/RECORD +0 -37
  34. {humalab-0.0.5.dist-info → humalab-0.0.6.dist-info}/WHEEL +0 -0
  35. {humalab-0.0.5.dist-info → humalab-0.0.6.dist-info}/entry_points.txt +0 -0
  36. {humalab-0.0.5.dist-info → humalab-0.0.6.dist-info}/licenses/LICENSE +0 -0
  37. {humalab-0.0.5.dist-info → humalab-0.0.6.dist-info}/top_level.txt +0 -0
humalab/humalab_test.py CHANGED
@@ -2,12 +2,13 @@ import unittest
2
2
  from unittest.mock import patch, MagicMock, Mock
3
3
  import uuid
4
4
 
5
+ from humalab.constants import DEFAULT_PROJECT
5
6
  from humalab import humalab
6
7
  from humalab.run import Run
7
- from humalab.scenario import Scenario
8
+ from humalab.scenarios.scenario import Scenario
8
9
  from humalab.humalab_config import HumalabConfig
9
10
  from humalab.humalab_api_client import HumaLabApiClient
10
- from humalab.constants import EpisodeStatus
11
+ from humalab.humalab_api_client import EpisodeStatus, RunStatus
11
12
 
12
13
 
13
14
  class HumalabTest(unittest.TestCase):
@@ -30,9 +31,10 @@ class HumalabTest(unittest.TestCase):
30
31
  # Pre-condition
31
32
  client = Mock()
32
33
  scenario = {"key": "value"}
34
+ project = "test_project"
33
35
 
34
36
  # In-test
35
- result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
37
+ result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=None)
36
38
 
37
39
  # Post-condition
38
40
  self.assertEqual(result, scenario)
@@ -42,32 +44,34 @@ class HumalabTest(unittest.TestCase):
42
44
  """Test that _pull_scenario fetches from API when scenario_id is provided."""
43
45
  # Pre-condition
44
46
  client = Mock()
47
+ project = "test_project"
45
48
  scenario_id = "test-scenario-id"
46
49
  yaml_content = "scenario: test"
47
50
  client.get_scenario.return_value = {"yaml_content": yaml_content}
48
51
 
49
52
  # In-test
50
- result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
53
+ result = humalab._pull_scenario(client=client, project=project, scenario=None, scenario_id=scenario_id)
51
54
 
52
55
  # Post-condition
53
56
  self.assertEqual(result, yaml_content)
54
- client.get_scenario.assert_called_once_with(uuid=scenario_id)
57
+ client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
55
58
 
56
59
  def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
57
60
  """Test that _pull_scenario uses scenario_id even when scenario is provided."""
58
61
  # Pre-condition
59
62
  client = Mock()
63
+ project = "test_project"
60
64
  scenario = {"key": "value"}
61
65
  scenario_id = "test-scenario-id"
62
66
  yaml_content = "scenario: from_api"
63
67
  client.get_scenario.return_value = {"yaml_content": yaml_content}
64
68
 
65
69
  # In-test
66
- result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
70
+ result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=scenario_id)
67
71
 
68
72
  # Post-condition
69
73
  self.assertEqual(result, yaml_content)
70
- client.get_scenario.assert_called_once_with(uuid=scenario_id)
74
+ client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
71
75
 
72
76
  # Tests for init context manager
73
77
 
@@ -92,6 +96,8 @@ class HumalabTest(unittest.TestCase):
92
96
  mock_config_class.return_value = mock_config
93
97
 
94
98
  mock_api_client = Mock()
99
+ mock_api_client.create_project.return_value = {"name": project}
100
+ mock_api_client.get_run.return_value = {"run_id": run_id, "name": name, "description": description, "tags": tags}
95
101
  mock_api_client_class.return_value = mock_api_client
96
102
 
97
103
  mock_scenario_inst = Mock()
@@ -137,6 +143,8 @@ class HumalabTest(unittest.TestCase):
137
143
  mock_config_class.return_value = mock_config
138
144
 
139
145
  mock_api_client = Mock()
146
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
147
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
140
148
  mock_api_client_class.return_value = mock_api_client
141
149
 
142
150
  mock_scenario_inst = Mock()
@@ -149,7 +157,7 @@ class HumalabTest(unittest.TestCase):
149
157
  with humalab.init() as run:
150
158
  # Post-condition
151
159
  call_kwargs = mock_run_class.call_args.kwargs
152
- self.assertEqual(call_kwargs['project'], "default")
160
+ self.assertEqual(call_kwargs['project'], DEFAULT_PROJECT)
153
161
  self.assertEqual(call_kwargs['name'], "")
154
162
  self.assertEqual(call_kwargs['description'], "")
155
163
  self.assertIsNotNone(call_kwargs['id']) # UUID generated
@@ -170,7 +178,18 @@ class HumalabTest(unittest.TestCase):
170
178
  mock_config.timeout = 30.0
171
179
  mock_config_class.return_value = mock_config
172
180
 
181
+ # Mock HTTP 404 error for get_run (run doesn't exist yet)
182
+ import requests
183
+ http_error = requests.HTTPError()
184
+ http_error.response = Mock()
185
+ http_error.response.status_code = 404
186
+
173
187
  mock_api_client = Mock()
188
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
189
+ mock_api_client.get_run.side_effect = http_error
190
+ # Mock create_run to return a valid UUID
191
+ generated_uuid = str(uuid.uuid4())
192
+ mock_api_client.create_run.return_value = {"run_id": generated_uuid, "name": "", "description": "", "tags": None}
174
193
  mock_api_client_class.return_value = mock_api_client
175
194
 
176
195
  mock_scenario_inst = Mock()
@@ -206,6 +225,8 @@ class HumalabTest(unittest.TestCase):
206
225
  mock_config_class.return_value = mock_config
207
226
 
208
227
  mock_api_client = Mock()
228
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
229
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
209
230
  mock_api_client_class.return_value = mock_api_client
210
231
 
211
232
  mock_scenario_inst = Mock()
@@ -241,6 +262,8 @@ class HumalabTest(unittest.TestCase):
241
262
  mock_config_class.return_value = mock_config
242
263
 
243
264
  mock_api_client = Mock()
265
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
266
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
244
267
  mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
245
268
  mock_api_client_class.return_value = mock_api_client
246
269
 
@@ -253,7 +276,7 @@ class HumalabTest(unittest.TestCase):
253
276
  # In-test
254
277
  with humalab.init(scenario_id=scenario_id) as run:
255
278
  # Post-condition
256
- mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
279
+ mock_api_client.get_scenario.assert_called_once_with(project_name='default', uuid=scenario_id, version=None)
257
280
  mock_scenario_inst.init.assert_called_once()
258
281
  call_kwargs = mock_scenario_inst.init.call_args.kwargs
259
282
  self.assertEqual(call_kwargs['scenario'], yaml_content)
@@ -274,6 +297,8 @@ class HumalabTest(unittest.TestCase):
274
297
  mock_config_class.return_value = mock_config
275
298
 
276
299
  mock_api_client = Mock()
300
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
301
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
277
302
  mock_api_client_class.return_value = mock_api_client
278
303
 
279
304
  mock_scenario_inst = Mock()
@@ -304,6 +329,8 @@ class HumalabTest(unittest.TestCase):
304
329
  mock_config_class.return_value = mock_config
305
330
 
306
331
  mock_api_client = Mock()
332
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
333
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
307
334
  mock_api_client_class.return_value = mock_api_client
308
335
 
309
336
  mock_scenario_inst = Mock()
@@ -338,6 +365,8 @@ class HumalabTest(unittest.TestCase):
338
365
  mock_config_class.return_value = mock_config
339
366
 
340
367
  mock_api_client = Mock()
368
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
369
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
341
370
  mock_api_client_class.return_value = mock_api_client
342
371
 
343
372
  mock_scenario_inst = Mock()
@@ -369,33 +398,33 @@ class HumalabTest(unittest.TestCase):
369
398
  humalab.finish()
370
399
 
371
400
  # Post-condition
372
- mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=None)
401
+ mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=None)
373
402
 
374
403
  def test_finish_should_call_finish_on_current_run_with_custom_status(self):
375
404
  """Test that finish() calls finish on the current run with custom status."""
376
405
  # Pre-condition
377
406
  mock_run = Mock()
378
407
  humalab._cur_run = mock_run
379
- status = EpisodeStatus.FAILED
408
+ status = RunStatus.ERRORED
380
409
 
381
410
  # In-test
382
411
  humalab.finish(status=status)
383
412
 
384
413
  # Post-condition
385
- mock_run.finish.assert_called_once_with(status=status, quiet=None)
414
+ mock_run.finish.assert_called_once_with(status=status, err_msg=None)
386
415
 
387
- def test_finish_should_call_finish_on_current_run_with_quiet_parameter(self):
388
- """Test that finish() calls finish on the current run with quiet parameter."""
416
+ def test_finish_should_call_finish_on_current_run_with_err_msg_parameter(self):
417
+ """Test that finish() calls finish on the current run with err_msg parameter."""
389
418
  # Pre-condition
390
419
  mock_run = Mock()
391
420
  humalab._cur_run = mock_run
392
- quiet = True
421
+ err_msg = "Test error message"
393
422
 
394
423
  # In-test
395
- humalab.finish(quiet=quiet)
424
+ humalab.finish(err_msg=err_msg)
396
425
 
397
426
  # Post-condition
398
- mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=quiet)
427
+ mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=err_msg)
399
428
 
400
429
  def test_finish_should_do_nothing_when_no_current_run(self):
401
430
  """Test that finish() does nothing when _cur_run is None."""
@@ -1,11 +1,11 @@
1
- from .metric import MetricGranularity, MetricType, Metrics
2
- from .dist_metric import DistributionMetric
1
+ from .metric import Metrics
2
+ from .code import Code
3
+ from .scenario_stats import ScenarioStats
3
4
  from .summary import Summary
4
5
 
5
6
  __all__ = [
6
- "MetricGranularity",
7
- "MetricType",
7
+ "Code",
8
8
  "Metrics",
9
- "DistributionMetric",
9
+ "ScenarioStats",
10
10
  "Summary",
11
11
  ]
@@ -0,0 +1,28 @@
1
+ class Code:
2
+ """Class for logging code artifacts."""
3
+ def __init__(self,
4
+ run_id: str,
5
+ key: str,
6
+ code_content: str,
7
+ episode_id: str | None = None) -> None:
8
+ super().__init__()
9
+ self._run_id = run_id
10
+ self._key = key
11
+ self._code_content = code_content
12
+ self._episode_id = episode_id
13
+
14
+ @property
15
+ def run_id(self) -> str:
16
+ return self._run_id
17
+
18
+ @property
19
+ def key(self) -> str:
20
+ return self._key
21
+
22
+ @property
23
+ def code_content(self) -> str:
24
+ return self._code_content
25
+
26
+ @property
27
+ def episode_id(self) -> str | None:
28
+ return self._episode_id
humalab/metrics/metric.py CHANGED
@@ -1,129 +1,62 @@
1
- from enum import Enum
2
- from typing import Any
3
- from humalab.constants import EpisodeStatus
4
-
5
-
6
- class MetricType(Enum):
7
- DEFAULT = "default"
8
- STREAM = "stream"
9
- DISTRIBUTION = "distribution"
10
- SUMMARY = "summary"
11
-
12
-
13
- class MetricGranularity(Enum):
14
- STEP = "step"
15
- EPISODE = "episode"
16
- RUN = "run"
1
+ from typing import Any
2
+ from humalab.constants import MetricDimType, GraphType
17
3
 
18
4
 
19
5
  class Metrics:
20
- def __init__(self,
21
- name: str,
22
- metric_type: MetricType,
23
- episode_id: str,
24
- run_id: str,
25
- granularity: MetricGranularity = MetricGranularity.STEP) -> None:
6
+ def __init__(self,
7
+ metric_dim_type: MetricDimType= MetricDimType.ONE_D,
8
+ graph_type: GraphType=GraphType.LINE) -> None:
26
9
  """
27
10
  Base class for different types of metrics.
28
-
29
- Args:
30
- name (str): The name of the metric.
31
- metric_type (MetricType): The type of the metric.
32
- episode_id (str): The ID of the episode.
33
- run_id (str): The ID of the run.
34
- granularity (MetricGranularity): The granularity of the metric.
35
11
  """
36
- self._name = name
37
- self._metric_type = metric_type
38
- self._granularity = granularity
39
12
  self._values = []
40
13
  self._x_values = []
41
- self._episode_id = episode_id
42
- self._run_id = run_id
43
- self._last_step = -1
44
-
45
- def reset(self,
46
- episode_id: str | None = None) -> None:
47
- """Reset the metric for a new episode or run.
14
+ self._step = -1
15
+ self._metric_dim_type = metric_dim_type
16
+ self._graph_type = graph_type
48
17
 
49
- Args:
50
- episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
51
- """
52
- if self._granularity != MetricGranularity.RUN:
53
- self._submit()
54
- self._values = []
55
- self._x_values = []
56
- self._last_step = -1
57
- self._episode_id = episode_id
58
-
59
18
  @property
60
- def name(self) -> str:
61
- """The name of the metric.
62
-
63
- Returns:
64
- str: The name of the metric.
65
- """
66
- return self._name
19
+ def metric_dim_type(self) -> MetricDimType:
20
+ return self._metric_dim_type
67
21
 
68
22
  @property
69
- def metric_type(self) -> MetricType:
70
- """The type of the metric.
71
-
72
- Returns:
73
- MetricType: The type of the metric.
74
- """
75
- return self._metric_type
23
+ def graph_type(self) -> GraphType:
24
+ return self._graph_type
76
25
 
77
- @property
78
- def granularity(self) -> MetricGranularity:
79
- """The granularity of the metric.
80
-
81
- Returns:
82
- MetricGranularity: The granularity of the metric.
83
- """
84
- return self._granularity
85
-
86
- def log(self, data: Any, step: int | None = None, replace: bool = True) -> None:
26
+ def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
87
27
  """Log a new data point for the metric. The behavior depends on the granularity.
88
28
 
89
29
  Args:
90
30
  data (Any): The data point to log.
91
- step (int | None): The step number for STEP granularity. Must be provided if granularity is STEP.
92
- replace (bool): Whether to replace the last logged value if logging at the same step/episode/run.
31
+ x (Any | None): The x-axis value associated with the data point.
32
+ if None, the current step is used.
33
+ replace (bool): Whether to replace the last logged value.
93
34
  """
94
- if self._granularity == MetricGranularity.STEP:
95
- if step is None:
96
- raise ValueError("step Must be provided!")
97
- if step == self._last_step:
98
- if replace:
99
- self._values[-1] = data
100
- else:
101
- raise ValueError("Cannot log the data at the same step.")
35
+ if replace:
36
+ self._values[-1] = data
37
+ if x is not None:
38
+ self._x_values[-1] = x
39
+ else:
40
+ self._values.append(data)
41
+ if x is not None:
42
+ self._x_values.append(x)
102
43
  else:
103
- self._values.append(data)
104
- self._x_values.append(step)
105
- elif self._granularity == MetricGranularity.EPISODE:
106
- if len(self._x_values) > 0 and not replace:
107
- raise ValueError("Cannot log the data at the same episode.")
108
- self._values = [data]
109
- self._x_values = [self._episode_id]
110
- else: # MetricGranularity.RUN
111
- if len(self._values) > 0 and not replace:
112
- raise ValueError("Cannot log the data at the same run.")
113
- self._values = [data]
114
- self._x_values = [self._run_id]
115
-
116
- def _submit(self) -> None:
117
- if not self._values:
118
- # If there is no data to submit, then return.
119
- return
120
- # TODO: Implement commit logic
121
-
122
- # Clear data after the submission.
44
+ self._x_values.append(self._step)
45
+ self._step += 1
46
+
47
+ def finalize(self) -> dict:
48
+ """Finalize the logged data for processing."""
49
+ ret_result = self._finalize()
50
+
51
+ return ret_result
52
+
53
+ def _finalize(self) -> dict:
54
+ """Process the logged data before submission. To be implemented by subclasses."""
55
+ ret_val = {
56
+ "values": self._values,
57
+ "x_values": self._x_values
58
+ }
123
59
  self._values = []
124
60
  self._x_values = []
125
-
126
- def finish(self) -> None:
127
- """Finish the metric logging and submit the final data."""
128
- self.reset()
129
- self._submit()
61
+ self._step = -1
62
+ return ret_val
@@ -0,0 +1,95 @@
1
+ from humalab.metrics.metric import Metrics
2
+ from humalab.constants import ArtifactType, GraphType, MetricDimType
3
+ from humalab.humalab_api_client import EpisodeStatus
4
+ from typing import Any
5
+
6
+
7
+ SCENARIO_STATS_NEED_FLATTEN = {
8
+ "uniform_1d",
9
+ "bernoulli_1d",
10
+ "categorical_1d",
11
+ "discrete_1d",
12
+ "log_uniform_1d",
13
+ "gaussian_1d",
14
+ "truncated_gaussian_1d"
15
+ }
16
+
17
+
18
+ class ScenarioStats(Metrics):
19
+ """Metric to track scenario statistics such as total reward, length, and success.
20
+
21
+ Attributes:
22
+ """
23
+
24
+ def __init__(self,
25
+ name: str,
26
+ distribution_type: str,
27
+ metric_dim_type: MetricDimType,
28
+ graph_type: GraphType,
29
+ ) -> None:
30
+ super().__init__(
31
+ metric_dim_type=metric_dim_type,
32
+ graph_type=graph_type
33
+ )
34
+ self._name = name
35
+ self._distribution_type = distribution_type
36
+ self._artifact_type = ArtifactType.SCENARIO_STATS
37
+ self._values = {}
38
+ self._results = {}
39
+
40
+ @property
41
+ def name(self) -> str:
42
+ return self._name
43
+
44
+ @property
45
+ def distribution_type(self) -> str:
46
+ return self._distribution_type
47
+
48
+ @property
49
+ def artifact_type(self) -> ArtifactType:
50
+ return self._artifact_type
51
+
52
+ def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
53
+ if x in self._values:
54
+ if replace:
55
+ if self._distribution_type in SCENARIO_STATS_NEED_FLATTEN:
56
+ data = data[0]
57
+ self._values[x] = data
58
+ else:
59
+ raise ValueError(f"Data for episode_id {x} already exists. Use replace=True to overwrite.")
60
+ else:
61
+ if self._distribution_type in SCENARIO_STATS_NEED_FLATTEN:
62
+ data = data[0]
63
+ self._values[x] = data
64
+
65
+ def log_status(self,
66
+ episode_id: str,
67
+ episode_status: EpisodeStatus,
68
+ replace: bool = False) -> None:
69
+ """Log a new data point for the metric. The behavior depends on the granularity.
70
+
71
+ Args:
72
+ data (Any): The data point to log.
73
+ x (Any | None): The x-axis value associated with the data point.
74
+ if None, the current step is used.
75
+ replace (bool): Whether to replace the last logged value.
76
+ """
77
+ if episode_id in self._results:
78
+ if replace:
79
+ self._results[episode_id] = episode_status
80
+ else:
81
+ raise ValueError(f"Data for episode_id {episode_id} already exists. Use replace=True to overwrite.")
82
+ else:
83
+ self._results[episode_id] = episode_status
84
+
85
+ def _finalize(self) -> dict:
86
+ ret_val = {
87
+ "values": self._values,
88
+ "results": self._results,
89
+ "distribution_type": self._distribution_type,
90
+ }
91
+ self._values = {}
92
+ self._results = {}
93
+ return ret_val
94
+
95
+
@@ -1,15 +1,11 @@
1
1
 
2
- from humalab.metrics.metric import MetricGranularity, Metrics, MetricType
3
- from humalab.constants import EpisodeStatus
2
+ from humalab.metrics.metric import Metrics
3
+ from humalab.constants import MetricDimType, GraphType
4
4
 
5
5
 
6
6
  class Summary(Metrics):
7
7
  def __init__(self,
8
- name: str,
9
8
  summary: str,
10
- episode_id: str,
11
- run_id: str,
12
- granularity: MetricGranularity = MetricGranularity.RUN,
13
9
  ) -> None:
14
10
  """
15
11
  A summary metric that captures a single value per episode or run.
@@ -22,26 +18,33 @@ class Summary(Metrics):
22
18
  from being generated.
23
19
  granularity (MetricGranularity): The granularity of the metric.
24
20
  """
25
- if granularity == MetricGranularity.RUN:
26
- raise ValueError("Summary metrics cannot have RUN granularity.")
27
21
  if summary not in {"min", "max", "mean", "last", "first", "none"}:
28
22
  raise ValueError(f"Unsupported summary type: {summary}. Supported types are 'min', 'max', 'mean', 'last', 'first', and 'none'.")
29
- super().__init__(name, MetricType.SUMMARY, episode_id=episode_id, run_id=run_id, granularity=granularity)
30
- self.summary = summary
23
+ super().__init__(metric_dim_type= MetricDimType.ZERO_D,
24
+ graph_type=GraphType.NUMERIC)
25
+ self._summary = summary
26
+
27
+ @property
28
+ def summary(self) -> str:
29
+ return self._summary
31
30
 
32
- def _submit(self) -> None:
31
+ def _finalize(self) -> dict:
33
32
  if not self._values:
34
- return
33
+ return {
34
+ "value": None,
35
+ "summary": self.summary
36
+ }
37
+ final_val = None
35
38
  # For summary metrics, we only keep the latest value
36
39
  if self.summary == "last":
37
- self._values = [self._values[-1]]
40
+ final_val = self._values[-1]
38
41
  elif self.summary == "first":
39
- self._values = [self._values[0]]
42
+ final_val = self._values[0]
40
43
  elif self.summary == "none":
41
- self._values = []
44
+ final_val = None
42
45
  elif self.summary in {"min", "max", "mean"}:
43
46
  if not self._values:
44
- self._values = []
47
+ final_val = None
45
48
  else:
46
49
  if self.summary == "min":
47
50
  agg_value = min(self._values)
@@ -49,6 +52,9 @@ class Summary(Metrics):
49
52
  agg_value = max(self._values)
50
53
  elif self.summary == "mean":
51
54
  agg_value = sum(self._values) / len(self._values)
52
- self._values = [agg_value]
55
+ final_val = agg_value
53
56
 
54
- super()._submit()
57
+ return {
58
+ "value": final_val,
59
+ "summary": self.summary
60
+ }