mlrun 1.10.0rc16__py3-none-any.whl → 1.10.0rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (98) hide show
  1. mlrun/__init__.py +22 -2
  2. mlrun/artifacts/document.py +6 -1
  3. mlrun/artifacts/llm_prompt.py +21 -15
  4. mlrun/artifacts/model.py +3 -3
  5. mlrun/common/constants.py +9 -0
  6. mlrun/common/formatters/artifact.py +1 -0
  7. mlrun/common/model_monitoring/helpers.py +86 -0
  8. mlrun/common/schemas/__init__.py +2 -0
  9. mlrun/common/schemas/auth.py +2 -0
  10. mlrun/common/schemas/function.py +10 -0
  11. mlrun/common/schemas/hub.py +30 -18
  12. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  13. mlrun/common/schemas/model_monitoring/constants.py +30 -6
  14. mlrun/common/schemas/model_monitoring/functions.py +13 -4
  15. mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
  16. mlrun/common/schemas/pipeline.py +1 -1
  17. mlrun/common/schemas/serving.py +3 -0
  18. mlrun/common/schemas/workflow.py +1 -0
  19. mlrun/common/secrets.py +22 -1
  20. mlrun/config.py +32 -10
  21. mlrun/datastore/__init__.py +11 -3
  22. mlrun/datastore/azure_blob.py +162 -47
  23. mlrun/datastore/datastore.py +9 -4
  24. mlrun/datastore/datastore_profile.py +61 -5
  25. mlrun/datastore/model_provider/huggingface_provider.py +363 -0
  26. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  27. mlrun/datastore/model_provider/model_provider.py +211 -74
  28. mlrun/datastore/model_provider/openai_provider.py +243 -71
  29. mlrun/datastore/s3.py +24 -2
  30. mlrun/datastore/storeytargets.py +2 -3
  31. mlrun/datastore/utils.py +15 -3
  32. mlrun/db/base.py +27 -19
  33. mlrun/db/httpdb.py +57 -48
  34. mlrun/db/nopdb.py +25 -10
  35. mlrun/execution.py +55 -13
  36. mlrun/hub/__init__.py +15 -0
  37. mlrun/hub/module.py +181 -0
  38. mlrun/k8s_utils.py +105 -16
  39. mlrun/launcher/base.py +13 -6
  40. mlrun/launcher/local.py +2 -0
  41. mlrun/model.py +9 -3
  42. mlrun/model_monitoring/api.py +66 -27
  43. mlrun/model_monitoring/applications/__init__.py +1 -1
  44. mlrun/model_monitoring/applications/base.py +372 -136
  45. mlrun/model_monitoring/applications/context.py +2 -4
  46. mlrun/model_monitoring/applications/results.py +4 -7
  47. mlrun/model_monitoring/controller.py +239 -101
  48. mlrun/model_monitoring/db/_schedules.py +36 -13
  49. mlrun/model_monitoring/db/_stats.py +4 -3
  50. mlrun/model_monitoring/db/tsdb/base.py +29 -9
  51. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +4 -5
  52. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +154 -50
  53. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
  54. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
  55. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +245 -51
  56. mlrun/model_monitoring/helpers.py +28 -5
  57. mlrun/model_monitoring/stream_processing.py +45 -14
  58. mlrun/model_monitoring/writer.py +220 -1
  59. mlrun/platforms/__init__.py +3 -2
  60. mlrun/platforms/iguazio.py +7 -3
  61. mlrun/projects/operations.py +6 -1
  62. mlrun/projects/pipelines.py +2 -2
  63. mlrun/projects/project.py +128 -45
  64. mlrun/run.py +94 -17
  65. mlrun/runtimes/__init__.py +18 -0
  66. mlrun/runtimes/base.py +14 -6
  67. mlrun/runtimes/daskjob.py +1 -0
  68. mlrun/runtimes/local.py +5 -2
  69. mlrun/runtimes/mounts.py +20 -2
  70. mlrun/runtimes/nuclio/__init__.py +1 -0
  71. mlrun/runtimes/nuclio/application/application.py +147 -17
  72. mlrun/runtimes/nuclio/function.py +70 -27
  73. mlrun/runtimes/nuclio/serving.py +85 -4
  74. mlrun/runtimes/pod.py +213 -21
  75. mlrun/runtimes/utils.py +49 -9
  76. mlrun/secrets.py +54 -13
  77. mlrun/serving/remote.py +79 -6
  78. mlrun/serving/routers.py +23 -41
  79. mlrun/serving/server.py +211 -40
  80. mlrun/serving/states.py +536 -156
  81. mlrun/serving/steps.py +62 -0
  82. mlrun/serving/system_steps.py +136 -81
  83. mlrun/serving/v2_serving.py +9 -10
  84. mlrun/utils/helpers.py +212 -82
  85. mlrun/utils/logger.py +3 -1
  86. mlrun/utils/notifications/notification/base.py +18 -0
  87. mlrun/utils/notifications/notification/git.py +2 -4
  88. mlrun/utils/notifications/notification/slack.py +2 -4
  89. mlrun/utils/notifications/notification/webhook.py +2 -5
  90. mlrun/utils/notifications/notification_pusher.py +1 -1
  91. mlrun/utils/version/version.json +2 -2
  92. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/METADATA +44 -45
  93. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/RECORD +97 -92
  94. mlrun/api/schemas/__init__.py +0 -259
  95. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/WHEEL +0 -0
  96. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/entry_points.txt +0 -0
  97. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/licenses/LICENSE +0 -0
  98. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/top_level.txt +0 -0
