mlrun 1.7.0rc6__py3-none-any.whl → 1.7.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (59) hide show
  1. mlrun/common/constants.py +6 -0
  2. mlrun/common/schemas/__init__.py +2 -0
  3. mlrun/common/schemas/model_monitoring/__init__.py +4 -0
  4. mlrun/common/schemas/model_monitoring/constants.py +35 -18
  5. mlrun/common/schemas/project.py +1 -0
  6. mlrun/common/types.py +7 -1
  7. mlrun/config.py +11 -4
  8. mlrun/data_types/data_types.py +4 -0
  9. mlrun/datastore/alibaba_oss.py +130 -0
  10. mlrun/datastore/azure_blob.py +4 -5
  11. mlrun/datastore/base.py +22 -16
  12. mlrun/datastore/datastore.py +4 -0
  13. mlrun/datastore/google_cloud_storage.py +1 -1
  14. mlrun/datastore/sources.py +2 -3
  15. mlrun/db/base.py +14 -6
  16. mlrun/db/httpdb.py +61 -56
  17. mlrun/db/nopdb.py +3 -0
  18. mlrun/model.py +1 -0
  19. mlrun/model_monitoring/__init__.py +1 -1
  20. mlrun/model_monitoring/api.py +104 -295
  21. mlrun/model_monitoring/controller.py +25 -25
  22. mlrun/model_monitoring/db/__init__.py +16 -0
  23. mlrun/model_monitoring/{stores → db/stores}/__init__.py +43 -34
  24. mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
  25. mlrun/model_monitoring/{stores/model_endpoint_store.py → db/stores/base/store.py} +47 -6
  26. mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
  27. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +49 -0
  28. mlrun/model_monitoring/{stores → db/stores/sqldb}/models/base.py +76 -3
  29. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +68 -0
  30. mlrun/model_monitoring/{stores → db/stores/sqldb}/models/sqlite.py +13 -1
  31. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +662 -0
  32. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
  33. mlrun/model_monitoring/{stores/kv_model_endpoint_store.py → db/stores/v3io_kv/kv_store.py} +134 -3
  34. mlrun/model_monitoring/helpers.py +0 -2
  35. mlrun/model_monitoring/stream_processing.py +41 -9
  36. mlrun/model_monitoring/tracking_policy.py +7 -1
  37. mlrun/model_monitoring/writer.py +4 -36
  38. mlrun/projects/pipelines.py +13 -1
  39. mlrun/projects/project.py +109 -101
  40. mlrun/run.py +3 -1
  41. mlrun/runtimes/base.py +6 -0
  42. mlrun/runtimes/nuclio/api_gateway.py +188 -61
  43. mlrun/runtimes/nuclio/function.py +3 -0
  44. mlrun/runtimes/nuclio/serving.py +28 -32
  45. mlrun/runtimes/pod.py +26 -0
  46. mlrun/serving/server.py +4 -6
  47. mlrun/serving/states.py +34 -14
  48. mlrun/utils/helpers.py +34 -0
  49. mlrun/utils/version/version.json +2 -2
  50. {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc7.dist-info}/METADATA +14 -5
  51. {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc7.dist-info}/RECORD +55 -51
  52. mlrun/model_monitoring/batch.py +0 -933
  53. mlrun/model_monitoring/stores/models/__init__.py +0 -27
  54. mlrun/model_monitoring/stores/models/mysql.py +0 -34
  55. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -382
  56. {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc7.dist-info}/LICENSE +0 -0
  57. {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc7.dist-info}/WHEEL +0 -0
  58. {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc7.dist-info}/entry_points.txt +0 -0
  59. {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc7.dist-info}/top_level.txt +0 -0
@@ -12,15 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import base64
15
+ import typing
15
16
  from typing import Optional, Union
16
17
  from urllib.parse import urljoin
17
18
 
18
19
  import requests
20
+ from requests.auth import HTTPBasicAuth
19
21
 
20
22
  import mlrun
