mlrun 1.10.0rc16__py3-none-any.whl → 1.10.1rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (101) hide show
  1. mlrun/__init__.py +22 -2
  2. mlrun/artifacts/document.py +6 -1
  3. mlrun/artifacts/llm_prompt.py +21 -15
  4. mlrun/artifacts/model.py +3 -3
  5. mlrun/common/constants.py +9 -0
  6. mlrun/common/formatters/artifact.py +1 -0
  7. mlrun/common/model_monitoring/helpers.py +86 -0
  8. mlrun/common/schemas/__init__.py +2 -0
  9. mlrun/common/schemas/auth.py +2 -0
  10. mlrun/common/schemas/function.py +10 -0
  11. mlrun/common/schemas/hub.py +30 -18
  12. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  13. mlrun/common/schemas/model_monitoring/constants.py +30 -6
  14. mlrun/common/schemas/model_monitoring/functions.py +13 -4
  15. mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
  16. mlrun/common/schemas/pipeline.py +1 -1
  17. mlrun/common/schemas/serving.py +3 -0
  18. mlrun/common/schemas/workflow.py +1 -0
  19. mlrun/common/secrets.py +22 -1
  20. mlrun/config.py +34 -21
  21. mlrun/datastore/__init__.py +11 -3
  22. mlrun/datastore/azure_blob.py +162 -47
  23. mlrun/datastore/base.py +265 -7
  24. mlrun/datastore/datastore.py +10 -5
  25. mlrun/datastore/datastore_profile.py +61 -5
  26. mlrun/datastore/model_provider/huggingface_provider.py +367 -0
  27. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  28. mlrun/datastore/model_provider/model_provider.py +211 -74
  29. mlrun/datastore/model_provider/openai_provider.py +243 -71
  30. mlrun/datastore/s3.py +24 -2
  31. mlrun/datastore/store_resources.py +4 -4
  32. mlrun/datastore/storeytargets.py +2 -3
  33. mlrun/datastore/utils.py +15 -3
  34. mlrun/db/base.py +27 -19
  35. mlrun/db/httpdb.py +57 -48
  36. mlrun/db/nopdb.py +25 -10
  37. mlrun/execution.py +55 -13
  38. mlrun/hub/__init__.py +15 -0
  39. mlrun/hub/module.py +181 -0
  40. mlrun/k8s_utils.py +105 -16
  41. mlrun/launcher/base.py +13 -6
  42. mlrun/launcher/local.py +2 -0
  43. mlrun/model.py +9 -3
  44. mlrun/model_monitoring/api.py +66 -27
  45. mlrun/model_monitoring/applications/__init__.py +1 -1
  46. mlrun/model_monitoring/applications/base.py +388 -138
  47. mlrun/model_monitoring/applications/context.py +2 -4
  48. mlrun/model_monitoring/applications/results.py +4 -7
  49. mlrun/model_monitoring/controller.py +239 -101
  50. mlrun/model_monitoring/db/_schedules.py +36 -13
  51. mlrun/model_monitoring/db/_stats.py +4 -3
  52. mlrun/model_monitoring/db/tsdb/base.py +29 -9
  53. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +4 -5
  54. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +154 -50
  55. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
  56. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
  57. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +245 -51
  58. mlrun/model_monitoring/helpers.py +28 -5
  59. mlrun/model_monitoring/stream_processing.py +45 -14
  60. mlrun/model_monitoring/writer.py +220 -1
  61. mlrun/platforms/__init__.py +3 -2
  62. mlrun/platforms/iguazio.py +7 -3
  63. mlrun/projects/operations.py +16 -11
  64. mlrun/projects/pipelines.py +2 -2
  65. mlrun/projects/project.py +157 -69
  66. mlrun/run.py +97 -20
  67. mlrun/runtimes/__init__.py +18 -0
  68. mlrun/runtimes/base.py +14 -6
  69. mlrun/runtimes/daskjob.py +1 -0
  70. mlrun/runtimes/local.py +5 -2
  71. mlrun/runtimes/mounts.py +20 -2
  72. mlrun/runtimes/nuclio/__init__.py +1 -0
  73. mlrun/runtimes/nuclio/application/application.py +147 -17
  74. mlrun/runtimes/nuclio/function.py +72 -27
  75. mlrun/runtimes/nuclio/serving.py +102 -20
  76. mlrun/runtimes/pod.py +213 -21
  77. mlrun/runtimes/utils.py +49 -9
  78. mlrun/secrets.py +54 -13
  79. mlrun/serving/remote.py +79 -6
  80. mlrun/serving/routers.py +23 -41
  81. mlrun/serving/server.py +230 -40
  82. mlrun/serving/states.py +605 -232
  83. mlrun/serving/steps.py +62 -0
  84. mlrun/serving/system_steps.py +136 -81
  85. mlrun/serving/v2_serving.py +9 -10
  86. mlrun/utils/helpers.py +215 -83
  87. mlrun/utils/logger.py +3 -1
  88. mlrun/utils/notifications/notification/base.py +18 -0
  89. mlrun/utils/notifications/notification/git.py +2 -4
  90. mlrun/utils/notifications/notification/mail.py +38 -15
  91. mlrun/utils/notifications/notification/slack.py +2 -4
  92. mlrun/utils/notifications/notification/webhook.py +2 -5
  93. mlrun/utils/notifications/notification_pusher.py +1 -1
  94. mlrun/utils/version/version.json +2 -2
  95. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.1rc4.dist-info}/METADATA +51 -50
  96. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.1rc4.dist-info}/RECORD +100 -95
  97. mlrun/api/schemas/__init__.py +0 -259
  98. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.1rc4.dist-info}/WHEEL +0 -0
  99. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.1rc4.dist-info}/entry_points.txt +0 -0
  100. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.1rc4.dist-info}/licenses/LICENSE +0 -0
  101. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.1rc4.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -24,6 +24,7 @@ import inspect
24
24
  import os
25
25
  import pathlib
26
26
  import traceback
27
+ import warnings
27
28
  from abc import ABC
28
29
  from copy import copy, deepcopy
29
30
  from inspect import getfullargspec, signature
@@ -38,17 +39,21 @@ import mlrun.common.schemas as schemas
38
39
  from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
39
40
  from mlrun.artifacts.model import ModelArtifact
40
41
  from mlrun.datastore.datastore_profile import (
41
- DatastoreProfileKafkaSource,
42
+ DatastoreProfileKafkaStream,
42
43
  DatastoreProfileKafkaTarget,
43
44
  DatastoreProfileV3io,
44
45
  datastore_profile_read,
45
46
  )
46
- from mlrun.datastore.model_provider.model_provider import ModelProvider
47
+ from mlrun.datastore.model_provider.model_provider import (
48
+ InvokeResponseFormat,
49
+ ModelProvider,
50
+ UsageResponseKeys,
51
+ )
47
52
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
48
- from mlrun.utils import get_data_from_path, logger, split_path
53
+ from mlrun.utils import get_data_from_path, logger, set_data_by_path, split_path
49
54
 
50
55
  from ..config import config