mlrun/serving/steps.py ADDED
@@ -0,0 +1,62 @@
1
+ # Copyright 2025 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Union
16
+
17
+ import storey
18
+
19
+ import mlrun.errors
20
+
21
+
22
+ class ChoiceByField(storey.Choice):
23
+ """
24
+ Selects downstream outlets to route each event based on a predetermined field.
25
+ :param field_name: event field name that contains the step name or names of the desired outlet or outlets
26
+ """
27
+
28
+ def __init__(self, field_name: Union[str, list[str]], **kwargs):
29
+ self.field_name = field_name
30
+ super().__init__(**kwargs)
31
+
32
+ def select_outlets(self, event):
33
+ # Case 1: Missing field
34
+ if self.field_name not in event:
35
+ raise mlrun.errors.MLRunRuntimeError(
36
+ f"Field '{self.field_name}' is not contained in the event keys {list(event.keys())}."
37
+ )
38
+
39
+ outlet = event[self.field_name]
40
+
41
+ # Case 2: Field exists but is None
42
+ if outlet is None:
43
+ raise mlrun.errors.MLRunInvalidArgumentError(
44
+ f"Field '{self.field_name}' exists but its value is None."
45
+ )
46
+
47
+ # Case 3: Invalid type
48
+ if not isinstance(outlet, (str, list, tuple)):
49
+ raise mlrun.errors.MLRunInvalidArgumentTypeError(
50
+ f"Field '{self.field_name}' must be a string or list of strings "
51
+ f"but is instead of type '{type(outlet).__name__}'."
52
+ )
53
+
54
+ outlets = [outlet] if isinstance(outlet, str) else outlet
55
+
56
+ # Case 4: Empty list or tuple
57
+ if not outlets:
58
+ raise mlrun.errors.MLRunRuntimeError(
59
+ f"The value of the key '{self.field_name}' cannot be an empty {type(outlets).__name__}."
60
+ )
61
+
62
+ return outlets
@@ -11,8 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import random
15
+ from copy import copy
16
16
  from datetime import timedelta
17
17
  from typing import Any, Optional, Union
18
18
 
@@ -22,11 +22,29 @@ import storey
22
22
  import mlrun
23
23
  import mlrun.artifacts
24
24
  import mlrun.common.schemas.model_monitoring as mm_schemas
25
+ import mlrun.feature_store
25
26
  import mlrun.serving
27
+ from mlrun.common.model_monitoring.helpers import (
28
+ get_model_endpoints_creation_task_status,
29
+ )
26
30
  from mlrun.common.schemas import MonitoringData
27
31
  from mlrun.utils import get_data_from_path, logger
28
32
 
29
33
 
34
+ class MatchingEndpointsState(mlrun.common.types.StrEnum):
35
+ all_matched = "all_matched"
36
+ not_all_matched = "not_all_matched"
37
+ no_check_needed = "no_check_needed"
38
+ not_yet_checked = "not_yet_matched"
39
+
40
+ @staticmethod
41
+ def success_states() -> list[str]:
42
+ return [
43
+ MatchingEndpointsState.all_matched,
44
+ MatchingEndpointsState.no_check_needed,
45
+ ]
46
+
47
+
30
48
  class MonitoringPreProcessor(storey.MapClass):
