mlrun 1.10.0rc13__py3-none-any.whl → 1.10.0rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (107) hide show
  1. mlrun/__init__.py +22 -2
  2. mlrun/artifacts/base.py +0 -31
  3. mlrun/artifacts/document.py +6 -1
  4. mlrun/artifacts/llm_prompt.py +123 -25
  5. mlrun/artifacts/manager.py +0 -5
  6. mlrun/artifacts/model.py +3 -3
  7. mlrun/common/constants.py +10 -1
  8. mlrun/common/formatters/artifact.py +1 -0
  9. mlrun/common/model_monitoring/helpers.py +86 -0
  10. mlrun/common/schemas/__init__.py +3 -0
  11. mlrun/common/schemas/auth.py +2 -0
  12. mlrun/common/schemas/function.py +10 -0
  13. mlrun/common/schemas/hub.py +30 -18
  14. mlrun/common/schemas/model_monitoring/__init__.py +3 -0
  15. mlrun/common/schemas/model_monitoring/constants.py +30 -6
  16. mlrun/common/schemas/model_monitoring/functions.py +14 -5
  17. mlrun/common/schemas/model_monitoring/model_endpoints.py +21 -0
  18. mlrun/common/schemas/pipeline.py +1 -1
  19. mlrun/common/schemas/serving.py +3 -0
  20. mlrun/common/schemas/workflow.py +3 -1
  21. mlrun/common/secrets.py +22 -1
  22. mlrun/config.py +33 -11
  23. mlrun/datastore/__init__.py +11 -3
  24. mlrun/datastore/azure_blob.py +162 -47
  25. mlrun/datastore/datastore.py +9 -4
  26. mlrun/datastore/datastore_profile.py +61 -5
  27. mlrun/datastore/model_provider/huggingface_provider.py +363 -0
  28. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  29. mlrun/datastore/model_provider/model_provider.py +230 -65
  30. mlrun/datastore/model_provider/openai_provider.py +295 -42
  31. mlrun/datastore/s3.py +24 -2
  32. mlrun/datastore/storeytargets.py +2 -3
  33. mlrun/datastore/utils.py +15 -3
  34. mlrun/db/base.py +47 -19
  35. mlrun/db/httpdb.py +120 -56
  36. mlrun/db/nopdb.py +38 -10
  37. mlrun/execution.py +70 -19
  38. mlrun/hub/__init__.py +15 -0
  39. mlrun/hub/module.py +181 -0
  40. mlrun/k8s_utils.py +105 -16
  41. mlrun/launcher/base.py +13 -6
  42. mlrun/launcher/local.py +15 -0
  43. mlrun/model.py +24 -3
  44. mlrun/model_monitoring/__init__.py +1 -0
  45. mlrun/model_monitoring/api.py +66 -27
  46. mlrun/model_monitoring/applications/__init__.py +1 -1
  47. mlrun/model_monitoring/applications/base.py +509 -117
  48. mlrun/model_monitoring/applications/context.py +2 -4
  49. mlrun/model_monitoring/applications/results.py +4 -7
  50. mlrun/model_monitoring/controller.py +239 -101
  51. mlrun/model_monitoring/db/_schedules.py +116 -33
  52. mlrun/model_monitoring/db/_stats.py +4 -3
  53. mlrun/model_monitoring/db/tsdb/base.py +100 -9
  54. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +11 -6
  55. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +191 -50
  56. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
  57. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
  58. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +259 -40
  59. mlrun/model_monitoring/helpers.py +54 -9
  60. mlrun/model_monitoring/stream_processing.py +45 -14
  61. mlrun/model_monitoring/writer.py +220 -1
  62. mlrun/platforms/__init__.py +3 -2
  63. mlrun/platforms/iguazio.py +7 -3
  64. mlrun/projects/operations.py +6 -1
  65. mlrun/projects/pipelines.py +46 -26
  66. mlrun/projects/project.py +166 -58
  67. mlrun/run.py +94 -17
  68. mlrun/runtimes/__init__.py +18 -0
  69. mlrun/runtimes/base.py +14 -6
  70. mlrun/runtimes/daskjob.py +7 -0
  71. mlrun/runtimes/local.py +5 -2
  72. mlrun/runtimes/mounts.py +20 -2
  73. mlrun/runtimes/mpijob/abstract.py +6 -0
  74. mlrun/runtimes/mpijob/v1.py +6 -0
  75. mlrun/runtimes/nuclio/__init__.py +1 -0
  76. mlrun/runtimes/nuclio/application/application.py +149 -17
  77. mlrun/runtimes/nuclio/function.py +76 -27
  78. mlrun/runtimes/nuclio/serving.py +97 -15
  79. mlrun/runtimes/pod.py +234 -21
  80. mlrun/runtimes/remotesparkjob.py +6 -0
  81. mlrun/runtimes/sparkjob/spark3job.py +6 -0
  82. mlrun/runtimes/utils.py +49 -11
  83. mlrun/secrets.py +54 -13
  84. mlrun/serving/__init__.py +2 -0
  85. mlrun/serving/remote.py +79 -6
  86. mlrun/serving/routers.py +23 -41
  87. mlrun/serving/server.py +320 -80
  88. mlrun/serving/states.py +725 -157
  89. mlrun/serving/steps.py +62 -0
  90. mlrun/serving/system_steps.py +200 -119
  91. mlrun/serving/v2_serving.py +9 -10
  92. mlrun/utils/helpers.py +288 -88
  93. mlrun/utils/logger.py +3 -1
  94. mlrun/utils/notifications/notification/base.py +18 -0
  95. mlrun/utils/notifications/notification/git.py +2 -4
  96. mlrun/utils/notifications/notification/slack.py +2 -4
  97. mlrun/utils/notifications/notification/webhook.py +2 -5
  98. mlrun/utils/notifications/notification_pusher.py +1 -1
  99. mlrun/utils/retryer.py +15 -2
  100. mlrun/utils/version/version.json +2 -2
  101. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/METADATA +45 -51
  102. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/RECORD +106 -101
  103. mlrun/api/schemas/__init__.py +0 -259
  104. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/WHEEL +0 -0
  105. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/entry_points.txt +0 -0
  106. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/licenses/LICENSE +0 -0
  107. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -24,6 +24,7 @@ import inspect
24
24
  import os
25
25
  import pathlib
26
26
  import traceback
27
+ import warnings
27
28
  from abc import ABC
28
29
  from copy import copy, deepcopy
29
30
  from inspect import getfullargspec, signature
@@ -35,20 +36,24 @@ from storey import ParallelExecutionMechanisms
35
36
  import mlrun
36
37
  import mlrun.artifacts
37
38
  import mlrun.common.schemas as schemas
38
- from mlrun.artifacts.llm_prompt import LLMPromptArtifact
39
+ from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
39
40
  from mlrun.artifacts.model import ModelArtifact
40
41
  from mlrun.datastore.datastore_profile import (
41
- DatastoreProfileKafkaSource,
42
+ DatastoreProfileKafkaStream,
42
43
  DatastoreProfileKafkaTarget,
43
44
  DatastoreProfileV3io,
44
45
  datastore_profile_read,
45
46
  )
46
- from mlrun.datastore.model_provider.model_provider import ModelProvider
47
+ from mlrun.datastore.model_provider.model_provider import (
48
+ InvokeResponseFormat,
49
+ ModelProvider,
50
+ UsageResponseKeys,
51
+ )
47
52
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
48
- from mlrun.utils import logger
53
+ from mlrun.utils import get_data_from_path, logger, set_data_by_path, split_path
49
54
 
50
55
  from ..config import config
