mlrun 1.8.0rc9__py3-none-any.whl → 1.8.0rc12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (37) hide show
  1. mlrun/artifacts/__init__.py +1 -1
  2. mlrun/artifacts/document.py +53 -11
  3. mlrun/common/constants.py +1 -0
  4. mlrun/common/schemas/__init__.py +2 -0
  5. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  6. mlrun/common/schemas/model_monitoring/constants.py +7 -0
  7. mlrun/common/schemas/model_monitoring/model_endpoints.py +36 -0
  8. mlrun/config.py +1 -0
  9. mlrun/data_types/data_types.py +1 -0
  10. mlrun/data_types/spark.py +3 -2
  11. mlrun/data_types/to_pandas.py +11 -2
  12. mlrun/datastore/__init__.py +2 -0
  13. mlrun/datastore/targets.py +2 -1
  14. mlrun/datastore/vectorstore.py +21 -15
  15. mlrun/db/base.py +2 -0
  16. mlrun/db/httpdb.py +12 -0
  17. mlrun/db/nopdb.py +2 -0
  18. mlrun/feature_store/steps.py +1 -1
  19. mlrun/model_monitoring/api.py +30 -21
  20. mlrun/model_monitoring/applications/base.py +42 -4
  21. mlrun/projects/project.py +18 -16
  22. mlrun/runtimes/nuclio/serving.py +28 -5
  23. mlrun/serving/__init__.py +8 -0
  24. mlrun/serving/merger.py +1 -1
  25. mlrun/serving/remote.py +17 -5
  26. mlrun/serving/routers.py +27 -87
  27. mlrun/serving/server.py +6 -2
  28. mlrun/serving/states.py +154 -13
  29. mlrun/serving/v2_serving.py +38 -79
  30. mlrun/utils/helpers.py +6 -0
  31. mlrun/utils/version/version.json +2 -2
  32. {mlrun-1.8.0rc9.dist-info → mlrun-1.8.0rc12.dist-info}/METADATA +10 -10
  33. {mlrun-1.8.0rc9.dist-info → mlrun-1.8.0rc12.dist-info}/RECORD +37 -37
  34. {mlrun-1.8.0rc9.dist-info → mlrun-1.8.0rc12.dist-info}/LICENSE +0 -0
  35. {mlrun-1.8.0rc9.dist-info → mlrun-1.8.0rc12.dist-info}/WHEEL +0 -0
  36. {mlrun-1.8.0rc9.dist-info → mlrun-1.8.0rc12.dist-info}/entry_points.txt +0 -0
  37. {mlrun-1.8.0rc9.dist-info → mlrun-1.8.0rc12.dist-info}/top_level.txt +0 -0
mlrun/projects/project.py CHANGED
@@ -2117,13 +2117,13 @@ class MlrunProject(ModelObj):
2117
2117
 
