mlrun 1.10.0rc13__py3-none-any.whl → 1.10.0rc42__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__init__.py +22 -2
- mlrun/artifacts/base.py +0 -31
- mlrun/artifacts/document.py +6 -1
- mlrun/artifacts/llm_prompt.py +123 -25
- mlrun/artifacts/manager.py +0 -5
- mlrun/artifacts/model.py +3 -3
- mlrun/common/constants.py +10 -1
- mlrun/common/formatters/artifact.py +1 -0
- mlrun/common/model_monitoring/helpers.py +86 -0
- mlrun/common/schemas/__init__.py +3 -0
- mlrun/common/schemas/auth.py +2 -0
- mlrun/common/schemas/function.py +10 -0
- mlrun/common/schemas/hub.py +30 -18
- mlrun/common/schemas/model_monitoring/__init__.py +3 -0
- mlrun/common/schemas/model_monitoring/constants.py +30 -6
- mlrun/common/schemas/model_monitoring/functions.py +14 -5
- mlrun/common/schemas/model_monitoring/model_endpoints.py +21 -0
- mlrun/common/schemas/pipeline.py +1 -1
- mlrun/common/schemas/serving.py +3 -0
- mlrun/common/schemas/workflow.py +3 -1
- mlrun/common/secrets.py +22 -1
- mlrun/config.py +33 -11
- mlrun/datastore/__init__.py +11 -3
- mlrun/datastore/azure_blob.py +162 -47
- mlrun/datastore/datastore.py +9 -4
- mlrun/datastore/datastore_profile.py +61 -5
- mlrun/datastore/model_provider/huggingface_provider.py +363 -0
- mlrun/datastore/model_provider/mock_model_provider.py +87 -0
- mlrun/datastore/model_provider/model_provider.py +230 -65
- mlrun/datastore/model_provider/openai_provider.py +295 -42
- mlrun/datastore/s3.py +24 -2
- mlrun/datastore/storeytargets.py +2 -3
- mlrun/datastore/utils.py +15 -3
- mlrun/db/base.py +47 -19
- mlrun/db/httpdb.py +120 -56
- mlrun/db/nopdb.py +38 -10
- mlrun/execution.py +70 -19
- mlrun/hub/__init__.py +15 -0
- mlrun/hub/module.py +181 -0
- mlrun/k8s_utils.py +105 -16
- mlrun/launcher/base.py +13 -6
- mlrun/launcher/local.py +15 -0
- mlrun/model.py +24 -3
- mlrun/model_monitoring/__init__.py +1 -0
- mlrun/model_monitoring/api.py +66 -27
- mlrun/model_monitoring/applications/__init__.py +1 -1
- mlrun/model_monitoring/applications/base.py +509 -117
- mlrun/model_monitoring/applications/context.py +2 -4
- mlrun/model_monitoring/applications/results.py +4 -7
- mlrun/model_monitoring/controller.py +239 -101
- mlrun/model_monitoring/db/_schedules.py +116 -33
- mlrun/model_monitoring/db/_stats.py +4 -3
- mlrun/model_monitoring/db/tsdb/base.py +100 -9
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +11 -6
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +191 -50
- mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +259 -40
- mlrun/model_monitoring/helpers.py +54 -9
- mlrun/model_monitoring/stream_processing.py +45 -14
- mlrun/model_monitoring/writer.py +220 -1
- mlrun/platforms/__init__.py +3 -2
- mlrun/platforms/iguazio.py +7 -3
- mlrun/projects/operations.py +6 -1
- mlrun/projects/pipelines.py +46 -26
- mlrun/projects/project.py +166 -58
- mlrun/run.py +94 -17
- mlrun/runtimes/__init__.py +18 -0
- mlrun/runtimes/base.py +14 -6
- mlrun/runtimes/daskjob.py +7 -0
- mlrun/runtimes/local.py +5 -2
- mlrun/runtimes/mounts.py +20 -2
- mlrun/runtimes/mpijob/abstract.py +6 -0
- mlrun/runtimes/mpijob/v1.py +6 -0
- mlrun/runtimes/nuclio/__init__.py +1 -0
- mlrun/runtimes/nuclio/application/application.py +149 -17
- mlrun/runtimes/nuclio/function.py +76 -27
- mlrun/runtimes/nuclio/serving.py +97 -15
- mlrun/runtimes/pod.py +234 -21
- mlrun/runtimes/remotesparkjob.py +6 -0
- mlrun/runtimes/sparkjob/spark3job.py +6 -0
- mlrun/runtimes/utils.py +49 -11
- mlrun/secrets.py +54 -13
- mlrun/serving/__init__.py +2 -0
- mlrun/serving/remote.py +79 -6
- mlrun/serving/routers.py +23 -41
- mlrun/serving/server.py +320 -80
- mlrun/serving/states.py +725 -157
- mlrun/serving/steps.py +62 -0
- mlrun/serving/system_steps.py +200 -119
- mlrun/serving/v2_serving.py +9 -10
- mlrun/utils/helpers.py +288 -88
- mlrun/utils/logger.py +3 -1
- mlrun/utils/notifications/notification/base.py +18 -0
- mlrun/utils/notifications/notification/git.py +2 -4
- mlrun/utils/notifications/notification/slack.py +2 -4
- mlrun/utils/notifications/notification/webhook.py +2 -5
- mlrun/utils/notifications/notification_pusher.py +1 -1
- mlrun/utils/retryer.py +15 -2
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/METADATA +45 -51
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/RECORD +106 -101
- mlrun/api/schemas/__init__.py +0 -259
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/WHEEL +0 -0
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/entry_points.txt +0 -0
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/licenses/LICENSE +0 -0
- {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py
CHANGED
|
@@ -24,6 +24,7 @@ import inspect
|
|
|
24
24
|
import os
|
|
25
25
|
import pathlib
|
|
26
26
|
import traceback
|
|
27
|
+
import warnings
|
|
27
28
|
from abc import ABC
|
|
28
29
|
from copy import copy, deepcopy
|
|
29
30
|
from inspect import getfullargspec, signature
|
|
@@ -35,20 +36,24 @@ from storey import ParallelExecutionMechanisms
|
|
|
35
36
|
import mlrun
|
|
36
37
|
import mlrun.artifacts
|
|
37
38
|
import mlrun.common.schemas as schemas
|
|
38
|
-
from mlrun.artifacts.llm_prompt import LLMPromptArtifact
|
|
39
|
+
from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
|
|
39
40
|
from mlrun.artifacts.model import ModelArtifact
|
|
40
41
|
from mlrun.datastore.datastore_profile import (
|
|
41
|
-
|
|
42
|
+
DatastoreProfileKafkaStream,
|
|
42
43
|
DatastoreProfileKafkaTarget,
|
|
43
44
|
DatastoreProfileV3io,
|
|
44
45
|
datastore_profile_read,
|
|
45
46
|
)
|
|
46
|
-
from mlrun.datastore.model_provider.model_provider import
|
|
47
|
+
from mlrun.datastore.model_provider.model_provider import (
|
|
48
|
+
InvokeResponseFormat,
|
|
49
|
+
ModelProvider,
|
|
50
|
+
UsageResponseKeys,
|
|
51
|
+
)
|
|
47
52
|
from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
|
|
48
|
-
from mlrun.utils import logger
|
|
53
|
+
from mlrun.utils import get_data_from_path, logger, set_data_by_path, split_path
|
|
49
54
|
|
|
50
55
|
from ..config import config
|
|
51
|
-
from ..datastore import get_stream_pusher
|
|
56
|
+
from ..datastore import _DummyStream, get_stream_pusher
|
|
52
57
|
from ..datastore.utils import (
|
|
53
58
|
get_kafka_brokers_from_dict,
|
|
54
59
|
parse_kafka_url,
|
|
@@ -501,10 +506,15 @@ class BaseStep(ModelObj):
|
|
|
501
506
|
def verify_model_runner_step(
|
|
502
507
|
self,
|
|
503
508
|
step: "ModelRunnerStep",
|
|
509
|
+
step_model_endpoints_names: Optional[list[str]] = None,
|
|
510
|
+
verify_shared_models: bool = True,
|
|
504
511
|
):
|
|
505
512
|
"""
|
|
506
513
|
Verify ModelRunnerStep, can be part of Flow graph and models can not repeat in graph.
|
|
507
|
-
:param step:
|
|
514
|
+
:param step: ModelRunnerStep to verify
|
|
515
|
+
:param step_model_endpoints_names: List of model endpoints names that are in the step.
|
|
516
|
+
if provided will ignore step models and verify only the models on list.
|
|
517
|
+
:param verify_shared_models: If True, verify that shared models are defined in the graph.
|
|
508
518
|
"""
|
|
509
519
|
|
|
510
520
|
if not isinstance(step, ModelRunnerStep):
|
|
@@ -512,11 +522,13 @@ class BaseStep(ModelObj):
|
|
|
512
522
|
|
|
513
523
|
root = self._extract_root_step()
|
|
514
524
|
|
|
515
|
-
if not isinstance(root, RootFlowStep)
|
|
525
|
+
if not isinstance(root, RootFlowStep) or (
|
|
526
|
+
isinstance(root, RootFlowStep) and root.engine != "async"
|
|
527
|
+
):
|
|
516
528
|
raise GraphError(
|
|
517
529
|
"ModelRunnerStep can be added to 'Flow' topology graph only"
|
|
518
530
|
)
|
|
519
|
-
step_model_endpoints_names = list(
|
|
531
|
+
step_model_endpoints_names = step_model_endpoints_names or list(
|
|
520
532
|
step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}).keys()
|
|
521
533
|
)
|
|
522
534
|
# Get all model_endpoints names that are in both lists
|
|
@@ -530,13 +542,14 @@ class BaseStep(ModelObj):
|
|
|
530
542
|
f"The graph already contains the model endpoints named - {common_endpoints_names}."
|
|
531
543
|
)
|
|
532
544
|
|
|
533
|
-
|
|
534
|
-
|
|
545
|
+
if verify_shared_models:
|
|
546
|
+
# Check if shared models are defined in the graph
|
|
547
|
+
self._verify_shared_models(root, step, step_model_endpoints_names)
|
|
535
548
|
# Update model endpoints names in the root step
|
|
536
549
|
root.update_model_endpoints_names(step_model_endpoints_names)
|
|
537
550
|
|
|
538
|
-
@staticmethod
|
|
539
551
|
def _verify_shared_models(
|
|
552
|
+
self,
|
|
540
553
|
root: "RootFlowStep",
|
|
541
554
|
step: "ModelRunnerStep",
|
|
542
555
|
step_model_endpoints_names: list[str],
|
|
@@ -565,33 +578,41 @@ class BaseStep(ModelObj):
|
|
|
565
578
|
prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
|
|
566
579
|
# if the model artifact is a prompt, we need to get the model URI
|
|
567
580
|
# to ensure that the shared runnable name is correct
|
|
581
|
+
llm_artifact_uri = None
|
|
568
582
|
if prefix == mlrun.utils.StorePrefix.LLMPrompt:
|
|
569
583
|
llm_artifact, _ = mlrun.store_manager.get_store_artifact(
|
|
570
584
|
model_artifact_uri
|
|
571
585
|
)
|
|
572
|
-
|
|
573
|
-
|
|
574
|
-
|
|
586
|
+
llm_artifact_uri = llm_artifact.uri
|
|
587
|
+
model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
|
|
588
|
+
llm_artifact.spec.parent_uri
|
|
589
|
+
)
|
|
590
|
+
actual_shared_name, shared_model_class, shared_model_params = (
|
|
591
|
+
root.get_shared_model_by_artifact_uri(model_artifact_uri)
|
|
575
592
|
)
|
|
576
593
|
|
|
577
|
-
if not
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
]["shared_runnable_name"] = actual_shared_name
|
|
586
|
-
shared_models.append(actual_shared_name)
|
|
594
|
+
if not actual_shared_name:
|
|
595
|
+
raise GraphError(
|
|
596
|
+
f"Can't find shared model named {shared_runnable_name}"
|
|
597
|
+
)
|
|
598
|
+
elif not shared_runnable_name:
|
|
599
|
+
step.class_args[schemas.ModelRunnerStepData.MODELS][name][
|
|
600
|
+
schemas.ModelsData.MODEL_PARAMETERS.value
|
|
601
|
+
]["shared_runnable_name"] = actual_shared_name
|
|
587
602
|
elif actual_shared_name != shared_runnable_name:
|
|
588
603
|
raise GraphError(
|
|
589
604
|
f"Model endpoint {name} shared runnable name mismatch: "
|
|
590
605
|
f"expected {actual_shared_name}, got {shared_runnable_name}"
|
|
591
606
|
)
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
607
|
+
shared_models.append(actual_shared_name)
|
|
608
|
+
self._edit_proxy_model_data(
|
|
609
|
+
step,
|
|
610
|
+
name,
|
|
611
|
+
actual_shared_name,
|
|
612
|
+
shared_model_params,
|
|
613
|
+
shared_model_class,
|
|
614
|
+
llm_artifact_uri or model_artifact_uri,
|
|
615
|
+
)
|
|
595
616
|
undefined_shared_models = list(
|
|
596
617
|
set(shared_models) - set(root.shared_models.keys())
|
|
597
618
|
)
|
|
@@ -600,6 +621,52 @@ class BaseStep(ModelObj):
|
|
|
600
621
|
f"The following shared models are not defined in the graph: {undefined_shared_models}."
|
|
601
622
|
)
|
|
602
623
|
|
|
624
|
+
@staticmethod
|
|
625
|
+
def _edit_proxy_model_data(
|
|
626
|
+
step: "ModelRunnerStep",
|
|
627
|
+
name: str,
|
|
628
|
+
actual_shared_name: str,
|
|
629
|
+
shared_model_params: dict,
|
|
630
|
+
shared_model_class: Any,
|
|
631
|
+
artifact: Union[ModelArtifact, LLMPromptArtifact, str],
|
|
632
|
+
):
|
|
633
|
+
monitoring_data = step.class_args.setdefault(
|
|
634
|
+
schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
# edit monitoring data according to the shared model parameters
|
|
638
|
+
monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
|
|
639
|
+
"input_path"
|
|
640
|
+
]
|
|
641
|
+
monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
|
|
642
|
+
"result_path"
|
|
643
|
+
]
|
|
644
|
+
monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
|
|
645
|
+
"inputs"
|
|
646
|
+
]
|
|
647
|
+
monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
|
|
648
|
+
"outputs"
|
|
649
|
+
]
|
|
650
|
+
monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
|
|
651
|
+
shared_model_class
|
|
652
|
+
if isinstance(shared_model_class, str)
|
|
653
|
+
else shared_model_class.__class__.__name__
|
|
654
|
+
)
|
|
655
|
+
if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
|
|
656
|
+
step._shared_proxy_mapping[actual_shared_name] = {
|
|
657
|
+
name: artifact.uri
|
|
658
|
+
if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
|
|
659
|
+
else artifact
|
|
660
|
+
}
|
|
661
|
+
elif actual_shared_name:
|
|
662
|
+
step._shared_proxy_mapping[actual_shared_name].update(
|
|
663
|
+
{
|
|
664
|
+
name: artifact.uri
|
|
665
|
+
if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
|
|
666
|
+
else artifact
|
|
667
|
+
}
|
|
668
|
+
)
|
|
669
|
+
|
|
603
670
|
|
|
604
671
|
class TaskStep(BaseStep):
|
|
605
672
|
"""task execution step, runs a class or handler"""
|
|
@@ -1081,6 +1148,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1081
1148
|
"raise_exception",
|
|
1082
1149
|
"artifact_uri",
|
|
1083
1150
|
"shared_runnable_name",
|
|
1151
|
+
"shared_proxy_mapping",
|
|
1152
|
+
"execution_mechanism",
|
|
1084
1153
|
]
|
|
1085
1154
|
kind = "model"
|
|
1086
1155
|
|
|
@@ -1089,15 +1158,21 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1089
1158
|
name: str,
|
|
1090
1159
|
raise_exception: bool = True,
|
|
1091
1160
|
artifact_uri: Optional[str] = None,
|
|
1161
|
+
shared_proxy_mapping: Optional[dict] = None,
|
|
1092
1162
|
**kwargs,
|
|
1093
1163
|
):
|
|
1094
1164
|
super().__init__(name=name, raise_exception=raise_exception, **kwargs)
|
|
1095
1165
|
if artifact_uri is not None and not isinstance(artifact_uri, str):
|
|
1096
1166
|
raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
|
|
1097
1167
|
self.artifact_uri = artifact_uri
|
|
1168
|
+
self.shared_proxy_mapping: dict[
|
|
1169
|
+
str : Union[str, ModelArtifact, LLMPromptArtifact]
|
|
1170
|
+
] = shared_proxy_mapping
|
|
1098
1171
|
self.invocation_artifact: Optional[LLMPromptArtifact] = None
|
|
1099
1172
|
self.model_artifact: Optional[ModelArtifact] = None
|
|
1100
1173
|
self.model_provider: Optional[ModelProvider] = None
|
|
1174
|
+
self._artifact_were_loaded = False
|
|
1175
|
+
self._execution_mechanism = None
|
|
1101
1176
|
|
|
1102
1177
|
def __init_subclass__(cls):
|
|
1103
1178
|
super().__init_subclass__()
|
|
@@ -1117,18 +1192,37 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1117
1192
|
raise_missing_schema_exception=False,
|
|
1118
1193
|
)
|
|
1119
1194
|
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1195
|
+
# Check if the relevant predict method is implemented when trying to initialize the model
|
|
1196
|
+
if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
|
|
1197
|
+
if self.__class__.predict_async is Model.predict_async:
|
|
1198
|
+
raise mlrun.errors.ModelRunnerError(
|
|
1199
|
+
f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict_async() "
|
|
1200
|
+
f"is not implemented"
|
|
1201
|
+
)
|
|
1125
1202
|
else:
|
|
1126
|
-
self.
|
|
1203
|
+
if self.__class__.predict is Model.predict:
|
|
1204
|
+
raise mlrun.errors.ModelRunnerError(
|
|
1205
|
+
f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict() "
|
|
1206
|
+
f"is not implemented"
|
|
1207
|
+
)
|
|
1127
1208
|
|
|
1128
|
-
def
|
|
1129
|
-
if self.
|
|
1130
|
-
|
|
1131
|
-
|
|
1209
|
+
def _load_artifacts(self) -> None:
|
|
1210
|
+
if not self._artifact_were_loaded:
|
|
1211
|
+
artifact = self._get_artifact_object()
|
|
1212
|
+
if isinstance(artifact, LLMPromptArtifact):
|
|
1213
|
+
self.invocation_artifact = artifact
|
|
1214
|
+
self.model_artifact = self.invocation_artifact.model_artifact
|
|
1215
|
+
else:
|
|
1216
|
+
self.model_artifact = artifact
|
|
1217
|
+
self._artifact_were_loaded = True
|
|
1218
|
+
|
|
1219
|
+
def _get_artifact_object(
|
|
1220
|
+
self, proxy_uri: Optional[str] = None
|
|
1221
|
+
) -> Union[ModelArtifact, LLMPromptArtifact, None]:
|
|
1222
|
+
uri = proxy_uri or self.artifact_uri
|
|
1223
|
+
if uri:
|
|
1224
|
+
if mlrun.datastore.is_store_uri(uri):
|
|
1225
|
+
artifact, _ = mlrun.store_manager.get_store_artifact(uri)
|
|
1132
1226
|
return artifact
|
|
1133
1227
|
else:
|
|
1134
1228
|
raise ValueError(
|
|
@@ -1140,18 +1234,20 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1140
1234
|
def init(self):
|
|
1141
1235
|
self.load()
|
|
1142
1236
|
|
|
1143
|
-
def predict(self, body: Any) -> Any:
|
|
1237
|
+
def predict(self, body: Any, **kwargs) -> Any:
|
|
1144
1238
|
"""Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
|
|
1145
|
-
|
|
1239
|
+
raise NotImplementedError("predict() method not implemented")
|
|
1146
1240
|
|
|
1147
|
-
async def predict_async(self, body: Any) -> Any:
|
|
1241
|
+
async def predict_async(self, body: Any, **kwargs) -> Any:
|
|
1148
1242
|
"""Override to implement prediction logic if the logic requires asyncio."""
|
|
1149
|
-
|
|
1243
|
+
raise NotImplementedError("predict_async() method not implemented")
|
|
1150
1244
|
|
|
1151
|
-
def run(self, body: Any, path: str) -> Any:
|
|
1245
|
+
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
|
|
1152
1246
|
return self.predict(body)
|
|
1153
1247
|
|
|
1154
|
-
async def run_async(
|
|
1248
|
+
async def run_async(
|
|
1249
|
+
self, body: Any, path: str, origin_name: Optional[str] = None
|
|
1250
|
+
) -> Any:
|
|
1155
1251
|
return await self.predict_async(body)
|
|
1156
1252
|
|
|
1157
1253
|
def get_local_model_path(self, suffix="") -> (str, dict):
|
|
@@ -1186,9 +1282,291 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
|
|
|
1186
1282
|
return None, None
|
|
1187
1283
|
|
|
1188
1284
|
|
|
1189
|
-
class
|
|
1285
|
+
class LLModel(Model):
|
|
1286
|
+
"""
|
|
1287
|
+
A model wrapper for handling LLM (Large Language Model) prompt-based inference.
|
|
1288
|
+
|
|
1289
|
+
This class extends the base `Model` to provide specialized handling for
|
|
1290
|
+
`LLMPromptArtifact` objects, enabling both synchronous and asynchronous
|
|
1291
|
+
invocation of language models.
|
|
1292
|
+
|
|
1293
|
+
**Model Invocation**:
|
|
1294
|
+
|
|
1295
|
+
- The execution of enriched prompts is delegated to the `model_provider`
|
|
1296
|
+
configured for the model (e.g., **Hugging Face** or **OpenAI**).
|
|
1297
|
+
- The `model_provider` is responsible for sending the prompt to the correct
|
|
1298
|
+
backend API and returning the generated output.
|
|
1299
|
+
- Users can override the `predict` and `predict_async` methods to customize
|
|
1300
|
+
the behavior of the model invocation.
|
|
1301
|
+
|
|
1302
|
+
**Prompt Enrichment Overview**:
|
|
1303
|
+
|
|
1304
|
+
- If an `LLMPromptArtifact` is found, load its prompt template and fill in
|
|
1305
|
+
placeholders using values from the request body.
|
|
1306
|
+
- If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
|
|
1307
|
+
to retrieve `messages` directly from the request body using the input path.
|
|
1308
|
+
|
|
1309
|
+
**Simplified Example**:
|
|
1310
|
+
|
|
1311
|
+
Input body::
|
|
1312
|
+
|
|
1313
|
+
{"city": "Paris", "days": 3}
|
|
1314
|
+
|
|
1315
|
+
Prompt template in artifact::
|
|
1316
|
+
|
|
1317
|
+
[
|
|
1318
|
+
{"role": "system", "content": "You are a travel planning assistant."},
|
|
1319
|
+
{"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
|
|
1320
|
+
]
|
|
1321
|
+
|
|
1322
|
+
Result after enrichment::
|
|
1323
|
+
|
|
1324
|
+
[
|
|
1325
|
+
{"role": "system", "content": "You are a travel planning assistant."},
|
|
1326
|
+
{"role": "user", "content": "Create a 3-day itinerary for Paris."},
|
|
1327
|
+
]
|
|
1328
|
+
|
|
1329
|
+
:param name: Name of the model.
|
|
1330
|
+
:param input_path: Path in the request body where input data is located.
|
|
1331
|
+
:param result_path: Path in the response body where model outputs and the statistics
|
|
1332
|
+
will be stored.
|
|
1333
|
+
"""
|
|
1334
|
+
|
|
1335
|
+
_dict_fields = Model._dict_fields + ["result_path", "input_path"]
|
|
1336
|
+
|
|
1337
|
+
def __init__(
|
|
1338
|
+
self,
|
|
1339
|
+
name: str,
|
|
1340
|
+
input_path: Optional[Union[str, list[str]]] = None,
|
|
1341
|
+
result_path: Optional[Union[str, list[str]]] = None,
|
|
1342
|
+
**kwargs,
|
|
1343
|
+
):
|
|
1344
|
+
super().__init__(name, **kwargs)
|
|
1345
|
+
self._input_path = split_path(input_path)
|
|
1346
|
+
self._result_path = split_path(result_path)
|
|
1347
|
+
logger.info(
|
|
1348
|
+
"LLModel initialized",
|
|
1349
|
+
model_name=name,
|
|
1350
|
+
input_path=input_path,
|
|
1351
|
+
result_path=result_path,
|
|
1352
|
+
)
|
|
1353
|
+
|
|
1354
|
+
def predict(
|
|
1355
|
+
self,
|
|
1356
|
+
body: Any,
|
|
1357
|
+
messages: Optional[list[dict]] = None,
|
|
1358
|
+
invocation_config: Optional[dict] = None,
|
|
1359
|
+
**kwargs,
|
|
1360
|
+
) -> Any:
|
|
1361
|
+
llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
|
|
1362
|
+
if isinstance(
|
|
1363
|
+
llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
|
|
1364
|
+
) and isinstance(self.model_provider, ModelProvider):
|
|
1365
|
+
logger.debug(
|
|
1366
|
+
"Invoking model provider",
|
|
1367
|
+
model_name=self.name,
|
|
1368
|
+
messages=messages,
|
|
1369
|
+
invocation_config=invocation_config,
|
|
1370
|
+
)
|
|
1371
|
+
response_with_stats = self.model_provider.invoke(
|
|
1372
|
+
messages=messages,
|
|
1373
|
+
invoke_response_format=InvokeResponseFormat.USAGE,
|
|
1374
|
+
**(invocation_config or {}),
|
|
1375
|
+
)
|
|
1376
|
+
set_data_by_path(
|
|
1377
|
+
path=self._result_path, data=body, value=response_with_stats
|
|
1378
|
+
)
|
|
1379
|
+
logger.debug(
|
|
1380
|
+
"LLModel prediction completed",
|
|
1381
|
+
model_name=self.name,
|
|
1382
|
+
answer=response_with_stats.get("answer"),
|
|
1383
|
+
usage=response_with_stats.get("usage"),
|
|
1384
|
+
)
|
|
1385
|
+
else:
|
|
1386
|
+
logger.warning(
|
|
1387
|
+
"LLModel invocation artifact or model provider not set, skipping prediction",
|
|
1388
|
+
model_name=self.name,
|
|
1389
|
+
invocation_artifact_type=type(llm_prompt_artifact).__name__,
|
|
1390
|
+
model_provider_type=type(self.model_provider).__name__,
|
|
1391
|
+
)
|
|
1392
|
+
return body
|
|
1393
|
+
|
|
1394
|
+
async def predict_async(
|
|
1395
|
+
self,
|
|
1396
|
+
body: Any,
|
|
1397
|
+
messages: Optional[list[dict]] = None,
|
|
1398
|
+
invocation_config: Optional[dict] = None,
|
|
1399
|
+
**kwargs,
|
|
1400
|
+
) -> Any:
|
|
1401
|
+
llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
|
|
1402
|
+
if isinstance(
|
|
1403
|
+
llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
|
|
1404
|
+
) and isinstance(self.model_provider, ModelProvider):
|
|
1405
|
+
logger.debug(
|
|
1406
|
+
"Async invoking model provider",
|
|
1407
|
+
model_name=self.name,
|
|
1408
|
+
messages=messages,
|
|
1409
|
+
invocation_config=invocation_config,
|
|
1410
|
+
)
|
|
1411
|
+
response_with_stats = await self.model_provider.async_invoke(
|
|
1412
|
+
messages=messages,
|
|
1413
|
+
invoke_response_format=InvokeResponseFormat.USAGE,
|
|
1414
|
+
**(invocation_config or {}),
|
|
1415
|
+
)
|
|
1416
|
+
set_data_by_path(
|
|
1417
|
+
path=self._result_path, data=body, value=response_with_stats
|
|
1418
|
+
)
|
|
1419
|
+
logger.debug(
|
|
1420
|
+
"LLModel async prediction completed",
|
|
1421
|
+
model_name=self.name,
|
|
1422
|
+
answer=response_with_stats.get("answer"),
|
|
1423
|
+
usage=response_with_stats.get("usage"),
|
|
1424
|
+
)
|
|
1425
|
+
else:
|
|
1426
|
+
logger.warning(
|
|
1427
|
+
"LLModel invocation artifact or model provider not set, skipping async prediction",
|
|
1428
|
+
model_name=self.name,
|
|
1429
|
+
invocation_artifact_type=type(llm_prompt_artifact).__name__,
|
|
1430
|
+
model_provider_type=type(self.model_provider).__name__,
|
|
1431
|
+
)
|
|
1432
|
+
return body
|
|
1433
|
+
|
|
1434
|
+
def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
|
|
1435
|
+
llm_prompt_artifact = self._get_invocation_artifact(origin_name)
|
|
1436
|
+
messages, invocation_config = self.enrich_prompt(
|
|
1437
|
+
body, origin_name, llm_prompt_artifact
|
|
1438
|
+
)
|
|
1439
|
+
logger.info(
|
|
1440
|
+
"Calling LLModel predict",
|
|
1441
|
+
model_name=self.name,
|
|
1442
|
+
model_endpoint_name=origin_name,
|
|
1443
|
+
messages_len=len(messages) if messages else 0,
|
|
1444
|
+
)
|
|
1445
|
+
return self.predict(
|
|
1446
|
+
body,
|
|
1447
|
+
messages=messages,
|
|
1448
|
+
invocation_config=invocation_config,
|
|
1449
|
+
llm_prompt_artifact=llm_prompt_artifact,
|
|
1450
|
+
)
|
|
1451
|
+
|
|
1452
|
+
async def run_async(
|
|
1453
|
+
self, body: Any, path: str, origin_name: Optional[str] = None
|
|
1454
|
+
) -> Any:
|
|
1455
|
+
llm_prompt_artifact = self._get_invocation_artifact(origin_name)
|
|
1456
|
+
messages, invocation_config = self.enrich_prompt(
|
|
1457
|
+
body, origin_name, llm_prompt_artifact
|
|
1458
|
+
)
|
|
1459
|
+
logger.info(
|
|
1460
|
+
"Calling LLModel async predict",
|
|
1461
|
+
model_name=self.name,
|
|
1462
|
+
model_endpoint_name=origin_name,
|
|
1463
|
+
messages_len=len(messages) if messages else 0,
|
|
1464
|
+
)
|
|
1465
|
+
return await self.predict_async(
|
|
1466
|
+
body,
|
|
1467
|
+
messages=messages,
|
|
1468
|
+
invocation_config=invocation_config,
|
|
1469
|
+
llm_prompt_artifact=llm_prompt_artifact,
|
|
1470
|
+
)
|
|
1471
|
+
|
|
1472
|
+
def enrich_prompt(
|
|
1473
|
+
self,
|
|
1474
|
+
body: dict,
|
|
1475
|
+
origin_name: str,
|
|
1476
|
+
llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
|
|
1477
|
+
) -> Union[tuple[list[dict], dict], tuple[None, None]]:
|
|
1478
|
+
logger.info(
|
|
1479
|
+
"Enriching prompt",
|
|
1480
|
+
model_name=self.name,
|
|
1481
|
+
model_endpoint_name=origin_name,
|
|
1482
|
+
)
|
|
1483
|
+
if not llm_prompt_artifact or not (
|
|
1484
|
+
llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
|
|
1485
|
+
):
|
|
1486
|
+
logger.warning(
|
|
1487
|
+
"LLModel must be provided with LLMPromptArtifact",
|
|
1488
|
+
model_name=self.name,
|
|
1489
|
+
artifact_type=type(llm_prompt_artifact).__name__,
|
|
1490
|
+
llm_prompt_artifact=llm_prompt_artifact,
|
|
1491
|
+
)
|
|
1492
|
+
prompt_legend, prompt_template, invocation_config = {}, [], {}
|
|
1493
|
+
else:
|
|
1494
|
+
prompt_legend = llm_prompt_artifact.spec.prompt_legend
|
|
1495
|
+
prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
|
|
1496
|
+
invocation_config = llm_prompt_artifact.spec.invocation_config
|
|
1497
|
+
input_data = copy(get_data_from_path(self._input_path, body))
|
|
1498
|
+
if isinstance(input_data, dict) and prompt_template:
|
|
1499
|
+
kwargs = (
|
|
1500
|
+
{
|
|
1501
|
+
place_holder: input_data.get(body_map["field"])
|
|
1502
|
+
for place_holder, body_map in prompt_legend.items()
|
|
1503
|
+
if input_data.get(body_map["field"])
|
|
1504
|
+
}
|
|
1505
|
+
if prompt_legend
|
|
1506
|
+
else {}
|
|
1507
|
+
)
|
|
1508
|
+
input_data.update(kwargs)
|
|
1509
|
+
default_place_holders = PlaceholderDefaultDict(lambda: None, input_data)
|
|
1510
|
+
for message in prompt_template:
|
|
1511
|
+
try:
|
|
1512
|
+
message["content"] = message["content"].format(**input_data)
|
|
1513
|
+
except KeyError as e:
|
|
1514
|
+
logger.warning(
|
|
1515
|
+
"Input data missing placeholder, content stays unformatted",
|
|
1516
|
+
model_name=self.name,
|
|
1517
|
+
key_error=mlrun.errors.err_to_str(e),
|
|
1518
|
+
)
|
|
1519
|
+
message["content"] = message["content"].format_map(
|
|
1520
|
+
default_place_holders
|
|
1521
|
+
)
|
|
1522
|
+
elif isinstance(input_data, dict) and not prompt_template:
|
|
1523
|
+
# If there is no prompt template, we assume the input data is already in the correct format.
|
|
1524
|
+
logger.debug("Attempting to retrieve messages from the request body.")
|
|
1525
|
+
prompt_template = input_data.get("messages", [])
|
|
1526
|
+
else:
|
|
1527
|
+
logger.warning(
|
|
1528
|
+
"Expected input data to be a dict, prompt template stays unformatted",
|
|
1529
|
+
model_name=self.name,
|
|
1530
|
+
input_data_type=type(input_data).__name__,
|
|
1531
|
+
)
|
|
1532
|
+
return prompt_template, invocation_config
|
|
1533
|
+
|
|
1534
|
+
def _get_invocation_artifact(
|
|
1535
|
+
self, origin_name: Optional[str] = None
|
|
1536
|
+
) -> Union[LLMPromptArtifact, None]:
|
|
1537
|
+
"""
|
|
1538
|
+
Get the LLMPromptArtifact object for this model.
|
|
1539
|
+
|
|
1540
|
+
:param proxy_uri: Optional; URI to the proxy artifact.
|
|
1541
|
+
:return: LLMPromptArtifact object or None if not found.
|
|
1542
|
+
"""
|
|
1543
|
+
if origin_name and self.shared_proxy_mapping:
|
|
1544
|
+
llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
|
|
1545
|
+
if isinstance(llm_prompt_artifact, str):
|
|
1546
|
+
llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
|
|
1547
|
+
self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
|
|
1548
|
+
elif self._artifact_were_loaded:
|
|
1549
|
+
llm_prompt_artifact = self.invocation_artifact
|
|
1550
|
+
else:
|
|
1551
|
+
self._load_artifacts()
|
|
1552
|
+
llm_prompt_artifact = self.invocation_artifact
|
|
1553
|
+
return llm_prompt_artifact
|
|
1554
|
+
|
|
1555
|
+
|
|
1556
|
+
class ModelSelector(ModelObj):
|
|
1190
1557
|
"""Used to select which models to run on each event."""
|
|
1191
1558
|
|
|
1559
|
+
def __init__(self, **kwargs):
|
|
1560
|
+
super().__init__()
|
|
1561
|
+
|
|
1562
|
+
def __init_subclass__(cls):
|
|
1563
|
+
super().__init_subclass__()
|
|
1564
|
+
cls._dict_fields = list(
|
|
1565
|
+
set(cls._dict_fields)
|
|
1566
|
+
| set(inspect.signature(cls.__init__).parameters.keys())
|
|
1567
|
+
)
|
|
1568
|
+
cls._dict_fields.remove("self")
|
|
1569
|
+
|
|
1192
1570
|
def select(
|
|
1193
1571
|
self, event, available_models: list[Model]
|
|
1194
1572
|
) -> Union[list[str], list[Model]]:
|
|
@@ -1280,6 +1658,13 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1280
1658
|
model_runner_step.add_model(..., model_class=MyModel(name="my_model"))
|
|
1281
1659
|
graph.to(model_runner_step)
|
|
1282
1660
|
|
|
1661
|
+
Note when ModelRunnerStep is used in a graph, MLRun automatically imports
|
|
1662
|
+
the default language model class (LLModel) during function deployment.
|
|
1663
|
+
|
|
1664
|
+
Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
|
|
1665
|
+
|
|
1666
|
+
Note see config_pool_resource method documentation for default number of max threads and max processes.
|
|
1667
|
+
|
|
1283
1668
|
:param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
|
|
1284
1669
|
event. Optional. If not passed, all models will be run.
|
|
1285
1670
|
:param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
|
|
@@ -1292,25 +1677,54 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1292
1677
|
"""
|
|
1293
1678
|
|
|
1294
1679
|
kind = "model_runner"
|
|
1680
|
+
_dict_fields = MonitoredStep._dict_fields + [
|
|
1681
|
+
"_shared_proxy_mapping",
|
|
1682
|
+
"max_processes",
|
|
1683
|
+
"max_threads",
|
|
1684
|
+
"pool_factor",
|
|
1685
|
+
]
|
|
1295
1686
|
|
|
1296
1687
|
def __init__(
|
|
1297
1688
|
self,
|
|
1298
1689
|
*args,
|
|
1299
1690
|
name: Optional[str] = None,
|
|
1300
1691
|
model_selector: Optional[Union[str, ModelSelector]] = None,
|
|
1692
|
+
model_selector_parameters: Optional[dict] = None,
|
|
1301
1693
|
raise_exception: bool = True,
|
|
1302
1694
|
**kwargs,
|
|
1303
1695
|
):
|
|
1696
|
+
self.max_processes = None
|
|
1697
|
+
self.max_threads = None
|
|
1698
|
+
self.pool_factor = None
|
|
1699
|
+
|
|
1700
|
+
if isinstance(model_selector, ModelSelector) and model_selector_parameters:
|
|
1701
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1702
|
+
"Cannot provide a model_selector object as argument to `model_selector` and also provide "
|
|
1703
|
+
"`model_selector_parameters`."
|
|
1704
|
+
)
|
|
1705
|
+
if model_selector:
|
|
1706
|
+
model_selector_parameters = model_selector_parameters or (
|
|
1707
|
+
model_selector.to_dict()
|
|
1708
|
+
if isinstance(model_selector, ModelSelector)
|
|
1709
|
+
else {}
|
|
1710
|
+
)
|
|
1711
|
+
model_selector = (
|
|
1712
|
+
model_selector
|
|
1713
|
+
if isinstance(model_selector, str)
|
|
1714
|
+
else model_selector.__class__.__name__
|
|
1715
|
+
)
|
|
1716
|
+
|
|
1304
1717
|
super().__init__(
|
|
1305
1718
|
*args,
|
|
1306
1719
|
name=name,
|
|
1307
1720
|
raise_exception=raise_exception,
|
|
1308
1721
|
class_name="mlrun.serving.ModelRunner",
|
|
1309
|
-
class_args=dict(model_selector=model_selector),
|
|
1722
|
+
class_args=dict(model_selector=(model_selector, model_selector_parameters)),
|
|
1310
1723
|
**kwargs,
|
|
1311
1724
|
)
|
|
1312
1725
|
self.raise_exception = raise_exception
|
|
1313
1726
|
self.shape = "folder"
|
|
1727
|
+
self._shared_proxy_mapping = {}
|
|
1314
1728
|
|
|
1315
1729
|
def add_shared_model_proxy(
|
|
1316
1730
|
self,
|
|
@@ -1321,10 +1735,6 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1321
1735
|
model_endpoint_creation_strategy: Optional[
|
|
1322
1736
|
schemas.ModelEndpointCreationStrategy
|
|
1323
1737
|
] = schemas.ModelEndpointCreationStrategy.INPLACE,
|
|
1324
|
-
inputs: Optional[list[str]] = None,
|
|
1325
|
-
outputs: Optional[list[str]] = None,
|
|
1326
|
-
input_path: Optional[str] = None,
|
|
1327
|
-
result_path: Optional[str] = None,
|
|
1328
1738
|
override: bool = False,
|
|
1329
1739
|
) -> None:
|
|
1330
1740
|
"""
|
|
@@ -1347,22 +1757,12 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1347
1757
|
1. If model endpoints with the same name exist, preserve them.
|
|
1348
1758
|
2. Create a new model endpoint with the same name and set it to `latest`.
|
|
1349
1759
|
|
|
1350
|
-
:param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
|
|
1351
|
-
that been configured in the model artifact, please note that those inputs need to
|
|
1352
|
-
be equal in length and order to the inputs that model_class predict method expects
|
|
1353
|
-
:param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
|
|
1354
|
-
that been configured in the model artifact, please note that those outputs need to
|
|
1355
|
-
be equal to the model_class predict method outputs (length, and order)
|
|
1356
|
-
:param input_path: input path inside the user event, expect scopes to be defined by dot notation
|
|
1357
|
-
(e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
|
|
1358
|
-
:param result_path: result path inside the user output event, expect scopes to be defined by dot
|
|
1359
|
-
notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
|
|
1360
|
-
in path.
|
|
1361
1760
|
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
1761
|
+
:raise GraphError: when the shared model is not found in the root flow step shared models.
|
|
1362
1762
|
"""
|
|
1363
|
-
model_class =
|
|
1364
|
-
|
|
1365
|
-
shared_runnable_name
|
|
1763
|
+
model_class, model_params = (
|
|
1764
|
+
"mlrun.serving.Model",
|
|
1765
|
+
{"name": endpoint_name, "shared_runnable_name": shared_model_name},
|
|
1366
1766
|
)
|
|
1367
1767
|
if isinstance(model_artifact, str):
|
|
1368
1768
|
model_artifact_uri = model_artifact
|
|
@@ -1375,11 +1775,21 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1375
1775
|
"model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
|
|
1376
1776
|
)
|
|
1377
1777
|
root = self._extract_root_step()
|
|
1778
|
+
shared_model_params = {}
|
|
1378
1779
|
if isinstance(root, RootFlowStep):
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
|
|
1780
|
+
actual_shared_model_name, shared_model_class, shared_model_params = (
|
|
1781
|
+
root.get_shared_model_by_artifact_uri(model_artifact_uri)
|
|
1382
1782
|
)
|
|
1783
|
+
if not actual_shared_model_name or (
|
|
1784
|
+
shared_model_name and actual_shared_model_name != shared_model_name
|
|
1785
|
+
):
|
|
1786
|
+
raise GraphError(
|
|
1787
|
+
f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
|
|
1788
|
+
f"model {shared_model_name} is not in the shared models."
|
|
1789
|
+
)
|
|
1790
|
+
elif not shared_model_name:
|
|
1791
|
+
shared_model_name = actual_shared_model_name
|
|
1792
|
+
model_params["shared_runnable_name"] = shared_model_name
|
|
1383
1793
|
if not root.shared_models or (
|
|
1384
1794
|
root.shared_models
|
|
1385
1795
|
and shared_model_name
|
|
@@ -1389,6 +1799,34 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1389
1799
|
f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
|
|
1390
1800
|
f"model {shared_model_name} is not in the shared models."
|
|
1391
1801
|
)
|
|
1802
|
+
monitoring_data = self.class_args.get(
|
|
1803
|
+
schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
1804
|
+
)
|
|
1805
|
+
monitoring_data.setdefault(endpoint_name, {})[
|
|
1806
|
+
schemas.MonitoringData.MODEL_CLASS
|
|
1807
|
+
] = (
|
|
1808
|
+
shared_model_class
|
|
1809
|
+
if isinstance(shared_model_class, str)
|
|
1810
|
+
else shared_model_class.__class__.__name__
|
|
1811
|
+
)
|
|
1812
|
+
self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
|
|
1813
|
+
monitoring_data
|
|
1814
|
+
)
|
|
1815
|
+
|
|
1816
|
+
if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
|
|
1817
|
+
self._shared_proxy_mapping[shared_model_name] = {
|
|
1818
|
+
endpoint_name: model_artifact.uri
|
|
1819
|
+
if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
|
|
1820
|
+
else model_artifact
|
|
1821
|
+
}
|
|
1822
|
+
elif override and shared_model_name:
|
|
1823
|
+
self._shared_proxy_mapping[shared_model_name].update(
|
|
1824
|
+
{
|
|
1825
|
+
endpoint_name: model_artifact.uri
|
|
1826
|
+
if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
|
|
1827
|
+
else model_artifact
|
|
1828
|
+
}
|
|
1829
|
+
)
|
|
1392
1830
|
self.add_model(
|
|
1393
1831
|
endpoint_name=endpoint_name,
|
|
1394
1832
|
model_class=model_class,
|
|
@@ -1396,11 +1834,12 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1396
1834
|
model_artifact=model_artifact,
|
|
1397
1835
|
labels=labels,
|
|
1398
1836
|
model_endpoint_creation_strategy=model_endpoint_creation_strategy,
|
|
1837
|
+
inputs=shared_model_params.get("inputs"),
|
|
1838
|
+
outputs=shared_model_params.get("outputs"),
|
|
1839
|
+
input_path=shared_model_params.get("input_path"),
|
|
1840
|
+
result_path=shared_model_params.get("result_path"),
|
|
1399
1841
|
override=override,
|
|
1400
|
-
|
|
1401
|
-
outputs=outputs,
|
|
1402
|
-
input_path=input_path,
|
|
1403
|
-
result_path=result_path,
|
|
1842
|
+
**model_params,
|
|
1404
1843
|
)
|
|
1405
1844
|
|
|
1406
1845
|
def add_model(
|
|
@@ -1424,7 +1863,9 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1424
1863
|
Add a Model to this ModelRunner.
|
|
1425
1864
|
|
|
1426
1865
|
:param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
|
|
1427
|
-
:param model_class: Model class name
|
|
1866
|
+
:param model_class: Model class name. If LLModel is chosen
|
|
1867
|
+
(either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
|
|
1868
|
+
outputs will be overridden with UsageResponseKeys fields.
|
|
1428
1869
|
:param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
|
|
1429
1870
|
* "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
|
|
1430
1871
|
intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
|
|
@@ -1435,14 +1876,6 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1435
1876
|
otherwise block the main event loop thread.
|
|
1436
1877
|
* "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
|
|
1437
1878
|
event loop to continue running while waiting for a response.
|
|
1438
|
-
* "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
|
|
1439
|
-
runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
|
|
1440
|
-
useful when:
|
|
1441
|
-
- You want to share a heavy resource like a large model loaded onto a GPU.
|
|
1442
|
-
- You want to centralize task scheduling or coordination for multiple lightweight tasks.
|
|
1443
|
-
- You aim to minimize overhead from creating new executors or processes/threads per runnable.
|
|
1444
|
-
The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
|
|
1445
|
-
memory and hardware accelerators.
|
|
1446
1879
|
* "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
|
|
1447
1880
|
It means that the runnable will not actually be run in parallel to anything else.
|
|
1448
1881
|
|
|
@@ -1465,11 +1898,30 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1465
1898
|
:param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
|
|
1466
1899
|
that been configured in the model artifact, please note that those outputs need to
|
|
1467
1900
|
be equal to the model_class predict method outputs (length, and order)
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1901
|
+
|
|
1902
|
+
When using LLModel, the output will be overridden with UsageResponseKeys.fields().
|
|
1903
|
+
|
|
1904
|
+
:param input_path: when specified selects the key/path in the event to use as model monitoring inputs
|
|
1905
|
+
this require that the event body will behave like a dict, expects scopes to be
|
|
1906
|
+
defined by dot notation (e.g "data.d").
|
|
1907
|
+
examples: input_path="data.b"
|
|
1908
|
+
event: {"data":{"a": 5, "b": 7}}, means monitored body will be 7.
|
|
1909
|
+
event: {"data":{"a": [5, 9], "b": [7, 8]}} means monitored body will be [7,8].
|
|
1910
|
+
event: {"data":{"a": "extra_data", "b": {"f0": [1, 2]}}} means monitored body will
|
|
1911
|
+
be {"f0": [1, 2]}.
|
|
1912
|
+
if a ``list`` or ``list of lists`` is provided, it must follow the order and
|
|
1913
|
+
size defined by the input schema.
|
|
1914
|
+
:param result_path: when specified selects the key/path in the output event to use as model monitoring
|
|
1915
|
+
outputs this require that the output event body will behave like a dict,
|
|
1916
|
+
expects scopes to be defined by dot notation (e.g "data.d").
|
|
1917
|
+
examples: result_path="out.b"
|
|
1918
|
+
event: {"out":{"a": 5, "b": 7}}, means monitored body will be 7.
|
|
1919
|
+
event: {"out":{"a": [5, 9], "b": [7, 8]}} means monitored body will be [7,8]
|
|
1920
|
+
event: {"out":{"a": "extra_data", "b": {"f0": [1, 2]}}} means monitored body will
|
|
1921
|
+
be {"f0": [1, 2]}
|
|
1922
|
+
if a ``list`` or ``list of lists`` is provided, it must follow the order and
|
|
1923
|
+
size defined by the output schema.
|
|
1924
|
+
|
|
1473
1925
|
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
1474
1926
|
:param model_parameters: Parameters for model instantiation
|
|
1475
1927
|
"""
|
|
@@ -1477,7 +1929,15 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1477
1929
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1478
1930
|
"Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
|
|
1479
1931
|
)
|
|
1480
|
-
|
|
1932
|
+
if type(model_class) is LLModel or (
|
|
1933
|
+
isinstance(model_class, str)
|
|
1934
|
+
and model_class.split(".")[-1] == LLModel.__name__
|
|
1935
|
+
):
|
|
1936
|
+
if outputs:
|
|
1937
|
+
warnings.warn(
|
|
1938
|
+
"LLModel with existing outputs detected, overriding to default"
|
|
1939
|
+
)
|
|
1940
|
+
outputs = UsageResponseKeys.fields()
|
|
1481
1941
|
model_parameters = model_parameters or (
|
|
1482
1942
|
model_class.to_dict() if isinstance(model_class, Model) else {}
|
|
1483
1943
|
)
|
|
@@ -1488,18 +1948,21 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1488
1948
|
):
|
|
1489
1949
|
try:
|
|
1490
1950
|
model_artifact, _ = mlrun.store_manager.get_store_artifact(
|
|
1491
|
-
model_artifact
|
|
1951
|
+
mlrun.utils.remove_tag_from_artifact_uri(model_artifact)
|
|
1492
1952
|
)
|
|
1493
1953
|
except mlrun.errors.MLRunNotFoundError:
|
|
1494
1954
|
raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
|
|
1495
1955
|
|
|
1496
|
-
outputs = outputs or self._get_model_output_schema(model_artifact)
|
|
1497
|
-
|
|
1498
1956
|
model_artifact = (
|
|
1499
1957
|
model_artifact.uri
|
|
1500
1958
|
if isinstance(model_artifact, mlrun.artifacts.Artifact)
|
|
1501
1959
|
else model_artifact
|
|
1502
1960
|
)
|
|
1961
|
+
model_artifact = (
|
|
1962
|
+
mlrun.utils.remove_tag_from_artifact_uri(model_artifact)
|
|
1963
|
+
if model_artifact
|
|
1964
|
+
else None
|
|
1965
|
+
)
|
|
1503
1966
|
model_parameters["artifact_uri"] = model_parameters.get(
|
|
1504
1967
|
"artifact_uri", model_artifact
|
|
1505
1968
|
)
|
|
@@ -1515,6 +1978,11 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1515
1978
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1516
1979
|
f"Model with name {endpoint_name} already exists in this ModelRunnerStep."
|
|
1517
1980
|
)
|
|
1981
|
+
root = self._extract_root_step()
|
|
1982
|
+
if isinstance(root, RootFlowStep):
|
|
1983
|
+
self.verify_model_runner_step(
|
|
1984
|
+
self, [endpoint_name], verify_shared_models=False
|
|
1985
|
+
)
|
|
1518
1986
|
ParallelExecutionMechanisms.validate(execution_mechanism)
|
|
1519
1987
|
self.class_args[schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM] = (
|
|
1520
1988
|
self.class_args.get(
|
|
@@ -1550,28 +2018,13 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1550
2018
|
self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
|
|
1551
2019
|
|
|
1552
2020
|
@staticmethod
|
|
1553
|
-
def
|
|
1554
|
-
model_artifact: Union[ModelArtifact, LLMPromptArtifact],
|
|
1555
|
-
) -> Optional[list[str]]:
|
|
1556
|
-
if isinstance(
|
|
1557
|
-
model_artifact,
|
|
1558
|
-
ModelArtifact,
|
|
1559
|
-
):
|
|
1560
|
-
return [feature.name for feature in model_artifact.spec.outputs]
|
|
1561
|
-
elif isinstance(
|
|
1562
|
-
model_artifact,
|
|
1563
|
-
LLMPromptArtifact,
|
|
1564
|
-
):
|
|
1565
|
-
_model_artifact = model_artifact.model_artifact
|
|
1566
|
-
return [feature.name for feature in _model_artifact.spec.outputs]
|
|
1567
|
-
|
|
1568
|
-
@staticmethod
|
|
1569
|
-
def _get_model_endpoint_output_schema(
|
|
2021
|
+
def _get_model_endpoint_schema(
|
|
1570
2022
|
name: str,
|
|
1571
2023
|
project: str,
|
|
1572
2024
|
uid: str,
|
|
1573
|
-
) -> list[str]:
|
|
2025
|
+
) -> tuple[list[str], list[str]]:
|
|
1574
2026
|
output_schema = None
|
|
2027
|
+
input_schema = None
|
|
1575
2028
|
try:
|
|
1576
2029
|
model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
|
|
1577
2030
|
mlrun.db.get_run_db().get_model_endpoint(
|
|
@@ -1582,23 +2035,16 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1582
2035
|
)
|
|
1583
2036
|
)
|
|
1584
2037
|
output_schema = model_endpoint.spec.label_names
|
|
2038
|
+
input_schema = model_endpoint.spec.feature_names
|
|
1585
2039
|
except (
|
|
1586
2040
|
mlrun.errors.MLRunNotFoundError,
|
|
1587
2041
|
mlrun.errors.MLRunInvalidArgumentError,
|
|
1588
|
-
):
|
|
2042
|
+
) as ex:
|
|
1589
2043
|
logger.warning(
|
|
1590
|
-
f"Model endpoint not found, using default output schema for model {name}"
|
|
2044
|
+
f"Model endpoint not found, using default output schema for model {name}",
|
|
2045
|
+
error=f"{type(ex).__name__}: {ex}",
|
|
1591
2046
|
)
|
|
1592
|
-
return output_schema
|
|
1593
|
-
|
|
1594
|
-
@staticmethod
|
|
1595
|
-
def _split_path(path: str) -> Union[str, list[str], None]:
|
|
1596
|
-
if path is not None:
|
|
1597
|
-
parsed_path = path.split(".")
|
|
1598
|
-
if len(parsed_path) == 1:
|
|
1599
|
-
parsed_path = parsed_path[0]
|
|
1600
|
-
return parsed_path
|
|
1601
|
-
return path
|
|
2047
|
+
return input_schema, output_schema
|
|
1602
2048
|
|
|
1603
2049
|
def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]:
|
|
1604
2050
|
monitoring_data = deepcopy(
|
|
@@ -1608,59 +2054,117 @@ class ModelRunnerStep(MonitoredStep):
|
|
|
1608
2054
|
)
|
|
1609
2055
|
if isinstance(monitoring_data, dict):
|
|
1610
2056
|
for model in monitoring_data:
|
|
1611
|
-
monitoring_data[model][schemas.MonitoringData.
|
|
1612
|
-
monitoring_data
|
|
1613
|
-
or self._get_model_endpoint_output_schema(
|
|
1614
|
-
name=model,
|
|
1615
|
-
project=self.context.project if self.context else None,
|
|
1616
|
-
uid=monitoring_data.get(model, {}).get(
|
|
1617
|
-
mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
|
|
1618
|
-
),
|
|
1619
|
-
)
|
|
2057
|
+
monitoring_data[model][schemas.MonitoringData.INPUT_PATH] = split_path(
|
|
2058
|
+
monitoring_data[model][schemas.MonitoringData.INPUT_PATH]
|
|
1620
2059
|
)
|
|
1621
|
-
|
|
1622
|
-
|
|
2060
|
+
monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path(
|
|
2061
|
+
monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
|
|
2062
|
+
)
|
|
2063
|
+
|
|
2064
|
+
mep_output_schema, mep_input_schema = None, None
|
|
2065
|
+
|
|
2066
|
+
output_schema = self.class_args[
|
|
1623
2067
|
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
1624
|
-
][model][schemas.MonitoringData.OUTPUTS]
|
|
1625
|
-
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
|
|
1629
|
-
|
|
2068
|
+
][model][schemas.MonitoringData.OUTPUTS]
|
|
2069
|
+
input_schema = self.class_args[
|
|
2070
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2071
|
+
][model][schemas.MonitoringData.INPUTS]
|
|
2072
|
+
if not output_schema or not input_schema:
|
|
2073
|
+
# if output or input schema is not provided, try to get it from the model endpoint
|
|
2074
|
+
mep_input_schema, mep_output_schema = (
|
|
2075
|
+
self._get_model_endpoint_schema(
|
|
2076
|
+
model,
|
|
2077
|
+
self.context.project,
|
|
2078
|
+
monitoring_data[model].get(
|
|
2079
|
+
schemas.MonitoringData.MODEL_ENDPOINT_UID, ""
|
|
2080
|
+
),
|
|
2081
|
+
)
|
|
1630
2082
|
)
|
|
2083
|
+
self.class_args[
|
|
2084
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2085
|
+
][model][schemas.MonitoringData.OUTPUTS] = (
|
|
2086
|
+
output_schema or mep_output_schema
|
|
1631
2087
|
)
|
|
1632
|
-
|
|
1633
|
-
|
|
1634
|
-
|
|
1635
|
-
|
|
2088
|
+
self.class_args[
|
|
2089
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
|
|
2090
|
+
][model][schemas.MonitoringData.INPUTS] = (
|
|
2091
|
+
input_schema or mep_input_schema
|
|
1636
2092
|
)
|
|
1637
2093
|
return monitoring_data
|
|
2094
|
+
else:
|
|
2095
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
2096
|
+
"Monitoring data must be a dictionary."
|
|
2097
|
+
)
|
|
2098
|
+
|
|
2099
|
+
def configure_pool_resource(
|
|
2100
|
+
self,
|
|
2101
|
+
max_processes: Optional[int] = None,
|
|
2102
|
+
max_threads: Optional[int] = None,
|
|
2103
|
+
pool_factor: Optional[int] = None,
|
|
2104
|
+
) -> None:
|
|
2105
|
+
"""
|
|
2106
|
+
Configure the resource limits for the shared models in the graph.
|
|
2107
|
+
|
|
2108
|
+
:param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
|
|
2109
|
+
Defaults to the number of CPUs or 16 if undetectable.
|
|
2110
|
+
:param max_threads: Maximum number of threads to spawn. Defaults to 32.
|
|
2111
|
+
:param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
|
|
2112
|
+
"""
|
|
2113
|
+
self.max_processes = max_processes
|
|
2114
|
+
self.max_threads = max_threads
|
|
2115
|
+
self.pool_factor = pool_factor
|
|
1638
2116
|
|
|
1639
2117
|
def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
|
|
1640
2118
|
self.context = context
|
|
1641
2119
|
if not self._is_local_function(context):
|
|
1642
2120
|
# skip init of non local functions
|
|
1643
2121
|
return
|
|
1644
|
-
model_selector = self.class_args.get(
|
|
2122
|
+
model_selector, model_selector_params = self.class_args.get(
|
|
2123
|
+
"model_selector", (None, None)
|
|
2124
|
+
)
|
|
1645
2125
|
execution_mechanism_by_model_name = self.class_args.get(
|
|
1646
2126
|
schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
|
|
1647
2127
|
)
|
|
1648
2128
|
models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
|
|
1649
|
-
if
|
|
1650
|
-
model_selector = get_class(model_selector, namespace)(
|
|
2129
|
+
if model_selector:
|
|
2130
|
+
model_selector = get_class(model_selector, namespace).from_dict(
|
|
2131
|
+
model_selector_params, init_with_params=True
|
|
2132
|
+
)
|
|
1651
2133
|
model_objects = []
|
|
1652
2134
|
for model, model_params in models.values():
|
|
2135
|
+
model_name = model_params.get("name")
|
|
2136
|
+
model_params[schemas.MonitoringData.INPUT_PATH] = (
|
|
2137
|
+
self.class_args.get(
|
|
2138
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
2139
|
+
)
|
|
2140
|
+
.get(model_name, {})
|
|
2141
|
+
.get(schemas.MonitoringData.INPUT_PATH)
|
|
2142
|
+
)
|
|
2143
|
+
model_params[schemas.MonitoringData.RESULT_PATH] = (
|
|
2144
|
+
self.class_args.get(
|
|
2145
|
+
mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
|
|
2146
|
+
)
|
|
2147
|
+
.get(model_name, {})
|
|
2148
|
+
.get(schemas.MonitoringData.RESULT_PATH)
|
|
2149
|
+
)
|
|
1653
2150
|
model = get_class(model, namespace).from_dict(
|
|
1654
2151
|
model_params, init_with_params=True
|
|
1655
2152
|
)
|
|
1656
2153
|
model._raise_exception = False
|
|
2154
|
+
model._execution_mechanism = execution_mechanism_by_model_name.get(
|
|
2155
|
+
model_name
|
|
2156
|
+
)
|
|
1657
2157
|
model_objects.append(model)
|
|
1658
2158
|
self._async_object = ModelRunner(
|
|
1659
2159
|
model_selector=model_selector,
|
|
1660
2160
|
runnables=model_objects,
|
|
1661
2161
|
execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
|
|
2162
|
+
shared_proxy_mapping=self._shared_proxy_mapping or None,
|
|
1662
2163
|
name=self.name,
|
|
1663
2164
|
context=context,
|
|
2165
|
+
max_processes=self.max_processes,
|
|
2166
|
+
max_threads=self.max_threads,
|
|
2167
|
+
pool_factor=self.pool_factor,
|
|
1664
2168
|
)
|
|
1665
2169
|
|
|
1666
2170
|
|
|
@@ -2298,7 +2802,13 @@ class FlowStep(BaseStep):
|
|
|
2298
2802
|
if not step.before and not any(
|
|
2299
2803
|
[step.name in other_step.after for other_step in self._steps.values()]
|
|
2300
2804
|
):
|
|
2301
|
-
|
|
2805
|
+
if any(
|
|
2806
|
+
[
|
|
2807
|
+
getattr(step_in_graph, "responder", False)
|
|
2808
|
+
for step_in_graph in self._steps.values()
|
|
2809
|
+
]
|
|
2810
|
+
):
|
|
2811
|
+
step.responder = True
|
|
2302
2812
|
return
|
|
2303
2813
|
|
|
2304
2814
|
for step_name in step.before:
|
|
@@ -2381,14 +2891,20 @@ class RootFlowStep(FlowStep):
|
|
|
2381
2891
|
name: str,
|
|
2382
2892
|
model_class: Union[str, Model],
|
|
2383
2893
|
execution_mechanism: Union[str, ParallelExecutionMechanisms],
|
|
2384
|
-
model_artifact:
|
|
2894
|
+
model_artifact: Union[str, ModelArtifact],
|
|
2895
|
+
inputs: Optional[list[str]] = None,
|
|
2896
|
+
outputs: Optional[list[str]] = None,
|
|
2897
|
+
input_path: Optional[str] = None,
|
|
2898
|
+
result_path: Optional[str] = None,
|
|
2385
2899
|
override: bool = False,
|
|
2386
2900
|
**model_parameters,
|
|
2387
2901
|
) -> None:
|
|
2388
2902
|
"""
|
|
2389
2903
|
Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
|
|
2390
2904
|
:param name: Name of the shared model (should be unique in the graph)
|
|
2391
|
-
:param model_class: Model class name
|
|
2905
|
+
:param model_class: Model class name. If LLModel is chosen
|
|
2906
|
+
(either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
|
|
2907
|
+
outputs will be overridden with UsageResponseKeys fields.
|
|
2392
2908
|
:param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
|
|
2393
2909
|
* "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
|
|
2394
2910
|
intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
|
|
@@ -2411,6 +2927,19 @@ class RootFlowStep(FlowStep):
|
|
|
2411
2927
|
It means that the runnable will not actually be run in parallel to anything else.
|
|
2412
2928
|
|
|
2413
2929
|
:param model_artifact: model artifact or mlrun model artifact uri
|
|
2930
|
+
:param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
|
|
2931
|
+
that been configured in the model artifact, please note that those inputs need
|
|
2932
|
+
to be equal in length and order to the inputs that model_class
|
|
2933
|
+
predict method expects
|
|
2934
|
+
:param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
|
|
2935
|
+
that been configured in the model artifact, please note that those outputs need
|
|
2936
|
+
to be equal to the model_class
|
|
2937
|
+
predict method outputs (length, and order)
|
|
2938
|
+
:param input_path: input path inside the user event, expect scopes to be defined by dot notation
|
|
2939
|
+
(e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
|
|
2940
|
+
:param result_path: result path inside the user output event, expect scopes to be defined by dot
|
|
2941
|
+
notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
|
|
2942
|
+
in path.
|
|
2414
2943
|
:param override: bool allow override existing model on the current ModelRunnerStep.
|
|
2415
2944
|
:param model_parameters: Parameters for model instantiation
|
|
2416
2945
|
"""
|
|
@@ -2418,6 +2947,15 @@ class RootFlowStep(FlowStep):
|
|
|
2418
2947
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
2419
2948
|
"Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
|
|
2420
2949
|
)
|
|
2950
|
+
if type(model_class) is LLModel or (
|
|
2951
|
+
isinstance(model_class, str)
|
|
2952
|
+
and model_class.split(".")[-1] == LLModel.__name__
|
|
2953
|
+
):
|
|
2954
|
+
if outputs:
|
|
2955
|
+
warnings.warn(
|
|
2956
|
+
"LLModel with existing outputs detected, overriding to default"
|
|
2957
|
+
)
|
|
2958
|
+
outputs = UsageResponseKeys.fields()
|
|
2421
2959
|
|
|
2422
2960
|
if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
|
|
2423
2961
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -2433,6 +2971,7 @@ class RootFlowStep(FlowStep):
|
|
|
2433
2971
|
if isinstance(model_artifact, mlrun.artifacts.Artifact)
|
|
2434
2972
|
else model_artifact
|
|
2435
2973
|
)
|
|
2974
|
+
model_artifact = mlrun.utils.remove_tag_from_artifact_uri(model_artifact)
|
|
2436
2975
|
model_parameters["artifact_uri"] = model_parameters.get(
|
|
2437
2976
|
"artifact_uri", model_artifact
|
|
2438
2977
|
)
|
|
@@ -2444,6 +2983,14 @@ class RootFlowStep(FlowStep):
|
|
|
2444
2983
|
"Inconsistent name for the added model."
|
|
2445
2984
|
)
|
|
2446
2985
|
model_parameters["name"] = name
|
|
2986
|
+
model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
|
|
2987
|
+
model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
|
|
2988
|
+
model_parameters["input_path"] = input_path or model_parameters.get(
|
|
2989
|
+
"input_path"
|
|
2990
|
+
)
|
|
2991
|
+
model_parameters["result_path"] = result_path or model_parameters.get(
|
|
2992
|
+
"result_path"
|
|
2993
|
+
)
|
|
2447
2994
|
|
|
2448
2995
|
if name in self.shared_models and not override:
|
|
2449
2996
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -2458,7 +3005,9 @@ class RootFlowStep(FlowStep):
|
|
|
2458
3005
|
self.shared_models[name] = (model_class, model_parameters)
|
|
2459
3006
|
self.shared_models_mechanism[name] = execution_mechanism
|
|
2460
3007
|
|
|
2461
|
-
def
|
|
3008
|
+
def get_shared_model_by_artifact_uri(
|
|
3009
|
+
self, artifact_uri: str
|
|
3010
|
+
) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
|
|
2462
3011
|
"""
|
|
2463
3012
|
Get a shared model by its artifact URI.
|
|
2464
3013
|
:param artifact_uri: The artifact URI of the model.
|
|
@@ -2466,10 +3015,10 @@ class RootFlowStep(FlowStep):
|
|
|
2466
3015
|
"""
|
|
2467
3016
|
for model_name, (model_class, model_params) in self.shared_models.items():
|
|
2468
3017
|
if model_params.get("artifact_uri") == artifact_uri:
|
|
2469
|
-
return model_name
|
|
2470
|
-
return None
|
|
3018
|
+
return model_name, model_class, model_params
|
|
3019
|
+
return None, None, None
|
|
2471
3020
|
|
|
2472
|
-
def
|
|
3021
|
+
def configure_shared_pool_resource(
|
|
2473
3022
|
self,
|
|
2474
3023
|
max_processes: Optional[int] = None,
|
|
2475
3024
|
max_threads: Optional[int] = None,
|
|
@@ -2494,12 +3043,30 @@ class RootFlowStep(FlowStep):
|
|
|
2494
3043
|
max_threads=self.shared_max_threads,
|
|
2495
3044
|
pool_factor=self.pool_factor,
|
|
2496
3045
|
)
|
|
2497
|
-
|
|
3046
|
+
monitored_steps = self.get_monitored_steps().values()
|
|
3047
|
+
for monitored_step in monitored_steps:
|
|
3048
|
+
if isinstance(monitored_step, ModelRunnerStep):
|
|
3049
|
+
for model, model_params in self.shared_models.values():
|
|
3050
|
+
if "shared_proxy_mapping" in model_params:
|
|
3051
|
+
model_params["shared_proxy_mapping"].update(
|
|
3052
|
+
deepcopy(
|
|
3053
|
+
monitored_step._shared_proxy_mapping.get(
|
|
3054
|
+
model_params.get("name"), {}
|
|
3055
|
+
)
|
|
3056
|
+
)
|
|
3057
|
+
)
|
|
3058
|
+
else:
|
|
3059
|
+
model_params["shared_proxy_mapping"] = deepcopy(
|
|
3060
|
+
monitored_step._shared_proxy_mapping.get(
|
|
3061
|
+
model_params.get("name"), {}
|
|
3062
|
+
)
|
|
3063
|
+
)
|
|
2498
3064
|
for model, model_params in self.shared_models.values():
|
|
2499
3065
|
model = get_class(model, namespace).from_dict(
|
|
2500
3066
|
model_params, init_with_params=True
|
|
2501
3067
|
)
|
|
2502
3068
|
model._raise_exception = False
|
|
3069
|
+
model._execution_mechanism = self._shared_models_mechanism[model.name]
|
|
2503
3070
|
self.context.executor.add_runnable(
|
|
2504
3071
|
model, self._shared_models_mechanism[model.name]
|
|
2505
3072
|
)
|
|
@@ -2619,12 +3186,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
|
|
|
2619
3186
|
graph.edge(step.fullname, route.fullname)
|
|
2620
3187
|
|
|
2621
3188
|
|
|
2622
|
-
def _add_graphviz_model_runner(graph, step, source=None):
|
|
3189
|
+
def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
|
|
2623
3190
|
if source:
|
|
2624
3191
|
graph.node("_start", source.name, shape=source.shape, style="filled")
|
|
2625
3192
|
graph.edge("_start", step.fullname)
|
|
2626
|
-
|
|
2627
|
-
is_monitored = step._extract_root_step().track_models
|
|
2628
3193
|
m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
|
|
2629
3194
|
|
|
2630
3195
|
number_of_models = len(
|
|
@@ -2663,6 +3228,7 @@ def _add_graphviz_flow(
|
|
|
2663
3228
|
allow_empty=True
|
|
2664
3229
|
)
|
|
2665
3230
|
graph.node("_start", source.name, shape=source.shape, style="filled")
|
|
3231
|
+
is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
|
|
2666
3232
|
for start_step in start_steps:
|
|
2667
3233
|
graph.edge("_start", start_step.fullname)
|
|
2668
3234
|
for child in step.get_children():
|
|
@@ -2671,7 +3237,7 @@ def _add_graphviz_flow(
|
|
|
2671
3237
|
with graph.subgraph(name="cluster_" + child.fullname) as sg:
|
|
2672
3238
|
_add_graphviz_router(sg, child)
|
|
2673
3239
|
elif kind == StepKinds.model_runner:
|
|
2674
|
-
_add_graphviz_model_runner(graph, child)
|
|
3240
|
+
_add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
|
|
2675
3241
|
else:
|
|
2676
3242
|
graph.node(child.fullname, label=child.name, shape=child.get_shape())
|
|
2677
3243
|
_add_edges(child.after or [], step, graph, child)
|
|
@@ -2803,7 +3369,7 @@ def params_to_step(
|
|
|
2803
3369
|
step = QueueStep(name, **class_args)
|
|
2804
3370
|
|
|
2805
3371
|
elif class_name and hasattr(class_name, "to_dict"):
|
|
2806
|
-
struct = class_name.to_dict()
|
|
3372
|
+
struct = deepcopy(class_name.to_dict())
|
|
2807
3373
|
kind = struct.get("kind", StepKinds.task)
|
|
2808
3374
|
name = (
|
|
2809
3375
|
name
|
|
@@ -2890,7 +3456,7 @@ def _init_async_objects(context, steps):
|
|
|
2890
3456
|
datastore_profile = datastore_profile_read(stream_path)
|
|
2891
3457
|
if isinstance(
|
|
2892
3458
|
datastore_profile,
|
|
2893
|
-
(DatastoreProfileKafkaTarget,
|
|
3459
|
+
(DatastoreProfileKafkaTarget, DatastoreProfileKafkaStream),
|
|
2894
3460
|
):
|
|
2895
3461
|
step._async_object = KafkaStoreyTarget(
|
|
2896
3462
|
path=stream_path,
|
|
@@ -2906,7 +3472,7 @@ def _init_async_objects(context, steps):
|
|
|
2906
3472
|
else:
|
|
2907
3473
|
raise mlrun.errors.MLRunValueError(
|
|
2908
3474
|
f"Received an unexpected stream profile type: {type(datastore_profile)}\n"
|
|
2909
|
-
"Expects `DatastoreProfileV3io` or `
|
|
3475
|
+
"Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaStream`."
|
|
2910
3476
|
)
|
|
2911
3477
|
elif stream_path.startswith("kafka://") or kafka_brokers:
|
|
2912
3478
|
topic, brokers = parse_kafka_url(stream_path, kafka_brokers)
|
|
@@ -2922,6 +3488,8 @@ def _init_async_objects(context, steps):
|
|
|
2922
3488
|
context=context,
|
|
2923
3489
|
**options,
|
|
2924
3490
|
)
|
|
3491
|
+
elif stream_path.startswith("dummy://"):
|
|
3492
|
+
step._async_object = _DummyStream(context=context, **options)
|
|
2925
3493
|
else:
|
|
2926
3494
|
if stream_path.startswith("v3io://"):
|
|
2927
3495
|
endpoint, stream_path = parse_path(step.path)
|