humalab 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of humalab might be problematic. Click here for more details.
- humalab/__init__.py +11 -0
- humalab/assets/__init__.py +2 -2
- humalab/assets/files/resource_file.py +29 -3
- humalab/assets/files/urdf_file.py +14 -10
- humalab/assets/resource_operator.py +91 -0
- humalab/constants.py +39 -5
- humalab/dists/bernoulli.py +16 -0
- humalab/dists/categorical.py +4 -0
- humalab/dists/discrete.py +22 -0
- humalab/dists/gaussian.py +22 -0
- humalab/dists/log_uniform.py +22 -0
- humalab/dists/truncated_gaussian.py +36 -0
- humalab/dists/uniform.py +22 -0
- humalab/episode.py +196 -0
- humalab/humalab.py +116 -153
- humalab/humalab_api_client.py +760 -62
- humalab/humalab_config.py +0 -13
- humalab/humalab_test.py +46 -29
- humalab/metrics/__init__.py +5 -5
- humalab/metrics/code.py +28 -0
- humalab/metrics/metric.py +41 -108
- humalab/metrics/scenario_stats.py +95 -0
- humalab/metrics/summary.py +24 -18
- humalab/run.py +180 -115
- humalab/scenarios/__init__.py +4 -0
- humalab/scenarios/scenario.py +372 -0
- humalab/scenarios/scenario_operator.py +82 -0
- humalab/{scenario_test.py → scenarios/scenario_test.py} +150 -269
- humalab/utils.py +37 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/METADATA +1 -1
- humalab-0.0.6.dist-info/RECORD +39 -0
- humalab/assets/resource_manager.py +0 -57
- humalab/metrics/dist_metric.py +0 -22
- humalab/scenario.py +0 -225
- humalab-0.0.4.dist-info/RECORD +0 -34
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/WHEEL +0 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/entry_points.txt +0 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {humalab-0.0.4.dist-info → humalab-0.0.6.dist-info}/top_level.txt +0 -0
humalab/humalab_config.py
CHANGED
|
@@ -6,13 +6,11 @@ class HumalabConfig:
|
|
|
6
6
|
def __init__(self):
|
|
7
7
|
self._config = {
|
|
8
8
|
"workspace_path": "",
|
|
9
|
-
"entity": "",
|
|
10
9
|
"base_url": "",
|
|
11
10
|
"api_key": "",
|
|
12
11
|
"timeout": 30.0,
|
|
13
12
|
}
|
|
14
13
|
self._workspace_path = ""
|
|
15
|
-
self._entity = ""
|
|
16
14
|
self._base_url = ""
|
|
17
15
|
self._api_key = ""
|
|
18
16
|
self._timeout = 30.0
|
|
@@ -27,7 +25,6 @@ class HumalabConfig:
|
|
|
27
25
|
with open(config_path, "r") as f:
|
|
28
26
|
self._config = yaml.safe_load(f)
|
|
29
27
|
self._workspace_path = os.path.expanduser(self._config["workspace_path"]) if self._config and "workspace_path" in self._config else home_path
|
|
30
|
-
self._entity = self._config["entity"] if self._config and "entity" in self._config else ""
|
|
31
28
|
self._base_url = self._config["base_url"] if self._config and "base_url" in self._config else ""
|
|
32
29
|
self._api_key = self._config["api_key"] if self._config and "api_key" in self._config else ""
|
|
33
30
|
self._timeout = self._config["timeout"] if self._config and "timeout" in self._config else 30.0
|
|
@@ -45,16 +42,6 @@ class HumalabConfig:
|
|
|
45
42
|
self._config["workspace_path"] = path
|
|
46
43
|
self._save()
|
|
47
44
|
|
|
48
|
-
@property
|
|
49
|
-
def entity(self) -> str:
|
|
50
|
-
return str(self._entity)
|
|
51
|
-
|
|
52
|
-
@entity.setter
|
|
53
|
-
def entity(self, entity: str) -> None:
|
|
54
|
-
self._entity = entity
|
|
55
|
-
self._config["entity"] = entity
|
|
56
|
-
self._save()
|
|
57
|
-
|
|
58
45
|
@property
|
|
59
46
|
def base_url(self) -> str:
|
|
60
47
|
return str(self._base_url)
|
humalab/humalab_test.py
CHANGED
|
@@ -2,12 +2,13 @@ import unittest
|
|
|
2
2
|
from unittest.mock import patch, MagicMock, Mock
|
|
3
3
|
import uuid
|
|
4
4
|
|
|
5
|
+
from humalab.constants import DEFAULT_PROJECT
|
|
5
6
|
from humalab import humalab
|
|
6
7
|
from humalab.run import Run
|
|
7
|
-
from humalab.scenario import Scenario
|
|
8
|
+
from humalab.scenarios.scenario import Scenario
|
|
8
9
|
from humalab.humalab_config import HumalabConfig
|
|
9
10
|
from humalab.humalab_api_client import HumaLabApiClient
|
|
10
|
-
from humalab.
|
|
11
|
+
from humalab.humalab_api_client import EpisodeStatus, RunStatus
|
|
11
12
|
|
|
12
13
|
|
|
13
14
|
class HumalabTest(unittest.TestCase):
|
|
@@ -30,9 +31,10 @@ class HumalabTest(unittest.TestCase):
|
|
|
30
31
|
# Pre-condition
|
|
31
32
|
client = Mock()
|
|
32
33
|
scenario = {"key": "value"}
|
|
34
|
+
project = "test_project"
|
|
33
35
|
|
|
34
36
|
# In-test
|
|
35
|
-
result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
|
|
37
|
+
result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=None)
|
|
36
38
|
|
|
37
39
|
# Post-condition
|
|
38
40
|
self.assertEqual(result, scenario)
|
|
@@ -42,32 +44,34 @@ class HumalabTest(unittest.TestCase):
|
|
|
42
44
|
"""Test that _pull_scenario fetches from API when scenario_id is provided."""
|
|
43
45
|
# Pre-condition
|
|
44
46
|
client = Mock()
|
|
47
|
+
project = "test_project"
|
|
45
48
|
scenario_id = "test-scenario-id"
|
|
46
49
|
yaml_content = "scenario: test"
|
|
47
50
|
client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
48
51
|
|
|
49
52
|
# In-test
|
|
50
|
-
result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
|
|
53
|
+
result = humalab._pull_scenario(client=client, project=project, scenario=None, scenario_id=scenario_id)
|
|
51
54
|
|
|
52
55
|
# Post-condition
|
|
53
56
|
self.assertEqual(result, yaml_content)
|
|
54
|
-
client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
57
|
+
client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
|
|
55
58
|
|
|
56
59
|
def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
|
|
57
60
|
"""Test that _pull_scenario uses scenario_id even when scenario is provided."""
|
|
58
61
|
# Pre-condition
|
|
59
62
|
client = Mock()
|
|
63
|
+
project = "test_project"
|
|
60
64
|
scenario = {"key": "value"}
|
|
61
65
|
scenario_id = "test-scenario-id"
|
|
62
66
|
yaml_content = "scenario: from_api"
|
|
63
67
|
client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
64
68
|
|
|
65
69
|
# In-test
|
|
66
|
-
result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
|
|
70
|
+
result = humalab._pull_scenario(client=client, project=project, scenario=scenario, scenario_id=scenario_id)
|
|
67
71
|
|
|
68
72
|
# Post-condition
|
|
69
73
|
self.assertEqual(result, yaml_content)
|
|
70
|
-
client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
74
|
+
client.get_scenario.assert_called_once_with(project_name=project, uuid=scenario_id, version=None)
|
|
71
75
|
|
|
72
76
|
# Tests for init context manager
|
|
73
77
|
|
|
@@ -78,7 +82,6 @@ class HumalabTest(unittest.TestCase):
|
|
|
78
82
|
def test_init_should_create_run_with_provided_parameters(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
79
83
|
"""Test that init() creates a Run with provided parameters."""
|
|
80
84
|
# Pre-condition
|
|
81
|
-
entity = "test_entity"
|
|
82
85
|
project = "test_project"
|
|
83
86
|
name = "test_name"
|
|
84
87
|
description = "test_description"
|
|
@@ -87,13 +90,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
87
90
|
scenario_data = {"key": "value"}
|
|
88
91
|
|
|
89
92
|
mock_config = Mock()
|
|
90
|
-
mock_config.entity = "default_entity"
|
|
91
93
|
mock_config.base_url = "http://localhost:8000"
|
|
92
94
|
mock_config.api_key = "test_key"
|
|
93
95
|
mock_config.timeout = 30.0
|
|
94
96
|
mock_config_class.return_value = mock_config
|
|
95
97
|
|
|
96
98
|
mock_api_client = Mock()
|
|
99
|
+
mock_api_client.create_project.return_value = {"name": project}
|
|
100
|
+
mock_api_client.get_run.return_value = {"run_id": run_id, "name": name, "description": description, "tags": tags}
|
|
97
101
|
mock_api_client_class.return_value = mock_api_client
|
|
98
102
|
|
|
99
103
|
mock_scenario_inst = Mock()
|
|
@@ -104,7 +108,6 @@ class HumalabTest(unittest.TestCase):
|
|
|
104
108
|
|
|
105
109
|
# In-test
|
|
106
110
|
with humalab.init(
|
|
107
|
-
entity=entity,
|
|
108
111
|
project=project,
|
|
109
112
|
name=name,
|
|
110
113
|
description=description,
|
|
@@ -116,7 +119,6 @@ class HumalabTest(unittest.TestCase):
|
|
|
116
119
|
self.assertEqual(run, mock_run_inst)
|
|
117
120
|
mock_run_class.assert_called_once()
|
|
118
121
|
call_kwargs = mock_run_class.call_args.kwargs
|
|
119
|
-
self.assertEqual(call_kwargs['entity'], entity)
|
|
120
122
|
self.assertEqual(call_kwargs['project'], project)
|
|
121
123
|
self.assertEqual(call_kwargs['name'], name)
|
|
122
124
|
self.assertEqual(call_kwargs['description'], description)
|
|
@@ -135,13 +137,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
135
137
|
"""Test that init() uses config defaults when parameters are not provided."""
|
|
136
138
|
# Pre-condition
|
|
137
139
|
mock_config = Mock()
|
|
138
|
-
mock_config.entity = "config_entity"
|
|
139
140
|
mock_config.base_url = "http://config:8000"
|
|
140
141
|
mock_config.api_key = "config_key"
|
|
141
142
|
mock_config.timeout = 60.0
|
|
142
143
|
mock_config_class.return_value = mock_config
|
|
143
144
|
|
|
144
145
|
mock_api_client = Mock()
|
|
146
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
147
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
145
148
|
mock_api_client_class.return_value = mock_api_client
|
|
146
149
|
|
|
147
150
|
mock_scenario_inst = Mock()
|
|
@@ -154,8 +157,7 @@ class HumalabTest(unittest.TestCase):
|
|
|
154
157
|
with humalab.init() as run:
|
|
155
158
|
# Post-condition
|
|
156
159
|
call_kwargs = mock_run_class.call_args.kwargs
|
|
157
|
-
self.assertEqual(call_kwargs['
|
|
158
|
-
self.assertEqual(call_kwargs['project'], "default")
|
|
160
|
+
self.assertEqual(call_kwargs['project'], DEFAULT_PROJECT)
|
|
159
161
|
self.assertEqual(call_kwargs['name'], "")
|
|
160
162
|
self.assertEqual(call_kwargs['description'], "")
|
|
161
163
|
self.assertIsNotNone(call_kwargs['id']) # UUID generated
|
|
@@ -171,13 +173,23 @@ class HumalabTest(unittest.TestCase):
|
|
|
171
173
|
"""Test that init() generates a UUID when id is not provided."""
|
|
172
174
|
# Pre-condition
|
|
173
175
|
mock_config = Mock()
|
|
174
|
-
mock_config.entity = "test_entity"
|
|
175
176
|
mock_config.base_url = "http://localhost:8000"
|
|
176
177
|
mock_config.api_key = "test_key"
|
|
177
178
|
mock_config.timeout = 30.0
|
|
178
179
|
mock_config_class.return_value = mock_config
|
|
179
180
|
|
|
181
|
+
# Mock HTTP 404 error for get_run (run doesn't exist yet)
|
|
182
|
+
import requests
|
|
183
|
+
http_error = requests.HTTPError()
|
|
184
|
+
http_error.response = Mock()
|
|
185
|
+
http_error.response.status_code = 404
|
|
186
|
+
|
|
180
187
|
mock_api_client = Mock()
|
|
188
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
189
|
+
mock_api_client.get_run.side_effect = http_error
|
|
190
|
+
# Mock create_run to return a valid UUID
|
|
191
|
+
generated_uuid = str(uuid.uuid4())
|
|
192
|
+
mock_api_client.create_run.return_value = {"run_id": generated_uuid, "name": "", "description": "", "tags": None}
|
|
181
193
|
mock_api_client_class.return_value = mock_api_client
|
|
182
194
|
|
|
183
195
|
mock_scenario_inst = Mock()
|
|
@@ -207,13 +219,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
207
219
|
scenario_data = {"key": "value"}
|
|
208
220
|
|
|
209
221
|
mock_config = Mock()
|
|
210
|
-
mock_config.entity = "test_entity"
|
|
211
222
|
mock_config.base_url = "http://localhost:8000"
|
|
212
223
|
mock_config.api_key = "test_key"
|
|
213
224
|
mock_config.timeout = 30.0
|
|
214
225
|
mock_config_class.return_value = mock_config
|
|
215
226
|
|
|
216
227
|
mock_api_client = Mock()
|
|
228
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
229
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
217
230
|
mock_api_client_class.return_value = mock_api_client
|
|
218
231
|
|
|
219
232
|
mock_scenario_inst = Mock()
|
|
@@ -243,13 +256,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
243
256
|
yaml_content = "scenario: from_api"
|
|
244
257
|
|
|
245
258
|
mock_config = Mock()
|
|
246
|
-
mock_config.entity = "test_entity"
|
|
247
259
|
mock_config.base_url = "http://localhost:8000"
|
|
248
260
|
mock_config.api_key = "test_key"
|
|
249
261
|
mock_config.timeout = 30.0
|
|
250
262
|
mock_config_class.return_value = mock_config
|
|
251
263
|
|
|
252
264
|
mock_api_client = Mock()
|
|
265
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
266
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
253
267
|
mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
254
268
|
mock_api_client_class.return_value = mock_api_client
|
|
255
269
|
|
|
@@ -262,7 +276,7 @@ class HumalabTest(unittest.TestCase):
|
|
|
262
276
|
# In-test
|
|
263
277
|
with humalab.init(scenario_id=scenario_id) as run:
|
|
264
278
|
# Post-condition
|
|
265
|
-
mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
279
|
+
mock_api_client.get_scenario.assert_called_once_with(project_name='default', uuid=scenario_id, version=None)
|
|
266
280
|
mock_scenario_inst.init.assert_called_once()
|
|
267
281
|
call_kwargs = mock_scenario_inst.init.call_args.kwargs
|
|
268
282
|
self.assertEqual(call_kwargs['scenario'], yaml_content)
|
|
@@ -277,13 +291,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
277
291
|
"""Test that init() sets the global _cur_run variable."""
|
|
278
292
|
# Pre-condition
|
|
279
293
|
mock_config = Mock()
|
|
280
|
-
mock_config.entity = "test_entity"
|
|
281
294
|
mock_config.base_url = "http://localhost:8000"
|
|
282
295
|
mock_config.api_key = "test_key"
|
|
283
296
|
mock_config.timeout = 30.0
|
|
284
297
|
mock_config_class.return_value = mock_config
|
|
285
298
|
|
|
286
299
|
mock_api_client = Mock()
|
|
300
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
301
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
287
302
|
mock_api_client_class.return_value = mock_api_client
|
|
288
303
|
|
|
289
304
|
mock_scenario_inst = Mock()
|
|
@@ -308,13 +323,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
308
323
|
"""Test that init() calls finish even when exception occurs in context."""
|
|
309
324
|
# Pre-condition
|
|
310
325
|
mock_config = Mock()
|
|
311
|
-
mock_config.entity = "test_entity"
|
|
312
326
|
mock_config.base_url = "http://localhost:8000"
|
|
313
327
|
mock_config.api_key = "test_key"
|
|
314
328
|
mock_config.timeout = 30.0
|
|
315
329
|
mock_config_class.return_value = mock_config
|
|
316
330
|
|
|
317
331
|
mock_api_client = Mock()
|
|
332
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
333
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
318
334
|
mock_api_client_class.return_value = mock_api_client
|
|
319
335
|
|
|
320
336
|
mock_scenario_inst = Mock()
|
|
@@ -343,13 +359,14 @@ class HumalabTest(unittest.TestCase):
|
|
|
343
359
|
timeout = 120.0
|
|
344
360
|
|
|
345
361
|
mock_config = Mock()
|
|
346
|
-
mock_config.entity = "test_entity"
|
|
347
362
|
mock_config.base_url = "http://localhost:8000"
|
|
348
363
|
mock_config.api_key = "default_key"
|
|
349
364
|
mock_config.timeout = 30.0
|
|
350
365
|
mock_config_class.return_value = mock_config
|
|
351
366
|
|
|
352
367
|
mock_api_client = Mock()
|
|
368
|
+
mock_api_client.create_project.return_value = {"name": DEFAULT_PROJECT}
|
|
369
|
+
mock_api_client.get_run.return_value = {"run_id": "", "name": "", "description": "", "tags": None}
|
|
353
370
|
mock_api_client_class.return_value = mock_api_client
|
|
354
371
|
|
|
355
372
|
mock_scenario_inst = Mock()
|
|
@@ -381,33 +398,33 @@ class HumalabTest(unittest.TestCase):
|
|
|
381
398
|
humalab.finish()
|
|
382
399
|
|
|
383
400
|
# Post-condition
|
|
384
|
-
mock_run.finish.assert_called_once_with(status=
|
|
401
|
+
mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=None)
|
|
385
402
|
|
|
386
403
|
def test_finish_should_call_finish_on_current_run_with_custom_status(self):
|
|
387
404
|
"""Test that finish() calls finish on the current run with custom status."""
|
|
388
405
|
# Pre-condition
|
|
389
406
|
mock_run = Mock()
|
|
390
407
|
humalab._cur_run = mock_run
|
|
391
|
-
status =
|
|
408
|
+
status = RunStatus.ERRORED
|
|
392
409
|
|
|
393
410
|
# In-test
|
|
394
411
|
humalab.finish(status=status)
|
|
395
412
|
|
|
396
413
|
# Post-condition
|
|
397
|
-
mock_run.finish.assert_called_once_with(status=status,
|
|
414
|
+
mock_run.finish.assert_called_once_with(status=status, err_msg=None)
|
|
398
415
|
|
|
399
|
-
def
|
|
400
|
-
"""Test that finish() calls finish on the current run with
|
|
416
|
+
def test_finish_should_call_finish_on_current_run_with_err_msg_parameter(self):
|
|
417
|
+
"""Test that finish() calls finish on the current run with err_msg parameter."""
|
|
401
418
|
# Pre-condition
|
|
402
419
|
mock_run = Mock()
|
|
403
420
|
humalab._cur_run = mock_run
|
|
404
|
-
|
|
421
|
+
err_msg = "Test error message"
|
|
405
422
|
|
|
406
423
|
# In-test
|
|
407
|
-
humalab.finish(
|
|
424
|
+
humalab.finish(err_msg=err_msg)
|
|
408
425
|
|
|
409
426
|
# Post-condition
|
|
410
|
-
mock_run.finish.assert_called_once_with(status=
|
|
427
|
+
mock_run.finish.assert_called_once_with(status=RunStatus.FINISHED, err_msg=err_msg)
|
|
411
428
|
|
|
412
429
|
def test_finish_should_do_nothing_when_no_current_run(self):
|
|
413
430
|
"""Test that finish() does nothing when _cur_run is None."""
|
humalab/metrics/__init__.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from .metric import
|
|
2
|
-
from .
|
|
1
|
+
from .metric import Metrics
|
|
2
|
+
from .code import Code
|
|
3
|
+
from .scenario_stats import ScenarioStats
|
|
3
4
|
from .summary import Summary
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
|
-
"
|
|
7
|
-
"MetricType",
|
|
7
|
+
"Code",
|
|
8
8
|
"Metrics",
|
|
9
|
-
"
|
|
9
|
+
"ScenarioStats",
|
|
10
10
|
"Summary",
|
|
11
11
|
]
|
humalab/metrics/code.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class Code:
|
|
2
|
+
"""Class for logging code artifacts."""
|
|
3
|
+
def __init__(self,
|
|
4
|
+
run_id: str,
|
|
5
|
+
key: str,
|
|
6
|
+
code_content: str,
|
|
7
|
+
episode_id: str | None = None) -> None:
|
|
8
|
+
super().__init__()
|
|
9
|
+
self._run_id = run_id
|
|
10
|
+
self._key = key
|
|
11
|
+
self._code_content = code_content
|
|
12
|
+
self._episode_id = episode_id
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def run_id(self) -> str:
|
|
16
|
+
return self._run_id
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def key(self) -> str:
|
|
20
|
+
return self._key
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def code_content(self) -> str:
|
|
24
|
+
return self._code_content
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
def episode_id(self) -> str | None:
|
|
28
|
+
return self._episode_id
|
humalab/metrics/metric.py
CHANGED
|
@@ -1,129 +1,62 @@
|
|
|
1
|
-
from
|
|
2
|
-
from
|
|
3
|
-
from humalab.constants import EpisodeStatus
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
class MetricType(Enum):
|
|
7
|
-
DEFAULT = "default"
|
|
8
|
-
STREAM = "stream"
|
|
9
|
-
DISTRIBUTION = "distribution"
|
|
10
|
-
SUMMARY = "summary"
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
class MetricGranularity(Enum):
|
|
14
|
-
STEP = "step"
|
|
15
|
-
EPISODE = "episode"
|
|
16
|
-
RUN = "run"
|
|
1
|
+
from typing import Any
|
|
2
|
+
from humalab.constants import MetricDimType, GraphType
|
|
17
3
|
|
|
18
4
|
|
|
19
5
|
class Metrics:
|
|
20
|
-
def __init__(self,
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
episode_id: str,
|
|
24
|
-
run_id: str,
|
|
25
|
-
granularity: MetricGranularity = MetricGranularity.STEP) -> None:
|
|
6
|
+
def __init__(self,
|
|
7
|
+
metric_dim_type: MetricDimType= MetricDimType.ONE_D,
|
|
8
|
+
graph_type: GraphType=GraphType.LINE) -> None:
|
|
26
9
|
"""
|
|
27
10
|
Base class for different types of metrics.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
name (str): The name of the metric.
|
|
31
|
-
metric_type (MetricType): The type of the metric.
|
|
32
|
-
episode_id (str): The ID of the episode.
|
|
33
|
-
run_id (str): The ID of the run.
|
|
34
|
-
granularity (MetricGranularity): The granularity of the metric.
|
|
35
11
|
"""
|
|
36
|
-
self._name = name
|
|
37
|
-
self._metric_type = metric_type
|
|
38
|
-
self._granularity = granularity
|
|
39
12
|
self._values = []
|
|
40
13
|
self._x_values = []
|
|
41
|
-
self.
|
|
42
|
-
self.
|
|
43
|
-
self.
|
|
44
|
-
|
|
45
|
-
def reset(self,
|
|
46
|
-
episode_id: str | None = None) -> None:
|
|
47
|
-
"""Reset the metric for a new episode or run.
|
|
14
|
+
self._step = -1
|
|
15
|
+
self._metric_dim_type = metric_dim_type
|
|
16
|
+
self._graph_type = graph_type
|
|
48
17
|
|
|
49
|
-
Args:
|
|
50
|
-
episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
|
|
51
|
-
"""
|
|
52
|
-
if self._granularity != MetricGranularity.RUN:
|
|
53
|
-
self._submit()
|
|
54
|
-
self._values = []
|
|
55
|
-
self._x_values = []
|
|
56
|
-
self._last_step = -1
|
|
57
|
-
self._episode_id = episode_id
|
|
58
|
-
|
|
59
18
|
@property
|
|
60
|
-
def
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
Returns:
|
|
64
|
-
str: The name of the metric.
|
|
65
|
-
"""
|
|
66
|
-
return self._name
|
|
19
|
+
def metric_dim_type(self) -> MetricDimType:
|
|
20
|
+
return self._metric_dim_type
|
|
67
21
|
|
|
68
22
|
@property
|
|
69
|
-
def
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
Returns:
|
|
73
|
-
MetricType: The type of the metric.
|
|
74
|
-
"""
|
|
75
|
-
return self._metric_type
|
|
23
|
+
def graph_type(self) -> GraphType:
|
|
24
|
+
return self._graph_type
|
|
76
25
|
|
|
77
|
-
|
|
78
|
-
def granularity(self) -> MetricGranularity:
|
|
79
|
-
"""The granularity of the metric.
|
|
80
|
-
|
|
81
|
-
Returns:
|
|
82
|
-
MetricGranularity: The granularity of the metric.
|
|
83
|
-
"""
|
|
84
|
-
return self._granularity
|
|
85
|
-
|
|
86
|
-
def log(self, data: Any, step: int | None = None, replace: bool = True) -> None:
|
|
26
|
+
def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
|
|
87
27
|
"""Log a new data point for the metric. The behavior depends on the granularity.
|
|
88
28
|
|
|
89
29
|
Args:
|
|
90
30
|
data (Any): The data point to log.
|
|
91
|
-
|
|
92
|
-
|
|
31
|
+
x (Any | None): The x-axis value associated with the data point.
|
|
32
|
+
if None, the current step is used.
|
|
33
|
+
replace (bool): Whether to replace the last logged value.
|
|
93
34
|
"""
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
35
|
+
if replace:
|
|
36
|
+
self._values[-1] = data
|
|
37
|
+
if x is not None:
|
|
38
|
+
self._x_values[-1] = x
|
|
39
|
+
else:
|
|
40
|
+
self._values.append(data)
|
|
41
|
+
if x is not None:
|
|
42
|
+
self._x_values.append(x)
|
|
102
43
|
else:
|
|
103
|
-
self.
|
|
104
|
-
self.
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
# If there is no data to submit, then return.
|
|
119
|
-
return
|
|
120
|
-
# TODO: Implement commit logic
|
|
121
|
-
|
|
122
|
-
# Clear data after the submission.
|
|
44
|
+
self._x_values.append(self._step)
|
|
45
|
+
self._step += 1
|
|
46
|
+
|
|
47
|
+
def finalize(self) -> dict:
|
|
48
|
+
"""Finalize the logged data for processing."""
|
|
49
|
+
ret_result = self._finalize()
|
|
50
|
+
|
|
51
|
+
return ret_result
|
|
52
|
+
|
|
53
|
+
def _finalize(self) -> dict:
|
|
54
|
+
"""Process the logged data before submission. To be implemented by subclasses."""
|
|
55
|
+
ret_val = {
|
|
56
|
+
"values": self._values,
|
|
57
|
+
"x_values": self._x_values
|
|
58
|
+
}
|
|
123
59
|
self._values = []
|
|
124
60
|
self._x_values = []
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
"""Finish the metric logging and submit the final data."""
|
|
128
|
-
self.reset()
|
|
129
|
-
self._submit()
|
|
61
|
+
self._step = -1
|
|
62
|
+
return ret_val
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from humalab.metrics.metric import Metrics
|
|
2
|
+
from humalab.constants import ArtifactType, GraphType, MetricDimType
|
|
3
|
+
from humalab.humalab_api_client import EpisodeStatus
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
SCENARIO_STATS_NEED_FLATTEN = {
|
|
8
|
+
"uniform_1d",
|
|
9
|
+
"bernoulli_1d",
|
|
10
|
+
"categorical_1d",
|
|
11
|
+
"discrete_1d",
|
|
12
|
+
"log_uniform_1d",
|
|
13
|
+
"gaussian_1d",
|
|
14
|
+
"truncated_gaussian_1d"
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ScenarioStats(Metrics):
|
|
19
|
+
"""Metric to track scenario statistics such as total reward, length, and success.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self,
|
|
25
|
+
name: str,
|
|
26
|
+
distribution_type: str,
|
|
27
|
+
metric_dim_type: MetricDimType,
|
|
28
|
+
graph_type: GraphType,
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__(
|
|
31
|
+
metric_dim_type=metric_dim_type,
|
|
32
|
+
graph_type=graph_type
|
|
33
|
+
)
|
|
34
|
+
self._name = name
|
|
35
|
+
self._distribution_type = distribution_type
|
|
36
|
+
self._artifact_type = ArtifactType.SCENARIO_STATS
|
|
37
|
+
self._values = {}
|
|
38
|
+
self._results = {}
|
|
39
|
+
|
|
40
|
+
@property
|
|
41
|
+
def name(self) -> str:
|
|
42
|
+
return self._name
|
|
43
|
+
|
|
44
|
+
@property
|
|
45
|
+
def distribution_type(self) -> str:
|
|
46
|
+
return self._distribution_type
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def artifact_type(self) -> ArtifactType:
|
|
50
|
+
return self._artifact_type
|
|
51
|
+
|
|
52
|
+
def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
|
|
53
|
+
if x in self._values:
|
|
54
|
+
if replace:
|
|
55
|
+
if self._distribution_type in SCENARIO_STATS_NEED_FLATTEN:
|
|
56
|
+
data = data[0]
|
|
57
|
+
self._values[x] = data
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Data for episode_id {x} already exists. Use replace=True to overwrite.")
|
|
60
|
+
else:
|
|
61
|
+
if self._distribution_type in SCENARIO_STATS_NEED_FLATTEN:
|
|
62
|
+
data = data[0]
|
|
63
|
+
self._values[x] = data
|
|
64
|
+
|
|
65
|
+
def log_status(self,
|
|
66
|
+
episode_id: str,
|
|
67
|
+
episode_status: EpisodeStatus,
|
|
68
|
+
replace: bool = False) -> None:
|
|
69
|
+
"""Log a new data point for the metric. The behavior depends on the granularity.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
data (Any): The data point to log.
|
|
73
|
+
x (Any | None): The x-axis value associated with the data point.
|
|
74
|
+
if None, the current step is used.
|
|
75
|
+
replace (bool): Whether to replace the last logged value.
|
|
76
|
+
"""
|
|
77
|
+
if episode_id in self._results:
|
|
78
|
+
if replace:
|
|
79
|
+
self._results[episode_id] = episode_status
|
|
80
|
+
else:
|
|
81
|
+
raise ValueError(f"Data for episode_id {episode_id} already exists. Use replace=True to overwrite.")
|
|
82
|
+
else:
|
|
83
|
+
self._results[episode_id] = episode_status
|
|
84
|
+
|
|
85
|
+
def _finalize(self) -> dict:
|
|
86
|
+
ret_val = {
|
|
87
|
+
"values": self._values,
|
|
88
|
+
"results": self._results,
|
|
89
|
+
"distribution_type": self._distribution_type,
|
|
90
|
+
}
|
|
91
|
+
self._values = {}
|
|
92
|
+
self._results = {}
|
|
93
|
+
return ret_val
|
|
94
|
+
|
|
95
|
+
|