21
23
  import mlrun.common.schemas
22
24
 
23
- from .function import RemoteRuntime
25
+ from .function import RemoteRuntime, get_fullname
24
26
  from .serving import ServingRuntime
25
27
 
26
28
  NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_BASIC_AUTH = "basicAuth"
@@ -28,6 +30,67 @@ NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_NONE = "none"
28
30
  PROJECT_NAME_LABEL = "nuclio.io/project-name"
29
31
 
30
32
 
33
+ class APIGatewayAuthenticator(typing.Protocol):
34
+ @property
35
+ def authentication_mode(self) -> str:
36
+ return NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_NONE
37
+
38
+ @classmethod
39
+ def from_scheme(cls, api_gateway_spec: mlrun.common.schemas.APIGatewaySpec):
40
+ if (
41
+ api_gateway_spec.authenticationMode
42
+ == NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_BASIC_AUTH
43
+ ):
44
+ if api_gateway_spec.authentication:
45
+ return BasicAuth(
46
+ username=api_gateway_spec.authentication.get("username", ""),
47
+ password=api_gateway_spec.authentication.get("password", ""),
48
+ )
49
+ else:
50
+ return BasicAuth()
51
+ else:
52
+ return NoneAuth()
53
+
54
+ def to_scheme(
55
+ self,
56
+ ) -> Optional[dict[str, Optional[mlrun.common.schemas.APIGatewayBasicAuth]]]:
57
+ return None
58
+
59
+
60
+ class NoneAuth(APIGatewayAuthenticator):
61
+ """
62
+ An API gateway authenticator with no authentication.
63
+ """
64
+
65
+ pass
66
+
67
+
68
+ class BasicAuth(APIGatewayAuthenticator):
69
+ """
70
+ An API gateway authenticator with basic authentication.
71
+
72
+ :param username: (str) The username for basic authentication.
73
+ :param password: (str) The password for basic authentication.
74
+ """
75
+
76
+ def __init__(self, username=None, password=None):
77
+ self._username = username
78
+ self._password = password
79
+
80
+ @property
81
+ def authentication_mode(self) -> str:
82
+ return NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_BASIC_AUTH
83
+
84
+ def to_scheme(
85
+ self,
86
+ ) -> Optional[dict[str, Optional[mlrun.common.schemas.APIGatewayBasicAuth]]]:
87
+ return {
88
+ "authentication": mlrun.common.schemas.APIGatewayBasicAuth(
89
+ username=self._username, password=self._password
90
+ )
91
+ }
92
+
93
+
31
94
  class APIGateway:
32
95
  def __init__(
33
96
  self,
@@ -47,22 +110,34 @@ class APIGateway:
47
110
  ],
48
111
  description: str = "",
49
112
  path: str = "/",
50
- authentication_mode: Optional[
51
- str
52
- ] = NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_NONE,
113
+ authentication: Optional[APIGatewayAuthenticator] = NoneAuth(),
53
114
  host: Optional[str] = None,
54
115
  canary: Optional[list[int]] = None,
55
- username: Optional[str] = None,
56
- password: Optional[str] = None,
57
116
  ):
117
+ """
118
+ Initialize the APIGateway instance.
119
+
120
+ :param project: The project name
121
+ :param name: The name of the API gateway
122
+ :param functions: The list of functions associated with the API gateway
123
+ Can be a list of function names (["my-func1", "my-func2"])
124
+ or a list or a single entity of
125
+ :py:class:`~mlrun.runtimes.nuclio.function.RemoteRuntime` OR
126
+ :py:class:`~mlrun.runtimes.nuclio.serving.ServingRuntime`
127
+
128
+ :param description: Optional description of the API gateway
129
+ :param path: Optional path of the API gateway, default value is "/"
130
+ :param authentication: The authentication for the API gateway of type
131
+ :py:class:`~mlrun.runtimes.nuclio.api_gateway.BasicAuth`
132
+ :param host: The host of the API gateway (optional). If not set, it will be automatically generated
133
+ :param canary: The canary percents for the API gateway of type list[int]; for instance: [20,80]
134
+ """
58
135
  self.functions = None
