mlrun 1.8.0rc10__py3-none-any.whl → 1.8.0rc13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (40) hide show
  1. mlrun/artifacts/document.py +32 -6
  2. mlrun/common/constants.py +1 -0
  3. mlrun/common/formatters/artifact.py +1 -1
  4. mlrun/common/schemas/__init__.py +2 -0
  5. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  6. mlrun/common/schemas/model_monitoring/constants.py +6 -0
  7. mlrun/common/schemas/model_monitoring/model_endpoints.py +35 -0
  8. mlrun/common/schemas/partition.py +23 -18
  9. mlrun/datastore/vectorstore.py +69 -26
  10. mlrun/db/base.py +14 -0
  11. mlrun/db/httpdb.py +48 -1
  12. mlrun/db/nopdb.py +13 -0
  13. mlrun/execution.py +43 -11
  14. mlrun/feature_store/steps.py +1 -1
  15. mlrun/model_monitoring/api.py +26 -19
  16. mlrun/model_monitoring/applications/_application_steps.py +1 -1
  17. mlrun/model_monitoring/applications/base.py +44 -7
  18. mlrun/model_monitoring/applications/context.py +94 -71
  19. mlrun/projects/pipelines.py +6 -3
  20. mlrun/projects/project.py +95 -17
  21. mlrun/runtimes/nuclio/function.py +2 -1
  22. mlrun/runtimes/nuclio/serving.py +33 -5
  23. mlrun/serving/__init__.py +8 -0
  24. mlrun/serving/merger.py +1 -1
  25. mlrun/serving/remote.py +17 -5
  26. mlrun/serving/routers.py +36 -87
  27. mlrun/serving/server.py +6 -2
  28. mlrun/serving/states.py +162 -13
  29. mlrun/serving/v2_serving.py +39 -82
  30. mlrun/utils/helpers.py +6 -0
  31. mlrun/utils/notifications/notification/base.py +1 -1
  32. mlrun/utils/notifications/notification/webhook.py +13 -12
  33. mlrun/utils/notifications/notification_pusher.py +18 -23
  34. mlrun/utils/version/version.json +2 -2
  35. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/METADATA +10 -10
  36. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/RECORD +40 -40
  37. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/LICENSE +0 -0
  38. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/WHEEL +0 -0
  39. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/entry_points.txt +0 -0
  40. {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc13.dist-info}/top_level.txt +0 -0
mlrun/projects/project.py CHANGED
@@ -1873,6 +1873,34 @@ class MlrunProject(ModelObj):
1873
1873
  vector_store: "VectorStore", # noqa: F821
1874
1874
  collection_name: Optional[str] = None,
1875
1875
  ) -> VectorStoreCollection:
1876
+ """
1877
+ Create a VectorStoreCollection wrapper for a given vector store instance.
1878
+
1879
+ This method wraps a vector store implementation (like Milvus, Chroma) with MLRun
1880
+ integration capabilities. The wrapper provides access to the underlying vector
1881
+ store's functionality while adding MLRun-specific features like document and
1882
+ artifact management.
1883
+
1884
+ Args:
1885
+ vector_store: The vector store instance to wrap (e.g., Milvus, Chroma).
1886
+ This is the underlying implementation that will handle
1887
+ vector storage and retrieval.
1888
+ collection_name: Optional name for the collection. If not provided,
1889
+ will attempt to extract it from the vector_store object
1890
+ by looking for 'collection_name', '_collection_name',
1891
+ 'index_name', or '_index_name' attributes.
1892
+
1893
+ Returns:
1894
+ VectorStoreCollection: A wrapped vector store instance with MLRun integration.
1895
+ This wrapper provides both access to the original vector
1896
+ store's capabilities and additional MLRun functionality.
1897
+
1898
+ Example:
1899
+ >>> vector_store = Chroma(embedding_function=embeddings)
1900
+ >>> collection = project.get_vector_store_collection(
1901
+ ... vector_store, collection_name="my_collection"
1902
+ ... )
1903
+ """
1876
1904
  return VectorStoreCollection(
1877
1905
  self,
1878
1906
  vector_store,
@@ -1899,12 +1927,39 @@ class MlrunProject(ModelObj):
1899
1927
  :param local_path: path to the local file we upload, will also be use
1900
1928
  as the destination subpath (under "artifact_path")
1901
1929
  :param artifact_path: Target path for artifact storage
1902
- :param document_loader_spec: Spec to use to load the artifact as langchain document
1930
+ :param document_loader_spec: Spec to use to load the artifact as langchain document.
1931
+
1932
+ By default, uses DocumentLoaderSpec() which initializes with:
1933
+
1934
+ * loader_class_name="langchain_community.document_loaders.TextLoader"
1935
+ * src_name="file_path"
1936
+ * kwargs=None
1937
+
1938
+ Can be customized for different document types, e.g.::
1939
+
1940
+ DocumentLoaderSpec(
1941
+ loader_class_name="langchain_community.document_loaders.PDFLoader",
1942
+ src_name="file_path",
1943
+ kwargs={"extract_images": True}
1944
+ )
1903
1945
  :param upload: Whether to upload the artifact
1904
1946
  :param labels: Key-value labels
1905
1947
  :param target_path: Target file path
1906
1948
  :param kwargs: Additional keyword arguments
1907
1949
  :return: DocumentArtifact object
1950
+
1951
+ Example:
1952
+ >>> # Log a PDF document with custom loader
1953
+ >>> project.log_document(
1954
+ ... key="my_doc",
1955
+ ... local_path="path/to/doc.pdf",
1956
+ ... document_loader=DocumentLoaderSpec(
1957
+ ... loader_class_name="langchain_community.document_loaders.PDFLoader",
1958
+ ... src_name="file_path",
1959
+ ... kwargs={"extract_images": True},
1960
+ ... ),
1961
+ ... )
1962
+
1908
1963
  """
1909
1964
  doc_artifact = DocumentArtifact(
1910
1965
  key=key,
@@ -2117,13 +2172,13 @@ class MlrunProject(ModelObj):
2117
2172
 
2118
2173
  def set_model_monitoring_function(
2119
2174
  self,
2175
+ name: str,
2120
2176
  func: typing.Union[str, mlrun.runtimes.RemoteRuntime, None] = None,
2121
2177
  application_class: typing.Union[
2122
2178
  str, mm_app.ModelMonitoringApplicationBase, None
2123
2179
  ] = None,
2124
- name: Optional[str] = None,
2125
2180
  image: Optional[str] = None,
2126
- handler=None,
2181
+ handler: Optional[str] = None,
2127
2182
  with_repo: Optional[bool] = None,
2128
2183
  tag: Optional[str] = None,
2129
2184
  requirements: Optional[typing.Union[str, list[str]]] = None,
@@ -2135,7 +2190,7 @@ class MlrunProject(ModelObj):
2135
2190
  Note: to deploy the function after linking it to the project,
2136
2191
  call `fn.deploy()` where `fn` is the object returned by this method.
2137
2192
 
2138
- examples::
2193
+ Example::
2139
2194
 
2140
2195
  project.set_model_monitoring_function(
2141
2196
  name="myApp", application_class="MyApp", image="mlrun/mlrun"
@@ -2144,8 +2199,7 @@ class MlrunProject(ModelObj):
2144
2199
  :param func: Remote function object or spec/code URL. :code:`None` refers to the current
2145
2200
  notebook.
2146
2201
  :param name: Name of the function (under the project), can be specified with a tag to support
2147
- versions (e.g. myfunc:v1)
2148
- Default: job
2202
+ versions (e.g. myfunc:v1).
2149
2203
  :param image: Docker image to be used, can also be specified in
2150
2204
  the function object/yaml
2151
2205
  :param handler: Default function handler to invoke (can only be set with .py/.ipynb files)
@@ -2183,12 +2237,13 @@ class MlrunProject(ModelObj):
2183
2237
 
2184
2238
  def create_model_monitoring_function(
2185
2239
  self,
2240
+ name: str,
2186
2241
  func: Optional[str] = None,
2187
2242
  application_class: typing.Union[
2188
2243
  str,
2189
2244
  mm_app.ModelMonitoringApplicationBase,
2245
+ None,
2190
2246
  ] = None,
2191
- name: Optional[str] = None,
2192
2247
  image: Optional[str] = None,
2193
2248
  handler: Optional[str] = None,
2194
2249
  with_repo: Optional[bool] = None,
@@ -2200,16 +2255,15 @@ class MlrunProject(ModelObj):
2200
2255
  """
2201
2256
  Create a monitoring function object without setting it to the project
2202
2257
 
2203
- examples::
2258
+ Example::
2204
2259
 
2205
2260
  project.create_model_monitoring_function(
2206
- application_class_name="MyApp", image="mlrun/mlrun", name="myApp"
2261
+ name="myApp", application_class="MyApp", image="mlrun/mlrun"
2207
2262
  )
2208
2263
 
2209
2264
  :param func: The function's code URL. :code:`None` refers to the current notebook.
2210
2265
  :param name: Name of the function, can be specified with a tag to support
2211
- versions (e.g. myfunc:v1)
2212
- Default: job
2266
+ versions (e.g. myfunc:v1).
2213
2267
  :param image: Docker image to be used, can also be specified in
2214
2268
  the function object/yaml
2215
2269
  :param handler: Default function handler to invoke (can only be set with .py/.ipynb files)
@@ -2587,6 +2641,24 @@ class MlrunProject(ModelObj):
2587
2641
  self._set_function(resolved_function_name, tag, function_object, func)
2588
2642
  return function_object
2589
2643
 
2644
+ def push_run_notifications(
2645
+ self,
2646
+ uid,
2647
+ timeout=45,
2648
+ ):
2649
+ """
2650
+ Push notifications for a run.
2651
+
2652
+ :param uid: Unique ID of the run.
2653
+ :returns: :py:class:`~mlrun.common.schemas.BackgroundTask`.
2654
+ """
2655
+ db = mlrun.db.get_run_db(secrets=self._secrets)
2656
+ return db.push_run_notifications(
2657
+ project=self.name,
2658
+ uid=uid,
2659
+ timeout=timeout,
2660
+ )
2661
+
2590
2662
  def _instantiate_function(
2591
2663
  self,
2592
2664
  func: typing.Union[str, mlrun.runtimes.BaseRuntime] = None,
@@ -3240,6 +3312,7 @@ class MlrunProject(ModelObj):
3240
3312
  cleanup_ttl: Optional[int] = None,
3241
3313
  notifications: Optional[list[mlrun.model.Notification]] = None,
3242
3314
  workflow_runner_node_selector: typing.Optional[dict[str, str]] = None,
3315
+ context: typing.Optional[mlrun.execution.MLClientCtx] = None,
3243
3316
  ) -> _PipelineRunStatus:
3244
3317
  """Run a workflow using kubeflow pipelines
3245
3318
 
@@ -3282,6 +3355,7 @@ class MlrunProject(ModelObj):
3282
3355
  This allows you to control and specify where the workflow runner pod will be scheduled.
3283
3356
  This setting is only relevant when the engine is set to 'remote' or for scheduled workflows,
3284
3357
  and it will be ignored if the workflow is not run on a remote engine.
3358
+ :param context: mlrun context.
3285
3359
  :returns: ~py:class:`~mlrun.projects.pipelines._PipelineRunStatus` instance
3286
3360
  """
3287
3361
 
@@ -3368,6 +3442,7 @@ class MlrunProject(ModelObj):
3368
3442
  namespace=namespace,
3369
3443
  source=source,
3370
3444
  notifications=notifications,
3445
+ context=context,
3371
3446
  )
