humalab 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of humalab might be problematic. Click here for more details.

humalab/dists/uniform.py CHANGED
@@ -4,6 +4,11 @@ from typing import Any
4
4
  import numpy as np
5
5
 
6
6
  class Uniform(Distribution):
7
+ """Uniform distribution over a continuous or discrete range.
8
+
9
+ Samples values uniformly from the half-open interval [low, high). Supports
10
+ scalar outputs as well as multi-dimensional arrays with 1D, 2D, or 3D variants.
11
+ """
7
12
  def __init__(self,
8
13
  generator: np.random.Generator,
9
14
  low: float | Any,
@@ -25,6 +30,15 @@ class Uniform(Distribution):
25
30
 
26
31
  @staticmethod
27
32
  def validate(dimensions: int, *args) -> bool:
33
+ """Validate distribution parameters for the given dimensions.
34
+
35
+ Args:
36
+ dimensions (int): The number of dimensions (0 for scalar, -1 for any).
37
+ *args: The distribution parameters (low, high).
38
+
39
+ Returns:
40
+ bool: True if parameters are valid, False otherwise.
41
+ """
28
42
  arg1 = args[0]
29
43
  arg2 = args[1]
30
44
  if dimensions == 0:
@@ -46,9 +60,19 @@ class Uniform(Distribution):
46
60
  return True
47
61
 
48
62
  def _sample(self) -> int | float | np.ndarray:
63
+ """Generate a sample from the uniform distribution.
64
+
65
+ Returns:
66
+ int | float | np.ndarray: Sampled value(s) from [low, high).
67
+ """
49
68
  return self._generator.uniform(self._low, self._high, size=self._size)
50
69
 
51
70
  def __repr__(self) -> str:
71
+ """String representation of the uniform distribution.
72
+
73
+ Returns:
74
+ str: String representation showing low, high, and size.
75
+ """
52
76
  return f"Uniform(low={self._low}, high={self._high}, size={self._size})"
53
77
 
54
78
  @staticmethod
humalab/episode.py CHANGED
@@ -12,6 +12,31 @@ from humalab.utils import is_standard_type
12
12
 
13
13
 
14
14
  class Episode:
15
+ """Represents a single episode within a run.
16
+
17
+ An Episode is a context manager that tracks a single execution instance of a
18
+ scenario. It provides access to scenario configuration values, supports metric
19
+ logging, and manages episode lifecycle with various completion statuses.
20
+
21
+ Episodes can be finished with different statuses:
22
+ - SUCCESS: Episode completed successfully
23
+ - FAILED: Episode failed
24
+ - CANCELED: Episode was discarded/canceled
25
+ - ERRORED: Episode encountered an error
26
+
27
+ Use as a context manager to automatically handle episode lifecycle:
28
+ with episode:
29
+ # Your code here
30
+ pass
31
+
32
+ Attributes:
33
+ run_id (str): The unique identifier of the parent run.
34
+ episode_id (str): The unique identifier for this episode.
35
+ scenario (DictConfig | ListConfig): The resolved scenario configuration.
36
+ status (EpisodeStatus): The current status of the episode.
37
+ episode_vals (dict): The sampled values from scenario distributions.
38
+ is_finished (bool): Whether the episode has been finalized.
39
+ """
15
40
  def __init__(self,
16
41
  run_id: str,
17
42
  episode_id: str,
@@ -35,32 +60,64 @@ class Episode:
35
60
 
36
61
  @property
37
62
  def run_id(self) -> str:
63
+ """The unique identifier of the parent run.
64
+
65
+ Returns:
66
+ str: The run ID.
67
+ """
38
68
  return self._run_id
39
-
69
+
40
70
  @property
41
71
  def episode_id(self) -> str:
72
+ """The unique identifier for this episode.
73
+
74
+ Returns:
75
+ str: The episode ID.
76
+ """
42
77
  return self._episode_id
43
78
 
44
79
  @property
45
80
  def scenario(self) -> DictConfig | ListConfig:
81
+ """The resolved scenario configuration for this episode.
82
+
83
+ Returns:
84
+ DictConfig | ListConfig: The scenario configuration.
85
+ """
46
86
  return self._scenario_conf
47
-
87
+
48
88
  @property
49
89
  def status(self) -> EpisodeStatus:
90
+ """The current status of the episode.
91
+
92
+ Returns:
93
+ EpisodeStatus: The episode status.
94
+ """
50
95
  return self._episode_status
51
-
96
+
52
97
  @property
53
98
  def episode_vals(self) -> dict:
99
+ """The sampled values from scenario distributions.
100
+
101
+ Returns:
102
+ dict: Dictionary mapping scenario variable names to their sampled values.
103
+ """
54
104
  return self._episode_vals
55
-
105
+
56
106
  @property
57
107
  def is_finished(self) -> bool:
108
+ """Whether the episode has been finalized.
109
+
110
+ Returns:
111
+ bool: True if the episode is finished, False otherwise.
112
+ """
58
113
  return self._is_finished
59
114
 
60
115
  def __enter__(self):
116
+ """Enter the episode context."""
61
117
  return self
62
118
 
63
119
  def __exit__(self, exception_type, exception_value, exception_traceback):
120
+ """Exit the episode context and finalize the episode."""
64
121
  if self._is_finished:
65
122
  return
66
123
  if exception_type is not None:
@@ -70,16 +127,51 @@ class Episode:
70
127
  self.finish(status=EpisodeStatus.SUCCESS)
71
128
 
72
129
  def __getattr__(self, name: Any) -> Any:
130
+ """Access scenario configuration values as attributes.
131
+
132
+ Allows accessing scenario configuration using dot notation (e.g., episode.my_param).
133
+
134
+ Args:
135
+ name (Any): The attribute/key name from scenario configuration.
136
+
137
+ Returns:
138
+ Any: The value from scenario configuration.
139
+
140
+ Raises:
141
+ AttributeError: If the attribute is not in scenario configuration.
142
+ """
73
143
  if name in self._scenario_conf:
74
144
  return self._scenario_conf[name]
75
145
  raise AttributeError(f"'Scenario' object has no attribute '{name}'")
76
146
 
77
147
  def __getitem__(self, key: Any) -> Any:
148
+ """Access scenario configuration values using subscript notation.
149
+
150
+ Allows accessing scenario configuration using bracket notation (e.g., episode['my_param']).
151
+
152
+ Args:
153
+ key (Any): The key name from scenario configuration.
154
+
155
+ Returns:
156
+ Any: The value from scenario configuration.
157
+
158
+ Raises:
159
+ KeyError: If the key is not in scenario configuration.
160
+ """
78
161
  if key in self._scenario_conf:
79
162
  return self._scenario_conf[key]
80
163
  raise KeyError(f"'Scenario' object has no key '{key}'")
81
164
 
82
165
  def add_metric(self, name: str, metric: Metrics) -> None:
166
+ """Add a metric to track for this episode.
167
+
168
+ Args:
169
+ name (str): The name of the metric.
170
+ metric (Metrics): The metric instance to add.
171
+
172
+ Raises:
173
+ ValueError: If the name is already used.
174
+ """
83
175
  if name in self._logs:
84
176
  raise ValueError(f"{name} is a reserved name and is not allowed.")
85
177
  self._logs[name] = metric
@@ -101,6 +193,16 @@ class Episode:
101
193
  )
102
194
 
103
195
  def log(self, data: dict, x: dict | None = None, replace: bool = False) -> None:
196
+ """Log data points or values for the episode.
197
+
198
+ Args:
199
+ data (dict): Dictionary of key-value pairs to log.
200
+ x (dict | None): Optional dictionary of x-axis values for each key.
201
+ replace (bool): Whether to replace existing values. Defaults to False.
202
+
203
+ Raises:
204
+ ValueError: If a key is reserved or logging fails.
205
+ """
104
206
  for key, value in data.items():
105
207
  if key in RESERVED_NAMES:
106
208
  raise ValueError(f"{key} is a reserved name and is not allowed.")
@@ -127,15 +229,24 @@ class Episode:
127
229
  return OmegaConf.to_yaml(self._scenario_conf)
128
230
 
129
231
  def discard(self) -> None:
232
+ """Mark the episode as discarded/canceled."""
130
233
  self._finish(EpisodeStatus.CANCELED)
131
234
 
132
235
  def success(self) -> None:
236
+ """Mark the episode as successfully completed."""
133
237
  self._finish(EpisodeStatus.SUCCESS)
134
-
238
+
135
239
  def fail(self) -> None:
240
+ """Mark the episode as failed."""
136
241
  self._finish(EpisodeStatus.FAILED)
137
242
 
138
243
  def finish(self, status: EpisodeStatus, err_msg: str | None = None) -> None:
244
+ """Finish the episode with a specific status.
245
+
246
+ Args:
247
+ status (EpisodeStatus): The final status of the episode.
248
+ err_msg (str | None): Optional error message if the episode errored.
249
+ """
139
250
  if self._is_finished:
140
251
  return
141
252
  self._is_finished = True
@@ -155,7 +266,7 @@ class Episode:
155
266
  pickled = pickle.dumps(metric_val["value"])
156
267
  self._api_client.upload_python(
157
268
  artifact_key=key,
158
- run_id=self._id,
269
+ run_id=self._run_id,
159
270
  episode_id=self._episode_id,
160
271
  pickled_bytes=pickled
161
272
  )
@@ -164,11 +275,10 @@ class Episode:
164
275
  pickled = pickle.dumps(metric_val)
165
276
  self._api_client.upload_metrics(
166
277
  artifact_key=key,
167
- run_id=self._id,
278
+ run_id=self._run_id,
168
279
  episode_id=self._episode_id,
169
280
  pickled_bytes=pickled,
170
281
  graph_type=value.graph_type.value,
171
- metric_dim_type=value.metric_dim_type.value
172
282
  )
173
283
  elif isinstance(value, Code):
174
284
  self._api_client.upload_code(
humalab/humalab.py CHANGED
@@ -20,8 +20,20 @@ _cur_run: Run | None = None
20
20
 
21
21
  def _pull_scenario(client: HumaLabApiClient,
22
22
  project: str,
23
- scenario: str | list | dict | None = None,
24
- scenario_id: str | None = None,) -> str | list | dict | None:
23
+ seed: int | None = None,
24
+ scenario: str | list | dict | Scenario | None = None,
25
+ scenario_id: str | None = None,) -> Scenario:
26
+ """Pull a scenario from the server if scenario_id is provided.
27
+
28
+ Args:
29
+ client (HumaLabApiClient): API client instance.
30
+ project (str): Project name.
31
+ scenario (str | list | dict | None): Local scenario configuration.
32
+ scenario_id (str | None): ID of scenario to pull from server.
33
+
34
+ Returns:
35
+ str | list | dict | None: The scenario configuration.
36
+ """
25
37
  if scenario_id is not None:
26
38
  scenario_arr = scenario_id.split(":")
27
39
  if len(scenario_arr) < 1:
@@ -31,9 +43,22 @@ def _pull_scenario(client: HumaLabApiClient,
31
43
 
32
44
  scenario_response = client.get_scenario(
33
45
  project_name=project,
34
- uuid=scenario_real_id, version=scenario_version)
35
- return scenario_response["yaml_content"]
36
- return scenario
46
+ uuid=scenario_real_id,
47
+ version=scenario_version)
48
+ final_scenario = scenario_response["yaml_content"]
49
+ else:
50
+ final_scenario = scenario
51
+
52
+ if isinstance(final_scenario, Scenario):
53
+ scenario_inst = final_scenario
54
+ else:
55
+ scenario_inst = Scenario()
56
+ scenario_inst.init(scenario=final_scenario,
57
+ seed=seed,
58
+ scenario_id=scenario_id,
59
+ #num_env=num_env,
60
+ )
61
+ return scenario_inst
37
62
 
38
63
  @contextmanager
39
64
  def init(project: str | None = None,
@@ -41,7 +66,7 @@ def init(project: str | None = None,
41
66
  description: str | None = None,
42
67
  id: str | None = None,
43
68
  tags: list[str] | None = None,
44
- scenario: str | list | dict | None = None,
69
+ scenario: str | list | dict | Scenario | None = None,
45
70
  scenario_id: str | None = None,
46
71
  seed: int | None=None,
47
72
  auto_create_scenario: bool = False,
@@ -80,19 +105,14 @@ def init(project: str | None = None,
80
105
  api_client = HumaLabApiClient(base_url=base_url,
81
106
  api_key=api_key,
82
107
  timeout=timeout)
83
- final_scenario = _pull_scenario(client=api_client,
108
+ scenario_inst = _pull_scenario(client=api_client,
84
109
  project=project,
110
+ seed=seed,
85
111
  scenario=scenario,
86
112
  scenario_id=scenario_id)
87
113
 
88
114
  project_resp = api_client.create_project(name=project)
89
-
90
- scenario_inst = Scenario()
91
- scenario_inst.init(scenario=final_scenario,
92
- seed=seed,
93
- scenario_id=scenario_id,
94
- #num_env=num_env,
95
- )
115
+
96
116
  if scenario_id is None and scenario is not None and auto_create_scenario:
97
117
  scenario_response = api_client.create_scenario(
98
118
  project_name=project_resp['name'],
@@ -159,10 +179,17 @@ def init(project: str | None = None,
159
179
  finish(status=RunStatus.FINISHED)
160
180
 
161
181
  def discard() -> None:
182
+ """Discard the current run by finishing it with CANCELED status."""
162
183
  finish(status=RunStatus.CANCELED)
163
184
 
164
185
  def finish(status: RunStatus = RunStatus.FINISHED,
165
186
  err_msg: str | None = None) -> None:
187
+ """Finish the current run.
188
+
189
+ Args:
190
+ status (RunStatus): The final status of the run. Defaults to FINISHED.
191
+ err_msg (str | None): Optional error message if the run errored.
192
+ """
166
193
  global _cur_run
167
194
  if _cur_run:
168
195
  _cur_run.finish(status=status, err_msg=err_msg)
@@ -173,6 +200,18 @@ def login(api_key: str | None = None,
173
200
  host: str | None = None,
174
201
  force: bool | None = None,
175
202
  timeout: float | None = None) -> bool:
203
+ """Configure HumaLab authentication and connection settings.
204
+
205
+ Args:
206
+ api_key (str | None): API key for authentication.
207
+ relogin (bool | None): Unused parameter (for compatibility).
208
+ host (str | None): API host URL.
209
+ force (bool | None): Unused parameter (for compatibility).
210
+ timeout (float | None): Request timeout in seconds.
211
+
212
+ Returns:
213
+ bool: Always returns True.
214
+ """
176
215
  humalab_config = HumalabConfig()
177
216
  humalab_config.api_key = api_key or humalab_config.api_key
178
217
  humalab_config.base_url = host or humalab_config.base_url
@@ -38,13 +38,13 @@ class HumaLabApiClient:
38
38
  Initialize the HumaLab API client.
39
39
 
40
40
  Args:
41
- base_url: Base URL for the HumaLab service (defaults to localhost:8000)
41
+ base_url: Base URL for the HumaLab service (defaults to https://api.humalab.ai)
42
42
  api_key: API key for authentication (defaults to HUMALAB_API_KEY env var)
43
43
  timeout: Request timeout in seconds
44
44
  """
45
45
  humalab_config = HumalabConfig()
46
- self.base_url = base_url or os.getenv("HUMALAB_SERVICE_URL", "http://localhost:8000") or humalab_config.base_url
47
- self.api_key = api_key or os.getenv("HUMALAB_API_KEY") or humalab_config.api_key
46
+ self.base_url = base_url or humalab_config.base_url or os.getenv("HUMALAB_SERVICE_URL", "https://api.humalab.ai")
47
+ self.api_key = api_key or humalab_config.api_key or os.getenv("HUMALAB_API_KEY")
48
48
  self.timeout = timeout or humalab_config.timeout or 30.0 # Default timeout of 30 seconds
49
49
 
50
50
  # Ensure base_url ends without trailing slash
@@ -885,7 +885,6 @@ class HumaLabApiClient:
885
885
  run_id: str,
886
886
  pickled_bytes: bytes,
887
887
  graph_type: str,
888
- metric_dim_type: str
889
888
  ) -> Dict[str, Any]:
890
889
  """
891
890
  Upload scenario stats artifact (pickled Python dict data).
@@ -905,7 +904,6 @@ class HumaLabApiClient:
905
904
  data = {
906
905
  'artifact_key': artifact_key,
907
906
  'run_id': run_id,
908
- 'metric_dim_type': metric_dim_type,
909
907
  'graph_type': graph_type
910
908
  }
911
909
 
@@ -940,7 +938,6 @@ class HumaLabApiClient:
940
938
  artifact_key: str,
941
939
  pickled_bytes: bytes,
942
940
  graph_type: str,
943
- metric_dim_type: str,
944
941
  episode_id: str | None = None,
945
942
  ) -> Dict[str, Any]:
946
943
  """
@@ -951,7 +948,6 @@ class HumaLabApiClient:
951
948
  artifact_key: Artifact key
952
949
  pickled_bytes: Pickled metrics data as bytes
953
950
  graph_type: Optional new graph type
954
- metric_dim_type: Optional new metric dimension type
955
951
  episode_id: Optional new episode ID
956
952
 
957
953
  Returns:
@@ -960,7 +956,6 @@ class HumaLabApiClient:
960
956
  data = {
961
957
  "run_id": run_id,
962
958
  "artifact_key": artifact_key,
963
- 'metric_dim_type': metric_dim_type,
964
959
  'graph_type': graph_type
965
960
  }
966
961
  files = {'file': pickled_bytes}
humalab/humalab_config.py CHANGED
@@ -3,6 +3,18 @@ import yaml
3
3
  import os
4
4
 
5
5
  class HumalabConfig:
6
+ """Manages HumaLab SDK configuration settings.
7
+
8
+ Configuration is stored in ~/.humalab/config.yaml and includes workspace path,
9
+ API credentials, and connection settings. Values are automatically loaded on
10
+ initialization and saved when modified through property setters.
11
+
12
+ Attributes:
13
+ workspace_path (str): The local workspace directory path.
14
+ base_url (str): The HumaLab API base URL.
15
+ api_key (str): The API key for authentication.
16
+ timeout (float): Request timeout in seconds.
17
+ """
6
18
  def __init__(self):
7
19
  self._config = {
8
20
  "workspace_path": "",
@@ -17,6 +29,7 @@ class HumalabConfig:
17
29
  self._load_config()
18
30
 
19
31
  def _load_config(self):
32
+ """Load configuration from ~/.humalab/config.yaml."""
20
33
  home_path = Path.home()
21
34
  config_path = home_path / ".humalab" / "config.yaml"
22
35
  if not config_path.exists():
@@ -30,10 +43,16 @@ class HumalabConfig:
30
43
  self._timeout = self._config["timeout"] if self._config and "timeout" in self._config else 30.0
31
44
 
32
45
  def _save(self) -> None:
46
+ """Save current configuration to ~/.humalab/config.yaml."""
33
47
  yaml.dump(self._config, open(Path.home() / ".humalab" / "config.yaml", "w"))
34
48
 
35
49
  @property
36
50
  def workspace_path(self) -> str:
51
+ """The local workspace directory path.
52
+
53
+ Returns:
54
+ str: The workspace path.
55
+ """
37
56
  return str(self._workspace_path)
38
57
 
39
58
  @workspace_path.setter
@@ -44,30 +63,60 @@ class HumalabConfig:
44
63
 
45
64
  @property
46
65
  def base_url(self) -> str:
66
+ """The HumaLab API base URL.
67
+
68
+ Returns:
69
+ str: The base URL.
70
+ """
47
71
  return str(self._base_url)
48
72
 
49
73
  @base_url.setter
50
74
  def base_url(self, base_url: str) -> None:
75
+ """Set the HumaLab API base URL and save to config.
76
+
77
+ Args:
78
+ base_url (str): The new base URL.
79
+ """
51
80
  self._base_url = base_url
52
81
  self._config["base_url"] = base_url
53
82
  self._save()
54
83
 
55
84
  @property
56
85
  def api_key(self) -> str:
86
+ """The API key for authentication.
87
+
88
+ Returns:
89
+ str: The API key.
90
+ """
57
91
  return str(self._api_key)
58
92
 
59
93
  @api_key.setter
60
94
  def api_key(self, api_key: str) -> None:
95
+ """Set the API key and save to config.
96
+
97
+ Args:
98
+ api_key (str): The new API key.
99
+ """
61
100
  self._api_key = api_key
62
101
  self._config["api_key"] = api_key
63
102
  self._save()
64
103
 
65
104
  @property
66
105
  def timeout(self) -> float:
106
+ """Request timeout in seconds.
107
+
108
+ Returns:
109
+ float: The timeout value.
110
+ """
67
111
  return self._timeout
68
112
 
69
113
  @timeout.setter
70
114
  def timeout(self, timeout: float) -> None:
115
+ """Set the request timeout and save to config.
116
+
117
+ Args:
118
+ timeout (float): The new timeout in seconds.
119
+ """
71
120
  self._timeout = timeout
72
121
  self._config["timeout"] = timeout
73
122
  self._save()
@@ -1,3 +1,9 @@
1
+ """Metrics tracking and management.
2
+
3
+ This module provides classes for tracking various types of metrics during runs and episodes,
4
+ including general metrics, summary statistics, code artifacts, and scenario statistics.
5
+ """
6
+
1
7
  from .metric import Metrics
2
8
  from .code import Code
3
9
  from .scenario_stats import ScenarioStats
humalab/metrics/code.py CHANGED
@@ -1,7 +1,18 @@
1
1
  class Code:
2
- """Class for logging code artifacts."""
3
- def __init__(self,
4
- run_id: str,
2
+ """Class for logging code artifacts.
3
+
4
+ Code artifacts capture source code or configuration files associated with
5
+ runs or episodes. They are stored as text content and can be retrieved later
6
+ for reproducibility and debugging purposes.
7
+
8
+ Attributes:
9
+ run_id (str): The unique identifier of the associated run.
10
+ key (str): The artifact key/name for this code.
11
+ code_content (str): The actual code or text content.
12
+ episode_id (str | None): Optional episode identifier if scoped to an episode.
13
+ """
14
+ def __init__(self,
15
+ run_id: str,
5
16
  key: str,
6
17
  code_content: str,
7
18
  episode_id: str | None = None) -> None:
@@ -13,16 +24,36 @@ class Code:
13
24
 
14
25
  @property
15
26
  def run_id(self) -> str:
27
+ """The unique identifier of the associated run.
28
+
29
+ Returns:
30
+ str: The run ID.
31
+ """
16
32
  return self._run_id
17
-
33
+
18
34
  @property
19
35
  def key(self) -> str:
36
+ """The artifact key/name for this code.
37
+
38
+ Returns:
39
+ str: The artifact key.
40
+ """
20
41
  return self._key
21
-
42
+
22
43
  @property
23
44
  def code_content(self) -> str:
45
+ """The actual code or text content.
46
+
47
+ Returns:
48
+ str: The code content.
49
+ """
24
50
  return self._code_content
25
-
51
+
26
52
  @property
27
53
  def episode_id(self) -> str | None:
54
+ """Optional episode identifier if scoped to an episode.
55
+
56
+ Returns:
57
+ str | None: The episode ID, or None if run-scoped.
58
+ """
28
59
  return self._episode_id
humalab/metrics/metric.py CHANGED
@@ -1,36 +1,70 @@
1
1
  from typing import Any
2
2
  from humalab.constants import MetricDimType, GraphType
3
3
 
4
+ GRAPH_TO_DIM_TYPE = {
5
+ GraphType.NUMERIC: MetricDimType.ZERO_D,
6
+ GraphType.LINE: MetricDimType.ONE_D,
7
+ GraphType.HISTOGRAM: MetricDimType.ONE_D,
8
+ GraphType.BAR: MetricDimType.ONE_D,
9
+ GraphType.GAUSSIAN: MetricDimType.ONE_D,
10
+ GraphType.SCATTER: MetricDimType.TWO_D,
11
+ GraphType.HEATMAP: MetricDimType.TWO_D,
12
+ GraphType.THREE_D_MAP: MetricDimType.THREE_D,
13
+ }
14
+
4
15
 
5
16
  class Metrics:
17
+ """Base class for tracking and logging metrics during runs and episodes.
18
+
19
+ Metrics provide a flexible way to log time-series data or aggregated values
20
+ during experiments. Data points are collected with optional x-axis values
21
+ and can be visualized using different graph types.
22
+
23
+ Subclasses should override _finalize() to implement custom processing logic.
24
+
25
+ Attributes:
26
+ graph_type (GraphType): The type of graph used for visualization.
27
+ """
6
28
  def __init__(self,
7
- metric_dim_type: MetricDimType= MetricDimType.ONE_D,
8
29
  graph_type: GraphType=GraphType.LINE) -> None:
9
- """
10
- Base class for different types of metrics.
30
+ """Initialize a new Metrics instance.
31
+
32
+ Args:
33
+ graph_type (GraphType): The type of graph to use for visualization
34
+ (e.g., LINE, BAR, HISTOGRAM, SCATTER). Defaults to LINE.
11
35
  """
12
36
  self._values = []
13
37
  self._x_values = []
14
38
  self._step = -1
15
- self._metric_dim_type = metric_dim_type
39
+ self._metric_dim_type = GRAPH_TO_DIM_TYPE.get(graph_type, MetricDimType.ONE_D)
16
40
  self._graph_type = graph_type
17
41
 
18
42
  @property
19
43
  def metric_dim_type(self) -> MetricDimType:
44
+ """The dimensionality of the metric data.
45
+
46
+ Returns:
47
+ MetricDimType: The metric dimension type.
48
+ """
20
49
  return self._metric_dim_type
21
-
50
+
22
51
  @property
23
52
  def graph_type(self) -> GraphType:
53
+ """The type of graph used for visualization.
54
+
55
+ Returns:
56
+ GraphType: The graph type.
57
+ """
24
58
  return self._graph_type
25
59
 
26
60
  def log(self, data: Any, x: Any = None, replace: bool = False) -> None:
27
- """Log a new data point for the metric. The behavior depends on the granularity.
61
+ """Log a new data point for the metric.
28
62
 
29
63
  Args:
30
64
  data (Any): The data point to log.
31
65
  x (Any | None): The x-axis value associated with the data point.
32
- if None, the current step is used.
33
- replace (bool): Whether to replace the last logged value.
66
+ If None, uses an auto-incrementing step counter.
67
+ replace (bool): Whether to replace the last logged value. Defaults to False.
34
68
  """
35
69
  if replace:
36
70
  self._values[-1] = data