59
136
  self._validate(
60
137
  project=project,
61
138
  functions=functions,
62
139
  name=name,
63
140
  canary=canary,
64
- username=username,
65
- password=password,
66
141
  )
67
142
  self.project = project
68
143
  self.name = name
@@ -70,14 +145,8 @@ class APIGateway:
70
145
 
71
146
  self.path = path
72
147
  self.description = description
73
- self.authentication_mode = (
74
- authentication_mode
75
- if authentication_mode
76
- else self._enrich_authentication_mode(username=username, password=password)
77
- )
78
148
  self.canary = canary
79
- self._username = username
80
- self._password = password
149
+ self.authentication = authentication
81
150
 
82
151
  def invoke(
83
152
  self,
@@ -86,24 +155,94 @@ class APIGateway:
86
155
  auth: Optional[tuple[str, str]] = None,
87
156
  **kwargs,
88
157
  ):
158
+ """
159
+ Invoke the API gateway.
160
+
161
+ :param method: (str, optional) The HTTP method for the invocation.
162
+ :param headers: (dict, optional) The HTTP headers for the invocation.
163
+ :param auth: (Optional[tuple[str, str]], optional) The authentication creds for the invocation if required.
164
+ :param kwargs: (dict) Additional keyword arguments.
165
+
166
+ :return: The response from the API gateway invocation.
167
+ """
89
168
  if not self.invoke_url:
90
- raise mlrun.errors.MLRunInvalidArgumentError(
91
- "Invocation url is not set. Set up gateway's `invoke_url` attribute."
92
- )
169
+ # try to resolve invoke_url before fail
170
+ self.sync()
171
+ if not self.invoke_url:
172
+ raise mlrun.errors.MLRunInvalidArgumentError(
173
+ "Invocation url is not set. Set up gateway's `invoke_url` attribute."
174
+ )
93
175
  if (
94
- self.authentication_mode
176
+ self.authentication.authentication_mode
95
177
  == NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_BASIC_AUTH
96
178
  and not auth
97
179
  ):
98
180
  raise mlrun.errors.MLRunInvalidArgumentError(
99
181
  "API Gateway invocation requires authentication. Please pass credentials"
100
182
  )
101
- if auth:
102
- headers["Authorization"] = self._generate_basic_auth(*auth)
103
183
  return requests.request(
104
- method=method, url=self.invoke_url, headers=headers, **kwargs
184
+ method=method,
185
+ url=self.invoke_url,
186
+ headers=headers,
187
+ **kwargs,
188
+ auth=HTTPBasicAuth(*auth) if auth else None,
105
189
  )
106
190
 
191
+ def sync(self):
192
+ """
193
+ Synchronize the API gateway from the server.
194
+ """
195
+ synced_gateway = mlrun.get_run_db().get_api_gateway(self.name, self.project)
196
+ synced_gateway = self.from_scheme(synced_gateway)
197
+
198
+ self.host = synced_gateway.host
199
+ self.path = synced_gateway.path
200
+ self.authentication = synced_gateway.authentication
201
+ self.functions = synced_gateway.functions
202
+ self.canary = synced_gateway.canary
203
+ self.description = synced_gateway.description
204
+
205
+ def with_basic_auth(self, username: str, password: str):
206
+ """
207
+ Set basic authentication for the API gateway.
208
+
209
+ :param username: (str) The username for basic authentication.
210
+ :param password: (str) The password for basic authentication.
211
+ """
212
+ self.authentication = BasicAuth(username=username, password=password)
213
+
214
+ def with_canary(
215
+ self,
216
+ functions: Union[
217
+ list[str],
218
+ list[
219
+ Union[
220
+ RemoteRuntime,
221
+ ServingRuntime,
222
+ ]
223
+ ],
224
+ ],
225
+ canary: list[int],
226
+ ):
227
+ """
228
+ Set canary function for the API gateway
229
+
230
+ :param functions: The list of functions associated with the API gateway
231
+ Can be a list of function names (["my-func1", "my-func2"])
232
+ or a list of nuclio functions of types
233
+ :py:class:`~mlrun.runtimes.nuclio.function.RemoteRuntime` OR
234
+ :py:class:`~mlrun.runtimes.nuclio.serving.ServingRuntime`
235
+ :param canary: The canary percents for the API gateway of type list[int]; for instance: [20,80]
236
+
237
+ """
238
+ if len(functions) != 2:
239
+ raise mlrun.errors.MLRunInvalidArgumentError(
240
+ f"Gateway with canary can be created only with two functions, "
241
+ f"the number of functions passed is {len(functions)}"
242
+ )
243
+ self.functions = self._validate_functions(self.project, functions)
244
+ self.canary = self._validate_canary(canary)
245
+
107
246
  @classmethod