51
- from ..datastore import get_stream_pusher
56
+ from ..datastore import _DummyStream, get_stream_pusher
52
57
  from ..datastore.utils import (
53
58
  get_kafka_brokers_from_dict,
54
59
  parse_kafka_url,
@@ -501,10 +506,15 @@ class BaseStep(ModelObj):
501
506
  def verify_model_runner_step(
502
507
  self,
503
508
  step: "ModelRunnerStep",
509
+ step_model_endpoints_names: Optional[list[str]] = None,
510
+ verify_shared_models: bool = True,
504
511
  ):
505
512
  """
506
513
  Verify ModelRunnerStep, can be part of Flow graph and models can not repeat in graph.
507
- :param step: ModelRunnerStep to verify
514
+ :param step: ModelRunnerStep to verify
515
+ :param step_model_endpoints_names: List of model endpoints names that are in the step.
516
+ if provided will ignore step models and verify only the models on list.
517
+ :param verify_shared_models: If True, verify that shared models are defined in the graph.
508
518
  """
509
519
 
510
520
  if not isinstance(step, ModelRunnerStep):
@@ -512,11 +522,13 @@ class BaseStep(ModelObj):
512
522
 
513
523
  root = self._extract_root_step()
514
524
 
515
- if not isinstance(root, RootFlowStep):
525
+ if not isinstance(root, RootFlowStep) or (
526
+ isinstance(root, RootFlowStep) and root.engine != "async"
527
+ ):
516
528
  raise GraphError(
517
529
  "ModelRunnerStep can be added to 'Flow' topology graph only"
518
530
  )
519
- step_model_endpoints_names = list(
531
+ step_model_endpoints_names = step_model_endpoints_names or list(
520
532
  step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}).keys()
521
533
  )
522
534
  # Get all model_endpoints names that are in both lists
@@ -530,13 +542,14 @@ class BaseStep(ModelObj):
530
542
  f"The graph already contains the model endpoints named - {common_endpoints_names}."
531
543
  )
532
544
 
533
- # Check if shared models are defined in the graph
534
- self._verify_shared_models(root, step, step_model_endpoints_names)
545
+ if verify_shared_models:
546
+ # Check if shared models are defined in the graph
547
+ self._verify_shared_models(root, step, step_model_endpoints_names)
535
548
  # Update model endpoints names in the root step
536
549
  root.update_model_endpoints_names(step_model_endpoints_names)
537
550
 
538
- @staticmethod
539
551
  def _verify_shared_models(
552
+ self,
540
553
  root: "RootFlowStep",
541
554
  step: "ModelRunnerStep",
542
555
  step_model_endpoints_names: list[str],
@@ -565,33 +578,41 @@ class BaseStep(ModelObj):
565
578
  prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
566
579
  # if the model artifact is a prompt, we need to get the model URI
567
580
  # to ensure that the shared runnable name is correct
581
+ llm_artifact_uri = None
568
582
  if prefix == mlrun.utils.StorePrefix.LLMPrompt:
569
583
  llm_artifact, _ = mlrun.store_manager.get_store_artifact(
570
584
  model_artifact_uri
571
585
  )
572
- model_artifact_uri = llm_artifact.spec.parent_uri
573
- actual_shared_name = root.get_shared_model_name_by_artifact_uri(
574
- model_artifact_uri
586
+ llm_artifact_uri = llm_artifact.uri
587
+ model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
588
+ llm_artifact.spec.parent_uri
589
+ )
590
+ actual_shared_name, shared_model_class, shared_model_params = (
591
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
575
592
  )
576
593
 
577
- if not shared_runnable_name:
578
- if not actual_shared_name:
579
- raise GraphError(
580
- f"Can't find shared model for {name} model endpoint"
581
- )
582
- else:
583
- step.class_args[schemas.ModelRunnerStepData.MODELS][name][
584
- schemas.ModelsData.MODEL_PARAMETERS.value
585
- ]["shared_runnable_name"] = actual_shared_name
586
- shared_models.append(actual_shared_name)
594
+ if not actual_shared_name:
595
+ raise GraphError(
596
+ f"Can't find shared model named {shared_runnable_name}"
597
+ )
598
+ elif not shared_runnable_name:
599
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
600
+ schemas.ModelsData.MODEL_PARAMETERS.value
601
+ ]["shared_runnable_name"] = actual_shared_name
587
602
  elif actual_shared_name != shared_runnable_name:
588
603
  raise GraphError(
589
604
  f"Model endpoint {name} shared runnable name mismatch: "
590
605
  f"expected {actual_shared_name}, got {shared_runnable_name}"
591
606
  )
592
- else:
593
- shared_models.append(actual_shared_name)
594
-
607
+ shared_models.append(actual_shared_name)
608
+ self._edit_proxy_model_data(
609
+ step,
610
+ name,
611
+ actual_shared_name,
612
+ shared_model_params,
613
+ shared_model_class,
614
+ llm_artifact_uri or model_artifact_uri,
615
+ )
595
616
  undefined_shared_models = list(
596
617
  set(shared_models) - set(root.shared_models.keys())
597
618
  )
@@ -600,6 +621,52 @@ class BaseStep(ModelObj):
600
621
  f"The following shared models are not defined in the graph: {undefined_shared_models}."
601
622
  )
602
623
 
624
+ @staticmethod
625
+ def _edit_proxy_model_data(
626
+ step: "ModelRunnerStep",
627
+ name: str,
628
+ actual_shared_name: str,
629
+ shared_model_params: dict,
630
+ shared_model_class: Any,
631
+ artifact: Union[ModelArtifact, LLMPromptArtifact, str],
632
+ ):
633
+ monitoring_data = step.class_args.setdefault(
634
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
635
+ )
636
+
637
+ # edit monitoring data according to the shared model parameters
638
+ monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
639
+ "input_path"
640
+ ]
641
+ monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
642
+ "result_path"
643
+ ]
644
+ monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
645
+ "inputs"
646
+ ]
647
+ monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
648
+ "outputs"
649
+ ]
650
+ monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
651
+ shared_model_class
652
+ if isinstance(shared_model_class, str)
653
+ else shared_model_class.__class__.__name__
654
+ )
655
+ if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
656
+ step._shared_proxy_mapping[actual_shared_name] = {
657
+ name: artifact.uri
658
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
659
+ else artifact
660
+ }
661
+ elif actual_shared_name:
662
+ step._shared_proxy_mapping[actual_shared_name].update(
663
+ {
664
+ name: artifact.uri
665
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
666
+ else artifact
667
+ }
668
+ )
669
+
603
670
 
604
671
  class TaskStep(BaseStep):
605
672
  """task execution step, runs a class or handler"""
@@ -1081,6 +1148,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1081
1148
  "raise_exception",
1082
1149
  "artifact_uri",
1083
1150
  "shared_runnable_name",
1151
+ "shared_proxy_mapping",
1152
+ "execution_mechanism",
1084
1153
  ]
1085
1154
  kind = "model"
1086
1155
 
@@ -1089,15 +1158,21 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1089
1158
  name: str,
1090
1159
  raise_exception: bool = True,
1091
1160
  artifact_uri: Optional[str] = None,
1161
+ shared_proxy_mapping: Optional[dict] = None,
1092
1162
  **kwargs,
1093
1163
  ):
1094
1164
  super().__init__(name=name, raise_exception=raise_exception, **kwargs)
1095
1165
  if artifact_uri is not None and not isinstance(artifact_uri, str):
1096
1166
  raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
1097
1167
  self.artifact_uri = artifact_uri
1168
+ self.shared_proxy_mapping: dict[
1169
+ str : Union[str, ModelArtifact, LLMPromptArtifact]
1170
+ ] = shared_proxy_mapping
1098
1171
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1099
1172
  self.model_artifact: Optional[ModelArtifact] = None