3372
3447
  # run is None when scheduling
3373
3448
  if run and run.state == mlrun_pipelines.common.models.RunStatuses.failed:
@@ -3551,6 +3626,7 @@ class MlrunProject(ModelObj):
3551
3626
  self,
3552
3627
  name: Optional[str] = None,
3553
3628
  model_name: Optional[str] = None,
3629
+ model_tag: Optional[str] = None,
3554
3630
  function_name: Optional[str] = None,
3555
3631
  function_tag: Optional[str] = None,
3556
3632
  labels: Optional[list[str]] = None,
@@ -3565,12 +3641,13 @@ class MlrunProject(ModelObj):
3565
3641
  model endpoint. This functions supports filtering by the following parameters:
3566
3642
  1) name
3567
3643
  2) model_name
3568
- 3) function_name
3569
- 4) function_tag
3570
- 5) labels
3571
- 6) top level
3572
- 7) uids
3573
- 8) start and end time, corresponding to the `created` field.
3644
+ 3) model_tag
3645
+ 4) function_name
3646
+ 5) function_tag
3647
+ 6) labels
3648
+ 7) top level
3649
+ 8) uids
3650
+ 9) start and end time, corresponding to the `created` field.
3574
3651
  By default, when no filters are applied, all available endpoints for the given project will be listed.
