mlrun 1.10.0rc16__py3-none-any.whl → 1.10.0rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (98) hide show
  1. mlrun/__init__.py +22 -2
  2. mlrun/artifacts/document.py +6 -1
  3. mlrun/artifacts/llm_prompt.py +21 -15
  4. mlrun/artifacts/model.py +3 -3
  5. mlrun/common/constants.py +9 -0
  6. mlrun/common/formatters/artifact.py +1 -0
  7. mlrun/common/model_monitoring/helpers.py +86 -0
  8. mlrun/common/schemas/__init__.py +2 -0
  9. mlrun/common/schemas/auth.py +2 -0
  10. mlrun/common/schemas/function.py +10 -0
  11. mlrun/common/schemas/hub.py +30 -18
  12. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  13. mlrun/common/schemas/model_monitoring/constants.py +30 -6
  14. mlrun/common/schemas/model_monitoring/functions.py +13 -4
  15. mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
  16. mlrun/common/schemas/pipeline.py +1 -1
  17. mlrun/common/schemas/serving.py +3 -0
  18. mlrun/common/schemas/workflow.py +1 -0
  19. mlrun/common/secrets.py +22 -1
  20. mlrun/config.py +32 -10
  21. mlrun/datastore/__init__.py +11 -3
  22. mlrun/datastore/azure_blob.py +162 -47
  23. mlrun/datastore/datastore.py +9 -4
  24. mlrun/datastore/datastore_profile.py +61 -5
  25. mlrun/datastore/model_provider/huggingface_provider.py +363 -0
  26. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  27. mlrun/datastore/model_provider/model_provider.py +211 -74
  28. mlrun/datastore/model_provider/openai_provider.py +243 -71
  29. mlrun/datastore/s3.py +24 -2
  30. mlrun/datastore/storeytargets.py +2 -3
  31. mlrun/datastore/utils.py +15 -3
  32. mlrun/db/base.py +27 -19
  33. mlrun/db/httpdb.py +57 -48
  34. mlrun/db/nopdb.py +25 -10
  35. mlrun/execution.py +55 -13
  36. mlrun/hub/__init__.py +15 -0
  37. mlrun/hub/module.py +181 -0
  38. mlrun/k8s_utils.py +105 -16
  39. mlrun/launcher/base.py +13 -6
  40. mlrun/launcher/local.py +2 -0
  41. mlrun/model.py +9 -3
  42. mlrun/model_monitoring/api.py +66 -27
  43. mlrun/model_monitoring/applications/__init__.py +1 -1
  44. mlrun/model_monitoring/applications/base.py +372 -136
  45. mlrun/model_monitoring/applications/context.py +2 -4
  46. mlrun/model_monitoring/applications/results.py +4 -7
  47. mlrun/model_monitoring/controller.py +239 -101
  48. mlrun/model_monitoring/db/_schedules.py +36 -13
  49. mlrun/model_monitoring/db/_stats.py +4 -3
  50. mlrun/model_monitoring/db/tsdb/base.py +29 -9
  51. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +4 -5
  52. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +154 -50
  53. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
  54. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
  55. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +245 -51
  56. mlrun/model_monitoring/helpers.py +28 -5
  57. mlrun/model_monitoring/stream_processing.py +45 -14
  58. mlrun/model_monitoring/writer.py +220 -1
  59. mlrun/platforms/__init__.py +3 -2
  60. mlrun/platforms/iguazio.py +7 -3
  61. mlrun/projects/operations.py +6 -1
  62. mlrun/projects/pipelines.py +2 -2
  63. mlrun/projects/project.py +128 -45
  64. mlrun/run.py +94 -17
  65. mlrun/runtimes/__init__.py +18 -0
  66. mlrun/runtimes/base.py +14 -6
  67. mlrun/runtimes/daskjob.py +1 -0
  68. mlrun/runtimes/local.py +5 -2
  69. mlrun/runtimes/mounts.py +20 -2
  70. mlrun/runtimes/nuclio/__init__.py +1 -0
  71. mlrun/runtimes/nuclio/application/application.py +147 -17
  72. mlrun/runtimes/nuclio/function.py +70 -27
  73. mlrun/runtimes/nuclio/serving.py +85 -4
  74. mlrun/runtimes/pod.py +213 -21
  75. mlrun/runtimes/utils.py +49 -9
  76. mlrun/secrets.py +54 -13
  77. mlrun/serving/remote.py +79 -6
  78. mlrun/serving/routers.py +23 -41
  79. mlrun/serving/server.py +211 -40
  80. mlrun/serving/states.py +536 -156
  81. mlrun/serving/steps.py +62 -0
  82. mlrun/serving/system_steps.py +136 -81
  83. mlrun/serving/v2_serving.py +9 -10
  84. mlrun/utils/helpers.py +212 -82
  85. mlrun/utils/logger.py +3 -1
  86. mlrun/utils/notifications/notification/base.py +18 -0
  87. mlrun/utils/notifications/notification/git.py +2 -4
  88. mlrun/utils/notifications/notification/slack.py +2 -4
  89. mlrun/utils/notifications/notification/webhook.py +2 -5
  90. mlrun/utils/notifications/notification_pusher.py +1 -1
  91. mlrun/utils/version/version.json +2 -2
  92. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/METADATA +44 -45
  93. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/RECORD +97 -92
  94. mlrun/api/schemas/__init__.py +0 -259
  95. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/WHEEL +0 -0
  96. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/entry_points.txt +0 -0
  97. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/licenses/LICENSE +0 -0
  98. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -24,6 +24,7 @@ import inspect
24
24
  import os
25
25
  import pathlib
26
26
  import traceback
27
+ import warnings
27
28
  from abc import ABC
28
29
  from copy import copy, deepcopy
29
30
  from inspect import getfullargspec, signature
@@ -38,17 +39,21 @@ import mlrun.common.schemas as schemas
38
39
  from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
39
40
  from mlrun.artifacts.model import ModelArtifact
40
41
  from mlrun.datastore.datastore_profile import (
41
- DatastoreProfileKafkaSource,
42
+ DatastoreProfileKafkaStream,
42
43
  DatastoreProfileKafkaTarget,
43
44
  DatastoreProfileV3io,
44
45
  datastore_profile_read,
45
46
  )
46
- from mlrun.datastore.model_provider.model_provider import ModelProvider
47
+ from mlrun.datastore.model_provider.model_provider import (
48
+ InvokeResponseFormat,
49
+ ModelProvider,
50
+ UsageResponseKeys,
51
+ )
47
52
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
48
- from mlrun.utils import get_data_from_path, logger, split_path
53
+ from mlrun.utils import get_data_from_path, logger, set_data_by_path, split_path
49
54
 