31
49
  """preprocess step, reconstructs the serving output event body to StreamProcessingEvent schema"""
32
50
 
@@ -45,33 +63,20 @@ class MonitoringPreProcessor(storey.MapClass):
45
63
  result_path = model_monitoring_data.get(MonitoringData.RESULT_PATH)
46
64
  input_path = model_monitoring_data.get(MonitoringData.INPUT_PATH)
47
65
 
48
- result = get_data_from_path(result_path, event.body.get(model, event.body))
49
66
  output_schema = model_monitoring_data.get(MonitoringData.OUTPUTS)
50
67
  input_schema = model_monitoring_data.get(MonitoringData.INPUTS)
51
- logger.debug("output schema retrieved", output_schema=output_schema)
52
- if isinstance(result, dict):
53
- # transpose by key the outputs:
54
- outputs = self.transpose_by_key(result, output_schema)
55
- if not output_schema:
56
- logger.warn(
57
- "Output schema was not provided using Project:log_model or by ModelRunnerStep:add_model order "
58
- "may not preserved"
59
- )
60
- else:
61
- outputs = result
68
+ logger.debug(
69
+ "output and input schema retrieved",
70
+ output_schema=output_schema,
71
+ input_schema=input_schema,
72
+ )
62
73
 
63
- event_inputs = event._metadata.get("inputs", {})
64
- event_inputs = get_data_from_path(input_path, event_inputs)
65
- if isinstance(event_inputs, dict):
66
- # transpose by key the inputs:
67
- inputs = self.transpose_by_key(event_inputs, input_schema)
68
- if not input_schema:
69
- logger.warn(
70
- "Input schema was not provided using by ModelRunnerStep:add_model, order "
71
- "may not preserved"
72
- )
73
- else:
74
- inputs = event_inputs
74
+ outputs, new_output_schema = self.get_listed_data(
75
+ event.body.get(model, event.body), result_path, output_schema
76
+ )
77
+ inputs, new_input_schema = self.get_listed_data(
78
+ event._metadata.get("inputs", {}), input_path, input_schema
79
+ )
75
80
 
76
81
  if outputs and isinstance(outputs[0], list):
77
82
  if output_schema and len(output_schema) != len(outputs[0]):
@@ -96,15 +101,43 @@ class MonitoringPreProcessor(storey.MapClass):
96
101
  "outputs and inputs are not in the same length check 'input_path' and "
97
102
  "'output_path' was specified if needed"
98
103
  )
99
- request = {"inputs": inputs, "id": getattr(event, "id", None)}
100
- resp = {"outputs": outputs}
104
+ request = {
105
+ "inputs": inputs,
106
+ "id": getattr(event, "id", None),
107
+ "input_schema": new_input_schema,
108
+ }
109
+ resp = {"outputs": outputs, "output_schema": new_output_schema}
101
110
 
102
111
  return request, resp
103
112
 
113
+ def get_listed_data(
114
+ self,
115
+ raw_data: dict,
116
+ data_path: Optional[Union[list[str], str]] = None,
117
+ schema: Optional[list[str]] = None,
118
+ ):
119
+ """Get data from a path and transpose it by keys if dict is provided."""
120
+ new_schema = None
121
+ data_from_path = get_data_from_path(data_path, raw_data)
122
+ if isinstance(data_from_path, dict):
123
+ # transpose by key the inputs:
124
+ listed_data, new_schema = self.transpose_by_key(data_from_path, schema)
125
+ new_schema = new_schema or schema
126
+ if not schema:
127
+ logger.warn(
128
+ f"No schema provided through add_model(); the order of {data_from_path} "
129
+ "may not be preserved."
130
+ )
131
+ elif not isinstance(data_from_path, list):
132
+ listed_data = [data_from_path]
133
+ else:
134
+ listed_data = data_from_path
135
+ return listed_data, new_schema
136
+
104
137
  @staticmethod
