digitalhub 0.10.0b5__py3-none-any.whl → 0.10.0b6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of digitalhub might be problematic. Click here for more details.

Files changed (49) hide show
  1. digitalhub/client/_base/api_builder.py +1 -1
  2. digitalhub/client/_base/client.py +22 -0
  3. digitalhub/client/_base/params_builder.py +16 -0
  4. digitalhub/client/dhcore/api_builder.py +4 -3
  5. digitalhub/client/dhcore/client.py +4 -0
  6. digitalhub/client/dhcore/configurator.py +22 -0
  7. digitalhub/client/dhcore/params_builder.py +174 -0
  8. digitalhub/client/local/api_builder.py +4 -1
  9. digitalhub/client/local/client.py +6 -0
  10. digitalhub/client/local/params_builder.py +116 -0
  11. digitalhub/entities/_base/context/entity.py +4 -4
  12. digitalhub/entities/_base/executable/entity.py +2 -2
  13. digitalhub/entities/_base/material/entity.py +3 -3
  14. digitalhub/entities/_base/unversioned/entity.py +2 -2
  15. digitalhub/entities/_base/versioned/entity.py +2 -2
  16. digitalhub/entities/_commons/enums.py +1 -0
  17. digitalhub/entities/_commons/metrics.py +164 -0
  18. digitalhub/entities/_commons/utils.py +0 -26
  19. digitalhub/entities/_processors/base.py +527 -0
  20. digitalhub/entities/{_operations/processor.py → _processors/context.py} +85 -739
  21. digitalhub/entities/_processors/utils.py +158 -0
  22. digitalhub/entities/artifact/crud.py +10 -10
  23. digitalhub/entities/dataitem/crud.py +10 -10
  24. digitalhub/entities/function/crud.py +9 -9
  25. digitalhub/entities/model/_base/entity.py +26 -78
  26. digitalhub/entities/model/_base/status.py +1 -1
  27. digitalhub/entities/model/crud.py +10 -10
  28. digitalhub/entities/project/_base/entity.py +317 -9
  29. digitalhub/entities/project/crud.py +10 -9
  30. digitalhub/entities/run/_base/entity.py +32 -84
  31. digitalhub/entities/run/_base/status.py +1 -1
  32. digitalhub/entities/run/crud.py +8 -8
  33. digitalhub/entities/secret/_base/entity.py +3 -3
  34. digitalhub/entities/secret/crud.py +9 -9
  35. digitalhub/entities/task/_base/entity.py +4 -4
  36. digitalhub/entities/task/_base/models.py +10 -0
  37. digitalhub/entities/task/crud.py +8 -8
  38. digitalhub/entities/workflow/crud.py +9 -9
  39. digitalhub/stores/s3/enums.py +7 -7
  40. digitalhub/stores/sql/enums.py +6 -6
  41. digitalhub/utils/git_utils.py +16 -9
  42. {digitalhub-0.10.0b5.dist-info → digitalhub-0.10.0b6.dist-info}/METADATA +1 -4
  43. {digitalhub-0.10.0b5.dist-info → digitalhub-0.10.0b6.dist-info}/RECORD +46 -43
  44. digitalhub/entities/_base/project/entity.py +0 -341
  45. digitalhub/entities/_commons/models.py +0 -13
  46. digitalhub/entities/_operations/__init__.py +0 -0
  47. /digitalhub/entities/{_base/project → _processors}/__init__.py +0 -0
  48. {digitalhub-0.10.0b5.dist-info → digitalhub-0.10.0b6.dist-info}/WHEEL +0 -0
  49. {digitalhub-0.10.0b5.dist-info → digitalhub-0.10.0b6.dist-info}/licenses/LICENSE.txt +0 -0
@@ -1,11 +1,15 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import typing
4
+ from pathlib import Path
4
5
  from typing import Any
5
6
 