51
- from ..datastore import get_stream_pusher
56
+ from ..datastore import _DummyStream, get_stream_pusher
52
57
  from ..datastore.utils import (
53
58
  get_kafka_brokers_from_dict,
54
59
  parse_kafka_url,
@@ -372,20 +377,14 @@ class BaseStep(ModelObj):
372
377
  to event["y"] resulting in {"x": 5, "y": <result>}
373
378
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
374
379
 
375
- * **overwrite**:
376
-
377
- 1. If model endpoints with the same name exist, delete the `latest` one.
378
- 2. Create a new model endpoint entry and set it as `latest`.
379
-
380
- * **inplace** (default):
381
-
382
- 1. If model endpoints with the same name exist, update the `latest` entry.
383
- 2. Otherwise, create a new entry.
380
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
381
+ create a new model endpoint entry and set it as `latest`.
384
382
 
385
- * **archive**:
383
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
384
+ entry; otherwise, create a new entry.
386
385
 
387
- 1. If model endpoints with the same name exist, preserve them.
388
- 2. Create a new model endpoint with the same name and set it to `latest`.
386
+ * **archive**: If model endpoints with the same name exist, preserve them;
387
+ create a new model endpoint with the same name and set it to `latest`.
389
388
 
390
389
  :param class_args: class init arguments
391
390
  """
@@ -517,7 +516,9 @@ class BaseStep(ModelObj):
517
516
 
518
517
  root = self._extract_root_step()
519
518
 
520
- if not isinstance(root, RootFlowStep):
519
+ if not isinstance(root, RootFlowStep) or (
520
+ isinstance(root, RootFlowStep) and root.engine != "async"
521
+ ):
521
522
  raise GraphError(
522
523
  "ModelRunnerStep can be added to 'Flow' topology graph only"
523
524
  )
@@ -541,8 +542,8 @@ class BaseStep(ModelObj):
541
542
  # Update model endpoints names in the root step
542
543
  root.update_model_endpoints_names(step_model_endpoints_names)
543
544
 
544
- @staticmethod
545
545
  def _verify_shared_models(
546
+ self,
546
547
  root: "RootFlowStep",
547
548
  step: "ModelRunnerStep",
548
549
  step_model_endpoints_names: list[str],
@@ -571,35 +572,41 @@ class BaseStep(ModelObj):
571
572
  prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
572
573
  # if the model artifact is a prompt, we need to get the model URI
573
574
  # to ensure that the shared runnable name is correct
575
+ llm_artifact_uri = None
574
576
  if prefix == mlrun.utils.StorePrefix.LLMPrompt:
575
577
  llm_artifact, _ = mlrun.store_manager.get_store_artifact(
576
578
  model_artifact_uri
577
579
  )
580
+ llm_artifact_uri = llm_artifact.uri
578
581
  model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
579
582
  llm_artifact.spec.parent_uri
580
583
  )
581
- actual_shared_name = root.get_shared_model_name_by_artifact_uri(
582
- model_artifact_uri
584
+ actual_shared_name, shared_model_class, shared_model_params = (
585
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
583
586
  )
584
587
 
585
- if not shared_runnable_name:
586
- if not actual_shared_name:
587
- raise GraphError(
588
- f"Can't find shared model for {name} model endpoint"
589
- )
590
- else:
591
- step.class_args[schemas.ModelRunnerStepData.MODELS][name][
592
- schemas.ModelsData.MODEL_PARAMETERS.value
593
- ]["shared_runnable_name"] = actual_shared_name
594
- shared_models.append(actual_shared_name)
588
+ if not actual_shared_name:
589
+ raise GraphError(
590
+ f"Can't find shared model named {shared_runnable_name}"
591
+ )
592
+ elif not shared_runnable_name:
593
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
594
+ schemas.ModelsData.MODEL_PARAMETERS.value
595
+ ]["shared_runnable_name"] = actual_shared_name
595
596
  elif actual_shared_name != shared_runnable_name:
596
597
  raise GraphError(
597
598
  f"Model endpoint {name} shared runnable name mismatch: "
598
599
  f"expected {actual_shared_name}, got {shared_runnable_name}"
599
600
  )
600
- else:
601
- shared_models.append(actual_shared_name)
602
-
601
+ shared_models.append(actual_shared_name)
602
+ self._edit_proxy_model_data(
603
+ step,
604
+ name,
605
+ actual_shared_name,
606
+ shared_model_params,
607
+ shared_model_class,
608
+ llm_artifact_uri or model_artifact_uri,
609
+ )
603
610
  undefined_shared_models = list(
604
611
  set(shared_models) - set(root.shared_models.keys())
605
612
  )
@@ -608,6 +615,52 @@ class BaseStep(ModelObj):
608
615
  f"The following shared models are not defined in the graph: {undefined_shared_models}."
609
616
  )
610
617
 
618
+ @staticmethod
619
+ def _edit_proxy_model_data(
620
+ step: "ModelRunnerStep",
621
+ name: str,
622
+ actual_shared_name: str,
623
+ shared_model_params: dict,
624
+ shared_model_class: Any,
625
+ artifact: Union[ModelArtifact, LLMPromptArtifact, str],
626
+ ):
627
+ monitoring_data = step.class_args.setdefault(
628
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
629
+ )
630
+
631
+ # edit monitoring data according to the shared model parameters
632
+ monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
633
+ "input_path"
634
+ ]
635
+ monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
636
+ "result_path"
637
+ ]
638
+ monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
639
+ "inputs"
640
+ ]
641
+ monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
642
+ "outputs"
643
+ ]
644
+ monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
645
+ shared_model_class
646
+ if isinstance(shared_model_class, str)
647
+ else shared_model_class.__class__.__name__
648
+ )
649
+ if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
650
+ step._shared_proxy_mapping[actual_shared_name] = {
651
+ name: artifact.uri
652
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
653
+ else artifact
654
+ }
655
+ elif actual_shared_name:
656
+ step._shared_proxy_mapping[actual_shared_name].update(
657
+ {
658
+ name: artifact.uri
659
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
660
+ else artifact
661
+ }
662
+ )
663
+
611
664
 
612
665
  class TaskStep(BaseStep):
613
666
  """task execution step, runs a class or handler"""
@@ -983,20 +1036,14 @@ class RouterStep(TaskStep):
983
1036
  :param function: function this step should run in
984
1037
  :param creation_strategy: Strategy for creating or updating the model endpoint:
985
1038
 
986
- * **overwrite**:
1039
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1040
+ create a new model endpoint entry and set it as `latest`.
987
1041
 
988
- 1. If model endpoints with the same name exist, delete the `latest` one.
989
- 2. Create a new model endpoint entry and set it as `latest`.
1042
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
1043
+ entry;otherwise, create a new entry.
990
1044
 
991
- * **inplace** (default):
992
-
993
- 1. If model endpoints with the same name exist, update the `latest` entry.
994
- 2. Otherwise, create a new entry.
995
-
996
- * **archive**:
997
-
998
- 1. If model endpoints with the same name exist, preserve them.
999
- 2. Create a new model endpoint with the same name and set it to `latest`.
1045
+ * **archive**: If model endpoints with the same name exist, preserve them;
1046
+ create a new model endpoint with the same name and set it to `latest`.
1000
1047
 
1001
1048
  """
1002
1049
  if len(self.routes.keys()) >= MAX_MODELS_PER_ROUTER and key not in self.routes:
@@ -1090,6 +1137,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1090
1137
  "artifact_uri",
1091
1138
  "shared_runnable_name",
1092
1139
  "shared_proxy_mapping",
1140
+ "execution_mechanism",
1093
1141
  ]
1094
1142
  kind = "model"
1095
1143
 
@@ -1111,6 +1159,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1111
1159
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1112
1160
  self.model_artifact: Optional[ModelArtifact] = None
1113
1161
  self.model_provider: Optional[ModelProvider] = None
1162
+ self._artifact_were_loaded = False
1163
+ self._execution_mechanism = None
1114
1164
 
1115
1165
  def __init_subclass__(cls):
1116
1166
  super().__init_subclass__()
@@ -1130,13 +1180,29 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1130
1180
  raise_missing_schema_exception=False,
1131
1181
  )
1132
1182
 