3575
3652
 
3576
3653
  In addition, this functions provides a facade for listing endpoint related metrics. This facade is time-based
@@ -3599,6 +3676,7 @@ class MlrunProject(ModelObj):
3599
3676
  project=self.name,
3600
3677
  name=name,
3601
3678
  model_name=model_name,
3679
+ model_tag=model_tag,
3602
3680
  function_name=function_name,
3603
3681
  function_tag=function_tag,
3604
3682
  labels=labels,
@@ -1036,9 +1036,10 @@ class RemoteRuntime(KubeResource):
1036
1036
  if args and sidecar.get("command"):
1037
1037
  sidecar["args"] = mlrun.utils.helpers.as_list(args)
1038
1038
 
1039
- # populate the sidecar resources from the function spec
1039
+ # put the configured resources on the sidecar container instead of the reverse proxy container
1040
1040
  if self.spec.resources:
1041
1041
  sidecar["resources"] = self.spec.resources
1042
+ self.spec.resources = None
1042
1043
 
1043
1044
  def _set_sidecar(self, name: str) -> dict:
1044
1045
  self.spec.config.setdefault("spec.sidecars", [])
@@ -22,7 +22,7 @@ import nuclio
22
22
  from nuclio import KafkaTrigger
23
23
 
24
24
  import mlrun
25
- import mlrun.common.schemas
25
+ import mlrun.common.schemas as schemas
26
26
  from mlrun.datastore import get_kafka_brokers_from_dict, parse_kafka_url
27
27
  from mlrun.model import ObjectList
28
28
  from mlrun.runtimes.function_reference import FunctionReference
@@ -362,6 +362,9 @@ class ServingRuntime(RemoteRuntime):
362
362
  handler: Optional[str] = None,
363
363
  router_step: Optional[str] = None,
364
364
  child_function: Optional[str] = None,
365
+ creation_strategy: Optional[
366
+ schemas.ModelEndpointCreationStrategy
367
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
365
368
  **class_args,
366
369
  ):
367
370
  """add ml model and/or route to the function.
@@ -384,6 +387,16 @@ class ServingRuntime(RemoteRuntime):
384
387
  :param router_step: router step name (to determine which router we add the model to in graphs
385
388
  with multiple router steps)
386
389
  :param child_function: child function name, when the model runs in a child function
390
+ :param creation_strategy: Strategy for creating or updating the model endpoint:
391
+ * **overwrite**:
392
+ 1. If model endpoints with the same name exist, delete the `latest` one.
393
+ 2. Create a new model endpoint entry and set it as `latest`.
394
+ * **inplace** (default):
395
+ 1. If model endpoints with the same name exist, update the `latest` entry.
396
+ 2. Otherwise, create a new entry.
397
+ * **archive**:
398
+ 1. If model endpoints with the same name exist, preserve them.
399
+ 2. Create a new model endpoint with the same name and set it to `latest`.
387
400
  :param class_args: extra kwargs to pass to the model serving class __init__
388
401
  (can be read in the model using .get_param(key) method)
389
402
  """
@@ -419,7 +432,12 @@ class ServingRuntime(RemoteRuntime):
419
432
  if class_name and hasattr(class_name, "to_dict"):
420
433
  if model_path:
421
434
  class_name.model_path = model_path
422
- key, state = params_to_step(class_name, key)
435
+ key, state = params_to_step(
436
+ class_name,
437
+ key,
438
+ model_endpoint_creation_strategy=creation_strategy,
439
+ endpoint_type=schemas.EndpointType.LEAF_EP,
440
+ )
423
441
  else:
424
442
  class_name = class_name or self.spec.default_class
425
443
  if class_name and not isinstance(class_name, str):
@@ -432,12 +450,22 @@ class ServingRuntime(RemoteRuntime):
432
450
  model_path = str(model_path)
433
451
 
434
452
  if model_url:
435
- state = new_remote_endpoint(model_url, **class_args)
453
+ state = new_remote_endpoint(
454
+ model_url,
455
+ creation_strategy=creation_strategy,
456
+ endpoint_type=schemas.EndpointType.LEAF_EP,
457
+ **class_args,
458
+ )
436
459
  else:
437
460
  class_args = deepcopy(class_args)
438
461
  class_args["model_path"] = model_path
439
462
  state = TaskStep(
440
- class_name, class_args, handler=handler, function=child_function
463
+ class_name,
464
+ class_args,
465
+ handler=handler,
466
+ function=child_function,
467
+ model_endpoint_creation_strategy=creation_strategy,
468
+ endpoint_type=schemas.EndpointType.LEAF_EP,
441
469
  )
442
470
 
443
471
  return graph.add_route(key, state)
@@ -581,7 +609,7 @@ class ServingRuntime(RemoteRuntime):
581
609
  project="",
582
610
  tag="",
583
611
  verbose=False,
584
- auth_info: mlrun.common.schemas.AuthInfo = None,
612
+ auth_info: schemas.AuthInfo = None,
585
613
  builder_env: Optional[dict] = None,
586
614
  force_build: bool = False,
587
615
  ):
mlrun/serving/__init__.py CHANGED
@@ -23,6 +23,10 @@ __all__ = [
23
23
  "QueueStep",
24
24
  "ErrorStep",
25
25
  "MonitoringApplicationStep",
26
+ "ModelRunnerStep",
27
+ "ModelRunner",
28
+ "Model",
29
+ "ModelSelector",
26
30
  ]
27
31
 
28
32
  from .routers import ModelRouter, VotingEnsemble # noqa
@@ -33,6 +37,10 @@ from .states import (
33
37
  RouterStep,
34
38
  TaskStep,
35
39
  MonitoringApplicationStep,
40
+ ModelRunnerStep,
41
+ ModelRunner,
42
+ Model,
43
+ ModelSelector,
36
44
  ) # noqa
37
45
  from .v1_serving import MLModelServer, new_v1_model_server # noqa
38
46
  from .v2_serving import V2ModelServer # noqa
mlrun/serving/merger.py CHANGED
@@ -74,7 +74,7 @@ class Merge(storey.Flow):
74
74
  self._queue_len = max_behind or 64 # default queue is 64 entries
75
75
  self._keys_queue = []
76
76
 
77
- def post_init(self, mode="sync"):
77
+ def post_init(self, mode="sync", **kwargs):
78
78
  # auto detect number of uplinks or use user specified value
79
79
  self._uplinks = self.expected_num_events or (
80
80
  len(self._graph_step.after) if self._graph_step else 0
mlrun/serving/remote.py CHANGED
@@ -14,6 +14,7 @@
14
14
  #
15
15
  import asyncio
16
16
  import json
17
+ from copy import copy
17
18
  from typing import Optional
18
19
 
19
20
  import aiohttp
@@ -53,6 +54,7 @@ class RemoteStep(storey.SendToHttp):
53
54
  retries=None,
54
55
  backoff_factor=None,
55
56
  timeout=None,
57
+ headers_expression: Optional[str] = None,
56
58
  **kwargs,
57
59
  ):
58
60
  """class for calling remote endpoints
@@ -86,6 +88,7 @@ class RemoteStep(storey.SendToHttp):
86
88
  :param retries: number of retries (in exponential backoff)
87
89
  :param backoff_factor: A backoff factor in seconds to apply between attempts after the second try
88
90
  :param timeout: How long to wait for the server to send data before giving up, float in seconds
91
+ :param headers_expression: an expression for getting the request headers from the event, e.g. "event['headers']"
89
92
  """
90
93
  # init retry args for storey
91
94
  retries = default_retries if retries is None else retries
@@ -102,6 +105,7 @@ class RemoteStep(storey.SendToHttp):
102
105
  self.url = url
103
106
  self.url_expression = url_expression
104
107
  self.body_expression = body_expression
108
+ self.headers_expression = headers_expression
105
109
  self.headers = headers
106
110
  self.method = method
107
111
  self.return_json = return_json
@@ -114,8 +118,9 @@ class RemoteStep(storey.SendToHttp):
114
118
  self._session = None
115
119
  self._url_function_handler = None
116
120
  self._body_function_handler = None
121
+ self._headers_function_handler = None
117
122
 
118
- def post_init(self, mode="sync"):
123
+ def post_init(self, mode="sync", **kwargs):
119
124
  self._endpoint = self.url
120
125
  if self.url and self.context:
121
126
  self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
