humalab 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of humalab might be problematic. Click here for more details.

Files changed (42) hide show
  1. humalab/__init__.py +25 -0
  2. humalab/assets/__init__.py +8 -2
  3. humalab/assets/files/resource_file.py +96 -6
  4. humalab/assets/files/urdf_file.py +49 -11
  5. humalab/assets/resource_operator.py +139 -0
  6. humalab/constants.py +48 -5
  7. humalab/dists/__init__.py +7 -0
  8. humalab/dists/bernoulli.py +26 -1
  9. humalab/dists/categorical.py +25 -0
  10. humalab/dists/discrete.py +27 -2
  11. humalab/dists/distribution.py +11 -0
  12. humalab/dists/gaussian.py +27 -2
  13. humalab/dists/log_uniform.py +29 -3
  14. humalab/dists/truncated_gaussian.py +33 -4
  15. humalab/dists/uniform.py +24 -0
  16. humalab/episode.py +291 -11
  17. humalab/humalab.py +93 -38
  18. humalab/humalab_api_client.py +297 -95
  19. humalab/humalab_config.py +49 -0
  20. humalab/humalab_test.py +46 -17
  21. humalab/metrics/__init__.py +11 -5
  22. humalab/metrics/code.py +59 -0
  23. humalab/metrics/metric.py +69 -102
  24. humalab/metrics/scenario_stats.py +163 -0
  25. humalab/metrics/summary.py +45 -24
  26. humalab/run.py +224 -101
  27. humalab/scenarios/__init__.py +11 -0
  28. humalab/{scenario.py → scenarios/scenario.py} +130 -136
  29. humalab/scenarios/scenario_operator.py +114 -0
  30. humalab/{scenario_test.py → scenarios/scenario_test.py} +150 -269
  31. humalab/utils.py +37 -0
  32. {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/METADATA +1 -1
  33. humalab-0.0.7.dist-info/RECORD +39 -0
  34. humalab/assets/resource_manager.py +0 -58
  35. humalab/evaluators/__init__.py +0 -16
  36. humalab/humalab_main.py +0 -119
  37. humalab/metrics/dist_metric.py +0 -22
  38. humalab-0.0.5.dist-info/RECORD +0 -37
  39. {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/WHEEL +0 -0
  40. {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/entry_points.txt +0 -0
  41. {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/licenses/LICENSE +0 -0
  42. {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/top_level.txt +0 -0
humalab/humalab_config.py CHANGED
@@ -3,6 +3,18 @@ import yaml
3
3
  import os
4
4
 
5
5
  class HumalabConfig:
6
+ """Manages HumaLab SDK configuration settings.
7
+
8
+ Configuration is stored in ~/.humalab/config.yaml and includes workspace path,
9
+ API credentials, and connection settings. Values are automatically loaded on
10
+ initialization and saved when modified through property setters.
11
+
12
+ Attributes:
13
+ workspace_path (str): The local workspace directory path.
14
+ base_url (str): The HumaLab API base URL.
15
+ api_key (str): The API key for authentication.
16
+ timeout (float): Request timeout in seconds.
17
+ """
6
18
  def __init__(self):
7
19
  self._config = {
8
20
  "workspace_path": "",
@@ -17,6 +29,7 @@ class HumalabConfig:
17
29
  self._load_config()
18
30
 
19
31
  def _load_config(self):
32
+ """Load configuration from ~/.humalab/config.yaml."""
20
33
  home_path = Path.home()
21
34
  config_path = home_path / ".humalab" / "config.yaml"
22
35
  if not config_path.exists():
@@ -30,10 +43,16 @@ class HumalabConfig:
30
43
  self._timeout = self._config["timeout"] if self._config and "timeout" in self._config else 30.0
31
44
 
32
45
  def _save(self) -> None:
46
+ """Save current configuration to ~/.humalab/config.yaml."""
33
47
  yaml.dump(self._config, open(Path.home() / ".humalab" / "config.yaml", "w"))
34
48
 
35
49
  @property
36
50
  def workspace_path(self) -> str:
51
+ """The local workspace directory path.
52
+
53
+ Returns:
54
+ str: The workspace path.
55
+ """
37
56
  return str(self._workspace_path)
38
57
 
39
58
  @workspace_path.setter
@@ -44,30 +63,60 @@ class HumalabConfig:
44
63
 
45
64
  @property
46
65
  def base_url(self) -> str:
66
+ """The HumaLab API base URL.
67
+
68
+ Returns:
69
+ str: The base URL.
70
+ """
47
71
  return str(self._base_url)
48
72
 
49
73
  @base_url.setter
50
74
  def base_url(self, base_url: str) -> None:
75
+ """Set the HumaLab API base URL and save to config.
76
+
77
+ Args:
78
+ base_url (str): The new base URL.
79
+ """
51
80
  self._base_url = base_url
52
81
  self._config["base_url"] = base_url
53
82
  self._save()
54
83
 
55
84
  @property
56
85
  def api_key(self) -> str:
86
+ """The API key for authentication.
87
+
88
+ Returns:
89
+ str: The API key.
90
+ """
57
91
  return str(self._api_key)
58
92
 
59
93
  @api_key.setter
60
94
  def api_key(self, api_key: str) -> None:
95
+ """Set the API key and save to config.
96
+
97
+ Args:
98
+ api_key (str): The new API key.
99
+ """
61
100
  self._api_key = api_key
62
101
  self._config["api_key"] = api_key
63
102
  self._save()
64
103
 
65
104
  @property
66
105
  def timeout(self) -> float:
106
+ """Request timeout in seconds.
107
+
108
+ Returns:
109
+ float: The timeout value.
110
+ """
67
111
  return self._timeout
68
112
 
69
113
  @timeout.setter
70
114
  def timeout(self, timeout: float) -> None:
115
+ """Set the request timeout and save to config.
116
+
117
+ Args:
118
+ timeout (float): The new timeout in seconds.
119
+ """
71
120
  self._timeout = timeout
72
121
  self._config["timeout"] = timeout
73
122
  self._save()
humalab/humalab_test.py CHANGED
@@ -2,12 +2,13 @@ import unittest
2
2
  from unittest.mock import patch, MagicMock, Mock
3
3
  import uuid
4
4
 
5
+ from humalab.constants import DEFAULT_PROJECT
5
6
  from humalab import humalab
6
7
  from humalab.run import Run
7
- from humalab.scenario import Scenario
8
+ from humalab.scenarios.scenario import Scenario
8
9
  from humalab.humalab_config import HumalabConfig
9
10
  from humalab.humalab_api_client import HumaLabApiClient
10
- from humalab.constants import EpisodeStatus
11
+ from humalab.humalab_api_client import EpisodeStatus, RunStatus
11
12
 
12
13
 
13
14
  class HumalabTest(unittest.TestCase):
@@ -30,9 +31,10 @@ class HumalabTest(unittest.TestCase):
30
31
  # Pre-condition
31
32
  client = Mock()
32
33
  scenario = {"key": "value"}
34
+ project = "test_project"
33
35
 
34
36
  # In-test
35
- result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
37
+ result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=None)
36
38
 
37
39
  # Post-condition
38
40
  self.assertEqual(result, scenario)
@@ -42,32 +44,34 @@ class HumalabTest(unittest.TestCase):
42
44
  """Test that _pull_scenario fetches from API when scenario_id is provided."""
43
45
  # Pre-condition
44
46
  client = Mock()
47
+ project = "test_project"
45
48
  scenario_id = "test-scenario-id"
46
49
  yaml_content = "scenario: test"
47
50
  client.get_scenario.return_value = {"yaml_content": yaml_content}
48
51
 
49
52
  # In-test
50
- result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
53
+ result = humalab._pull_scenario(client=client, project=project, scenario=None, scenario_id=scenario_id)
51
54
 
52
55
  # Post-condition
53
56
  self.assertEqual(result, yaml_content)
54
- client.get_scenario.assert_called_once_with(uuid=scenario_id)
57
+ client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
55
58
 
56
59
  def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
57
60
  """Test that _pull_scenario uses scenario_id even when scenario is provided."""
58
61
  # Pre-condition
59
62
  client = Mock()
63
+ project = "test_project"
60
64
  scenario = {"key": "value"}
61
65
  scenario_id = "test-scenario-id"
62
66
  yaml_content = "scenario: from_api"
63
67
  client.get_scenario.return_value = {"yaml_content": yaml_content}
64
68
 
65
69
  # In-test
66
- result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
70
+ result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=scenario_id)
67
71
 
68
72
  # Post-condition
69
73
  self.assertEqual(result, yaml_content)
70
- client.get_scenario.assert_called_once_with(uuid=scenario_id)
74
+ client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
71
75
 
72
76
  # Tests for init context manager
73
77
 
@@ -92,6 +96,8 @@ class HumalabTest(unittest.TestCase):
92
96
  mock_config_class.return_value = mock_config
93
97
 
94
98
  mock_api_client = Mock()
99
+ mock_api_client.create_project.return_value = {"name": project}
100
+ mock_api_client.get_run.return_value = {"run_id": run_id, "name": name, "description": description, "tags": tags}
95
101
  mock_api_client_class.return_value = mock_api_client
96
102
 
97
103
  mock_scenario_inst = Mock()
@@ -137,6 +143,8 @@ class HumalabTest(unittest.TestCase):
137
143
  mock_config_class.return_value = mock_config
138
144
 
139
145
  mock_api_client = Mock()
146
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
147
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
140
148
  mock_api_client_class.return_value = mock_api_client
141
149
 
142
150
  mock_scenario_inst = Mock()
@@ -149,7 +157,7 @@ class HumalabTest(unittest.TestCase):
149
157
  with humalab.init() as run:
150
158
  # Post-condition
151
159
  call_kwargs = mock_run_class.call_args.kwargs
152
- self.assertEqual(call_kwargs['project'], "default")
160
+ self.assertEqual(call_kwargs['project'], DEFAULT_PROJECT)
153
161
  self.assertEqual(call_kwargs['name'], "")
154
162
  self.assertEqual(call_kwargs['description'], "")
155
163
  self.assertIsNotNone(call_kwargs['id']) # UUID generated
@@ -170,7 +178,18 @@ class HumalabTest(unittest.TestCase):
170
178
  mock_config.timeout = 30.0
171
179
  mock_config_class.return_value = mock_config
172
180
 
181
+ # Mock HTTP 404 error for get_run (run doesn't exist yet)
182
+ import requests
183
+ http_error = requests.HTTPError()
184
+ http_error.response = Mock()
185
+ http_error.response.status_code = 404
186
+
173
187
  mock_api_client = Mock()
188
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
189
+ mock_api_client.get_run.side_effect = http_error
190
+ # Mock create_run to return a valid UUID
191
+ generated_uuid = str(uuid.uuid4())
192
+ mock_api_client.create_run.return_value = {"run_id": generated_uuid, "name": "", "description": "", "tags": None}
174
193
  mock_api_client_class.return_value = mock_api_client
175
194
 
176
195
  mock_scenario_inst = Mock()
@@ -206,6 +225,8 @@ class HumalabTest(unittest.TestCase):
206
225
  mock_config_class.return_value = mock_config
207
226
 
208
227
  mock_api_client = Mock()
228
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
229
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
209
230
  mock_api_client_class.return_value = mock_api_client
210
231
 
211
232
  mock_scenario_inst = Mock()
@@ -241,6 +262,8 @@ class HumalabTest(unittest.TestCase):
241
262
  mock_config_class.return_value = mock_config
242
263
 
243
264
  mock_api_client = Mock()
265
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
266
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
244
267
  mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
245
268
  mock_api_client_class.return_value = mock_api_client
246
269
 
@@ -253,7 +276,7 @@ class HumalabTest(unittest.TestCase):
253
276
  # In-test
254
277
  with humalab.init(scenario_id=scenario_id) as run:
255
278
  # Post-condition
256
- mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
279
+ mock_api_client.get_scenario.assert_called_once_with(project_name='default', uuid=scenario_id, version=None)
257
280
  mock_scenario_inst.init.assert_called_once()
258
281
  call_kwargs = mock_scenario_inst.init.call_args.kwargs
259
282
  self.assertEqual(call_kwargs['scenario'], yaml_content)
@@ -274,6 +297,8 @@ class HumalabTest(unittest.TestCase):
274
297
  mock_config_class.return_value = mock_config
275
298
 
276
299
  mock_api_client = Mock()
300
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
301
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
277
302
  mock_api_client_class.return_value = mock_api_client
278
303
 
279
304
  mock_scenario_inst = Mock()
@@ -304,6 +329,8 @@ class HumalabTest(unittest.TestCase):
304
329
  mock_config_class.return_value = mock_config
305
330
 
306
331
  mock_api_client = Mock()
332
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
333
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
307
334
  mock_api_client_class.return_value = mock_api_client
308
335
 
309
336
  mock_scenario_inst = Mock()
@@ -338,6 +365,8 @@ class HumalabTest(unittest.TestCase):
338
365
  mock_config_class.return_value = mock_config
339
366
 
340
367
  mock_api_client = Mock()
368
+ mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
369
+ mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
341
370
  mock_api_client_class.return_value = mock_api_client
342
371
 
343
372
  mock_scenario_inst = Mock()
@@ -369,33 +398,33 @@ class HumalabTest(unittest.TestCase):
369
398
  humalab.finish()
370
399
 
371
400
  # Post-condition
372
- mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=None)
401
+ mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=None)
373
402
 
374
403
  def test_finish_should_call_finish_on_current_run_with_custom_status(self):
375
404
  """Test that finish() calls finish on the current run with custom status."""
376
405
  # Pre-condition
377
406
  mock_run = Mock()
378
407
  humalab._cur_run = mock_run
379
- status = EpisodeStatus.FAILED
408
+ status = RunStatus.ERRORED
380
409
 
381
410
  # In-test
382
411
  humalab.finish(status=status)
383
412
 
384
413
  # Post-condition
385
- mock_run.finish.assert_called_once_with(status=status, quiet=None)
414
+ mock_run.finish.assert_called_once_with(status=status, err_msg=None)
386
415
 
387
- def test_finish_should_call_finish_on_current_run_with_quiet_parameter(self):
388
- """Test that finish() calls finish on the current run with quiet parameter."""
416
+ def test_finish_should_call_finish_on_current_run_with_err_msg_parameter(self):
417
+ """Test that finish() calls finish on the current run with err_msg parameter."""
389
418
  # Pre-condition
390
419
  mock_run = Mock()
391
420
  humalab._cur_run = mock_run
392
- quiet = True
421
+ err_msg = "Test error message"
393
422
 
394
423
  # In-test
395
- humalab.finish(quiet=quiet)
424
+ humalab.finish(err_msg=err_msg)
396
425
 
397
426
  # Post-condition
398
- mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=quiet)
427
+ mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=err_msg)
399
428
 
