mlrun 1.8.0rc10__py3-none-any.whl → 1.8.0rc12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

mlrun/serving/routers.py CHANGED
@@ -30,7 +30,6 @@ import mlrun.common.model_monitoring
30
30
  import mlrun.common.schemas.model_monitoring
31
31
  from mlrun.utils import logger, now_date
32
32
 
33
- from ..common.schemas.model_monitoring import ModelEndpointSchema
34
33
  from .server import GraphServer
35
34
  from .utils import RouterToDict, _extract_input_data, _update_result_body
36
35
  from .v2_serving import _ModelLogPusher
@@ -110,7 +109,7 @@ class BaseModelRouter(RouterToDict):
110
109
 
111
110
  return parsed_event
112
111
 
113
- def post_init(self, mode="sync"):
112
+ def post_init(self, mode="sync", **kwargs):
114
113
  self.context.logger.info(f"Loaded {list(self.routes.keys())}")
115
114
 
116
115
  def get_metadata(self):
@@ -610,7 +609,7 @@ class VotingEnsemble(ParallelRun):
610
609
  self.model_endpoint_uid = None
611
610
  self.shard_by_endpoint = shard_by_endpoint
612
611
 
613
- def post_init(self, mode="sync"):
612
+ def post_init(self, mode="sync", **kwargs):
614
613
  server = getattr(self.context, "_server", None) or getattr(
615
614
  self.context, "server", None
616
615
  )
@@ -619,7 +618,9 @@ class VotingEnsemble(ParallelRun):
619
618
  return
620
619
 
621
620
  if not self.context.is_mock or self.context.monitoring_mock:
622
- self.model_endpoint_uid = _init_endpoint_record(server, self)
621
+ self.model_endpoint_uid = _init_endpoint_record(
622
+ server, self, creation_strategy=kwargs.get("creation_strategy")
623
+ )
623
624
 
624
625
  self._update_weights(self.weights)
625
626
 
@@ -1001,7 +1002,10 @@ class VotingEnsemble(ParallelRun):
1001
1002
 
1002
1003
 
1003
1004
  def _init_endpoint_record(
1004
- graph_server: GraphServer, voting_ensemble: VotingEnsemble
1005
+ graph_server: GraphServer,
1006
+ voting_ensemble: VotingEnsemble,
1007
+ creation_strategy: str,
1008
+ endpoint_type: mlrun.common.schemas.EndpointType,
1005
1009
  ) -> Union[str, None]:
1006
1010
  """
1007
1011
  Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
@@ -1011,61 +1015,44 @@ def _init_endpoint_record(
1011
1015
  :param graph_server: A GraphServer object which will be used for getting the function uri.
1012
1016
  :param voting_ensemble: Voting ensemble serving class. It contains important details for the model endpoint record
1013
1017
  such as model name, model path, model version, and the ids of the children model endpoints.
1014
-
1018
+ :param creation_strategy: model endpoint creation strategy :
1019
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
1020
+ * inplace - Use the existing model endpoint if it already exists (default).
1021
+ * archive - Preserve the old model endpoint and create a new one,
1022
+ tagging it as the latest.
1023
+ :param endpoint_type: model endpoint type
1015
1024
  :return: Model endpoint unique ID.
1016
1025
  """
1017
1026
 
1018
1027
  logger.info("Initializing endpoint records")
1019
- try:
1020
- model_endpoint = mlrun.get_run_db().get_model_endpoint(
1021
- project=graph_server.project,
1022
- name=voting_ensemble.name,
1023
- function_name=graph_server.function_name,
1024
- function_tag=graph_server.function_tag or "latest",
1025
- )
1026
- except mlrun.errors.MLRunNotFoundError:
1027
- model_endpoint = None
1028
- except mlrun.errors.MLRunBadRequestError as err:
1029
- logger.info(
1030
- "Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
1031
- )
1032
- return
1033
-
1034
- function = mlrun.get_run_db().get_function(
1035
- name=graph_server.function_name,
1036
- project=graph_server.project,
1037
- tag=graph_server.function_tag or "latest",
1038
- )
1039
- function_uid = function.get("metadata", {}).get("uid")
1040
- # Get the children model endpoints ids
1041
1028
  children_uids = []