50
55
  from ..config import config
51
- from ..datastore import get_stream_pusher
56
+ from ..datastore import _DummyStream, get_stream_pusher
52
57
  from ..datastore.utils import (
53
58
  get_kafka_brokers_from_dict,
54
59
  parse_kafka_url,
@@ -517,7 +522,9 @@ class BaseStep(ModelObj):
517
522
 
518
523
  root = self._extract_root_step()
519
524
 
520
- if not isinstance(root, RootFlowStep):
525
+ if not isinstance(root, RootFlowStep) or (
526
+ isinstance(root, RootFlowStep) and root.engine != "async"
527
+ ):
521
528
  raise GraphError(
522
529
  "ModelRunnerStep can be added to 'Flow' topology graph only"
523
530
  )
@@ -541,8 +548,8 @@ class BaseStep(ModelObj):
541
548
  # Update model endpoints names in the root step
542
549
  root.update_model_endpoints_names(step_model_endpoints_names)
543
550
 
544
- @staticmethod
545
551
  def _verify_shared_models(
552
+ self,
546
553
  root: "RootFlowStep",
547
554
  step: "ModelRunnerStep",
548
555
  step_model_endpoints_names: list[str],
@@ -571,35 +578,41 @@ class BaseStep(ModelObj):
571
578
  prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
572
579
  # if the model artifact is a prompt, we need to get the model URI
573
580
  # to ensure that the shared runnable name is correct
581
+ llm_artifact_uri = None
574
582
  if prefix == mlrun.utils.StorePrefix.LLMPrompt:
575
583
  llm_artifact, _ = mlrun.store_manager.get_store_artifact(
576
584
  model_artifact_uri
577
585
  )
586
+ llm_artifact_uri = llm_artifact.uri
578
587
  model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
579
588
  llm_artifact.spec.parent_uri
580
589
  )
581
- actual_shared_name = root.get_shared_model_name_by_artifact_uri(
582
- model_artifact_uri
590
+ actual_shared_name, shared_model_class, shared_model_params = (
591
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
583
592
  )
584
593
 
585
- if not shared_runnable_name:
586
- if not actual_shared_name:
587
- raise GraphError(
588
- f"Can't find shared model for {name} model endpoint"
589
- )
590
- else:
591
- step.class_args[schemas.ModelRunnerStepData.MODELS][name][
592
- schemas.ModelsData.MODEL_PARAMETERS.value
593
- ]["shared_runnable_name"] = actual_shared_name
594
- shared_models.append(actual_shared_name)
594
+ if not actual_shared_name:
595
+ raise GraphError(
596
+ f"Can't find shared model named {shared_runnable_name}"
597
+ )
598
+ elif not shared_runnable_name:
599
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
600
+ schemas.ModelsData.MODEL_PARAMETERS.value
601
+ ]["shared_runnable_name"] = actual_shared_name
595
602
  elif actual_shared_name != shared_runnable_name:
596
603
  raise GraphError(
597
604
  f"Model endpoint {name} shared runnable name mismatch: "
598
605
  f"expected {actual_shared_name}, got {shared_runnable_name}"
599
606
  )
600
- else:
601
- shared_models.append(actual_shared_name)
602
-
607
+ shared_models.append(actual_shared_name)
608
+ self._edit_proxy_model_data(
609
+ step,
610
+ name,
611
+ actual_shared_name,
612
+ shared_model_params,
613
+ shared_model_class,
614
+ llm_artifact_uri or model_artifact_uri,
615
+ )
603
616
  undefined_shared_models = list(
604
617
  set(shared_models) - set(root.shared_models.keys())
605
618
  )
@@ -608,6 +621,52 @@ class BaseStep(ModelObj):
608
621
  f"The following shared models are not defined in the graph: {undefined_shared_models}."
609
622
  )
610
623
 
624
+ @staticmethod
625
+ def _edit_proxy_model_data(
626
+ step: "ModelRunnerStep",
627
+ name: str,
628
+ actual_shared_name: str,
629
+ shared_model_params: dict,
630
+ shared_model_class: Any,
631
+ artifact: Union[ModelArtifact, LLMPromptArtifact, str],
632
+ ):
633
+ monitoring_data = step.class_args.setdefault(
634
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
635
+ )
636
+
637
+ # edit monitoring data according to the shared model parameters
638
+ monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
639
+ "input_path"
640
+ ]
641
+ monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
642
+ "result_path"
643
+ ]
644
+ monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
645
+ "inputs"
646
+ ]
647
+ monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
648
+ "outputs"
649
+ ]
650
+ monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
651
+ shared_model_class
652
+ if isinstance(shared_model_class, str)
653
+ else shared_model_class.__class__.__name__
654
+ )
655
+ if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
656
+ step._shared_proxy_mapping[actual_shared_name] = {
657
+ name: artifact.uri
658
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
659
+ else artifact
660
+ }
661
+ elif actual_shared_name:
662
+ step._shared_proxy_mapping[actual_shared_name].update(
663
+ {
664
+ name: artifact.uri
665
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
666
+ else artifact
667
+ }
668
+ )
669
+
611
670
 
612
671
  class TaskStep(BaseStep):
613
672
  """task execution step, runs a class or handler"""
@@ -1090,6 +1149,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1090
1149
  "artifact_uri",
1091
1150
  "shared_runnable_name",
1092
1151
  "shared_proxy_mapping",
1152
+ "execution_mechanism",
1093
1153
  ]
1094
1154
  kind = "model"
1095
1155
 
@@ -1111,6 +1171,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1111
1171
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1112
1172
  self.model_artifact: Optional[ModelArtifact] = None
1113
1173
  self.model_provider: Optional[ModelProvider] = None
1174
+ self._artifact_were_loaded = False
1175
+ self._execution_mechanism = None
1114
1176
 
1115
1177
  def __init_subclass__(cls):
1116
1178
  super().__init_subclass__()
@@ -1130,13 +1192,29 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1130
1192
  raise_missing_schema_exception=False,
1131
1193
  )
1132
1194
 
1133
- def _load_artifacts(self) -> None:
1134
- artifact = self._get_artifact_object()
1135
- if isinstance(artifact, LLMPromptArtifact):
1136
- self.invocation_artifact = artifact
1137
- self.model_artifact = self.invocation_artifact.model_artifact
1195
+ # Check if the relevant predict method is implemented when trying to initialize the model
1196
+ if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
1197
+ if self.__class__.predict_async is Model.predict_async:
1198
+ raise mlrun.errors.ModelRunnerError(
1199
+ f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict_async() "
1200
+ f"is not implemented"
1201
+ )
1138
1202
  else:
1139
- self.model_artifact = artifact
1203
+ if self.__class__.predict is Model.predict:
1204
+ raise mlrun.errors.ModelRunnerError(
1205
+ f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict() "
1206
+ f"is not implemented"
1207
+ )
1208
+
1209
+ def _load_artifacts(self) -> None:
1210
+ if not self._artifact_were_loaded:
1211
+ artifact = self._get_artifact_object()
1212
+ if isinstance(artifact, LLMPromptArtifact):
1213
+ self.invocation_artifact = artifact
1214
+ self.model_artifact = self.invocation_artifact.model_artifact
1215
+ else:
1216
+ self.model_artifact = artifact
1217
+ self._artifact_were_loaded = True
1140
1218
 
1141
1219
  def _get_artifact_object(
1142
1220
  self, proxy_uri: Optional[str] = None
@@ -1158,11 +1236,11 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1158
1236
 
1159
1237
  def predict(self, body: Any, **kwargs) -> Any:
1160
1238
  """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
1161
- return body
1239
+ raise NotImplementedError("predict() method not implemented")
1162
1240
 
1163
1241
  async def predict_async(self, body: Any, **kwargs) -> Any:
1164
1242
  """Override to implement prediction logic if the logic requires asyncio."""
1165
- return body
1243
+ raise NotImplementedError("predict_async() method not implemented")
1166
1244
 
1167
1245
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1168
1246
  return self.predict(body)
@@ -1205,26 +1283,111 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1205
1283
 
1206
1284
 
1207
1285
  class LLModel(Model):
1286
+ """
1287
+ A model wrapper for handling LLM (Large Language Model) prompt-based inference.
1288
+
1289
+ This class extends the base `Model` to provide specialized handling for
1290
+ `LLMPromptArtifact` objects, enabling both synchronous and asynchronous
1291
+ invocation of language models.
1292
+
1293
+ **Model Invocation**:
1294
+
1295
+ - The execution of enriched prompts is delegated to the `model_provider`
1296
+ configured for the model (e.g., **Hugging Face** or **OpenAI**).
1297
+ - The `model_provider` is responsible for sending the prompt to the correct
1298
+ backend API and returning the generated output.
1299
+ - Users can override the `predict` and `predict_async` methods to customize
1300
+ the behavior of the model invocation.
1301
+
1302
+ **Prompt Enrichment Overview**:
1303
+
1304
+ - If an `LLMPromptArtifact` is found, load its prompt template and fill in
1305
+ placeholders using values from the request body.
1306
+ - If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
1307
+ to retrieve `messages` directly from the request body using the input path.
1308
+
1309
+ **Simplified Example**:
1310
+
1311
+ Input body::
1312
+
1313
+ {"city": "Paris", "days": 3}
1314
+
1315
+ Prompt template in artifact::
1316
+
1317
+ [
1318
+ {"role": "system", "content": "You are a travel planning assistant."},
1319
+ {"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
1320
+ ]
1321
+
1322
+ Result after enrichment::
1323
+
1324
+ [
1325
+ {"role": "system", "content": "You are a travel planning assistant."},
1326
+ {"role": "user", "content": "Create a 3-day itinerary for Paris."},
1327
+ ]
1328
+
1329
+ :param name: Name of the model.
1330
+ :param input_path: Path in the request body where input data is located.
1331
+ :param result_path: Path in the response body where model outputs and the statistics
1332
+ will be stored.
1333
+ """
1334
+
1335
+ _dict_fields = Model._dict_fields + ["result_path", "input_path"]
1336
+
1208
1337
  def __init__(
1209
- self, name: str, input_path: Optional[Union[str, list[str]]], **kwargs
1338
+ self,
1339
+ name: str,
1340
+ input_path: Optional[Union[str, list[str]]] = None,
1341
+ result_path: Optional[Union[str, list[str]]] = None,
1342
+ **kwargs,
1210
1343
  ):
1211
1344
  super().__init__(name, **kwargs)
1212
1345
  self._input_path = split_path(input_path)
1346
+ self._result_path = split_path(result_path)
1347
+ logger.info(
1348
+ "LLModel initialized",
1349
+ model_name=name,
1350
+ input_path=input_path,
1351
+ result_path=result_path,
1352
+ )
1213
1353
 
1214
1354
  def predict(
1215
1355
  self,
1216
1356
  body: Any,
1217
1357
  messages: Optional[list[dict]] = None,
1218
- model_configuration: Optional[dict] = None,
1358
+ invocation_config: Optional[dict] = None,
1219
1359
  **kwargs,
1220
1360
  ) -> Any:
1361
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1221
1362
  if isinstance(
1222
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1363
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1223
1364
  ) and isinstance(self.model_provider, ModelProvider):
1224
- body["result"] = self.model_provider.invoke(
1365
+ logger.debug(
1366
+ "Invoking model provider",
1367
+ model_name=self.name,
1225
1368
  messages=messages,
1226
- as_str=True,
1227
- **(model_configuration or {}),
1369
+ invocation_config=invocation_config,
1370
+ )
1371
+ response_with_stats = self.model_provider.invoke(
1372
+ messages=messages,
1373
+ invoke_response_format=InvokeResponseFormat.USAGE,
1374
+ **(invocation_config or {}),
1375
+ )
1376
+ set_data_by_path(
1377
+ path=self._result_path, data=body, value=response_with_stats
1378
+ )
1379
+ logger.debug(
1380
+ "LLModel prediction completed",
1381
+ model_name=self.name,
1382
+ answer=response_with_stats.get("answer"),
1383
+ usage=response_with_stats.get("usage"),
1384
+ )
1385
+ else:
1386
+ logger.warning(
1387
+ "LLModel invocation artifact or model provider not set, skipping prediction",
1388
+ model_name=self.name,
1389
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1390
+ model_provider_type=type(self.model_provider).__name__,
1228
1391
  )
1229
1392
  return body
1230
1393
 
@@ -1232,61 +1395,112 @@ class LLModel(Model):
1232
1395
  self,
1233
1396
  body: Any,
1234
1397
  messages: Optional[list[dict]] = None,
1235
- model_configuration: Optional[dict] = None,
1398
+ invocation_config: Optional[dict] = None,
1236
1399
  **kwargs,
1237
1400
  ) -> Any:
1401
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1238
1402
  if isinstance(
1239
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1403
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1240
1404
  ) and isinstance(self.model_provider, ModelProvider):
1241
- body["result"] = await self.model_provider.async_invoke(
1405
+ logger.debug(
1406
+ "Async invoking model provider",
1407
+ model_name=self.name,
1242
1408
  messages=messages,
1243
- as_str=True,
1244
- **(model_configuration or {}),
1409
+ invocation_config=invocation_config,
1410
+ )
1411
+ response_with_stats = await self.model_provider.async_invoke(
1412
+ messages=messages,
1413
+ invoke_response_format=InvokeResponseFormat.USAGE,
1414
+ **(invocation_config or {}),
1415
+ )
1416
+ set_data_by_path(
1417
+ path=self._result_path, data=body, value=response_with_stats
1418
+ )
1419
+ logger.debug(
1420
+ "LLModel async prediction completed",
1421
+ model_name=self.name,
1422
+ answer=response_with_stats.get("answer"),
1423
+ usage=response_with_stats.get("usage"),
1424
+ )
1425
+ else:
1426
+ logger.warning(
1427
+ "LLModel invocation artifact or model provider not set, skipping async prediction",
1428
+ model_name=self.name,
1429
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1430
+ model_provider_type=type(self.model_provider).__name__,
1245
1431
  )
1246
1432
  return body
1247
1433
 
1248
1434
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1249
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1435
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1436
+ messages, invocation_config = self.enrich_prompt(
1437
+ body, origin_name, llm_prompt_artifact
1438
+ )
1439
+ logger.info(
1440
+ "Calling LLModel predict",
1441
+ model_name=self.name,
1442
+ model_endpoint_name=origin_name,
1443
+ messages_len=len(messages) if messages else 0,
1444
+ )
1250
1445
  return self.predict(
1251
- body, messages=messages, model_configuration=model_configuration
1446
+ body,
1447
+ messages=messages,
1448
+ invocation_config=invocation_config,
1449
+ llm_prompt_artifact=llm_prompt_artifact,
1252
1450
  )
1253
1451
 
1254
1452
  async def run_async(
1255
1453
  self, body: Any, path: str, origin_name: Optional[str] = None
1256
1454
  ) -> Any:
1257
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1455
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1456
+ messages, invocation_config = self.enrich_prompt(
1457
+ body, origin_name, llm_prompt_artifact
1458
+ )
1459
+ logger.info(
1460
+ "Calling LLModel async predict",
1461
+ model_name=self.name,
1462
+ model_endpoint_name=origin_name,
1463
+ messages_len=len(messages) if messages else 0,
1464
+ )
1258
1465
  return await self.predict_async(
1259
- body, messages=messages, model_configuration=model_configuration
1466
+ body,
1467
+ messages=messages,
1468
+ invocation_config=invocation_config,
1469
+ llm_prompt_artifact=llm_prompt_artifact,
1260
1470
  )
1261
1471
 
1262
1472
  def enrich_prompt(
1263
- self, body: dict, origin_name: str
1473
+ self,
1474
+ body: dict,
1475
+ origin_name: str,
1476
+ llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
1264
1477
  ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1265
- if origin_name and self.shared_proxy_mapping:
1266
- llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1267
- if isinstance(llm_prompt_artifact, str):
1268
- llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1269
- self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1270
- else:
1271
- llm_prompt_artifact = (
1272
- self.invocation_artifact or self._get_artifact_object()
1273
- )
1274
- if not (
1478
+ logger.info(
1479
+ "Enriching prompt",
1480
+ model_name=self.name,
1481
+ model_endpoint_name=origin_name,
1482
+ )
1483
+ if not llm_prompt_artifact or not (
1275
1484
  llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1276
1485
  ):
1277
1486
  logger.warning(
1278
- "LLMModel must be provided with LLMPromptArtifact",
1487
+ "LLModel must be provided with LLMPromptArtifact",
1488
+ model_name=self.name,
1489
+ artifact_type=type(llm_prompt_artifact).__name__,
1279
1490
  llm_prompt_artifact=llm_prompt_artifact,
1280
1491
  )
1281
- return None, None
1282
- prompt_legend = llm_prompt_artifact.spec.prompt_legend
1283
- prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1492
+ prompt_legend, prompt_template, invocation_config = {}, [], {}
1493
+ else:
1494
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1495
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1496
+ invocation_config = llm_prompt_artifact.spec.invocation_config
1284
1497
  input_data = copy(get_data_from_path(self._input_path, body))
1285
- if isinstance(input_data, dict):
1498
+ if isinstance(input_data, dict) and prompt_template:
1286
1499
  kwargs = (
1287
1500
  {
1288
1501
  place_holder: input_data.get(body_map["field"])
1289
1502
  for place_holder, body_map in prompt_legend.items()
1503
+ if input_data.get(body_map["field"])
1290
1504
  }
1291
1505
  if prompt_legend
1292
1506
  else {}
@@ -1298,23 +1512,61 @@ class LLModel(Model):
1298
1512
  message["content"] = message["content"].format(**input_data)
1299
1513
  except KeyError as e:
1300
1514
  logger.warning(
1301
- "Input data was missing a placeholder, placeholder stay unformatted",
1302
- key_error=e,
1515
+ "Input data missing placeholder, content stays unformatted",
1516
+ model_name=self.name,
1517
+ key_error=mlrun.errors.err_to_str(e),
1303
1518
  )
1304
1519
  message["content"] = message["content"].format_map(
1305
1520
  default_place_holders
1306
1521
  )
1522
+ elif isinstance(input_data, dict) and not prompt_template:
1523
+ # If there is no prompt template, we assume the input data is already in the correct format.
1524
+ logger.debug("Attempting to retrieve messages from the request body.")
1525
+ prompt_template = input_data.get("messages", [])
1307
1526
  else:
1308
1527
  logger.warning(
1309
- f"Expected input data to be a dict, but received input data from type {type(input_data)} prompt "
1310
- f"template stay unformatted",
1528
+ "Expected input data to be a dict, prompt template stays unformatted",
1529
+ model_name=self.name,
1530
+ input_data_type=type(input_data).__name__,
1311
1531
  )
1312
- return prompt_template, llm_prompt_artifact.spec.model_configuration
1532
+ return prompt_template, invocation_config
1533
+
1534
+ def _get_invocation_artifact(
1535
+ self, origin_name: Optional[str] = None
1536
+ ) -> Union[LLMPromptArtifact, None]:
1537
+ """
1538
+ Get the LLMPromptArtifact object for this model.
1539
+
1540
+ :param proxy_uri: Optional; URI to the proxy artifact.
1541
+ :return: LLMPromptArtifact object or None if not found.
1542
+ """
1543
+ if origin_name and self.shared_proxy_mapping:
1544
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1545
+ if isinstance(llm_prompt_artifact, str):
1546
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1547
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1548
+ elif self._artifact_were_loaded:
1549
+ llm_prompt_artifact = self.invocation_artifact
1550
+ else:
1551
+ self._load_artifacts()
1552
+ llm_prompt_artifact = self.invocation_artifact
1553
+ return llm_prompt_artifact
1313
1554
 
1314
1555
 
1315
- class ModelSelector:
1556
+ class ModelSelector(ModelObj):
1316
1557
  """Used to select which models to run on each event."""
1317
1558
 
1559
+ def __init__(self, **kwargs):
1560
+ super().__init__()
1561
+
1562
+ def __init_subclass__(cls):
1563
+ super().__init_subclass__()
1564
+ cls._dict_fields = list(
1565
+ set(cls._dict_fields)
1566
+ | set(inspect.signature(cls.__init__).parameters.keys())
1567
+ )
1568
+ cls._dict_fields.remove("self")
1569
+
1318
1570
  def select(
1319
1571
  self, event, available_models: list[Model]
1320
1572
  ) -> Union[list[str], list[Model]]:
@@ -1406,6 +1658,13 @@ class ModelRunnerStep(MonitoredStep):
1406
1658
  model_runner_step.add_model(..., model_class=MyModel(name="my_model"))
1407
1659
  graph.to(model_runner_step)
1408
1660
 
1661
+ Note when ModelRunnerStep is used in a graph, MLRun automatically imports
1662
+ the default language model class (LLModel) during function deployment.
1663
+
1664
+ Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
1665
+
1666
+ Note see config_pool_resource method documentation for default number of max threads and max processes.
1667
+
1409
1668
  :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1410
1669
  event. Optional. If not passed, all models will be run.
1411
1670
  :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
@@ -1418,22 +1677,49 @@ class ModelRunnerStep(MonitoredStep):
1418
1677
  """
1419
1678
 
1420
1679
  kind = "model_runner"
1421
- _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
1680
+ _dict_fields = MonitoredStep._dict_fields + [
1681
+ "_shared_proxy_mapping",
1682
+ "max_processes",
1683
+ "max_threads",
1684
+ "pool_factor",
1685
+ ]
1422
1686
 
1423
1687
  def __init__(
1424
1688
  self,
1425
1689
  *args,
1426
1690
  name: Optional[str] = None,
1427
1691
  model_selector: Optional[Union[str, ModelSelector]] = None,
1692
+ model_selector_parameters: Optional[dict] = None,
1428
1693
  raise_exception: bool = True,
1429
1694
  **kwargs,
1430
1695
  ):
1696
+ self.max_processes = None
1697
+ self.max_threads = None
1698
+ self.pool_factor = None
1699
+
1700
+ if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1701
+ raise mlrun.errors.MLRunInvalidArgumentError(
1702
+ "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1703
+ "`model_selector_parameters`."
1704
+ )
1705
+ if model_selector:
1706
+ model_selector_parameters = model_selector_parameters or (
1707
+ model_selector.to_dict()
1708
+ if isinstance(model_selector, ModelSelector)
1709
+ else {}
1710
+ )
1711
+ model_selector = (
1712
+ model_selector
1713
+ if isinstance(model_selector, str)
1714
+ else model_selector.__class__.__name__
1715
+ )
1716
+
1431
1717
  super().__init__(
1432
1718
  *args,
1433
1719
  name=name,
1434
1720
  raise_exception=raise_exception,
1435
1721
  class_name="mlrun.serving.ModelRunner",
1436
- class_args=dict(model_selector=model_selector),
1722
+ class_args=dict(model_selector=(model_selector, model_selector_parameters)),
1437
1723
  **kwargs,
1438
1724
  )
1439
1725
  self.raise_exception = raise_exception
@@ -1449,10 +1735,6 @@ class ModelRunnerStep(MonitoredStep):
1449
1735
  model_endpoint_creation_strategy: Optional[
1450
1736
  schemas.ModelEndpointCreationStrategy
1451
1737
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1452
- inputs: Optional[list[str]] = None,
1453
- outputs: Optional[list[str]] = None,
1454
- input_path: Optional[str] = None,
1455
- result_path: Optional[str] = None,
1456
1738
  override: bool = False,
1457
1739
  ) -> None:
1458
1740
  """
@@ -1475,18 +1757,8 @@ class ModelRunnerStep(MonitoredStep):
1475
1757
  1. If model endpoints with the same name exist, preserve them.
1476
1758
  2. Create a new model endpoint with the same name and set it to `latest`.
1477
1759
 
1478
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1479
- that been configured in the model artifact, please note that those inputs need to
1480
- be equal in length and order to the inputs that model_class predict method expects
1481
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1482
- that been configured in the model artifact, please note that those outputs need to
1483
- be equal to the model_class predict method outputs (length, and order)
1484
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1485
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1486
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
1487
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1488
- in path.
1489
1760
  :param override: bool allow override existing model on the current ModelRunnerStep.
1761
+ :raise GraphError: when the shared model is not found in the root flow step shared models.
1490
1762
  """
1491
1763
  model_class, model_params = (
1492
1764
  "mlrun.serving.Model",
@@ -1503,11 +1775,21 @@ class ModelRunnerStep(MonitoredStep):
1503
1775
  "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1504
1776
  )
1505
1777
  root = self._extract_root_step()
1778
+ shared_model_params = {}
1506
1779
  if isinstance(root, RootFlowStep):
1507
- shared_model_name = (
1508
- shared_model_name
1509
- or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1780
+ actual_shared_model_name, shared_model_class, shared_model_params = (
1781
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
1510
1782
  )
1783
+ if not actual_shared_model_name or (
1784
+ shared_model_name and actual_shared_model_name != shared_model_name
1785
+ ):
1786
+ raise GraphError(
1787
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1788
+ f"model {shared_model_name} is not in the shared models."
1789
+ )
1790
+ elif not shared_model_name:
1791
+ shared_model_name = actual_shared_model_name
1792
+ model_params["shared_runnable_name"] = shared_model_name
1511
1793
  if not root.shared_models or (
1512
1794
  root.shared_models
1513
1795
  and shared_model_name
@@ -1517,13 +1799,27 @@ class ModelRunnerStep(MonitoredStep):
1517
1799
  f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1518
1800
  f"model {shared_model_name} is not in the shared models."
1519
1801
  )
1520
- if shared_model_name not in self._shared_proxy_mapping:
1802
+ monitoring_data = self.class_args.get(
1803
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
1804
+ )
1805
+ monitoring_data.setdefault(endpoint_name, {})[
1806
+ schemas.MonitoringData.MODEL_CLASS
1807
+ ] = (
1808
+ shared_model_class
1809
+ if isinstance(shared_model_class, str)
1810
+ else shared_model_class.__class__.__name__
1811
+ )
1812
+ self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
1813
+ monitoring_data
1814
+ )
1815
+
1816
+ if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
1521
1817
  self._shared_proxy_mapping[shared_model_name] = {
1522
1818
  endpoint_name: model_artifact.uri
1523
1819
  if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1524
1820
  else model_artifact
1525
1821
  }
1526
- else:
1822
+ elif override and shared_model_name:
1527
1823
  self._shared_proxy_mapping[shared_model_name].update(
1528
1824
  {
1529
1825
  endpoint_name: model_artifact.uri
@@ -1538,11 +1834,11 @@ class ModelRunnerStep(MonitoredStep):
1538
1834
  model_artifact=model_artifact,
1539
1835
  labels=labels,
1540
1836
  model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1837
+ inputs=shared_model_params.get("inputs"),
1838
+ outputs=shared_model_params.get("outputs"),
1839
+ input_path=shared_model_params.get("input_path"),
1840
+ result_path=shared_model_params.get("result_path"),
1541
1841
  override=override,
1542
- inputs=inputs,
1543
- outputs=outputs,
1544
- input_path=input_path,
1545
- result_path=result_path,
1546
1842
  **model_params,
1547
1843
  )
1548
1844
 
@@ -1567,7 +1863,9 @@ class ModelRunnerStep(MonitoredStep):
1567
1863
  Add a Model to this ModelRunner.
1568
1864
 
1569
1865
  :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1570
- :param model_class: Model class name
1866
+ :param model_class: Model class name. If LLModel is chosen
1867
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
1868
+ outputs will be overridden with UsageResponseKeys fields.
1571
1869
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
1572
1870
  * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
1573
1871
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
@@ -1578,14 +1876,6 @@ class ModelRunnerStep(MonitoredStep):
1578
1876
  otherwise block the main event loop thread.
1579
1877
  * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
1580
1878
  event loop to continue running while waiting for a response.
1581
- * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
1582
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
1583
- useful when:
1584
- - You want to share a heavy resource like a large model loaded onto a GPU.
1585
- - You want to centralize task scheduling or coordination for multiple lightweight tasks.
1586
- - You aim to minimize overhead from creating new executors or processes/threads per runnable.
1587
- The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
1588
- memory and hardware accelerators.
1589
1879
  * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
1590
1880
  It means that the runnable will not actually be run in parallel to anything else.
1591
1881
 
@@ -1608,6 +1898,9 @@ class ModelRunnerStep(MonitoredStep):
1608
1898
  :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1609
1899
  that been configured in the model artifact, please note that those outputs need to
1610
1900
  be equal to the model_class predict method outputs (length, and order)
1901
+
1902
+ When using LLModel, the output will be overridden with UsageResponseKeys.fields().
1903
+
1611
1904
  :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
1612
1905
  this require that the event body will behave like a dict, expects scopes to be
1613
1906
  defined by dot notation (e.g "data.d").
@@ -1636,7 +1929,15 @@ class ModelRunnerStep(MonitoredStep):
1636
1929
  raise mlrun.errors.MLRunInvalidArgumentError(
1637
1930
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
1638
1931
  )
1639
-
1932
+ if type(model_class) is LLModel or (
1933
+ isinstance(model_class, str)
1934
+ and model_class.split(".")[-1] == LLModel.__name__
1935
+ ):
1936
+ if outputs:
1937
+ warnings.warn(
1938
+ "LLModel with existing outputs detected, overriding to default"
1939
+ )
1940
+ outputs = UsageResponseKeys.fields()
1640
1941
  model_parameters = model_parameters or (
1641
1942
  model_class.to_dict() if isinstance(model_class, Model) else {}
1642
1943
  )
@@ -1652,8 +1953,6 @@ class ModelRunnerStep(MonitoredStep):
1652
1953
  except mlrun.errors.MLRunNotFoundError:
1653
1954
  raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
1654
1955
 
1655
- outputs = outputs or self._get_model_output_schema(model_artifact)
1656
-
1657
1956
  model_artifact = (
1658
1957
  model_artifact.uri
1659
1958
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
@@ -1719,28 +2018,13 @@ class ModelRunnerStep(MonitoredStep):
1719
2018
  self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
1720
2019
 
1721
2020
  @staticmethod
1722
- def _get_model_output_schema(
1723
- model_artifact: Union[ModelArtifact, LLMPromptArtifact],
1724
- ) -> Optional[list[str]]:
1725
- if isinstance(
1726
- model_artifact,
1727
- ModelArtifact,
1728
- ):
1729
- return [feature.name for feature in model_artifact.spec.outputs]
1730
- elif isinstance(
1731
- model_artifact,
1732
- LLMPromptArtifact,
1733
- ):
1734
- _model_artifact = model_artifact.model_artifact
1735
- return [feature.name for feature in _model_artifact.spec.outputs]
1736
-
1737
- @staticmethod
1738
- def _get_model_endpoint_output_schema(
2021
+ def _get_model_endpoint_schema(
1739
2022
  name: str,
1740
2023
  project: str,
1741
2024
  uid: str,
1742
- ) -> list[str]:
2025
+ ) -> tuple[list[str], list[str]]:
1743
2026
  output_schema = None
2027
+ input_schema = None
1744
2028
  try:
1745
2029
  model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
1746
2030
  mlrun.db.get_run_db().get_model_endpoint(
@@ -1751,14 +2035,16 @@ class ModelRunnerStep(MonitoredStep):
1751
2035
  )
1752
2036
  )
1753
2037
  output_schema = model_endpoint.spec.label_names
2038
+ input_schema = model_endpoint.spec.feature_names
1754
2039
  except (
1755
2040
  mlrun.errors.MLRunNotFoundError,
1756
2041
  mlrun.errors.MLRunInvalidArgumentError,
1757
- ):
2042
+ ) as ex:
1758
2043
  logger.warning(
1759
- f"Model endpoint not found, using default output schema for model {name}"
2044
+ f"Model endpoint not found, using default output schema for model {name}",
2045
+ error=f"{type(ex).__name__}: {ex}",
1760
2046
  )
1761
- return output_schema
2047
+ return input_schema, output_schema
1762
2048
 
1763
2049
  def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]:
1764
2050
  monitoring_data = deepcopy(
@@ -1768,55 +2054,106 @@ class ModelRunnerStep(MonitoredStep):
1768
2054
  )
1769
2055
  if isinstance(monitoring_data, dict):
1770
2056
  for model in monitoring_data:
1771
- monitoring_data[model][schemas.MonitoringData.OUTPUTS] = (
1772
- monitoring_data.get(model, {}).get(schemas.MonitoringData.OUTPUTS)
1773
- or self._get_model_endpoint_output_schema(
1774
- name=model,
1775
- project=self.context.project if self.context else None,
1776
- uid=monitoring_data.get(model, {}).get(
1777
- mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
1778
- ),
1779
- )
1780
- )
1781
- # Prevent calling _get_model_output_schema for same model more than once
1782
- self.class_args[
1783
- mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
1784
- ][model][schemas.MonitoringData.OUTPUTS] = monitoring_data[model][
1785
- schemas.MonitoringData.OUTPUTS
1786
- ]
1787
2057
  monitoring_data[model][schemas.MonitoringData.INPUT_PATH] = split_path(
1788
2058
  monitoring_data[model][schemas.MonitoringData.INPUT_PATH]
1789
2059
  )
1790
2060
  monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path(
1791
2061
  monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
1792
2062
  )
2063
+
2064
+ mep_output_schema, mep_input_schema = None, None
2065
+
2066
+ output_schema = self.class_args[
2067
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2068
+ ][model][schemas.MonitoringData.OUTPUTS]
2069
+ input_schema = self.class_args[
2070
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2071
+ ][model][schemas.MonitoringData.INPUTS]
2072
+ if not output_schema or not input_schema:
2073
+ # if output or input schema is not provided, try to get it from the model endpoint
2074
+ mep_input_schema, mep_output_schema = (
2075
+ self._get_model_endpoint_schema(
2076
+ model,
2077
+ self.context.project,
2078
+ monitoring_data[model].get(
2079
+ schemas.MonitoringData.MODEL_ENDPOINT_UID, ""
2080
+ ),
2081
+ )
2082
+ )
2083
+ self.class_args[
2084
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2085
+ ][model][schemas.MonitoringData.OUTPUTS] = (
2086
+ output_schema or mep_output_schema
2087
+ )
2088
+ self.class_args[
2089
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2090
+ ][model][schemas.MonitoringData.INPUTS] = (
2091
+ input_schema or mep_input_schema
2092
+ )
1793
2093
  return monitoring_data
2094
+ else:
2095
+ raise mlrun.errors.MLRunInvalidArgumentError(
2096
+ "Monitoring data must be a dictionary."
2097
+ )
2098
+
2099
+ def configure_pool_resource(
2100
+ self,
2101
+ max_processes: Optional[int] = None,
2102
+ max_threads: Optional[int] = None,
2103
+ pool_factor: Optional[int] = None,
2104
+ ) -> None:
2105
+ """
2106
+ Configure the resource limits for the shared models in the graph.
2107
+
2108
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2109
+ Defaults to the number of CPUs or 16 if undetectable.
2110
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2111
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2112
+ """
2113
+ self.max_processes = max_processes
2114
+ self.max_threads = max_threads
2115
+ self.pool_factor = pool_factor
1794
2116
 
1795
2117
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1796
2118
  self.context = context
1797
2119
  if not self._is_local_function(context):
1798
2120
  # skip init of non local functions
1799
2121
  return
1800
- model_selector = self.class_args.get("model_selector")
2122
+ model_selector, model_selector_params = self.class_args.get(
2123
+ "model_selector", (None, None)
2124
+ )
1801
2125
  execution_mechanism_by_model_name = self.class_args.get(
1802
2126
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
1803
2127
  )
1804
2128
  models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1805
- if isinstance(model_selector, str):
1806
- model_selector = get_class(model_selector, namespace)()
2129
+ if model_selector:
2130
+ model_selector = get_class(model_selector, namespace).from_dict(
2131
+ model_selector_params, init_with_params=True
2132
+ )
1807
2133
  model_objects = []
1808
2134
  for model, model_params in models.values():
2135
+ model_name = model_params.get("name")
1809
2136
  model_params[schemas.MonitoringData.INPUT_PATH] = (
1810
2137
  self.class_args.get(
1811
2138
  mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
1812
2139
  )
1813
- .get(model_params.get("name"), {})
2140
+ .get(model_name, {})
1814
2141
  .get(schemas.MonitoringData.INPUT_PATH)
1815
2142
  )
2143
+ model_params[schemas.MonitoringData.RESULT_PATH] = (
2144
+ self.class_args.get(
2145
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
2146
+ )
2147
+ .get(model_name, {})
2148
+ .get(schemas.MonitoringData.RESULT_PATH)
2149
+ )
1816
2150
  model = get_class(model, namespace).from_dict(
1817
2151
  model_params, init_with_params=True
1818
2152
  )
1819
2153
  model._raise_exception = False
2154
+ model._execution_mechanism = execution_mechanism_by_model_name.get(
2155
+ model_name
2156
+ )
1820
2157
  model_objects.append(model)
1821
2158
  self._async_object = ModelRunner(
1822
2159
  model_selector=model_selector,
@@ -1825,6 +2162,9 @@ class ModelRunnerStep(MonitoredStep):
1825
2162
  shared_proxy_mapping=self._shared_proxy_mapping or None,
1826
2163
  name=self.name,
1827
2164
  context=context,
2165
+ max_processes=self.max_processes,
2166
+ max_threads=self.max_threads,
2167
+ pool_factor=self.pool_factor,
1828
2168
  )
1829
2169
 
1830
2170
 
@@ -2552,13 +2892,19 @@ class RootFlowStep(FlowStep):
2552
2892
  model_class: Union[str, Model],
2553
2893
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
2554
2894
  model_artifact: Union[str, ModelArtifact],
2895
+ inputs: Optional[list[str]] = None,
2896
+ outputs: Optional[list[str]] = None,
2897
+ input_path: Optional[str] = None,
2898
+ result_path: Optional[str] = None,
2555
2899
  override: bool = False,
2556
2900
  **model_parameters,
2557
2901
  ) -> None:
2558
2902
  """
2559
2903
  Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2560
2904
  :param name: Name of the shared model (should be unique in the graph)
2561
- :param model_class: Model class name
2905
+ :param model_class: Model class name. If LLModel is chosen
2906
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
2907
+ outputs will be overridden with UsageResponseKeys fields.
2562
2908
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2563
2909
  * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
2564
2910
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
@@ -2581,6 +2927,19 @@ class RootFlowStep(FlowStep):
2581
2927
  It means that the runnable will not actually be run in parallel to anything else.
2582
2928
 
2583
2929
  :param model_artifact: model artifact or mlrun model artifact uri
2930
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2931
+ that been configured in the model artifact, please note that those inputs need
2932
+ to be equal in length and order to the inputs that model_class
2933
+ predict method expects
2934
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2935
+ that been configured in the model artifact, please note that those outputs need
2936
+ to be equal to the model_class
2937
+ predict method outputs (length, and order)
2938
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
2939
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
2940
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
2941
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
2942
+ in path.
2584
2943
  :param override: bool allow override existing model on the current ModelRunnerStep.
2585
2944
  :param model_parameters: Parameters for model instantiation
2586
2945
  """
@@ -2588,6 +2947,15 @@ class RootFlowStep(FlowStep):
2588
2947
  raise mlrun.errors.MLRunInvalidArgumentError(
2589
2948
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2590
2949
  )