400
429
  def test_finish_should_do_nothing_when_no_current_run(self):
401
430
  """Test that finish() does nothing when _cur_run is None."""
@@ -1,11 +1,17 @@
1
- from .metric import MetricGranularity, MetricType, Metrics
2
- from .dist_metric import DistributionMetric
1
+ """Metrics tracking and management.
2
+
3
+ This module provides classes for tracking various types of metrics during runs and episodes,
4
+ including general metrics, summary statistics, code artifacts, and scenario statistics.
5
+ """
6
+
7
+ from .metric import Metrics
8
+ from .code import Code
9
+ from .scenario_stats import ScenarioStats
3
10
  from .summary import Summary
4
11
 
5
12
  __all__ = [
6
- "MetricGranularity",
7
- "MetricType",
13
+ "Code",
8
14
  "Metrics",
9
- "DistributionMetric",
15
+ "ScenarioStats",
10
16
  "Summary",
11
17
  ]
@@ -0,0 +1,59 @@
1
+ class Code:
2
+ """Class for logging code artifacts.
3
+
4
+ Code artifacts capture source code or configuration files associated with
5
+ runs or episodes. They are stored as text content and can be retrieved later
6
+ for reproducibility and debugging purposes.
7
+
8
+ Attributes:
9
+ run_id (str): The unique identifier of the associated run.
10
+ key (str): The artifact key/name for this code.
11
+ code_content (str): The actual code or text content.
12
+ episode_id (str | None): Optional episode identifier if scoped to an episode.
13
+ """
14
+ def __init__(self,
15
+ run_id: str,
16
+ key: str,
17
+ code_content: str,
18
+ episode_id: str | None = None) -> None:
19
+ super().__init__()
20
+ self._run_id = run_id
21
+ self._key = key
22
+ self._code_content = code_content
23
+ self._episode_id = episode_id
24
+
25
+ @property
26
+ def run_id(self) -> str:
27
+ """The unique identifier of the associated run.
28
+
29
+ Returns:
30
+ str: The run ID.
31
+ """
32
+ return self._run_id
33
+
34
+ @property
35
+ def key(self) -> str:
36
+ """The artifact key/name for this code.
37
+
38
+ Returns:
39
+ str: The artifact key.
40
+ """
41
+ return self._key
42
+
43
+ @property
44
+ def code_content(self) -> str:
45
+ """The actual code or text content.
46
+
47
+ Returns:
48
+ str: The code content.
49
+ """
50
+ return self._code_content
51
+
52
+ @property
53
+ def episode_id(self) -> str | None:
54
+ """Optional episode identifier if scoped to an episode.
55
+
56
+ Returns:
57
+ str | None: The episode ID, or None if run-scoped.
58
+ """
59
+ return self._episode_id
humalab/metrics/metric.py CHANGED
@@ -1,129 +1,96 @@
1
- from enum import Enum
2
- from typing import Any
3
- from humalab.constants import EpisodeStatus
1
+ from typing import Any
2
+ from humalab.constants import MetricDimType, GraphType
4
3
 
