mlrun 1.10.0rc18__py3-none-any.whl → 1.11.0rc16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (167) hide show
  1. mlrun/__init__.py +24 -3
  2. mlrun/__main__.py +0 -4
  3. mlrun/artifacts/dataset.py +2 -2
  4. mlrun/artifacts/document.py +6 -1
  5. mlrun/artifacts/llm_prompt.py +21 -15
  6. mlrun/artifacts/model.py +3 -3
  7. mlrun/artifacts/plots.py +1 -1
  8. mlrun/{model_monitoring/db/tsdb/tdengine → auth}/__init__.py +2 -3
  9. mlrun/auth/nuclio.py +89 -0
  10. mlrun/auth/providers.py +429 -0
  11. mlrun/auth/utils.py +415 -0
  12. mlrun/common/constants.py +14 -0
  13. mlrun/common/model_monitoring/helpers.py +123 -0
  14. mlrun/common/runtimes/constants.py +28 -0
  15. mlrun/common/schemas/__init__.py +14 -3
  16. mlrun/common/schemas/alert.py +2 -2
  17. mlrun/common/schemas/api_gateway.py +3 -0
  18. mlrun/common/schemas/auth.py +12 -10
  19. mlrun/common/schemas/client_spec.py +4 -0
  20. mlrun/common/schemas/constants.py +25 -0
  21. mlrun/common/schemas/frontend_spec.py +1 -8
  22. mlrun/common/schemas/function.py +34 -0
  23. mlrun/common/schemas/hub.py +33 -20
  24. mlrun/common/schemas/model_monitoring/__init__.py +2 -1
  25. mlrun/common/schemas/model_monitoring/constants.py +12 -15
  26. mlrun/common/schemas/model_monitoring/functions.py +13 -4
  27. mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
  28. mlrun/common/schemas/pipeline.py +1 -1
  29. mlrun/common/schemas/secret.py +17 -2
  30. mlrun/common/secrets.py +95 -1
  31. mlrun/common/types.py +10 -10
  32. mlrun/config.py +69 -19
  33. mlrun/data_types/infer.py +2 -2
  34. mlrun/datastore/__init__.py +12 -5
  35. mlrun/datastore/azure_blob.py +162 -47
  36. mlrun/datastore/base.py +274 -10
  37. mlrun/datastore/datastore.py +7 -2
  38. mlrun/datastore/datastore_profile.py +84 -22
  39. mlrun/datastore/model_provider/huggingface_provider.py +225 -41
  40. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  41. mlrun/datastore/model_provider/model_provider.py +206 -74
  42. mlrun/datastore/model_provider/openai_provider.py +226 -66
  43. mlrun/datastore/s3.py +39 -18
  44. mlrun/datastore/sources.py +1 -1
  45. mlrun/datastore/store_resources.py +4 -4
  46. mlrun/datastore/storeytargets.py +17 -12
  47. mlrun/datastore/targets.py +1 -1
  48. mlrun/datastore/utils.py +25 -6
  49. mlrun/datastore/v3io.py +1 -1
  50. mlrun/db/base.py +63 -32
  51. mlrun/db/httpdb.py +373 -153
  52. mlrun/db/nopdb.py +54 -21
  53. mlrun/errors.py +4 -2
  54. mlrun/execution.py +66 -25
  55. mlrun/feature_store/api.py +1 -1
  56. mlrun/feature_store/common.py +1 -1
  57. mlrun/feature_store/feature_vector_utils.py +1 -1
  58. mlrun/feature_store/steps.py +8 -6
  59. mlrun/frameworks/_common/utils.py +3 -3
  60. mlrun/frameworks/_dl_common/loggers/logger.py +1 -1
  61. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -1
  62. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +1 -1
  63. mlrun/frameworks/_ml_common/utils.py +2 -1
  64. mlrun/frameworks/auto_mlrun/auto_mlrun.py +4 -3
  65. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +2 -1
  66. mlrun/frameworks/onnx/dataset.py +2 -1
  67. mlrun/frameworks/onnx/mlrun_interface.py +2 -1
  68. mlrun/frameworks/pytorch/callbacks/logging_callback.py +5 -4
  69. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +2 -1
  70. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +2 -1
  71. mlrun/frameworks/pytorch/utils.py +2 -1
  72. mlrun/frameworks/sklearn/metric.py +2 -1
  73. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +5 -4
  74. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +2 -1
  75. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +2 -1
  76. mlrun/hub/__init__.py +52 -0
  77. mlrun/hub/base.py +142 -0
  78. mlrun/hub/module.py +172 -0
  79. mlrun/hub/step.py +113 -0
  80. mlrun/k8s_utils.py +105 -16
  81. mlrun/launcher/base.py +15 -7
  82. mlrun/launcher/local.py +4 -1
  83. mlrun/model.py +14 -4
  84. mlrun/model_monitoring/__init__.py +0 -1
  85. mlrun/model_monitoring/api.py +65 -28
  86. mlrun/model_monitoring/applications/__init__.py +1 -1
  87. mlrun/model_monitoring/applications/base.py +299 -128
  88. mlrun/model_monitoring/applications/context.py +2 -4
  89. mlrun/model_monitoring/controller.py +132 -58
  90. mlrun/model_monitoring/db/_schedules.py +38 -29
  91. mlrun/model_monitoring/db/_stats.py +6 -16
  92. mlrun/model_monitoring/db/tsdb/__init__.py +9 -7
  93. mlrun/model_monitoring/db/tsdb/base.py +29 -9
  94. mlrun/model_monitoring/db/tsdb/preaggregate.py +234 -0
  95. mlrun/model_monitoring/db/tsdb/stream_graph_steps.py +63 -0
  96. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_metrics_queries.py +414 -0
  97. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_predictions_queries.py +376 -0
  98. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_results_queries.py +590 -0
  99. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connection.py +434 -0
  100. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connector.py +541 -0
  101. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_operations.py +808 -0
  102. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_schema.py +502 -0
  103. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream.py +163 -0
  104. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream_graph_steps.py +60 -0
  105. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_dataframe_processor.py +141 -0
  106. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_query_builder.py +585 -0
  107. mlrun/model_monitoring/db/tsdb/timescaledb/writer_graph_steps.py +73 -0
  108. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +20 -9
  109. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +235 -51
  110. mlrun/model_monitoring/features_drift_table.py +2 -1
  111. mlrun/model_monitoring/helpers.py +30 -6
  112. mlrun/model_monitoring/stream_processing.py +34 -28
  113. mlrun/model_monitoring/writer.py +224 -4
  114. mlrun/package/__init__.py +2 -1
  115. mlrun/platforms/__init__.py +0 -43
  116. mlrun/platforms/iguazio.py +8 -4
  117. mlrun/projects/operations.py +17 -11
  118. mlrun/projects/pipelines.py +2 -2
  119. mlrun/projects/project.py +187 -123
  120. mlrun/run.py +95 -21
  121. mlrun/runtimes/__init__.py +2 -186
  122. mlrun/runtimes/base.py +103 -25
  123. mlrun/runtimes/constants.py +225 -0
  124. mlrun/runtimes/daskjob.py +5 -2
  125. mlrun/runtimes/databricks_job/databricks_runtime.py +2 -1
  126. mlrun/runtimes/local.py +5 -2
  127. mlrun/runtimes/mounts.py +20 -2
  128. mlrun/runtimes/nuclio/__init__.py +12 -7
  129. mlrun/runtimes/nuclio/api_gateway.py +36 -6
  130. mlrun/runtimes/nuclio/application/application.py +339 -40
  131. mlrun/runtimes/nuclio/function.py +222 -72
  132. mlrun/runtimes/nuclio/serving.py +132 -42
  133. mlrun/runtimes/pod.py +213 -21
  134. mlrun/runtimes/utils.py +49 -9
  135. mlrun/secrets.py +99 -14
  136. mlrun/serving/__init__.py +2 -0
  137. mlrun/serving/remote.py +84 -11
  138. mlrun/serving/routers.py +26 -44
  139. mlrun/serving/server.py +138 -51
  140. mlrun/serving/serving_wrapper.py +6 -2
  141. mlrun/serving/states.py +997 -283
  142. mlrun/serving/steps.py +62 -0
  143. mlrun/serving/system_steps.py +149 -95
  144. mlrun/serving/v2_serving.py +9 -10
  145. mlrun/track/trackers/mlflow_tracker.py +29 -31
  146. mlrun/utils/helpers.py +292 -94
  147. mlrun/utils/http.py +9 -2
  148. mlrun/utils/notifications/notification/base.py +18 -0
  149. mlrun/utils/notifications/notification/git.py +3 -5
  150. mlrun/utils/notifications/notification/mail.py +39 -16
  151. mlrun/utils/notifications/notification/slack.py +2 -4
  152. mlrun/utils/notifications/notification/webhook.py +2 -5
  153. mlrun/utils/notifications/notification_pusher.py +3 -3
  154. mlrun/utils/version/version.json +2 -2
  155. mlrun/utils/version/version.py +3 -4
  156. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/METADATA +63 -74
  157. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/RECORD +161 -143
  158. mlrun/api/schemas/__init__.py +0 -259
  159. mlrun/db/auth_utils.py +0 -152
  160. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +0 -344
  161. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +0 -75
  162. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +0 -281
  163. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +0 -1266
  164. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/WHEEL +0 -0
  165. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/entry_points.txt +0 -0
  166. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/licenses/LICENSE +0 -0
  167. {mlrun-1.10.0rc18.dist-info → mlrun-1.11.0rc16.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -24,12 +24,15 @@ import inspect
24
24
  import os
25
25
  import pathlib
26
26
  import traceback
27
+ import warnings
27
28
  from abc import ABC
29
+ from collections.abc import Collection
28
30
  from copy import copy, deepcopy
29
31
  from inspect import getfullargspec, signature
30
32
  from typing import Any, Optional, Union, cast
31
33
 
32
34
  import storey.utils
35
+ from deprecated import deprecated
33
36
  from storey import ParallelExecutionMechanisms
34
37
 
35
38
  import mlrun
@@ -38,17 +41,21 @@ import mlrun.common.schemas as schemas
38
41
  from mlrun.artifacts.llm_prompt import LLMPromptArtifact, PlaceholderDefaultDict
39
42
  from mlrun.artifacts.model import ModelArtifact
40
43
  from mlrun.datastore.datastore_profile import (
41
- DatastoreProfileKafkaSource,
44
+ DatastoreProfileKafkaStream,
42
45
  DatastoreProfileKafkaTarget,
43
46
  DatastoreProfileV3io,
44
47
  datastore_profile_read,
45
48
  )
46
- from mlrun.datastore.model_provider.model_provider import ModelProvider
49
+ from mlrun.datastore.model_provider.model_provider import (
50
+ InvokeResponseFormat,
51
+ ModelProvider,
52
+ UsageResponseKeys,
53
+ )
47
54
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
48
- from mlrun.utils import get_data_from_path, logger, split_path
55
+ from mlrun.utils import get_data_from_path, logger, set_data_by_path, split_path
49
56
 
50
57
  from ..config import config
51
- from ..datastore import get_stream_pusher
58
+ from ..datastore import _DummyStream, get_stream_pusher
52
59
  from ..datastore.utils import (
53
60
  get_kafka_brokers_from_dict,
54
61
  parse_kafka_url,
@@ -85,25 +92,6 @@ class StepKinds:
85
92
  model_runner = "model_runner"
86
93
 
87
94
 
88
- _task_step_fields = [
89
- "kind",
90
- "class_name",
91
- "class_args",
92
- "handler",
93
- "skip_context",
94
- "after",
95
- "function",
96
- "comment",
97
- "shape",
98
- "full_event",
99
- "on_error",
100
- "responder",
101
- "input_path",
102
- "result_path",
103
- "model_endpoint_creation_strategy",
104
- "endpoint_type",
105
- ]
106
-
107
95
  _default_fields_to_strip_from_step = [
108
96
  "model_endpoint_creation_strategy",
109
97
  "endpoint_type",
@@ -129,7 +117,14 @@ def new_remote_endpoint(
129
117
  class BaseStep(ModelObj):
130
118
  kind = "BaseStep"
131
119
  default_shape = "ellipse"
132
- _dict_fields = ["kind", "comment", "after", "on_error"]
120
+ _dict_fields = [
121
+ "kind",
122
+ "comment",
123
+ "after",
124
+ "on_error",
125
+ "max_iterations",
126
+ "cycle_from",
127
+ ]
133
128
  _default_fields_to_strip = _default_fields_to_strip_from_step
134
129
 
135
130
  def __init__(
@@ -137,6 +132,7 @@ class BaseStep(ModelObj):
137
132
  name: Optional[str] = None,
138
133
  after: Optional[list] = None,
139
134
  shape: Optional[str] = None,
135
+ max_iterations: Optional[int] = None,
140
136
  ):
141
137
  self.name = name
142
138
  self._parent = None
@@ -150,6 +146,8 @@ class BaseStep(ModelObj):
150
146
  self.model_endpoint_creation_strategy = (
151
147
  schemas.ModelEndpointCreationStrategy.SKIP
152
148
  )
149
+ self._max_iterations = max_iterations
150
+ self.cycle_from = []
153
151
 
154
152
  def get_shape(self):
155
153
  """graphviz shape"""
@@ -343,6 +341,8 @@ class BaseStep(ModelObj):
343
341
  model_endpoint_creation_strategy: Optional[
344
342
  schemas.ModelEndpointCreationStrategy
345
343
  ] = None,
344
+ cycle_to: Optional[list[str]] = None,
345
+ max_iterations: Optional[int] = None,
346
346
  **class_args,
347
347
  ):
348
348
  """add a step right after this step and return the new step
@@ -372,21 +372,17 @@ class BaseStep(ModelObj):
372
372
  to event["y"] resulting in {"x": 5, "y": <result>}
373
373
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
374
374
 
375
- * **overwrite**:
376
-
377
- 1. If model endpoints with the same name exist, delete the `latest` one.
378
- 2. Create a new model endpoint entry and set it as `latest`.
375
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
376
+ create a new model endpoint entry and set it as `latest`.
379
377
 
380
- * **inplace** (default):
378
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
379
+ entry; otherwise, create a new entry.
381
380
 
382
- 1. If model endpoints with the same name exist, update the `latest` entry.
383
- 2. Otherwise, create a new entry.
384
-
385
- * **archive**:
386
-
387
- 1. If model endpoints with the same name exist, preserve them.
388
- 2. Create a new model endpoint with the same name and set it to `latest`.
381
+ * **archive**: If model endpoints with the same name exist, preserve them;
382
+ create a new model endpoint with the same name and set it to `latest`.
389
383
 
384
+ :param cycle_to: list of step names to create a cycle to (for cyclic graphs)
385
+ :param max_iterations: maximum number of iterations for this step in case of a cycle graph
390
386
  :param class_args: class init arguments
391
387
  """
392
388
  if hasattr(self, "steps"):
@@ -421,8 +417,39 @@ class BaseStep(ModelObj):
421
417
  # check that its not the root, todo: in future may gave nested flows
422
418
  step.after_step(self.name)
423
419
  parent._last_added = step
420
+ step.cycle_to(cycle_to or [])
421
+ step._max_iterations = max_iterations
424
422
  return step
425
423
 
424
+ def cycle_to(self, step_names: Union[str, list[str]]):
425
+ """create a cycle in the graph to the specified step names
426
+
427
+ example:
428
+ in the below example, a cycle is created from 'step3' to 'step1':
429
+ graph.to('step1')\
430
+ .to('step2')\
431
+ .to('step3')\
432
+ .cycle_to(['step1']) # creates a cycle from step3 to step1
433
+
434
+ :param step_names: list of step names to create a cycle to (for cyclic graphs)
435
+ """
436
+ root = self._extract_root_step()
437
+ if not isinstance(root, RootFlowStep):
438
+ raise GraphError("cycle_to() can only be called on a step within a graph")
439
+ if not root.allow_cyclic and step_names:
440
+ raise GraphError("cyclic graphs are not allowed, enable allow_cyclic")
441
+ step_names = [step_names] if isinstance(step_names, str) else step_names
442
+
443
+ for step_name in step_names:
444
+ if step_name not in root:
445
+ raise GraphError(
446
+ f"step {step_name} doesnt exist in the graph under {self._parent.fullname}"
447
+ )
448
+ root[step_name].after_step(self.name, append=True)
449
+ root[step_name].cycle_from.append(self.name)
450
+
451
+ return self
452
+
426
453
  def set_flow(
427
454
  self,
428
455
  steps: list[Union[str, StepToDict, dict[str, Any]]],
@@ -517,7 +544,9 @@ class BaseStep(ModelObj):
517
544
 
518
545
  root = self._extract_root_step()
519
546
 
520
- if not isinstance(root, RootFlowStep):
547
+ if not isinstance(root, RootFlowStep) or (
548
+ isinstance(root, RootFlowStep) and root.engine != "async"
549
+ ):
521
550
  raise GraphError(
522
551
  "ModelRunnerStep can be added to 'Flow' topology graph only"
523
552
  )
@@ -541,8 +570,8 @@ class BaseStep(ModelObj):
541
570
  # Update model endpoints names in the root step
542
571
  root.update_model_endpoints_names(step_model_endpoints_names)
543
572
 
544
- @staticmethod
545
573
  def _verify_shared_models(
574
+ self,
546
575
  root: "RootFlowStep",
547
576
  step: "ModelRunnerStep",
548
577
  step_model_endpoints_names: list[str],
@@ -571,35 +600,41 @@ class BaseStep(ModelObj):
571
600
  prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
572
601
  # if the model artifact is a prompt, we need to get the model URI
573
602
  # to ensure that the shared runnable name is correct
603
+ llm_artifact_uri = None
574
604
  if prefix == mlrun.utils.StorePrefix.LLMPrompt:
575
605
  llm_artifact, _ = mlrun.store_manager.get_store_artifact(
576
606
  model_artifact_uri
577
607
  )
608
+ llm_artifact_uri = llm_artifact.uri
578
609
  model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
579
610
  llm_artifact.spec.parent_uri
580
611
  )
581
- actual_shared_name = root.get_shared_model_name_by_artifact_uri(
582
- model_artifact_uri
612
+ actual_shared_name, shared_model_class, shared_model_params = (
613
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
583
614
  )
584
615
 
585
- if not shared_runnable_name:
586
- if not actual_shared_name:
587
- raise GraphError(
588
- f"Can't find shared model for {name} model endpoint"
589
- )
590
- else:
591
- step.class_args[schemas.ModelRunnerStepData.MODELS][name][
592
- schemas.ModelsData.MODEL_PARAMETERS.value
593
- ]["shared_runnable_name"] = actual_shared_name
594
- shared_models.append(actual_shared_name)
616
+ if not actual_shared_name:
617
+ raise GraphError(
618
+ f"Can't find shared model named {shared_runnable_name}"
619
+ )
620
+ elif not shared_runnable_name:
621
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
622
+ schemas.ModelsData.MODEL_PARAMETERS.value
623
+ ]["shared_runnable_name"] = actual_shared_name
595
624
  elif actual_shared_name != shared_runnable_name:
596
625
  raise GraphError(
597
626
  f"Model endpoint {name} shared runnable name mismatch: "
598
627
  f"expected {actual_shared_name}, got {shared_runnable_name}"
599
628
  )
600
- else:
601
- shared_models.append(actual_shared_name)
602
-
629
+ shared_models.append(actual_shared_name)
630
+ self._edit_proxy_model_data(
631
+ step,
632
+ name,
633
+ actual_shared_name,
634
+ shared_model_params,
635
+ shared_model_class,
636
+ llm_artifact_uri or model_artifact_uri,
637
+ )
603
638
  undefined_shared_models = list(
604
639
  set(shared_models) - set(root.shared_models.keys())
605
640
  )
@@ -608,12 +643,71 @@ class BaseStep(ModelObj):
608
643
  f"The following shared models are not defined in the graph: {undefined_shared_models}."
609
644
  )
610
645
 
646
+ @staticmethod
647
+ def _edit_proxy_model_data(
648
+ step: "ModelRunnerStep",
649
+ name: str,
650
+ actual_shared_name: str,
651
+ shared_model_params: dict,
652
+ shared_model_class: Any,
653
+ artifact: Union[ModelArtifact, LLMPromptArtifact, str],
654
+ ):
655
+ monitoring_data = step.class_args.setdefault(
656
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
657
+ )
658
+
659
+ # edit monitoring data according to the shared model parameters
660
+ monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
661
+ "input_path"
662
+ ]
663
+ monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
664
+ "result_path"
665
+ ]
666
+ monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
667
+ "inputs"
668
+ ]
669
+ monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
670
+ "outputs"
671
+ ]
672
+ monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
673
+ shared_model_class
674
+ if isinstance(shared_model_class, str)
675
+ else shared_model_class.__class__.__name__
676
+ )
677
+ if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
678
+ step._shared_proxy_mapping[actual_shared_name] = {
679
+ name: artifact.uri
680
+ if isinstance(artifact, ModelArtifact | LLMPromptArtifact)
681
+ else artifact
682
+ }
683
+ elif actual_shared_name:
684
+ step._shared_proxy_mapping[actual_shared_name].update(
685
+ {
686
+ name: artifact.uri
687
+ if isinstance(artifact, ModelArtifact | LLMPromptArtifact)
688
+ else artifact
689
+ }
690
+ )
691
+
611
692
 
