humalab 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of humalab might be problematic. Click here for more details.
- humalab/__init__.py +9 -0
- humalab/assets/__init__.py +0 -0
- humalab/assets/archive.py +101 -0
- humalab/assets/resource_file.py +28 -0
- humalab/assets/resource_handler.py +175 -0
- humalab/constants.py +7 -0
- humalab/dists/__init__.py +17 -0
- humalab/dists/bernoulli.py +44 -0
- humalab/dists/categorical.py +49 -0
- humalab/dists/discrete.py +56 -0
- humalab/dists/distribution.py +38 -0
- humalab/dists/gaussian.py +49 -0
- humalab/dists/log_uniform.py +49 -0
- humalab/dists/truncated_gaussian.py +64 -0
- humalab/dists/uniform.py +49 -0
- humalab/humalab.py +149 -0
- humalab/humalab_api_client.py +273 -0
- humalab/humalab_config.py +86 -0
- humalab/humalab_test.py +510 -0
- humalab/metrics/__init__.py +11 -0
- humalab/metrics/dist_metric.py +22 -0
- humalab/metrics/metric.py +129 -0
- humalab/metrics/summary.py +54 -0
- humalab/run.py +214 -0
- humalab/scenario.py +225 -0
- humalab/scenario_test.py +911 -0
- humalab-0.0.1.dist-info/METADATA +43 -0
- humalab-0.0.1.dist-info/RECORD +32 -0
- humalab-0.0.1.dist-info/WHEEL +5 -0
- humalab-0.0.1.dist-info/entry_points.txt +2 -0
- humalab-0.0.1.dist-info/licenses/LICENSE +21 -0
- humalab-0.0.1.dist-info/top_level.txt +1 -0
humalab/humalab_test.py
ADDED
|
@@ -0,0 +1,510 @@
|
|
|
1
|
+
import unittest
|
|
2
|
+
from unittest.mock import patch, MagicMock, Mock
|
|
3
|
+
import uuid
|
|
4
|
+
|
|
5
|
+
from humalab_sdk import humalab
|
|
6
|
+
from humalab_sdk.run import Run
|
|
7
|
+
from humalab_sdk.scenario import Scenario
|
|
8
|
+
from humalab_sdk.humalab_config import HumalabConfig
|
|
9
|
+
from humalab_sdk.humalab_api_client import HumaLabApiClient
|
|
10
|
+
from humalab_sdk.constants import EpisodeStatus
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class HumalabTest(unittest.TestCase):
|
|
14
|
+
"""Unit tests for humalab module functions."""
|
|
15
|
+
|
|
16
|
+
def setUp(self):
|
|
17
|
+
"""Set up test fixtures before each test method."""
|
|
18
|
+
# Reset the global _cur_run before each test
|
|
19
|
+
humalab._cur_run = None
|
|
20
|
+
|
|
21
|
+
def tearDown(self):
|
|
22
|
+
"""Clean up after each test method."""
|
|
23
|
+
# Reset the global _cur_run after each test
|
|
24
|
+
humalab._cur_run = None
|
|
25
|
+
|
|
26
|
+
# Tests for _pull_scenario
|
|
27
|
+
|
|
28
|
+
def test_pull_scenario_should_return_scenario_when_no_scenario_id(self):
|
|
29
|
+
"""Test that _pull_scenario returns scenario when scenario_id is None."""
|
|
30
|
+
# Pre-condition
|
|
31
|
+
client = Mock()
|
|
32
|
+
scenario = {"key": "value"}
|
|
33
|
+
|
|
34
|
+
# In-test
|
|
35
|
+
result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
|
|
36
|
+
|
|
37
|
+
# Post-condition
|
|
38
|
+
self.assertEqual(result, scenario)
|
|
39
|
+
client.get_scenario.assert_not_called()
|
|
40
|
+
|
|
41
|
+
def test_pull_scenario_should_fetch_scenario_from_client_when_scenario_id_provided(self):
|
|
42
|
+
"""Test that _pull_scenario fetches from API when scenario_id is provided."""
|
|
43
|
+
# Pre-condition
|
|
44
|
+
client = Mock()
|
|
45
|
+
scenario_id = "test-scenario-id"
|
|
46
|
+
yaml_content = "scenario: test"
|
|
47
|
+
client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
48
|
+
|
|
49
|
+
# In-test
|
|
50
|
+
result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
|
|
51
|
+
|
|
52
|
+
# Post-condition
|
|
53
|
+
self.assertEqual(result, yaml_content)
|
|
54
|
+
client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
55
|
+
|
|
56
|
+
def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
|
|
57
|
+
"""Test that _pull_scenario uses scenario_id even when scenario is provided."""
|
|
58
|
+
# Pre-condition
|
|
59
|
+
client = Mock()
|
|
60
|
+
scenario = {"key": "value"}
|
|
61
|
+
scenario_id = "test-scenario-id"
|
|
62
|
+
yaml_content = "scenario: from_api"
|
|
63
|
+
client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
64
|
+
|
|
65
|
+
# In-test
|
|
66
|
+
result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
|
|
67
|
+
|
|
68
|
+
# Post-condition
|
|
69
|
+
self.assertEqual(result, yaml_content)
|
|
70
|
+
client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
71
|
+
|
|
72
|
+
# Tests for init context manager
|
|
73
|
+
|
|
74
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
75
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
76
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
77
|
+
@patch('humalab_sdk.humalab.Run')
|
|
78
|
+
def test_init_should_create_run_with_provided_parameters(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
79
|
+
"""Test that init() creates a Run with provided parameters."""
|
|
80
|
+
# Pre-condition
|
|
81
|
+
entity = "test_entity"
|
|
82
|
+
project = "test_project"
|
|
83
|
+
name = "test_name"
|
|
84
|
+
description = "test_description"
|
|
85
|
+
run_id = "test_id"
|
|
86
|
+
tags = ["tag1", "tag2"]
|
|
87
|
+
scenario_data = {"key": "value"}
|
|
88
|
+
|
|
89
|
+
mock_config = Mock()
|
|
90
|
+
mock_config.entity = "default_entity"
|
|
91
|
+
mock_config.base_url = "http://localhost:8000"
|
|
92
|
+
mock_config.api_key = "test_key"
|
|
93
|
+
mock_config.timeout = 30.0
|
|
94
|
+
mock_config_class.return_value = mock_config
|
|
95
|
+
|
|
96
|
+
mock_api_client = Mock()
|
|
97
|
+
mock_api_client_class.return_value = mock_api_client
|
|
98
|
+
|
|
99
|
+
mock_scenario_inst = Mock()
|
|
100
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
101
|
+
|
|
102
|
+
mock_run_inst = Mock()
|
|
103
|
+
mock_run_class.return_value = mock_run_inst
|
|
104
|
+
|
|
105
|
+
# In-test
|
|
106
|
+
with humalab.init(
|
|
107
|
+
entity=entity,
|
|
108
|
+
project=project,
|
|
109
|
+
name=name,
|
|
110
|
+
description=description,
|
|
111
|
+
id=run_id,
|
|
112
|
+
tags=tags,
|
|
113
|
+
scenario=scenario_data
|
|
114
|
+
) as run:
|
|
115
|
+
# Post-condition
|
|
116
|
+
self.assertEqual(run, mock_run_inst)
|
|
117
|
+
mock_run_class.assert_called_once()
|
|
118
|
+
call_kwargs = mock_run_class.call_args.kwargs
|
|
119
|
+
self.assertEqual(call_kwargs['entity'], entity)
|
|
120
|
+
self.assertEqual(call_kwargs['project'], project)
|
|
121
|
+
self.assertEqual(call_kwargs['name'], name)
|
|
122
|
+
self.assertEqual(call_kwargs['description'], description)
|
|
123
|
+
self.assertEqual(call_kwargs['id'], run_id)
|
|
124
|
+
self.assertEqual(call_kwargs['tags'], tags)
|
|
125
|
+
self.assertEqual(call_kwargs['scenario'], mock_scenario_inst)
|
|
126
|
+
|
|
127
|
+
# Verify finish was called
|
|
128
|
+
mock_run_inst.finish.assert_called_once()
|
|
129
|
+
|
|
130
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
131
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
132
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
133
|
+
@patch('humalab_sdk.humalab.Run')
|
|
134
|
+
def test_init_should_use_config_defaults_when_parameters_not_provided(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
135
|
+
"""Test that init() uses config defaults when parameters are not provided."""
|
|
136
|
+
# Pre-condition
|
|
137
|
+
mock_config = Mock()
|
|
138
|
+
mock_config.entity = "config_entity"
|
|
139
|
+
mock_config.base_url = "http://config:8000"
|
|
140
|
+
mock_config.api_key = "config_key"
|
|
141
|
+
mock_config.timeout = 60.0
|
|
142
|
+
mock_config_class.return_value = mock_config
|
|
143
|
+
|
|
144
|
+
mock_api_client = Mock()
|
|
145
|
+
mock_api_client_class.return_value = mock_api_client
|
|
146
|
+
|
|
147
|
+
mock_scenario_inst = Mock()
|
|
148
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
149
|
+
|
|
150
|
+
mock_run_inst = Mock()
|
|
151
|
+
mock_run_class.return_value = mock_run_inst
|
|
152
|
+
|
|
153
|
+
# In-test
|
|
154
|
+
with humalab.init() as run:
|
|
155
|
+
# Post-condition
|
|
156
|
+
call_kwargs = mock_run_class.call_args.kwargs
|
|
157
|
+
self.assertEqual(call_kwargs['entity'], "config_entity")
|
|
158
|
+
self.assertEqual(call_kwargs['project'], "default")
|
|
159
|
+
self.assertEqual(call_kwargs['name'], "")
|
|
160
|
+
self.assertEqual(call_kwargs['description'], "")
|
|
161
|
+
self.assertIsNotNone(call_kwargs['id']) # UUID generated
|
|
162
|
+
self.assertIsNone(call_kwargs['tags'])
|
|
163
|
+
|
|
164
|
+
mock_run_inst.finish.assert_called_once()
|
|
165
|
+
|
|
166
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
167
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
168
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
169
|
+
@patch('humalab_sdk.humalab.Run')
|
|
170
|
+
def test_init_should_generate_uuid_when_id_not_provided(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
171
|
+
"""Test that init() generates a UUID when id is not provided."""
|
|
172
|
+
# Pre-condition
|
|
173
|
+
mock_config = Mock()
|
|
174
|
+
mock_config.entity = "test_entity"
|
|
175
|
+
mock_config.base_url = "http://localhost:8000"
|
|
176
|
+
mock_config.api_key = "test_key"
|
|
177
|
+
mock_config.timeout = 30.0
|
|
178
|
+
mock_config_class.return_value = mock_config
|
|
179
|
+
|
|
180
|
+
mock_api_client = Mock()
|
|
181
|
+
mock_api_client_class.return_value = mock_api_client
|
|
182
|
+
|
|
183
|
+
mock_scenario_inst = Mock()
|
|
184
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
185
|
+
|
|
186
|
+
mock_run_inst = Mock()
|
|
187
|
+
mock_run_class.return_value = mock_run_inst
|
|
188
|
+
|
|
189
|
+
# In-test
|
|
190
|
+
with humalab.init() as run:
|
|
191
|
+
# Post-condition
|
|
192
|
+
call_kwargs = mock_run_class.call_args.kwargs
|
|
193
|
+
run_id = call_kwargs['id']
|
|
194
|
+
# Verify it's a valid UUID
|
|
195
|
+
uuid.UUID(run_id) # Will raise ValueError if not valid
|
|
196
|
+
|
|
197
|
+
mock_run_inst.finish.assert_called_once()
|
|
198
|
+
|
|
199
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
200
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
201
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
202
|
+
@patch('humalab_sdk.humalab.Run')
|
|
203
|
+
def test_init_should_initialize_scenario_with_seed(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
204
|
+
"""Test that init() initializes scenario with provided seed."""
|
|
205
|
+
# Pre-condition
|
|
206
|
+
seed = 42
|
|
207
|
+
scenario_data = {"key": "value"}
|
|
208
|
+
|
|
209
|
+
mock_config = Mock()
|
|
210
|
+
mock_config.entity = "test_entity"
|
|
211
|
+
mock_config.base_url = "http://localhost:8000"
|
|
212
|
+
mock_config.api_key = "test_key"
|
|
213
|
+
mock_config.timeout = 30.0
|
|
214
|
+
mock_config_class.return_value = mock_config
|
|
215
|
+
|
|
216
|
+
mock_api_client = Mock()
|
|
217
|
+
mock_api_client_class.return_value = mock_api_client
|
|
218
|
+
|
|
219
|
+
mock_scenario_inst = Mock()
|
|
220
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
221
|
+
|
|
222
|
+
mock_run_inst = Mock()
|
|
223
|
+
mock_run_class.return_value = mock_run_inst
|
|
224
|
+
|
|
225
|
+
# In-test
|
|
226
|
+
with humalab.init(scenario=scenario_data, seed=seed) as run:
|
|
227
|
+
# Post-condition
|
|
228
|
+
mock_scenario_inst.init.assert_called_once()
|
|
229
|
+
call_kwargs = mock_scenario_inst.init.call_args.kwargs
|
|
230
|
+
self.assertEqual(call_kwargs['seed'], seed)
|
|
231
|
+
self.assertEqual(call_kwargs['scenario'], scenario_data)
|
|
232
|
+
|
|
233
|
+
mock_run_inst.finish.assert_called_once()
|
|
234
|
+
|
|
235
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
236
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
237
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
238
|
+
@patch('humalab_sdk.humalab.Run')
|
|
239
|
+
def test_init_should_pull_scenario_from_api_when_scenario_id_provided(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
240
|
+
"""Test that init() pulls scenario from API when scenario_id is provided."""
|
|
241
|
+
# Pre-condition
|
|
242
|
+
scenario_id = "test-scenario-id"
|
|
243
|
+
yaml_content = "scenario: from_api"
|
|
244
|
+
|
|
245
|
+
mock_config = Mock()
|
|
246
|
+
mock_config.entity = "test_entity"
|
|
247
|
+
mock_config.base_url = "http://localhost:8000"
|
|
248
|
+
mock_config.api_key = "test_key"
|
|
249
|
+
mock_config.timeout = 30.0
|
|
250
|
+
mock_config_class.return_value = mock_config
|
|
251
|
+
|
|
252
|
+
mock_api_client = Mock()
|
|
253
|
+
mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
|
|
254
|
+
mock_api_client_class.return_value = mock_api_client
|
|
255
|
+
|
|
256
|
+
mock_scenario_inst = Mock()
|
|
257
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
258
|
+
|
|
259
|
+
mock_run_inst = Mock()
|
|
260
|
+
mock_run_class.return_value = mock_run_inst
|
|
261
|
+
|
|
262
|
+
# In-test
|
|
263
|
+
with humalab.init(scenario_id=scenario_id) as run:
|
|
264
|
+
# Post-condition
|
|
265
|
+
mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
|
|
266
|
+
mock_scenario_inst.init.assert_called_once()
|
|
267
|
+
call_kwargs = mock_scenario_inst.init.call_args.kwargs
|
|
268
|
+
self.assertEqual(call_kwargs['scenario'], yaml_content)
|
|
269
|
+
|
|
270
|
+
mock_run_inst.finish.assert_called_once()
|
|
271
|
+
|
|
272
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
273
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
274
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
275
|
+
@patch('humalab_sdk.humalab.Run')
|
|
276
|
+
def test_init_should_set_global_cur_run(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
277
|
+
"""Test that init() sets the global _cur_run variable."""
|
|
278
|
+
# Pre-condition
|
|
279
|
+
mock_config = Mock()
|
|
280
|
+
mock_config.entity = "test_entity"
|
|
281
|
+
mock_config.base_url = "http://localhost:8000"
|
|
282
|
+
mock_config.api_key = "test_key"
|
|
283
|
+
mock_config.timeout = 30.0
|
|
284
|
+
mock_config_class.return_value = mock_config
|
|
285
|
+
|
|
286
|
+
mock_api_client = Mock()
|
|
287
|
+
mock_api_client_class.return_value = mock_api_client
|
|
288
|
+
|
|
289
|
+
mock_scenario_inst = Mock()
|
|
290
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
291
|
+
|
|
292
|
+
mock_run_inst = Mock()
|
|
293
|
+
mock_run_class.return_value = mock_run_inst
|
|
294
|
+
|
|
295
|
+
# In-test
|
|
296
|
+
self.assertIsNone(humalab._cur_run)
|
|
297
|
+
with humalab.init() as run:
|
|
298
|
+
# Post-condition
|
|
299
|
+
self.assertEqual(humalab._cur_run, mock_run_inst)
|
|
300
|
+
|
|
301
|
+
mock_run_inst.finish.assert_called_once()
|
|
302
|
+
|
|
303
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
304
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
305
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
306
|
+
@patch('humalab_sdk.humalab.Run')
|
|
307
|
+
def test_init_should_call_finish_on_exception(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
308
|
+
"""Test that init() calls finish even when exception occurs in context."""
|
|
309
|
+
# Pre-condition
|
|
310
|
+
mock_config = Mock()
|
|
311
|
+
mock_config.entity = "test_entity"
|
|
312
|
+
mock_config.base_url = "http://localhost:8000"
|
|
313
|
+
mock_config.api_key = "test_key"
|
|
314
|
+
mock_config.timeout = 30.0
|
|
315
|
+
mock_config_class.return_value = mock_config
|
|
316
|
+
|
|
317
|
+
mock_api_client = Mock()
|
|
318
|
+
mock_api_client_class.return_value = mock_api_client
|
|
319
|
+
|
|
320
|
+
mock_scenario_inst = Mock()
|
|
321
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
322
|
+
|
|
323
|
+
mock_run_inst = Mock()
|
|
324
|
+
mock_run_class.return_value = mock_run_inst
|
|
325
|
+
|
|
326
|
+
# In-test & Post-condition
|
|
327
|
+
with self.assertRaises(RuntimeError):
|
|
328
|
+
with humalab.init() as run:
|
|
329
|
+
raise RuntimeError("Test exception")
|
|
330
|
+
|
|
331
|
+
# Verify finish was still called
|
|
332
|
+
mock_run_inst.finish.assert_called_once()
|
|
333
|
+
|
|
334
|
+
@patch('humalab_sdk.humalab.HumaLabApiClient')
|
|
335
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
336
|
+
@patch('humalab_sdk.humalab.Scenario')
|
|
337
|
+
@patch('humalab_sdk.humalab.Run')
|
|
338
|
+
def test_init_should_create_api_client_with_custom_parameters(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
|
|
339
|
+
"""Test that init() creates API client with custom base_url, api_key, and timeout."""
|
|
340
|
+
# Pre-condition
|
|
341
|
+
base_url = "http://custom:9000"
|
|
342
|
+
api_key = "custom_key"
|
|
343
|
+
timeout = 120.0
|
|
344
|
+
|
|
345
|
+
mock_config = Mock()
|
|
346
|
+
mock_config.entity = "test_entity"
|
|
347
|
+
mock_config.base_url = "http://localhost:8000"
|
|
348
|
+
mock_config.api_key = "default_key"
|
|
349
|
+
mock_config.timeout = 30.0
|
|
350
|
+
mock_config_class.return_value = mock_config
|
|
351
|
+
|
|
352
|
+
mock_api_client = Mock()
|
|
353
|
+
mock_api_client_class.return_value = mock_api_client
|
|
354
|
+
|
|
355
|
+
mock_scenario_inst = Mock()
|
|
356
|
+
mock_scenario_class.return_value = mock_scenario_inst
|
|
357
|
+
|
|
358
|
+
mock_run_inst = Mock()
|
|
359
|
+
mock_run_class.return_value = mock_run_inst
|
|
360
|
+
|
|
361
|
+
# In-test
|
|
362
|
+
with humalab.init(base_url=base_url, api_key=api_key, timeout=timeout) as run:
|
|
363
|
+
# Post-condition
|
|
364
|
+
mock_api_client_class.assert_called_once_with(
|
|
365
|
+
base_url=base_url,
|
|
366
|
+
api_key=api_key,
|
|
367
|
+
timeout=timeout
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
mock_run_inst.finish.assert_called_once()
|
|
371
|
+
|
|
372
|
+
# Tests for finish function
|
|
373
|
+
|
|
374
|
+
def test_finish_should_call_finish_on_current_run_with_default_status(self):
|
|
375
|
+
"""Test that finish() calls finish on the current run with default status."""
|
|
376
|
+
# Pre-condition
|
|
377
|
+
mock_run = Mock()
|
|
378
|
+
humalab._cur_run = mock_run
|
|
379
|
+
|
|
380
|
+
# In-test
|
|
381
|
+
humalab.finish()
|
|
382
|
+
|
|
383
|
+
# Post-condition
|
|
384
|
+
mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=None)
|
|
385
|
+
|
|
386
|
+
def test_finish_should_call_finish_on_current_run_with_custom_status(self):
|
|
387
|
+
"""Test that finish() calls finish on the current run with custom status."""
|
|
388
|
+
# Pre-condition
|
|
389
|
+
mock_run = Mock()
|
|
390
|
+
humalab._cur_run = mock_run
|
|
391
|
+
status = EpisodeStatus.FAILED
|
|
392
|
+
|
|
393
|
+
# In-test
|
|
394
|
+
humalab.finish(status=status)
|
|
395
|
+
|
|
396
|
+
# Post-condition
|
|
397
|
+
mock_run.finish.assert_called_once_with(status=status, quiet=None)
|
|
398
|
+
|
|
399
|
+
def test_finish_should_call_finish_on_current_run_with_quiet_parameter(self):
|
|
400
|
+
"""Test that finish() calls finish on the current run with quiet parameter."""
|
|
401
|
+
# Pre-condition
|
|
402
|
+
mock_run = Mock()
|
|
403
|
+
humalab._cur_run = mock_run
|
|
404
|
+
quiet = True
|
|
405
|
+
|
|
406
|
+
# In-test
|
|
407
|
+
humalab.finish(quiet=quiet)
|
|
408
|
+
|
|
409
|
+
# Post-condition
|
|
410
|
+
mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=quiet)
|
|
411
|
+
|
|
412
|
+
def test_finish_should_do_nothing_when_no_current_run(self):
|
|
413
|
+
"""Test that finish() does nothing when _cur_run is None."""
|
|
414
|
+
# Pre-condition
|
|
415
|
+
humalab._cur_run = None
|
|
416
|
+
|
|
417
|
+
# In-test
|
|
418
|
+
humalab.finish() # Should not raise any exception
|
|
419
|
+
|
|
420
|
+
# Post-condition
|
|
421
|
+
# No exception means success
|
|
422
|
+
self.assertIsNone(humalab._cur_run)
|
|
423
|
+
|
|
424
|
+
# Tests for login function
|
|
425
|
+
|
|
426
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
427
|
+
def test_login_should_set_api_key_when_provided(self, mock_config_class):
|
|
428
|
+
"""Test that login() sets the api_key when provided."""
|
|
429
|
+
# Pre-condition
|
|
430
|
+
mock_config = Mock()
|
|
431
|
+
mock_config.api_key = "old_key"
|
|
432
|
+
mock_config.base_url = "http://localhost:8000"
|
|
433
|
+
mock_config.timeout = 30.0
|
|
434
|
+
mock_config_class.return_value = mock_config
|
|
435
|
+
|
|
436
|
+
new_key = "new_api_key"
|
|
437
|
+
|
|
438
|
+
# In-test
|
|
439
|
+
result = humalab.login(api_key=new_key)
|
|
440
|
+
|
|
441
|
+
# Post-condition
|
|
442
|
+
self.assertTrue(result)
|
|
443
|
+
self.assertEqual(mock_config.api_key, new_key)
|
|
444
|
+
|
|
445
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
446
|
+
def test_login_should_keep_existing_key_when_not_provided(self, mock_config_class):
|
|
447
|
+
"""Test that login() keeps existing api_key when key is not provided."""
|
|
448
|
+
# Pre-condition
|
|
449
|
+
existing_key = "existing_key"
|
|
450
|
+
existing_url = "http://localhost:8000"
|
|
451
|
+
existing_timeout = 30.0
|
|
452
|
+
mock_config = Mock()
|
|
453
|
+
mock_config.api_key = existing_key
|
|
454
|
+
mock_config.base_url = existing_url
|
|
455
|
+
mock_config.timeout = existing_timeout
|
|
456
|
+
mock_config_class.return_value = mock_config
|
|
457
|
+
|
|
458
|
+
# In-test
|
|
459
|
+
result = humalab.login()
|
|
460
|
+
|
|
461
|
+
# Post-condition
|
|
462
|
+
self.assertTrue(result)
|
|
463
|
+
self.assertEqual(mock_config.api_key, existing_key)
|
|
464
|
+
self.assertEqual(mock_config.base_url, existing_url)
|
|
465
|
+
self.assertEqual(mock_config.timeout, existing_timeout)
|
|
466
|
+
|
|
467
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
468
|
+
def test_login_should_return_true(self, mock_config_class):
|
|
469
|
+
"""Test that login() always returns True."""
|
|
470
|
+
# Pre-condition
|
|
471
|
+
mock_config = Mock()
|
|
472
|
+
mock_config.api_key = "test_key"
|
|
473
|
+
mock_config.base_url = "http://localhost:8000"
|
|
474
|
+
mock_config.timeout = 30.0
|
|
475
|
+
mock_config_class.return_value = mock_config
|
|
476
|
+
|
|
477
|
+
# In-test
|
|
478
|
+
result = humalab.login()
|
|
479
|
+
|
|
480
|
+
# Post-condition
|
|
481
|
+
self.assertTrue(result)
|
|
482
|
+
|
|
483
|
+
@patch('humalab_sdk.humalab.HumalabConfig')
|
|
484
|
+
def test_login_should_accept_optional_parameters(self, mock_config_class):
|
|
485
|
+
"""Test that login() accepts optional parameters without errors."""
|
|
486
|
+
# Pre-condition
|
|
487
|
+
mock_config = Mock()
|
|
488
|
+
mock_config.api_key = "old_key"
|
|
489
|
+
mock_config.base_url = "http://old:8000"
|
|
490
|
+
mock_config.timeout = 30.0
|
|
491
|
+
mock_config_class.return_value = mock_config
|
|
492
|
+
|
|
493
|
+
# In-test
|
|
494
|
+
result = humalab.login(
|
|
495
|
+
api_key="test_key",
|
|
496
|
+
relogin=True,
|
|
497
|
+
host="http://localhost:8000",
|
|
498
|
+
force=True,
|
|
499
|
+
timeout=60.0
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
# Post-condition
|
|
503
|
+
self.assertTrue(result)
|
|
504
|
+
self.assertEqual(mock_config.api_key, "test_key")
|
|
505
|
+
self.assertEqual(mock_config.base_url, "http://localhost:8000")
|
|
506
|
+
self.assertEqual(mock_config.timeout, 60.0)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
if __name__ == "__main__":
|
|
510
|
+
unittest.main()
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
|
|
2
|
+
from humalab_sdk.metrics.metric import MetricGranularity, Metrics, MetricType
|
|
3
|
+
|
|
4
|
+
class DistributionMetric(Metrics):
|
|
5
|
+
def __init__(self,
|
|
6
|
+
name: str,
|
|
7
|
+
distribution_type: str,
|
|
8
|
+
episode_id: str,
|
|
9
|
+
run_id: str,
|
|
10
|
+
granularity: MetricGranularity = MetricGranularity.EPISODE) -> None:
|
|
11
|
+
"""
|
|
12
|
+
Initialize the distribution metric.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
name (str): The name of the metric.
|
|
16
|
+
distribution_type (str): The type of distribution (e.g., "normal", "uniform").
|
|
17
|
+
episode_id (str): The ID of the episode.
|
|
18
|
+
run_id (str): The ID of the run.
|
|
19
|
+
granularity (MetricGranularity): The granularity of the metric.
|
|
20
|
+
"""
|
|
21
|
+
super().__init__(name, MetricType.DISTRIBUTION, episode_id=episode_id, run_id=run_id, granularity=granularity)
|
|
22
|
+
self.distribution_type = distribution_type
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any
|
|
3
|
+
from humalab_sdk.constants import EpisodeStatus
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MetricType(Enum):
|
|
7
|
+
DEFAULT = "default"
|
|
8
|
+
STREAM = "stream"
|
|
9
|
+
DISTRIBUTION = "distribution"
|
|
10
|
+
SUMMARY = "summary"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MetricGranularity(Enum):
|
|
14
|
+
STEP = "step"
|
|
15
|
+
EPISODE = "episode"
|
|
16
|
+
RUN = "run"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Metrics:
|
|
20
|
+
def __init__(self,
|
|
21
|
+
name: str,
|
|
22
|
+
metric_type: MetricType,
|
|
23
|
+
episode_id: str,
|
|
24
|
+
run_id: str,
|
|
25
|
+
granularity: MetricGranularity = MetricGranularity.STEP) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Base class for different types of metrics.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
name (str): The name of the metric.
|
|
31
|
+
metric_type (MetricType): The type of the metric.
|
|
32
|
+
episode_id (str): The ID of the episode.
|
|
33
|
+
run_id (str): The ID of the run.
|
|
34
|
+
granularity (MetricGranularity): The granularity of the metric.
|
|
35
|
+
"""
|
|
36
|
+
self._name = name
|
|
37
|
+
self._metric_type = metric_type
|
|
38
|
+
self._granularity = granularity
|
|
39
|
+
self._values = []
|
|
40
|
+
self._x_values = []
|
|
41
|
+
self._episode_id = episode_id
|
|
42
|
+
self._run_id = run_id
|
|
43
|
+
self._last_step = -1
|
|
44
|
+
|
|
45
|
+
def reset(self,
|
|
46
|
+
episode_id: str | None = None) -> None:
|
|
47
|
+
"""Reset the metric for a new episode or run.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
|
|
51
|
+
"""
|
|
52
|
+
if self._granularity != MetricGranularity.RUN:
|
|
53
|
+
self._submit()
|
|
54
|
+
self._values = []
|
|
55
|
+
self._x_values = []
|
|
56
|
+
self._last_step = -1
|
|
57
|
+
self._episode_id = episode_id
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def name(self) -> str:
|
|
61
|
+
"""The name of the metric.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
str: The name of the metric.
|
|
65
|
+
"""
|
|
66
|
+
return self._name
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def metric_type(self) -> MetricType:
|
|
70
|
+
"""The type of the metric.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
MetricType: The type of the metric.
|
|
74
|
+
"""
|
|
75
|
+
return self._metric_type
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def granularity(self) -> MetricGranularity:
|
|
79
|
+
"""The granularity of the metric.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
MetricGranularity: The granularity of the metric.
|
|
83
|
+
"""
|
|
84
|
+
return self._granularity
|
|
85
|
+
|
|
86
|
+
def log(self, data: Any, step: int | None = None, replace: bool = True) -> None:
|
|
87
|
+
"""Log a new data point for the metric. The behavior depends on the granularity.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
data (Any): The data point to log.
|
|
91
|
+
step (int | None): The step number for STEP granularity. Must be provided if granularity is STEP.
|
|
92
|
+
replace (bool): Whether to replace the last logged value if logging at the same step/episode/run.
|
|
93
|
+
"""
|
|
94
|
+
if self._granularity == MetricGranularity.STEP:
|
|
95
|
+
if step is None:
|
|
96
|
+
raise ValueError("step Must be provided!")
|
|
97
|
+
if step == self._last_step:
|
|
98
|
+
if replace:
|
|
99
|
+
self._values[-1] = data
|
|
100
|
+
else:
|
|
101
|
+
raise ValueError("Cannot log the data at the same step.")
|
|
102
|
+
else:
|
|
103
|
+
self._values.append(data)
|
|
104
|
+
self._x_values.append(step)
|
|
105
|
+
elif self._granularity == MetricGranularity.EPISODE:
|
|
106
|
+
if len(self._x_values) > 0 and not replace:
|
|
107
|
+
raise ValueError("Cannot log the data at the same episode.")
|
|
108
|
+
self._values = [data]
|
|
109
|
+
self._x_values = [self._episode_id]
|
|
110
|
+
else: # MetricGranularity.RUN
|
|
111
|
+
if len(self._values) > 0 and not replace:
|
|
112
|
+
raise ValueError("Cannot log the data at the same run.")
|
|
113
|
+
self._values = [data]
|
|
114
|
+
self._x_values = [self._run_id]
|
|
115
|
+
|
|
116
|
+
def _submit(self) -> None:
|
|
117
|
+
if not self._values:
|
|
118
|
+
# If there is no data to submit, then return.
|
|
119
|
+
return
|
|
120
|
+
# TODO: Implement commit logic
|
|
121
|
+
|
|
122
|
+
# Clear data after the submission.
|
|
123
|
+
self._values = []
|
|
124
|
+
self._x_values = []
|
|
125
|
+
|
|
126
|
+
def finish(self) -> None:
|
|
127
|
+
"""Finish the metric logging and submit the final data."""
|
|
128
|
+
self.reset()
|
|
129
|
+
self._submit()
|