4
+ GRAPH_TO_DIM_TYPE = {
5
+ GraphType.NUMERIC: MetricDimType.ZERO_D,
6
+ GraphType.LINE: MetricDimType.ONE_D,
7
+ GraphType.HISTOGRAM: MetricDimType.ONE_D,
8
+ GraphType.BAR: MetricDimType.ONE_D,
9
+ GraphType.GAUSSIAN: MetricDimType.ONE_D,
10
+ GraphType.SCATTER: MetricDimType.TWO_D,
11
+ GraphType.HEATMAP: MetricDimType.TWO_D,
12
+ GraphType.THREE_D_MAP: MetricDimType.THREE_D,
13
+ }
5
14
 
6
- class MetricType(Enum):
7
- DEFAULT = "default"
8
- STREAM = "stream"
9
- DISTRIBUTION = "distribution"
10
- SUMMARY = "summary"
11
15
 
16
+ class Metrics:
17
+ """Base class for tracking and logging metrics during runs and episodes.
12
18
 
13
- class MetricGranularity(Enum):
14
- STEP = "step"
15
- EPISODE = "episode"
16
- RUN = "run"
19
+ Metrics provide a flexible way to log time-series data or aggregated values
20
+ during experiments. Data points are collected with optional x-axis values
21
+ and can be visualized using different graph types.
17
22
 
23
+ Subclasses should override _finalize() to implement custom processing logic.
24
+
25
+ Attributes:
26
+ graph_type (GraphType): The type of graph used for visualization.
27
+ """
28
+ def __init__(self,
29
+ graph_type: GraphType=GraphType.LINE) -> None:
30
+ """Initialize a new Metrics instance.
18
31
 
19
- class Metrics:
20
- def __init__(self,
21
- name: str,
22
- metric_type: MetricType,
23
- episode_id: str,
24
- run_id: str,
25
- granularity: MetricGranularity = MetricGranularity.STEP) -> None:
26
- """
27
- Base class for different types of metrics.
28
-
29
32
  Args:
30
- name (str): The name of the metric.
31
- metric_type (MetricType): The type of the metric.
32
- episode_id (str): The ID of the episode.
33
- run_id (str): The ID of the run.
34
- granularity (MetricGranularity): The granularity of the metric.
33
+ graph_type (GraphType): The type of graph to use for visualization
34
+ (e.g., LINE, BAR, HISTOGRAM, SCATTER). Defaults to LINE.
35
35
  """
36
- self._name = name
37
- self._metric_type = metric_type
38
- self._granularity = granularity
39
36
  self._values = []
40
37
  self._x_values = []
41
- self._episode_id = episode_id
42
- self._run_id = run_id
43
- self._last_step = -1
44
-
45
- def reset(self,
46
- episode_id: str | None = None) -> None:
47
- """Reset the metric for a new episode or run.
38
+ self._step = -1
39
+ self._metric_dim_type = GRAPH_TO_DIM_TYPE.get(graph_type, MetricDimType.ONE_D)
40
+ self._graph_type = graph_type
48
41
 
49
- Args:
50
- episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
51
- """
52
- if self._granularity != MetricGranularity.RUN:
53
- self._submit()
54
- self._values = []
55
- self._x_values = []
56
- self._last_step = -1
57
- self._episode_id = episode_id
58
-
59
- @property
60
- def name(self) -> str:
61
- """The name of the metric.
62
-
63
- Returns:
64
- str: The name of the metric.
65
- """
66
- return self._name
67
-
68
42
  @property
69
- def metric_type(self) -> MetricType:
70
- """The type of the metric.
43
+ def metric_dim_type(self) -> MetricDimType:
44
+ """The dimensionality of the metric data.
71
45
 
72
46
  Returns:
73
- MetricType: The type of the metric.
47
+ MetricDimType: The metric dimension type.
74
48
  """