1133
- def _load_artifacts(self) -> None:
1134
- artifact = self._get_artifact_object()
1135
- if isinstance(artifact, LLMPromptArtifact):
1136
- self.invocation_artifact = artifact
1137
- self.model_artifact = self.invocation_artifact.model_artifact
1183
+ # Check if the relevant predict method is implemented when trying to initialize the model
1184
+ if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
1185
+ if self.__class__.predict_async is Model.predict_async:
1186
+ raise mlrun.errors.ModelRunnerError(
1187
+ f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict_async() "
1188
+ f"is not implemented"
1189
+ )
1138
1190
  else:
1139
- self.model_artifact = artifact
1191
+ if self.__class__.predict is Model.predict:
1192
+ raise mlrun.errors.ModelRunnerError(
1193
+ f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict() "
1194
+ f"is not implemented"
1195
+ )
1196
+
1197
+ def _load_artifacts(self) -> None:
1198
+ if not self._artifact_were_loaded:
1199
+ artifact = self._get_artifact_object()
1200
+ if isinstance(artifact, LLMPromptArtifact):
1201
+ self.invocation_artifact = artifact
1202
+ self.model_artifact = self.invocation_artifact.model_artifact
1203
+ else:
1204
+ self.model_artifact = artifact
1205
+ self._artifact_were_loaded = True
1140
1206
 