6
- from digitalhub.entities._base.project.entity import ProjectEntity
7
+ from digitalhub.client.api import get_client
8
+ from digitalhub.context.api import build_context
9
+ from digitalhub.entities._base.entity.entity import Entity
7
10
  from digitalhub.entities._commons.enums import EntityTypes
8
- from digitalhub.entities._operations.processor import processor
11
+ from digitalhub.entities._processors.base import base_processor
12
+ from digitalhub.entities._processors.context import context_processor
9
13
  from digitalhub.entities.artifact.crud import (
10
14
  delete_artifact,
11
15
  get_artifact,
@@ -64,9 +68,14 @@ from digitalhub.entities.workflow.crud import (
64
68
  new_workflow,
65
69
  update_workflow,
66
70
  )
67
- from digitalhub.utils.exceptions import EntityError
71
+ from digitalhub.factory.api import build_entity_from_dict
72
+ from digitalhub.utils.exceptions import BackendError, EntityAlreadyExistsError, EntityError
73
+ from digitalhub.utils.generic_utils import get_timestamp
74
+ from digitalhub.utils.io_utils import write_yaml
75
+ from digitalhub.utils.uri_utils import has_local_scheme
68
76
 
69
77
  if typing.TYPE_CHECKING:
78
+ from digitalhub.entities._base.context.entity import ContextEntity
70
79
  from digitalhub.entities._base.entity.metadata import Metadata
71
80
  from digitalhub.entities.artifact._base.entity import Artifact
72
81
  from digitalhub.entities.dataitem._base.entity import Dataitem
@@ -79,7 +88,7 @@ if typing.TYPE_CHECKING:
79
88
  from digitalhub.entities.workflow._base.entity import Workflow
80
89
 
81
90
 
82
- class Project(ProjectEntity):
91
+ class Project(Entity):
83
92
  """
84
93
  A class representing a project.
85
94
  """
@@ -96,9 +105,308 @@ class Project(ProjectEntity):
96
105
  user: str | None = None,
97
106
  local: bool = False,
98
107
  ) -> None:
99
- super().__init__(name, kind, metadata, spec, status, user, local)
100
- self.spec: ProjectSpec
101
- self.status: ProjectStatus
108
+ super().__init__(kind, metadata, spec, status, user)
109
+ self.id = name
110
+ self.name = name
111
+ self.key = base_processor.build_project_key(self.name, local=local)
112
+
113
+ self._obj_attr.extend(["id", "name"])
114
+
115
+ # Set client
116
+ self._client = get_client(local)
117
+
118
+ # Set context
119
+ build_context(self)
120
+
121
+ ##############################
122
+ # Save / Refresh / Export
123
+ ##############################
124
+
125
+ def save(self, update: bool = False) -> Project:
126
+ """
127
+ Save entity into backend.
128
+
129
+ Parameters
130
+ ----------
131
+ update : bool
132
+ If True, the object will be updated.
133
+
134
+ Returns
135
+ -------
136
+ Project
137
+ Entity saved.
138
+ """
139
+ if update:
140
+ if self._client.is_local():
141
+ self.metadata.updated = get_timestamp()
142
+ new_obj = base_processor.update_project_entity(
143
+ entity_type=self.ENTITY_TYPE,
144
+ entity_name=self.name,
145
+ entity_dict=self.to_dict(),
146
+ local=self._client.is_local(),
147
+ )
148
+ else:
149
+ new_obj = base_processor.create_project_entity(_entity=self)
150
+ self._update_attributes(new_obj)
151
+ return self
152
+
153
+ def refresh(self) -> Project:
154
+ """
155
+ Refresh object from backend.
156
+
157
+ Returns
158
+ -------
159
+ Project
160
+ Project object.
161
+ """
162
+ new_obj = base_processor.read_project_entity(
163
+ entity_type=self.ENTITY_TYPE,
164
+ entity_name=self.name,
165
+ local=self._client.is_local(),
166
+ )
167
+ self._update_attributes(new_obj)
168
+ return self
169
+
170
+ def search_entity(
171
+ self,
172
+ query: str | None = None,
173
+ entity_types: list[str] | None = None,
174
+ name: str | None = None,
175
+ kind: str | None = None,
176
+ created: str | None = None,
177
+ updated: str | None = None,
178
+ description: str | None = None,
179
+ labels: list[str] | None = None,
180
+ **kwargs,
181
+ ) -> list[ContextEntity]:
182
+ """
183
+ Search objects from backend.
184
+
185
+ Parameters
186
+ ----------
187
+ query : str
188
+ Search query.
189
+ entity_types : list[str]
190
+ Entity types.
191
+ name : str
192
+ Entity name.
193
+ kind : str
194
+ Entity kind.
195
+ created : str
196
+ Entity creation date.
197
+ updated : str
198
+ Entity update date.
199
+ description : str
200
+ Entity description.
201
+ labels : list[str]
202
+ Entity labels.
203
+ **kwargs : dict
204
+ Parameters to pass to the API call.
205
+
206
+ Returns
207
+ -------
208
+ list[ContextEntity]
209
+ List of object instances.
210
+ """
211
+ objs = context_processor.search_entity(
212
+ self.name,
213
+ query=query,
214
+ entity_types=entity_types,
215
+ name=name,
216
+ kind=kind,
217
+ created=created,
218
+ updated=updated,
219
+ description=description,
220
+ labels=labels,
221
+ **kwargs,
222
+ )
223
+ self.refresh()
224
+ return objs
225
+
226
+ def export(self) -> str:
227
+ """
228
+ Export object as a YAML file in the context folder.
229
+ If the objects are not embedded, the objects are exported as a YAML file.
230
+
231
+ Returns
232
+ -------
233
+ str
234
+ Exported filepath.
235
+ """
236
+ obj = self._refresh_to_dict()
237
+ pth = Path(self.spec.context) / f"{self.ENTITY_TYPE}s-{self.name}.yaml"
238
+ obj = self._export_not_embedded(obj)
239
+ write_yaml(pth, obj)
240
+ return str(pth)
241
+
242
+ def _refresh_to_dict(self) -> dict:
243
+ """
244
+ Try to refresh object to collect entities related to project.
245
+
246
+ Returns
247
+ -------
248
+ dict
249
+ Entity object in dictionary format.
250
+ """
251
+ try:
252
+ return self.refresh().to_dict()
253
+ except BackendError:
254
+ return self.to_dict()
255
+
256
+ def _export_not_embedded(self, obj: dict) -> dict:
257
+ """
258
+ Export project objects if not embedded.
259
+
260
+ Parameters
261
+ ----------
262
+ obj : dict
263
+ Project object in dictionary format.
264
+
265
+ Returns
266
+ -------
267
+ dict
268
+ Updatated project object in dictionary format with referenced entities.
269
+ """
270
+ # Cycle over entity types
271
+ for entity_type in self._get_entity_types():
272
+ # Entity types are stored as a list of entities
273
+ for idx, entity in enumerate(obj.get("spec", {}).get(entity_type, [])):
274
+ # Export entity if not embedded is in metadata, else do nothing
275
+ if not self._is_embedded(entity):
276
+ # Get entity object from backend
277
+ ent = context_processor.read_context_entity(entity["key"])
278
+
279
+ # Export and store ref in object metadata inside project
280
+ pth = ent.export()
281
+ obj["spec"][entity_type][idx]["metadata"]["ref"] = pth
282
+
283
+ # Return updated object
284
+ return obj
285
+
286
+ def _import_entities(self, obj: dict) -> None:
287
+ """
288
+ Import project entities.
289
+
290
+ Parameters
291
+ ----------
292
+ obj : dict
293
+ Project object in dictionary format.
294
+
295
+ Returns
296
+ -------
297
+ None
298
+ """
299
+ entity_types = self._get_entity_types()
300
+
301
+ # Cycle over entity types
302
+ for entity_type in entity_types:
303
+ # Entity types are stored as a list of entities
304
+ for entity in obj.get("spec", {}).get(entity_type, []):
305
+ embedded = self._is_embedded(entity)
306
+ ref = entity["metadata"].get("ref")
307
+
308
+ # Import entity if not embedded and there is a ref
309
+ if not embedded and ref is not None:
310
+ # Import entity from local ref
311
+ if has_local_scheme(ref):
312
+ try:
313
+ # Artifacts, Dataitems and Models
314
+ if entity_type in entity_types[:3]:
315
+ context_processor.import_context_entity(ref)
316
+
317
+ # Functions and Workflows
318
+ elif entity_type in entity_types[3:]:
319
+ context_processor.import_executable_entity(ref)
320
+
321
+ except FileNotFoundError:
322
+ msg = f"File not found: {ref}."
323
+ raise EntityError(msg)
324
+
325
+ # If entity is embedded, create it and try to save
326
+ elif embedded:
327
+ # It's possible that embedded field in metadata is not shown
328
+ if entity["metadata"].get("embedded") is None:
329
+ entity["metadata"]["embedded"] = True
330
+
331
+ try:
332
+ build_entity_from_dict(entity).save()
333
+ except EntityAlreadyExistsError:
334
+ pass
335
+
336
+ def _load_entities(self, obj: dict) -> None:
337
+ """
338
+ Load project entities.
339
+
340
+ Parameters
341
+ ----------
342
+ obj : dict
343
+ Project object in dictionary format.
344
+
345
+ Returns
346
+ -------
347
+ None
348
+ """
349
+ entity_types = self._get_entity_types()
350
+
351
+ # Cycle over entity types
352
+ for entity_type in entity_types:
353
+ # Entity types are stored as a list of entities
354
+ for entity in obj.get("spec", {}).get(entity_type, []):
355
+ embedded = self._is_embedded(entity)
356
+ ref = entity["metadata"].get("ref")
357
+
358
+ # Load entity if not embedded and there is a ref
359
+ if not embedded and ref is not None:
360
+ # Load entity from local ref
361
+ if has_local_scheme(ref):
362
+ try:
363
+ # Artifacts, Dataitems and Models
364
+ if entity_type in entity_types[:3]:
365
+ context_processor.load_context_entity(ref)
366
+
367
+ # Functions and Workflows
368
+ elif entity_type in entity_types[3:]:
369
+ context_processor.load_executable_entity(ref)
370
+
371
+ except FileNotFoundError:
372
+ msg = f"File not found: {ref}."
373
+ raise EntityError(msg)
374
+
375
+ def _is_embedded(self, entity: dict) -> bool:
376
+ """
377
+ Check if entity is embedded.
378
+
379
+ Parameters
380
+ ----------
381
+ entity : dict
382
+ Entity in dictionary format.
383
+
384
+ Returns
385
+ -------
386
+ bool
387
+ True if entity is embedded.
388
+ """
389
+ metadata_embedded = entity["metadata"].get("embedded", False)
390
+ no_status = entity.get("status", None) is None
391
+ no_spec = entity.get("spec", None) is None
392
+ return metadata_embedded or not (no_status and no_spec)
393
+
394
+ def _get_entity_types(self) -> list[str]:
395
+ """
396
+ Get entity types.
397
+
398
+ Returns
399
+ -------
400
+ list
401
+ Entity types.
402
+ """
403
+ return [
404
+ f"{EntityTypes.ARTIFACT.value}s",
405
+ f"{EntityTypes.DATAITEM.value}s",
406
+ f"{EntityTypes.MODEL.value}s",
407
+ f"{EntityTypes.FUNCTION.value}s",
408
+ f"{EntityTypes.WORKFLOW.value}s",
409
+ ]
102
410
 
