mlrun 1.8.0rc10__py3-none-any.whl → 1.8.0rc12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/common/constants.py +1 -0
- mlrun/common/schemas/__init__.py +2 -0
- mlrun/common/schemas/model_monitoring/__init__.py +1 -0
- mlrun/common/schemas/model_monitoring/constants.py +6 -0
- mlrun/common/schemas/model_monitoring/model_endpoints.py +35 -0
- mlrun/db/base.py +2 -0
- mlrun/db/httpdb.py +12 -0
- mlrun/db/nopdb.py +2 -0
- mlrun/feature_store/steps.py +1 -1
- mlrun/model_monitoring/api.py +26 -19
- mlrun/model_monitoring/applications/base.py +42 -4
- mlrun/projects/project.py +18 -16
- mlrun/runtimes/nuclio/serving.py +28 -5
- mlrun/serving/__init__.py +8 -0
- mlrun/serving/merger.py +1 -1
- mlrun/serving/remote.py +17 -5
- mlrun/serving/routers.py +27 -87
- mlrun/serving/server.py +6 -2
- mlrun/serving/states.py +154 -13
- mlrun/serving/v2_serving.py +34 -82
- mlrun/utils/helpers.py +6 -0
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc12.dist-info}/METADATA +10 -10
- {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc12.dist-info}/RECORD +28 -28
- {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc12.dist-info}/LICENSE +0 -0
- {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc12.dist-info}/WHEEL +0 -0
- {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc12.dist-info}/entry_points.txt +0 -0
- {mlrun-1.8.0rc10.dist-info → mlrun-1.8.0rc12.dist-info}/top_level.txt +0 -0
mlrun/serving/routers.py
CHANGED
|
@@ -30,7 +30,6 @@ import mlrun.common.model_monitoring
|
|
|
30
30
|
import mlrun.common.schemas.model_monitoring
|
|
31
31
|
from mlrun.utils import logger, now_date
|
|
32
32
|
|
|
33
|
-
from ..common.schemas.model_monitoring import ModelEndpointSchema
|
|
34
33
|
from .server import GraphServer
|
|
35
34
|
from .utils import RouterToDict, _extract_input_data, _update_result_body
|
|
36
35
|
from .v2_serving import _ModelLogPusher
|
|
@@ -110,7 +109,7 @@ class BaseModelRouter(RouterToDict):
|
|
|
110
109
|
|
|
111
110
|
return parsed_event
|
|
112
111
|
|
|
113
|
-
def post_init(self, mode="sync"):
|
|
112
|
+
def post_init(self, mode="sync", **kwargs):
|
|
114
113
|
self.context.logger.info(f"Loaded {list(self.routes.keys())}")
|
|
115
114
|
|
|
116
115
|
def get_metadata(self):
|
|
@@ -610,7 +609,7 @@ class VotingEnsemble(ParallelRun):
|
|
|
610
609
|
self.model_endpoint_uid = None
|
|
611
610
|
self.shard_by_endpoint = shard_by_endpoint
|
|
612
611
|
|
|
613
|
-
def post_init(self, mode="sync"):
|
|
612
|
+
def post_init(self, mode="sync", **kwargs):
|
|
614
613
|
server = getattr(self.context, "_server", None) or getattr(
|
|
615
614
|
self.context, "server", None
|
|
616
615
|
)
|
|
@@ -619,7 +618,9 @@ class VotingEnsemble(ParallelRun):
|
|
|
619
618
|
return
|
|
620
619
|
|
|
621
620
|
if not self.context.is_mock or self.context.monitoring_mock:
|
|
622
|
-
self.model_endpoint_uid = _init_endpoint_record(
|
|
621
|
+
self.model_endpoint_uid = _init_endpoint_record(
|
|
622
|
+
server, self, creation_strategy=kwargs.get("creation_strategy")
|
|
623
|
+
)
|
|
623
624
|
|
|
624
625
|
self._update_weights(self.weights)
|
|
625
626
|
|
|
@@ -1001,7 +1002,10 @@ class VotingEnsemble(ParallelRun):
|
|
|
1001
1002
|
|
|
1002
1003
|
|
|
1003
1004
|
def _init_endpoint_record(
|
|
1004
|
-
graph_server: GraphServer,
|
|
1005
|
+
graph_server: GraphServer,
|
|
1006
|
+
voting_ensemble: VotingEnsemble,
|
|
1007
|
+
creation_strategy: str,
|
|
1008
|
+
endpoint_type: mlrun.common.schemas.EndpointType,
|
|
1005
1009
|
) -> Union[str, None]:
|
|
1006
1010
|
"""
|
|
1007
1011
|
Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
|
|
@@ -1011,61 +1015,44 @@ def _init_endpoint_record(
|
|
|
1011
1015
|
:param graph_server: A GraphServer object which will be used for getting the function uri.
|
|
1012
1016
|
:param voting_ensemble: Voting ensemble serving class. It contains important details for the model endpoint record
|
|
1013
1017
|
such as model name, model path, model version, and the ids of the children model endpoints.
|
|
1014
|
-
|
|
1018
|
+
:param creation_strategy: model endpoint creation strategy :
|
|
1019
|
+
* overwrite - Create a new model endpoint and delete the last old one if it exists.
|
|
1020
|
+
* inplace - Use the existing model endpoint if it already exists (default).
|
|
1021
|
+
* archive - Preserve the old model endpoint and create a new one,
|
|
1022
|
+
tagging it as the latest.
|
|
1023
|
+
:param endpoint_type: model endpoint type
|
|
1015
1024
|
:return: Model endpoint unique ID.
|
|
1016
1025
|
"""
|
|
1017
1026
|
|
|
1018
1027
|
logger.info("Initializing endpoint records")
|
|
1019
|
-
try:
|
|
1020
|
-
model_endpoint = mlrun.get_run_db().get_model_endpoint(
|
|
1021
|
-
project=graph_server.project,
|
|
1022
|
-
name=voting_ensemble.name,
|
|
1023
|
-
function_name=graph_server.function_name,
|
|
1024
|
-
function_tag=graph_server.function_tag or "latest",
|
|
1025
|
-
)
|
|
1026
|
-
except mlrun.errors.MLRunNotFoundError:
|
|
1027
|
-
model_endpoint = None
|
|
1028
|
-
except mlrun.errors.MLRunBadRequestError as err:
|
|
1029
|
-
logger.info(
|
|
1030
|
-
"Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
|
|
1031
|
-
)
|
|
1032
|
-
return
|
|
1033
|
-
|
|
1034
|
-
function = mlrun.get_run_db().get_function(
|
|
1035
|
-
name=graph_server.function_name,
|
|
1036
|
-
project=graph_server.project,
|
|
1037
|
-
tag=graph_server.function_tag or "latest",
|
|
1038
|
-
)
|
|
1039
|
-
function_uid = function.get("metadata", {}).get("uid")
|
|
1040
|
-
# Get the children model endpoints ids
|
|
1041
1028
|
children_uids = []
|
|
1042
1029
|
children_names = []
|
|
1043
1030
|
for _, c in voting_ensemble.routes.items():
|
|
1044
1031
|
if hasattr(c, "endpoint_uid"):
|
|
1045
1032
|
children_uids.append(c.endpoint_uid)
|
|
1046
1033
|
children_names.append(c.name)
|
|
1047
|
-
|
|
1034
|
+
try:
|
|
1048
1035
|
logger.info(
|
|
1049
|
-
"Creating a new model endpoint record",
|
|
1036
|
+
"Creating Or Updating a new model endpoint record",
|
|
1050
1037
|
name=voting_ensemble.name,
|
|
1051
1038
|
project=graph_server.project,
|
|
1052
1039
|
function_name=graph_server.function_name,
|
|
1053
1040
|
function_tag=graph_server.function_tag or "latest",
|
|
1054
|
-
function_uid=function_uid,
|
|
1055
1041
|
model_class=voting_ensemble.__class__.__name__,
|
|
1042
|
+
creation_strategy=creation_strategy,
|
|
1056
1043
|
)
|
|
1057
1044
|
model_endpoint = mlrun.common.schemas.ModelEndpoint(
|
|
1058
1045
|
metadata=mlrun.common.schemas.ModelEndpointMetadata(
|
|
1059
1046
|
project=graph_server.project,
|
|
1060
1047
|
name=voting_ensemble.name,
|
|
1061
|
-
endpoint_type=
|
|
1048
|
+
endpoint_type=endpoint_type,
|
|
1062
1049
|
),
|
|
1063
1050
|
spec=mlrun.common.schemas.ModelEndpointSpec(
|
|
1064
1051
|
function_name=graph_server.function_name,
|
|
1065
|
-
function_uid=function_uid,
|
|
1066
1052
|
function_tag=graph_server.function_tag or "latest",
|
|
1067
1053
|
model_class=voting_ensemble.__class__.__name__,
|
|
1068
|
-
children_uids=
|
|
1054
|
+
children_uids=children_uids,
|
|
1055
|
+
children=children_names,
|
|
1069
1056
|
),
|
|
1070
1057
|
status=mlrun.common.schemas.ModelEndpointStatus(
|
|
1071
1058
|
monitoring_mode=mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
|
|
@@ -1074,59 +1061,12 @@ def _init_endpoint_record(
|
|
|
1074
1061
|
),
|
|
1075
1062
|
)
|
|
1076
1063
|
db = mlrun.get_run_db()
|
|
1077
|
-
db.create_model_endpoint(
|
|
1078
|
-
|
|
1079
|
-
elif model_endpoint:
|
|
1080
|
-
attributes = {}
|
|
1081
|
-
if function_uid != model_endpoint.spec.function_uid:
|
|
1082
|
-
attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
|
|
1083
|
-
if children_uids != model_endpoint.spec.children_uids:
|
|
1084
|
-
attributes[ModelEndpointSchema.CHILDREN_UIDS] = children_uids
|
|
1085
|
-
if (
|
|
1086
|
-
model_endpoint.status.monitoring_mode
|
|
1087
|
-
== mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
|
|
1088
|
-
) != voting_ensemble.context.server.track_models:
|
|
1089
|
-
attributes[ModelEndpointSchema.MONITORING_MODE] = (
|
|
1090
|
-
mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
|
|
1091
|
-
if voting_ensemble.context.server.track_models
|
|
1092
|
-
else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
|
|
1093
|
-
)
|
|
1094
|
-
if attributes:
|
|
1095
|
-
db = mlrun.get_run_db()
|
|
1096
|
-
logger.info(
|
|
1097
|
-
"Updating model endpoint attributes",
|
|
1098
|
-
attributes=attributes,
|
|
1099
|
-
project=model_endpoint.metadata.project,
|
|
1100
|
-
name=model_endpoint.metadata.name,
|
|
1101
|
-
function_name=model_endpoint.spec.function_name,
|
|
1102
|
-
)
|
|
1103
|
-
model_endpoint = db.patch_model_endpoint(
|
|
1104
|
-
project=model_endpoint.metadata.project,
|
|
1105
|
-
name=model_endpoint.metadata.name,
|
|
1106
|
-
endpoint_id=model_endpoint.metadata.uid,
|
|
1107
|
-
attributes=attributes,
|
|
1108
|
-
)
|
|
1109
|
-
else:
|
|
1110
|
-
logger.info(
|
|
1111
|
-
"Did not create a new model endpoint record, monitoring is disabled"
|
|
1064
|
+
db.create_model_endpoint(
|
|
1065
|
+
model_endpoint=model_endpoint, creation_strategy=creation_strategy
|
|
1112
1066
|
)
|
|
1067
|
+
except mlrun.errors.MLRunInvalidArgumentError as e:
|
|
1068
|
+
logger.info("Failed to create model endpoint record", error=e)
|
|
1113
1069
|
return None
|
|
1114
|
-
|
|
1115
|
-
# Update model endpoint children type
|
|
1116
|
-
logger.info(
|
|
1117
|
-
"Updating children model endpoint type",
|
|
1118
|
-
children_uids=children_uids,
|
|
1119
|
-
children_names=children_names,
|
|
1120
|
-
)
|
|
1121
|
-
for uid, name in zip(children_uids, children_names):
|
|
1122
|
-
mlrun.get_run_db().patch_model_endpoint(
|
|
1123
|
-
name=name,
|
|
1124
|
-
project=graph_server.project,
|
|
1125
|
-
endpoint_id=uid,
|
|
1126
|
-
attributes={
|
|
1127
|
-
ModelEndpointSchema.ENDPOINT_TYPE: mlrun.common.schemas.model_monitoring.EndpointType.LEAF_EP
|
|
1128
|
-
},
|
|
1129
|
-
)
|
|
1130
1070
|
return model_endpoint.metadata.uid
|
|
1131
1071
|
|
|
1132
1072
|
|
|
@@ -1192,7 +1132,7 @@ class EnrichmentModelRouter(ModelRouter):
|
|
|
1192
1132
|
|
|
1193
1133
|
self._feature_service = None
|
|
1194
1134
|
|
|
1195
|
-
def post_init(self, mode="sync"):
|
|
1135
|
+
def post_init(self, mode="sync", **kwargs):
|
|
1196
1136
|
from ..feature_store import get_feature_vector
|
|
1197
1137
|
|
|
1198
1138
|
super().post_init(mode)
|
|
@@ -1342,7 +1282,7 @@ class EnrichmentVotingEnsemble(VotingEnsemble):
|
|
|
1342
1282
|
|
|
1343
1283
|
self._feature_service = None
|
|
1344
1284
|
|
|
1345
|
-
def post_init(self, mode="sync"):
|
|
1285
|
+
def post_init(self, mode="sync", **kwargs):
|
|
1346
1286
|
from ..feature_store import get_feature_vector
|
|
1347
1287
|
|
|
1348
1288
|
super().post_init(mode)
|
mlrun/serving/server.py
CHANGED
|
@@ -367,7 +367,9 @@ def _set_callbacks(server, context):
|
|
|
367
367
|
|
|
368
368
|
async def termination_callback():
|
|
369
369
|
context.logger.info("Termination callback called")
|
|
370
|
-
server.wait_for_completion()
|
|
370
|
+
maybe_coroutine = server.wait_for_completion()
|
|
371
|
+
if asyncio.iscoroutine(maybe_coroutine):
|
|
372
|
+
await maybe_coroutine
|
|
371
373
|
context.logger.info("Termination of async flow is completed")
|
|
372
374
|
|
|
373
375
|
context.platform.set_termination_callback(termination_callback)
|
|
@@ -379,7 +381,9 @@ def _set_callbacks(server, context):
|
|
|
379
381
|
|
|
380
382
|
async def drain_callback():
|
|
381
383
|
context.logger.info("Drain callback called")
|
|
382
|
-
server.wait_for_completion()
|
|
384
|
+
maybe_coroutine = server.wait_for_completion()
|
|
385
|
+
if asyncio.iscoroutine(maybe_coroutine):
|
|
386
|
+
await maybe_coroutine
|
|
383
387
|
context.logger.info(
|
|
384
388
|
"Termination of async flow is completed. Rerunning async flow."
|
|
385
389
|
)
|
mlrun/serving/states.py
CHANGED
|
@@ -25,11 +25,12 @@ import pathlib
|
|
|
25
25
|
import traceback
|
|
26
26
|
from copy import copy, deepcopy
|
|
27
27
|
from inspect import getfullargspec, signature
|
|
28
|
-
from typing import Any, Optional, Union
|
|
28
|
+
from typing import Any, Optional, Union, cast
|
|
29
29
|
|
|
30
30
|
import storey.utils
|
|
31
31
|
|
|
32
32
|
import mlrun
|
|
33
|
+
import mlrun.common.schemas as schemas
|
|
33
34
|
|
|
34
35
|
from ..config import config
|
|
35
36
|
from ..datastore import get_stream_pusher
|
|
@@ -81,22 +82,28 @@ _task_step_fields = [
|
|
|
81
82
|
"responder",
|
|
82
83
|
"input_path",
|
|
83
84
|
"result_path",
|
|
85
|
+
"model_endpoint_creation_strategy",
|
|
86
|
+
"endpoint_type",
|
|
84
87
|
]
|
|
85
88
|
|
|
86
89
|
|
|
87
90
|
MAX_ALLOWED_STEPS = 4500
|
|
88
91
|
|
|
89
92
|
|
|
90
|
-
def
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def new_remote_endpoint(url, **class_args):
|
|
93
|
+
def new_remote_endpoint(
|
|
94
|
+
url: str,
|
|
95
|
+
creation_strategy: schemas.ModelEndpointCreationStrategy,
|
|
96
|
+
endpoint_type: schemas.EndpointType,
|
|
97
|
+
**class_args,
|
|
98
|
+
):
|
|
97
99
|
class_args = deepcopy(class_args)
|
|
98
100
|
class_args["url"] = url
|
|
99
|
-
return TaskStep(
|
|
101
|
+
return TaskStep(
|
|
102
|
+
"$remote",
|
|
103
|
+
class_args=class_args,
|
|
104
|
+
model_endpoint_creation_strategy=creation_strategy,
|
|
105
|
+
endpoint_type=endpoint_type,
|
|
106
|
+
)
|
|
100
107
|
|
|
101
108
|
|
|
102
109
|
class BaseStep(ModelObj):
|
|
@@ -419,6 +426,10 @@ class TaskStep(BaseStep):
|
|
|
419
426
|
responder: Optional[bool] = None,
|
|
420
427
|
input_path: Optional[str] = None,
|
|
421
428
|
result_path: Optional[str] = None,
|
|
429
|
+
model_endpoint_creation_strategy: Optional[
|
|
430
|
+
schemas.ModelEndpointCreationStrategy
|
|
431
|
+
] = schemas.ModelEndpointCreationStrategy.INPLACE,
|
|
432
|
+
endpoint_type: Optional[schemas.EndpointType] = schemas.EndpointType.NODE_EP,
|
|
422
433
|
):
|
|
423
434
|
super().__init__(name, after)
|
|
424
435
|
self.class_name = class_name
|
|
@@ -438,6 +449,8 @@ class TaskStep(BaseStep):
|
|
|
438
449
|
self.on_error = None
|
|
439
450
|
self._inject_context = False
|
|
440
451
|
self._call_with_event = False
|
|
452
|
+
self.model_endpoint_creation_strategy = model_endpoint_creation_strategy
|
|
453
|
+
self.endpoint_type = endpoint_type
|
|
441
454
|
|
|
442
455
|
def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
|
|
443
456
|
self.context = context
|
|
@@ -554,7 +567,11 @@ class TaskStep(BaseStep):
|
|
|
554
567
|
|
|
555
568
|
def _post_init(self, mode="sync"):
|
|
556
569
|
if self._object and hasattr(self._object, "post_init"):
|
|
557
|
-
self._object.post_init(
|
|
570
|
+
self._object.post_init(
|
|
571
|
+
mode,
|
|
572
|
+
creation_strategy=self.model_endpoint_creation_strategy,
|
|
573
|
+
endpoint_type=self.endpoint_type,
|
|
574
|
+
)
|
|
558
575
|
if hasattr(self._object, "model_endpoint_uid"):
|
|
559
576
|
self.endpoint_uid = self._object.model_endpoint_uid
|
|
560
577
|
if hasattr(self._object, "name"):
|
|
@@ -705,6 +722,7 @@ class RouterStep(TaskStep):
|
|
|
705
722
|
)
|
|
706
723
|
self._routes: ObjectDict = None
|
|
707
724
|
self.routes = routes
|
|
725
|
+
self.endpoint_type = schemas.EndpointType.ROUTER
|
|
708
726
|
|
|
709
727
|
def get_children(self):
|
|
710
728
|
"""get child steps (routes)"""
|
|
@@ -726,6 +744,7 @@ class RouterStep(TaskStep):
|
|
|
726
744
|
class_name=None,
|
|
727
745
|
handler=None,
|
|
728
746
|
function=None,
|
|
747
|
+
creation_strategy: schemas.ModelEndpointCreationStrategy = schemas.ModelEndpointCreationStrategy.INPLACE,
|
|
729
748
|
**class_args,
|
|
730
749
|
):
|
|
731
750
|
"""add child route step or class to the router
|
|
@@ -736,12 +755,23 @@ class RouterStep(TaskStep):
|
|
|
736
755
|
:param class_args: class init arguments
|
|
737
756
|
:param handler: class handler to invoke on run/event
|
|
738
757
|
:param function: function this step should run in
|
|
758
|
+
:param creation_strategy: model endpoint creation strategy :
|
|
759
|
+
* overwrite - Create a new model endpoint and delete the last old one if it exists.
|
|
760
|
+
* inplace - Use the existing model endpoint if it already exists (default).
|
|
761
|
+
* archive - Preserve the old model endpoint and create a new one,
|
|
762
|
+
tagging it as the latest.
|
|
739
763
|
"""
|
|
740
764
|
|
|
741
765
|
if not route and not class_name and not handler:
|
|
742
766
|
raise MLRunInvalidArgumentError("route or class_name must be specified")
|
|
743
767
|
if not route:
|
|
744
|
-
route = TaskStep(
|
|
768
|
+
route = TaskStep(
|
|
769
|
+
class_name,
|
|
770
|
+
class_args,
|
|
771
|
+
handler=handler,
|
|
772
|
+
model_endpoint_creation_strategy=creation_strategy,
|
|
773
|
+
endpoint_type=schemas.EndpointType.NODE_EP,
|
|
774
|
+
)
|
|
745
775
|
route.function = function or route.function
|
|
746
776
|
|
|
747
777
|
if len(self._routes) >= MAX_ALLOWED_STEPS:
|
|
@@ -805,6 +835,106 @@ class RouterStep(TaskStep):
|
|
|
805
835
|
)
|
|
806
836
|
|
|
807
837
|
|
|
838
|
+
class Model(storey.ParallelExecutionRunnable):
|
|
839
|
+
def load(self) -> None:
|
|
840
|
+
"""Override to load model if needed."""
|
|
841
|
+
pass
|
|
842
|
+
|
|
843
|
+
def init(self):
|
|
844
|
+
self.load()
|
|
845
|
+
|
|
846
|
+
def predict(self, body: Any) -> Any:
|
|
847
|
+
"""Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
|
|
848
|
+
return body
|
|
849
|
+
|
|
850
|
+
async def predict_async(self, body: Any) -> Any:
|
|
851
|
+
"""Override to implement prediction logic if the logic requires asyncio."""
|
|
852
|
+
return body
|
|
853
|
+
|
|
854
|
+
def run(self, body: Any, path: str) -> Any:
|
|
855
|
+
return self.predict(body)
|
|
856
|
+
|
|
857
|
+
async def run_async(self, body: Any, path: str) -> Any:
|
|
858
|
+
return self.predict(body)
|
|
859
|
+
|
|
860
|
+
|
|
861
|
+
class ModelSelector:
|
|
862
|
+
"""Used to select which models to run on each event."""
|
|
863
|
+
|
|
864
|
+
def select(
|
|
865
|
+
self, event, available_models: list[Model]
|
|
866
|
+
) -> Union[list[str], list[Model]]:
|
|
867
|
+
"""
|
|
868
|
+
Given an event, returns a list of model names or a list of model objects to run on the event.
|
|
869
|
+
If None is returned, all models will be run.
|
|
870
|
+
|
|
871
|
+
:param event: The full event
|
|
872
|
+
:param available_models: List of available models
|
|
873
|
+
"""
|
|
874
|
+
pass
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
class ModelRunner(storey.ParallelExecution):
|
|
878
|
+
"""
|
|
879
|
+
Runs multiple Models on each event. See ModelRunnerStep.
|
|
880
|
+
|
|
881
|
+
:param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
|
|
882
|
+
event. Optional. If not passed, all models will be run.
|
|
883
|
+
"""
|
|
884
|
+
|
|
885
|
+
def __init__(self, *args, model_selector: Optional[ModelSelector] = None, **kwargs):
|
|
886
|
+
super().__init__(*args, **kwargs)
|
|
887
|
+
self.model_selector = model_selector or ModelSelector()
|
|
888
|
+
|
|
889
|
+
def select_runnables(self, event):
|
|
890
|
+
models = cast(list[Model], self.runnables)
|
|
891
|
+
return self.model_selector.select(event, models)
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
class ModelRunnerStep(TaskStep):
|
|
895
|
+
"""
|
|
896
|
+
Runs multiple Models on each event.
|
|
897
|
+
|
|
898
|
+
example::
|
|
899
|
+
|
|
900
|
+
model_runner_step = ModelRunnerStep(name="my_model_runner")
|
|
901
|
+
model_runner_step.add_model(MyModel(name="my_model"))
|
|
902
|
+
graph.to(model_runner_step)
|
|
903
|
+
|
|
904
|
+
:param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
|
|
905
|
+
event. Optional. If not passed, all models will be run.
|
|
906
|
+
"""
|
|
907
|
+
|
|
908
|
+
kind = "model_runner"
|
|
909
|
+
|
|
910
|
+
def __init__(
|
|
911
|
+
self,
|
|
912
|
+
*args,
|
|
913
|
+
model_selector: Optional[Union[str, ModelSelector]] = None,
|
|
914
|
+
**kwargs,
|
|
915
|
+
):
|
|
916
|
+
self._models = []
|
|
917
|
+
super().__init__(
|
|
918
|
+
*args,
|
|
919
|
+
class_name="mlrun.serving.ModelRunner",
|
|
920
|
+
class_args=dict(runnables=self._models, model_selector=model_selector),
|
|
921
|
+
**kwargs,
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
def add_model(self, model: Model) -> None:
|
|
925
|
+
"""Add a Model to this ModelRunner."""
|
|
926
|
+
self._models.append(model)
|
|
927
|
+
|
|
928
|
+
def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
|
|
929
|
+
model_selector = self.class_args.get("model_selector")
|
|
930
|
+
if isinstance(model_selector, str):
|
|
931
|
+
model_selector = get_class(model_selector, namespace)()
|
|
932
|
+
self._async_object = ModelRunner(
|
|
933
|
+
self.class_args.get("runnables"),
|
|
934
|
+
model_selector=model_selector,
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
|
|
808
938
|
class QueueStep(BaseStep):
|
|
809
939
|
"""queue step, implement an async queue or represent a stream"""
|
|
810
940
|
|
|
@@ -1344,8 +1474,9 @@ class FlowStep(BaseStep):
|
|
|
1344
1474
|
|
|
1345
1475
|
if self._controller:
|
|
1346
1476
|
if hasattr(self._controller, "terminate"):
|
|
1347
|
-
self._controller.terminate()
|
|
1348
|
-
|
|
1477
|
+
return self._controller.terminate(wait=True)
|
|
1478
|
+
else:
|
|
1479
|
+
return self._controller.await_termination()
|
|
1349
1480
|
|
|
1350
1481
|
def plot(self, filename=None, format=None, source=None, targets=None, **kw):
|
|
1351
1482
|
"""plot/save graph using graphviz
|
|
@@ -1433,6 +1564,7 @@ classes_map = {
|
|
|
1433
1564
|
"queue": QueueStep,
|
|
1434
1565
|
"error_step": ErrorStep,
|
|
1435
1566
|
"monitoring_application": MonitoringApplicationStep,
|
|
1567
|
+
"model_runner": ModelRunnerStep,
|
|
1436
1568
|
}
|
|
1437
1569
|
|
|
1438
1570
|
|
|
@@ -1572,6 +1704,10 @@ def params_to_step(
|
|
|
1572
1704
|
input_path: Optional[str] = None,
|
|
1573
1705
|
result_path: Optional[str] = None,
|
|
1574
1706
|
class_args=None,
|
|
1707
|
+
model_endpoint_creation_strategy: Optional[
|
|
1708
|
+
schemas.ModelEndpointCreationStrategy
|
|
1709
|
+
] = None,
|
|
1710
|
+
endpoint_type: Optional[schemas.EndpointType] = None,
|
|
1575
1711
|
):
|
|
1576
1712
|
"""return step object from provided params or classes/objects"""
|
|
1577
1713
|
|
|
@@ -1587,6 +1723,9 @@ def params_to_step(
|
|
|
1587
1723
|
step.full_event = full_event or step.full_event
|
|
1588
1724
|
step.input_path = input_path or step.input_path
|
|
1589
1725
|
step.result_path = result_path or step.result_path
|
|
1726
|
+
if kind == StepKinds.task:
|
|
1727
|
+
step.model_endpoint_creation_strategy = model_endpoint_creation_strategy
|
|
1728
|
+
step.endpoint_type = endpoint_type
|
|
1590
1729
|
|
|
1591
1730
|
elif class_name and class_name in queue_class_names:
|
|
1592
1731
|
if "path" not in class_args:
|
|
@@ -1627,6 +1766,8 @@ def params_to_step(
|
|
|
1627
1766
|
full_event=full_event,
|
|
1628
1767
|
input_path=input_path,
|
|
1629
1768
|
result_path=result_path,
|
|
1769
|
+
model_endpoint_creation_strategy=model_endpoint_creation_strategy,
|
|
1770
|
+
endpoint_type=endpoint_type,
|
|
1630
1771
|
)
|
|
1631
1772
|
else:
|
|
1632
1773
|
raise MLRunInvalidArgumentError("class_name or handler must be provided")
|
mlrun/serving/v2_serving.py
CHANGED
|
@@ -23,7 +23,6 @@ import mlrun.common.schemas.model_monitoring
|
|
|
23
23
|
import mlrun.model_monitoring
|
|
24
24
|
from mlrun.utils import logger, now_date
|
|
25
25
|
|
|
26
|
-
from ..common.schemas.model_monitoring import ModelEndpointSchema
|
|
27
26
|
from .server import GraphServer
|
|
28
27
|
from .utils import StepToDict, _extract_input_data, _update_result_body
|
|
29
28
|
|
|
@@ -130,7 +129,7 @@ class V2ModelServer(StepToDict):
|
|
|
130
129
|
self.ready = True
|
|
131
130
|
self.context.logger.info(f"model {self.name} was loaded")
|
|
132
131
|
|
|
133
|
-
def post_init(self, mode="sync"):
|
|
132
|
+
def post_init(self, mode="sync", **kwargs):
|
|
134
133
|
"""sync/async model loading, for internal use"""
|
|
135
134
|
if not self.ready:
|
|
136
135
|
if mode == "async":
|
|
@@ -149,7 +148,10 @@ class V2ModelServer(StepToDict):
|
|
|
149
148
|
|
|
150
149
|
if not self.context.is_mock or self.context.monitoring_mock:
|
|
151
150
|
self.model_endpoint_uid = _init_endpoint_record(
|
|
152
|
-
graph_server=server,
|
|
151
|
+
graph_server=server,
|
|
152
|
+
model=self,
|
|
153
|
+
creation_strategy=kwargs.get("creation_strategy"),
|
|
154
|
+
endpoint_type=kwargs.get("endpoint_type"),
|
|
153
155
|
)
|
|
154
156
|
self._model_logger = (
|
|
155
157
|
_ModelLogPusher(self, self.context)
|
|
@@ -554,7 +556,10 @@ class _ModelLogPusher:
|
|
|
554
556
|
|
|
555
557
|
|
|
556
558
|
def _init_endpoint_record(
|
|
557
|
-
graph_server: GraphServer,
|
|
559
|
+
graph_server: GraphServer,
|
|
560
|
+
model: V2ModelServer,
|
|
561
|
+
creation_strategy: str,
|
|
562
|
+
endpoint_type: mlrun.common.schemas.EndpointType,
|
|
558
563
|
) -> Union[str, None]:
|
|
559
564
|
"""
|
|
560
565
|
Initialize model endpoint record and write it into the DB. In general, this method retrieve the unique model
|
|
@@ -564,6 +569,12 @@ def _init_endpoint_record(
|
|
|
564
569
|
:param graph_server: A GraphServer object which will be used for getting the function uri.
|
|
565
570
|
:param model: Base model serving class (v2). It contains important details for the model endpoint record
|
|
566
571
|
such as model name, model path, and model version.
|
|
572
|
+
:param creation_strategy: model endpoint creation strategy :
|
|
573
|
+
* overwrite - Create a new model endpoint and delete the last old one if it exists.
|
|
574
|
+
* inplace - Use the existing model endpoint if it already exists (default).
|
|
575
|
+
* archive - Preserve the old model endpoint and create a new one,
|
|
576
|
+
tagging it as the latest.
|
|
577
|
+
:param endpoint_type model endpoint type
|
|
567
578
|
|
|
568
579
|
:return: Model endpoint unique ID.
|
|
569
580
|
"""
|
|
@@ -583,51 +594,30 @@ def _init_endpoint_record(
|
|
|
583
594
|
model_uid = None
|
|
584
595
|
model_tag = None
|
|
585
596
|
model_labels = {}
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
name=model.name,
|
|
590
|
-
function_name=graph_server.function_name,
|
|
591
|
-
function_tag=graph_server.function_tag or "latest",
|
|
592
|
-
)
|
|
593
|
-
except mlrun.errors.MLRunNotFoundError:
|
|
594
|
-
model_ep = None
|
|
595
|
-
except mlrun.errors.MLRunBadRequestError as err:
|
|
596
|
-
logger.info(
|
|
597
|
-
"Cannot get the model endpoints store", err=mlrun.errors.err_to_str(err)
|
|
598
|
-
)
|
|
599
|
-
return
|
|
600
|
-
|
|
601
|
-
function = mlrun.get_run_db().get_function(
|
|
602
|
-
name=graph_server.function_name,
|
|
597
|
+
logger.info(
|
|
598
|
+
"Creating Or Updating a new model endpoint record",
|
|
599
|
+
name=model.name,
|
|
603
600
|
project=graph_server.project,
|
|
604
|
-
|
|
601
|
+
function_name=graph_server.function_name,
|
|
602
|
+
function_tag=graph_server.function_tag or "latest",
|
|
603
|
+
model_name=model_name,
|
|
604
|
+
model_tag=model_tag,
|
|
605
|
+
model_db_key=model_db_key,
|
|
606
|
+
model_uid=model_uid,
|
|
607
|
+
model_class=model.__class__.__name__,
|
|
608
|
+
creation_strategy=creation_strategy,
|
|
609
|
+
endpoint_type=endpoint_type,
|
|
605
610
|
)
|
|
606
|
-
|
|
607
|
-
if not model_ep and model.context.server.track_models:
|
|
608
|
-
logger.info(
|
|
609
|
-
"Creating a new model endpoint record",
|
|
610
|
-
name=model.name,
|
|
611
|
-
project=graph_server.project,
|
|
612
|
-
function_name=graph_server.function_name,
|
|
613
|
-
function_tag=graph_server.function_tag or "latest",
|
|
614
|
-
function_uid=function_uid,
|
|
615
|
-
model_name=model_name,
|
|
616
|
-
model_tag=model_tag,
|
|
617
|
-
model_db_key=model_db_key,
|
|
618
|
-
model_uid=model_uid,
|
|
619
|
-
model_class=model.__class__.__name__,
|
|
620
|
-
)
|
|
611
|
+
try:
|
|
621
612
|
model_ep = mlrun.common.schemas.ModelEndpoint(
|
|
622
613
|
metadata=mlrun.common.schemas.ModelEndpointMetadata(
|
|
623
614
|
project=graph_server.project,
|
|
624
615
|
labels=model_labels,
|
|
625
616
|
name=model.name,
|
|
626
|
-
endpoint_type=
|
|
617
|
+
endpoint_type=endpoint_type,
|
|
627
618
|
),
|
|
628
619
|
spec=mlrun.common.schemas.ModelEndpointSpec(
|
|
629
620
|
function_name=graph_server.function_name,
|
|
630
|
-
function_uid=function_uid,
|
|
631
621
|
function_tag=graph_server.function_tag or "latest",
|
|
632
622
|
model_name=model_name,
|
|
633
623
|
model_db_key=model_db_key,
|
|
@@ -642,49 +632,11 @@ def _init_endpoint_record(
|
|
|
642
632
|
),
|
|
643
633
|
)
|
|
644
634
|
db = mlrun.get_run_db()
|
|
645
|
-
model_ep = db.create_model_endpoint(
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
attributes[ModelEndpointSchema.FUNCTION_UID] = function_uid
|
|
651
|
-
if model_name != model_ep.spec.model_name:
|
|
652
|
-
attributes[ModelEndpointSchema.MODEL_NAME] = model_name
|
|
653
|
-
if model_uid != model_ep.spec.model_uid:
|
|
654
|
-
attributes[ModelEndpointSchema.MODEL_UID] = model_uid
|
|
655
|
-
if model_tag != model_ep.spec.model_tag:
|
|
656
|
-
attributes[ModelEndpointSchema.MODEL_TAG] = model_tag
|
|
657
|
-
if model_db_key != model_ep.spec.model_db_key:
|
|
658
|
-
attributes[ModelEndpointSchema.MODEL_DB_KEY] = model_db_key
|
|
659
|
-
if model_labels != model_ep.metadata.labels:
|
|
660
|
-
attributes[ModelEndpointSchema.LABELS] = model_labels
|
|
661
|
-
if model.__class__.__name__ != model_ep.spec.model_class:
|
|
662
|
-
attributes[ModelEndpointSchema.MODEL_CLASS] = model.__class__.__name__
|
|
663
|
-
if (
|
|
664
|
-
model_ep.status.monitoring_mode
|
|
665
|
-
== mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
|
|
666
|
-
) != model.context.server.track_models:
|
|
667
|
-
attributes[ModelEndpointSchema.MONITORING_MODE] = (
|
|
668
|
-
mlrun.common.schemas.model_monitoring.ModelMonitoringMode.enabled
|
|
669
|
-
if model.context.server.track_models
|
|
670
|
-
else mlrun.common.schemas.model_monitoring.ModelMonitoringMode.disabled
|
|
671
|
-
)
|
|
672
|
-
if attributes:
|
|
673
|
-
logger.info(
|
|
674
|
-
"Updating model endpoint attributes",
|
|
675
|
-
attributes=attributes,
|
|
676
|
-
project=model_ep.metadata.project,
|
|
677
|
-
name=model_ep.metadata.name,
|
|
678
|
-
function_name=model_ep.spec.function_name,
|
|
679
|
-
)
|
|
680
|
-
db = mlrun.get_run_db()
|
|
681
|
-
model_ep = db.patch_model_endpoint(
|
|
682
|
-
project=model_ep.metadata.project,
|
|
683
|
-
name=model_ep.metadata.name,
|
|
684
|
-
endpoint_id=model_ep.metadata.uid,
|
|
685
|
-
attributes=attributes,
|
|
686
|
-
)
|
|
687
|
-
else:
|
|
635
|
+
model_ep = db.create_model_endpoint(
|
|
636
|
+
model_endpoint=model_ep, creation_strategy=creation_strategy
|
|
637
|
+
)
|
|
638
|
+
except mlrun.errors.MLRunBadRequestError as e:
|
|
639
|
+
logger.info("Failed to create model endpoint record", error=e)
|
|
688
640
|
return None
|
|
689
641
|
|
|
690
642
|
return model_ep.metadata.uid
|