2950
+ if type(model_class) is LLModel or (
2951
+ isinstance(model_class, str)
2952
+ and model_class.split(".")[-1] == LLModel.__name__
2953
+ ):
2954
+ if outputs:
2955
+ warnings.warn(
2956
+ "LLModel with existing outputs detected, overriding to default"
2957
+ )
2958
+ outputs = UsageResponseKeys.fields()
2591
2959
 
2592
2960
  if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2593
2961
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2615,6 +2983,14 @@ class RootFlowStep(FlowStep):
2615
2983
  "Inconsistent name for the added model."
2616
2984
  )
2617
2985
  model_parameters["name"] = name
2986
+ model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
2987
+ model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
2988
+ model_parameters["input_path"] = input_path or model_parameters.get(
2989
+ "input_path"
2990
+ )
2991
+ model_parameters["result_path"] = result_path or model_parameters.get(
2992
+ "result_path"
2993
+ )
2618
2994
 
2619
2995
  if name in self.shared_models and not override:
2620
2996
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2629,7 +3005,9 @@ class RootFlowStep(FlowStep):
2629
3005
  self.shared_models[name] = (model_class, model_parameters)
2630
3006
  self.shared_models_mechanism[name] = execution_mechanism
2631
3007
 
2632
- def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
3008
+ def get_shared_model_by_artifact_uri(
3009
+ self, artifact_uri: str
3010
+ ) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
2633
3011
  """
2634
3012
  Get a shared model by its artifact URI.
2635
3013
  :param artifact_uri: The artifact URI of the model.
@@ -2637,10 +3015,10 @@ class RootFlowStep(FlowStep):
2637
3015
  """