103
411
  ##############################
104
412
  # Artifacts
@@ -1859,7 +2167,7 @@ class Project(ProjectEntity):
1859
2167
  -------
1860
2168
  None
1861
2169
  """
1862
- return processor.share_project_entity(
2170
+ return base_processor.share_project_entity(
1863
2171
  entity_type=self.ENTITY_TYPE,
1864
2172
  entity_name=self.name,
1865
2173
  user=user,
@@ -1879,7 +2187,7 @@ class Project(ProjectEntity):
1879
2187
  -------
1880
2188
  None
1881
2189
  """
1882
- return processor.share_project_entity(
2190
+ return base_processor.share_project_entity(
1883
2191
  entity_type=self.ENTITY_TYPE,
1884
2192
  entity_name=self.name,
1885
2193
  user=user,
@@ -3,7 +3,8 @@ from __future__ import annotations
3
3
  import typing
4
4
 
5
5
  from digitalhub.entities._commons.enums import EntityTypes
6
- from digitalhub.entities._operations.processor import processor
6
+ from digitalhub.entities._processors.base import base_processor
7
+ from digitalhub.entities._processors.context import context_processor
7
8
  from digitalhub.entities.project.utils import setup_project
8
9
  from digitalhub.utils.exceptions import BackendError
9
10
 
@@ -58,7 +59,7 @@ def new_project(
58
59
  """
59
60
  if context is None:
60
61
  context = "./"
61
- obj = processor.create_project_entity(
62
+ obj = base_processor.create_project_entity(
62
63
  name=name,
63
64
  kind="project",
64
65
  description=description,
@@ -103,7 +104,7 @@ def get_project(
103
104
  --------
104
105
  >>> obj = get_project("my-project")
105
106
  """
106
- obj = processor.read_project_entity(
107
+ obj = base_processor.read_project_entity(
107
108
  entity_type=ENTITY_TYPE,
108
109
  entity_name=name,
109
110
  local=local,
@@ -142,7 +143,7 @@ def import_project(
142
143
  --------
143
144
  >>> obj = import_project("my-project.yaml")
144
145
  """
145
- obj = processor.import_project_entity(file=file, local=local, config=config)
146
+ obj = base_processor.import_project_entity(file=file, local=local, config=config)
146
147
  return setup_project(obj, setup_kwargs)
147
148
 
148
149
 
@@ -175,7 +176,7 @@ def load_project(
175
176
  --------
176
177
  >>> obj = load_project("my-project.yaml")
177
178
  """
178
- obj = processor.load_project_entity(file=file, local=local, config=config)
179
+ obj = base_processor.load_project_entity(file=file, local=local, config=config)
179
180
  return setup_project(obj, setup_kwargs)
180
181
 
181
182
 
@@ -195,7 +196,7 @@ def list_projects(local: bool = False, **kwargs) -> list[Project]:
195
196
  list
196
197
  List of objects.
197
198
  """
198
- return processor.list_project_entities(local=local, **kwargs)
199
+ return base_processor.list_project_entities(local=local, **kwargs)
199
200
 
200
201
 
201
202
  def get_or_create_project(
@@ -268,7 +269,7 @@ def update_project(entity: Project, **kwargs) -> Project:
268
269
  --------
269
270
  >>> obj = update_project(obj)
270
271
  """
271
- return processor.update_project_entity(
272
+ return base_processor.update_project_entity(
272
273
  entity_type=entity.ENTITY_TYPE,
273
274
  entity_name=entity.name,
274
275
  entity_dict=entity.to_dict(),
@@ -309,7 +310,7 @@ def delete_project(
309
310
  --------
310
311
  >>> delete_project("my-project")
311
312
  """
312
- return processor.delete_project_entity(
313
+ return base_processor.delete_project_entity(
313
314
  entity_type=ENTITY_TYPE,
314
315
  entity_name=name,
315
316
  local=local,
@@ -362,7 +363,7 @@ def search_entity(
362
363
  list[ContextEntity]
363
364
  List of object instances.
364
365
  """
365
- return processor.search_entity(
366
+ return context_processor.search_entity(
366
367
  project_name,
367
368
  query=query,
368
369
  entity_types=entity_types,
@@ -5,8 +5,8 @@ import typing
5
5
 
6
6
  from digitalhub.entities._base.unversioned.entity import UnversionedEntity
7
7
  from digitalhub.entities._commons.enums import EntityTypes, State
8
- from digitalhub.entities._commons.utils import validate_metric_value
9
- from digitalhub.entities._operations.processor import processor
8
+ from digitalhub.entities._commons.metrics import MetricType, set_metrics, validate_metric_value
9
+ from digitalhub.entities._processors.context import context_processor
10
10
  from digitalhub.factory.api import (
11
11
  build_runtime,
12
12
  build_spec,
@@ -46,9 +46,6 @@ class Run(UnversionedEntity):
46
46
  self.spec: RunSpec
47
47
  self.status: RunStatus
48
48
 
49
- # Initialize metrics
50
- self._init_metrics()
51
-
52
49
  ##############################
53
50
  # Run Methods
54
51
  ##############################
@@ -141,7 +138,7 @@ class Run(UnversionedEntity):
141
138
  dict
142
139
  Run logs.
143
140
  """
144
- return processor.read_run_logs(self.project, self.ENTITY_TYPE, self.id)
141
+ return context_processor.read_run_logs(self.project, self.ENTITY_TYPE, self.id)
145
142
 
146
143
  def stop(self) -> None:
147
144
  """
@@ -152,7 +149,7 @@ class Run(UnversionedEntity):
152
149
  None
153
150
  """
154
151
  if not self.spec.local_execution:
155
- return processor.stop_run(self.project, self.ENTITY_TYPE, self.id)
152
+ return context_processor.stop_run(self.project, self.ENTITY_TYPE, self.id)
156
153
 
157
154
  def resume(self) -> None:
158
155
  """
@@ -163,12 +160,12 @@ class Run(UnversionedEntity):
163
160
  None
164
161
  """
165
162
  if not self.spec.local_execution:
166
- return processor.resume_run(self.project, self.ENTITY_TYPE, self.id)
163
+ return context_processor.resume_run(self.project, self.ENTITY_TYPE, self.id)
167
164
 
168
165
  def log_metric(
169
166
  self,
170
167
  key: str,
171
- value: list[float | int] | float | int,
168
+ value: MetricType,
172
169
  overwrite: bool = False,
173
170
  single_value: bool = False,
174
171
  ) -> None:
@@ -182,7 +179,7 @@ class Run(UnversionedEntity):
182
179
  ----------
183
180
  key : str
184
181
  Key of the metric.
185
- value : list[float | int] | float | int
182
+ value : MetricType
186
183
  Value of the metric.
187
184
  overwrite : bool
188
185
  If True, overwrite existing metric.
@@ -210,16 +207,8 @@ class Run(UnversionedEntity):
210
207
  Log a list of values and overwrite existing metric:
211
208
  >>> entity.log_metric("accuracy", [0.8, 0.9], overwrite=True)
212
209
  """
213
- value = validate_metric_value(value)
214
-
215
- if isinstance(value, list):
216
- self._handle_metric_list(key, value, overwrite)
217
- elif single_value:
218
- self._handle_metric_single(key, value, overwrite)
219
- else:
220
- self._handle_metric_list_append(key, value, overwrite)
221
-
222
- processor.update_metric(self.project, self.ENTITY_TYPE, self.id, key, self.status.metrics[key])
210
+ self._set_metrics(key, value, overwrite, single_value)
211
+ context_processor.update_metric(self.project, self.ENTITY_TYPE, self.id, key, self.status.metrics[key])
223
212
 
224
213
  ##############################
225
214
  # Helpers
@@ -340,7 +329,7 @@ class Run(UnversionedEntity):
340
329
  exec_type = get_entity_type_from_kind(exec_kind)
341
330
  string_to_split = getattr(self.spec, exec_type)
342
331
  exec_name, exec_id = string_to_split.split("://")[-1].split("/")[-1].split(":")
343
- return processor.read_context_entity(
332
+ return context_processor.read_context_entity(
344
333
  exec_name,
345
334
  entity_type=exec_type,
346
335
  project=self.project,
@@ -358,97 +347,56 @@ class Run(UnversionedEntity):
358
347
  Task from backend.
359
348
  """
360
349
  task_id = self.spec.task.split("://")[-1].split("/")[-1]
361
- return processor.read_unversioned_entity(
350
+ return context_processor.read_unversioned_entity(
362
351
  task_id,
363
352
  entity_type=EntityTypes.TASK.value,
364
353
  project=self.project,
365
354
  ).to_dict()
366
355
 
367
- def _init_metrics(self) -> None:
368
- """
369
- Initialize metrics.
370
-
371
- Returns
372
- -------
373
- None
374
- """
375
- if self.status.metrics is None:
376
- self.status.metrics = {}
377
-
378
356
  def _get_metrics(self) -> None:
379
357
  """
380
- Get run metrics from backend.
358
+ Get model metrics from backend.
381
359
 
382
360
  Returns
383
361
  -------
384
362
  None
385
363
  """
386
- self.status.metrics = processor.read_metrics(
364
+ self.status.metrics = context_processor.read_metrics(
387
365
  project=self.project,
388
366
  entity_type=self.ENTITY_TYPE,
389
367
  entity_id=self.id,
390
368
  )
391
369
 
392
- def _handle_metric_single(self, key: str, value: float | int, overwrite: bool = False) -> None:
393
- """
394
- Handle metric single value.
395
-
396
- Parameters
397
- ----------
398
- key : str
399
- Key of the metric.
400
- value : float
401
- Value of the metric.
402
- overwrite : bool
403
- If True, overwrite existing metric.
404
-
405
- Returns
406
- -------
407
- None
408
- """
409
- if key not in self.status.metrics or overwrite:
410
- self.status.metrics[key] = value
411
-
412
- def _handle_metric_list_append(self, key: str, value: float | int, overwrite: bool = False) -> None:
413
- """
414
- Handle metric list append.
415
-
416
- Parameters
417
- ----------
418
- key : str
419
- Key of the metric.
420
- value : float
421
- Value of the metric.
422
- overwrite : bool
423
- If True, overwrite existing metric.
424
-
425
- Returns
426
- -------
427
- None
428
- """
429
- if key not in self.status.metrics or overwrite:
430
- self.status.metrics[key] = [value]
431
- else:
432
- self.status.metrics[key].append(value)
433
-
434
- def _handle_metric_list(self, key: str, value: list[int | float], overwrite: bool = False) -> None:
370
+ def _set_metrics(
371
+ self,
372
+ key: str,
373
+ value: MetricType,
374
+ overwrite: bool,
375
+ single_value: bool,
376
+ ) -> None:
435
377
  """
436
- Handle metric list.
378
+ Set model metrics.
437
379
 
438
380
  Parameters
439
381
  ----------
440
382
  key : str
441
383
  Key of the metric.
442
- value : list[int | float]
384
+ value : MetricType
443
385
  Value of the metric.
444
386
  overwrite : bool
445
387
  If True, overwrite existing metric.
388
+ single_value : bool
389
+ If True, value is a single value.
446
390
 
447
391
  Returns
448
392
  -------
449
393
  None
450
394
  """
451
- if key not in self.status.metrics or overwrite:
452
- self.status.metrics[key] = value
453
- else:
454
- self.status.metrics[key].extend(value)
395
+ value = validate_metric_value(value)
396
+ self.status.metrics = set_metrics(
397
+ self.status.metrics,
398
+ key,
399
+ value,
400
+ overwrite,
401
+ single_value,
402
+ )
@@ -18,4 +18,4 @@ class RunStatus(Status):
18
18
  **kwargs,
19
19
  ) -> None:
20
20
  super().__init__(state, message, transitions, k8s, **kwargs)
21
- self.metrics = metrics
21
+ self.metrics = metrics if metrics is not None else {}