108
247
  def from_scheme(cls, api_gateway: mlrun.common.schemas.APIGateway):
109
248
  project = api_gateway.metadata.labels.get(PROJECT_NAME_LABEL)
@@ -114,7 +253,7 @@ class APIGateway:
114
253
  name=api_gateway.spec.name,
115
254
  host=api_gateway.spec.host,
116
255
  path=api_gateway.spec.path,
117
- authentication_mode=str(api_gateway.spec.authenticationMode),
256
+ authentication=APIGatewayAuthenticator.from_scheme(api_gateway.spec),
118
257
  functions=functions,
119
258
  canary=canary,
120
259
  )
@@ -141,26 +280,26 @@ class APIGateway:
141
280
  spec=mlrun.common.schemas.APIGatewaySpec(
142
281
  name=self.name,
143
282
  description=self.description,
283
+ host=self.host,
144
284
  path=self.path,
145
- authentication_mode=mlrun.common.schemas.APIGatewayAuthenticationMode.from_str(
146
- self.authentication_mode
285
+ authenticationMode=mlrun.common.schemas.APIGatewayAuthenticationMode.from_str(
286
+ self.authentication.authentication_mode
147
287
  ),
148
288
  upstreams=upstreams,
149
289
  ),
150
290
  )
151
- if (
152
- self.authentication_mode
153
- is NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_BASIC_AUTH
154
- ):
155
- api_gateway.spec.authentication = mlrun.common.schemas.APIGatewayBasicAuth(
156
- username=self._username, password=self._password
157
- )
291
+ api_gateway.spec.authentication = self.authentication.to_scheme()
158
292
  return api_gateway
159
293
 
160
294
  @property
161
295
  def invoke_url(
162
296
  self,
163
297
  ):
298
+ """
299
+ Get the invoke URL.
300
+
301
+ :return: (str) The invoke URL.
302
+ """
164
303
  return urljoin(self.host, self.path)
165
304
 
166
305
  def _validate(
@@ -180,8 +319,6 @@ class APIGateway:
180
319
  ],
181
320
  ],
182
321
  canary: Optional[list[int]] = None,
183
- username: Optional[str] = None,
184
- password: Optional[str] = None,
185
322
  ):
186
323
  if not name:
187
324
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -192,26 +329,23 @@ class APIGateway:
192
329
 
193
330
  # validating canary
194
331
  if canary:
195
- if len(self.functions) != len(canary):
196
- raise mlrun.errors.MLRunInvalidArgumentError(
197
- "Function and canary lists lengths do not match"
198
- )
199
- for canary_percent in canary:
200
- if canary_percent < 0 or canary_percent > 100:
201
- raise mlrun.errors.MLRunInvalidArgumentError(
202
- "The percentage value must be in the range from 0 to 100"
203
- )
204
- if sum(canary) != 100:
332
+ self._validate_canary(canary)
333
+
334
+ def _validate_canary(self, canary: list[int]):
335
+ if len(self.functions) != len(canary):
336
+ raise mlrun.errors.MLRunInvalidArgumentError(
337
+ "Function and canary lists lengths do not match"
338
+ )
339
+ for canary_percent in canary:
340
+ if canary_percent < 0 or canary_percent > 100:
205
341
  raise mlrun.errors.MLRunInvalidArgumentError(
206
- "The sum of canary function percents should be equal to 100"
342
+ "The percentage value must be in the range from 0 to 100"
207
343
  )