105
138
  def transpose_by_key(
106
139
  data: dict, schema: Optional[Union[str, list[str]]] = None
107
- ) -> Union[list[float], list[list[float]]]:
140
+ ) -> tuple[Union[list[Any], list[list[Any]]], list[str]]:
108
141
  """
109
142
  Transpose values from a dictionary by keys.
110
143
 
@@ -136,17 +169,28 @@ class MonitoringPreProcessor(storey.MapClass):
136
169
  * If result is a matrix, returns a list of lists.
137
170
 
138
171
  :raises ValueError: If the values include a mix of scalars and lists, or if the list lengths do not match.
172
+ mlrun.MLRunInvalidArgumentError if the schema keys are not contained in the data keys.
139
173
  """
140
-
174
+ new_schema = None
175
+ # Normalize keys in data:
176
+ normalize_data = {
177
+ mlrun.feature_store.api.norm_column_name(k): copy(v)
178
+ for k, v in data.items()
179
+ }
141
180
  # Normalize schema to list
142
181
  if not schema:
143
- keys = list(data.keys())
182
+ keys = list(normalize_data.keys())
183
+ new_schema = keys
144
184
  elif isinstance(schema, str):
145
- keys = [schema]
185
+ keys = [mlrun.feature_store.api.norm_column_name(schema)]
146
186
  else:
147
- keys = schema
187
+ keys = [mlrun.feature_store.api.norm_column_name(key) for key in schema]
148
188
 
149
- values = [data[key] for key in keys]
189
+ values = [normalize_data[key] for key in keys if key in normalize_data]
190
+ if len(values) != len(keys):
191
+ raise mlrun.MLRunInvalidArgumentError(
192
+ f"Schema keys {keys} are not contained in the data keys {list(data.keys())}."
193
+ )
150
194
 
151
195
  # Detect if all are scalars ie: int,float,str
152
196
  all_scalars = all(not isinstance(v, (list, tuple, np.ndarray)) for v in values)
@@ -158,18 +202,18 @@ class MonitoringPreProcessor(storey.MapClass):
158
202
  )
159
203
 
160
204
  if all_scalars:
161
- transposed = np.array([values])
205
+ transposed = np.array([values], dtype=object)
162
206
  elif all_lists and len(keys) > 1:
163
- arrays = [np.array(v) for v in values]
207
+ arrays = [np.array(v, dtype=object) for v in values]
164
208
  mat = np.stack(arrays, axis=0)
165
209
  transposed = mat.T
166
210
  else:
167
- return values[0]
211
+ return values[0], new_schema
168
212
 
169
213
  if transposed.shape[1] == 1 and transposed.shape[0] == 1:
170
214
  # Transform [[0]] -> [0]:
171
- return transposed[:, 0].tolist()
172
- return transposed.tolist()
215
+ return transposed[:, 0].tolist(), new_schema
216
+ return transposed.tolist(), new_schema
173
217
 
174
218
  def do(self, event):
175
219
  monitoring_event_list = []
@@ -192,6 +236,12 @@ class MonitoringPreProcessor(storey.MapClass):
192
236
  request, resp = self.reconstruct_request_resp_fields(
193
237
  event, model, monitoring_data[model]
194
238
  )
