humalab 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of humalab might be problematic. Click here for more details.

@@ -0,0 +1,510 @@
1
+ import unittest
2
+ from unittest.mock import patch, MagicMock, Mock
3
+ import uuid
4
+
5
+ from humalab_sdk import humalab
6
+ from humalab_sdk.run import Run
7
+ from humalab_sdk.scenario import Scenario
8
+ from humalab_sdk.humalab_config import HumalabConfig
9
+ from humalab_sdk.humalab_api_client import HumaLabApiClient
10
+ from humalab_sdk.constants import EpisodeStatus
11
+
12
+
13
+ class HumalabTest(unittest.TestCase):
14
+ """Unit tests for humalab module functions."""
15
+
16
+ def setUp(self):
17
+ """Set up test fixtures before each test method."""
18
+ # Reset the global _cur_run before each test
19
+ humalab._cur_run = None
20
+
21
+ def tearDown(self):
22
+ """Clean up after each test method."""
23
+ # Reset the global _cur_run after each test
24
+ humalab._cur_run = None
25
+
26
+ # Tests for _pull_scenario
27
+
28
+ def test_pull_scenario_should_return_scenario_when_no_scenario_id(self):
29
+ """Test that _pull_scenario returns scenario when scenario_id is None."""
30
+ # Pre-condition
31
+ client = Mock()
32
+ scenario = {"key": "value"}
33
+
34
+ # In-test
35
+ result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=None)
36
+
37
+ # Post-condition
38
+ self.assertEqual(result, scenario)
39
+ client.get_scenario.assert_not_called()
40
+
41
+ def test_pull_scenario_should_fetch_scenario_from_client_when_scenario_id_provided(self):
42
+ """Test that _pull_scenario fetches from API when scenario_id is provided."""
43
+ # Pre-condition
44
+ client = Mock()
45
+ scenario_id = "test-scenario-id"
46
+ yaml_content = "scenario: test"
47
+ client.get_scenario.return_value = {"yaml_content": yaml_content}
48
+
49
+ # In-test
50
+ result = humalab._pull_scenario(client=client, scenario=None, scenario_id=scenario_id)
51
+
52
+ # Post-condition
53
+ self.assertEqual(result, yaml_content)
54
+ client.get_scenario.assert_called_once_with(uuid=scenario_id)
55
+
56
+ def test_pull_scenario_should_prefer_scenario_id_over_scenario(self):
57
+ """Test that _pull_scenario uses scenario_id even when scenario is provided."""
58
+ # Pre-condition
59
+ client = Mock()
60
+ scenario = {"key": "value"}
61
+ scenario_id = "test-scenario-id"
62
+ yaml_content = "scenario: from_api"
63
+ client.get_scenario.return_value = {"yaml_content": yaml_content}
64
+
65
+ # In-test
66
+ result = humalab._pull_scenario(client=client, scenario=scenario, scenario_id=scenario_id)
67
+
68
+ # Post-condition
69
+ self.assertEqual(result, yaml_content)
70
+ client.get_scenario.assert_called_once_with(uuid=scenario_id)
71
+
72
+ # Tests for init context manager
73
+
74
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
75
+ @patch('humalab_sdk.humalab.HumalabConfig')
76
+ @patch('humalab_sdk.humalab.Scenario')
77
+ @patch('humalab_sdk.humalab.Run')
78
+ def test_init_should_create_run_with_provided_parameters(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
79
+ """Test that init() creates a Run with provided parameters."""
80
+ # Pre-condition
81
+ entity = "test_entity"
82
+ project = "test_project"
83
+ name = "test_name"
84
+ description = "test_description"
85
+ run_id = "test_id"
86
+ tags = ["tag1", "tag2"]
87
+ scenario_data = {"key": "value"}
88
+
89
+ mock_config = Mock()
90
+ mock_config.entity = "default_entity"
91
+ mock_config.base_url = "http://localhost:8000"
92
+ mock_config.api_key = "test_key"
93
+ mock_config.timeout = 30.0
94
+ mock_config_class.return_value = mock_config
95
+
96
+ mock_api_client = Mock()
97
+ mock_api_client_class.return_value = mock_api_client
98
+
99
+ mock_scenario_inst = Mock()
100
+ mock_scenario_class.return_value = mock_scenario_inst
101
+
102
+ mock_run_inst = Mock()
103
+ mock_run_class.return_value = mock_run_inst
104
+
105
+ # In-test
106
+ with humalab.init(
107
+ entity=entity,
108
+ project=project,
109
+ name=name,
110
+ description=description,
111
+ id=run_id,
112
+ tags=tags,
113
+ scenario=scenario_data
114
+ ) as run:
115
+ # Post-condition
116
+ self.assertEqual(run, mock_run_inst)
117
+ mock_run_class.assert_called_once()
118
+ call_kwargs = mock_run_class.call_args.kwargs
119
+ self.assertEqual(call_kwargs['entity'], entity)
120
+ self.assertEqual(call_kwargs['project'], project)
121
+ self.assertEqual(call_kwargs['name'], name)
122
+ self.assertEqual(call_kwargs['description'], description)
123
+ self.assertEqual(call_kwargs['id'], run_id)
124
+ self.assertEqual(call_kwargs['tags'], tags)
125
+ self.assertEqual(call_kwargs['scenario'], mock_scenario_inst)
126
+
127
+ # Verify finish was called
128
+ mock_run_inst.finish.assert_called_once()
129
+
130
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
131
+ @patch('humalab_sdk.humalab.HumalabConfig')
132
+ @patch('humalab_sdk.humalab.Scenario')
133
+ @patch('humalab_sdk.humalab.Run')
134
+ def test_init_should_use_config_defaults_when_parameters_not_provided(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
135
+ """Test that init() uses config defaults when parameters are not provided."""
136
+ # Pre-condition
137
+ mock_config = Mock()
138
+ mock_config.entity = "config_entity"
139
+ mock_config.base_url = "http://config:8000"
140
+ mock_config.api_key = "config_key"
141
+ mock_config.timeout = 60.0
142
+ mock_config_class.return_value = mock_config
143
+
144
+ mock_api_client = Mock()
145
+ mock_api_client_class.return_value = mock_api_client
146
+
147
+ mock_scenario_inst = Mock()
148
+ mock_scenario_class.return_value = mock_scenario_inst
149
+
150
+ mock_run_inst = Mock()
151
+ mock_run_class.return_value = mock_run_inst
152
+
153
+ # In-test
154
+ with humalab.init() as run:
155
+ # Post-condition
156
+ call_kwargs = mock_run_class.call_args.kwargs
157
+ self.assertEqual(call_kwargs['entity'], "config_entity")
158
+ self.assertEqual(call_kwargs['project'], "default")
159
+ self.assertEqual(call_kwargs['name'], "")
160
+ self.assertEqual(call_kwargs['description'], "")
161
+ self.assertIsNotNone(call_kwargs['id']) # UUID generated
162
+ self.assertIsNone(call_kwargs['tags'])
163
+
164
+ mock_run_inst.finish.assert_called_once()
165
+
166
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
167
+ @patch('humalab_sdk.humalab.HumalabConfig')
168
+ @patch('humalab_sdk.humalab.Scenario')
169
+ @patch('humalab_sdk.humalab.Run')
170
+ def test_init_should_generate_uuid_when_id_not_provided(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
171
+ """Test that init() generates a UUID when id is not provided."""
172
+ # Pre-condition
173
+ mock_config = Mock()
174
+ mock_config.entity = "test_entity"
175
+ mock_config.base_url = "http://localhost:8000"
176
+ mock_config.api_key = "test_key"
177
+ mock_config.timeout = 30.0
178
+ mock_config_class.return_value = mock_config
179
+
180
+ mock_api_client = Mock()
181
+ mock_api_client_class.return_value = mock_api_client
182
+
183
+ mock_scenario_inst = Mock()
184
+ mock_scenario_class.return_value = mock_scenario_inst
185
+
186
+ mock_run_inst = Mock()
187
+ mock_run_class.return_value = mock_run_inst
188
+
189
+ # In-test
190
+ with humalab.init() as run:
191
+ # Post-condition
192
+ call_kwargs = mock_run_class.call_args.kwargs
193
+ run_id = call_kwargs['id']
194
+ # Verify it's a valid UUID
195
+ uuid.UUID(run_id) # Will raise ValueError if not valid
196
+
197
+ mock_run_inst.finish.assert_called_once()
198
+
199
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
200
+ @patch('humalab_sdk.humalab.HumalabConfig')
201
+ @patch('humalab_sdk.humalab.Scenario')
202
+ @patch('humalab_sdk.humalab.Run')
203
+ def test_init_should_initialize_scenario_with_seed(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
204
+ """Test that init() initializes scenario with provided seed."""
205
+ # Pre-condition
206
+ seed = 42
207
+ scenario_data = {"key": "value"}
208
+
209
+ mock_config = Mock()
210
+ mock_config.entity = "test_entity"
211
+ mock_config.base_url = "http://localhost:8000"
212
+ mock_config.api_key = "test_key"
213
+ mock_config.timeout = 30.0
214
+ mock_config_class.return_value = mock_config
215
+
216
+ mock_api_client = Mock()
217
+ mock_api_client_class.return_value = mock_api_client
218
+
219
+ mock_scenario_inst = Mock()
220
+ mock_scenario_class.return_value = mock_scenario_inst
221
+
222
+ mock_run_inst = Mock()
223
+ mock_run_class.return_value = mock_run_inst
224
+
225
+ # In-test
226
+ with humalab.init(scenario=scenario_data, seed=seed) as run:
227
+ # Post-condition
228
+ mock_scenario_inst.init.assert_called_once()
229
+ call_kwargs = mock_scenario_inst.init.call_args.kwargs
230
+ self.assertEqual(call_kwargs['seed'], seed)
231
+ self.assertEqual(call_kwargs['scenario'], scenario_data)
232
+
233
+ mock_run_inst.finish.assert_called_once()
234
+
235
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
236
+ @patch('humalab_sdk.humalab.HumalabConfig')
237
+ @patch('humalab_sdk.humalab.Scenario')
238
+ @patch('humalab_sdk.humalab.Run')
239
+ def test_init_should_pull_scenario_from_api_when_scenario_id_provided(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
240
+ """Test that init() pulls scenario from API when scenario_id is provided."""
241
+ # Pre-condition
242
+ scenario_id = "test-scenario-id"
243
+ yaml_content = "scenario: from_api"
244
+
245
+ mock_config = Mock()
246
+ mock_config.entity = "test_entity"
247
+ mock_config.base_url = "http://localhost:8000"
248
+ mock_config.api_key = "test_key"
249
+ mock_config.timeout = 30.0
250
+ mock_config_class.return_value = mock_config
251
+
252
+ mock_api_client = Mock()
253
+ mock_api_client.get_scenario.return_value = {"yaml_content": yaml_content}
254
+ mock_api_client_class.return_value = mock_api_client
255
+
256
+ mock_scenario_inst = Mock()
257
+ mock_scenario_class.return_value = mock_scenario_inst
258
+
259
+ mock_run_inst = Mock()
260
+ mock_run_class.return_value = mock_run_inst
261
+
262
+ # In-test
263
+ with humalab.init(scenario_id=scenario_id) as run:
264
+ # Post-condition
265
+ mock_api_client.get_scenario.assert_called_once_with(uuid=scenario_id)
266
+ mock_scenario_inst.init.assert_called_once()
267
+ call_kwargs = mock_scenario_inst.init.call_args.kwargs
268
+ self.assertEqual(call_kwargs['scenario'], yaml_content)
269
+
270
+ mock_run_inst.finish.assert_called_once()
271
+
272
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
273
+ @patch('humalab_sdk.humalab.HumalabConfig')
274
+ @patch('humalab_sdk.humalab.Scenario')
275
+ @patch('humalab_sdk.humalab.Run')
276
+ def test_init_should_set_global_cur_run(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
277
+ """Test that init() sets the global _cur_run variable."""
278
+ # Pre-condition
279
+ mock_config = Mock()
280
+ mock_config.entity = "test_entity"
281
+ mock_config.base_url = "http://localhost:8000"
282
+ mock_config.api_key = "test_key"
283
+ mock_config.timeout = 30.0
284
+ mock_config_class.return_value = mock_config
285
+
286
+ mock_api_client = Mock()
287
+ mock_api_client_class.return_value = mock_api_client
288
+
289
+ mock_scenario_inst = Mock()
290
+ mock_scenario_class.return_value = mock_scenario_inst
291
+
292
+ mock_run_inst = Mock()
293
+ mock_run_class.return_value = mock_run_inst
294
+
295
+ # In-test
296
+ self.assertIsNone(humalab._cur_run)
297
+ with humalab.init() as run:
298
+ # Post-condition
299
+ self.assertEqual(humalab._cur_run, mock_run_inst)
300
+
301
+ mock_run_inst.finish.assert_called_once()
302
+
303
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
304
+ @patch('humalab_sdk.humalab.HumalabConfig')
305
+ @patch('humalab_sdk.humalab.Scenario')
306
+ @patch('humalab_sdk.humalab.Run')
307
+ def test_init_should_call_finish_on_exception(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
308
+ """Test that init() calls finish even when exception occurs in context."""
309
+ # Pre-condition
310
+ mock_config = Mock()
311
+ mock_config.entity = "test_entity"
312
+ mock_config.base_url = "http://localhost:8000"
313
+ mock_config.api_key = "test_key"
314
+ mock_config.timeout = 30.0
315
+ mock_config_class.return_value = mock_config
316
+
317
+ mock_api_client = Mock()
318
+ mock_api_client_class.return_value = mock_api_client
319
+
320
+ mock_scenario_inst = Mock()
321
+ mock_scenario_class.return_value = mock_scenario_inst
322
+
323
+ mock_run_inst = Mock()
324
+ mock_run_class.return_value = mock_run_inst
325
+
326
+ # In-test & Post-condition
327
+ with self.assertRaises(RuntimeError):
328
+ with humalab.init() as run:
329
+ raise RuntimeError("Test exception")
330
+
331
+ # Verify finish was still called
332
+ mock_run_inst.finish.assert_called_once()
333
+
334
+ @patch('humalab_sdk.humalab.HumaLabApiClient')
335
+ @patch('humalab_sdk.humalab.HumalabConfig')
336
+ @patch('humalab_sdk.humalab.Scenario')
337
+ @patch('humalab_sdk.humalab.Run')
338
+ def test_init_should_create_api_client_with_custom_parameters(self, mock_run_class, mock_scenario_class, mock_config_class, mock_api_client_class):
339
+ """Test that init() creates API client with custom base_url, api_key, and timeout."""
340
+ # Pre-condition
341
+ base_url = "http://custom:9000"
342
+ api_key = "custom_key"
343
+ timeout = 120.0
344
+
345
+ mock_config = Mock()
346
+ mock_config.entity = "test_entity"
347
+ mock_config.base_url = "http://localhost:8000"
348
+ mock_config.api_key = "default_key"
349
+ mock_config.timeout = 30.0
350
+ mock_config_class.return_value = mock_config
351
+
352
+ mock_api_client = Mock()
353
+ mock_api_client_class.return_value = mock_api_client
354
+
355
+ mock_scenario_inst = Mock()
356
+ mock_scenario_class.return_value = mock_scenario_inst
357
+
358
+ mock_run_inst = Mock()
359
+ mock_run_class.return_value = mock_run_inst
360
+
361
+ # In-test
362
+ with humalab.init(base_url=base_url, api_key=api_key, timeout=timeout) as run:
363
+ # Post-condition
364
+ mock_api_client_class.assert_called_once_with(
365
+ base_url=base_url,
366
+ api_key=api_key,
367
+ timeout=timeout
368
+ )
369
+
370
+ mock_run_inst.finish.assert_called_once()
371
+
372
+ # Tests for finish function
373
+
374
+ def test_finish_should_call_finish_on_current_run_with_default_status(self):
375
+ """Test that finish() calls finish on the current run with default status."""
376
+ # Pre-condition
377
+ mock_run = Mock()
378
+ humalab._cur_run = mock_run
379
+
380
+ # In-test
381
+ humalab.finish()
382
+
383
+ # Post-condition
384
+ mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=None)
385
+
386
+ def test_finish_should_call_finish_on_current_run_with_custom_status(self):
387
+ """Test that finish() calls finish on the current run with custom status."""
388
+ # Pre-condition
389
+ mock_run = Mock()
390
+ humalab._cur_run = mock_run
391
+ status = EpisodeStatus.FAILED
392
+
393
+ # In-test
394
+ humalab.finish(status=status)
395
+
396
+ # Post-condition
397
+ mock_run.finish.assert_called_once_with(status=status, quiet=None)
398
+
399
+ def test_finish_should_call_finish_on_current_run_with_quiet_parameter(self):
400
+ """Test that finish() calls finish on the current run with quiet parameter."""
401
+ # Pre-condition
402
+ mock_run = Mock()
403
+ humalab._cur_run = mock_run
404
+ quiet = True
405
+
406
+ # In-test
407
+ humalab.finish(quiet=quiet)
408
+
409
+ # Post-condition
410
+ mock_run.finish.assert_called_once_with(status=EpisodeStatus.PASS, quiet=quiet)
411
+
412
+ def test_finish_should_do_nothing_when_no_current_run(self):
413
+ """Test that finish() does nothing when _cur_run is None."""
414
+ # Pre-condition
415
+ humalab._cur_run = None
416
+
417
+ # In-test
418
+ humalab.finish() # Should not raise any exception
419
+
420
+ # Post-condition
421
+ # No exception means success
422
+ self.assertIsNone(humalab._cur_run)
423
+
424
+ # Tests for login function
425
+
426
+ @patch('humalab_sdk.humalab.HumalabConfig')
427
+ def test_login_should_set_api_key_when_provided(self, mock_config_class):
428
+ """Test that login() sets the api_key when provided."""
429
+ # Pre-condition
430
+ mock_config = Mock()
431
+ mock_config.api_key = "old_key"
432
+ mock_config.base_url = "http://localhost:8000"
433
+ mock_config.timeout = 30.0
434
+ mock_config_class.return_value = mock_config
435
+
436
+ new_key = "new_api_key"
437
+
438
+ # In-test
439
+ result = humalab.login(api_key=new_key)
440
+
441
+ # Post-condition
442
+ self.assertTrue(result)
443
+ self.assertEqual(mock_config.api_key, new_key)
444
+
445
+ @patch('humalab_sdk.humalab.HumalabConfig')
446
+ def test_login_should_keep_existing_key_when_not_provided(self, mock_config_class):
447
+ """Test that login() keeps existing api_key when key is not provided."""
448
+ # Pre-condition
449
+ existing_key = "existing_key"
450
+ existing_url = "http://localhost:8000"
451
+ existing_timeout = 30.0
452
+ mock_config = Mock()
453
+ mock_config.api_key = existing_key
454
+ mock_config.base_url = existing_url
455
+ mock_config.timeout = existing_timeout
456
+ mock_config_class.return_value = mock_config
457
+
458
+ # In-test
459
+ result = humalab.login()
460
+
461
+ # Post-condition
462
+ self.assertTrue(result)
463
+ self.assertEqual(mock_config.api_key, existing_key)
464
+ self.assertEqual(mock_config.base_url, existing_url)
465
+ self.assertEqual(mock_config.timeout, existing_timeout)
466
+
467
+ @patch('humalab_sdk.humalab.HumalabConfig')
468
+ def test_login_should_return_true(self, mock_config_class):
469
+ """Test that login() always returns True."""
470
+ # Pre-condition
471
+ mock_config = Mock()
472
+ mock_config.api_key = "test_key"
473
+ mock_config.base_url = "http://localhost:8000"
474
+ mock_config.timeout = 30.0
475
+ mock_config_class.return_value = mock_config
476
+
477
+ # In-test
478
+ result = humalab.login()
479
+
480
+ # Post-condition
481
+ self.assertTrue(result)
482
+
483
+ @patch('humalab_sdk.humalab.HumalabConfig')
484
+ def test_login_should_accept_optional_parameters(self, mock_config_class):
485
+ """Test that login() accepts optional parameters without errors."""
486
+ # Pre-condition
487
+ mock_config = Mock()
488
+ mock_config.api_key = "old_key"
489
+ mock_config.base_url = "http://old:8000"
490
+ mock_config.timeout = 30.0
491
+ mock_config_class.return_value = mock_config
492
+
493
+ # In-test
494
+ result = humalab.login(
495
+ api_key="test_key",
496
+ relogin=True,
497
+ host="http://localhost:8000",
498
+ force=True,
499
+ timeout=60.0
500
+ )
501
+
502
+ # Post-condition
503
+ self.assertTrue(result)
504
+ self.assertEqual(mock_config.api_key, "test_key")
505
+ self.assertEqual(mock_config.base_url, "http://localhost:8000")
506
+ self.assertEqual(mock_config.timeout, 60.0)
507
+
508
+
509
+ if __name__ == "__main__":
510
+ unittest.main()
@@ -0,0 +1,11 @@
1
+ from .metric import MetricGranularity, MetricType, Metrics
2
+ from .dist_metric import DistributionMetric
3
+ from .summary import Summary
4
+
5
+ __all__ = [
6
+ "MetricGranularity",
7
+ "MetricType",
8
+ "Metrics",
9
+ "DistributionMetric",
10
+ "Summary",
11
+ ]
@@ -0,0 +1,22 @@
1
+
2
+ from humalab_sdk.metrics.metric import MetricGranularity, Metrics, MetricType
3
+
4
+ class DistributionMetric(Metrics):
5
+ def __init__(self,
6
+ name: str,
7
+ distribution_type: str,
8
+ episode_id: str,
9
+ run_id: str,
10
+ granularity: MetricGranularity = MetricGranularity.EPISODE) -> None:
11
+ """
12
+ Initialize the distribution metric.
13
+
14
+ Args:
15
+ name (str): The name of the metric.
16
+ distribution_type (str): The type of distribution (e.g., "normal", "uniform").
17
+ episode_id (str): The ID of the episode.
18
+ run_id (str): The ID of the run.
19
+ granularity (MetricGranularity): The granularity of the metric.
20
+ """
21
+ super().__init__(name, MetricType.DISTRIBUTION, episode_id=episode_id, run_id=run_id, granularity=granularity)
22
+ self.distribution_type = distribution_type
@@ -0,0 +1,129 @@
1
+ from enum import Enum
2
+ from typing import Any
3
+ from humalab_sdk.constants import EpisodeStatus
4
+
5
+
6
+ class MetricType(Enum):
7
+ DEFAULT = "default"
8
+ STREAM = "stream"
9
+ DISTRIBUTION = "distribution"
10
+ SUMMARY = "summary"
11
+
12
+
13
+ class MetricGranularity(Enum):
14
+ STEP = "step"
15
+ EPISODE = "episode"
16
+ RUN = "run"
17
+
18
+
19
+ class Metrics:
20
+ def __init__(self,
21
+ name: str,
22
+ metric_type: MetricType,
23
+ episode_id: str,
24
+ run_id: str,
25
+ granularity: MetricGranularity = MetricGranularity.STEP) -> None:
26
+ """
27
+ Base class for different types of metrics.
28
+
29
+ Args:
30
+ name (str): The name of the metric.
31
+ metric_type (MetricType): The type of the metric.
32
+ episode_id (str): The ID of the episode.
33
+ run_id (str): The ID of the run.
34
+ granularity (MetricGranularity): The granularity of the metric.
35
+ """
36
+ self._name = name
37
+ self._metric_type = metric_type
38
+ self._granularity = granularity
39
+ self._values = []
40
+ self._x_values = []
41
+ self._episode_id = episode_id
42
+ self._run_id = run_id
43
+ self._last_step = -1
44
+
45
+ def reset(self,
46
+ episode_id: str | None = None) -> None:
47
+ """Reset the metric for a new episode or run.
48
+
49
+ Args:
50
+ episode_id (str | None): Optional new episode ID. If None, keeps the current episode ID.
51
+ """
52
+ if self._granularity != MetricGranularity.RUN:
53
+ self._submit()
54
+ self._values = []
55
+ self._x_values = []
56
+ self._last_step = -1
57
+ self._episode_id = episode_id
58
+
59
+ @property
60
+ def name(self) -> str:
61
+ """The name of the metric.
62
+
63
+ Returns:
64
+ str: The name of the metric.
65
+ """
66
+ return self._name
67
+
68
+ @property
69
+ def metric_type(self) -> MetricType:
70
+ """The type of the metric.
71
+
72
+ Returns:
73
+ MetricType: The type of the metric.
74
+ """
75
+ return self._metric_type
76
+
77
+ @property
78
+ def granularity(self) -> MetricGranularity:
79
+ """The granularity of the metric.
80
+
81
+ Returns:
82
+ MetricGranularity: The granularity of the metric.
83
+ """
84
+ return self._granularity
85
+
86
+ def log(self, data: Any, step: int | None = None, replace: bool = True) -> None:
87
+ """Log a new data point for the metric. The behavior depends on the granularity.
88
+
89
+ Args:
90
+ data (Any): The data point to log.
91
+ step (int | None): The step number for STEP granularity. Must be provided if granularity is STEP.
92
+ replace (bool): Whether to replace the last logged value if logging at the same step/episode/run.
93
+ """
94
+ if self._granularity == MetricGranularity.STEP:
95
+ if step is None:
96
+ raise ValueError("step Must be provided!")
97
+ if step == self._last_step:
98
+ if replace:
99
+ self._values[-1] = data
100
+ else:
101
+ raise ValueError("Cannot log the data at the same step.")
102
+ else:
103
+ self._values.append(data)
104
+ self._x_values.append(step)
105
+ elif self._granularity == MetricGranularity.EPISODE:
106
+ if len(self._x_values) > 0 and not replace:
107
+ raise ValueError("Cannot log the data at the same episode.")
108
+ self._values = [data]
109
+ self._x_values = [self._episode_id]
110
+ else: # MetricGranularity.RUN
111
+ if len(self._values) > 0 and not replace:
112
+ raise ValueError("Cannot log the data at the same run.")
113
+ self._values = [data]
114
+ self._x_values = [self._run_id]
115
+
116
+ def _submit(self) -> None:
117
+ if not self._values:
118
+ # If there is no data to submit, then return.
119
+ return
120
+ # TODO: Implement commit logic
121
+
122
+ # Clear data after the submission.
123
+ self._values = []
124
+ self._x_values = []
125
+
126
+ def finish(self) -> None:
127
+ """Finish the metric logging and submit the final data."""
128
+ self.reset()
129
+ self._submit()