2118
2118
  def set_model_monitoring_function(
2119
2119
  self,
2120
+ name: str,
2120
2121
  func: typing.Union[str, mlrun.runtimes.RemoteRuntime, None] = None,
2121
2122
  application_class: typing.Union[
2122
2123
  str, mm_app.ModelMonitoringApplicationBase, None
2123
2124
  ] = None,
2124
- name: Optional[str] = None,
2125
2125
  image: Optional[str] = None,
2126
- handler=None,
2126
+ handler: Optional[str] = None,
2127
2127
  with_repo: Optional[bool] = None,
2128
2128
  tag: Optional[str] = None,
2129
2129
  requirements: Optional[typing.Union[str, list[str]]] = None,
@@ -2135,7 +2135,7 @@ class MlrunProject(ModelObj):
2135
2135
  Note: to deploy the function after linking it to the project,
2136
2136
  call `fn.deploy()` where `fn` is the object returned by this method.
2137
2137
 
2138
- examples::
2138
+ Example::
2139
2139
 
2140
2140
  project.set_model_monitoring_function(
2141
2141
  name="myApp", application_class="MyApp", image="mlrun/mlrun"
@@ -2144,8 +2144,7 @@ class MlrunProject(ModelObj):
2144
2144
  :param func: Remote function object or spec/code URL. :code:`None` refers to the current
2145
2145
  notebook.
2146
2146
  :param name: Name of the function (under the project), can be specified with a tag to support
2147
- versions (e.g. myfunc:v1)
2148
- Default: job
2147
+ versions (e.g. myfunc:v1).
2149
2148
  :param image: Docker image to be used, can also be specified in
2150
2149
  the function object/yaml
2151
2150
  :param handler: Default function handler to invoke (can only be set with .py/.ipynb files)
@@ -2183,12 +2182,13 @@ class MlrunProject(ModelObj):
2183
2182
 
2184
2183
  def create_model_monitoring_function(
2185
2184
  self,
2185
+ name: str,
2186
2186
  func: Optional[str] = None,
2187
2187
  application_class: typing.Union[
2188
2188
  str,
2189
2189
  mm_app.ModelMonitoringApplicationBase,
2190
+ None,
2190
2191
  ] = None,
2191
- name: Optional[str] = None,
2192
2192
  image: Optional[str] = None,
2193
2193
  handler: Optional[str] = None,
2194
2194
  with_repo: Optional[bool] = None,
@@ -2200,16 +2200,15 @@ class MlrunProject(ModelObj):
2200
2200
  """
2201
2201
  Create a monitoring function object without setting it to the project
2202
2202
 
2203
- examples::
2203
+ Example::
2204
2204
 
2205
2205
  project.create_model_monitoring_function(
2206
- application_class_name="MyApp", image="mlrun/mlrun", name="myApp"
2206
+ name="myApp", application_class="MyApp", image="mlrun/mlrun"
2207
2207
  )
2208
2208
 
2209
2209
  :param func: The function's code URL. :code:`None` refers to the current notebook.
2210
2210
  :param name: Name of the function, can be specified with a tag to support
2211
- versions (e.g. myfunc:v1)
2212
- Default: job
2211
+ versions (e.g. myfunc:v1).
2213
2212
  :param image: Docker image to be used, can also be specified in
2214
2213
  the function object/yaml
2215
2214
  :param handler: Default function handler to invoke (can only be set with .py/.ipynb files)
@@ -3551,6 +3550,7 @@ class MlrunProject(ModelObj):
3551
3550
  self,
3552
3551
  name: Optional[str] = None,
3553
3552
  model_name: Optional[str] = None,
3553
+ model_tag: Optional[str] = None,
3554
3554
  function_name: Optional[str] = None,
3555
3555
  function_tag: Optional[str] = None,
3556
3556
  labels: Optional[list[str]] = None,
@@ -3565,12 +3565,13 @@ class MlrunProject(ModelObj):
3565
3565
  model endpoint. This functions supports filtering by the following parameters:
3566
3566
  1) name
3567
3567
  2) model_name
3568
- 3) function_name
3569
- 4) function_tag
3570
- 5) labels
3571
- 6) top level
3572
- 7) uids
3573
- 8) start and end time, corresponding to the `created` field.
3568
+ 3) model_tag
3569
+ 4) function_name
3570
+ 5) function_tag
3571
+ 6) labels
3572
+ 7) top level
3573
+ 8) uids
3574
+ 9) start and end time, corresponding to the `created` field.
3574
3575
  By default, when no filters are applied, all available endpoints for the given project will be listed.
3575
3576
 
3576
3577
  In addition, this functions provides a facade for listing endpoint related metrics. This facade is time-based
@@ -3599,6 +3600,7 @@ class MlrunProject(ModelObj):
3599
3600
  project=self.name,
3600
3601
  name=name,
3601
3602
  model_name=model_name,
3603
+ model_tag=model_tag,
3602
3604
  function_name=function_name,
3603
3605
  function_tag=function_tag,
3604
3606
  labels=labels,
@@ -22,7 +22,7 @@ import nuclio
22
22
  from nuclio import KafkaTrigger
23
23
 
24
24
  import mlrun
25
- import mlrun.common.schemas
25
+ import mlrun.common.schemas as schemas
26
26
  from mlrun.datastore import get_kafka_brokers_from_dict, parse_kafka_url
27
27
  from mlrun.model import ObjectList
28
28
  from mlrun.runtimes.function_reference import FunctionReference
@@ -362,6 +362,9 @@ class ServingRuntime(RemoteRuntime):
362
362
  handler: Optional[str] = None,
363
363
  router_step: Optional[str] = None,
364
364
  child_function: Optional[str] = None,
365
+ creation_strategy: Optional[
366
+ schemas.ModelEndpointCreationStrategy
367
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
365
368
  **class_args,
366
369
  ):
367
370
  """add ml model and/or route to the function.
@@ -384,6 +387,11 @@ class ServingRuntime(RemoteRuntime):
384
387
  :param router_step: router step name (to determine which router we add the model to in graphs
385
388
  with multiple router steps)
386
389
  :param child_function: child function name, when the model runs in a child function
390
+ :param creation_strategy: model endpoint creation strategy :
391
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
392
+ * inplace - Use the existing model endpoint if it already exists (default).
393
+ * archive - Preserve the old model endpoint and create a new one,
394
+ tagging it as the latest.
387
395
  :param class_args: extra kwargs to pass to the model serving class __init__
388
396
  (can be read in the model using .get_param(key) method)
389
397
  """
@@ -419,7 +427,12 @@ class ServingRuntime(RemoteRuntime):
419
427
  if class_name and hasattr(class_name, "to_dict"):
420
428
  if model_path:
421
429
  class_name.model_path = model_path
422
- key, state = params_to_step(class_name, key)
430
+ key, state = params_to_step(
431
+ class_name,
432
+ key,
433
+ model_endpoint_creation_strategy=creation_strategy,
434
+ endpoint_type=schemas.EndpointType.LEAF_EP,
435
+ )
423
436
  else:
424
437
  class_name = class_name or self.spec.default_class
425
438
  if class_name and not isinstance(class_name, str):
@@ -432,12 +445,22 @@ class ServingRuntime(RemoteRuntime):
432
445
  model_path = str(model_path)
433
446
 
434
447
  if model_url:
435
- state = new_remote_endpoint(model_url, **class_args)
448
+ state = new_remote_endpoint(
449
+ model_url,
450
+ creation_strategy=creation_strategy,
451
+ endpoint_type=schemas.EndpointType.LEAF_EP,
452
+ **class_args,
453
+ )
436
454
  else:
437
455
  class_args = deepcopy(class_args)
438
456
  class_args["model_path"] = model_path
439
457
  state = TaskStep(
440
- class_name, class_args, handler=handler, function=child_function
458
+ class_name,
459
+ class_args,
460
+ handler=handler,
461
+ function=child_function,
462
+ model_endpoint_creation_strategy=creation_strategy,
463
+ endpoint_type=schemas.EndpointType.LEAF_EP,
441
464
  )
442
465
 
443
466
  return graph.add_route(key, state)
@@ -581,7 +604,7 @@ class ServingRuntime(RemoteRuntime):
581
604
  project="",
582
605
  tag="",
583
606
  verbose=False,
584
- auth_info: mlrun.common.schemas.AuthInfo = None,
607
+ auth_info: schemas.AuthInfo = None,
585
608
  builder_env: Optional[dict] = None,
586
609
  force_build: bool = False,
587
610
  ):
mlrun/serving/__init__.py CHANGED
@@ -23,6 +23,10 @@ __all__ = [
23
23
  "QueueStep",
24
24
  "ErrorStep",
25
25
  "MonitoringApplicationStep",
26
+ "ModelRunnerStep",
27
+ "ModelRunner",
28
+ "Model",
29
+ "ModelSelector",
26
30
  ]
27
31
 
28
32
  from .routers import ModelRouter, VotingEnsemble # noqa
@@ -33,6 +37,10 @@ from .states import (
33
37
  RouterStep,
34
38
  TaskStep,
35
39
  MonitoringApplicationStep,
40
+ ModelRunnerStep,
41
+ ModelRunner,
42
+ Model,
43
+ ModelSelector,
36
44
  ) # noqa
37
45
  from .v1_serving import MLModelServer, new_v1_model_server # noqa
38
46
  from .v2_serving import V2ModelServer # noqa
mlrun/serving/merger.py CHANGED
@@ -74,7 +74,7 @@ class Merge(storey.Flow):
74
74
  self._queue_len = max_behind or 64 # default queue is 64 entries
75
75
  self._keys_queue = []
76
76
 
77
- def post_init(self, mode="sync"):
77
+ def post_init(self, mode="sync", **kwargs):
78
78
  # auto detect number of uplinks or use user specified value
79
79
  self._uplinks = self.expected_num_events or (
80
80
  len(self._graph_step.after) if self._graph_step else 0
mlrun/serving/remote.py CHANGED
@@ -14,6 +14,7 @@
14
14
  #
15
15
  import asyncio
16
16
  import json
17
+ from copy import copy
17
18
  from typing import Optional
18
19
 
19
20
  import aiohttp
@@ -53,6 +54,7 @@ class RemoteStep(storey.SendToHttp):
53
54
  retries=None,
54
55
  backoff_factor=None,
55
56
  timeout=None,
57
+ headers_expression: Optional[str] = None,
56
58
  **kwargs,
57
59
  ):
58
60
  """class for calling remote endpoints
@@ -86,6 +88,7 @@ class RemoteStep(storey.SendToHttp):
86
88
  :param retries: number of retries (in exponential backoff)
87
89
  :param backoff_factor: A backoff factor in seconds to apply between attempts after the second try
88
90
  :param timeout: How long to wait for the server to send data before giving up, float in seconds
91
+ :param headers_expression: an expression for getting the request headers from the event, e.g. "event['headers']"
89
92
  """
90
93
  # init retry args for storey
91
94
  retries = default_retries if retries is None else retries
@@ -102,6 +105,7 @@ class RemoteStep(storey.SendToHttp):
102
105
  self.url = url
103
106
  self.url_expression = url_expression
104
107
  self.body_expression = body_expression
108
+ self.headers_expression = headers_expression
105
109
  self.headers = headers
106
110
  self.method = method
107
111
  self.return_json = return_json
@@ -114,8 +118,9 @@ class RemoteStep(storey.SendToHttp):
114
118
  self._session = None
115
119
  self._url_function_handler = None
116
120
  self._body_function_handler = None
121
+ self._headers_function_handler = None
117
122
 
118
- def post_init(self, mode="sync"):
123
+ def post_init(self, mode="sync", **kwargs):
119
124
  self._endpoint = self.url
120
125
  if self.url and self.context:
121
126
  self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
@@ -131,6 +136,12 @@ class RemoteStep(storey.SendToHttp):
131
136
  {"endpoint": self._endpoint, "context": self.context},
132
137
  {},
133
138
  )
139
+ if self.headers_expression:
140
+ self._headers_function_handler = eval(
141
+ "lambda event: " + self.headers_expression,
142
+ {"context": self.context},
143
+ {},
144
+ )
134
145
  elif self.subpath:
135
146
  self._append_event_path = self.subpath == "$path"
136
147
  if not self._append_event_path:
@@ -205,7 +216,10 @@ class RemoteStep(storey.SendToHttp):
205
216
 
206
217
  def _generate_request(self, event, body):
207
218
  method = self.method or event.method or "POST"
208
- headers = self.headers or {}
219
+ if self._headers_function_handler:
220
+ headers = self._headers_function_handler(body)
221
+ else:
222
+ headers = copy(self.headers) or {}
209
223
 
210
224
  if self._url_function_handler:
211
225
  url = self._url_function_handler(body)
@@ -216,10 +230,8 @@ class RemoteStep(storey.SendToHttp):
216
230
  url = url + "/" + striped_path
217
231
  if striped_path:
218
232
  headers[event_path_key] = event.path
219
-
220
233
  if event.id:
221
234
  headers[event_id_key] = event.id
222
-
223
235
  if method == "GET":
224
236
  body = None
225
237
  elif body is not None and not isinstance(body, (str, bytes)):
@@ -334,7 +346,7 @@ class BatchHttpRequests(_ConcurrentJobExecution):
334
346
  async def _cleanup(self):
335
347
  await self._client_session.close()
336
348
 
337
- def post_init(self, mode="sync"):
349
+ def post_init(self, mode="sync", **kwargs):
338
350
  self._endpoint = self.url
339
351
  if self.url and self.context:
340
352
  self._endpoint = self.context.get_remote_endpoint(self.url).strip("/")
mlrun/serving/routers.py CHANGED
@@ -30,7 +30,6 @@ import mlrun.common.model_monitoring
30
30
  import mlrun.common.schemas.model_monitoring
31
31
  from mlrun.utils import logger, now_date
32
32
 
33
- from ..common.schemas.model_monitoring import ModelEndpointSchema
34
33
  from .server import GraphServer
35
34
  from .utils import RouterToDict, _extract_input_data, _update_result_body
36
35
  from .v2_serving import _ModelLogPusher
@@ -110,7 +109,7 @@ class BaseModelRouter(RouterToDict):
110
109
 
111
110
  return parsed_event
112
111
 
113
- def post_init(self, mode="sync"):
112
+ def post_init(self, mode="sync", **kwargs):
114
113
  self.context.logger.info(f"Loaded {list(self.routes.keys())}")
115
114
 
116
115
  def get_metadata(self):
@@ -610,7 +609,7 @@ class VotingEnsemble(ParallelRun):
610
609
  self.model_endpoint_uid = None
611
610
  self.shard_by_endpoint = shard_by_endpoint
612
611
 
613
- def post_init(self, mode="sync"):
612
+ def post_init(self, mode="sync", **kwargs):
614
613
  server = getattr(self.context, "_server", None) or getattr(
615
614
  self.context, "server", None
616
615
  )
@@ -619,7 +618,9 @@ class VotingEnsemble(ParallelRun):
619
618
  return
620
619
 
621
620
  if not self.context.is_mock or self.context.monitoring_mock:
622
- self.model_endpoint_uid = _init_endpoint_record(server, self)
621
+ self.model_endpoint_uid = _init_endpoint_record(
622
+ server, self, creation_strategy=kwargs.get("creation_strategy")
623
+ )
623
624
 
624
625
  self._update_weights(self.weights)
625
626
 
@@ -1001,7 +1002,10 @@ class VotingEnsemble(ParallelRun):
1001
1002
 
1002
1003
 
1003
1004
  def _init_endpoint_record(
1004
- graph_server: GraphServer, voting_ensemble: VotingEnsemble
1005
+ graph_server: GraphServer,
1006
+ voting_ensemble: VotingEnsemble,
1007
+ creation_strategy: str,
1008
+ endpoint_type: mlrun.common.schemas.EndpointType,
1005
1009
  ) -> Union[str, None]:
1006
1010
  """
1007
1011
  Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
@@ -1011,61 +1015,44 @@ def _init_endpoint_record(
1011
1015
  :param graph_server: A GraphServer object which will be used for getting the function uri.
1012
1016
  :param voting_ensemble: Voting ensemble serving class. It contains important details for the model endpoint record
1013
1017
  such as model name, model path, model version, and the ids of the children model endpoints.
1014
-
1018
+ :param creation_strategy: model endpoint creation strategy :
1019
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
1020
+ * inplace - Use the existing model endpoint if it already exists (default).
1021
+ * archive - Preserve the old model endpoint and create a new one,
1022
+ tagging it as the latest.
1023
+ :param endpoint_type: model endpoint type
1015
1024
  :return: Model endpoint unique ID.
1016
1025
  """
1017
1026
 
1018
1027
  logger.info("Initializing endpoint records")
1019
- try:
1020
- model_endpoint = mlrun.get_run_db().get_model_endpoint(
1021
- project=graph_server.project,
1022
- name=voting_ensemble.name,
1023
- function_name=graph_server.function_name,
1024
- function_tag=graph_server.function_tag or "latest",
1025
- )
1026
- except mlrun.errors.MLRunNotFoundError:
1027
- model_endpoint = None
1028
- except mlrun.errors.MLRunBadRequestError as err:
1029
- logger.info(
1030
- "Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
1031
- )
1032
- return
1033
-
1034
- function = mlrun.get_run_db().get_function(
1035
- name=graph_server.function_name,
1036
- project=graph_server.project,
1037
- tag=graph_server.function_tag or "latest",
1038
- )
1039
- function_uid = function.get("metadata", {}).get("uid")
1040
- # Get the children model endpoints ids
1041
1028
  children_uids = []
1042
1029
  children_names = []
1043
1030
  for _, c in voting_ensemble.routes.items():
1044
1031
  if hasattr(c, "endpoint_uid"):
1045
1032
  children_uids.append(c.endpoint_uid)
1046
1033
  children_names.append(c.name)
1047
- if not model_endpoint and voting_ensemble.context.server.track_models:
1034
+ try:
1048
1035
  logger.info(
1049
- "Creating a new model endpoint record",
1036
+ "Creating Or Updating a new model endpoint record",
1050
1037
  name=voting_ensemble.name,
1051
1038
  project=graph_server.project,
1052
1039
  function_name=graph_server.function_name,
1053
1040
  function_tag=graph_server.function_tag or "latest",
1054
- function_uid=function_uid,
1055
1041
  model_class=voting_ensemble.__class__.__name__,
1042
+ creation_strategy=creation_strategy,
1056
1043
  )
1057
1044
  model_endpoint = mlrun.common.schemas.ModelEndpoint(
1058
1045
  metadata=mlrun.common.schemas.ModelEndpointMetadata(
1059
1046
  project=graph_server.project,
1060
1047
  name=voting_ensemble.name,
1061
- endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.ROUTER,
1048
+ endpoint_type=endpoint_type,
1062
1049
  ),
1063
1050
  spec=mlrun.common.schemas.ModelEndpointSpec(
1064
1051
  function_name=graph_server.function_name,
1065
- function_uid=function_uid,
1066
1052
  function_tag=graph_server.function_tag or "latest",
1067
1053
  model_class=voting_ensemble.__class__.__name__,
1068
- children_uids=list(voting_ensemble.routes.keys()),
1054
+ children_uids=children_uids,
1055
+ children=children_names,
1069
1056
  ),
1070
1057
  status=mlrun.common.schemas.ModelEndpointStatus(
1071
1058
  monitoring_mode=mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
@@ -1074,59 +1061,12 @@ def _init_endpoint_record(
1074
1061
  ),
1075
1062
  )
1076
1063
  db = mlrun.get_run_db()
1077
- db.create_model_endpoint(model_endpoint=model_endpoint)
1078
-
1079
- elif model_endpoint:
1080
- attributes = {}
1081
- if function_uid != model_endpoint.spec.function_uid:
1082
- attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
1083
- if children_uids != model_endpoint.spec.children_uids:
1084
- attributes[ModelEndpointSchema.CHILDREN_UIDS] = children_uids
1085
- if (
1086
- model_endpoint.status.monitoring_mode
1087
- == mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
1088
- ) != voting_ensemble.context.server.track_models:
1089
- attributes[ModelEndpointSchema.MONITORING_MODE] = (
1090
- mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
1091
- if voting_ensemble.context.server.track_models
1092
- else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
1093
- )
1094
- if attributes:
1095
- db = mlrun.get_run_db()
1096
- logger.info(
1097
- "Updating model endpoint attributes",
1098
- attributes=attributes,
1099
- project=model_endpoint.metadata.project,
1100
- name=model_endpoint.metadata.name,
1101
- function_name=model_endpoint.spec.function_name,
1102
- )
1103
- model_endpoint = db.patch_model_endpoint(
1104
- project=model_endpoint.metadata.project,
1105
- name=model_endpoint.metadata.name,
1106
- endpoint_id=model_endpoint.metadata.uid,
1107
- attributes=attributes,
1108
- )
1109
- else:
1110
- logger.info(
1111
- "Did not create a new model endpoint record, monitoring is disabled"
1064
+ db.create_model_endpoint(
1065
+ model_endpoint=model_endpoint, creation_strategy=creation_strategy
1112
1066
  )
1067
+ except mlrun.errors.MLRunInvalidArgumentError as e:
1068
+ logger.info("Failed to create model endpoint record", error=e)
1113
1069
  return None
1114
-
1115
- # Update model endpoint children type
1116
- logger.info(
1117
- "Updating children model endpoint type",
1118
- children_uids=children_uids,
1119
- children_names=children_names,
1120
- )
1121
- for uid, name in zip(children_uids, children_names):
1122
- mlrun.get_run_db().patch_model_endpoint(
1123
- name=name,
1124
- project=graph_server.project,
1125
- endpoint_id=uid,
1126
- attributes={
1127
- ModelEndpointSchema.ENDPOINT_TYPE: mlrun.common.schemas.model_monitoring.EndpointType.LEAF_EP
1128
- },
1129
- )
1130
1070
  return model_endpoint.metadata.uid
1131
1071
 
1132
1072
 
@@ -1192,7 +1132,7 @@ class EnrichmentModelRouter(ModelRouter):
1192
1132
 
1193
1133
  self._feature_service = None
1194
1134
 
1195
- def post_init(self, mode="sync"):
1135
+ def post_init(self, mode="sync", **kwargs):
1196
1136
  from ..feature_store import get_feature_vector
1197
1137
 
1198
1138
  super().post_init(mode)
@@ -1342,7 +1282,7 @@ class EnrichmentVotingEnsemble(VotingEnsemble):
1342
1282
 
1343
1283
  self._feature_service = None
1344
1284
 
1345
- def post_init(self, mode="sync"):
1285
+ def post_init(self, mode="sync", **kwargs):
1346
1286
  from ..feature_store import get_feature_vector
1347
1287
 
1348
1288
  super().post_init(mode)
mlrun/serving/server.py CHANGED
@@ -367,7 +367,9 @@ def _set_callbacks(server, context):
367
367
 
368
368
  async def termination_callback():
369
369
  context.logger.info("Termination callback called")
370
- server.wait_for_completion()
370
+ maybe_coroutine = server.wait_for_completion()
371
+ if asyncio.iscoroutine(maybe_coroutine):
372
+ await maybe_coroutine
371
373
  context.logger.info("Termination of async flow is completed")
372
374
 
373
375
  context.platform.set_termination_callback(termination_callback)
@@ -379,7 +381,9 @@ def _set_callbacks(server, context):
379
381
 
380
382
  async def drain_callback():
381
383
  context.logger.info("Drain callback called")
382
- server.wait_for_completion()
384
+ maybe_coroutine = server.wait_for_completion()
385
+ if asyncio.iscoroutine(maybe_coroutine):
386
+ await maybe_coroutine
383
387
  context.logger.info(
384
388
  "Termination of async flow is completed. Rerunning async flow."
385
389
  )