75
- return self._metric_type
76
-
49
+ return self._metric_dim_type
50
+
77
51
  @property
78
- def granularity(self) -> MetricGranularity:
79
- """The granularity of the metric.
52
+ def graph_type(self) -> GraphType:
53
+ """The type of graph used for visualization.
80
54
 
81
55
  Returns:
82
- MetricGranularity: The granularity of the metric.
56
+ GraphType: The graph type.
83
57
  """
84
- return self._granularity
85
-
86
- def log(self, data: Any, step: int | None = None, replace: bool = True) -> None:
87
- """Log a new data point for the metric. The behavior depends on the granularity.
58
+ return self._graph_type
59
+
60
+ def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
61
+ """Log a new data point for the metric.
88
62
 
89
63
  Args:
90
64
  data (Any): The data point to log.
91
- step (int | None): The step number for STEP granularity. Must be provided if granularity is STEP.
92
- replace (bool): Whether to replace the last logged value if logging at the same step/episode/run.
65
+ x (Any | None): The x-axis value associated with the data point.
66
+ If None, uses an auto-incrementing step counter.
67
+ replace (bool): Whether to replace the last logged value. Defaults to False.
93
68
  """
94
- if self._granularity == MetricGranularity.STEP:
95
- if step is None:
96
- raise ValueError("step Must be provided!")
97
- if step == self._last_step:
98
- if replace:
99
- self._values[-1] = data
100
- else:
101
- raise ValueError("Cannot log the data at the same step.")
69
+ if replace:
70
+ self._values[-1] = data
71
+ if x is not None:
72
+ self._x_values[-1] = x
73
+ else:
74
+ self._values.append(data)
75
+ if x is not None:
76
+ self._x_values.append(x)
102
77
  else:
103
- self._values.append(data)
104
- self._x_values.append(step)
105
- elif self._granularity == MetricGranularity.EPISODE:
106
- if len(self._x_values) > 0 and not replace:
107
- raise ValueError("Cannot log the data at the same episode.")
108
- self._values = [data]
109
- self._x_values = [self._episode_id]
110
- else: # MetricGranularity.RUN
111
- if len(self._values) > 0 and not replace:
112
- raise ValueError("Cannot log the data at the same run.")
113
- self._values = [data]
114
- self._x_values = [self._run_id]
115
-
116
- def _submit(self) -> None:
117
- if not self._values:
118
- # If there is no data to submit, then return.
119
- return
120
- # TODO: Implement commit logic
78
+ self._x_values.append(self._step)
79
+ self._step += 1
80
+
81
+ def finalize(self) -> dict:
82
+ """Finalize the logged data for processing."""
83
+ ret_result = self._finalize()
121
84
 
122
- # Clear data after the submission.
85
+ return ret_result
86
+
87
+ def _finalize(self) -> dict:
88
+ """Process the logged data before submission. To be implemented by subclasses."""
89
+ ret_val = {
90
+ "values": self._values,
91
+ "x_values": self._x_values
92
+ }
123
93
  self._values = []
124
94
  self._x_values = []
125
-
126
- def finish(self) -> None:
127
- """Finish the metric logging and submit the final data."""
128
- self.reset()
129
- self._submit()
95
+ self._step = -1
96
+ return ret_val