humalab 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of humalab might be problematic. Click here for more details.
- humalab/__init__.py +25 -0
- humalab/assets/__init__.py +8 -2
- humalab/assets/files/resource_file.py +96 -6
- humalab/assets/files/urdf_file.py +49 -11
- humalab/assets/resource_operator.py +139 -0
- humalab/constants.py +48 -5
- humalab/dists/__init__.py +7 -0
- humalab/dists/bernoulli.py +26 -1
- humalab/dists/categorical.py +25 -0
- humalab/dists/discrete.py +27 -2
- humalab/dists/distribution.py +11 -0
- humalab/dists/gaussian.py +27 -2
- humalab/dists/log_uniform.py +29 -3
- humalab/dists/truncated_gaussian.py +33 -4
- humalab/dists/uniform.py +24 -0
- humalab/episode.py +291 -11
- humalab/humalab.py +93 -38
- humalab/humalab_api_client.py +297 -95
- humalab/humalab_config.py +49 -0
- humalab/humalab_test.py +46 -17
- humalab/metrics/__init__.py +11 -5
- humalab/metrics/code.py +59 -0
- humalab/metrics/metric.py +69 -102
- humalab/metrics/scenario_stats.py +163 -0
- humalab/metrics/summary.py +45 -24
- humalab/run.py +224 -101
- humalab/scenarios/__init__.py +11 -0
- humalab/{scenario.py → scenarios/scenario.py} +130 -136
- humalab/scenarios/scenario_operator.py +114 -0
- humalab/{scenario_test.py → scenarios/scenario_test.py} +150 -269
- humalab/utils.py +37 -0
- {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/METADATA +1 -1
- humalab-0.0.7.dist-info/RECORD +39 -0
- humalab/assets/resource_manager.py +0 -58
- humalab/evaluators/__init__.py +0 -16
- humalab/humalab_main.py +0 -119
- humalab/metrics/dist_metric.py +0 -22
- humalab-0.0.5.dist-info/RECORD +0 -37
- {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/WHEEL +0 -0
- {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/entry_points.txt +0 -0
- {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/licenses/LICENSE +0 -0
- {humalab-0.0.5.dist-info → humalab-0.0.7.dist-info}/top_level.txt +0 -0
humalab/humalab_config.py
CHANGED
|
@@ -3,6 +3,18 @@ import yaml
|
|
|
3
3
|
import os
|
|
4
4
|
|
|
5
5
|
class HumalabConfig:
|
|
6
|
+
"""Manages HumaLab SDK configuration settings.
|
|
7
|
+
|
|
8
|
+
Configuration is stored in ~/.humalab/config.yaml and includes workspace path,
|
|
9
|
+
API credentials, and connection settings. Values are automatically loaded on
|
|
10
|
+
initialization and saved when modified through property setters.
|
|
11
|
+
|
|
12
|
+
Attributes:
|
|
13
|
+
workspace_path (str): The local workspace directory path.
|
|
14
|
+
base_url (str): The HumaLab API base URL.
|
|
15
|
+
api_key (str): The API key for authentication.
|
|
16
|
+
timeout (float): Request timeout in seconds.
|
|
17
|
+
"""
|
|
6
18
|
def __init__(self):
|
|
7
19
|
self._config = {
|
|
8
20
|
"workspace_path": "",
|
|
@@ -17,6 +29,7 @@ class HumalabConfig:
|
|
|
17
29
|
self._load_config()
|
|
18
30
|
|
|
19
31
|
def _load_config(self):
|
|
32
|
+
"""Load configuration from ~/.humalab/config.yaml."""
|
|
20
33
|
home_path = Path.home()
|
|
21
34
|
config_path = home_path / ".humalab" / "config.yaml"
|
|
22
35
|
if not config_path.exists():
|
|
@@ -30,10 +43,16 @@ class HumalabConfig:
|
|
|
30
43
|
self._timeout = self._config["timeout"] if self._config and "timeout" in self._config else 30.0
|
|
31
44
|
|
|
32
45
|
def _save(self) -> None:
|
|
46
|
+
"""Save current configuration to ~/.humalab/config.yaml."""
|
|
33
47
|
yaml.dump(self._config, open(Path.home() / ".humalab" / "config.yaml", "w"))
|
|
34
48
|
|
|
35
49
|
@property
|
|
36
50
|
def workspace_path(self) -> str:
|
|
51
|
+
"""The local workspace directory path.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
str: The workspace path.
|
|
55
|
+
"""
|
|
37
56
|
return str(self._workspace_path)
|
|
38
57
|
|
|
39
58
|
@workspace_path.setter
|
|
@@ -44,30 +63,60 @@ class HumalabConfig:
|
|
|
44
63
|
|
|
45
64
|
@property
|
|
46
65
|
def base_url(self) -> str:
|
|
66
|
+
"""The HumaLab API base URL.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
str: The base URL.
|
|
70
|
+
"""
|
|
47
71
|
return str(self._base_url)
|
|
48
72
|
|
|
49
73
|
@base_url.setter
|
|
50
74
|
def base_url(self, base_url: str) -> None:
|
|
75
|
+
"""Set the HumaLab API base URL and save to config.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
base_url (str): The new base URL.
|
|
79
|
+
"""
|
|
51
80
|
self._base_url = base_url
|
|
52
81
|
self._config["base_url"] = base_url
|
|
53
82
|
self._save()
|
|
54
83
|
|
|
55
84
|
@property
|
|
56
85
|
def api_key(self) -> str:
|
|
86
|
+
"""The API key for authentication.
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
str: The API key.
|
|
90
|
+
"""
|
|
57
91
|
return str(self._api_key)
|
|
58
92
|
|
|
59
93
|
@api_key.setter
|
|
60
94
|
def api_key(self, api_key: str) -> None:
|
|
95
|
+
"""Set the API key and save to config.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
api_key (str): The new API key.
|
|
99
|
+
"""
|
|
61
100
|
self._api_key = api_key
|
|
62
101
|
self._config["api_key"] = api_key
|
|
63
102
|
self._save()
|
|
64
103
|
|
|
65
104
|
@property
|
|
66
105
|
def timeout(self) -> float:
|
|
106
|
+
"""Request timeout in seconds.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
float: The timeout value.
|
|
110
|
+
"""
|
|
67
111
|
return self._timeout
|
|
68
112
|
|
|
69
113
|
@timeout.setter
|
|
70
114
|
def timeout(self, timeout: float) -> None:
|
|
115
|
+
"""Set the request timeout and save to config.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
timeout (float): The new timeout in seconds.
|
|
119
|
+
"""
|
|
71
120
|
self._timeout = timeout
|
|
72
121
|
self._config["timeout"] = timeout
|
|
73
122
|
self._save()
|
humalab/humalab_test.py
CHANGED
|
@@ -2,12 +2,13 @@ import unittest
|
|
|
2
2
|
from unittest.mock import patch, MagicMock, Mock
|
|
3
3
|
import uuid
|
|
4
4
|
|
|
5
|
+
from humalab.constants import DEFAULT_PROJECT
|
|
5
6
|
from humalab import humalab
|
|
6
7
|
from humalab.run import Run
|
|
7
|
-
from humalab.scenario import Scenario
|
|
8
|
+
from humalab.scenarios.scenario import Scenario
|
|
8
9
|
from humalab.humalab_config import HumalabConfig
|
|
9
10
|
from humalab.humalab_api_client import HumaLabApiClient
|
|
10
|
-
from humalab.
|
|
11
|
+
from humalab.humalab_api_client import EpisodeStatus, RunStatus
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class HumalabTest(unittest.TestCase):
|
|
@@ -30,9 +31,10 @@ class HumalabTest(unittest.TestCase):
|
|
|
30
31
|
# Pre-condition
|
|
31
32
|
client = Mock()
|
|
32
33
|
scenario = {"key": "value"}
|
|
34
|
+
project = "test_project"
|
|
33
35
|
|
|
34
36
|
# In-test
|
|
35
|
-
result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
|
|
37
|
+
result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=None)
|
|
36
38
|
|
|
37
39
|
# Post-condition
|
|
38
40
|
self.assertEqual(result, scenario)
|
|
@@ -42,32 +44,34 @@ class HumalabTest(unittest.TestCase):
|
|
|
42
44
|
"""Test that _pull_scenario fetches from API when scenario_id is provided."""
|
|
43
45
|
# Pre-condition
|
|
44
46
|
client = Mock()
|
|
47
|
+
project = "test_project"
|
|
45
48
|
scenario_id = "test-scenario-id"
|
|
46
49
|
yaml_content = "scenario: test"
|
|
47
50
|
client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
48
51
|
|
|
49
52
|
# In-test
|
|
50
|
-
result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
|
|
53
|
+
result = humalab._pull_scenario(client=client, project=project, scenario=None, scenario_id=scenario_id)
|
|
51
54
|
|
|
52
55
|
# Post-condition
|
|
53
56
|
self.assertEqual(result, yaml_content)
|
|
54
|
-
client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
57
|
+
client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
|
|
55
58
|
|
|
56
59
|
def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
|
|
57
60
|
"""Test that _pull_scenario uses scenario_id even when scenario is provided."""
|
|
58
61
|
# Pre-condition
|
|
59
62
|
client = Mock()
|
|
63
|
+
project = "test_project"
|
|
60
64
|
scenario = {"key": "value"}
|
|
61
65
|
scenario_id = "test-scenario-id"
|
|
62
66
|
yaml_content = "scenario: from_api"
|
|
63
67
|
client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
64
68
|
|
|
65
69
|
# In-test
|
|
66
|
-
result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
|
|
70
|
+
result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=scenario_id)
|
|
67
71
|
|
|
68
72
|
# Post-condition
|
|
69
73
|
self.assertEqual(result, yaml_content)
|
|
70
|
-
client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
74
|
+
client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
|
|
71
75
|
|
|
72
76
|
# Tests for init context manager
|
|
73
77
|
|
|
@@ -92,6 +96,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
92
96
|
mock_config_class.return_value = mock_config
|
|
93
97
|
|
|
94
98
|
mock_api_client = Mock()
|
|
99
|
+
mock_api_client.create_project.return_value = {"name": project}
|
|
100
|
+
mock_api_client.get_run.return_value = {"run_id": run_id, "name": name, "description": description, "tags": tags}
|
|
95
101
|
mock_api_client_class.return_value = mock_api_client
|
|
96
102
|
|
|
97
103
|
mock_scenario_inst = Mock()
|
|
@@ -137,6 +143,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
137
143
|
mock_config_class.return_value = mock_config
|
|
138
144
|
|
|
139
145
|
mock_api_client = Mock()
|
|
146
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
147
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
140
148
|
mock_api_client_class.return_value = mock_api_client
|
|
141
149
|
|
|
142
150
|
mock_scenario_inst = Mock()
|
|
@@ -149,7 +157,7 @@ class HumalabTest(unittest.TestCase):
|
|
|
149
157
|
with humalab.init() as run:
|
|
150
158
|
# Post-condition
|
|
151
159
|
call_kwargs = mock_run_class.call_args.kwargs
|
|
152
|
-
self.assertEqual(call_kwargs['project'],
|
|
160
|
+
self.assertEqual(call_kwargs['project'], DEFAULT_PROJECT)
|
|
153
161
|
self.assertEqual(call_kwargs['name'], "")
|
|
154
162
|
self.assertEqual(call_kwargs['description'], "")
|
|
155
163
|
self.assertIsNotNone(call_kwargs['id']) # UUID generated
|
|
@@ -170,7 +178,18 @@ class HumalabTest(unittest.TestCase):
|
|
|
170
178
|
mock_config.timeout = 30.0
|
|
171
179
|
mock_config_class.return_value = mock_config
|
|
172
180
|
|
|
181
|
+
# Mock HTTP 404 error for get_run (run doesn't exist yet)
|
|
182
|
+
import requests
|
|
183
|
+
http_error = requests.HTTPError()
|
|
184
|
+
http_error.response = Mock()
|
|
185
|
+
http_error.response.status_code = 404
|
|
186
|
+
|
|
173
187
|
mock_api_client = Mock()
|
|
188
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
189
|
+
mock_api_client.get_run.side_effect = http_error
|
|
190
|
+
# Mock create_run to return a valid UUID
|
|
191
|
+
generated_uuid = str(uuid.uuid4())
|
|
192
|
+
mock_api_client.create_run.return_value = {"run_id": generated_uuid, "name": "", "description": "", "tags": None}
|
|
174
193
|
mock_api_client_class.return_value = mock_api_client
|
|
175
194
|
|
|
176
195
|
mock_scenario_inst = Mock()
|
|
@@ -206,6 +225,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
206
225
|
mock_config_class.return_value = mock_config
|
|
207
226
|
|
|
208
227
|
mock_api_client = Mock()
|
|
228
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
229
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
209
230
|
mock_api_client_class.return_value = mock_api_client
|
|
210
231
|
|
|
211
232
|
mock_scenario_inst = Mock()
|
|
@@ -241,6 +262,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
241
262
|
mock_config_class.return_value = mock_config
|
|
242
263
|
|
|
243
264
|
mock_api_client = Mock()
|
|
265
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
266
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
244
267
|
mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
245
268
|
mock_api_client_class.return_value = mock_api_client
|
|
246
269
|
|
|
@@ -253,7 +276,7 @@ class HumalabTest(unittest.TestCase):
|
|
|
253
276
|
# In-test
|
|
254
277
|
with humalab.init(scenario_id=scenario_id) as run:
|
|
255
278
|
# Post-condition
|
|
256
|
-
mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
279
|
+
mock_api_client.get_scenario.assert_called_once_with(project_name='default', uuid=scenario_id, version=None)
|
|
257
280
|
mock_scenario_inst.init.assert_called_once()
|
|
258
281
|
call_kwargs = mock_scenario_inst.init.call_args.kwargs
|
|
259
282
|
self.assertEqual(call_kwargs['scenario'], yaml_content)
|
|
@@ -274,6 +297,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
274
297
|
mock_config_class.return_value = mock_config
|
|
275
298
|
|
|
276
299
|
mock_api_client = Mock()
|
|
300
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
301
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
277
302
|
mock_api_client_class.return_value = mock_api_client
|
|
278
303
|
|
|
279
304
|
mock_scenario_inst = Mock()
|
|
@@ -304,6 +329,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
304
329
|
mock_config_class.return_value = mock_config
|
|
305
330
|
|
|
306
331
|
mock_api_client = Mock()
|
|
332
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
333
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
307
334
|
mock_api_client_class.return_value = mock_api_client
|
|
308
335
|
|
|
309
336
|
mock_scenario_inst = Mock()
|
|
@@ -338,6 +365,8 @@ class HumalabTest(unittest.TestCase):
|
|
|
338
365
|
mock_config_class.return_value = mock_config
|
|
339
366
|
|
|
340
367
|
mock_api_client = Mock()
|
|
368
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
369
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
341
370
|
mock_api_client_class.return_value = mock_api_client
|
|
342
371
|
|
|
343
372
|
mock_scenario_inst = Mock()
|
|
@@ -369,33 +398,33 @@ class HumalabTest(unittest.TestCase):
|
|
|
369
398
|
humalab.finish()
|
|
370
399
|
|
|
371
400
|
# Post-condition
|
|
372
|
-
mock_run.finish.assert_called_once_with(status=
|
|
401
|
+
mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=None)
|
|
373
402
|
|
|
374
403
|
def test_finish_should_call_finish_on_current_run_with_custom_status(self):
|
|
375
404
|
"""Test that finish() calls finish on the current run with custom status."""
|
|
376
405
|
# Pre-condition
|
|
377
406
|
mock_run = Mock()
|
|
378
407
|
humalab._cur_run = mock_run
|
|
379
|
-
status =
|
|
408
|
+
status = RunStatus.ERRORED
|
|
380
409
|
|
|
381
410
|
# In-test
|
|
382
411
|
humalab.finish(status=status)
|
|
383
412
|
|
|
384
413
|
# Post-condition
|
|
385
|
-
mock_run.finish.assert_called_once_with(status=status,
|
|
414
|
+
mock_run.finish.assert_called_once_with(status=status, err_msg=None)
|
|
386
415
|
|
|
387
|
-
def
|
|
388
|
-
"""Test that finish() calls finish on the current run with
|
|
416
|
+
def test_finish_should_call_finish_on_current_run_with_err_msg_parameter(self):
|
|
417
|
+
"""Test that finish() calls finish on the current run with err_msg parameter."""
|
|
389
418
|
# Pre-condition
|
|
390
419
|
mock_run = Mock()
|
|
391
420
|
humalab._cur_run = mock_run
|
|
392
|
-
|
|
421
|
+
err_msg = "Test error message"
|
|
393
422
|
|
|
394
423
|
# In-test
|
|
395
|
-
humalab.finish(
|
|
424
|
+
humalab.finish(err_msg=err_msg)
|
|
396
425
|
|
|
397
426
|
# Post-condition
|
|
398
|
-
mock_run.finish.assert_called_once_with(status=
|
|
427
|
+
mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=err_msg)
|
|
399
428
|
|
|
400
429
|
def test_finish_should_do_nothing_when_no_current_run(self):
|
|
401
430
|
"""Test that finish() does nothing when _cur_run is None."""
|
humalab/metrics/__init__.py
CHANGED
|
@@ -1,11 +1,17 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
1
|
+
"""Metrics tracking and management.
|
|
2
|
+
|
|
3
|
+
This module provides classes for tracking various types of metrics during runs and episodes,
|
|
4
|
+
including general metrics, summary statistics, code artifacts, and scenario statistics.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from .metric import Metrics
|
|
8
|
+
from .code import Code
|
|
9
|
+
from .scenario_stats import ScenarioStats
|
|
3
10
|
from .summary import Summary
|
|
4
11
|
|
|
5
12
|
__all__ = [
|
|
6
|
-
"
|
|
7
|
-
"MetricType",
|
|
13
|
+
"Code",
|
|
8
14
|
"Metrics",
|
|
9
|
-
"
|
|
15
|
+
"ScenarioStats",
|
|
10
16
|
"Summary",
|
|
11
17
|
]
|
humalab/metrics/code.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
class Code:
|
|
2
|
+
"""Class for logging code artifacts.
|
|
3
|
+
|
|
4
|
+
Code artifacts capture source code or configuration files associated with
|
|
5
|
+
runs or episodes. They are stored as text content and can be retrieved later
|
|
6
|
+
for reproducibility and debugging purposes.
|
|
7
|
+
|
|
8
|
+
Attributes:
|
|
9
|
+
run_id (str): The unique identifier of the associated run.
|
|
10
|
+
key (str): The artifact key/name for this code.
|
|
11
|
+
code_content (str): The actual code or text content.
|
|
12
|
+
episode_id (str | None): Optional episode identifier if scoped to an episode.
|
|
13
|
+
"""
|
|
14
|
+
def __init__(self,
|
|
15
|
+
run_id: str,
|
|
16
|
+
key: str,
|
|
17
|
+
code_content: str,
|
|
18
|
+
episode_id: str | None = None) -> None:
|
|
19
|
+
super().__init__()
|
|
20
|
+
self._run_id = run_id
|
|
21
|
+
self._key = key
|
|
22
|
+
self._code_content = code_content
|
|
23
|
+
self._episode_id = episode_id
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def run_id(self) -> str:
|
|
27
|
+
"""The unique identifier of the associated run.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
str: The run ID.
|
|
31
|
+
"""
|
|
32
|
+
return self._run_id
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def key(self) -> str:
|
|
36
|
+
"""The artifact key/name for this code.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The artifact key.
|
|
40
|
+
"""
|
|
41
|
+
return self._key
|
|
42
|
+
|
|
43
|
+
@property
|
|
44
|
+
def code_content(self) -> str:
|
|
45
|
+
"""The actual code or text content.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
str: The code content.
|
|
49
|
+
"""
|
|
50
|
+
return self._code_content
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def episode_id(self) -> str | None:
|
|
54
|
+
"""Optional episode identifier if scoped to an episode.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
str | None: The episode ID, or None if run-scoped.
|
|
58
|
+
"""
|
|
59
|
+
return self._episode_id
|
humalab/metrics/metric.py
CHANGED
|
@@ -1,129 +1,96 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
3
|
-
from humalab.constants import EpisodeStatus
|
|
1
|
+
from typing import Any
|
|
2
|
+
from humalab.constants import MetricDimType, GraphType
|
|
4
3
|
|
|
4
|
+
GRAPH_TO_DIM_TYPE = {
|
|
5
|
+
GraphType.NUMERIC: MetricDimType.ZERO_D,
|
|
6
|
+
GraphType.LINE: MetricDimType.ONE_D,
|
|
7
|
+
GraphType.HISTOGRAM: MetricDimType.ONE_D,
|
|
8
|
+
GraphType.BAR: MetricDimType.ONE_D,
|
|
9
|
+
GraphType.GAUSSIAN: MetricDimType.ONE_D,
|
|
10
|
+
GraphType.SCATTER: MetricDimType.TWO_D,
|
|
11
|
+
GraphType.HEATMAP: MetricDimType.TWO_D,
|
|
12
|
+
GraphType.THREE_D_MAP: MetricDimType.THREE_D,
|
|
13
|
+
}
|
|
5
14
|
|
|
6
|
-
class MetricType(Enum):
|
|
7
|
-
DEFAULT = "default"
|
|
8
|
-
STREAM = "stream"
|
|
9
|
-
DISTRIBUTION = "distribution"
|
|
10
|
-
SUMMARY = "summary"
|
|
11
15
|
|
|
16
|
+
class Metrics:
|
|
17
|
+
"""Base class for tracking and logging metrics during runs and episodes.
|
|
12
18
|
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
RUN = "run"
|
|
19
|
+
Metrics provide a flexible way to log time-series data or aggregated values
|
|
20
|
+
during experiments. Data points are collected with optional x-axis values
|
|
21
|
+
and can be visualized using different graph types.
|
|
17
22
|
|
|
23
|
+
Subclasses should override _finalize() to implement custom processing logic.
|
|
24
|
+
|
|
25
|
+
Attributes:
|
|
26
|
+
graph_type (GraphType): The type of graph used for visualization.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self,
|
|
29
|
+
graph_type: GraphType=GraphType.LINE) -> None:
|
|
30
|
+
"""Initialize a new Metrics instance.
|
|
18
31
|
|
|
19
|
-
class Metrics:
|
|
20
|
-
def __init__(self,
|
|
21
|
-
name: str,
|
|
22
|
-
metric_type: MetricType,
|
|
23
|
-
episode_id: str,
|
|
24
|
-
run_id: str,
|
|
25
|
-
granularity: MetricGranularity = MetricGranularity.STEP) -> None:
|
|
26
|
-
"""
|
|
27
|
-
Base class for different types of metrics.
|
|
28
|
-
|
|
29
32
|
Args:
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
episode_id (str): The ID of the episode.
|
|
33
|
-
run_id (str): The ID of the run.
|
|
34
|
-
granularity (MetricGranularity): The granularity of the metric.
|
|
33
|
+
graph_type (GraphType): The type of graph to use for visualization
|
|
34
|
+
(e.g., LINE, BAR, HISTOGRAM, SCATTER). Defaults to LINE.
|
|
35
35
|
"""
|
|
36
|
-
self._name = name
|
|
37
|
-
self._metric_type = metric_type
|
|
38
|
-
self._granularity = granularity
|
|
39
36
|
self._values = []
|
|
40
37
|
self._x_values = []
|
|
41
|
-
self.
|
|
42
|
-
self.
|
|
43
|
-
self.
|
|
44
|
-
|
|
45
|
-
def reset(self,
|
|
46
|
-
episode_id: str | None = None) -> None:
|
|
47
|
-
"""Reset the metric for a new episode or run.
|
|
38
|
+
self._step = -1
|
|
39
|
+
self._metric_dim_type = GRAPH_TO_DIM_TYPE.get(graph_type, MetricDimType.ONE_D)
|
|
40
|
+
self._graph_type = graph_type
|
|
48
41
|
|
|
49
|
-
Args:
|
|
50
|
-
episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
|
|
51
|
-
"""
|
|
52
|
-
if self._granularity != MetricGranularity.RUN:
|
|
53
|
-
self._submit()
|
|
54
|
-
self._values = []
|
|
55
|
-
self._x_values = []
|
|
56
|
-
self._last_step = -1
|
|
57
|
-
self._episode_id = episode_id
|
|
58
|
-
|
|
59
|
-
@property
|
|
60
|
-
def name(self) -> str:
|
|
61
|
-
"""The name of the metric.
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
str: The name of the metric.
|
|
65
|
-
"""
|
|
66
|
-
return self._name
|
|
67
|
-
|
|
68
42
|
@property
|
|
69
|
-
def
|
|
70
|
-
"""The
|
|
43
|
+
def metric_dim_type(self) -> MetricDimType:
|
|
44
|
+
"""The dimensionality of the metric data.
|
|
71
45
|
|
|
72
46
|
Returns:
|
|
73
|
-
|
|
47
|
+
MetricDimType: The metric dimension type.
|
|
74
48
|
"""
|
|
75
|
-
return self.
|
|
76
|
-
|
|
49
|
+
return self._metric_dim_type
|
|
50
|
+
|
|
77
51
|
@property
|
|
78
|
-
def
|
|
79
|
-
"""The
|
|
52
|
+
def graph_type(self) -> GraphType:
|
|
53
|
+
"""The type of graph used for visualization.
|
|
80
54
|
|
|
81
55
|
Returns:
|
|
82
|
-
|
|
56
|
+
GraphType: The graph type.
|
|
83
57
|
"""
|
|
84
|
-
return self.
|
|
85
|
-
|
|
86
|
-
def log(self, data: Any,
|
|
87
|
-
"""Log a new data point for the metric.
|
|
58
|
+
return self._graph_type
|
|
59
|
+
|
|
60
|
+
def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
|
|
61
|
+
"""Log a new data point for the metric.
|
|
88
62
|
|
|
89
63
|
Args:
|
|
90
64
|
data (Any): The data point to log.
|
|
91
|
-
|
|
92
|
-
|
|
65
|
+
x (Any | None): The x-axis value associated with the data point.
|
|
66
|
+
If None, uses an auto-incrementing step counter.
|
|
67
|
+
replace (bool): Whether to replace the last logged value. Defaults to False.
|
|
93
68
|
"""
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
69
|
+
if replace:
|
|
70
|
+
self._values[-1] = data
|
|
71
|
+
if x is not None:
|
|
72
|
+
self._x_values[-1] = x
|
|
73
|
+
else:
|
|
74
|
+
self._values.append(data)
|
|
75
|
+
if x is not None:
|
|
76
|
+
self._x_values.append(x)
|
|
102
77
|
else:
|
|
103
|
-
self.
|
|
104
|
-
self.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
self._x_values = [self._episode_id]
|
|
110
|
-
else: # MetricGranularity.RUN
|
|
111
|
-
if len(self._values) > 0 and not replace:
|
|
112
|
-
raise ValueError("Cannot log the data at the same run.")
|
|
113
|
-
self._values = [data]
|
|
114
|
-
self._x_values = [self._run_id]
|
|
115
|
-
|
|
116
|
-
def _submit(self) -> None:
|
|
117
|
-
if not self._values:
|
|
118
|
-
# If there is no data to submit, then return.
|
|
119
|
-
return
|
|
120
|
-
# TODO: Implement commit logic
|
|
78
|
+
self._x_values.append(self._step)
|
|
79
|
+
self._step += 1
|
|
80
|
+
|
|
81
|
+
def finalize(self) -> dict:
|
|
82
|
+
"""Finalize the logged data for processing."""
|
|
83
|
+
ret_result = self._finalize()
|
|
121
84
|
|
|
122
|
-
|
|
85
|
+
return ret_result
|
|
86
|
+
|
|
87
|
+
def _finalize(self) -> dict:
|
|
88
|
+
"""Process the logged data before submission. To be implemented by subclasses."""
|
|
89
|
+
ret_val = {
|
|
90
|
+
"values": self._values,
|
|
91
|
+
"x_values": self._x_values
|
|
92
|
+
}
|
|
123
93
|
self._values = []
|
|
124
94
|
self._x_values = []
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
"""Finish the metric logging and submit the final data."""
|
|
128
|
-
self.reset()
|
|
129
|
-
self._submit()
|
|
95
|
+
self._step = -1
|
|
96
|
+
return ret_val
|