239
+ if hasattr(event, "_original_timestamp"):
240
+ when = event._original_timestamp
241
+ else:
242
+ when = event._metadata.get(model, {}).get(
243
+ mm_schemas.StreamProcessingEvent.WHEN
244
+ )
195
245
  monitoring_event_list.append(
196
246
  {
197
247
  mm_schemas.StreamProcessingEvent.MODEL: model,
@@ -201,17 +251,16 @@ class MonitoringPreProcessor(storey.MapClass):
201
251
  mm_schemas.StreamProcessingEvent.MICROSEC: event._metadata.get(
202
252
  model, {}
203
253
  ).get(mm_schemas.StreamProcessingEvent.MICROSEC),
204
- mm_schemas.StreamProcessingEvent.WHEN: event._metadata.get(
205
- model, {}
206
- ).get(mm_schemas.StreamProcessingEvent.WHEN),
254
+ mm_schemas.StreamProcessingEvent.WHEN: when,
207
255
  mm_schemas.StreamProcessingEvent.ENDPOINT_ID: monitoring_data[
208
256
  model
209
257
  ].get(
210
258
  mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
211
259
  ),
212
- mm_schemas.StreamProcessingEvent.LABELS: monitoring_data[
260
+ mm_schemas.StreamProcessingEvent.LABELS: event.body[
213
261
  model
214
- ].get(mlrun.common.schemas.MonitoringData.OUTPUTS),
262
+ ].get("labels")
263
+ or {},
215
264
  mm_schemas.StreamProcessingEvent.FUNCTION_URI: self.server.function_uri
216
265
  if self.server
217
266
  else None,
@@ -236,6 +285,10 @@ class MonitoringPreProcessor(storey.MapClass):
236
285
  request, resp = self.reconstruct_request_resp_fields(
237
286
  event, model, monitoring_data[model]
238
287
  )
288
+ if hasattr(event, "_original_timestamp"):
289
+ when = event._original_timestamp
290
+ else:
291
+ when = event._metadata.get(mm_schemas.StreamProcessingEvent.WHEN)
239
292
  monitoring_event_list.append(
240
293
  {
241
294
  mm_schemas.StreamProcessingEvent.MODEL: model,
@@ -245,25 +298,20 @@ class MonitoringPreProcessor(storey.MapClass):
245
298
  mm_schemas.StreamProcessingEvent.MICROSEC: event._metadata.get(
246
299
  mm_schemas.StreamProcessingEvent.MICROSEC
247
300
  ),
248
- mm_schemas.StreamProcessingEvent.WHEN: event._metadata.get(
249
- mm_schemas.StreamProcessingEvent.WHEN
250
- ),
301
+ mm_schemas.StreamProcessingEvent.WHEN: when,
251
302
  mm_schemas.StreamProcessingEvent.ENDPOINT_ID: monitoring_data[
252
303
  model
253
304
  ].get(mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID),
254
- mm_schemas.StreamProcessingEvent.LABELS: monitoring_data[model].get(
255
- mlrun.common.schemas.MonitoringData.OUTPUTS
256
- ),
305
+ mm_schemas.StreamProcessingEvent.LABELS: event.body.get("labels")
306
+ or {},
257
307
  mm_schemas.StreamProcessingEvent.FUNCTION_URI: self.server.function_uri
258
308
  if self.server
259
309
  else None,
260
310
  mm_schemas.StreamProcessingEvent.REQUEST: request,
261
311
  mm_schemas.StreamProcessingEvent.RESPONSE: resp,
262
- mm_schemas.StreamProcessingEvent.ERROR: event.body[
312
+ mm_schemas.StreamProcessingEvent.ERROR: event.body.get(
263
313
  mm_schemas.StreamProcessingEvent.ERROR
264
- ]
265
- if mm_schemas.StreamProcessingEvent.ERROR in event.body
266
- else None,
314
+ ),
267
315
  mm_schemas.StreamProcessingEvent.METRICS: event.body[
268
316
  mm_schemas.StreamProcessingEvent.METRICS
269
317
  ]
@@ -283,6 +331,9 @@ class BackgroundTaskStatus(storey.MapClass):
283
331
 
284
332
  def __init__(self, **kwargs):
285
333
  super().__init__(**kwargs)
334
+ self.matching_endpoints = MatchingEndpointsState.not_yet_checked
335
+ self.graph_model_endpoint_uids: set = set()
336
+ self.listed_model_endpoint_uids: set = set()
286
337
  self.server: mlrun.serving.GraphServer = (
287
338
  getattr(self.context, "server", None) if self.context else None
288
339
  )
@@ -303,43 +354,47 @@ class BackgroundTaskStatus(storey.MapClass):
303
354
  )
304
355
  )
305
356
  ):
306
- background_task = mlrun.get_run_db().get_project_background_task(
307
- self.server.project, self.server.model_endpoint_creation_task_name
308
- )
309
- self._background_task_check_timestamp = mlrun.utils.now_date()
310
- self._log_background_task_state(background_task.status.state)
311
- self._background_task_state = background_task.status.state
357
+ (
358
+ self._background_task_state,
359
+ self._background_task_check_timestamp,
360
+ self.listed_model_endpoint_uids,
361
+ ) = get_model_endpoints_creation_task_status(self.server)
362
+ if (
363
+ self.listed_model_endpoint_uids
364
+ and self.matching_endpoints == MatchingEndpointsState.not_yet_checked
365
+ ):
366
+ if not self.graph_model_endpoint_uids:
367
+ self.graph_model_endpoint_uids = collect_model_endpoint_uids(
368
+ self.server
369
+ )
370
+
371
+ if self.graph_model_endpoint_uids.issubset(self.listed_model_endpoint_uids):
372
+ self.matching_endpoints = MatchingEndpointsState.all_matched
373
+ elif self.listed_model_endpoint_uids is None:
374
+ self.matching_endpoints = MatchingEndpointsState.no_check_needed
312
375
 
313
376
  if (
314
377
  self._background_task_state
315
378
  == mlrun.common.schemas.BackgroundTaskState.succeeded
379
+ and self.matching_endpoints in MatchingEndpointsState.success_states()
316
380
  ):
317
381
  return event
318
382
  else:
319
383
  return None
320
384
 
321
- def _log_background_task_state(
322
- self, background_task_state: mlrun.common.schemas.BackgroundTaskState
323
- ):
324
- logger.info(
325
- "Checking model endpoint creation task status",
326
- task_name=self.server.model_endpoint_creation_task_name,
327
- )
328
- if (
329
- background_task_state
330
- in mlrun.common.schemas.BackgroundTaskState.terminal_states()
331
- ):
332
- logger.info(
333
- f"Model endpoint creation task completed with state {background_task_state}"
334
- )
335
- else: # in progress
336
- logger.info(
337
- f"Model endpoint creation task is still in progress with the current state: "
338
- f"{background_task_state}. Events will not be monitored for the next "
339
- f"{mlrun.mlconf.model_endpoint_monitoring.model_endpoint_creation_check_period} seconds",
340
- name=self.name,
341
- background_task_check_timestamp=self._background_task_check_timestamp.isoformat(),
342
- )
385
+
386
+ def collect_model_endpoint_uids(server: mlrun.serving.GraphServer) -> set[str]:
387
+ """Collects all model endpoint UIDs from the server's graph steps."""
388
+ model_endpoint_uids = set()
389
+ for step in server.graph.steps.values():
390
+ if hasattr(step, "monitoring_data"):
391
+ for model in step.monitoring_data.keys():
392
+ uid = step.monitoring_data[model].get(
393
+ mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
394
+ )
395
+ if uid:
396
+ model_endpoint_uids.add(uid)
397
+ return model_endpoint_uids
343
398
 
344
399
 
345
400
  class SamplingStep(storey.MapClass):
@@ -24,6 +24,9 @@ import mlrun.common.schemas.model_monitoring
24
24
  import mlrun.model_monitoring
25
25
  from mlrun.utils import logger, now_date
26
26
 
27
+ from ..common.model_monitoring.helpers import (
28
+ get_model_endpoints_creation_task_status,
29
+ )
27
30
  from .utils import StepToDict, _extract_input_data, _update_result_body
28
31
 
29
32
 
@@ -474,22 +477,18 @@ class V2ModelServer(StepToDict):
474
477
  ) or getattr(self.context, "server", None)