@@ -131,6 +136,12 @@ class RemoteStep(storey.SendToHttp):
131
136
  {"endpoint": self._endpoint, "context": self.context},
132
137
  {},
133
138
  )
139
+ if self.headers_expression:
140
+ self._headers_function_handler = eval(
141
+ "lambda event: " + self.headers_expression,
142
+ {"context": self.context},
143
+ {},
144
+ )
134
145
  elif self.subpath:
135
146
  self._append_event_path = self.subpath == "$path"
136
147
  if not self._append_event_path:
@@ -205,7 +216,10 @@ class RemoteStep(storey.SendToHttp):
205
216
 
206
217
  def _generate_request(self, event, body):
207
218
  method = self.method or event.method or "POST"
208
- headers = self.headers or {}
219
+ if self._headers_function_handler:
220
+ headers = self._headers_function_handler(body)
221
+ else:
222
+ headers = copy(self.headers) or {}
209
223
 
210
224
  if self._url_function_handler:
211
225
  url = self._url_function_handler(body)
@@ -216,10 +230,8 @@ class RemoteStep(storey.SendToHttp):
216
230
  url = url + "/" + striped_path
217
231
  if striped_path:
218
232
  headers[event_path_key] = event.path
219
-
220
233
  if event.id:
221
234
  headers[event_id_key] = event.id
222
-
223
235
  if method == "GET":
224
236
  body = None
225
237
  elif body is not None and not isinstance(body, (str, bytes)):
@@ -334,7 +346,7 @@ class BatchHttpRequests(_ConcurrentJobExecution):
334
346
  async def _cleanup(self):
335
347
  await self._client_session.close()
336
348
 
337
- def post_init(self, mode="sync"):
349
+ def post_init(self, mode="sync", **kwargs):
338
350
  self._endpoint = self.url
339
351
  if self.url and self.context:
340
352
  self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
mlrun/serving/routers.py CHANGED
@@ -30,7 +30,6 @@ import mlrun.common.model_monitoring
30
30
  import mlrun.common.schemas.model_monitoring
31
31
  from mlrun.utils import logger, now_date
32
32
 
33
- from ..common.schemas.model_monitoring import ModelEndpointSchema
34
33
  from .server import GraphServer
35
34
  from .utils import RouterToDict, _extract_input_data, _update_result_body
36
35
  from .v2_serving import _ModelLogPusher
@@ -110,7 +109,7 @@ class BaseModelRouter(RouterToDict):
110
109
 
111
110
  return parsed_event
112
111
 
113
- def post_init(self, mode="sync"):
112
+ def post_init(self, mode="sync", **kwargs):
114
113
  self.context.logger.info(f"Loaded {list(self.routes.keys())}")
115
114
 
116
115
  def get_metadata(self):
@@ -610,7 +609,7 @@ class VotingEnsemble(ParallelRun):
610
609
  self.model_endpoint_uid = None
611
610
  self.shard_by_endpoint = shard_by_endpoint
612
611
 
613
- def post_init(self, mode="sync"):
612
+ def post_init(self, mode="sync", **kwargs):
614
613
  server = getattr(self.context, "_server", None) or getattr(
615
614
  self.context, "server", None
616
615
  )
@@ -619,7 +618,12 @@ class VotingEnsemble(ParallelRun):
619
618
  return
620
619
 
621
620
  if not self.context.is_mock or self.context.monitoring_mock:
622
- self.model_endpoint_uid = _init_endpoint_record(server, self)
621
+ self.model_endpoint_uid = _init_endpoint_record(
622
+ server,
623
+ self,
624
+ creation_strategy=kwargs.get("creation_strategy"),
625
+ endpoint_type=kwargs.get("endpoint_type"),
626
+ )
623
627
 
624
628
  self._update_weights(self.weights)
625
629
 
@@ -1001,7 +1005,10 @@ class VotingEnsemble(ParallelRun):
1001
1005
 
1002
1006
 
1003
1007
  def _init_endpoint_record(
1004
- graph_server: GraphServer, voting_ensemble: VotingEnsemble
1008
+ graph_server: GraphServer,
1009
+ voting_ensemble: VotingEnsemble,
1010
+ creation_strategy: mlrun.common.schemas.ModelEndpointCreationStrategy,
1011
+ endpoint_type: mlrun.common.schemas.EndpointType,
1005
1012
  ) -> Union[str, None]:
1006
1013
  """
1007
1014
  Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
@@ -1011,61 +1018,50 @@ def _init_endpoint_record(
1011
1018
  :param graph_server: A GraphServer object which will be used for getting the function uri.
1012
1019
  :param voting_ensemble: Voting ensemble serving class. It contains important details for the model endpoint record
1013
1020
  such as model name, model path, model version, and the ids of the children model endpoints.
1014
-
1021
+ :param creation_strategy: Strategy for creating or updating the model endpoint:
1022
+ * **overwrite**:
1023
+ 1. If model endpoints with the same name exist, delete the `latest` one.
1024
+ 2. Create a new model endpoint entry and set it as `latest`.
1025
+ * **inplace** (default):
1026
+ 1. If model endpoints with the same name exist, update the `latest` entry.
1027
+ 2. Otherwise, create a new entry.
1028
+ * **archive**:
1029
+ 1. If model endpoints with the same name exist, preserve them.
1030
+ 2. Create a new model endpoint with the same name and set it to `latest`.
1031
+
1032
+ :param endpoint_type: model endpoint type
1015
1033
  :return: Model endpoint unique ID.
1016
1034
  """
1017
1035
 
1018
1036
  logger.info("Initializing endpoint records")
1019
- try:
1020
- model_endpoint = mlrun.get_run_db().get_model_endpoint(
1021
- project=graph_server.project,
1022
- name=voting_ensemble.name,
1023
- function_name=graph_server.function_name,
1024
- function_tag=graph_server.function_tag or "latest",
1025
- )
1026
- except mlrun.errors.MLRunNotFoundError:
1027
- model_endpoint = None
1028
- except mlrun.errors.MLRunBadRequestError as err:
1029
- logger.info(
1030
- "Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
1031
- )
1032
- return
1033
-
1034
- function = mlrun.get_run_db().get_function(
1035
- name=graph_server.function_name,
1036
- project=graph_server.project,
1037
- tag=graph_server.function_tag or "latest",
1038
- )
1039
- function_uid = function.get("metadata", {}).get("uid")
1040
- # Get the children model endpoints ids
1041
1037
  children_uids = []
1042
1038
  children_names = []
1043
1039
  for _, c in voting_ensemble.routes.items():
1044
1040
  if hasattr(c, "endpoint_uid"):
1045
1041
  children_uids.append(c.endpoint_uid)
1046
1042
  children_names.append(c.name)
1047
- if not model_endpoint and voting_ensemble.context.server.track_models:
1043
+ try:
1048
1044
  logger.info(
1049
- "Creating a new model endpoint record",
1045
+ "Creating Or Updating a new model endpoint record",
1050
1046
  name=voting_ensemble.name,
1051
1047
  project=graph_server.project,
1052
1048
  function_name=graph_server.function_name,
1053
1049
  function_tag=graph_server.function_tag or "latest",
1054
- function_uid=function_uid,
1055
1050
  model_class=voting_ensemble.__class__.__name__,
1051
+ creation_strategy=creation_strategy,
1056
1052
  )