208
-
209
- # validating auth
210
- if username and not password:
211
- raise mlrun.errors.MLRunInvalidArgumentError("Password is not specified")
212
-
213
- if password and not username:
214
- raise mlrun.errors.MLRunInvalidArgumentError("Username is not specified")
344
+ if sum(canary) != 100:
345
+ raise mlrun.errors.MLRunInvalidArgumentError(
346
+ "The sum of canary function percents should be equal to 100"
347
+ )
348
+ return canary
215
349
 
216
350
  @staticmethod
217
351
  def _validate_functions(
@@ -257,17 +391,10 @@ class APIGateway:
257
391
  f"input function {function_name} "
258
392
  f"does not belong to this project"
259
393
  )
260
- function_names.append(func.uri)
394
+ nuclio_name = get_fullname(function_name, project, func.metadata.tag)
395
+ function_names.append(nuclio_name)
261
396
  return function_names
262
397
 
263
- @staticmethod
264
- def _enrich_authentication_mode(username, password):
265
- return (
266
- NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_NONE
267
- if username is not None and password is not None
268
- else NUCLIO_API_GATEWAY_AUTHENTICATION_MODE_BASIC_AUTH
269
- )
270
-
271
398
  @staticmethod
272
399
  def _generate_basic_auth(username: str, password: str):
273
400
  token = base64.b64encode(f"{username}:{password}".encode()).decode()
@@ -775,6 +775,9 @@ class RemoteRuntime(KubeResource):
775
775
  ] = self.metadata.credentials.access_key
776
776
  return runtime_env
777
777
 
778
+ def _get_serving_spec(self):
779
+ return None
780
+
778
781
  def _get_nuclio_config_spec_env(self):
779
782
  env_dict = {}
780
783
  external_source_env_dict = {}
@@ -14,8 +14,9 @@
14
14
 
15
15
  import json
16
16
  import os
17
+ import warnings
17
18
  from copy import deepcopy
18
- from typing import Union
19
+ from typing import TYPE_CHECKING, Optional, Union
19
20
 
20
21
  import nuclio
21
22
  from nuclio import KafkaTrigger
@@ -24,7 +25,6 @@ import mlrun
24
25
  import mlrun.common.schemas
25
26
  from mlrun.datastore import parse_kafka_url
26
27
  from mlrun.model import ObjectList
27
- from mlrun.model_monitoring.tracking_policy import TrackingPolicy
28
28
  from mlrun.runtimes.function_reference import FunctionReference
29
29
  from mlrun.secrets import SecretsStore
30
30
  from mlrun.serving.server import GraphServer, create_graph_server
@@ -43,6 +43,10 @@ from .function import NuclioSpec, RemoteRuntime
43
43
 
44
44
  serving_subkind = "serving_v2"
45
45
 
46
+ if TYPE_CHECKING:
47
+ # remove this block in 1.9.0
48
+ from mlrun.model_monitoring import TrackingPolicy
49
+
46
50
 
47
51
  def new_v2_model_server(
48
52
  name,
@@ -291,7 +295,9 @@ class ServingRuntime(RemoteRuntime):
291
295
  "provided class is not a router step, must provide a router class in router topology"
292
296
  )
293
297
  else:
294
- step = RouterStep(class_name=class_name, class_args=class_args)
298
+ step = RouterStep(
299
+ class_name=class_name, class_args=class_args, engine=engine
300
+ )
295
301
  self.spec.graph = step