1042
1029
  children_names = []
1043
1030
  for _, c in voting_ensemble.routes.items():
1044
1031
  if hasattr(c, "endpoint_uid"):
1045
1032
  children_uids.append(c.endpoint_uid)
1046
1033
  children_names.append(c.name)
1047
- if not model_endpoint and voting_ensemble.context.server.track_models:
1034
+ try:
1048
1035
  logger.info(
1049
- "Creating a new model endpoint record",
1036
+ "Creating Or Updating a new model endpoint record",
1050
1037
  name=voting_ensemble.name,
1051
1038
  project=graph_server.project,
1052
1039
  function_name=graph_server.function_name,
1053
1040
  function_tag=graph_server.function_tag or "latest",
1054
- function_uid=function_uid,
1055
1041
  model_class=voting_ensemble.__class__.__name__,
1042
+ creation_strategy=creation_strategy,
1056
1043
  )
1057
1044
  model_endpoint = mlrun.common.schemas.ModelEndpoint(
1058
1045
  metadata=mlrun.common.schemas.ModelEndpointMetadata(
1059
1046
  project=graph_server.project,
1060
1047
  name=voting_ensemble.name,
1061
- endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.ROUTER,
1048
+ endpoint_type=endpoint_type,
1062
1049
  ),
1063
1050
  spec=mlrun.common.schemas.ModelEndpointSpec(
1064
1051
  function_name=graph_server.function_name,
1065
- function_uid=function_uid,
1066
1052
  function_tag=graph_server.function_tag or "latest",
1067
1053
  model_class=voting_ensemble.__class__.__name__,
1068
- children_uids=list(voting_ensemble.routes.keys()),
1054
+ children_uids=children_uids,
1055
+ children=children_names,
1069
1056
  ),
1070
1057
  status=mlrun.common.schemas.ModelEndpointStatus(
1071
1058
  monitoring_mode=mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
@@ -1074,59 +1061,12 @@ def _init_endpoint_record(
1074
1061
  ),
1075
1062
  )
1076
1063
  db = mlrun.get_run_db()
1077
- db.create_model_endpoint(model_endpoint=model_endpoint)
1078
-
1079
- elif model_endpoint:
1080
- attributes = {}
1081
- if function_uid != model_endpoint.spec.function_uid:
1082
- attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
1083
- if children_uids != model_endpoint.spec.children_uids:
1084
- attributes[ModelEndpointSchema.CHILDREN_UIDS] = children_uids
1085
- if (
1086
- model_endpoint.status.monitoring_mode
1087
- == mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
1088
- ) != voting_ensemble.context.server.track_models:
1089
- attributes[ModelEndpointSchema.MONITORING_MODE] = (
1090
- mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
1091
- if voting_ensemble.context.server.track_models
1092
- else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
1093
- )
1094
- if attributes:
1095
- db = mlrun.get_run_db()
1096
- logger.info(
1097
- "Updating model endpoint attributes",
1098
- attributes=attributes,
1099
- project=model_endpoint.metadata.project,
1100
- name=model_endpoint.metadata.name,
1101
- function_name=model_endpoint.spec.function_name,
1102
- )
1103
- model_endpoint = db.patch_model_endpoint(
1104
- project=model_endpoint.metadata.project,
1105
- name=model_endpoint.metadata.name,
1106
- endpoint_id=model_endpoint.metadata.uid,
1107
- attributes=attributes,
1108
- )
1109
- else:
1110
- logger.info(
1111
- "Did not create a new model endpoint record, monitoring is disabled"
1064
+ db.create_model_endpoint(
1065
+ model_endpoint=model_endpoint, creation_strategy=creation_strategy
1112
1066
  )
1067
+ except mlrun.errors.MLRunInvalidArgumentError as e:
1068
+ logger.info("Failed to create model endpoint record", error=e)
1113
1069
  return None
1114
-
1115
- # Update model endpoint children type
1116
- logger.info(
1117
- "Updating children model endpoint type",
1118
- children_uids=children_uids,
1119
- children_names=children_names,
1120
- )
1121
- for uid, name in zip(children_uids, children_names):
1122
- mlrun.get_run_db().patch_model_endpoint(
1123
- name=name,
1124
- project=graph_server.project,
1125
- endpoint_id=uid,
1126
- attributes={
1127
- ModelEndpointSchema.ENDPOINT_TYPE: mlrun.common.schemas.model_monitoring.EndpointType.LEAF_EP
1128
- },
1129
- )
1130
1070
  return model_endpoint.metadata.uid
1131
1071
 
1132
1072
 
@@ -1192,7 +1132,7 @@ class EnrichmentModelRouter(ModelRouter):
1192
1132
 
1193
1133
  self._feature_service = None
1194
1134
 
1195
- def post_init(self, mode="sync"):
1135
+ def post_init(self, mode="sync", **kwargs):
1196
1136
  from ..feature_store import get_feature_vector
1197
1137
 
1198
1138
  super().post_init(mode)
@@ -1342,7 +1282,7 @@ class EnrichmentVotingEnsemble(VotingEnsemble):
1342
1282
 
1343
1283
  self._feature_service = None
1344
1284
 
1345
- def post_init(self, mode="sync"):
1285
+ def post_init(self, mode="sync", **kwargs):
1346
1286
  from ..feature_store import get_feature_vector
1347
1287
 
1348
1288
  super().post_init(mode)
mlrun/serving/server.py CHANGED
@@ -367,7 +367,9 @@ def _set_callbacks(server, context):
367
367
 
368
368
  async def termination_callback():
369
369
  context.logger.info("Termination callback called")
370
- server.wait_for_completion()
370
+ maybe_coroutine = server.wait_for_completion()
371
+ if asyncio.iscoroutine(maybe_coroutine):
372
+ await maybe_coroutine
371
373
  context.logger.info("Termination of async flow is completed")
372
374
 
373
375
  context.platform.set_termination_callback(termination_callback)
@@ -379,7 +381,9 @@ def _set_callbacks(server, context):
379
381
 
380
382
  async def drain_callback():
381
383
  context.logger.info("Drain callback called")
382
- server.wait_for_completion()
384
+ maybe_coroutine = server.wait_for_completion()
385
+ if asyncio.iscoroutine(maybe_coroutine):
386
+ await maybe_coroutine
383
387
  context.logger.info(
384
388
  "Termination of async flow is completed. Rerunning async flow."
385
389
  )
mlrun/serving/states.py CHANGED
@@ -25,11 +25,12 @@ import pathlib
25
25
  import traceback
26
26
  from copy import copy, deepcopy
27
27
  from inspect import getfullargspec, signature
28
- from typing import Any, Optional, Union
28
+ from typing import Any, Optional, Union, cast
29
29
 
30
30
  import storey.utils
31
31
 
32
32
  import mlrun
33
+ import mlrun.common.schemas as schemas
33
34
 
34
35
  from ..config import config
35
36
  from ..datastore import get_stream_pusher
@@ -81,22 +82,28 @@ _task_step_fields = [
81
82
  "responder",
82
83
  "input_path",
83
84
  "result_path",
85
+ "model_endpoint_creation_strategy",
86
+ "endpoint_type",
84
87
  ]
85
88
 
86
89
 
87
90
  MAX_ALLOWED_STEPS = 4500
88
91
 
89
92
 
90
- def new_model_endpoint(class_name, model_path, handler=None, **class_args):
91
- class_args = deepcopy(class_args)
92
- class_args["model_path"] = model_path
93
- return TaskStep(class_name, class_args, handler=handler)
94
-
95
-
96
- def new_remote_endpoint(url, **class_args):
93
+ def new_remote_endpoint(
94
+ url: str,
95
+ creation_strategy: schemas.ModelEndpointCreationStrategy,
96
+ endpoint_type: schemas.EndpointType,
97
+ **class_args,
98
+ ):
97
99
  class_args = deepcopy(class_args)
98
100
  class_args["url"] = url
99
- return TaskStep("$remote", class_args)
101
+ return TaskStep(
102
+ "$remote",
103
+ class_args=class_args,
104
+ model_endpoint_creation_strategy=creation_strategy,
105
+ endpoint_type=endpoint_type,
106
+ )
100
107
 
101
108
 
102
109
  class BaseStep(ModelObj):
@@ -419,6 +426,10 @@ class TaskStep(BaseStep):
419
426
  responder: Optional[bool] = None,
420
427
  input_path: Optional[str] = None,
421
428
  result_path: Optional[str] = None,
429
+ model_endpoint_creation_strategy: Optional[
430
+ schemas.ModelEndpointCreationStrategy
431
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
432
+ endpoint_type: Optional[schemas.EndpointType] = schemas.EndpointType.NODE_EP,
422
433
  ):
423
434
  super().__init__(name, after)
424
435
  self.class_name = class_name
@@ -438,6 +449,8 @@ class TaskStep(BaseStep):
438
449
  self.on_error = None
439
450
  self._inject_context = False
440
451
  self._call_with_event = False
452
+ self.model_endpoint_creation_strategy = model_endpoint_creation_strategy
453
+ self.endpoint_type = endpoint_type
441
454
 
442
455
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
443
456
  self.context = context
@@ -554,7 +567,11 @@ class TaskStep(BaseStep):
554
567
 
555
568
  def _post_init(self, mode="sync"):
556
569
  if self._object and hasattr(self._object, "post_init"):
557
- self._object.post_init(mode)
570
+ self._object.post_init(
571
+ mode,
572
+ creation_strategy=self.model_endpoint_creation_strategy,
573
+ endpoint_type=self.endpoint_type,
574
+ )
558
575
  if hasattr(self._object, "model_endpoint_uid"):
559
576
  self.endpoint_uid = self._object.model_endpoint_uid
560
577
  if hasattr(self._object, "name"):
@@ -705,6 +722,7 @@ class RouterStep(TaskStep):
705
722
  )
706
723
  self._routes: ObjectDict = None
707
724
  self.routes = routes
725
+ self.endpoint_type = schemas.EndpointType.ROUTER
708
726
 
709
727
  def get_children(self):
710
728
  """get child steps (routes)"""
@@ -726,6 +744,7 @@ class RouterStep(TaskStep):
726
744
  class_name=None,
727
745
  handler=None,
728
746
  function=None,
747
+ creation_strategy: schemas.ModelEndpointCreationStrategy = schemas.ModelEndpointCreationStrategy.INPLACE,
729
748
  **class_args,
730
749
  ):
731
750
  """add child route step or class to the router
@@ -736,12 +755,23 @@ class RouterStep(TaskStep):
736
755
  :param class_args: class init arguments
737
756
  :param handler: class handler to invoke on run/event
738
757
  :param function: function this step should run in
758
+ :param creation_strategy: model endpoint creation strategy :
759
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
760
+ * inplace - Use the existing model endpoint if it already exists (default).
761
+ * archive - Preserve the old model endpoint and create a new one,
762
+ tagging it as the latest.
739
763
  """
740
764
 
741
765
  if not route and not class_name and not handler:
742
766
  raise MLRunInvalidArgumentError("route or class_name must be specified")
743
767
  if not route:
744
- route = TaskStep(class_name, class_args, handler=handler)
768
+ route = TaskStep(
769
+ class_name,
770
+ class_args,
771
+ handler=handler,
772
+ model_endpoint_creation_strategy=creation_strategy,
773
+ endpoint_type=schemas.EndpointType.NODE_EP,
774
+ )
745
775
  route.function = function or route.function
746
776
 
747
777
  if len(self._routes) >= MAX_ALLOWED_STEPS:
@@ -805,6 +835,106 @@ class RouterStep(TaskStep):
805
835
  )
806
836
 
807
837
 
838
+ class Model(storey.ParallelExecutionRunnable):
839
+ def load(self) -> None:
840
+ """Override to load model if needed."""
841
+ pass
842
+
843
+ def init(self):
844
+ self.load()
845
+
846
+ def predict(self, body: Any) -> Any:
847
+ """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
848
+ return body
849
+
850
+ async def predict_async(self, body: Any) -> Any:
851
+ """Override to implement prediction logic if the logic requires asyncio."""
852
+ return body
853
+
854
+ def run(self, body: Any, path: str) -> Any:
855
+ return self.predict(body)
856
+
857
+ async def run_async(self, body: Any, path: str) -> Any:
858
+ return self.predict(body)
859
+
860
+
861
+ class ModelSelector:
862
+ """Used to select which models to run on each event."""
863
+
864
+ def select(
865
+ self, event, available_models: list[Model]
866
+ ) -> Union[list[str], list[Model]]:
867
+ """
868
+ Given an event, returns a list of model names or a list of model objects to run on the event.
869
+ If None is returned, all models will be run.
870
+
871
+ :param event: The full event
872
+ :param available_models: List of available models
873
+ """
874
+ pass
875
+
876
+
877
+ class ModelRunner(storey.ParallelExecution):
878
+ """
879
+ Runs multiple Models on each event. See ModelRunnerStep.
880
+
881
+ :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
882
+ event. Optional. If not passed, all models will be run.
883
+ """
884
+
885
+ def __init__(self, *args, model_selector: Optional[ModelSelector] = None, **kwargs):
886
+ super().__init__(*args, **kwargs)
887
+ self.model_selector = model_selector or ModelSelector()
888
+
889
+ def select_runnables(self, event):
890
+ models = cast(list[Model], self.runnables)
891
+ return self.model_selector.select(event, models)
892
+
893
+
894
+ class ModelRunnerStep(TaskStep):
895
+ """
896
+ Runs multiple Models on each event.
897
+
898
+ example::
899
+
900
+ model_runner_step = ModelRunnerStep(name="my_model_runner")
901
+ model_runner_step.add_model(MyModel(name="my_model"))
902
+ graph.to(model_runner_step)
903
+
904
+ :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
905
+ event. Optional. If not passed, all models will be run.
906
+ """
907
+
908
+ kind = "model_runner"
909
+
910
+ def __init__(
911
+ self,
912
+ *args,
913
+ model_selector: Optional[Union[str, ModelSelector]] = None,
914
+ **kwargs,
915
+ ):
916
+ self._models = []
917
+ super().__init__(
918
+ *args,
919
+ class_name="mlrun.serving.ModelRunner",
920
+ class_args=dict(runnables=self._models, model_selector=model_selector),
921
+ **kwargs,
922
+ )
923
+
924
+ def add_model(self, model: Model) -> None:
925
+ """Add a Model to this ModelRunner."""
926
+ self._models.append(model)
927
+
928
+ def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
929
+ model_selector = self.class_args.get("model_selector")
930
+ if isinstance(model_selector, str):
931
+ model_selector = get_class(model_selector, namespace)()
932
+ self._async_object = ModelRunner(
933
+ self.class_args.get("runnables"),
934
+ model_selector=model_selector,
935
+ )
936
+
937
+
808
938
  class QueueStep(BaseStep):
809
939
  """queue step, implement an async queue or represent a stream"""
810
940
 
@@ -1344,8 +1474,9 @@ class FlowStep(BaseStep):
1344
1474
 
1345
1475
  if self._controller:
1346
1476
  if hasattr(self._controller, "terminate"):
1347
- self._controller.terminate()
1348
- return self._controller.await_termination()
1477
+ return self._controller.terminate(wait=True)
1478
+ else:
1479
+ return self._controller.await_termination()
1349
1480
 
1350
1481
  def plot(self, filename=None, format=None, source=None, targets=None, **kw):
1351
1482
  """plot/save graph using graphviz
@@ -1433,6 +1564,7 @@ classes_map = {
1433
1564
  "queue": QueueStep,
1434
1565
  "error_step": ErrorStep,
1435
1566
  "monitoring_application": MonitoringApplicationStep,
1567
+ "model_runner": ModelRunnerStep,
1436
1568
  }
1437
1569
 
1438
1570
 
@@ -1572,6 +1704,10 @@ def params_to_step(
1572
1704
  input_path: Optional[str] = None,
1573
1705
  result_path: Optional[str] = None,
1574
1706
  class_args=None,
1707
+ model_endpoint_creation_strategy: Optional[
1708
+ schemas.ModelEndpointCreationStrategy
1709
+ ] = None,
1710
+ endpoint_type: Optional[schemas.EndpointType] = None,
1575
1711
  ):
1576
1712
  """return step object from provided params or classes/objects"""
1577
1713
 
@@ -1587,6 +1723,9 @@ def params_to_step(
1587
1723
  step.full_event = full_event or step.full_event
1588
1724
  step.input_path = input_path or step.input_path
1589
1725
  step.result_path = result_path or step.result_path
1726
+ if kind == StepKinds.task:
1727
+ step.model_endpoint_creation_strategy = model_endpoint_creation_strategy
1728
+ step.endpoint_type = endpoint_type
1590
1729
 
1591
1730
  elif class_name and class_name in queue_class_names:
1592
1731
  if "path" not in class_args:
@@ -1627,6 +1766,8 @@ def params_to_step(
1627
1766
  full_event=full_event,
1628
1767
  input_path=input_path,
1629
1768
  result_path=result_path,
1769
+ model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1770
+ endpoint_type=endpoint_type,
1630
1771
  )
1631
1772
  else:
1632
1773
  raise MLRunInvalidArgumentError("class_name or handler must be provided")
@@ -23,7 +23,6 @@ import mlrun.common.schemas.model_monitoring
23
23
  import mlrun.model_monitoring
24
24
  from mlrun.utils import logger, now_date
25
25
 
26
- from ..common.schemas.model_monitoring import ModelEndpointSchema
27
26
  from .server import GraphServer
28
27
  from .utils import StepToDict, _extract_input_data, _update_result_body
29
28
 
@@ -130,7 +129,7 @@ class V2ModelServer(StepToDict):
130
129
  self.ready = True
131
130
  self.context.logger.info(f"model {self.name} was loaded")
132
131
 
133
- def post_init(self, mode="sync"):
132
+ def post_init(self, mode="sync", **kwargs):
134
133
  """sync/async model loading, for internal use"""
135
134
  if not self.ready:
136
135
  if mode == "async":
@@ -149,7 +148,10 @@ class V2ModelServer(StepToDict):
149
148
 
150
149
  if not self.context.is_mock or self.context.monitoring_mock:
151
150
  self.model_endpoint_uid = _init_endpoint_record(
152
- graph_server=server, model=self
151
+ graph_server=server,
152
+ model=self,
153
+ creation_strategy=kwargs.get("creation_strategy"),
154
+ endpoint_type=kwargs.get("endpoint_type"),
153
155
  )
154
156
  self._model_logger = (
155
157
  _ModelLogPusher(self, self.context)
@@ -554,7 +556,10 @@ class _ModelLogPusher:
554
556
 
555
557
 
556
558
  def _init_endpoint_record(
557
- graph_server: GraphServer, model: V2ModelServer
559
+ graph_server: GraphServer,
560
+ model: V2ModelServer,
561
+ creation_strategy: str,
562
+ endpoint_type: mlrun.common.schemas.EndpointType,
558
563
  ) -> Union[str, None]:
559
564
  """
560
565
  Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
@@ -564,6 +569,12 @@ def _init_endpoint_record(
564
569
  :param graph_server: A GraphServer object which will be used for getting the function uri.
565
570
  :param model: Base model serving class (v2). It contains important details for the model endpoint record
566
571
  such as model name, model path, and model version.
572
+ :param creation_strategy: model endpoint creation strategy :
573
+ * overwrite - Create a new model endpoint and delete the last old one if it exists.
574
+ * inplace - Use the existing model endpoint if it already exists (default).
575
+ * archive - Preserve the old model endpoint and create a new one,
576
+ tagging it as the latest.
577
+ :param endpoint_type model endpoint type
567
578
 
568
579
  :return: Model endpoint unique ID.
569
580
  """
@@ -583,51 +594,30 @@ def _init_endpoint_record(
583
594
  model_uid = None
584
595
  model_tag = None
585
596
  model_labels = {}
586
- try:
587
- model_ep = mlrun.get_run_db().get_model_endpoint(
588
- project=graph_server.project,
589
- name=model.name,
590
- function_name=graph_server.function_name,
591
- function_tag=graph_server.function_tag or "latest",
592
- )
593
- except mlrun.errors.MLRunNotFoundError:
594
- model_ep = None
595
- except mlrun.errors.MLRunBadRequestError as err:
596
- logger.info(
597
- "Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
598
- )
599
- return
600
-
601
- function = mlrun.get_run_db().get_function(
602
- name=graph_server.function_name,
597
+ logger.info(
598
+ "Creating Or Updating a new model endpoint record",
599
+ name=model.name,
603
600
  project=graph_server.project,
604
- tag=graph_server.function_tag or "latest",
601
+ function_name=graph_server.function_name,
602
+ function_tag=graph_server.function_tag or "latest",
603
+ model_name=model_name,
604
+ model_tag=model_tag,
605
+ model_db_key=model_db_key,
606
+ model_uid=model_uid,
607
+ model_class=model.__class__.__name__,
608
+ creation_strategy=creation_strategy,
609
+ endpoint_type=endpoint_type,
605
610
  )
606
- function_uid = function.get("metadata", {}).get("uid")
607
- if not model_ep and model.context.server.track_models:
608
- logger.info(
609
- "Creating a new model endpoint record",
610
- name=model.name,
611
- project=graph_server.project,
612
- function_name=graph_server.function_name,
613
- function_tag=graph_server.function_tag or "latest",
614
- function_uid=function_uid,
615
- model_name=model_name,
616
- model_tag=model_tag,
617
- model_db_key=model_db_key,
618
- model_uid=model_uid,
619
- model_class=model.__class__.__name__,
620
- )
611
+ try:
621
612
  model_ep = mlrun.common.schemas.ModelEndpoint(
622
613
  metadata=mlrun.common.schemas.ModelEndpointMetadata(
623
614
  project=graph_server.project,
624
615
  labels=model_labels,
625
616
  name=model.name,
626
- endpoint_type=mlrun.common.schemas.model_monitoring.EndpointType.NODE_EP,
617
+ endpoint_type=endpoint_type,
627
618
  ),
628
619
  spec=mlrun.common.schemas.ModelEndpointSpec(
629
620
  function_name=graph_server.function_name,
630
- function_uid=function_uid,
631
621
  function_tag=graph_server.function_tag or "latest",
632
622
  model_name=model_name,
633
623
  model_db_key=model_db_key,
@@ -642,49 +632,11 @@ def _init_endpoint_record(
642
632
  ),
643
633
  )
644
634
  db = mlrun.get_run_db()
645
- model_ep = db.create_model_endpoint(model_endpoint=model_ep)
646
-
647
- elif model_ep:
648
- attributes = {}
649
- if function_uid != model_ep.spec.function_uid:
650
- attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
651
- if model_name != model_ep.spec.model_name:
652
- attributes[ModelEndpointSchema.MODEL_NAME] = model_name
653
- if model_uid != model_ep.spec.model_uid:
654
- attributes[ModelEndpointSchema.MODEL_UID] = model_uid
655
- if model_tag != model_ep.spec.model_tag:
656
- attributes[ModelEndpointSchema.MODEL_TAG] = model_tag
657
- if model_db_key != model_ep.spec.model_db_key:
658
- attributes[ModelEndpointSchema.MODEL_DB_KEY] = model_db_key
659
- if model_labels != model_ep.metadata.labels:
660
- attributes[ModelEndpointSchema.LABELS] = model_labels
661
- if model.__class__.__name__ != model_ep.spec.model_class:
662
- attributes[ModelEndpointSchema.MODEL_CLASS] = model.__class__.__name__
663
- if (
664
- model_ep.status.monitoring_mode
665
- == mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
666
- ) != model.context.server.track_models:
667
- attributes[ModelEndpointSchema.MONITORING_MODE] = (
668
- mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
669
- if model.context.server.track_models
670
- else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
671
- )
672
- if attributes:
673
- logger.info(
674
- "Updating model endpoint attributes",
675
- attributes=attributes,
676
- project=model_ep.metadata.project,
677
- name=model_ep.metadata.name,
678
- function_name=model_ep.spec.function_name,
679
- )
680
- db = mlrun.get_run_db()
681
- model_ep = db.patch_model_endpoint(
682
- project=model_ep.metadata.project,
683
- name=model_ep.metadata.name,
684
- endpoint_id=model_ep.metadata.uid,
685
- attributes=attributes,
686
- )
687
- else:
635
+ model_ep = db.create_model_endpoint(
636
+ model_endpoint=model_ep, creation_strategy=creation_strategy
637
+ )
638
+ except mlrun.errors.MLRunBadRequestError as e:
639
+ logger.info("Failed to create model endpoint record", error=e)
688
640
  return None
689
641
 
690
642
  return model_ep.metadata.uid