1057
1053
  model_endpoint = mlrun.common.schemas.ModelEndpoint(
1058
1054
  metadata=mlrun.common.schemas.ModelEndpointMetadata(
1059
1055
  project=graph_server.project,
1060
1056
  name=voting_ensemble.name,
1061
- endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.ROUTER,
1057
+ endpoint_type=endpoint_type,
1062
1058
  ),
1063
1059
  spec=mlrun.common.schemas.ModelEndpointSpec(
1064
1060
  function_name=graph_server.function_name,
1065
- function_uid=function_uid,
1066
1061
  function_tag=graph_server.function_tag or "latest",
1067
1062
  model_class=voting_ensemble.__class__.__name__,
1068
- children_uids=list(voting_ensemble.routes.keys()),
1063
+ children_uids=children_uids,
1064
+ children=children_names,
1069
1065
  ),
1070
1066
  status=mlrun.common.schemas.ModelEndpointStatus(
1071
1067
  monitoring_mode=mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
@@ -1074,59 +1070,12 @@ def _init_endpoint_record(
1074
1070
  ),
1075
1071
  )
1076
1072
  db = mlrun.get_run_db()
1077
- db.create_model_endpoint(model_endpoint=model_endpoint)
1078
-
1079
- elif model_endpoint:
1080
- attributes = {}
1081
- if function_uid != model_endpoint.spec.function_uid:
1082
- attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
1083
- if children_uids != model_endpoint.spec.children_uids:
1084
- attributes[ModelEndpointSchema.CHILDREN_UIDS] = children_uids
1085
- if (
1086
- model_endpoint.status.monitoring_mode
1087
- == mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
1088
- ) != voting_ensemble.context.server.track_models:
1089
- attributes[ModelEndpointSchema.MONITORING_MODE] = (
1090
- mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
1091
- if voting_ensemble.context.server.track_models
1092
- else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
1093
- )
1094
- if attributes:
1095
- db = mlrun.get_run_db()
1096
- logger.info(
1097
- "Updating model endpoint attributes",
1098
- attributes=attributes,
1099
- project=model_endpoint.metadata.project,
1100
- name=model_endpoint.metadata.name,
1101
- function_name=model_endpoint.spec.function_name,
1102
- )
1103
- model_endpoint = db.patch_model_endpoint(
1104
- project=model_endpoint.metadata.project,
1105
- name=model_endpoint.metadata.name,
1106
- endpoint_id=model_endpoint.metadata.uid,
1107
- attributes=attributes,
1108
- )
1109
- else:
1110
- logger.info(
1111
- "Did not create a new model endpoint record, monitoring is disabled"
1073
+ db.create_model_endpoint(
1074
+ model_endpoint=model_endpoint, creation_strategy=creation_strategy
1112
1075
  )
1076
+ except mlrun.errors.MLRunInvalidArgumentError as e:
1077
+ logger.info("Failed to create model endpoint record", error=e)
1113
1078
  return None
1114
-
1115
- # Update model endpoint children type
1116
- logger.info(
1117
- "Updating children model endpoint type",
1118
- children_uids=children_uids,
1119
- children_names=children_names,
1120
- )
1121
- for uid, name in zip(children_uids, children_names):
1122
- mlrun.get_run_db().patch_model_endpoint(
1123
- name=name,
1124
- project=graph_server.project,
1125
- endpoint_id=uid,
1126
- attributes={
1127
- ModelEndpointSchema.ENDPOINT_TYPE: mlrun.common.schemas.model_monitoring.EndpointType.LEAF_EP
1128
- },
1129
- )
1130
1079
  return model_endpoint.metadata.uid
1131
1080
 
1132
1081
 
@@ -1192,7 +1141,7 @@ class EnrichmentModelRouter(ModelRouter):
1192
1141
 
1193
1142
  self._feature_service = None
1194
1143
 
1195
- def post_init(self, mode="sync"):
1144
+ def post_init(self, mode="sync", **kwargs):
1196
1145
  from ..feature_store import get_feature_vector
1197
1146
 
1198
1147
  super().post_init(mode)
@@ -1342,7 +1291,7 @@ class EnrichmentVotingEnsemble(VotingEnsemble):
1342
1291
 
1343
1292
  self._feature_service = None
1344
1293
 
1345
- def post_init(self, mode="sync"):
1294
+ def post_init(self, mode="sync", **kwargs):
1346
1295
  from ..feature_store import get_feature_vector
1347
1296
 
1348
1297
  super().post_init(mode)
mlrun/serving/server.py CHANGED
@@ -367,7 +367,9 @@ def _set_callbacks(server, context):
367
367
 
368
368
  async def termination_callback():
369
369
  context.logger.info("Termination callback called")
370
- server.wait_for_completion()
370
+ maybe_coroutine = server.wait_for_completion()
371
+ if asyncio.iscoroutine(maybe_coroutine):
372
+ await maybe_coroutine
371
373
  context.logger.info("Termination of async flow is completed")
372
374
 
373
375
  context.platform.set_termination_callback(termination_callback)
@@ -379,7 +381,9 @@ def _set_callbacks(server, context):
379
381
 
380
382
  async def drain_callback():
381
383
  context.logger.info("Drain callback called")
382
- server.wait_for_completion()
384
+ maybe_coroutine = server.wait_for_completion()
385
+ if asyncio.iscoroutine(maybe_coroutine):
386
+ await maybe_coroutine
383
387
  context.logger.info(
384
388
  "Termination of async flow is completed. Rerunning async flow."
385
389
  )