612
693
  class TaskStep(BaseStep):
613
694
  """task execution step, runs a class or handler"""
614
695
 
615
696
  kind = "task"
616
- _dict_fields = _task_step_fields
697
+ _dict_fields = BaseStep._dict_fields + [
698
+ "class_name",
699
+ "class_args",
700
+ "handler",
701
+ "skip_context",
702
+ "function",
703
+ "shape",
704
+ "full_event",
705
+ "responder",
706
+ "input_path",
707
+ "result_path",
708
+ "model_endpoint_creation_strategy",
709
+ "endpoint_type",
710
+ ]
617
711
  _default_class = ""
618
712
 
619
713
  def __init__(
@@ -639,6 +733,7 @@ class TaskStep(BaseStep):
639
733
  self.handler = handler
640
734
  self.function = function
641
735
  self._handler = None
736
+ self._outlets_selector = None
642
737
  self._object = None
643
738
  self._async_object = None
644
739
  self.skip_context = None
@@ -706,6 +801,8 @@ class TaskStep(BaseStep):
706
801
  handler = "do"
707
802
  if handler:
708
803
  self._handler = getattr(self._object, handler, None)
804
+ if hasattr(self._object, "select_outlets"):
805
+ self._outlets_selector = self._object.select_outlets
709
806
 
710
807
  self._set_error_handler()
711
808
  if mode != "skip":
@@ -879,7 +976,7 @@ class ErrorStep(TaskStep):
879
976
  """error execution step, runs a class or handler"""
880
977
 
881
978
  kind = "error_step"
882
- _dict_fields = _task_step_fields + ["before", "base_step"]
979
+ _dict_fields = TaskStep._dict_fields + ["before", "base_step"]
883
980
  _default_class = ""
884
981
 
885
982
  def __init__(
@@ -916,7 +1013,7 @@ class RouterStep(TaskStep):
916
1013
 
917
1014
  kind = "router"
918
1015
  default_shape = "doubleoctagon"
919
- _dict_fields = _task_step_fields + ["routes", "name"]
1016
+ _dict_fields = TaskStep._dict_fields + ["routes", "name"]
920
1017
  _default_class = "mlrun.serving.ModelRouter"
921
1018
 
922
1019
  def __init__(
@@ -983,20 +1080,14 @@ class RouterStep(TaskStep):
983
1080
  :param function: function this step should run in
984
1081
  :param creation_strategy: Strategy for creating or updating the model endpoint:
985
1082
 
986
- * **overwrite**:
987
-
988
- 1. If model endpoints with the same name exist, delete the `latest` one.
989
- 2. Create a new model endpoint entry and set it as `latest`.
990
-
991
- * **inplace** (default):
1083
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1084
+ create a new model endpoint entry and set it as `latest`.
992
1085
 
993
- 1. If model endpoints with the same name exist, update the `latest` entry.
994
- 2. Otherwise, create a new entry.
1086
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
1087
+ entry;otherwise, create a new entry.
995
1088
 
996
- * **archive**:
997
-
998
- 1. If model endpoints with the same name exist, preserve them.
999
- 2. Create a new model endpoint with the same name and set it to `latest`.
1089
+ * **archive**: If model endpoints with the same name exist, preserve them;
1090
+ create a new model endpoint with the same name and set it to `latest`.
1000
1091
 
1001
1092
  """
1002
1093
  if len(self.routes.keys()) >= MAX_MODELS_PER_ROUTER and key not in self.routes:
@@ -1090,6 +1181,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1090
1181
  "artifact_uri",
1091
1182
  "shared_runnable_name",
1092
1183
  "shared_proxy_mapping",
1184
+ "execution_mechanism",
1093
1185
  ]
1094
1186
  kind = "model"
1095
1187
 
@@ -1111,6 +1203,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1111
1203
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1112
1204
  self.model_artifact: Optional[ModelArtifact] = None
1113
1205
  self.model_provider: Optional[ModelProvider] = None
1206
+ self._artifact_were_loaded = False
1207
+ self._execution_mechanism = None
1114
1208
 
1115
1209
  def __init_subclass__(cls):
1116
1210
  super().__init_subclass__()
@@ -1130,13 +1224,33 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1130
1224
  raise_missing_schema_exception=False,
1131
1225
  )
1132
1226
 
1133
- def _load_artifacts(self) -> None:
1134
- artifact = self._get_artifact_object()
1135
- if isinstance(artifact, LLMPromptArtifact):
1136
- self.invocation_artifact = artifact
1137
- self.model_artifact = self.invocation_artifact.model_artifact
1227
+ # Check if the relevant predict method is implemented when trying to initialize the model
1228
+ if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
1229
+ if self.__class__.predict_async is Model.predict_async:
1230
+ raise mlrun.errors.ModelRunnerError(
1231
+ {
1232
+ self.name: f"is running with {self._execution_mechanism} "
1233
+ f"execution_mechanism but predict_async() is not implemented"
1234
+ }
1235
+ )
1138
1236
  else:
1139
- self.model_artifact = artifact
1237
+ if self.__class__.predict is Model.predict:
1238
+ raise mlrun.errors.ModelRunnerError(
1239
+ {
1240
+ self.name: f"is running with {self._execution_mechanism} execution_mechanism but predict() "
1241
+ f"is not implemented"
1242
+ }
1243
+ )
1244
+
1245
+ def _load_artifacts(self) -> None:
1246
+ if not self._artifact_were_loaded:
1247
+ artifact = self._get_artifact_object()
1248
+ if isinstance(artifact, LLMPromptArtifact):
1249
+ self.invocation_artifact = artifact
1250
+ self.model_artifact = self.invocation_artifact.model_artifact
1251
+ else:
1252
+ self.model_artifact = artifact
1253
+ self._artifact_were_loaded = True
1140
1254
 
1141
1255
  def _get_artifact_object(
1142
1256
  self, proxy_uri: Optional[str] = None
@@ -1144,7 +1258,9 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1144
1258
  uri = proxy_uri or self.artifact_uri
1145
1259
  if uri:
1146
1260
  if mlrun.datastore.is_store_uri(uri):
1147
- artifact, _ = mlrun.store_manager.get_store_artifact(uri)
1261
+ artifact, _ = mlrun.store_manager.get_store_artifact(
1262
+ uri, allow_empty_resources=True
1263
+ )
1148
1264
  return artifact
1149
1265
  else:
1150
1266
  raise ValueError(
@@ -1158,13 +1274,15 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1158
1274
 
1159
1275
  def predict(self, body: Any, **kwargs) -> Any:
1160
1276
  """Override to implement prediction logic. If the logic requires asyncio, override predict_async() instead."""
1161
- return body
1277
+ raise NotImplementedError("predict() method not implemented")
1162
1278
 
1163
1279
  async def predict_async(self, body: Any, **kwargs) -> Any:
1164
1280
  """Override to implement prediction logic if the logic requires asyncio."""
1165
- return body
1281
+ raise NotImplementedError("predict_async() method not implemented")
1166
1282
 
1167
1283
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1284
+ if isinstance(body, list):
1285
+ body = self.format_batch(body)
1168
1286
  return self.predict(body)
1169
1287
 
1170
1288
  async def run_async(
@@ -1203,28 +1321,117 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1203
1321
  return model_file, extra_dataitems
1204
1322
  return None, None
1205
1323
 
1324
+ @staticmethod
1325
+ def format_batch(body: Any):
1326
+ return body
1327
+
1206
1328
 
1207
1329
  class LLModel(Model):
1330
+ """
1331
+ A model wrapper for handling LLM (Large Language Model) prompt-based inference.
1332
+
1333
+ This class extends the base `Model` to provide specialized handling for
1334
+ `LLMPromptArtifact` objects, enabling both synchronous and asynchronous
1335
+ invocation of language models.
1336
+
1337
+ **Model Invocation**:
1338
+
1339
+ - The execution of enriched prompts is delegated to the `model_provider`
1340
+ configured for the model (e.g., **Hugging Face** or **OpenAI**).
1341
+ - The `model_provider` is responsible for sending the prompt to the correct
1342
+ backend API and returning the generated output.
1343
+ - Users can override the `predict` and `predict_async` methods to customize
1344
+ the behavior of the model invocation.
1345
+
1346
+ **Prompt Enrichment Overview**:
1347
+
1348
+ - If an `LLMPromptArtifact` is found, load its prompt template and fill in
1349
+ placeholders using values from the request body.
1350
+ - If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
1351
+ to retrieve `messages` directly from the request body using the input path.
1352
+
1353
+ **Simplified Example**:
1354
+
1355
+ Input body::
1356
+
1357
+ {"city": "Paris", "days": 3}
1358
+
1359
+ Prompt template in artifact::
1360
+
1361
+ [
1362
+ {"role": "system", "content": "You are a travel planning assistant."},
1363
+ {"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
1364
+ ]
1365
+
1366
+ Result after enrichment::
1367
+
1368
+ [
1369
+ {"role": "system", "content": "You are a travel planning assistant."},
1370
+ {"role": "user", "content": "Create a 3-day itinerary for Paris."},
1371
+ ]
1372
+
1373
+ :param name: Name of the model.
1374
+ :param input_path: Path in the request body where input data is located.
1375
+ :param result_path: Path in the response body where model outputs and the statistics
1376
+ will be stored.
1377
+ """
1378
+
1379
+ _dict_fields = Model._dict_fields + ["result_path", "input_path"]
1380
+
1208
1381
  def __init__(
1209
- self, name: str, input_path: Optional[Union[str, list[str]]] = None, **kwargs
1382
+ self,
1383
+ name: str,
1384
+ input_path: Optional[Union[str, list[str]]] = None,
1385
+ result_path: Optional[Union[str, list[str]]] = None,
1386
+ **kwargs,
1210
1387
  ):
1211
1388
  super().__init__(name, **kwargs)
1212
1389
  self._input_path = split_path(input_path)
1390
+ self._result_path = split_path(result_path)
1391
+ logger.info(
1392
+ "LLModel initialized",
1393
+ model_name=name,
1394
+ input_path=input_path,
1395
+ result_path=result_path,
1396
+ )
1213
1397
 
1214
1398
  def predict(
1215
1399
  self,
1216
1400
  body: Any,
1217
1401
  messages: Optional[list[dict]] = None,
1218
- model_configuration: Optional[dict] = None,
1402
+ invocation_config: Optional[dict] = None,
1219
1403
  **kwargs,
1220
1404
  ) -> Any:
1405
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1221
1406
  if isinstance(
1222
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1407
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1223
1408
  ) and isinstance(self.model_provider, ModelProvider):
1224
- body["result"] = self.model_provider.invoke(
1409
+ logger.debug(
1410
+ "Invoking model provider",
1411
+ model_name=self.name,
1412
+ messages=messages,
1413
+ invocation_config=invocation_config,
1414
+ )
1415
+ response_with_stats = self.model_provider.invoke(
1225
1416
  messages=messages,
1226
- as_str=True,
1227
- **(model_configuration or {}),
1417
+ invoke_response_format=InvokeResponseFormat.USAGE,
1418
+ **(invocation_config or {}),
1419
+ )
1420
+ set_data_by_path(
1421
+ path=self._result_path, data=body, value=response_with_stats
1422
+ )
1423
+ logger.debug(
1424
+ "LLModel prediction completed",
1425
+ model_name=self.name,
1426
+ answer=response_with_stats.get("answer"),
1427
+ usage=response_with_stats.get("usage"),
1428
+ )
1429
+ else:
1430
+ logger.warning(
1431
+ "LLModel invocation artifact or model provider not set, skipping prediction",
1432
+ model_name=self.name,
1433
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1434
+ model_provider_type=type(self.model_provider).__name__,
1228
1435
  )
1229
1436
  return body
1230
1437
 
@@ -1232,61 +1439,130 @@ class LLModel(Model):
1232
1439
  self,
1233
1440
  body: Any,
1234
1441
  messages: Optional[list[dict]] = None,
1235
- model_configuration: Optional[dict] = None,
1442
+ invocation_config: Optional[dict] = None,
1236
1443
  **kwargs,
1237
1444
  ) -> Any:
1445
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1238
1446
  if isinstance(
1239
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1447
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1240
1448
  ) and isinstance(self.model_provider, ModelProvider):
1241
- body["result"] = await self.model_provider.async_invoke(
1449
+ logger.debug(
1450
+ "Async invoking model provider",
1451
+ model_name=self.name,
1242
1452
  messages=messages,
1243
- as_str=True,
1244
- **(model_configuration or {}),
1453
+ invocation_config=invocation_config,
1454
+ )
1455
+ response_with_stats = await self.model_provider.async_invoke(
1456
+ messages=messages,
1457
+ invoke_response_format=InvokeResponseFormat.USAGE,
1458
+ **(invocation_config or {}),
1459
+ )
1460
+ set_data_by_path(
1461
+ path=self._result_path, data=body, value=response_with_stats
1462
+ )
1463
+ logger.debug(
1464
+ "LLModel async prediction completed",
1465
+ model_name=self.name,
1466
+ answer=response_with_stats.get("answer"),
1467
+ usage=response_with_stats.get("usage"),
1468
+ )
1469
+ else:
1470
+ logger.warning(
1471
+ "LLModel invocation artifact or model provider not set, skipping async prediction",
1472
+ model_name=self.name,
1473
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1474
+ model_provider_type=type(self.model_provider).__name__,
1245
1475
  )
1246
1476
  return body
1247
1477
 
1478
+ def init(self):
1479
+ super().init()
1480
+
1481
+ if not self.model_provider:
1482
+ if self._execution_mechanism != storey.ParallelExecutionMechanisms.asyncio:
1483
+ unchanged_predict = self.__class__.predict is LLModel.predict
1484
+ predict_function_name = "predict"
1485
+ else:
1486
+ unchanged_predict = (
1487
+ self.__class__.predict_async is LLModel.predict_async
1488
+ )
1489
+ predict_function_name = "predict_async"
1490
+ if unchanged_predict:
1491
+ raise mlrun.errors.MLRunRuntimeError(
1492
+ f"Model provider could not be determined for model '{self.name}',"
1493
+ f" and the {predict_function_name} function was not overridden."
1494
+ )
1495
+
1248
1496
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1249
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1497
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1498
+ messages, invocation_config = self.enrich_prompt(
1499
+ body, origin_name, llm_prompt_artifact
1500
+ )
1501
+ logger.info(
1502
+ "Calling LLModel predict",
1503
+ model_name=self.name,
1504
+ model_endpoint_name=origin_name,
1505
+ messages_len=len(messages) if messages else 0,
1506
+ )
1250
1507
  return self.predict(
1251
- body, messages=messages, model_configuration=model_configuration
1508
+ body,
1509
+ messages=messages,
1510
+ invocation_config=invocation_config,
1511
+ llm_prompt_artifact=llm_prompt_artifact,
1252
1512
  )
1253
1513
 
1254
1514
  async def run_async(
1255
1515
  self, body: Any, path: str, origin_name: Optional[str] = None
1256
1516
  ) -> Any:
1257
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1517
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1518
+ messages, invocation_config = self.enrich_prompt(
1519
+ body, origin_name, llm_prompt_artifact
1520
+ )
1521
+ logger.info(
1522
+ "Calling LLModel async predict",
1523
+ model_name=self.name,
1524
+ model_endpoint_name=origin_name,
1525
+ messages_len=len(messages) if messages else 0,
1526
+ )
1258
1527
  return await self.predict_async(
1259
- body, messages=messages, model_configuration=model_configuration
1528
+ body,
1529
+ messages=messages,
1530
+ invocation_config=invocation_config,
1531
+ llm_prompt_artifact=llm_prompt_artifact,
1260
1532
  )
1261
1533
 
1262
1534
  def enrich_prompt(
1263
- self, body: dict, origin_name: str
1535
+ self,
1536
+ body: dict,
1537
+ origin_name: str,
1538
+ llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
1264
1539
  ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1265
- if origin_name and self.shared_proxy_mapping:
1266
- llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1267
- if isinstance(llm_prompt_artifact, str):
1268
- llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1269
- self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1270
- else:
1271
- llm_prompt_artifact = (
1272
- self.invocation_artifact or self._get_artifact_object()
1273
- )
1274
- if not (
1540
+ logger.info(
1541
+ "Enriching prompt",
1542
+ model_name=self.name,
1543
+ model_endpoint_name=origin_name,
1544
+ )
1545
+ if not llm_prompt_artifact or not (
1275
1546
  llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1276
1547
  ):
1277
1548
  logger.warning(
1278
- "LLMModel must be provided with LLMPromptArtifact",
1549
+ "LLModel must be provided with LLMPromptArtifact",
1550
+ model_name=self.name,
1551
+ artifact_type=type(llm_prompt_artifact).__name__,
1279
1552
  llm_prompt_artifact=llm_prompt_artifact,
1280
1553
  )
1281
- return None, None
1282
- prompt_legend = llm_prompt_artifact.spec.prompt_legend
1283
- prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1554
+ prompt_legend, prompt_template, invocation_config = {}, [], {}
1555
+ else:
1556
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1557
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1558
+ invocation_config = llm_prompt_artifact.spec.invocation_config
1284
1559
  input_data = copy(get_data_from_path(self._input_path, body))
1285
- if isinstance(input_data, dict):
1560
+ if isinstance(input_data, dict) and prompt_template:
1286
1561
  kwargs = (
1287
1562
  {
1288
1563
  place_holder: input_data.get(body_map["field"])
1289
1564
  for place_holder, body_map in prompt_legend.items()
1565
+ if input_data.get(body_map["field"])
1290
1566
  }
1291
1567
  if prompt_legend
1292
1568
  else {}
@@ -1298,23 +1574,124 @@ class LLModel(Model):
1298
1574
  message["content"] = message["content"].format(**input_data)
1299
1575
  except KeyError as e:
1300
1576
  logger.warning(
1301
- "Input data was missing a placeholder, placeholder stay unformatted",
1302
- key_error=e,
1577
+ "Input data missing placeholder, content stays unformatted",
1578
+ model_name=self.name,
1579
+ key_error=mlrun.errors.err_to_str(e),
1303
1580
  )
1304
1581
  message["content"] = message["content"].format_map(
1305
1582
  default_place_holders
1306
1583
  )
1584
+ elif isinstance(input_data, dict) and not prompt_template:
1585
+ # If there is no prompt template, we assume the input data is already in the correct format.
1586
+ logger.debug("Attempting to retrieve messages from the request body.")
1587
+ prompt_template = input_data.get("messages", [])
1307
1588
  else:
1308
1589
  logger.warning(
1309
- f"Expected input data to be a dict, but received input data from type {type(input_data)} prompt "
1310
- f"template stay unformatted",
1590
+ "Expected input data to be a dict, prompt template stays unformatted",
1591
+ model_name=self.name,
1592
+ input_data_type=type(input_data).__name__,
1311
1593
  )
1312
- return prompt_template, llm_prompt_artifact.spec.model_configuration
1594
+ return prompt_template, invocation_config
1595
+
1596
+ def _get_invocation_artifact(
1597
+ self, origin_name: Optional[str] = None
1598
+ ) -> Union[LLMPromptArtifact, None]:
1599
+ """
1600
+ Get the LLMPromptArtifact object for this model.
1601
+
1602
+ :param proxy_uri: Optional; URI to the proxy artifact.
1603
+ :return: LLMPromptArtifact object or None if not found.
1604
+ """
1605
+ if origin_name and self.shared_proxy_mapping:
1606
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1607
+ if isinstance(llm_prompt_artifact, str):
1608
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1609
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1610
+ elif self._artifact_were_loaded:
1611
+ llm_prompt_artifact = self.invocation_artifact
1612
+ else:
1613
+ self._load_artifacts()
1614
+ llm_prompt_artifact = self.invocation_artifact
1615
+ return llm_prompt_artifact
1616
+
1617
+
1618
+ class ModelRunnerSelector(ModelObj):
1619
+ """
1620
+ Strategy for controlling model selection and output routing in ModelRunnerStep.
1621
+
1622
+ Subclass this to implement custom logic for agent workflows:
1623
+ - `select_models()`: Called BEFORE execution to choose which models run
1624
+ - `select_outlets()`: Called AFTER execution to route output to downstream steps
1625
+
1626
+ Return `None` from either method to use default behavior (all models / all outlets).
1627
+
1628
+ Example::
1629
+
1630
+ class ToolSelector(ModelRunnerSelector):
1631
+ def select_outlets(self, event):
1632
+ tool = event.get("tool_call")
1633
+ return [tool] if tool else ["final"]
1634
+ """
1635
+
1636
+ def __init__(self, **kwargs):
1637
+ super().__init__()
1638
+
1639
+ def __init_subclass__(cls):
1640
+ super().__init_subclass__()
1641
+ cls._dict_fields = list(
1642
+ set(cls._dict_fields)
1643
+ | set(inspect.signature(cls.__init__).parameters.keys())
1644
+ )
1645
+ cls._dict_fields.remove("self")
1646
+
1647
+ def select_models(
1648
+ self,
1649
+ event: Any,
1650
+ available_models: list[Model],
1651
+ ) -> Optional[Union[list[str], list[Model]]]:
1652
+ """
1653
+ Called before model execution.
1654
+
1655
+ :param event: The full event
1656
+ :param available_models: List of available models
1657
+
1658
+ Returns the models to execute (by name or Model objects).
1659
+ """
1660
+ return None
1661
+
1662
+ def select_outlets(
1663
+ self,
1664
+ event: Any,
1665
+ ) -> Optional[list[str]]:
1666
+ """
1667
+ Called after model execution.
1668
+
1669
+ :param event: The event body after model execution
1670
+ :return: Returns the downstream outlets to route the event to.
1671
+ """
1672
+ return None
1313
1673
 
1314
1674
 
1315
- class ModelSelector:
1675
+ # TODO: Remove in 1.13.0
1676
+ @deprecated(
1677
+ version="1.11.0",
1678
+ reason="ModelSelector will be removed in 1.13.0, use ModelRunnerSelector instead",
1679
+ category=FutureWarning,
1680
+ )
1681
+ class ModelSelector(ModelObj):
1316
1682
  """Used to select which models to run on each event."""
1317
1683
 
1684
+ def __init__(self, **kwargs):
1685
+ super().__init__()
1686
+
1687
+ def __init_subclass__(cls):
1688
+ super().__init_subclass__()
1689
+ cls._dict_fields = list(
1690
+ set(cls._dict_fields)
1691
+ | set(inspect.signature(cls.__init__).parameters.keys())
1692
+ )
1693
+ cls._dict_fields.remove("self")
1694
+
1318
1695
  def select(
1319
1696
  self, event, available_models: list[Model]
1320
1697
  ) -> Union[list[str], list[Model]]:
@@ -1332,16 +1709,22 @@ class ModelRunner(storey.ParallelExecution):
1332
1709
  """
1333
1710
  Runs multiple Models on each event. See ModelRunnerStep.
1334
1711
 
1335
- :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1336
- event. Optional. If not passed, all models will be run.
1712
+ :param model_runner_selector: ModelSelector instance whose select() method will be used to select models
1713
+ to run on each event. Optional. If not passed, all models will be run.
1337
1714
  """
1338
1715
 
1339
1716
  def __init__(
1340
- self, *args, context, model_selector: Optional[ModelSelector] = None, **kwargs
1717
+ self,
1718
+ *args,
1719
+ context,
1720
+ model_runner_selector: Optional[ModelRunnerSelector] = None,
1721
+ raise_exception: bool = True,
1722
+ **kwargs,
1341
1723
  ):
1342
1724
  super().__init__(*args, **kwargs)
1343
- self.model_selector = model_selector or ModelSelector()
1725
+ self.model_runner_selector = model_runner_selector or ModelRunnerSelector()
1344
1726
  self.context = context
1727
+ self._raise_exception = raise_exception
1345
1728
 
1346
1729
  def preprocess_event(self, event):
1347
1730
  if not hasattr(event, "_metadata"):
@@ -1354,7 +1737,31 @@ class ModelRunner(storey.ParallelExecution):
1354
1737
 
1355
1738
  def select_runnables(self, event):
1356
1739
  models = cast(list[Model], self.runnables)
1357
- return self.model_selector.select(event, models)
1740
+ return self.model_runner_selector.select_models(event, models)
1741
+
1742
+ def select_outlets(self, event) -> Optional[Collection[str]]:
1743
+ sys_outlets = [f"{self.name}_error_raise"]
1744
+ if "background_task_status_step" in self._name_to_outlet:
1745
+ sys_outlets.append("background_task_status_step")
1746
+ if self._raise_exception and self._is_error(event):
1747
+ return sys_outlets
1748
+ user_outlets = self.model_runner_selector.select_outlets(event)
1749
+ if user_outlets:
1750
+ return (
1751
+ user_outlets if isinstance(user_outlets, list) else [user_outlets]
1752
+ ) + sys_outlets
1753
+ return None
1754
+
1755
+ def _is_error(self, event: dict) -> bool:
1756
+ if len(self.runnables) == 1:
1757
+ if isinstance(event, dict):
1758
+ return event.get("error") is not None
1759
+ else:
1760
+ for model in event:
1761
+ body_by_model = event.get(model)
1762
+ if isinstance(body_by_model, dict) and "error" in body_by_model:
1763
+ return True
1764
+ return False
1358
1765
 
1359
1766
 
1360
1767
  class MonitoredStep(ABC, TaskStep, StepToDict):
@@ -1406,34 +1813,122 @@ class ModelRunnerStep(MonitoredStep):
1406
1813
  model_runner_step.add_model(..., model_class=MyModel(name="my_model"))
1407
1814
  graph.to(model_runner_step)
1408
1815
 
1409
- :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1410
- event. Optional. If not passed, all models will be run.
1411
- :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
1412
- an error. If False, the error will appear in the output event.
1816
+ Note when ModelRunnerStep is used in a graph, MLRun automatically imports
1817
+ the default language model class (LLModel) during function deployment.
1818
+
1819
+ Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
1413
1820
 
1414
- :raise ModelRunnerError - when a model raise an error the ModelRunnerStep will handle it, collect errors and outputs
1415
- from added models, If raise_exception is True will raise ModelRunnerError Else will add
1416
- the error msg as part of the event body mapped by model name if more than one model was
1417
- added to the ModelRunnerStep
1821
+ Note see configure_pool_resource method documentation for default number of max threads and max processes.
1822
+
1823
+ :raise ModelRunnerError: when a model raises an error the ModelRunnerStep will handle it, collect errors and
1824
+ outputs from added models. If raise_exception is True will raise ModelRunnerError. Else
1825
+ will add the error msg as part of the event body mapped by model name if more than
1826
+ one model was added to the ModelRunnerStep
1418
1827
  """
1419
1828
 
1420
1829
  kind = "model_runner"
1421
- _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
1830
+ _dict_fields = MonitoredStep._dict_fields + [
1831
+ "_shared_proxy_mapping",
1832
+ "max_processes",
1833
+ "max_threads",
1834
+ "pool_factor",
1835
+ ]
1422
1836
 
1423
1837
  def __init__(
1424
1838
  self,
1425
1839
  *args,
1426
1840
  name: Optional[str] = None,
1841
+ model_runner_selector: Optional[Union[str, ModelRunnerSelector]] = None,
1842
+ model_runner_selector_parameters: Optional[dict] = None,
1427
1843
  model_selector: Optional[Union[str, ModelSelector]] = None,
1844
+ model_selector_parameters: Optional[dict] = None,
1428
1845
  raise_exception: bool = True,
1429
1846
  **kwargs,
1430
1847
  ):
1848
+ """
1849
+
1850
+ :param name: The name of the ModelRunnerStep.
1851
+ :param model_runner_selector: ModelRunnerSelector instance whose select_models()
1852
+ and select_outlets() methods will be used
1853
+ to select models to run on each event and outlets to
1854
+ route the event to.
1855
+ :param model_runner_selector_parameters: Parameters for the model_runner_selector, if model_runner_selector
1856
+ is the class name we will use this param when
1857
+ initializing the selector.
1858
+ :param model_selector: (Deprecated)
1859
+ :param model_selector_parameters: (Deprecated)
1860
+ :param raise_exception: Determines whether to raise ModelRunnerError when one or more models
1861
+ raise an error during execution.
1862
+ If False, errors will be added to the event body.
1863
+ """
1864
+ self.max_processes = None
1865
+ self.max_threads = None
1866
+ self.pool_factor = None
1867
+
1868
+ if (model_selector or model_selector_parameters) and (
1869
+ model_runner_selector or model_runner_selector_parameters
1870
+ ):
1871
+ raise GraphError(
1872
+ "Cannot provide both `model_selector`/`model_selector_parameters` "
1873
+ "and `model_runner_selector`/`model_runner_selector_parameters`. "
1874
+ "Please use only the latter pair."
1875
+ )
1876
+ if model_selector or model_selector_parameters:
1877
+ warnings.warn(
1878
+ "`model_selector` and `model_selector_parameters` are deprecated, "
1879
+ "please use `model_runner_selector` and `model_runner_selector_parameters` instead.",
1880
+ # TODO: Remove this in 1.13.0
1881
+ FutureWarning,
1882
+ )
1883
+ if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1884
+ raise mlrun.errors.MLRunInvalidArgumentError(
1885
+ "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1886
+ "`model_selector_parameters`."
1887
+ )
1888
+ if model_selector:
1889
+ model_selector_parameters = model_selector_parameters or (
1890
+ model_selector.to_dict()
1891
+ if isinstance(model_selector, ModelSelector)
1892
+ else {}
1893
+ )
1894
+ model_selector = (
1895
+ model_selector
1896
+ if isinstance(model_selector, str)
1897
+ else model_selector.__class__.__name__
1898
+ )
1899
+ else:
1900
+ if (
1901
+ isinstance(model_runner_selector, ModelRunnerSelector)
1902
+ and model_runner_selector_parameters
1903
+ ):
1904
+ raise mlrun.errors.MLRunInvalidArgumentError(
1905
+ "Cannot provide a model_runner_selector object as argument to `model_runner_selector` "
1906
+ "and also provide `model_runner_selector_parameters`."
1907
+ )
1908
+ if model_runner_selector:
1909
+ model_runner_selector_parameters = model_runner_selector_parameters or (
1910
+ model_runner_selector.to_dict()
1911
+ if isinstance(model_runner_selector, ModelRunnerSelector)
1912
+ else {}
1913
+ )
1914
+ model_runner_selector = (
1915
+ model_runner_selector
1916
+ if isinstance(model_runner_selector, str)
1917
+ else model_runner_selector.__class__.__name__
1918
+ )
1919
+
1431
1920
  super().__init__(
1432
1921
  *args,
1433
1922
  name=name,
1434
1923
  raise_exception=raise_exception,
1435
1924
  class_name="mlrun.serving.ModelRunner",
1436
- class_args=dict(model_selector=model_selector),
1925
+ class_args=dict(
1926
+ model_selector=(model_selector, model_selector_parameters),
1927
+ model_runner_selector=(
1928
+ model_runner_selector,
1929
+ model_runner_selector_parameters,
1930
+ ),
1931
+ ),
1437
1932
  **kwargs,
1438
1933
  )
1439
1934
  self.raise_exception = raise_exception
@@ -1449,10 +1944,6 @@ class ModelRunnerStep(MonitoredStep):
1449
1944
  model_endpoint_creation_strategy: Optional[
1450
1945
  schemas.ModelEndpointCreationStrategy
1451
1946
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1452
- inputs: Optional[list[str]] = None,
1453
- outputs: Optional[list[str]] = None,
1454
- input_path: Optional[str] = None,
1455
- result_path: Optional[str] = None,
1456
1947
  override: bool = False,
1457
1948
  ) -> None:
1458
1949
  """
@@ -1465,28 +1956,18 @@ class ModelRunnerStep(MonitoredStep):
1465
1956
  :param shared_model_name: str, the name of the shared model that is already defined within the graph
1466
1957
  :param labels: model endpoint labels, should be list of str or mapping of str:str
1467
1958
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1468
- * **overwrite**:
1469
- 1. If model endpoints with the same name exist, delete the `latest` one.
1470
- 2. Create a new model endpoint entry and set it as `latest`.
1471
- * **inplace** (default):
1472
- 1. If model endpoints with the same name exist, update the `latest` entry.
1473
- 2. Otherwise, create a new entry.
1474
- * **archive**:
1475
- 1. If model endpoints with the same name exist, preserve them.
1476
- 2. Create a new model endpoint with the same name and set it to `latest`.
1477
1959
 
1478
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1479
- that been configured in the model artifact, please note that those inputs need to
1480
- be equal in length and order to the inputs that model_class predict method expects
1481
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1482
- that been configured in the model artifact, please note that those outputs need to
1483
- be equal to the model_class predict method outputs (length, and order)
1484
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1485
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1486
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
1487
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1488
- in path.
1960
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1961
+ create a new model endpoint entry and set it as `latest`.
1962
+
1963
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest` entry;
1964
+ otherwise, create a new entry.
1965
+
1966
+ * **archive**: If model endpoints with the same name exist, preserve them;
1967
+ create a new model endpoint with the same name and set it to `latest`.
1968
+
1489
1969
  :param override: bool allow override existing model on the current ModelRunnerStep.
1970
+ :raise GraphError: when the shared model is not found in the root flow step shared models.
1490
1971
  """
1491
1972
  model_class, model_params = (
1492
1973
  "mlrun.serving.Model",
@@ -1503,11 +1984,21 @@ class ModelRunnerStep(MonitoredStep):
1503
1984
  "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1504
1985
  )
1505
1986
  root = self._extract_root_step()
1987
+ shared_model_params = {}
1506
1988
  if isinstance(root, RootFlowStep):
1507
- shared_model_name = (
1508
- shared_model_name
1509
- or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1989
+ actual_shared_model_name, shared_model_class, shared_model_params = (
1990
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
1510
1991
  )
1992
+ if not actual_shared_model_name or (
1993
+ shared_model_name and actual_shared_model_name != shared_model_name
1994
+ ):
1995
+ raise GraphError(
1996
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1997
+ f"model {shared_model_name} is not in the shared models."
1998
+ )
1999
+ elif not shared_model_name:
2000
+ shared_model_name = actual_shared_model_name
2001
+ model_params["shared_runnable_name"] = shared_model_name
1511
2002
  if not root.shared_models or (
1512
2003
  root.shared_models
1513
2004
  and shared_model_name
@@ -1517,17 +2008,31 @@ class ModelRunnerStep(MonitoredStep):
1517
2008
  f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1518
2009
  f"model {shared_model_name} is not in the shared models."
1519
2010
  )
1520
- if shared_model_name not in self._shared_proxy_mapping:
2011
+ monitoring_data = self.class_args.get(
2012
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
2013
+ )
2014
+ monitoring_data.setdefault(endpoint_name, {})[
2015
+ schemas.MonitoringData.MODEL_CLASS
2016
+ ] = (
2017
+ shared_model_class
2018
+ if isinstance(shared_model_class, str)
2019
+ else shared_model_class.__class__.__name__
2020
+ )
2021
+ self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
2022
+ monitoring_data
2023
+ )
2024
+
2025
+ if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
1521
2026
  self._shared_proxy_mapping[shared_model_name] = {
1522
2027
  endpoint_name: model_artifact.uri
1523
- if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
2028
+ if isinstance(model_artifact, ModelArtifact | LLMPromptArtifact)
1524
2029
  else model_artifact
1525
2030
  }
1526
- else:
2031
+ elif override and shared_model_name:
1527
2032
  self._shared_proxy_mapping[shared_model_name].update(
1528
2033
  {
1529
2034
  endpoint_name: model_artifact.uri
1530
- if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
2035
+ if isinstance(model_artifact, ModelArtifact | LLMPromptArtifact)
1531
2036
  else model_artifact
1532
2037
  }
1533
2038
  )
@@ -1538,11 +2043,11 @@ class ModelRunnerStep(MonitoredStep):
1538
2043
  model_artifact=model_artifact,
1539
2044
  labels=labels,
1540
2045
  model_endpoint_creation_strategy=model_endpoint_creation_strategy,
2046
+ inputs=shared_model_params.get("inputs"),
2047
+ outputs=shared_model_params.get("outputs"),
2048
+ input_path=shared_model_params.get("input_path"),
2049
+ result_path=shared_model_params.get("result_path"),
1541
2050
  override=override,
1542
- inputs=inputs,
1543
- outputs=outputs,
1544
- input_path=input_path,
1545
- result_path=result_path,
1546
2051
  **model_params,
1547
2052
  )
1548
2053
 
@@ -1567,48 +2072,52 @@ class ModelRunnerStep(MonitoredStep):
1567
2072
  Add a Model to this ModelRunner.
1568
2073
 
1569
2074
  :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1570
- :param model_class: Model class name
2075
+ :param model_class: Model class name. If LLModel is chosen
2076
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
2077
+ outputs will be overridden with UsageResponseKeys fields.
1571
2078
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
1572
- * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
1573
- intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
1574
- Lock (GIL).
1575
- * "dedicated_process" To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
1576
- tasks that also require significant Runnable-specific initialization (e.g. a large model).
1577
- * "thread_pool" To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
1578
- otherwise block the main event loop thread.
1579
- * "asyncio" To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
1580
- event loop to continue running while waiting for a response.
1581
- * "shared_executor" Reuses an external executor (typically managed by the flow or context) to execute the
1582
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
1583
- useful when:
1584
- - You want to share a heavy resource like a large model loaded onto a GPU.
1585
- - You want to centralize task scheduling or coordination for multiple lightweight tasks.
1586
- - You aim to minimize overhead from creating new executors or processes/threads per runnable.
1587
- The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
1588
- memory and hardware accelerators.
1589
- * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
1590
- It means that the runnable will not actually be run in parallel to anything else.
1591
-
1592
- :param model_artifact: model artifact or mlrun model artifact uri
1593
- :param labels: model endpoint labels, should be list of str or mapping of str:str
1594
- :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1595
- * **overwrite**:
1596
- 1. If model endpoints with the same name exist, delete the `latest` one.
1597
- 2. Create a new model endpoint entry and set it as `latest`.
1598
- * **inplace** (default):
1599
- 1. If model endpoints with the same name exist, update the `latest` entry.
1600
- 2. Otherwise, create a new entry.
1601
- * **archive**:
1602
- 1. If model endpoints with the same name exist, preserve them.
1603
- 2. Create a new model endpoint with the same name and set it to `latest`.
1604
-
1605
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2079
+
2080
+ * **process_pool**: To run in a separate process from a process pool. This is appropriate
2081
+ for CPU or GPU intensive tasks as they would otherwise block the main process by holding
2082
+ Python's Global Interpreter Lock (GIL).
2083
+
2084
+ * **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU
2085
+ or GPU intensive tasks that also require significant Runnable-specific initialization
2086
+ (e.g. a large model).
2087
+
2088
+ * **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks,
2089
+ as they would otherwise block the main event loop thread.
2090
+
2091
+ * **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use
2092
+ asyncio, allowing the event loop to continue running while waiting for a response.
2093
+
2094
+ * **naive**: To run in the main event loop. This is appropriate only for trivial computation
2095
+ and/or file I/O. It means that the runnable will not actually be run in parallel to
2096
+ anything else.
2097
+
2098
+ :param model_artifact: model artifact or mlrun model artifact uri
2099
+ :param labels: model endpoint labels, should be list of str or mapping of str:str
2100
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
2101
+
2102
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
2103
+ create a new model endpoint entry and set it as `latest`.
2104
+
2105
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
2106
+ entry; otherwise, create a new entry.
2107
+
2108
+ * **archive**: If model endpoints with the same name exist, preserve them;
2109
+ create a new model endpoint with the same name and set it to `latest`.
2110
+
2111
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1606
2112
  that been configured in the model artifact, please note that those inputs need to
1607
2113
  be equal in length and order to the inputs that model_class predict method expects
1608
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2114
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1609
2115
  that been configured in the model artifact, please note that those outputs need to
1610
2116
  be equal to the model_class predict method outputs (length, and order)
1611
- :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
2117
+
2118
+ When using LLModel, the output will be overridden with UsageResponseKeys.fields().
2119
+
2120
+ :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
1612
2121
  this require that the event body will behave like a dict, expects scopes to be
1613
2122
  defined by dot notation (e.g "data.d").
1614
2123
  examples: input_path="data.b"
@@ -1618,7 +2127,7 @@ class ModelRunnerStep(MonitoredStep):
1618
2127
  be {"f0": [1, 2]}.
1619
2128
  if a ``list`` or ``list of lists`` is provided, it must follow the order and
1620
2129
  size defined by the input schema.
1621
- :param result_path: when specified selects the key/path in the output event to use as model monitoring
2130
+ :param result_path: when specified selects the key/path in the output event to use as model monitoring
1622
2131
  outputs this require that the output event body will behave like a dict,
1623
2132
  expects scopes to be defined by dot notation (e.g "data.d").
1624
2133
  examples: result_path="out.b"
@@ -1629,14 +2138,22 @@ class ModelRunnerStep(MonitoredStep):
1629
2138
  if a ``list`` or ``list of lists`` is provided, it must follow the order and
1630
2139
  size defined by the output schema.
1631
2140
 
1632
- :param override: bool allow override existing model on the current ModelRunnerStep.
1633
- :param model_parameters: Parameters for model instantiation
2141
+ :param override: bool allow override existing model on the current ModelRunnerStep.
2142
+ :param model_parameters: Parameters for model instantiation
1634
2143
  """
1635
2144
  if isinstance(model_class, Model) and model_parameters:
1636
2145
  raise mlrun.errors.MLRunInvalidArgumentError(
1637
2146
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
1638
2147
  )
1639
-
2148
+ if type(model_class) is LLModel or (
2149
+ isinstance(model_class, str)
2150
+ and model_class.split(".")[-1] == LLModel.__name__
2151
+ ):
2152
+ if outputs:
2153
+ warnings.warn(
2154
+ "LLModel with existing outputs detected, overriding to default"
2155
+ )
2156
+ outputs = UsageResponseKeys.fields()
1640
2157
  model_parameters = model_parameters or (
1641
2158
  model_class.to_dict() if isinstance(model_class, Model) else {}
1642
2159
  )
@@ -1652,8 +2169,6 @@ class ModelRunnerStep(MonitoredStep):
1652
2169
  except mlrun.errors.MLRunNotFoundError:
1653
2170
  raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
1654
2171
 
1655
- outputs = outputs or self._get_model_output_schema(model_artifact)
1656
-
1657
2172
  model_artifact = (
1658
2173
  model_artifact.uri
1659
2174
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
@@ -1719,28 +2234,13 @@ class ModelRunnerStep(MonitoredStep):
1719
2234
  self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
1720
2235
 
1721
2236
  @staticmethod
1722
- def _get_model_output_schema(
1723
- model_artifact: Union[ModelArtifact, LLMPromptArtifact],
1724
- ) -> Optional[list[str]]:
1725
- if isinstance(
1726
- model_artifact,
1727
- ModelArtifact,
1728
- ):
1729
- return [feature.name for feature in model_artifact.spec.outputs]
1730
- elif isinstance(
1731
- model_artifact,
1732
- LLMPromptArtifact,
1733
- ):
1734
- _model_artifact = model_artifact.model_artifact
1735
- return [feature.name for feature in _model_artifact.spec.outputs]
1736
-
1737
- @staticmethod
1738
- def _get_model_endpoint_output_schema(
2237
+ def _get_model_endpoint_schema(
1739
2238
  name: str,
1740
2239
  project: str,
1741
2240
  uid: str,
1742
- ) -> list[str]:
2241
+ ) -> tuple[list[str], list[str]]:
1743
2242
  output_schema = None
2243
+ input_schema = None
1744
2244
  try:
1745
2245
  model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
1746
2246
  mlrun.db.get_run_db().get_model_endpoint(
@@ -1751,6 +2251,7 @@ class ModelRunnerStep(MonitoredStep):
1751
2251
  )
1752
2252
  )
1753
2253
  output_schema = model_endpoint.spec.label_names
2254
+ input_schema = model_endpoint.spec.feature_names
1754
2255
  except (
1755
2256
  mlrun.errors.MLRunNotFoundError,
1756
2257
  mlrun.errors.MLRunInvalidArgumentError,
@@ -1759,7 +2260,7 @@ class ModelRunnerStep(MonitoredStep):
1759
2260
  f"Model endpoint not found, using default output schema for model {name}",
1760
2261
  error=f"{type(ex).__name__}: {ex}",
1761
2262
  )
1762
- return output_schema
2263
+ return input_schema, output_schema
1763
2264
 
1764
2265
  def _calculate_monitoring_data(self) -> dict[str, dict[str, str]]:
1765
2266
  monitoring_data = deepcopy(
@@ -1775,47 +2276,154 @@ class ModelRunnerStep(MonitoredStep):
1775
2276
  monitoring_data[model][schemas.MonitoringData.RESULT_PATH] = split_path(
1776
2277
  monitoring_data[model][schemas.MonitoringData.RESULT_PATH]
1777
2278
  )
2279
+
2280
+ mep_output_schema, mep_input_schema = None, None
2281
+
2282
+ output_schema = self.class_args[
2283
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2284
+ ][model][schemas.MonitoringData.OUTPUTS]
2285
+ input_schema = self.class_args[
2286
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2287
+ ][model][schemas.MonitoringData.INPUTS]
2288
+ if not output_schema or not input_schema:
2289
+ # if output or input schema is not provided, try to get it from the model endpoint
2290
+ mep_input_schema, mep_output_schema = (
2291
+ self._get_model_endpoint_schema(
2292
+ model,
2293
+ self.context.project,
2294
+ monitoring_data[model].get(
2295
+ schemas.MonitoringData.MODEL_ENDPOINT_UID, ""
2296
+ ),
2297
+ )
2298
+ )
2299
+ self.class_args[
2300
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2301
+ ][model][schemas.MonitoringData.OUTPUTS] = (
2302
+ output_schema or mep_output_schema
2303
+ )
2304
+ self.class_args[
2305
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA
2306
+ ][model][schemas.MonitoringData.INPUTS] = (
2307
+ input_schema or mep_input_schema
2308
+ )
1778
2309
  return monitoring_data
1779
2310
  else:
1780
2311
  raise mlrun.errors.MLRunInvalidArgumentError(
1781
2312
  "Monitoring data must be a dictionary."
1782
2313
  )
1783
2314
 
2315
+ def configure_pool_resource(
2316
+ self,
2317
+ max_processes: Optional[int] = None,
2318
+ max_threads: Optional[int] = None,
2319
+ pool_factor: Optional[int] = None,
2320
+ ) -> None:
2321
+ """
2322
+ Configure the resource limits for the shared models in the graph.
2323
+
2324
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2325
+ Defaults to the number of CPUs or 16 if undetectable.
2326
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2327
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2328
+ """
2329
+ self.max_processes = max_processes
2330
+ self.max_threads = max_threads
2331
+ self.pool_factor = pool_factor
2332
+
1784
2333
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1785
2334
  self.context = context
1786
2335
  if not self._is_local_function(context):
1787
2336
  # skip init of non local functions
1788
2337
  return
1789
- model_selector = self.class_args.get("model_selector")
2338
+ model_selector, model_selector_params = self.class_args.get(
2339
+ "model_selector", (None, None)
2340
+ )
2341
+ model_runner_selector, model_runner_selector_params = self.class_args.get(
2342
+ "model_runner_selector", (None, None)
2343
+ )
1790
2344
  execution_mechanism_by_model_name = self.class_args.get(
1791
2345
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
1792
2346
  )
1793
2347
  models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1794
- if isinstance(model_selector, str):
1795
- model_selector = get_class(model_selector, namespace)()
2348
+ if model_selector:
2349
+ model_selector = get_class(model_selector, namespace).from_dict(
2350
+ model_selector_params, init_with_params=True
2351
+ )
2352
+ model_runner_selector = (
2353
+ self._convert_model_selector_to_model_runner_selector(
2354
+ model_selector=model_selector
2355
+ )
2356
+ )
2357
+ elif model_runner_selector:
2358
+ model_runner_selector = get_class(
2359
+ model_runner_selector, namespace
2360
+ ).from_dict(model_runner_selector_params, init_with_params=True)
1796
2361
  model_objects = []
1797
2362
  for model, model_params in models.values():
2363
+ model_name = model_params.get("name")
1798
2364
  model_params[schemas.MonitoringData.INPUT_PATH] = (
1799
2365
  self.class_args.get(
1800
2366
  mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
1801
2367
  )
1802
- .get(model_params.get("name"), {})
2368
+ .get(model_name, {})
1803
2369
  .get(schemas.MonitoringData.INPUT_PATH)
1804
2370
  )
2371
+ model_params[schemas.MonitoringData.RESULT_PATH] = (
2372
+ self.class_args.get(
2373
+ mlrun.common.schemas.ModelRunnerStepData.MONITORING_DATA, {}
2374
+ )
2375
+ .get(model_name, {})
2376
+ .get(schemas.MonitoringData.RESULT_PATH)
2377
+ )
1805
2378
  model = get_class(model, namespace).from_dict(
1806
2379
  model_params, init_with_params=True
1807
2380
  )
1808
2381
  model._raise_exception = False
2382
+ model._execution_mechanism = execution_mechanism_by_model_name.get(
2383
+ model_name
2384
+ )
1809
2385
  model_objects.append(model)
1810
2386
  self._async_object = ModelRunner(
1811
- model_selector=model_selector,
2387
+ model_runner_selector=model_runner_selector,
1812
2388
  runnables=model_objects,
1813
2389
  execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
1814
2390
  shared_proxy_mapping=self._shared_proxy_mapping or None,
1815
2391
  name=self.name,
1816
2392
  context=context,
2393
+ max_processes=self.max_processes,
2394
+ max_threads=self.max_threads,
2395
+ pool_factor=self.pool_factor,
2396
+ raise_exception=self.raise_exception,
2397
+ **extra_kwargs,
1817
2398
  )
1818
2399
 
2400
+ def _convert_model_selector_to_model_runner_selector(
2401
+ self,
2402
+ model_selector,
2403
+ ) -> "ModelRunnerSelector":
2404
+ """
2405
+ Wrap a ModelSelector into a ModelRunnerSelector for backward compatibility.
2406
+ """
2407
+
2408
+ class Adapter(ModelRunnerSelector):
2409
+ def __init__(self):
2410
+ self.selector = model_selector
2411
+
2412
+ def select_models(
2413
+ self, event, available_models
2414
+ ) -> Union[list[str], list[Model]]:
2415
+ # Call old ModelSelector logic
2416
+ return self.selector.select(event, available_models)
2417
+
2418
+ def select_outlets(
2419
+ self,
2420
+ event,
2421
+ ) -> Optional[list[str]]:
2422
+ # By default, return all outlets (old ModelSelector didn't control routing)
2423
+ return None
2424
+
2425
+ return Adapter()
2426
+
1819
2427
 
1820
2428
  class ModelRunnerErrorRaiser(storey.MapClass):
1821
2429
  def __init__(self, raise_exception: bool, models_names: list[str], **kwargs):
@@ -1828,11 +2436,15 @@ class ModelRunnerErrorRaiser(storey.MapClass):
1828
2436
  errors = {}
1829
2437
  should_raise = False
1830
2438
  if len(self._models_names) == 1:
1831
- should_raise = event.body.get("error") is not None
1832
- errors[self._models_names[0]] = event.body.get("error")
2439
+ if isinstance(event.body, dict):
2440
+ should_raise = event.body.get("error") is not None
2441
+ errors[self._models_names[0]] = event.body.get("error")
1833
2442
  else:
1834
2443
  for model in event.body:
1835
- errors[model] = event.body.get(model).get("error")
2444
+ body_by_model = event.body.get(model)
2445
+ errors[model] = None
2446
+ if isinstance(body_by_model, dict):
2447
+ errors[model] = body_by_model.get("error")
1836
2448
  if errors[model] is not None:
1837
2449
  should_raise = True
1838
2450
  if should_raise:
@@ -1902,6 +2514,8 @@ class QueueStep(BaseStep, StepToDict):
1902
2514
  model_endpoint_creation_strategy: Optional[
1903
2515
  schemas.ModelEndpointCreationStrategy
1904
2516
  ] = None,
2517
+ cycle_to: Optional[list[str]] = None,
2518
+ max_iterations: Optional[int] = None,
1905
2519
  **class_args,
1906
2520
  ):
1907
2521
  if not function:
@@ -1919,6 +2533,8 @@ class QueueStep(BaseStep, StepToDict):
1919
2533
  input_path,
1920
2534
  result_path,
1921
2535
  model_endpoint_creation_strategy,
2536
+ cycle_to,
2537
+ max_iterations,
1922
2538
  **class_args,
1923
2539
  )
1924
2540
 
@@ -1954,8 +2570,10 @@ class FlowStep(BaseStep):
1954
2570
  after: Optional[list] = None,
1955
2571
  engine=None,
1956
2572
  final_step=None,
2573
+ allow_cyclic: bool = False,
2574
+ max_iterations: Optional[int] = None,
1957
2575
  ):
1958
- super().__init__(name, after)
2576
+ super().__init__(name, after, max_iterations=max_iterations)
1959
2577
  self._steps = None
1960
2578
  self.steps = steps
1961
2579
  self.engine = engine
@@ -1967,6 +2585,7 @@ class FlowStep(BaseStep):
1967
2585
  self._wait_for_result = False
1968
2586
  self._source = None
1969
2587
  self._start_steps = []
2588
+ self._allow_cyclic = allow_cyclic
1970
2589
 
1971
2590
  def get_children(self):
1972
2591
  return self._steps.values()
@@ -2000,6 +2619,8 @@ class FlowStep(BaseStep):
2000
2619
  model_endpoint_creation_strategy: Optional[
2001
2620
  schemas.ModelEndpointCreationStrategy
2002
2621
  ] = None,
2622
+ cycle_to: Optional[list[str]] = None,
2623
+ max_iterations: Optional[int] = None,
2003
2624
  **class_args,
2004
2625
  ):
2005
2626
  """add task, queue or router step/class to the flow
@@ -2033,21 +2654,17 @@ class FlowStep(BaseStep):
2033
2654
  to event["y"] resulting in {"x": 5, "y": <result>}
2034
2655
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
2035
2656
 
2036
- * **overwrite**:
2037
-
2038
- 1. If model endpoints with the same name exist, delete the `latest` one.
2039
- 2. Create a new model endpoint entry and set it as `latest`.
2657
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
2658
+ create a new model endpoint entry and set it as `latest`.
2040
2659
 
2041
- * **inplace** (default):
2660
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
2661
+ entry; otherwise, create a new entry.
2042
2662
 
2043
- 1. If model endpoints with the same name exist, update the `latest` entry.
2044
- 2. Otherwise, create a new entry.
2045
-
2046
- * **archive**:
2047
-
2048
- 1. If model endpoints with the same name exist, preserve them.
2049
- 2. Create a new model endpoint with the same name and set it to `latest`.
2663
+ * **archive**: If model endpoints with the same name exist, preserve them;
2664
+ create a new model endpoint with the same name and set it to `latest`.
2050
2665
 
2666
+ :param cycle_to: list of step names to create a cycle to (for cyclic graphs)
2667
+ :param max_iterations: maximum number of iterations for this step in case of a cycle graph
2051
2668
  :param class_args: class init arguments
2052
2669
  """
2053
2670
 
@@ -2073,6 +2690,8 @@ class FlowStep(BaseStep):
2073
2690
  after_list = after if isinstance(after, list) else [after]
2074
2691
  for after in after_list:
2075
2692
  self.insert_step(name, step, after, before)
2693
+ step.cycle_to(cycle_to or [])
2694
+ step._max_iterations = max_iterations
2076
2695
  return step
2077
2696
 
2078
2697
  def insert_step(self, key, step, after, before=None):
@@ -2165,13 +2784,24 @@ class FlowStep(BaseStep):
2165
2784
  for step in self._steps.values():
2166
2785
  step._next = None
2167
2786
  step._visited = False
2168
- if step.after:
2787
+ if step.after and not step.cycle_from:
2788
+ has_illegal_branches = len(step.after) > 1 and self.engine == "sync"
2789
+ if has_illegal_branches:
2790
+ raise GraphError(
2791
+ f"synchronous flow engine doesnt support branches use async for step {step.name}"
2792
+ )
2169
2793
  loop_step = has_loop(step, [])
2170
- if loop_step:
2794
+ if loop_step and not self.allow_cyclic:
2171
2795
  raise GraphError(
2172
2796
  f"Error, loop detected in step {loop_step}, graph must be acyclic (DAG)"
2173
2797
  )
2174
- else:
2798
+ elif (
2799
+ step.after
2800
+ and step.cycle_from
2801
+ and set(step.after) == set(step.cycle_from)
2802
+ ):
2803
+ start_steps.append(step.name)
2804
+ elif not step.cycle_from:
2175
2805
  start_steps.append(step.name)
2176
2806
 
2177
2807
  responders = []
@@ -2268,6 +2898,9 @@ class FlowStep(BaseStep):
2268
2898
  def process_step(state, step, root):
2269
2899
  if not state._is_local_function(self.context) or state._visited:
2270
2900
  return
2901
+ state._visited = (
2902
+ True # mark visited to avoid re-visit in case of multiple uplinks
2903
+ )
2271
2904
  for item in state.next or []:
2272
2905
  next_state = root[item]
2273
2906
  if next_state.async_object:
@@ -2278,7 +2911,7 @@ class FlowStep(BaseStep):
2278
2911
  )
2279
2912
 
2280
2913
  default_source, self._wait_for_result = _init_async_objects(
2281
- self.context, self._steps.values()
2914
+ self.context, self._steps.values(), self
2282
2915
  )
2283
2916
 
2284
2917
  source = self._source or default_source
@@ -2509,6 +3142,8 @@ class RootFlowStep(FlowStep):
2509
3142
  "shared_models",
2510
3143
  "shared_models_mechanism",
2511
3144
  "pool_factor",
3145
+ "allow_cyclic",
3146
+ "max_iterations",
2512
3147
  ]
2513
3148
 
2514
3149
  def __init__(
@@ -2518,13 +3153,11 @@ class RootFlowStep(FlowStep):
2518
3153
  after: Optional[list] = None,
2519
3154
  engine=None,
2520
3155
  final_step=None,
3156
+ allow_cyclic: bool = False,
3157
+ max_iterations: Optional[int] = 10_000,
2521
3158
  ):
2522
3159
  super().__init__(
2523
- name,
2524
- steps,
2525
- after,
2526
- engine,
2527
- final_step,
3160
+ name, steps, after, engine, final_step, allow_cyclic, max_iterations
2528
3161
  )
2529
3162
  self._models = set()
2530
3163
  self._route_models = set()
@@ -2535,48 +3168,102 @@ class RootFlowStep(FlowStep):
2535
3168
  self._shared_max_threads = None
2536
3169
  self._pool_factor = None
2537
3170
 
3171
+ @property
3172
+ def max_iterations(self) -> int:
3173
+ return self._max_iterations
3174
+
3175
+ @max_iterations.setter
3176
+ def max_iterations(self, max_iterations: int):
3177
+ self._max_iterations = max_iterations
3178
+
3179
+ @property
3180
+ def allow_cyclic(self) -> bool:
3181
+ return self._allow_cyclic
3182
+
3183
+ @allow_cyclic.setter
3184
+ def allow_cyclic(self, allow_cyclic: bool):
3185
+ self._allow_cyclic = allow_cyclic
3186
+
2538
3187
  def add_shared_model(
2539
3188
  self,
2540
3189
  name: str,
2541
3190
  model_class: Union[str, Model],
2542
3191
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
2543
3192
  model_artifact: Union[str, ModelArtifact],
3193
+ inputs: Optional[list[str]] = None,
3194
+ outputs: Optional[list[str]] = None,
3195
+ input_path: Optional[str] = None,
3196
+ result_path: Optional[str] = None,
2544
3197
  override: bool = False,
2545
3198
  **model_parameters,
2546
3199
  ) -> None:
2547
3200
  """
2548
3201
  Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2549
3202
  :param name: Name of the shared model (should be unique in the graph)
2550
- :param model_class: Model class name
3203
+ :param model_class: Model class name. If LLModel is chosen
3204
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
3205
+ outputs will be overridden with UsageResponseKeys fields.
2551
3206
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2552
- * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
3207
+
3208
+ * **process_pool**: To run in a separate process from a process pool. This is appropriate for CPU or GPU
2553
3209
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
2554
3210
  Lock (GIL).
2555
- * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
2556
- tasks that also require significant Runnable-specific initialization (e.g. a large model).
2557
- * "thread_pool" To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
3211
+
3212
+ * **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU or GPU
3213
+ intensive tasks that also require significant Runnable-specific initialization (e.g. a large model).
3214
+
3215
+ * **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
2558
3216
  otherwise block the main event loop thread.
2559
- * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
3217
+
3218
+ * **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
2560
3219
  event loop to continue running while waiting for a response.
2561
- * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
2562
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
2563
- useful when:
3220
+
3221
+ * **shared_executor**: Reuses an external executor (typically managed by the flow or context) to execute
3222
+ the runnable. Should be used only if you have multiple `ParallelExecution` in the same flow and
3223
+ especially useful when:
3224
+
2564
3225
  - You want to share a heavy resource like a large model loaded onto a GPU.
3226
+
2565
3227
  - You want to centralize task scheduling or coordination for multiple lightweight tasks.
3228
+
2566
3229
  - You aim to minimize overhead from creating new executors or processes/threads per runnable.
3230
+
2567
3231
  The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
2568
3232
  memory and hardware accelerators.
2569
- * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
2570
- It means that the runnable will not actually be run in parallel to anything else.
2571
3233
 
2572
- :param model_artifact: model artifact or mlrun model artifact uri
2573
- :param override: bool allow override existing model on the current ModelRunnerStep.
2574
- :param model_parameters: Parameters for model instantiation
3234
+ * **naive**: To run in the main event loop. This is appropriate only for trivial computation and/or file
3235
+ I/O. It means that the runnable will not actually be run in parallel to anything else.
3236
+
3237
+ :param model_artifact: model artifact or mlrun model artifact uri
3238
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
3239
+ that been configured in the model artifact, please note that those inputs need
3240
+ to be equal in length and order to the inputs that model_class
3241
+ predict method expects
3242
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
3243
+ that been configured in the model artifact, please note that those outputs need
3244
+ to be equal to the model_class
3245
+ predict method outputs (length, and order)
3246
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
3247
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
3248
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
3249
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
3250
+ in path.
3251
+ :param override: bool allow override existing model on the current ModelRunnerStep.
3252
+ :param model_parameters: Parameters for model instantiation
2575
3253
  """
2576
3254
  if isinstance(model_class, Model) and model_parameters:
2577
3255
  raise mlrun.errors.MLRunInvalidArgumentError(
2578
3256
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2579
3257
  )
3258
+ if type(model_class) is LLModel or (
3259
+ isinstance(model_class, str)
3260
+ and model_class.split(".")[-1] == LLModel.__name__
3261
+ ):
3262
+ if outputs:
3263
+ warnings.warn(
3264
+ "LLModel with existing outputs detected, overriding to default"
3265
+ )
3266
+ outputs = UsageResponseKeys.fields()
2580
3267
 
2581
3268
  if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2582
3269
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2604,6 +3291,14 @@ class RootFlowStep(FlowStep):
2604
3291
  "Inconsistent name for the added model."
2605
3292
  )
2606
3293
  model_parameters["name"] = name
3294
+ model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
3295
+ model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
3296
+ model_parameters["input_path"] = input_path or model_parameters.get(
3297
+ "input_path"
3298
+ )
3299
+ model_parameters["result_path"] = result_path or model_parameters.get(
3300
+ "result_path"
3301
+ )
2607
3302
 
2608
3303
  if name in self.shared_models and not override:
2609
3304
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2618,7 +3313,9 @@ class RootFlowStep(FlowStep):
2618
3313
  self.shared_models[name] = (model_class, model_parameters)
2619
3314
  self.shared_models_mechanism[name] = execution_mechanism
2620
3315
 
2621
- def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
3316
+ def get_shared_model_by_artifact_uri(
3317
+ self, artifact_uri: str
3318
+ ) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
2622
3319
  """
2623
3320
  Get a shared model by its artifact URI.
2624
3321
  :param artifact_uri: The artifact URI of the model.
@@ -2626,10 +3323,10 @@ class RootFlowStep(FlowStep):
2626
3323
  """
2627
3324
  for model_name, (model_class, model_params) in self.shared_models.items():
2628
3325
  if model_params.get("artifact_uri") == artifact_uri:
2629
- return model_name
2630
- return None
3326
+ return model_name, model_class, model_params
3327
+ return None, None, None
2631
3328
 
2632
- def config_pool_resource(
3329
+ def configure_shared_pool_resource(
2633
3330
  self,
2634
3331
  max_processes: Optional[int] = None,
2635
3332
  max_threads: Optional[int] = None,
@@ -2637,8 +3334,9 @@ class RootFlowStep(FlowStep):
2637
3334
  ) -> None:
2638
3335
  """
2639
3336
  Configure the resource limits for the shared models in the graph.
3337
+
2640
3338
  :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2641
- Defaults to the number of CPUs or 16 if undetectable.
3339
+ Defaults to the number of CPUs or 16 if undetectable.
2642
3340
  :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2643
3341
  :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2644
3342
  """
@@ -2677,6 +3375,7 @@ class RootFlowStep(FlowStep):
2677
3375
  model_params, init_with_params=True
2678
3376
  )
2679
3377
  model._raise_exception = False
3378
+ model._execution_mechanism = self._shared_models_mechanism[model.name]
2680
3379
  self.context.executor.add_runnable(
2681
3380
  model, self._shared_models_mechanism[model.name]
2682
3381
  )
@@ -2796,12 +3495,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
2796
3495
  graph.edge(step.fullname, route.fullname)
2797
3496
 
2798
3497
 
2799
- def _add_graphviz_model_runner(graph, step, source=None):
3498
+ def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
2800
3499
  if source:
2801
3500
  graph.node("_start", source.name, shape=source.shape, style="filled")
2802
3501
  graph.edge("_start", step.fullname)
2803
-
2804
- is_monitored = step._extract_root_step().track_models
2805
3502
  m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
2806
3503
 
2807
3504
  number_of_models = len(
@@ -2840,6 +3537,7 @@ def _add_graphviz_flow(
2840
3537
  allow_empty=True
2841
3538
  )
2842
3539
  graph.node("_start", source.name, shape=source.shape, style="filled")
3540
+ is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
2843
3541
  for start_step in start_steps:
2844
3542
  graph.edge("_start", start_step.fullname)
2845
3543
  for child in step.get_children():
@@ -2848,7 +3546,7 @@ def _add_graphviz_flow(
2848
3546
  with graph.subgraph(name="cluster_" + child.fullname) as sg:
2849
3547
  _add_graphviz_router(sg, child)
2850
3548
  elif kind == StepKinds.model_runner:
2851
- _add_graphviz_model_runner(graph, child)
3549
+ _add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
2852
3550
  else:
2853
3551
  graph.node(child.fullname, label=child.name, shape=child.get_shape())
2854
3552
  _add_edges(child.after or [], step, graph, child)
@@ -3034,7 +3732,7 @@ def params_to_step(
3034
3732
  return name, step
3035
3733
 
3036
3734
 
3037
- def _init_async_objects(context, steps):
3735
+ def _init_async_objects(context, steps, root):
3038
3736
  try:
3039
3737
  import storey
3040
3738
  except ImportError:
@@ -3049,6 +3747,7 @@ def _init_async_objects(context, steps):
3049
3747
 
3050
3748
  for step in steps:
3051
3749
  if hasattr(step, "async_object") and step._is_local_function(context):
3750
+ max_iterations = step._max_iterations or root.max_iterations
3052
3751
  if step.kind == StepKinds.queue:
3053
3752
  skip_stream = context.is_mock and step.next
3054
3753
  if step.path and not skip_stream:
@@ -3067,23 +3766,25 @@ def _init_async_objects(context, steps):
3067
3766
  datastore_profile = datastore_profile_read(stream_path)
3068
3767
  if isinstance(
3069
3768
  datastore_profile,
3070
- (DatastoreProfileKafkaTarget, DatastoreProfileKafkaSource),
3769
+ DatastoreProfileKafkaTarget | DatastoreProfileKafkaStream,
3071
3770
  ):
3072
3771
  step._async_object = KafkaStoreyTarget(
3073
3772
  path=stream_path,
3074
3773
  context=context,
3774
+ max_iterations=max_iterations,
3075
3775
  **options,
3076
3776
  )
3077
3777
  elif isinstance(datastore_profile, DatastoreProfileV3io):
3078
3778
  step._async_object = StreamStoreyTarget(
3079
3779
  stream_path=stream_path,
3080
3780
  context=context,
3781
+ max_iterations=max_iterations,
3081
3782
  **options,
3082
3783
  )
3083
3784
  else:
3084
3785
  raise mlrun.errors.MLRunValueError(
3085
3786
  f"Received an unexpected stream profile type: {type(datastore_profile)}\n"
3086
- "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaSource`."
3787
+ "Expects `DatastoreProfileV3io` or `DatastoreProfileKafkaStream`."
3087
3788
  )
3088
3789
  elif stream_path.startswith("kafka://") or kafka_brokers:
3089
3790
  topic, brokers = parse_kafka_url(stream_path, kafka_brokers)
@@ -3097,6 +3798,13 @@ def _init_async_objects(context, steps):
3097
3798
  brokers=brokers,
3098
3799
  producer_options=kafka_producer_options,
3099
3800
  context=context,
3801
+ max_iterations=max_iterations,
3802
+ **options,
3803
+ )
3804
+ elif stream_path.startswith("dummy://"):
3805
+ step._async_object = _DummyStream(
3806
+ context=context,
3807
+ max_iterations=max_iterations,
3100
3808
  **options,
3101
3809
  )
3102
3810
  else:
@@ -3107,10 +3815,14 @@ def _init_async_objects(context, steps):
3107
3815
  storey.V3ioDriver(endpoint or config.v3io_api),
3108
3816
  stream_path,
3109
3817
  context=context,
3818
+ max_iterations=max_iterations,
3110
3819
  **options,
3111
3820
  )
3112
3821
  else:
3113
- step._async_object = storey.Map(lambda x: x)
3822
+ step._async_object = storey.Map(
3823
+ lambda x: x,
3824
+ max_iterations=max_iterations,
3825
+ )
3114
3826
 
3115
3827
  elif not step.async_object or not hasattr(step.async_object, "_outlets"):
3116
3828
  # if regular class, wrap with storey Map
@@ -3122,6 +3834,8 @@ def _init_async_objects(context, steps):
3122
3834
  name=step.name,
3123
3835
  context=context,
3124
3836
  pass_context=step._inject_context,
3837
+ fn_select_outlets=step._outlets_selector,
3838
+ max_iterations=max_iterations,
3125
3839
  )
3126
3840
  if (
3127
3841
  respond_supported