296
302
  elif topology == StepKinds.flow:
297
303
  self.spec.graph = RootFlowStep(engine=engine)
@@ -303,12 +309,12 @@ class ServingRuntime(RemoteRuntime):
303
309
 
304
310
  def set_tracking(
305
311
  self,
306
- stream_path: str = None,
307
- batch: int = None,
308
- sample: int = None,
309
- stream_args: dict = None,
310
- tracking_policy: Union[TrackingPolicy, dict] = None,
311
- ):
312
+ stream_path: Optional[str] = None,
313
+ batch: Optional[int] = None,
314
+ sample: Optional[int] = None,
315
+ stream_args: Optional[dict] = None,
316
+ tracking_policy: Optional[Union["TrackingPolicy", dict]] = None,
317
+ ) -> None:
312
318
  """apply on your serving function to monitor a deployed model, including real-time dashboards to detect drift
313
319
  and analyze performance.
314
320
 
@@ -317,31 +323,17 @@ class ServingRuntime(RemoteRuntime):
317
323
  :param batch: Micro batch size (send micro batches of N records at a time).
318
324
  :param sample: Sample size (send only one of N records).
319
325
  :param stream_args: Stream initialization parameters, e.g. shards, retention_in_hours, ..
320
- :param tracking_policy: Tracking policy object or a dictionary that will be converted into a tracking policy
321
- object. By using TrackingPolicy, the user can apply his model monitoring requirements,
322
- such as setting the scheduling policy of the model monitoring batch job or changing
323
- the image of the model monitoring stream.
324
326
 
325
327
  example::
326
328
 
327
329
  # initialize a new serving function
328
330
  serving_fn = mlrun.import_function("hub://v2-model-server", new_name="serving")
329
- # apply model monitoring and set monitoring batch job to run every 3 hours
330
- tracking_policy = {'default_batch_intervals':"0 */3 * * *"}
331
- serving_fn.set_tracking(tracking_policy=tracking_policy)
331
+ # apply model monitoring
332
+ serving_fn.set_tracking()
332
333
 
333
334
  """
334
-
335
335
  # Applying model monitoring configurations
336
336
  self.spec.track_models = True
337
- self.spec.tracking_policy = None
338
- if tracking_policy:
339
- if isinstance(tracking_policy, dict):
340
- # Convert tracking policy dictionary into `model_monitoring.TrackingPolicy` object
341
- self.spec.tracking_policy = TrackingPolicy.from_dict(tracking_policy)
342
- else:
343
- # Tracking_policy is already a `model_monitoring.TrackingPolicy` object
344
- self.spec.tracking_policy = tracking_policy
345
337
 
346
338
  if stream_path:
347
339
  self.spec.parameters["log_stream"] = stream_path
@@ -351,6 +343,14 @@ class ServingRuntime(RemoteRuntime):
351
343
  self.spec.parameters["log_stream_sample"] = sample
352
344
  if stream_args:
353
345
  self.spec.parameters["stream_args"] = stream_args
346
+ if tracking_policy is not None:
347
+ warnings.warn(
348
+ "The `tracking_policy` argument is deprecated from version 1.7.0 "
349
+ "and has no effect. It will be removed in 1.9.0.\n"
350
+ "To set the desired model monitoring time window and schedule, use "
351
+ "the `base_period` argument in `project.enable_model_monitoring()`.",
352
+ FutureWarning,
353
+ )
354
354
 
355
355
  def add_model(
356
356
  self,
@@ -644,8 +644,7 @@ class ServingRuntime(RemoteRuntime):
644
644
  force_build=force_build,
645
645
  )
646
646
 
647
- def _get_runtime_env(self):
648
- env = super()._get_runtime_env()
647
+ def _get_serving_spec(self):
649
648
  function_name_uri_map = {f.name: f.uri(self) for f in self.spec.function_refs}
650
649
 