1141
1207
  def _get_artifact_object(
1142
1208
  self, proxy_uri: Optional[str] = None
@@ -1158,11 +1224,11 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1158
1224
 
1159
1225
  def predict(self, body: Any, **kwargs) -> Any:
1160
1226
  """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
1161
- return body
1227
+ raise NotImplementedError("predict() method not implemented")
1162
1228
 
1163
1229
  async def predict_async(self, body: Any, **kwargs) -> Any:
1164
1230
  """Override to implement prediction logic if the logic requires asyncio."""
1165
- return body
1231
+ raise NotImplementedError("predict_async() method not implemented")
1166
1232
 
1167
1233
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1168
1234
  return self.predict(body)
@@ -1205,26 +1271,111 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1205
1271
 
1206
1272
 
1207
1273
  class LLModel(Model):
1274
+ """
1275
+ A model wrapper for handling LLM (Large Language Model) prompt-based inference.
1276
+
1277
+ This class extends the base `Model` to provide specialized handling for
1278
+ `LLMPromptArtifact` objects, enabling both synchronous and asynchronous
1279
+ invocation of language models.
1280
+
1281
+ **Model Invocation**:
1282
+
1283
+ - The execution of enriched prompts is delegated to the `model_provider`
1284
+ configured for the model (e.g., **Hugging Face** or **OpenAI**).
1285
+ - The `model_provider` is responsible for sending the prompt to the correct
1286
+ backend API and returning the generated output.
1287
+ - Users can override the `predict` and `predict_async` methods to customize
1288
+ the behavior of the model invocation.
1289
+
1290
+ **Prompt Enrichment Overview**:
1291
+
1292
+ - If an `LLMPromptArtifact` is found, load its prompt template and fill in
1293
+ placeholders using values from the request body.
1294
+ - If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
1295
+ to retrieve `messages` directly from the request body using the input path.
1296
+
1297
+ **Simplified Example**:
1298
+
1299
+ Input body::
1300
+
1301
+ {"city": "Paris", "days": 3}
1302
+
1303
+ Prompt template in artifact::
1304
+
1305
+ [
1306
+ {"role": "system", "content": "You are a travel planning assistant."},
1307
+ {"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
1308
+ ]
1309
+
1310
+ Result after enrichment::
1311
+
1312
+ [
1313
+ {"role": "system", "content": "You are a travel planning assistant."},
1314
+ {"role": "user", "content": "Create a 3-day itinerary for Paris."},
1315
+ ]
1316
+
1317
+ :param name: Name of the model.
1318
+ :param input_path: Path in the request body where input data is located.
1319
+ :param result_path: Path in the response body where model outputs and the statistics
1320
+ will be stored.
1321
+ """
1322
+
1323
+ _dict_fields = Model._dict_fields + ["result_path", "input_path"]
1324
+
1208
1325
  def __init__(
1209
- self, name: str, input_path: Optional[Union[str, list[str]]], **kwargs
1326
+ self,
1327
+ name: str,
1328
+ input_path: Optional[Union[str, list[str]]] = None,
1329
+ result_path: Optional[Union[str, list[str]]] = None,
1330
+ **kwargs,
1210
1331
  ):
1211
1332
  super().__init__(name, **kwargs)
1212
1333
  self._input_path = split_path(input_path)
1334
+ self._result_path = split_path(result_path)
1335
+ logger.info(
1336
+ "LLModel initialized",
1337
+ model_name=name,
1338
+ input_path=input_path,
1339
+ result_path=result_path,
1340
+ )
1213
1341
 
1214
1342
  def predict(
1215
1343
  self,
1216
1344
  body: Any,
1217
1345
  messages: Optional[list[dict]] = None,
1218
- model_configuration: Optional[dict] = None,
1346
+ invocation_config: Optional[dict] = None,
1219
1347
  **kwargs,
1220
1348
  ) -> Any:
1349
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1221
1350
  if isinstance(
1222
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1351
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1223
1352
  ) and isinstance(self.model_provider, ModelProvider):
1224
- body["result"] = self.model_provider.invoke(
1353
+ logger.debug(
1354
+ "Invoking model provider",
1355
+ model_name=self.name,
1356
+ messages=messages,
1357
+ invocation_config=invocation_config,
1358
+ )
1359
+ response_with_stats = self.model_provider.invoke(
1225
1360
  messages=messages,
1226
- as_str=True,
1227
- **(model_configuration or {}),
1361
+ invoke_response_format=InvokeResponseFormat.USAGE,
1362
+ **(invocation_config or {}),
1363
+ )
1364
+ set_data_by_path(
1365
+ path=self._result_path, data=body, value=response_with_stats
1366
+ )
1367
+ logger.debug(
1368
+ "LLModel prediction completed",
1369
+ model_name=self.name,
1370
+ answer=response_with_stats.get("answer"),
1371
+ usage=response_with_stats.get("usage"),
1372
+ )
1373
+ else:
1374
+ logger.warning(
1375
+ "LLModel invocation artifact or model provider not set, skipping prediction",
1376
+ model_name=self.name,
1377
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1378
+ model_provider_type=type(self.model_provider).__name__,
1228
1379
  )
1229
1380
  return body
1230
1381
 
@@ -1232,61 +1383,112 @@ class LLModel(Model):
1232
1383
  self,
1233
1384
  body: Any,
1234
1385
  messages: Optional[list[dict]] = None,
1235
- model_configuration: Optional[dict] = None,
1386
+ invocation_config: Optional[dict] = None,
1236
1387
  **kwargs,
1237
1388
  ) -> Any:
1389
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1238
1390
  if isinstance(
1239
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1391
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1240
1392
  ) and isinstance(self.model_provider, ModelProvider):
1241
- body["result"] = await self.model_provider.async_invoke(
1393
+ logger.debug(
1394
+ "Async invoking model provider",
1395
+ model_name=self.name,
1396
+ messages=messages,
1397
+ invocation_config=invocation_config,
1398
+ )
1399
+ response_with_stats = await self.model_provider.async_invoke(
1242
1400
  messages=messages,
1243
- as_str=True,
1244
- **(model_configuration or {}),
1401
+ invoke_response_format=InvokeResponseFormat.USAGE,
1402
+ **(invocation_config or {}),
1403
+ )
1404
+ set_data_by_path(
1405
+ path=self._result_path, data=body, value=response_with_stats
1406
+ )
1407
+ logger.debug(
1408
+ "LLModel async prediction completed",
1409
+ model_name=self.name,
1410
+ answer=response_with_stats.get("answer"),
1411
+ usage=response_with_stats.get("usage"),
1412
+ )
1413
+ else:
1414
+ logger.warning(
1415
+ "LLModel invocation artifact or model provider not set, skipping async prediction",
1416
+ model_name=self.name,
1417
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1418
+ model_provider_type=type(self.model_provider).__name__,
1245
1419
  )
1246
1420
  return body
1247
1421
 
1248
1422
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1249
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1423
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1424
+ messages, invocation_config = self.enrich_prompt(
1425
+ body, origin_name, llm_prompt_artifact
1426
+ )
1427
+ logger.info(
1428
+ "Calling LLModel predict",
1429
+ model_name=self.name,
1430
+ model_endpoint_name=origin_name,
1431
+ messages_len=len(messages) if messages else 0,
1432
+ )
1250
1433
  return self.predict(
1251
- body, messages=messages, model_configuration=model_configuration
1434
+ body,
1435
+ messages=messages,
1436
+ invocation_config=invocation_config,
1437
+ llm_prompt_artifact=llm_prompt_artifact,
1252
1438
  )
1253
1439
 
1254
1440
  async def run_async(
1255
1441
  self, body: Any, path: str, origin_name: Optional[str] = None
1256
1442
  ) -> Any:
1257
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1443
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1444
+ messages, invocation_config = self.enrich_prompt(
1445
+ body, origin_name, llm_prompt_artifact
1446
+ )
1447
+ logger.info(
1448
+ "Calling LLModel async predict",
1449
+ model_name=self.name,
1450
+ model_endpoint_name=origin_name,
1451
+ messages_len=len(messages) if messages else 0,
1452
+ )
1258
1453
  return await self.predict_async(
1259
- body, messages=messages, model_configuration=model_configuration
1454
+ body,
1455
+ messages=messages,
1456
+ invocation_config=invocation_config,
1457
+ llm_prompt_artifact=llm_prompt_artifact,
1260
1458
  )
1261
1459
 
1262
1460
  def enrich_prompt(
1263
- self, body: dict, origin_name: str
1461
+ self,
1462
+ body: dict,
1463
+ origin_name: str,
1464
+ llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
1264
1465
  ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1265
- if origin_name and self.shared_proxy_mapping:
1266
- llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1267
- if isinstance(llm_prompt_artifact, str):
1268
- llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1269
- self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1270
- else:
1271
- llm_prompt_artifact = (
1272
- self.invocation_artifact or self._get_artifact_object()
1273
- )
1274
- if not (
1466
+ logger.info(
1467
+ "Enriching prompt",
1468
+ model_name=self.name,
1469
+ model_endpoint_name=origin_name,
1470
+ )
1471
+ if not llm_prompt_artifact or not (
1275
1472
  llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1276
1473
  ):
1277
1474
  logger.warning(
1278
- "LLMModel must be provided with LLMPromptArtifact",
1475
+ "LLModel must be provided with LLMPromptArtifact",
1476
+ model_name=self.name,
1477
+ artifact_type=type(llm_prompt_artifact).__name__,
1279
1478
  llm_prompt_artifact=llm_prompt_artifact,
1280
1479
  )
1281
- return None, None
1282
- prompt_legend = llm_prompt_artifact.spec.prompt_legend
1283
- prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1480
+ prompt_legend, prompt_template, invocation_config = {}, [], {}
1481
+ else:
1482
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1483
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1484
+ invocation_config = llm_prompt_artifact.spec.invocation_config
1284
1485
  input_data = copy(get_data_from_path(self._input_path, body))
1285
- if isinstance(input_data, dict):
1486
+ if isinstance(input_data, dict) and prompt_template:
1286
1487
  kwargs = (
1287
1488
  {
1288
1489
  place_holder: input_data.get(body_map["field"])
1289
1490
  for place_holder, body_map in prompt_legend.items()
1491
+ if input_data.get(body_map["field"])
1290
1492
  }
1291
1493
  if prompt_legend
1292
1494
  else {}
@@ -1298,23 +1500,61 @@ class LLModel(Model):
1298
1500
  message["content"] = message["content"].format(**input_data)
1299
1501
  except KeyError as e:
1300
1502
  logger.warning(
1301
- "Input data was missing a placeholder, placeholder stay unformatted",
1302
- key_error=e,
1503
+ "Input data missing placeholder, content stays unformatted",
1504
+ model_name=self.name,
1505
+ key_error=mlrun.errors.err_to_str(e),
1303
1506
  )
1304
1507
  message["content"] = message["content"].format_map(
1305
1508
  default_place_holders
1306
1509
  )
1510
+ elif isinstance(input_data, dict) and not prompt_template:
1511
+ # If there is no prompt template, we assume the input data is already in the correct format.
1512
+ logger.debug("Attempting to retrieve messages from the request body.")
1513
+ prompt_template = input_data.get("messages", [])
1307
1514
  else:
1308
1515
  logger.warning(
1309
- f"Expected input data to be a dict, but received input data from type {type(input_data)} prompt "
1310
- f"template stay unformatted",
1516
+ "Expected input data to be a dict, prompt template stays unformatted",
1517
+ model_name=self.name,
1518
+ input_data_type=type(input_data).__name__,
1311
1519
  )
1312
- return prompt_template, llm_prompt_artifact.spec.model_configuration
1520
+ return prompt_template, invocation_config
1521
+
1522
+ def _get_invocation_artifact(
1523
+ self, origin_name: Optional[str] = None
1524
+ ) -> Union[LLMPromptArtifact, None]:
1525
+ """
1526
+ Get the LLMPromptArtifact object for this model.
1527
+
1528
+ :param proxy_uri: Optional; URI to the proxy artifact.
1529
+ :return: LLMPromptArtifact object or None if not found.
1530
+ """
1531
+ if origin_name and self.shared_proxy_mapping:
1532
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1533
+ if isinstance(llm_prompt_artifact, str):
1534
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1535
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1536
+ elif self._artifact_were_loaded:
1537
+ llm_prompt_artifact = self.invocation_artifact
1538
+ else:
1539
+ self._load_artifacts()
1540
+ llm_prompt_artifact = self.invocation_artifact
1541
+ return llm_prompt_artifact
1313
1542
 
1314
1543
 
1315
- class ModelSelector:
1544
+ class ModelSelector(ModelObj):
1316
1545
  """Used to select which models to run on each event."""
1317
1546
 
1547
+ def __init__(self, **kwargs):
1548
+ super().__init__()
1549
+
1550
+ def __init_subclass__(cls):
1551
+ super().__init_subclass__()
1552
+ cls._dict_fields = list(
1553
+ set(cls._dict_fields)
1554
+ | set(inspect.signature(cls.__init__).parameters.keys())
1555
+ )
1556
+ cls._dict_fields.remove("self")
1557
+
1318
1558
  def select(
1319
1559
  self, event, available_models: list[Model]
1320
1560
  ) -> Union[list[str], list[Model]]:
@@ -1406,34 +1646,68 @@ class ModelRunnerStep(MonitoredStep):
1406
1646
  model_runner_step.add_model(..., model_class=MyModel(name="my_model"))
1407
1647
  graph.to(model_runner_step)
1408
1648
 
1649
+ Note when ModelRunnerStep is used in a graph, MLRun automatically imports
1650
+ the default language model class (LLModel) during function deployment.
1651
+
1652
+ Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
1653
+
1654
+ Note see configure_pool_resource method documentation for default number of max threads and max processes.
1655
+
1409
1656
  :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1410
1657
  event. Optional. If not passed, all models will be run.
1411
1658
  :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
1412
1659
  an error. If False, the error will appear in the output event.
1413
1660
 
1414
- :raise ModelRunnerError - when a model raise an error the ModelRunnerStep will handle it, collect errors and outputs
1415
- from added models, If raise_exception is True will raise ModelRunnerError Else will add
1416
- the error msg as part of the event body mapped by model name if more than one model was
1417
- added to the ModelRunnerStep
1661
+ :raise ModelRunnerError: when a model raises an error the ModelRunnerStep will handle it, collect errors and
1662
+ outputs from added models. If raise_exception is True will raise ModelRunnerError. Else
1663
+ will add the error msg as part of the event body mapped by model name if more than
1664
+ one model was added to the ModelRunnerStep
1418
1665
  """
1419
1666
 
1420
1667
  kind = "model_runner"
1421
- _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
1668
+ _dict_fields = MonitoredStep._dict_fields + [
1669
+ "_shared_proxy_mapping",
1670
+ "max_processes",
1671
+ "max_threads",
1672
+ "pool_factor",
1673
+ ]
1422
1674
 
1423
1675
  def __init__(
1424
1676
  self,
1425
1677
  *args,
1426
1678
  name: Optional[str] = None,
1427
1679
  model_selector: Optional[Union[str, ModelSelector]] = None,
1680
+ model_selector_parameters: Optional[dict] = None,
1428
1681
  raise_exception: bool = True,
1429
1682
  **kwargs,
1430
1683
  ):
1684
+ self.max_processes = None
1685
+ self.max_threads = None
1686
+ self.pool_factor = None
1687
+
1688
+ if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1689
+ raise mlrun.errors.MLRunInvalidArgumentError(
1690
+ "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1691
+ "`model_selector_parameters`."
1692
+ )
1693
+ if model_selector:
1694
+ model_selector_parameters = model_selector_parameters or (
1695
+ model_selector.to_dict()
1696
+ if isinstance(model_selector, ModelSelector)
1697
+ else {}
1698
+ )
1699
+ model_selector = (
1700
+ model_selector
1701
+ if isinstance(model_selector, str)
1702
+ else model_selector.__class__.__name__
1703
+ )
1704
+
1431
1705
  super().__init__(
1432
1706
  *args,
1433
1707
  name=name,
1434
1708
  raise_exception=raise_exception,
1435
1709
  class_name="mlrun.serving.ModelRunner",
1436
- class_args=dict(model_selector=model_selector),
1710
+ class_args=dict(model_selector=(model_selector, model_selector_parameters)),
1437
1711
  **kwargs,
1438
1712
  )
1439
1713
  self.raise_exception = raise_exception
@@ -1449,10 +1723,6 @@ class ModelRunnerStep(MonitoredStep):
1449
1723
  model_endpoint_creation_strategy: Optional[
1450
1724
  schemas.ModelEndpointCreationStrategy
1451
1725
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1452
- inputs: Optional[list[str]] = None,
1453
- outputs: Optional[list[str]] = None,
1454
- input_path: Optional[str] = None,
1455
- result_path: Optional[str] = None,
1456
1726
  override: bool = False,
1457
1727
  ) -> None:
1458
1728
  """
@@ -1465,28 +1735,18 @@ class ModelRunnerStep(MonitoredStep):
1465
1735
  :param shared_model_name: str, the name of the shared model that is already defined within the graph
1466
1736
  :param labels: model endpoint labels, should be list of str or mapping of str:str
1467
1737
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1468
- * **overwrite**:
1469
- 1. If model endpoints with the same name exist, delete the `latest` one.
1470
- 2. Create a new model endpoint entry and set it as `latest`.
1471
- * **inplace** (default):
1472
- 1. If model endpoints with the same name exist, update the `latest` entry.
1473
- 2. Otherwise, create a new entry.
1474
- * **archive**:
1475
- 1. If model endpoints with the same name exist, preserve them.
1476
- 2. Create a new model endpoint with the same name and set it to `latest`.
1477
1738
 
1478
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1479
- that been configured in the model artifact, please note that those inputs need to
1480
- be equal in length and order to the inputs that model_class predict method expects
1481
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1482
- that been configured in the model artifact, please note that those outputs need to
1483
- be equal to the model_class predict method outputs (length, and order)
1484
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1485
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1486
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
1487
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1488
- in path.
1739
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1740
+ create a new model endpoint entry and set it as `latest`.
1741
+
1742
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest` entry;
1743
+ otherwise, create a new entry.
1744
+
1745
+ * **archive**: If model endpoints with the same name exist, preserve them;
1746
+ create a new model endpoint with the same name and set it to `latest`.
1747
+
1489
1748
  :param override: bool allow override existing model on the current ModelRunnerStep.
1749
+ :raise GraphError: when the shared model is not found in the root flow step shared models.
1490
1750
  """
1491
1751
  model_class, model_params = (
1492
1752
  "mlrun.serving.Model",
@@ -1503,11 +1763,21 @@ class ModelRunnerStep(MonitoredStep):
1503
1763
  "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1504
1764
  )
1505
1765
  root = self._extract_root_step()
1766
+ shared_model_params = {}
1506
1767
  if isinstance(root, RootFlowStep):
1507
- shared_model_name = (
1508
- shared_model_name
1509
- or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1768
+ actual_shared_model_name, shared_model_class, shared_model_params = (
1769
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
1510
1770
  )
1771
+ if not actual_shared_model_name or (
1772
+ shared_model_name and actual_shared_model_name != shared_model_name
1773
+ ):
1774
+ raise GraphError(
1775
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1776
+ f"model {shared_model_name} is not in the shared models."
1777
+ )
1778
+ elif not shared_model_name:
1779
+ shared_model_name = actual_shared_model_name
1780
+ model_params["shared_runnable_name"] = shared_model_name
1511
1781
  if not root.shared_models or (
1512
1782
  root.shared_models
1513
1783
  and shared_model_name
@@ -1517,13 +1787,27 @@ class ModelRunnerStep(MonitoredStep):
1517
1787
  f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1518
1788
  f"model {shared_model_name} is not in the shared models."
1519
1789
  )
1520
- if shared_model_name not in self._shared_proxy_mapping:
1790
+ monitoring_data = self.class_args.get(
1791
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
1792
+ )
1793
+ monitoring_data.setdefault(endpoint_name, {})[
1794
+ schemas.MonitoringData.MODEL_CLASS
1795
+ ] = (
1796
+ shared_model_class
1797
+ if isinstance(shared_model_class, str)
1798
+ else shared_model_class.__class__.__name__
1799
+ )
1800
+ self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
1801
+ monitoring_data
1802
+ )
1803
+
1804
+ if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
1521
1805
  self._shared_proxy_mapping[shared_model_name] = {
1522
1806
  endpoint_name: model_artifact.uri
1523
1807
  if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1524
1808
  else model_artifact
1525
1809
  }
1526
- else:
1810
+ elif override and shared_model_name:
1527
1811
  self._shared_proxy_mapping[shared_model_name].update(
1528
1812
  {
1529
1813
  endpoint_name: model_artifact.uri
@@ -1538,11 +1822,11 @@ class ModelRunnerStep(MonitoredStep):
1538
1822
  model_artifact=model_artifact,
1539
1823
  labels=labels,
1540
1824
  model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1825
+ inputs=shared_model_params.get("inputs"),
1826
+ outputs=shared_model_params.get("outputs"),
1827
+ input_path=shared_model_params.get("input_path"),
1828
+ result_path=shared_model_params.get("result_path"),
1541
1829
  override=override,
1542
- inputs=inputs,
1543
- outputs=outputs,
1544
- input_path=input_path,
1545
- result_path=result_path,
1546
1830
  **model_params,
1547
1831
  )
1548
1832
 
@@ -1567,8 +1851,11 @@ class ModelRunnerStep(MonitoredStep):
1567
1851
  Add a Model to this ModelRunner.
1568
1852
 
1569
1853
  :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1570
- :param model_class: Model class name
1854
+ :param model_class: Model class name. If LLModel is chosen
1855
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
1856
+ outputs will be overridden with UsageResponseKeys fields.
1571
1857
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
1858
+
1572
1859
  * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
1573
1860
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
1574
1861
  Lock (GIL).
@@ -1578,37 +1865,32 @@ class ModelRunnerStep(MonitoredStep):
1578
1865
  otherwise block the main event loop thread.
1579
1866
  * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
1580
1867
  event loop to continue running while waiting for a response.
1581
- * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
1582
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
1583
- useful when:
1584
- - You want to share a heavy resource like a large model loaded onto a GPU.
1585
- - You want to centralize task scheduling or coordination for multiple lightweight tasks.
1586
- - You aim to minimize overhead from creating new executors or processes/threads per runnable.
1587
- The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
1588
- memory and hardware accelerators.
1589
1868
  * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
1590
1869
  It means that the runnable will not actually be run in parallel to anything else.
1591
1870
 
1592
- :param model_artifact: model artifact or mlrun model artifact uri
1593
- :param labels: model endpoint labels, should be list of str or mapping of str:str
1594
- :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1595
- * **overwrite**:
1596
- 1. If model endpoints with the same name exist, delete the `latest` one.
1597
- 2. Create a new model endpoint entry and set it as `latest`.
1598
- * **inplace** (default):
1599
- 1. If model endpoints with the same name exist, update the `latest` entry.
1600
- 2. Otherwise, create a new entry.
1601
- * **archive**:
1602
- 1. If model endpoints with the same name exist, preserve them.
1603
- 2. Create a new model endpoint with the same name and set it to `latest`.
1604
-
1605
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1871
+ :param model_artifact: model artifact or mlrun model artifact uri
1872
+ :param labels: model endpoint labels, should be list of str or mapping of str:str
1873
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1874
+
1875
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1876
+ create a new model endpoint entry and set it as `latest`.
1877
+
1878
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
1879
+ entry; otherwise, create a new entry.
1880
+
1881
+ * **archive**: If model endpoints with the same name exist, preserve them;
1882
+ create a new model endpoint with the same name and set it to `latest`.
1883
+
1884
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1606
1885
  that been configured in the model artifact, please note that those inputs need to
1607
1886
  be equal in length and order to the inputs that model_class predict method expects
1608
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1887
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1609
1888
  that been configured in the model artifact, please note that those outputs need to
1610
1889
  be equal to the model_class predict method outputs (length, and order)
1611
- :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
1890
+
1891
+ When using LLModel, the output will be overridden with UsageResponseKeys.fields().
1892
+
1893
+ :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
1612
1894
  this require that the event body will behave like a dict, expects scopes to be
1613
1895
  defined by dot notation (e.g "data.d").
1614
1896
  examples: input_path="data.b"
@@ -1618,7 +1900,7 @@ class ModelRunnerStep(MonitoredStep):
1618
1900
  be {"f0": [1, 2]}.
1619
1901
  if a ``list`` or ``list of lists`` is provided, it must follow the order and
1620
1902
  size defined by the input schema.
1621
- :param result_path: when specified selects the key/path in the output event to use as model monitoring
1903
+ :param result_path: when specified selects the key/path in the output event to use as model monitoring
1622
1904
  outputs this require that the output event body will behave like a dict,
1623
1905
  expects scopes to be defined by dot notation (e.g "data.d").
1624
1906
  examples: result_path="out.b"
@@ -1629,14 +1911,22 @@ class ModelRunnerStep(MonitoredStep):
1629
1911
  if a ``list`` or ``list of lists`` is provided, it must follow the order and
1630
1912
  size defined by the output schema.
1631
1913
 
1632
- :param override: bool allow override existing model on the current ModelRunnerStep.
1633
- :param model_parameters: Parameters for model instantiation
1914
+ :param override: bool allow override existing model on the current ModelRunnerStep.
1915
+ :param model_parameters: Parameters for model instantiation
1634
1916
  """
1635
1917
  if isinstance(model_class, Model) and model_parameters:
1636
1918
  raise mlrun.errors.MLRunInvalidArgumentError(
1637
1919
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
1638
1920
  )
1639
-
1921
+ if type(model_class) is LLModel or (
1922
+ isinstance(model_class, str)
1923
+ and model_class.split(".")[-1] == LLModel.__name__
1924
+ ):
1925
+ if outputs:
1926
+ warnings.warn(
1927
+ "LLModel with existing outputs detected, overriding to default"
1928
+ )
1929
+ outputs = UsageResponseKeys.fields()
1640
1930
  model_parameters = model_parameters or (
1641
1931
  model_class.to_dict() if isinstance(model_class, Model) else {}
1642
1932
  )
@@ -1652,8 +1942,6 @@ class ModelRunnerStep(MonitoredStep):
1652
1942
  except mlrun.errors.MLRunNotFoundError:
1653
1943
  raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
1654
1944
 
1655
- outputs = outputs or self._get_model_output_schema(model_artifact)
1656
-
1657
1945
  model_artifact = (
1658
1946
  model_artifact.uri
1659
1947
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
@@ -1719,28 +2007,13 @@ class ModelRunnerStep(MonitoredStep):
1719
2007
  self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
1720
2008
 
1721
2009
  @staticmethod
1722
- def _get_model_output_schema(
1723
- model_artifact: Union[ModelArtifact, LLMPromptArtifact],
1724
- ) -> Optional[list[str]]:
1725
- if isinstance(
1726
- model_artifact,
1727
- ModelArtifact,
1728
- ):
1729
- return [feature.name for feature in model_artifact.spec.outputs]
1730
- elif isinstance(
1731
- model_artifact,
1732
- LLMPromptArtifact,
1733
- ):
1734
- _model_artifact = model_artifact.model_artifact
1735
- return [feature.name for feature in _model_artifact.spec.outputs]
1736
-
1737
- @staticmethod
1738
- def _get_model_endpoint_output_schema(
2010
+ def _get_model_endpoint_schema(
1739
2011
  name: str,
1740
2012
  project: str,
1741
2013
  uid: str,
1742
- ) -> list[str]:
2014
+ ) -> tuple[list[str], list[str]]:
1743
2015
  output_schema = None
2016
+ input_schema = None
1744
2017
  try:
1745
2018
  model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
1746
2019
  mlrun.db.get_run_db().get_model_endpoint(
@@ -1751,14 +2024,16 @@ class ModelRunnerStep(MonitoredStep):
1751
2024
  )
1752
2025
  )
1753
2026
  output_schema = model_endpoint.spec.label_names
2027
+ input_schema = model_endpoint.spec.feature_names
1754
2028
  except (
1755
2029
  mlrun.errors.MLRunNotFoundError,
1756
2030
  mlrun.errors.MLRunInvalidArgumentError,
1757
- ):
2031
+ ) as ex:
1758
2032
  logger.warning(
1759
- f"Model endpoint not found, using default output schema for model {name}"
2033
+ f"Model endpoint not found, using default output schema for model {name}",
2034
+ error=f"{type(ex).__name__}: {ex}",
1760
2035
  )
1761
- return output_schema
2036
+ return input_schema, output_schema
1762
2037
 
1763
2038
  def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]:
1764
2039
  monitoring_data = deepcopy(
@@ -1768,55 +2043,106 @@ class ModelRunnerStep(MonitoredStep):
1768
2043
  )
1769
2044
  if isinstance(monitoring_data, dict):
1770
2045
  for model in monitoring_data:
1771
- monitoring_data[model][schemas.MonitoringData.OUTPUTS] = (
1772
- monitoring_data.get(model, {}).get(schemas.MonitoringData.OUTPUTS)
1773
- or self._get_model_endpoint_output_schema(
1774
- name=model,
1775
- project=self.context.project if self.context else None,
1776
- uid=monitoring_data.get(model, {}).get(
1777
- mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
1778
- ),
1779
- )
1780
- )
1781
- # Prevent calling _get_model_output_schema for same model more than once
1782
- self.class_args[
1783
- mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
1784
- ][model][schemas.MonitoringData.OUTPUTS] = monitoring_data[model][
1785
- schemas.MonitoringData.OUTPUTS
1786
- ]
1787
2046
  monitoring_data[model][schemas.MonitoringData.INPUT_PATH] = split_path(
1788
2047
  monitoring_data[model][schemas.MonitoringData.INPUT_PATH]
1789
2048
  )
1790
2049
  monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path(
1791
2050
  monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
1792
2051
  )
2052
+
2053
+ mep_output_schema, mep_input_schema = None, None
2054
+
2055
+ output_schema = self.class_args[
2056
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2057
+ ][model][schemas.MonitoringData.OUTPUTS]
2058
+ input_schema = self.class_args[
2059
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2060
+ ][model][schemas.MonitoringData.INPUTS]
2061
+ if not output_schema or not input_schema:
2062
+ # if output or input schema is not provided, try to get it from the model endpoint
2063
+ mep_input_schema, mep_output_schema = (
2064
+ self._get_model_endpoint_schema(
2065
+ model,
2066
+ self.context.project,
2067
+ monitoring_data[model].get(
2068
+ schemas.MonitoringData.MODEL_ENDPOINT_UID, ""
2069
+ ),
2070
+ )
2071
+ )
2072
+ self.class_args[
2073
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2074
+ ][model][schemas.MonitoringData.OUTPUTS] = (
2075
+ output_schema or mep_output_schema
2076
+ )
2077
+ self.class_args[
2078
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2079
+ ][model][schemas.MonitoringData.INPUTS] = (
2080
+ input_schema or mep_input_schema
2081
+ )
1793
2082
  return monitoring_data
2083
+ else:
2084
+ raise mlrun.errors.MLRunInvalidArgumentError(
2085
+ "Monitoring data must be a dictionary."
2086
+ )
2087
+
2088
+ def configure_pool_resource(
2089
+ self,
2090
+ max_processes: Optional[int] = None,
2091
+ max_threads: Optional[int] = None,
2092
+ pool_factor: Optional[int] = None,
2093
+ ) -> None:
2094
+ """
2095
+ Configure the resource limits for the shared models in the graph.
2096
+
2097
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2098
+ Defaults to the number of CPUs or 16 if undetectable.
2099
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2100
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2101
+ """
2102
+ self.max_processes = max_processes
2103
+ self.max_threads = max_threads
2104
+ self.pool_factor = pool_factor
1794
2105
 
1795
2106
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1796
2107
  self.context = context
1797
2108
  if not self._is_local_function(context):
1798
2109
  # skip init of non local functions
1799
2110
  return
1800
- model_selector = self.class_args.get("model_selector")
2111
+ model_selector, model_selector_params = self.class_args.get(
2112
+ "model_selector", (None, None)
2113
+ )
1801
2114
  execution_mechanism_by_model_name = self.class_args.get(
1802
2115
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
1803
2116
  )
1804
2117
  models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1805
- if isinstance(model_selector, str):
1806
- model_selector = get_class(model_selector, namespace)()
2118
+ if model_selector:
2119
+ model_selector = get_class(model_selector, namespace).from_dict(
2120
+ model_selector_params, init_with_params=True
2121
+ )
1807
2122
  model_objects = []
1808
2123
  for model, model_params in models.values():
2124
+ model_name = model_params.get("name")
1809
2125
  model_params[schemas.MonitoringData.INPUT_PATH] = (
1810
2126
  self.class_args.get(
1811
2127
  mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
1812
2128
  )
1813
- .get(model_params.get("name"), {})
2129
+ .get(model_name, {})
1814
2130
  .get(schemas.MonitoringData.INPUT_PATH)
1815
2131
  )
2132
+ model_params[schemas.MonitoringData.RESULT_PATH] = (
2133
+ self.class_args.get(
2134
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
2135
+ )
2136
+ .get(model_name, {})
2137
+ .get(schemas.MonitoringData.RESULT_PATH)
2138
+ )
1816
2139
  model = get_class(model, namespace).from_dict(
1817
2140
  model_params, init_with_params=True
1818
2141
  )
1819
2142
  model._raise_exception = False
2143
+ model._execution_mechanism = execution_mechanism_by_model_name.get(
2144
+ model_name
2145
+ )
1820
2146
  model_objects.append(model)
1821
2147
  self._async_object = ModelRunner(
1822
2148
  model_selector=model_selector,
@@ -1825,6 +2151,9 @@ class ModelRunnerStep(MonitoredStep):
1825
2151
  shared_proxy_mapping=self._shared_proxy_mapping or None,
1826
2152
  name=self.name,
1827
2153
  context=context,
2154
+ max_processes=self.max_processes,
2155
+ max_threads=self.max_threads,
2156
+ pool_factor=self.pool_factor,
1828
2157
  )
1829
2158
 
1830
2159
 
@@ -2044,20 +2373,14 @@ class FlowStep(BaseStep):
2044
2373
  to event["y"] resulting in {"x": 5, "y": <result>}
2045
2374
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
2046
2375
 
2047
- * **overwrite**:
2048
-
2049
- 1. If model endpoints with the same name exist, delete the `latest` one.
2050
- 2. Create a new model endpoint entry and set it as `latest`.
2051
-
2052
- * **inplace** (default):
2376
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
2377
+ create a new model endpoint entry and set it as `latest`.
2053
2378
 
2054
- 1. If model endpoints with the same name exist, update the `latest` entry.
2055
- 2. Otherwise, create a new entry.
2379
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
2380
+ entry; otherwise, create a new entry.
2056
2381
 
2057
- * **archive**:
2058
-
2059
- 1. If model endpoints with the same name exist, preserve them.
2060
- 2. Create a new model endpoint with the same name and set it to `latest`.
2382
+ * **archive**: If model endpoints with the same name exist, preserve them;
2383
+ create a new model endpoint with the same name and set it to `latest`.
2061
2384
 
2062
2385
  :param class_args: class init arguments
2063
2386
  """
@@ -2552,35 +2875,64 @@ class RootFlowStep(FlowStep):
2552
2875
  model_class: Union[str, Model],
2553
2876
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
2554
2877
  model_artifact: Union[str, ModelArtifact],
2878
+ inputs: Optional[list[str]] = None,
2879
+ outputs: Optional[list[str]] = None,
2880
+ input_path: Optional[str] = None,
2881
+ result_path: Optional[str] = None,
2555
2882
  override: bool = False,
2556
2883
  **model_parameters,
2557
2884
  ) -> None:
2558
2885
  """
2559
2886
  Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2560
2887
  :param name: Name of the shared model (should be unique in the graph)
2561
- :param model_class: Model class name
2888
+ :param model_class: Model class name. If LLModel is chosen
2889
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
2890
+ outputs will be overridden with UsageResponseKeys fields.
2562
2891
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2563
- * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
2892
+
2893
+ * **process_pool**: To run in a separate process from a process pool. This is appropriate for CPU or GPU
2564
2894
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
2565
2895
  Lock (GIL).
2566
- * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
2567
- tasks that also require significant Runnable-specific initialization (e.g. a large model).
2568
- * "thread_pool" To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
2896
+
2897
+ * **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU or GPU
2898
+ intensive tasks that also require significant Runnable-specific initialization (e.g. a large model).
2899
+
2900
+ * **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
2569
2901
  otherwise block the main event loop thread.
2570
- * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
2902
+
2903
+ * **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
2571
2904
  event loop to continue running while waiting for a response.
2572
- * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
2905
+
2906
+ * **shared_executor": Reuses an external executor (typically managed by the flow or context) to execute the
2573
2907
  runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
2574
2908
  useful when:
2909
+
2575
2910
  - You want to share a heavy resource like a large model loaded onto a GPU.
2911
+
2576
2912
  - You want to centralize task scheduling or coordination for multiple lightweight tasks.
2913
+
2577
2914
  - You aim to minimize overhead from creating new executors or processes/threads per runnable.
2915
+
2578
2916
  The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
2579
2917
  memory and hardware accelerators.
2580
- * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
2581
- It means that the runnable will not actually be run in parallel to anything else.
2918
+
2919
+ * **naive**: To run in the main event loop. This is appropriate only for trivial computation and/or file
2920
+ I/O. It means that the runnable will not actually be run in parallel to anything else.
2582
2921
 
2583
2922
  :param model_artifact: model artifact or mlrun model artifact uri
2923
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2924
+ that been configured in the model artifact, please note that those inputs need
2925
+ to be equal in length and order to the inputs that model_class
2926
+ predict method expects
2927
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2928
+ that been configured in the model artifact, please note that those outputs need
2929
+ to be equal to the model_class
2930
+ predict method outputs (length, and order)
2931
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
2932
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
2933
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
2934
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
2935
+ in path.
2584
2936
  :param override: bool allow override existing model on the current ModelRunnerStep.
2585
2937
  :param model_parameters: Parameters for model instantiation
2586
2938
  """
@@ -2588,6 +2940,15 @@ class RootFlowStep(FlowStep):
2588
2940
  raise mlrun.errors.MLRunInvalidArgumentError(
2589
2941
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2590
2942
  )
2943
+ if type(model_class) is LLModel or (
2944
+ isinstance(model_class, str)
2945
+ and model_class.split(".")[-1] == LLModel.__name__
2946
+ ):
2947
+ if outputs:
2948
+ warnings.warn(
2949
+ "LLModel with existing outputs detected, overriding to default"
2950
+ )
2951
+ outputs = UsageResponseKeys.fields()
2591
2952
 
2592
2953
  if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2593
2954
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2615,6 +2976,14 @@ class RootFlowStep(FlowStep):
2615
2976
  "Inconsistent name for the added model."
2616
2977
  )
2617
2978
  model_parameters["name"] = name
2979
+ model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
2980
+ model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
2981
+ model_parameters["input_path"] = input_path or model_parameters.get(
2982
+ "input_path"
2983
+ )
2984
+ model_parameters["result_path"] = result_path or model_parameters.get(
2985
+ "result_path"
2986
+ )
2618
2987
 
2619
2988
  if name in self.shared_models and not override:
2620
2989
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2629,7 +2998,9 @@ class RootFlowStep(FlowStep):
2629
2998
  self.shared_models[name] = (model_class, model_parameters)
2630
2999
  self.shared_models_mechanism[name] = execution_mechanism
2631
3000
 
2632
- def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
3001
+ def get_shared_model_by_artifact_uri(
3002
+ self, artifact_uri: str
3003
+ ) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
2633
3004
  """
2634
3005
  Get a shared model by its artifact URI.
2635
3006
  :param artifact_uri: The artifact URI of the model.
@@ -2637,10 +3008,10 @@ class RootFlowStep(FlowStep):
2637
3008
  """
2638
3009
  for model_name, (model_class, model_params) in self.shared_models.items():
2639
3010
  if model_params.get("artifact_uri") == artifact_uri:
2640
- return model_name
2641
- return None
3011
+ return model_name, model_class, model_params
3012
+ return None, None, None
2642
3013
 
2643
- def config_pool_resource(
3014
+ def configure_shared_pool_resource(
2644
3015
  self,
2645
3016
  max_processes: Optional[int] = None,
2646
3017
  max_threads: Optional[int] = None,
@@ -2688,6 +3059,7 @@ class RootFlowStep(FlowStep):
2688
3059
  model_params, init_with_params=True
2689
3060
  )
2690
3061
  model._raise_exception = False
3062
+ model._execution_mechanism = self._shared_models_mechanism[model.name]
2691
3063
  self.context.executor.add_runnable(
2692
3064
  model, self._shared_models_mechanism[model.name]
2693
3065
  )
@@ -2807,12 +3179,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
2807
3179
  graph.edge(step.fullname, route.fullname)
2808
3180
 
2809
3181
 
2810
- def _add_graphviz_model_runner(graph, step, source=None):
3182
+ def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
2811
3183
  if source:
2812
3184
  graph.node("_start", source.name, shape=source.shape, style="filled")
2813
3185
  graph.edge("_start", step.fullname)
2814
-
2815
- is_monitored = step._extract_root_step().track_models
2816
3186
  m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
2817
3187
 
2818
3188
  number_of_models = len(
@@ -2851,6 +3221,7 @@ def _add_graphviz_flow(
2851
3221
  allow_empty=True
2852
3222
  )
2853
3223
  graph.node("_start", source.name, shape=source.shape, style="filled")
3224
+ is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
2854
3225
  for start_step in start_steps:
2855
3226
  graph.edge("_start", start_step.fullname)
2856
3227
  for child in step.get_children():
@@ -2859,7 +3230,7 @@ def _add_graphviz_flow(
2859
3230
  with graph.subgraph(name="cluster_" + child.fullname) as sg:
2860
3231
  _add_graphviz_router(sg, child)
2861
3232
  elif kind == StepKinds.model_runner:
2862
- _add_graphviz_model_runner(graph, child)
3233
+ _add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
2863
3234
  else:
2864
3235
  graph.node(child.fullname, label=child.name, shape=child.get_shape())
2865
3236
  _add_edges(child.after or [], step, graph, child)
@@ -3078,7 +3449,7 @@ def _init_async_objects(context, steps):
3078
3449
  datastore_profile = datastore_profile_read(stream_path)
3079
3450
  if isinstance(
3080
3451
  datastore_profile,
3081
- (DatastoreProfileKafkaTarget, DatastoreProfileKafkaSource),
3452
+ (DatastoreProfileKafkaTarget, DatastoreProfileKafkaStream),
3082
3453
  ):
3083
3454
  step._async_object = KafkaStoreyTarget(
3084
3455
  path=stream_path,
@@ -3094,7 +3465,7 @@ def _init_async_objects(context, steps):
3094
3465
  else:
3095
3466
  raise mlrun.errors.MLRunValueError(
3096
3467
  f"Received an unexpected stream profile type: {type(datastore_profile)}\n"
3097
- "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`."
3468
+ "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaStream`."
3098
3469
  )
3099
3470
  elif stream_path.startswith("kafka://") or kafka_brokers:
3100
3471
  topic, brokers = parse_kafka_url(stream_path, kafka_brokers)
@@ -3110,6 +3481,8 @@ def _init_async_objects(context, steps):
3110
3481
  context=context,
3111
3482
  **options,
3112
3483
  )
3484
+ elif stream_path.startswith("dummy://"):
3485
+ step._async_object = _DummyStream(context=context, **options)
3113
3486
  else:
3114
3487
  if stream_path.startswith("v3io://"):
3115
3488
  endpoint, stream_path = parse_path(step.path)