2638
3016
  for model_name, (model_class, model_params) in self.shared_models.items():
2639
3017
  if model_params.get("artifact_uri") == artifact_uri:
2640
- return model_name
2641
- return None
3018
+ return model_name, model_class, model_params
3019
+ return None, None, None
2642
3020
 
2643
- def config_pool_resource(
3021
+ def configure_shared_pool_resource(
2644
3022
  self,
2645
3023
  max_processes: Optional[int] = None,
2646
3024
  max_threads: Optional[int] = None,
@@ -2688,6 +3066,7 @@ class RootFlowStep(FlowStep):
2688
3066
  model_params, init_with_params=True
2689
3067
  )
2690
3068
  model._raise_exception = False
3069
+ model._execution_mechanism = self._shared_models_mechanism[model.name]
2691
3070
  self.context.executor.add_runnable(
2692
3071
  model, self._shared_models_mechanism[model.name]
2693
3072
  )
@@ -2807,12 +3186,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
2807
3186
  graph.edge(step.fullname, route.fullname)
2808
3187
 
2809
3188
 
2810
- def _add_graphviz_model_runner(graph, step, source=None):
3189
+ def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
2811
3190
  if source:
2812
3191
  graph.node("_start", source.name, shape=source.shape, style="filled")
2813
3192
  graph.edge("_start", step.fullname)
2814
-
2815
- is_monitored = step._extract_root_step().track_models
2816
3193
  m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
2817
3194
 
2818
3195
  number_of_models = len(
@@ -2851,6 +3228,7 @@ def _add_graphviz_flow(
2851
3228
  allow_empty=True
2852
3229
  )
2853
3230
  graph.node("_start", source.name, shape=source.shape, style="filled")
3231
+ is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
2854
3232
  for start_step in start_steps:
2855
3233
  graph.edge("_start", start_step.fullname)
2856
3234
  for child in step.get_children():
@@ -2859,7 +3237,7 @@ def _add_graphviz_flow(
2859
3237
  with graph.subgraph(name="cluster_" + child.fullname) as sg:
2860
3238
  _add_graphviz_router(sg, child)
2861
3239
  elif kind == StepKinds.model_runner:
2862
- _add_graphviz_model_runner(graph, child)
3240
+ _add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
2863
3241
  else:
2864
3242
  graph.node(child.fullname, label=child.name, shape=child.get_shape())
2865
3243
  _add_edges(child.after or [], step, graph, child)
@@ -3078,7 +3456,7 @@ def _init_async_objects(context, steps):
3078
3456
  datastore_profile = datastore_profile_read(stream_path)
3079
3457
  if isinstance(
3080
3458
  datastore_profile,
3081
- (DatastoreProfileKafkaTarget, DatastoreProfileKafkaSource),
3459
+ (DatastoreProfileKafkaTarget, DatastoreProfileKafkaStream),
3082
3460
  ):
3083
3461
  step._async_object = KafkaStoreyTarget(
3084
3462
  path=stream_path,
@@ -3094,7 +3472,7 @@ def _init_async_objects(context, steps):
3094
3472
  else:
3095
3473
  raise mlrun.errors.MLRunValueError(
3096
3474
  f"Received an unexpected stream profile type: {type(datastore_profile)}\n"
3097
- "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`."
3475
+ "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaStream`."
3098
3476
  )
3099
3477
  elif stream_path.startswith("kafka://") or kafka_brokers:
3100
3478
  topic, brokers = parse_kafka_url(stream_path, kafka_brokers)
@@ -3110,6 +3488,8 @@ def _init_async_objects(context, steps):
3110
3488
  context=context,
3111
3489
  **options,
3112
3490
  )
3491
+ elif stream_path.startswith("dummy://"):
3492
+ step._async_object = _DummyStream(context=context, **options)
3113
3493
  else:
3114
3494
  if stream_path.startswith("v3io://"):
3115
3495
  endpoint, stream_path = parse_path(step.path)