651
650
  serving_spec = {
@@ -658,9 +657,7 @@ class ServingRuntime(RemoteRuntime):
658
657
  "graph_initializer": self.spec.graph_initializer,
659
658
  "error_stream": self.spec.error_stream,
660
659
  "track_models": self.spec.track_models,
661
- "tracking_policy": self.spec.tracking_policy.to_dict()
662
- if self.spec.tracking_policy
663
- else None,
660
+ "tracking_policy": None,
664
661
  "default_content_type": self.spec.default_content_type,
665
662
  }
666
663
 
@@ -668,8 +665,7 @@ class ServingRuntime(RemoteRuntime):
668
665
  self._secrets = SecretsStore.from_list(self.spec.secret_sources)
669
666
  serving_spec["secret_sources"] = self._secrets.to_serial()
670
667
 
671
- env["SERVING_SPEC_ENV"] = json.dumps(serving_spec)
672
- return env
668
+ return json.dumps(serving_spec)
673
669
 
674
670
  def to_mock_server(
675
671
  self,
mlrun/runtimes/pod.py CHANGED
@@ -1057,6 +1057,32 @@ class KubeResource(BaseRuntime):
1057
1057
  return True
1058
1058
  return False
1059
1059
 
1060
+ def enrich_runtime_spec(
1061
+ self,
1062
+ project_node_selector: dict[str, str],
1063
+ ):
1064
+ """
1065
+ Enriches the runtime spec with the project-level node selector.
1066
+
1067
+ This method merges the project-level node selector with the existing function node_selector.
1068
+ The merge logic used here combines the two dictionaries, giving precedence to
1069
+ the keys in the runtime node_selector. If there are conflicting keys between the
1070
+ two dictionaries, the values from self.spec.node_selector will overwrite the
1071
+ values from project_node_selector.
1072
+
1073
+ Example:
1074
+ Suppose self.spec.node_selector = {"type": "gpu", "zone": "us-east-1"}
1075
+ and project_node_selector = {"type": "cpu", "environment": "production"}.
1076
+ After the merge, the resulting node_selector will be:
1077
+ {"type": "gpu", "zone": "us-east-1", "environment": "production"}
1078
+
1079
+ Note:
1080
+ - The merge uses the ** operator, also known as the "unpacking" operator in Python,
1081
+ combining key-value pairs from each dictionary. Later dictionaries take precedence
1082
+ when there are conflicting keys.
1083
+ """
1084
+ self.spec.node_selector = {**project_node_selector, **self.spec.node_selector}
1085
+
1060
1086
  def _set_env(self, name, value=None, value_from=None):
1061
1087
  new_var = k8s_client.V1EnvVar(name=name, value=value, value_from=value_from)
1062
1088
  i = 0
mlrun/serving/server.py CHANGED
@@ -23,6 +23,7 @@ import uuid
23
23
  from typing import Optional, Union
24
24
 
25
25
  import mlrun
26
+ import mlrun.common.constants
26
27
  import mlrun.common.helpers
27
28
  import mlrun.model_monitoring
28
29
  from mlrun.config import config
@@ -311,11 +312,8 @@ class GraphServer(ModelObj):
311
312
  def v2_serving_init(context, namespace=None):
312
313
  """hook for nuclio init_context()"""
313
314
 
314
- data = os.environ.get("SERVING_SPEC_ENV", "")
315
- if not data:
316
- raise MLRunInvalidArgumentError("failed to find spec env var")
317
- spec = json.loads(data)
318
315
  context.logger.info("Initializing server from spec")
316
+ spec = mlrun.utils.get_serving_spec()
319
317
  server = GraphServer.from_dict(spec)
320
318
  if config.log_level.lower() == "debug":
321
319
  server.verbose = True
@@ -355,7 +353,7 @@ def v2_serving_init(context, namespace=None):
355
353
 
356
354
  async def termination_callback():
357
355
  context.logger.info("Termination callback called")
358
- await server.wait_for_completion()
356
+ server.wait_for_completion()
359
357
  context.logger.info("Termination of async flow is completed")
360
358
 
361
359
  context.platform.set_termination_callback(termination_callback)
@@ -367,7 +365,7 @@ def v2_serving_init(context, namespace=None):
367
365
 
368
366
  async def drain_callback():
369
367
  context.logger.info("Drain callback called")
370
- await server.wait_for_completion()
368
+ server.wait_for_completion()
371
369
  context.logger.info(
372
370
  "Termination of async flow is completed. Rerunning async flow."
373
371
  )
mlrun/serving/states.py CHANGED
@@ -14,7 +14,6 @@
14
14
 
15
15
  __all__ = ["TaskStep", "RouterStep", "RootFlowStep", "ErrorStep"]
16
16
 
17
- import asyncio
18
17
  import os
19
18
  import pathlib
20
19
  import traceback
@@ -591,7 +590,7 @@ class RouterStep(TaskStep):
591
590
 
592
591
  kind = "router"
593
592
  default_shape = "doubleoctagon"
594
- _dict_fields = _task_step_fields + ["routes"]
593
+ _dict_fields = _task_step_fields + ["routes", "engine"]
595
594
  _default_class = "mlrun.serving.ModelRouter"
596
595
 
597
596
  def __init__(
@@ -604,6 +603,7 @@ class RouterStep(TaskStep):
604
603
  function: str = None,
605
604
  input_path: str = None,
606
605
  result_path: str = None,
606
+ engine: str = None,
607
607
  ):
608
608
  super().__init__(
609
609
  class_name,
@@ -616,6 +616,8 @@ class RouterStep(TaskStep):
616
616
  )
617
617
  self._routes: ObjectDict = None
618
618
  self.routes = routes
619
+ self.engine = engine
620
+ self._controller = None
619
621
 
620
622
  def get_children(self):
621
623
  """get child steps (routes)"""
@@ -685,6 +687,33 @@ class RouterStep(TaskStep):
685
687
  self._set_error_handler()
686
688
  self._post_init(mode)
687
689
 
690
+ if self.engine == "async":
691
+ self._build_async_flow()
692
+ self._run_async_flow()
693
+
694
+ def _build_async_flow(self):
695
+ """initialize and build the async/storey DAG"""
696
+
697
+ self.respond()
698
+ source, self._wait_for_result = _init_async_objects(self.context, [self])
699
+ source.to(self.async_object)
700
+
701
+ self._async_flow = source
702
+
703
+ def _run_async_flow(self):
704
+ self._controller = self._async_flow.run()
705
+
706
+ def run(self, event, *args, **kwargs):
707
+ if self._controller:
708
+ # async flow (using storey)
709
+ event._awaitable_result = None
710
+ resp = self._controller.emit(
711
+ event, return_awaitable_result=self._wait_for_result
712
+ )
713
+ return resp.await_result()
714
+
715
+ return super().run(event, *args, **kwargs)
716
+
688
717
  def __getitem__(self, name):
689
718
  return self._routes[name]
690
719
 
@@ -1205,18 +1234,9 @@ class FlowStep(BaseStep):
1205
1234
  """wait for completion of run in async flows"""
1206
1235
 
1207
1236
  if self._controller:
1208
- if asyncio.iscoroutinefunction(self._controller.await_termination):
1209
-
1210
- async def terminate_and_await_termination():
1211
- if hasattr(self._controller, "terminate"):
1212
- await self._controller.terminate()
1213
- return await self._controller.await_termination()
1214
-
1215
- return terminate_and_await_termination()
1216
- else:
1217
- if hasattr(self._controller, "terminate"):
1218
- self._controller.terminate()
1219
- return self._controller.await_termination()
1237
+ if hasattr(self._controller, "terminate"):
1238
+ self._controller.terminate()
1239
+ return self._controller.await_termination()
1220
1240
 
1221
1241
  def plot(self, filename=None, format=None, source=None, targets=None, **kw):
1222
1242
  """plot/save graph using graphviz