1100
1173
  self.model_provider: Optional[ModelProvider] = None
1174
+ self._artifact_were_loaded = False
1175
+ self._execution_mechanism = None
1101
1176
 
1102
1177
  def __init_subclass__(cls):
1103
1178
  super().__init_subclass__()
@@ -1117,18 +1192,37 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1117
1192
  raise_missing_schema_exception=False,
1118
1193
  )
1119
1194
 
1120
- def _load_artifacts(self) -> None:
1121
- artifact = self._get_artifact_object()
1122
- if isinstance(artifact, LLMPromptArtifact):
1123
- self.invocation_artifact = artifact
1124
- self.model_artifact = self.invocation_artifact.model_artifact
1195
+ # Check if the relevant predict method is implemented when trying to initialize the model
1196
+ if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
1197
+ if self.__class__.predict_async is Model.predict_async:
1198
+ raise mlrun.errors.ModelRunnerError(
1199
+ f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict_async() "
1200
+ f"is not implemented"
1201
+ )
1125
1202
  else:
1126
- self.model_artifact = artifact
1203
+ if self.__class__.predict is Model.predict:
1204
+ raise mlrun.errors.ModelRunnerError(
1205
+ f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict() "
1206
+ f"is not implemented"
1207
+ )
1127
1208
 
1128
- def _get_artifact_object(self) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1129
- if self.artifact_uri:
1130
- if mlrun.datastore.is_store_uri(self.artifact_uri):
1131
- artifact, _ = mlrun.store_manager.get_store_artifact(self.artifact_uri)
1209
+ def _load_artifacts(self) -> None:
1210
+ if not self._artifact_were_loaded:
1211
+ artifact = self._get_artifact_object()
1212
+ if isinstance(artifact, LLMPromptArtifact):
1213
+ self.invocation_artifact = artifact
1214
+ self.model_artifact = self.invocation_artifact.model_artifact
1215
+ else:
1216
+ self.model_artifact = artifact
1217
+ self._artifact_were_loaded = True
1218
+
1219
+ def _get_artifact_object(
1220
+ self, proxy_uri: Optional[str] = None
1221
+ ) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1222
+ uri = proxy_uri or self.artifact_uri
1223
+ if uri:
1224
+ if mlrun.datastore.is_store_uri(uri):
1225
+ artifact, _ = mlrun.store_manager.get_store_artifact(uri)
1132
1226
  return artifact
1133
1227
  else:
1134
1228
  raise ValueError(
@@ -1140,18 +1234,20 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1140
1234
  def init(self):
1141
1235
  self.load()
1142
1236
 
1143
- def predict(self, body: Any) -> Any:
1237
+ def predict(self, body: Any, **kwargs) -> Any:
1144
1238
  """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
1145
- return body
1239
+ raise NotImplementedError("predict() method not implemented")
1146
1240
 
1147
- async def predict_async(self, body: Any) -> Any:
1241
+ async def predict_async(self, body: Any, **kwargs) -> Any:
1148
1242
  """Override to implement prediction logic if the logic requires asyncio."""
1149
- return body
1243
+ raise NotImplementedError("predict_async() method not implemented")
1150
1244
 
1151
- def run(self, body: Any, path: str) -> Any:
1245
+ def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1152
1246
  return self.predict(body)
1153
1247
 
1154
- async def run_async(self, body: Any, path: str) -> Any:
1248
+ async def run_async(
1249
+ self, body: Any, path: str, origin_name: Optional[str] = None
1250
+ ) -> Any:
1155
1251
  return await self.predict_async(body)
1156
1252
 
1157
1253
  def get_local_model_path(self, suffix="") -> (str, dict):
@@ -1186,9 +1282,291 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1186
1282
  return None, None
1187
1283
 
1188
1284
 
1189
- class ModelSelector:
1285
+ class LLModel(Model):
1286
+ """
1287
+ A model wrapper for handling LLM (Large Language Model) prompt-based inference.
1288
+
1289
+ This class extends the base `Model` to provide specialized handling for
1290
+ `LLMPromptArtifact` objects, enabling both synchronous and asynchronous
1291
+ invocation of language models.
1292
+
1293
+ **Model Invocation**:
1294
+
1295
+ - The execution of enriched prompts is delegated to the `model_provider`
1296
+ configured for the model (e.g., **Hugging Face** or **OpenAI**).
1297
+ - The `model_provider` is responsible for sending the prompt to the correct
1298
+ backend API and returning the generated output.
1299
+ - Users can override the `predict` and `predict_async` methods to customize
1300
+ the behavior of the model invocation.
1301
+
1302
+ **Prompt Enrichment Overview**:
1303
+
1304
+ - If an `LLMPromptArtifact` is found, load its prompt template and fill in
1305
+ placeholders using values from the request body.
1306
+ - If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
1307
+ to retrieve `messages` directly from the request body using the input path.
1308
+
1309
+ **Simplified Example**:
1310
+
1311
+ Input body::
1312
+
1313
+ {"city": "Paris", "days": 3}
1314
+
1315
+ Prompt template in artifact::
1316
+
1317
+ [
1318
+ {"role": "system", "content": "You are a travel planning assistant."},
1319
+ {"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
1320
+ ]
1321
+
1322
+ Result after enrichment::
1323
+
1324
+ [
1325
+ {"role": "system", "content": "You are a travel planning assistant."},
1326
+ {"role": "user", "content": "Create a 3-day itinerary for Paris."},
1327
+ ]
1328
+
1329
+ :param name: Name of the model.
1330
+ :param input_path: Path in the request body where input data is located.
1331
+ :param result_path: Path in the response body where model outputs and the statistics
1332
+ will be stored.
1333
+ """
1334
+
1335
+ _dict_fields = Model._dict_fields + ["result_path", "input_path"]
1336
+
1337
+ def __init__(
1338
+ self,
1339
+ name: str,
1340
+ input_path: Optional[Union[str, list[str]]] = None,
1341
+ result_path: Optional[Union[str, list[str]]] = None,
1342
+ **kwargs,
1343
+ ):
1344
+ super().__init__(name, **kwargs)
1345
+ self._input_path = split_path(input_path)
1346
+ self._result_path = split_path(result_path)
1347
+ logger.info(
1348
+ "LLModel initialized",
1349
+ model_name=name,
1350
+ input_path=input_path,
1351
+ result_path=result_path,
1352
+ )
1353
+
1354
+ def predict(
1355
+ self,
1356
+ body: Any,
1357
+ messages: Optional[list[dict]] = None,
1358
+ invocation_config: Optional[dict] = None,
1359
+ **kwargs,
1360
+ ) -> Any:
1361
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1362
+ if isinstance(
1363
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1364
+ ) and isinstance(self.model_provider, ModelProvider):
1365
+ logger.debug(
1366
+ "Invoking model provider",
1367
+ model_name=self.name,
1368
+ messages=messages,
1369
+ invocation_config=invocation_config,
1370
+ )
1371
+ response_with_stats = self.model_provider.invoke(
1372
+ messages=messages,
1373
+ invoke_response_format=InvokeResponseFormat.USAGE,
1374
+ **(invocation_config or {}),
1375
+ )
1376
+ set_data_by_path(
1377
+ path=self._result_path, data=body, value=response_with_stats
1378
+ )
1379
+ logger.debug(
1380
+ "LLModel prediction completed",
1381
+ model_name=self.name,
1382
+ answer=response_with_stats.get("answer"),
1383
+ usage=response_with_stats.get("usage"),
1384
+ )
1385
+ else:
1386
+ logger.warning(
1387
+ "LLModel invocation artifact or model provider not set, skipping prediction",
1388
+ model_name=self.name,
1389
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1390
+ model_provider_type=type(self.model_provider).__name__,
1391
+ )
1392
+ return body
1393
+
1394
+ async def predict_async(
1395
+ self,
1396
+ body: Any,
1397
+ messages: Optional[list[dict]] = None,
1398
+ invocation_config: Optional[dict] = None,
1399
+ **kwargs,
1400
+ ) -> Any:
1401
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1402
+ if isinstance(
1403
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1404
+ ) and isinstance(self.model_provider, ModelProvider):
1405
+ logger.debug(
1406
+ "Async invoking model provider",
1407
+ model_name=self.name,
1408
+ messages=messages,
1409
+ invocation_config=invocation_config,
1410
+ )
1411
+ response_with_stats = await self.model_provider.async_invoke(
1412
+ messages=messages,
1413
+ invoke_response_format=InvokeResponseFormat.USAGE,
1414
+ **(invocation_config or {}),
1415
+ )
1416
+ set_data_by_path(
1417
+ path=self._result_path, data=body, value=response_with_stats
1418
+ )
1419
+ logger.debug(
1420
+ "LLModel async prediction completed",
1421
+ model_name=self.name,
1422
+ answer=response_with_stats.get("answer"),
1423
+ usage=response_with_stats.get("usage"),
1424
+ )
1425
+ else:
1426
+ logger.warning(
1427
+ "LLModel invocation artifact or model provider not set, skipping async prediction",
1428
+ model_name=self.name,
1429
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1430
+ model_provider_type=type(self.model_provider).__name__,
1431
+ )
1432
+ return body
1433
+
1434
+ def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1435
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1436
+ messages, invocation_config = self.enrich_prompt(
1437
+ body, origin_name, llm_prompt_artifact
1438
+ )
1439
+ logger.info(
1440
+ "Calling LLModel predict",
1441
+ model_name=self.name,
1442
+ model_endpoint_name=origin_name,
1443
+ messages_len=len(messages) if messages else 0,
1444
+ )
1445
+ return self.predict(
1446
+ body,
1447
+ messages=messages,
1448
+ invocation_config=invocation_config,
1449
+ llm_prompt_artifact=llm_prompt_artifact,
1450
+ )
1451
+
1452
+ async def run_async(
1453
+ self, body: Any, path: str, origin_name: Optional[str] = None
1454
+ ) -> Any:
1455
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1456
+ messages, invocation_config = self.enrich_prompt(
1457
+ body, origin_name, llm_prompt_artifact
1458
+ )
1459
+ logger.info(
1460
+ "Calling LLModel async predict",
1461
+ model_name=self.name,
1462
+ model_endpoint_name=origin_name,
1463
+ messages_len=len(messages) if messages else 0,
1464
+ )
1465
+ return await self.predict_async(
1466
+ body,
1467
+ messages=messages,
1468
+ invocation_config=invocation_config,
1469
+ llm_prompt_artifact=llm_prompt_artifact,
1470
+ )
1471
+
1472
+ def enrich_prompt(
1473
+ self,
1474
+ body: dict,
1475
+ origin_name: str,
1476
+ llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
1477
+ ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1478
+ logger.info(
1479
+ "Enriching prompt",
1480
+ model_name=self.name,
1481
+ model_endpoint_name=origin_name,
1482
+ )
1483
+ if not llm_prompt_artifact or not (
1484
+ llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1485
+ ):
1486
+ logger.warning(
1487
+ "LLModel must be provided with LLMPromptArtifact",
1488
+ model_name=self.name,
1489
+ artifact_type=type(llm_prompt_artifact).__name__,
1490
+ llm_prompt_artifact=llm_prompt_artifact,
1491
+ )
1492
+ prompt_legend, prompt_template, invocation_config = {}, [], {}
1493
+ else:
1494
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1495
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1496
+ invocation_config = llm_prompt_artifact.spec.invocation_config
1497
+ input_data = copy(get_data_from_path(self._input_path, body))
1498
+ if isinstance(input_data, dict) and prompt_template:
1499
+ kwargs = (
1500
+ {
1501
+ place_holder: input_data.get(body_map["field"])
1502
+ for place_holder, body_map in prompt_legend.items()
1503
+ if input_data.get(body_map["field"])
1504
+ }
1505
+ if prompt_legend
1506
+ else {}
1507
+ )
1508
+ input_data.update(kwargs)
1509
+ default_place_holders = PlaceholderDefaultDict(lambda: None, input_data)
1510
+ for message in prompt_template:
1511
+ try:
1512
+ message["content"] = message["content"].format(**input_data)
1513
+ except KeyError as e:
1514
+ logger.warning(
1515
+ "Input data missing placeholder, content stays unformatted",
1516
+ model_name=self.name,
1517
+ key_error=mlrun.errors.err_to_str(e),
1518
+ )
1519
+ message["content"] = message["content"].format_map(
1520
+ default_place_holders
1521
+ )
1522
+ elif isinstance(input_data, dict) and not prompt_template:
1523
+ # If there is no prompt template, we assume the input data is already in the correct format.
1524
+ logger.debug("Attempting to retrieve messages from the request body.")
1525
+ prompt_template = input_data.get("messages", [])
1526
+ else:
1527
+ logger.warning(
1528
+ "Expected input data to be a dict, prompt template stays unformatted",
1529
+ model_name=self.name,
1530
+ input_data_type=type(input_data).__name__,
1531
+ )
1532
+ return prompt_template, invocation_config
1533
+
1534
+ def _get_invocation_artifact(
1535
+ self, origin_name: Optional[str] = None
1536
+ ) -> Union[LLMPromptArtifact, None]:
1537
+ """
1538
+ Get the LLMPromptArtifact object for this model.
1539
+
1540
+ :param proxy_uri: Optional; URI to the proxy artifact.
1541
+ :return: LLMPromptArtifact object or None if not found.
1542
+ """
1543
+ if origin_name and self.shared_proxy_mapping:
1544
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1545
+ if isinstance(llm_prompt_artifact, str):
1546
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1547
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1548
+ elif self._artifact_were_loaded:
1549
+ llm_prompt_artifact = self.invocation_artifact
1550
+ else:
1551
+ self._load_artifacts()
1552
+ llm_prompt_artifact = self.invocation_artifact
1553
+ return llm_prompt_artifact
1554
+
1555
+
1556
+ class ModelSelector(ModelObj):
1190
1557
  """Used to select which models to run on each event."""
1191
1558
 
1559
+ def __init__(self, **kwargs):
1560
+ super().__init__()
1561
+
1562
+ def __init_subclass__(cls):
1563
+ super().__init_subclass__()
1564
+ cls._dict_fields = list(
1565
+ set(cls._dict_fields)
1566
+ | set(inspect.signature(cls.__init__).parameters.keys())
1567
+ )
1568
+ cls._dict_fields.remove("self")
1569
+
1192
1570
  def select(
1193
1571
  self, event, available_models: list[Model]
1194
1572
  ) -> Union[list[str], list[Model]]:
@@ -1280,6 +1658,13 @@ class ModelRunnerStep(MonitoredStep):
1280
1658
  model_runner_step.add_model(..., model_class=MyModel(name="my_model"))
1281
1659
  graph.to(model_runner_step)
1282
1660
 
1661
+ Note when ModelRunnerStep is used in a graph, MLRun automatically imports
1662
+ the default language model class (LLModel) during function deployment.
1663
+
1664
+ Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
1665
+
1666
+ Note see config_pool_resource method documentation for default number of max threads and max processes.
1667
+
1283
1668
  :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1284
1669
  event. Optional. If not passed, all models will be run.
1285
1670
  :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
@@ -1292,25 +1677,54 @@ class ModelRunnerStep(MonitoredStep):
1292
1677
  """
1293
1678
 
1294
1679
  kind = "model_runner"
1680
+ _dict_fields = MonitoredStep._dict_fields + [
1681
+ "_shared_proxy_mapping",
1682
+ "max_processes",
1683
+ "max_threads",
1684
+ "pool_factor",
1685
+ ]
1295
1686
 
1296
1687
  def __init__(
1297
1688
  self,
1298
1689
  *args,
1299
1690
  name: Optional[str] = None,
1300
1691
  model_selector: Optional[Union[str, ModelSelector]] = None,
1692
+ model_selector_parameters: Optional[dict] = None,
1301
1693
  raise_exception: bool = True,
1302
1694
  **kwargs,
1303
1695
  ):
1696
+ self.max_processes = None
1697
+ self.max_threads = None
1698
+ self.pool_factor = None
1699
+
1700
+ if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1701
+ raise mlrun.errors.MLRunInvalidArgumentError(
1702
+ "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1703
+ "`model_selector_parameters`."
1704
+ )
1705
+ if model_selector:
1706
+ model_selector_parameters = model_selector_parameters or (
1707
+ model_selector.to_dict()
1708
+ if isinstance(model_selector, ModelSelector)
1709
+ else {}
1710
+ )
1711
+ model_selector = (
1712
+ model_selector
1713
+ if isinstance(model_selector, str)
1714
+ else model_selector.__class__.__name__
1715
+ )
1716
+
1304
1717
  super().__init__(
1305
1718
  *args,
1306
1719
  name=name,
1307
1720
  raise_exception=raise_exception,
1308
1721
  class_name="mlrun.serving.ModelRunner",
1309
- class_args=dict(model_selector=model_selector),
1722
+ class_args=dict(model_selector=(model_selector, model_selector_parameters)),
1310
1723
  **kwargs,
1311
1724
  )
1312
1725
  self.raise_exception = raise_exception
1313
1726
  self.shape = "folder"
1727
+ self._shared_proxy_mapping = {}
1314
1728
 
1315
1729
  def add_shared_model_proxy(
1316
1730
  self,
@@ -1321,10 +1735,6 @@ class ModelRunnerStep(MonitoredStep):
1321
1735
  model_endpoint_creation_strategy: Optional[
1322
1736
  schemas.ModelEndpointCreationStrategy
1323
1737
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1324
- inputs: Optional[list[str]] = None,
1325
- outputs: Optional[list[str]] = None,
1326
- input_path: Optional[str] = None,
1327
- result_path: Optional[str] = None,
1328
1738
  override: bool = False,
1329
1739
  ) -> None:
1330
1740
  """
@@ -1347,22 +1757,12 @@ class ModelRunnerStep(MonitoredStep):
1347
1757
  1. If model endpoints with the same name exist, preserve them.
1348
1758
  2. Create a new model endpoint with the same name and set it to `latest`.
1349
1759
 
1350
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1351
- that been configured in the model artifact, please note that those inputs need to
1352
- be equal in length and order to the inputs that model_class predict method expects
1353
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1354
- that been configured in the model artifact, please note that those outputs need to
1355
- be equal to the model_class predict method outputs (length, and order)
1356
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1357
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1358
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
1359
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1360
- in path.
1361
1760
  :param override: bool allow override existing model on the current ModelRunnerStep.
1761
+ :raise GraphError: when the shared model is not found in the root flow step shared models.
1362
1762
  """
1363
- model_class = Model(
1364
- name=endpoint_name,
1365
- shared_runnable_name=shared_model_name,
1763
+ model_class, model_params = (
1764
+ "mlrun.serving.Model",
1765
+ {"name": endpoint_name, "shared_runnable_name": shared_model_name},
1366
1766
  )
1367
1767
  if isinstance(model_artifact, str):
1368
1768
  model_artifact_uri = model_artifact
@@ -1375,11 +1775,21 @@ class ModelRunnerStep(MonitoredStep):
1375
1775
  "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1376
1776
  )
1377
1777
  root = self._extract_root_step()
1778
+ shared_model_params = {}
1378
1779
  if isinstance(root, RootFlowStep):
1379
- shared_model_name = (
1380
- shared_model_name
1381
- or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1780
+ actual_shared_model_name, shared_model_class, shared_model_params = (
1781
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
1382
1782
  )
1783
+ if not actual_shared_model_name or (
1784
+ shared_model_name and actual_shared_model_name != shared_model_name
1785
+ ):
1786
+ raise GraphError(
1787
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1788
+ f"model {shared_model_name} is not in the shared models."
1789
+ )
1790
+ elif not shared_model_name:
1791
+ shared_model_name = actual_shared_model_name
1792
+ model_params["shared_runnable_name"] = shared_model_name
1383
1793
  if not root.shared_models or (
1384
1794
  root.shared_models
1385
1795
  and shared_model_name
@@ -1389,6 +1799,34 @@ class ModelRunnerStep(MonitoredStep):
1389
1799
  f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1390
1800
  f"model {shared_model_name} is not in the shared models."
1391
1801
  )
1802
+ monitoring_data = self.class_args.get(
1803
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
1804
+ )
1805
+ monitoring_data.setdefault(endpoint_name, {})[
1806
+ schemas.MonitoringData.MODEL_CLASS
1807
+ ] = (
1808
+ shared_model_class
1809
+ if isinstance(shared_model_class, str)
1810
+ else shared_model_class.__class__.__name__
1811
+ )
1812
+ self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
1813
+ monitoring_data
1814
+ )
1815
+
1816
+ if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
1817
+ self._shared_proxy_mapping[shared_model_name] = {
1818
+ endpoint_name: model_artifact.uri
1819
+ if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1820
+ else model_artifact
1821
+ }
1822
+ elif override and shared_model_name:
1823
+ self._shared_proxy_mapping[shared_model_name].update(
1824
+ {
1825
+ endpoint_name: model_artifact.uri
1826
+ if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1827
+ else model_artifact
1828
+ }
1829
+ )
1392
1830
  self.add_model(
1393
1831
  endpoint_name=endpoint_name,
1394
1832
  model_class=model_class,
@@ -1396,11 +1834,12 @@ class ModelRunnerStep(MonitoredStep):
1396
1834
  model_artifact=model_artifact,
1397
1835
  labels=labels,
1398
1836
  model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1837
+ inputs=shared_model_params.get("inputs"),
1838
+ outputs=shared_model_params.get("outputs"),
1839
+ input_path=shared_model_params.get("input_path"),
1840
+ result_path=shared_model_params.get("result_path"),
1399
1841
  override=override,
1400
- inputs=inputs,
1401
- outputs=outputs,
1402
- input_path=input_path,
1403
- result_path=result_path,
1842
+ **model_params,
1404
1843
  )
1405
1844
 
1406
1845
  def add_model(
@@ -1424,7 +1863,9 @@ class ModelRunnerStep(MonitoredStep):
1424
1863
  Add a Model to this ModelRunner.
1425
1864
 
1426
1865
  :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1427
- :param model_class: Model class name
1866
+ :param model_class: Model class name. If LLModel is chosen
1867
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
1868
+ outputs will be overridden with UsageResponseKeys fields.
1428
1869
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
1429
1870
  * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
1430
1871
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
@@ -1435,14 +1876,6 @@ class ModelRunnerStep(MonitoredStep):
1435
1876
  otherwise block the main event loop thread.
1436
1877
  * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
1437
1878
  event loop to continue running while waiting for a response.
1438
- * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
1439
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
1440
- useful when:
1441
- - You want to share a heavy resource like a large model loaded onto a GPU.
1442
- - You want to centralize task scheduling or coordination for multiple lightweight tasks.
1443
- - You aim to minimize overhead from creating new executors or processes/threads per runnable.
1444
- The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
1445
- memory and hardware accelerators.
1446
1879
  * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
1447
1880
  It means that the runnable will not actually be run in parallel to anything else.
1448
1881
 
@@ -1465,11 +1898,30 @@ class ModelRunnerStep(MonitoredStep):
1465
1898
  :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1466
1899
  that been configured in the model artifact, please note that those outputs need to
1467
1900
  be equal to the model_class predict method outputs (length, and order)
1468
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1469
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1470
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
1471
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1472
- in path.
1901
+
1902
+ When using LLModel, the output will be overridden with UsageResponseKeys.fields().
1903
+
1904
+ :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
1905
+ this require that the event body will behave like a dict, expects scopes to be
1906
+ defined by dot notation (e.g "data.d").
1907
+ examples: input_path="data.b"
1908
+ event: {"data":{"a": 5, "b": 7}}, means monitored body will be 7.
1909
+ event: {"data":{"a": [5, 9], "b": [7, 8]}} means monitored body will be [7,8].
1910
+ event: {"data":{"a": "extra_data", "b": {"f0": [1, 2]}}} means monitored body will
1911
+ be {"f0": [1, 2]}.
1912
+ if a ``list`` or ``list of lists`` is provided, it must follow the order and
1913
+ size defined by the input schema.
1914
+ :param result_path: when specified selects the key/path in the output event to use as model monitoring
1915
+ outputs this require that the output event body will behave like a dict,
1916
+ expects scopes to be defined by dot notation (e.g "data.d").
1917
+ examples: result_path="out.b"
1918
+ event: {"out":{"a": 5, "b": 7}}, means monitored body will be 7.
1919
+ event: {"out":{"a": [5, 9], "b": [7, 8]}} means monitored body will be [7,8]
1920
+ event: {"out":{"a": "extra_data", "b": {"f0": [1, 2]}}} means monitored body will
1921
+ be {"f0": [1, 2]}
1922
+ if a ``list`` or ``list of lists`` is provided, it must follow the order and
1923
+ size defined by the output schema.
1924
+
1473
1925
  :param override: bool allow override existing model on the current ModelRunnerStep.
1474
1926
  :param model_parameters: Parameters for model instantiation
1475
1927
  """
@@ -1477,7 +1929,15 @@ class ModelRunnerStep(MonitoredStep):
1477
1929
  raise mlrun.errors.MLRunInvalidArgumentError(
1478
1930
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
1479
1931
  )
1480
-
1932
+ if type(model_class) is LLModel or (
1933
+ isinstance(model_class, str)
1934
+ and model_class.split(".")[-1] == LLModel.__name__
1935
+ ):
1936
+ if outputs:
1937
+ warnings.warn(
1938
+ "LLModel with existing outputs detected, overriding to default"
1939
+ )
1940
+ outputs = UsageResponseKeys.fields()
1481
1941
  model_parameters = model_parameters or (
1482
1942
  model_class.to_dict() if isinstance(model_class, Model) else {}
1483
1943
  )
@@ -1488,18 +1948,21 @@ class ModelRunnerStep(MonitoredStep):
1488
1948
  ):
1489
1949
  try:
1490
1950
  model_artifact, _ = mlrun.store_manager.get_store_artifact(
1491
- model_artifact
1951
+ mlrun.utils.remove_tag_from_artifact_uri(model_artifact)
1492
1952
  )
1493
1953
  except mlrun.errors.MLRunNotFoundError:
1494
1954
  raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
1495
1955
 
1496
- outputs = outputs or self._get_model_output_schema(model_artifact)
1497
-
1498
1956
  model_artifact = (
1499
1957
  model_artifact.uri
1500
1958
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
1501
1959
  else model_artifact
1502
1960
  )
1961
+ model_artifact = (
1962
+ mlrun.utils.remove_tag_from_artifact_uri(model_artifact)
1963
+ if model_artifact
1964
+ else None
1965
+ )
1503
1966
  model_parameters["artifact_uri"] = model_parameters.get(
1504
1967
  "artifact_uri", model_artifact
1505
1968
  )
@@ -1515,6 +1978,11 @@ class ModelRunnerStep(MonitoredStep):
1515
1978
  raise mlrun.errors.MLRunInvalidArgumentError(
1516
1979
  f"Model with name {endpoint_name} already exists in this ModelRunnerStep."
1517
1980
  )
1981
+ root = self._extract_root_step()
1982
+ if isinstance(root, RootFlowStep):
1983
+ self.verify_model_runner_step(
1984
+ self, [endpoint_name], verify_shared_models=False
1985
+ )
1518
1986
  ParallelExecutionMechanisms.validate(execution_mechanism)
1519
1987
  self.class_args[schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM] = (
1520
1988
  self.class_args.get(
@@ -1550,28 +2018,13 @@ class ModelRunnerStep(MonitoredStep):
1550
2018
  self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
1551
2019
 
1552
2020
  @staticmethod
1553
- def _get_model_output_schema(
1554
- model_artifact: Union[ModelArtifact, LLMPromptArtifact],
1555
- ) -> Optional[list[str]]:
1556
- if isinstance(
1557
- model_artifact,
1558
- ModelArtifact,
1559
- ):
1560
- return [feature.name for feature in model_artifact.spec.outputs]
1561
- elif isinstance(
1562
- model_artifact,
1563
- LLMPromptArtifact,
1564
- ):
1565
- _model_artifact = model_artifact.model_artifact
1566
- return [feature.name for feature in _model_artifact.spec.outputs]
1567
-
1568
- @staticmethod
1569
- def _get_model_endpoint_output_schema(
2021
+ def _get_model_endpoint_schema(
1570
2022
  name: str,
1571
2023
  project: str,
1572
2024
  uid: str,
1573
- ) -> list[str]:
2025
+ ) -> tuple[list[str], list[str]]:
1574
2026
  output_schema = None
2027
+ input_schema = None
1575
2028
  try:
1576
2029
  model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
1577
2030
  mlrun.db.get_run_db().get_model_endpoint(
@@ -1582,23 +2035,16 @@ class ModelRunnerStep(MonitoredStep):
1582
2035
  )
1583
2036
  )
1584
2037
  output_schema = model_endpoint.spec.label_names
2038
+ input_schema = model_endpoint.spec.feature_names
1585
2039
  except (
1586
2040
  mlrun.errors.MLRunNotFoundError,
1587
2041
  mlrun.errors.MLRunInvalidArgumentError,
1588
- ):
2042
+ ) as ex:
1589
2043
  logger.warning(
1590
- f"Model endpoint not found, using default output schema for model {name}"
2044
+ f"Model endpoint not found, using default output schema for model {name}",
2045
+ error=f"{type(ex).__name__}: {ex}",
1591
2046
  )
1592
- return output_schema
1593
-
1594
- @staticmethod
1595
- def _split_path(path: str) -> Union[str, list[str], None]:
1596
- if path is not None:
1597
- parsed_path = path.split(".")
1598
- if len(parsed_path) == 1:
1599
- parsed_path = parsed_path[0]
1600
- return parsed_path
1601
- return path
2047
+ return input_schema, output_schema
1602
2048
 
1603
2049
  def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]:
1604
2050
  monitoring_data = deepcopy(
@@ -1608,59 +2054,117 @@ class ModelRunnerStep(MonitoredStep):
1608
2054
  )
1609
2055
  if isinstance(monitoring_data, dict):
1610
2056
  for model in monitoring_data:
1611
- monitoring_data[model][schemas.MonitoringData.OUTPUTS] = (
1612
- monitoring_data.get(model, {}).get(schemas.MonitoringData.OUTPUTS)
1613
- or self._get_model_endpoint_output_schema(
1614
- name=model,
1615
- project=self.context.project if self.context else None,
1616
- uid=monitoring_data.get(model, {}).get(
1617
- mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
1618
- ),
1619
- )
2057
+ monitoring_data[model][schemas.MonitoringData.INPUT_PATH] = split_path(
2058
+ monitoring_data[model][schemas.MonitoringData.INPUT_PATH]
1620
2059
  )
1621
- # Prevent calling _get_model_output_schema for same model more than once
1622
- self.class_args[
2060
+ monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path(
2061
+ monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
2062
+ )
2063
+
2064
+ mep_output_schema, mep_input_schema = None, None
2065
+
2066
+ output_schema = self.class_args[
1623
2067
  mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
1624
- ][model][schemas.MonitoringData.OUTPUTS] = monitoring_data[model][
1625
- schemas.MonitoringData.OUTPUTS
1626
- ]
1627
- monitoring_data[model][schemas.MonitoringData.INPUT_PATH] = (
1628
- self._split_path(
1629
- monitoring_data[model][schemas.MonitoringData.INPUT_PATH]
2068
+ ][model][schemas.MonitoringData.OUTPUTS]
2069
+ input_schema = self.class_args[
2070
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2071
+ ][model][schemas.MonitoringData.INPUTS]
2072
+ if not output_schema or not input_schema:
2073
+ # if output or input schema is not provided, try to get it from the model endpoint
2074
+ mep_input_schema, mep_output_schema = (
2075
+ self._get_model_endpoint_schema(
2076
+ model,
2077
+ self.context.project,
2078
+ monitoring_data[model].get(
2079
+ schemas.MonitoringData.MODEL_ENDPOINT_UID, ""
2080
+ ),
2081
+ )
1630
2082
  )
2083
+ self.class_args[
2084
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2085
+ ][model][schemas.MonitoringData.OUTPUTS] = (
2086
+ output_schema or mep_output_schema
1631
2087
  )
1632
- monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = (
1633
- self._split_path(
1634
- monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
1635
- )
2088
+ self.class_args[
2089
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2090
+ ][model][schemas.MonitoringData.INPUTS] = (
2091
+ input_schema or mep_input_schema
1636
2092
  )
1637
2093
  return monitoring_data
2094
+ else:
2095
+ raise mlrun.errors.MLRunInvalidArgumentError(
2096
+ "Monitoring data must be a dictionary."
2097
+ )
2098
+
2099
+ def configure_pool_resource(
2100
+ self,
2101
+ max_processes: Optional[int] = None,
2102
+ max_threads: Optional[int] = None,
2103
+ pool_factor: Optional[int] = None,
2104
+ ) -> None:
2105
+ """
2106
+ Configure the resource limits for the shared models in the graph.
2107
+
2108
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2109
+ Defaults to the number of CPUs or 16 if undetectable.
2110
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2111
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2112
+ """
2113
+ self.max_processes = max_processes
2114
+ self.max_threads = max_threads
2115
+ self.pool_factor = pool_factor
1638
2116
 
1639
2117
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1640
2118
  self.context = context
1641
2119
  if not self._is_local_function(context):
1642
2120
  # skip init of non local functions
1643
2121
  return
1644
- model_selector = self.class_args.get("model_selector")
2122
+ model_selector, model_selector_params = self.class_args.get(
2123
+ "model_selector", (None, None)
2124
+ )
1645
2125
  execution_mechanism_by_model_name = self.class_args.get(
1646
2126
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
1647
2127
  )
1648
2128
  models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1649
- if isinstance(model_selector, str):
1650
- model_selector = get_class(model_selector, namespace)()
2129
+ if model_selector:
2130
+ model_selector = get_class(model_selector, namespace).from_dict(
2131
+ model_selector_params, init_with_params=True
2132
+ )
1651
2133
  model_objects = []
1652
2134
  for model, model_params in models.values():
2135
+ model_name = model_params.get("name")
2136
+ model_params[schemas.MonitoringData.INPUT_PATH] = (
2137
+ self.class_args.get(
2138
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
2139
+ )
2140
+ .get(model_name, {})
2141
+ .get(schemas.MonitoringData.INPUT_PATH)
2142
+ )
2143
+ model_params[schemas.MonitoringData.RESULT_PATH] = (
2144
+ self.class_args.get(
2145
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
2146
+ )
2147
+ .get(model_name, {})
2148
+ .get(schemas.MonitoringData.RESULT_PATH)
2149
+ )
1653
2150
  model = get_class(model, namespace).from_dict(
1654
2151
  model_params, init_with_params=True
1655
2152
  )
1656
2153
  model._raise_exception = False
2154
+ model._execution_mechanism = execution_mechanism_by_model_name.get(
2155
+ model_name
2156
+ )
1657
2157
  model_objects.append(model)
1658
2158
  self._async_object = ModelRunner(
1659
2159
  model_selector=model_selector,
1660
2160
  runnables=model_objects,
1661
2161
  execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
2162
+ shared_proxy_mapping=self._shared_proxy_mapping or None,
1662
2163
  name=self.name,
1663
2164
  context=context,
2165
+ max_processes=self.max_processes,
2166
+ max_threads=self.max_threads,
2167
+ pool_factor=self.pool_factor,
1664
2168
  )
1665
2169
 
1666
2170
 
@@ -2298,7 +2802,13 @@ class FlowStep(BaseStep):
2298
2802
  if not step.before and not any(
2299
2803
  [step.name in other_step.after for other_step in self._steps.values()]
2300
2804
  ):
2301
- step.responder = True
2805
+ if any(
2806
+ [
2807
+ getattr(step_in_graph, "responder", False)
2808
+ for step_in_graph in self._steps.values()
2809
+ ]
2810
+ ):
2811
+ step.responder = True
2302
2812
  return
2303
2813
 
2304
2814
  for step_name in step.before:
@@ -2381,14 +2891,20 @@ class RootFlowStep(FlowStep):
2381
2891
  name: str,
2382
2892
  model_class: Union[str, Model],
2383
2893
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
2384
- model_artifact: Optional[Union[str, ModelArtifact]],
2894
+ model_artifact: Union[str, ModelArtifact],
2895
+ inputs: Optional[list[str]] = None,
2896
+ outputs: Optional[list[str]] = None,
2897
+ input_path: Optional[str] = None,
2898
+ result_path: Optional[str] = None,
2385
2899
  override: bool = False,
2386
2900
  **model_parameters,
2387
2901
  ) -> None:
2388
2902
  """
2389
2903
  Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2390
2904
  :param name: Name of the shared model (should be unique in the graph)
2391
- :param model_class: Model class name
2905
+ :param model_class: Model class name. If LLModel is chosen
2906
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
2907
+ outputs will be overridden with UsageResponseKeys fields.
2392
2908
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2393
2909
  * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
2394
2910
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
@@ -2411,6 +2927,19 @@ class RootFlowStep(FlowStep):
2411
2927
  It means that the runnable will not actually be run in parallel to anything else.
2412
2928
 
2413
2929
  :param model_artifact: model artifact or mlrun model artifact uri
2930
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2931
+ that been configured in the model artifact, please note that those inputs need
2932
+ to be equal in length and order to the inputs that model_class
2933
+ predict method expects
2934
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2935
+ that been configured in the model artifact, please note that those outputs need
2936
+ to be equal to the model_class
2937
+ predict method outputs (length, and order)
2938
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
2939
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
2940
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
2941
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
2942
+ in path.
2414
2943
  :param override: bool allow override existing model on the current ModelRunnerStep.
2415
2944
  :param model_parameters: Parameters for model instantiation
2416
2945
  """
@@ -2418,6 +2947,15 @@ class RootFlowStep(FlowStep):
2418
2947
  raise mlrun.errors.MLRunInvalidArgumentError(
2419
2948
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2420
2949
  )
2950
+ if type(model_class) is LLModel or (
2951
+ isinstance(model_class, str)
2952
+ and model_class.split(".")[-1] == LLModel.__name__
2953
+ ):
2954
+ if outputs:
2955
+ warnings.warn(
2956
+ "LLModel with existing outputs detected, overriding to default"
2957
+ )
2958
+ outputs = UsageResponseKeys.fields()
2421
2959
 
2422
2960
  if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2423
2961
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2433,6 +2971,7 @@ class RootFlowStep(FlowStep):
2433
2971
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
2434
2972
  else model_artifact
2435
2973
  )
2974
+ model_artifact = mlrun.utils.remove_tag_from_artifact_uri(model_artifact)
2436
2975
  model_parameters["artifact_uri"] = model_parameters.get(
2437
2976
  "artifact_uri", model_artifact
2438
2977
  )
@@ -2444,6 +2983,14 @@ class RootFlowStep(FlowStep):
2444
2983
  "Inconsistent name for the added model."
2445
2984
  )
2446
2985
  model_parameters["name"] = name
2986
+ model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
2987
+ model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
2988
+ model_parameters["input_path"] = input_path or model_parameters.get(
2989
+ "input_path"
2990
+ )
2991
+ model_parameters["result_path"] = result_path or model_parameters.get(
2992
+ "result_path"
2993
+ )
2447
2994
 
2448
2995
  if name in self.shared_models and not override:
2449
2996
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2458,7 +3005,9 @@ class RootFlowStep(FlowStep):
2458
3005
  self.shared_models[name] = (model_class, model_parameters)
2459
3006
  self.shared_models_mechanism[name] = execution_mechanism
2460
3007
 
2461
- def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
3008
+ def get_shared_model_by_artifact_uri(
3009
+ self, artifact_uri: str
3010
+ ) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
2462
3011
  """
2463
3012
  Get a shared model by its artifact URI.
2464
3013
  :param artifact_uri: The artifact URI of the model.
@@ -2466,10 +3015,10 @@ class RootFlowStep(FlowStep):
2466
3015
  """
2467
3016
  for model_name, (model_class, model_params) in self.shared_models.items():
2468
3017
  if model_params.get("artifact_uri") == artifact_uri:
2469
- return model_name
2470
- return None
3018
+ return model_name, model_class, model_params
3019
+ return None, None, None
2471
3020
 
2472
- def config_pool_resource(
3021
+ def configure_shared_pool_resource(
2473
3022
  self,
2474
3023
  max_processes: Optional[int] = None,
2475
3024
  max_threads: Optional[int] = None,
@@ -2494,12 +3043,30 @@ class RootFlowStep(FlowStep):
2494
3043
  max_threads=self.shared_max_threads,
2495
3044
  pool_factor=self.pool_factor,
2496
3045
  )
2497
-
3046
+ monitored_steps = self.get_monitored_steps().values()
3047
+ for monitored_step in monitored_steps:
3048
+ if isinstance(monitored_step, ModelRunnerStep):
3049
+ for model, model_params in self.shared_models.values():
3050
+ if "shared_proxy_mapping" in model_params:
3051
+ model_params["shared_proxy_mapping"].update(
3052
+ deepcopy(
3053
+ monitored_step._shared_proxy_mapping.get(
3054
+ model_params.get("name"), {}
3055
+ )
3056
+ )
3057
+ )
3058
+ else:
3059
+ model_params["shared_proxy_mapping"] = deepcopy(
3060
+ monitored_step._shared_proxy_mapping.get(
3061
+ model_params.get("name"), {}
3062
+ )
3063
+ )
2498
3064
  for model, model_params in self.shared_models.values():
2499
3065
  model = get_class(model, namespace).from_dict(
2500
3066
  model_params, init_with_params=True
2501
3067
  )
2502
3068
  model._raise_exception = False
3069
+ model._execution_mechanism = self._shared_models_mechanism[model.name]
2503
3070
  self.context.executor.add_runnable(
2504
3071
  model, self._shared_models_mechanism[model.name]
2505
3072
  )
@@ -2619,12 +3186,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
2619
3186
  graph.edge(step.fullname, route.fullname)
2620
3187
 
2621
3188
 
2622
- def _add_graphviz_model_runner(graph, step, source=None):
3189
+ def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
2623
3190
  if source:
2624
3191
  graph.node("_start", source.name, shape=source.shape, style="filled")
2625
3192
  graph.edge("_start", step.fullname)
2626
-
2627
- is_monitored = step._extract_root_step().track_models
2628
3193
  m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
2629
3194
 
2630
3195
  number_of_models = len(
@@ -2663,6 +3228,7 @@ def _add_graphviz_flow(
2663
3228
  allow_empty=True
2664
3229
  )
2665
3230
  graph.node("_start", source.name, shape=source.shape, style="filled")
3231
+ is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
2666
3232
  for start_step in start_steps:
2667
3233
  graph.edge("_start", start_step.fullname)
2668
3234
  for child in step.get_children():
@@ -2671,7 +3237,7 @@ def _add_graphviz_flow(
2671
3237
  with graph.subgraph(name="cluster_" + child.fullname) as sg:
2672
3238
  _add_graphviz_router(sg, child)
2673
3239
  elif kind == StepKinds.model_runner:
2674
- _add_graphviz_model_runner(graph, child)
3240
+ _add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
2675
3241
  else:
2676
3242
  graph.node(child.fullname, label=child.name, shape=child.get_shape())
2677
3243
  _add_edges(child.after or [], step, graph, child)
@@ -2803,7 +3369,7 @@ def params_to_step(
2803
3369
  step = QueueStep(name, **class_args)
2804
3370
 
2805
3371
  elif class_name and hasattr(class_name, "to_dict"):
2806
- struct = class_name.to_dict()
3372
+ struct = deepcopy(class_name.to_dict())
2807
3373
  kind = struct.get("kind", StepKinds.task)
2808
3374
  name = (
2809
3375
  name
@@ -2890,7 +3456,7 @@ def _init_async_objects(context, steps):
2890
3456
  datastore_profile = datastore_profile_read(stream_path)
2891
3457
  if isinstance(
2892
3458
  datastore_profile,
2893
- (DatastoreProfileKafkaTarget, DatastoreProfileKafkaSource),
3459
+ (DatastoreProfileKafkaTarget, DatastoreProfileKafkaStream),
2894
3460
  ):
2895
3461
  step._async_object = KafkaStoreyTarget(
2896
3462
  path=stream_path,
@@ -2906,7 +3472,7 @@ def _init_async_objects(context, steps):
2906
3472
  else:
2907
3473
  raise mlrun.errors.MLRunValueError(
2908
3474
  f"Received an unexpected stream profile type: {type(datastore_profile)}\n"
2909
- "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`."
3475
+ "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaStream`."
2910
3476
  )
2911
3477
  elif stream_path.startswith("kafka://") or kafka_brokers:
2912
3478
  topic, brokers = parse_kafka_url(stream_path, kafka_brokers)
@@ -2922,6 +3488,8 @@ def _init_async_objects(context, steps):
2922
3488
  context=context,
2923
3489
  **options,
2924
3490
  )
3491
+ elif stream_path.startswith("dummy://"):
3492
+ step._async_object = _DummyStream(context=context, **options)
2925
3493
  else:
2926
3494
  if stream_path.startswith("v3io://"):
2927
3495
  endpoint, stream_path = parse_path(step.path)