475
478
  if not self.context.is_mock or self.context.monitoring_mock:
476
479
  if server.model_endpoint_creation_task_name:
477
- background_task = mlrun.get_run_db().get_project_background_task(
478
- server.project, server.model_endpoint_creation_task_name
479
- )
480
- logger.debug(
481
- "Checking model endpoint creation task status",
482
- task_name=server.model_endpoint_creation_task_name,
480
+ background_task_state, _, _ = get_model_endpoints_creation_task_status(
481
+ server
483
482
  )
484
483
  if (
485
- background_task.status.state
484
+ background_task_state
486
485
  in mlrun.common.schemas.BackgroundTaskState.terminal_states()
487
486
  ):
488
487
  logger.debug(
489
- f"Model endpoint creation task completed with state {background_task.status.state}"
488
+ f"Model endpoint creation task completed with state {background_task_state}"
490
489
  )
491
490
  if (
492
- background_task.status.state
491
+ background_task_state
493
492
  == mlrun.common.schemas.BackgroundTaskState.succeeded
494
493
  ):
495
494
  self._model_logger = (
@@ -504,7 +503,7 @@ class V2ModelServer(StepToDict):
504
503
  else: # in progress
505
504
  logger.debug(
506
505
  f"Model endpoint creation task is still in progress with the current state: "
507
- f"{background_task.status.state}.",
506
+ f"{background_task_state}.",
508
507
  name=self.name,
509
508
  )
510
509
  else: