mlrun 1.10.0rc40__py3-none-any.whl → 1.11.0rc16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (150) hide show
  1. mlrun/__init__.py +3 -2
  2. mlrun/__main__.py +0 -4
  3. mlrun/artifacts/dataset.py +2 -2
  4. mlrun/artifacts/plots.py +1 -1
  5. mlrun/{model_monitoring/db/tsdb/tdengine → auth}/__init__.py +2 -3
  6. mlrun/auth/nuclio.py +89 -0
  7. mlrun/auth/providers.py +429 -0
  8. mlrun/auth/utils.py +415 -0
  9. mlrun/common/constants.py +7 -0
  10. mlrun/common/model_monitoring/helpers.py +41 -4
  11. mlrun/common/runtimes/constants.py +28 -0
  12. mlrun/common/schemas/__init__.py +13 -3
  13. mlrun/common/schemas/alert.py +2 -2
  14. mlrun/common/schemas/api_gateway.py +3 -0
  15. mlrun/common/schemas/auth.py +10 -10
  16. mlrun/common/schemas/client_spec.py +4 -0
  17. mlrun/common/schemas/constants.py +25 -0
  18. mlrun/common/schemas/frontend_spec.py +1 -8
  19. mlrun/common/schemas/function.py +24 -0
  20. mlrun/common/schemas/hub.py +3 -2
  21. mlrun/common/schemas/model_monitoring/__init__.py +1 -1
  22. mlrun/common/schemas/model_monitoring/constants.py +2 -2
  23. mlrun/common/schemas/secret.py +17 -2
  24. mlrun/common/secrets.py +95 -1
  25. mlrun/common/types.py +10 -10
  26. mlrun/config.py +53 -15
  27. mlrun/data_types/infer.py +2 -2
  28. mlrun/datastore/__init__.py +2 -3
  29. mlrun/datastore/base.py +274 -10
  30. mlrun/datastore/datastore.py +1 -1
  31. mlrun/datastore/datastore_profile.py +49 -17
  32. mlrun/datastore/model_provider/huggingface_provider.py +6 -2
  33. mlrun/datastore/model_provider/model_provider.py +2 -2
  34. mlrun/datastore/model_provider/openai_provider.py +2 -2
  35. mlrun/datastore/s3.py +15 -16
  36. mlrun/datastore/sources.py +1 -1
  37. mlrun/datastore/store_resources.py +4 -4
  38. mlrun/datastore/storeytargets.py +16 -10
  39. mlrun/datastore/targets.py +1 -1
  40. mlrun/datastore/utils.py +16 -3
  41. mlrun/datastore/v3io.py +1 -1
  42. mlrun/db/base.py +36 -12
  43. mlrun/db/httpdb.py +316 -101
  44. mlrun/db/nopdb.py +29 -11
  45. mlrun/errors.py +4 -2
  46. mlrun/execution.py +11 -12
  47. mlrun/feature_store/api.py +1 -1
  48. mlrun/feature_store/common.py +1 -1
  49. mlrun/feature_store/feature_vector_utils.py +1 -1
  50. mlrun/feature_store/steps.py +8 -6
  51. mlrun/frameworks/_common/utils.py +3 -3
  52. mlrun/frameworks/_dl_common/loggers/logger.py +1 -1
  53. mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +2 -1
  54. mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +1 -1
  55. mlrun/frameworks/_ml_common/utils.py +2 -1
  56. mlrun/frameworks/auto_mlrun/auto_mlrun.py +4 -3
  57. mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +2 -1
  58. mlrun/frameworks/onnx/dataset.py +2 -1
  59. mlrun/frameworks/onnx/mlrun_interface.py +2 -1
  60. mlrun/frameworks/pytorch/callbacks/logging_callback.py +5 -4
  61. mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +2 -1
  62. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +2 -1
  63. mlrun/frameworks/pytorch/utils.py +2 -1
  64. mlrun/frameworks/sklearn/metric.py +2 -1
  65. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +5 -4
  66. mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +2 -1
  67. mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +2 -1
  68. mlrun/hub/__init__.py +37 -0
  69. mlrun/hub/base.py +142 -0
  70. mlrun/hub/module.py +67 -76
  71. mlrun/hub/step.py +113 -0
  72. mlrun/launcher/base.py +2 -1
  73. mlrun/launcher/local.py +2 -1
  74. mlrun/model.py +12 -2
  75. mlrun/model_monitoring/__init__.py +0 -1
  76. mlrun/model_monitoring/api.py +2 -2
  77. mlrun/model_monitoring/applications/base.py +20 -6
  78. mlrun/model_monitoring/applications/context.py +1 -0
  79. mlrun/model_monitoring/controller.py +7 -17
  80. mlrun/model_monitoring/db/_schedules.py +2 -16
  81. mlrun/model_monitoring/db/_stats.py +2 -13
  82. mlrun/model_monitoring/db/tsdb/__init__.py +9 -7
  83. mlrun/model_monitoring/db/tsdb/base.py +2 -4
  84. mlrun/model_monitoring/db/tsdb/preaggregate.py +234 -0
  85. mlrun/model_monitoring/db/tsdb/stream_graph_steps.py +63 -0
  86. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_metrics_queries.py +414 -0
  87. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_predictions_queries.py +376 -0
  88. mlrun/model_monitoring/db/tsdb/timescaledb/queries/timescaledb_results_queries.py +590 -0
  89. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connection.py +434 -0
  90. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_connector.py +541 -0
  91. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_operations.py +808 -0
  92. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_schema.py +502 -0
  93. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream.py +163 -0
  94. mlrun/model_monitoring/db/tsdb/timescaledb/timescaledb_stream_graph_steps.py +60 -0
  95. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_dataframe_processor.py +141 -0
  96. mlrun/model_monitoring/db/tsdb/timescaledb/utils/timescaledb_query_builder.py +585 -0
  97. mlrun/model_monitoring/db/tsdb/timescaledb/writer_graph_steps.py +73 -0
  98. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +4 -6
  99. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +147 -79
  100. mlrun/model_monitoring/features_drift_table.py +2 -1
  101. mlrun/model_monitoring/helpers.py +2 -1
  102. mlrun/model_monitoring/stream_processing.py +18 -16
  103. mlrun/model_monitoring/writer.py +4 -3
  104. mlrun/package/__init__.py +2 -1
  105. mlrun/platforms/__init__.py +0 -44
  106. mlrun/platforms/iguazio.py +1 -1
  107. mlrun/projects/operations.py +11 -10
  108. mlrun/projects/project.py +81 -82
  109. mlrun/run.py +4 -7
  110. mlrun/runtimes/__init__.py +2 -204
  111. mlrun/runtimes/base.py +89 -21
  112. mlrun/runtimes/constants.py +225 -0
  113. mlrun/runtimes/daskjob.py +4 -2
  114. mlrun/runtimes/databricks_job/databricks_runtime.py +2 -1
  115. mlrun/runtimes/mounts.py +5 -0
  116. mlrun/runtimes/nuclio/__init__.py +12 -8
  117. mlrun/runtimes/nuclio/api_gateway.py +36 -6
  118. mlrun/runtimes/nuclio/application/application.py +200 -32
  119. mlrun/runtimes/nuclio/function.py +154 -49
  120. mlrun/runtimes/nuclio/serving.py +55 -42
  121. mlrun/runtimes/pod.py +59 -10
  122. mlrun/secrets.py +46 -2
  123. mlrun/serving/__init__.py +2 -0
  124. mlrun/serving/remote.py +5 -5
  125. mlrun/serving/routers.py +3 -3
  126. mlrun/serving/server.py +46 -43
  127. mlrun/serving/serving_wrapper.py +6 -2
  128. mlrun/serving/states.py +554 -207
  129. mlrun/serving/steps.py +1 -1
  130. mlrun/serving/system_steps.py +42 -33
  131. mlrun/track/trackers/mlflow_tracker.py +29 -31
  132. mlrun/utils/helpers.py +89 -16
  133. mlrun/utils/http.py +9 -2
  134. mlrun/utils/notifications/notification/git.py +1 -1
  135. mlrun/utils/notifications/notification/mail.py +39 -16
  136. mlrun/utils/notifications/notification_pusher.py +2 -2
  137. mlrun/utils/version/version.json +2 -2
  138. mlrun/utils/version/version.py +3 -4
  139. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/METADATA +39 -49
  140. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/RECORD +144 -130
  141. mlrun/db/auth_utils.py +0 -152
  142. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +0 -343
  143. mlrun/model_monitoring/db/tsdb/tdengine/stream_graph_steps.py +0 -75
  144. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connection.py +0 -281
  145. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +0 -1368
  146. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +0 -51
  147. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/WHEEL +0 -0
  148. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/entry_points.txt +0 -0
  149. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/licenses/LICENSE +0 -0
  150. {mlrun-1.10.0rc40.dist-info → mlrun-1.11.0rc16.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -26,11 +26,13 @@ import pathlib
26
26
  import traceback
27
27
  import warnings
28
28
  from abc import ABC
29
+ from collections.abc import Collection
29
30
  from copy import copy, deepcopy
30
31
  from inspect import getfullargspec, signature
31
32
  from typing import Any, Optional, Union, cast
32
33
 
33
34
  import storey.utils
35
+ from deprecated import deprecated
34
36
  from storey import ParallelExecutionMechanisms
35
37
 
36
38
  import mlrun
@@ -90,25 +92,6 @@ class StepKinds:
90
92
  model_runner = "model_runner"
91
93
 
92
94
 
93
- _task_step_fields = [
94
- "kind",
95
- "class_name",
96
- "class_args",
97
- "handler",
98
- "skip_context",
99
- "after",
100
- "function",
101
- "comment",
102
- "shape",
103
- "full_event",
104
- "on_error",
105
- "responder",
106
- "input_path",
107
- "result_path",
108
- "model_endpoint_creation_strategy",
109
- "endpoint_type",
110
- ]
111
-
112
95
  _default_fields_to_strip_from_step = [
113
96
  "model_endpoint_creation_strategy",
114
97
  "endpoint_type",
@@ -134,7 +117,14 @@ def new_remote_endpoint(
134
117
  class BaseStep(ModelObj):
135
118
  kind = "BaseStep"
136
119
  default_shape = "ellipse"
137
- _dict_fields = ["kind", "comment", "after", "on_error"]
120
+ _dict_fields = [
121
+ "kind",
122
+ "comment",
123
+ "after",
124
+ "on_error",
125
+ "max_iterations",
126
+ "cycle_from",
127
+ ]
138
128
  _default_fields_to_strip = _default_fields_to_strip_from_step
139
129
 
140
130
  def __init__(
@@ -142,6 +132,7 @@ class BaseStep(ModelObj):
142
132
  name: Optional[str] = None,
143
133
  after: Optional[list] = None,
144
134
  shape: Optional[str] = None,
135
+ max_iterations: Optional[int] = None,
145
136
  ):
146
137
  self.name = name
147
138
  self._parent = None
@@ -155,6 +146,8 @@ class BaseStep(ModelObj):
155
146
  self.model_endpoint_creation_strategy = (
156
147
  schemas.ModelEndpointCreationStrategy.SKIP
157
148
  )
149
+ self._max_iterations = max_iterations
150
+ self.cycle_from = []
158
151
 
159
152
  def get_shape(self):
160
153
  """graphviz shape"""
@@ -348,6 +341,8 @@ class BaseStep(ModelObj):
348
341
  model_endpoint_creation_strategy: Optional[
349
342
  schemas.ModelEndpointCreationStrategy
350
343
  ] = None,
344
+ cycle_to: Optional[list[str]] = None,
345
+ max_iterations: Optional[int] = None,
351
346
  **class_args,
352
347
  ):
353
348
  """add a step right after this step and return the new step
@@ -377,21 +372,17 @@ class BaseStep(ModelObj):
377
372
  to event["y"] resulting in {"x": 5, "y": <result>}
378
373
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
379
374
 
380
- * **overwrite**:
375
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
376
+ create a new model endpoint entry and set it as `latest`.
381
377
 
382
- 1. If model endpoints with the same name exist, delete the `latest` one.
383
- 2. Create a new model endpoint entry and set it as `latest`.
378
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
379
+ entry; otherwise, create a new entry.
384
380
 
385
- * **inplace** (default):
386
-
387
- 1. If model endpoints with the same name exist, update the `latest` entry.
388
- 2. Otherwise, create a new entry.
389
-
390
- * **archive**:
391
-
392
- 1. If model endpoints with the same name exist, preserve them.
393
- 2. Create a new model endpoint with the same name and set it to `latest`.
381
+ * **archive**: If model endpoints with the same name exist, preserve them;
382
+ create a new model endpoint with the same name and set it to `latest`.
394
383
 
384
+ :param cycle_to: list of step names to create a cycle to (for cyclic graphs)
385
+ :param max_iterations: maximum number of iterations for this step in case of a cycle graph
395
386
  :param class_args: class init arguments
396
387
  """
397
388
  if hasattr(self, "steps"):
@@ -426,8 +417,39 @@ class BaseStep(ModelObj):
426
417
  # check that its not the root, todo: in future may gave nested flows
427
418
  step.after_step(self.name)
428
419
  parent._last_added = step
420
+ step.cycle_to(cycle_to or [])
421
+ step._max_iterations = max_iterations
429
422
  return step
430
423
 
424
+ def cycle_to(self, step_names: Union[str, list[str]]):
425
+ """create a cycle in the graph to the specified step names
426
+
427
+ example:
428
+ in the below example, a cycle is created from 'step3' to 'step1':
429
+ graph.to('step1')\
430
+ .to('step2')\
431
+ .to('step3')\
432
+ .cycle_to(['step1']) # creates a cycle from step3 to step1
433
+
434
+ :param step_names: list of step names to create a cycle to (for cyclic graphs)
435
+ """
436
+ root = self._extract_root_step()
437
+ if not isinstance(root, RootFlowStep):
438
+ raise GraphError("cycle_to() can only be called on a step within a graph")
439
+ if not root.allow_cyclic and step_names:
440
+ raise GraphError("cyclic graphs are not allowed, enable allow_cyclic")
441
+ step_names = [step_names] if isinstance(step_names, str) else step_names
442
+
443
+ for step_name in step_names:
444
+ if step_name not in root:
445
+ raise GraphError(
446
+ f"step {step_name} doesnt exist in the graph under {self._parent.fullname}"
447
+ )
448
+ root[step_name].after_step(self.name, append=True)
449
+ root[step_name].cycle_from.append(self.name)
450
+
451
+ return self
452
+
431
453
  def set_flow(
432
454
  self,
433
455
  steps: list[Union[str, StepToDict, dict[str, Any]]],
@@ -591,15 +613,14 @@ class BaseStep(ModelObj):
591
613
  root.get_shared_model_by_artifact_uri(model_artifact_uri)
592
614
  )
593
615
 
594
- if not shared_runnable_name:
595
- if not actual_shared_name:
596
- raise GraphError(
597
- f"Can't find shared model for {name} model endpoint"
598
- )
599
- else:
600
- step.class_args[schemas.ModelRunnerStepData.MODELS][name][
601
- schemas.ModelsData.MODEL_PARAMETERS.value
602
- ]["shared_runnable_name"] = actual_shared_name
616
+ if not actual_shared_name:
617
+ raise GraphError(
618
+ f"Can't find shared model named {shared_runnable_name}"
619
+ )
620
+ elif not shared_runnable_name:
621
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
622
+ schemas.ModelsData.MODEL_PARAMETERS.value
623
+ ]["shared_runnable_name"] = actual_shared_name
603
624
  elif actual_shared_name != shared_runnable_name:
604
625
  raise GraphError(
605
626
  f"Model endpoint {name} shared runnable name mismatch: "
@@ -656,14 +677,14 @@ class BaseStep(ModelObj):
656
677
  if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
657
678
  step._shared_proxy_mapping[actual_shared_name] = {
658
679
  name: artifact.uri
659
- if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
680
+ if isinstance(artifact, ModelArtifact | LLMPromptArtifact)
660
681
  else artifact
661
682
  }
662
683
  elif actual_shared_name:
663
684
  step._shared_proxy_mapping[actual_shared_name].update(
664
685
  {
665
686
  name: artifact.uri
666
- if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
687
+ if isinstance(artifact, ModelArtifact | LLMPromptArtifact)
667
688
  else artifact
668
689
  }
669
690
  )
@@ -673,7 +694,20 @@ class TaskStep(BaseStep):
673
694
  """task execution step, runs a class or handler"""
674
695
 
675
696
  kind = "task"
676
- _dict_fields = _task_step_fields
697
+ _dict_fields = BaseStep._dict_fields + [
698
+ "class_name",
699
+ "class_args",
700
+ "handler",
701
+ "skip_context",
702
+ "function",
703
+ "shape",
704
+ "full_event",
705
+ "responder",
706
+ "input_path",
707
+ "result_path",
708
+ "model_endpoint_creation_strategy",
709
+ "endpoint_type",
710
+ ]
677
711
  _default_class = ""
678
712
 
679
713
  def __init__(
@@ -699,6 +733,7 @@ class TaskStep(BaseStep):
699
733
  self.handler = handler
700
734
  self.function = function
701
735
  self._handler = None
736
+ self._outlets_selector = None
702
737
  self._object = None
703
738
  self._async_object = None
704
739
  self.skip_context = None
@@ -766,6 +801,8 @@ class TaskStep(BaseStep):
766
801
  handler = "do"
767
802
  if handler:
768
803
  self._handler = getattr(self._object, handler, None)
804
+ if hasattr(self._object, "select_outlets"):
805
+ self._outlets_selector = self._object.select_outlets
769
806
 
770
807
  self._set_error_handler()
771
808
  if mode != "skip":
@@ -939,7 +976,7 @@ class ErrorStep(TaskStep):
939
976
  """error execution step, runs a class or handler"""
940
977
 
941
978
  kind = "error_step"
942
- _dict_fields = _task_step_fields + ["before", "base_step"]
979
+ _dict_fields = TaskStep._dict_fields + ["before", "base_step"]
943
980
  _default_class = ""
944
981
 
945
982
  def __init__(
@@ -976,7 +1013,7 @@ class RouterStep(TaskStep):
976
1013
 
977
1014
  kind = "router"
978
1015
  default_shape = "doubleoctagon"
979
- _dict_fields = _task_step_fields + ["routes", "name"]
1016
+ _dict_fields = TaskStep._dict_fields + ["routes", "name"]
980
1017
  _default_class = "mlrun.serving.ModelRouter"
981
1018
 
982
1019
  def __init__(
@@ -1043,20 +1080,14 @@ class RouterStep(TaskStep):
1043
1080
  :param function: function this step should run in
1044
1081
  :param creation_strategy: Strategy for creating or updating the model endpoint:
1045
1082
 
1046
- * **overwrite**:
1047
-
1048
- 1. If model endpoints with the same name exist, delete the `latest` one.
1049
- 2. Create a new model endpoint entry and set it as `latest`.
1050
-
1051
- * **inplace** (default):
1083
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1084
+ create a new model endpoint entry and set it as `latest`.
1052
1085
 
1053
- 1. If model endpoints with the same name exist, update the `latest` entry.
1054
- 2. Otherwise, create a new entry.
1086
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
1087
+ entry;otherwise, create a new entry.
1055
1088
 
1056
- * **archive**:
1057
-
1058
- 1. If model endpoints with the same name exist, preserve them.
1059
- 2. Create a new model endpoint with the same name and set it to `latest`.
1089
+ * **archive**: If model endpoints with the same name exist, preserve them;
1090
+ create a new model endpoint with the same name and set it to `latest`.
1060
1091
 
1061
1092
  """
1062
1093
  if len(self.routes.keys()) >= MAX_MODELS_PER_ROUTER and key not in self.routes:
@@ -1197,14 +1228,18 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1197
1228
  if self._execution_mechanism == storey.ParallelExecutionMechanisms.asyncio:
1198
1229
  if self.__class__.predict_async is Model.predict_async:
1199
1230
  raise mlrun.errors.ModelRunnerError(
1200
- f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict_async() "
1201
- f"is not implemented"
1231
+ {
1232
+ self.name: f"is running with {self._execution_mechanism} "
1233
+ f"execution_mechanism but predict_async() is not implemented"
1234
+ }
1202
1235
  )
1203
1236
  else:
1204
1237
  if self.__class__.predict is Model.predict:
1205
1238
  raise mlrun.errors.ModelRunnerError(
1206
- f"{self.name} is running with {self._execution_mechanism} execution_mechanism but predict() "
1207
- f"is not implemented"
1239
+ {
1240
+ self.name: f"is running with {self._execution_mechanism} execution_mechanism but predict() "
1241
+ f"is not implemented"
1242
+ }
1208
1243
  )
1209
1244
 
1210
1245
  def _load_artifacts(self) -> None:
@@ -1223,7 +1258,9 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1223
1258
  uri = proxy_uri or self.artifact_uri
1224
1259
  if uri:
1225
1260
  if mlrun.datastore.is_store_uri(uri):
1226
- artifact, _ = mlrun.store_manager.get_store_artifact(uri)
1261
+ artifact, _ = mlrun.store_manager.get_store_artifact(
1262
+ uri, allow_empty_resources=True
1263
+ )
1227
1264
  return artifact
1228
1265
  else:
1229
1266
  raise ValueError(
@@ -1244,6 +1281,8 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1244
1281
  raise NotImplementedError("predict_async() method not implemented")
1245
1282
 
1246
1283
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1284
+ if isinstance(body, list):
1285
+ body = self.format_batch(body)
1247
1286
  return self.predict(body)
1248
1287
 
1249
1288
  async def run_async(
@@ -1282,6 +1321,10 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1282
1321
  return model_file, extra_dataitems
1283
1322
  return None, None
1284
1323
 
1324
+ @staticmethod
1325
+ def format_batch(body: Any):
1326
+ return body
1327
+
1285
1328
 
1286
1329
  class LLModel(Model):
1287
1330
  """
@@ -1432,6 +1475,24 @@ class LLModel(Model):
1432
1475
  )
1433
1476
  return body
1434
1477
 
1478
+ def init(self):
1479
+ super().init()
1480
+
1481
+ if not self.model_provider:
1482
+ if self._execution_mechanism != storey.ParallelExecutionMechanisms.asyncio:
1483
+ unchanged_predict = self.__class__.predict is LLModel.predict
1484
+ predict_function_name = "predict"
1485
+ else:
1486
+ unchanged_predict = (
1487
+ self.__class__.predict_async is LLModel.predict_async
1488
+ )
1489
+ predict_function_name = "predict_async"
1490
+ if unchanged_predict:
1491
+ raise mlrun.errors.MLRunRuntimeError(
1492
+ f"Model provider could not be determined for model '{self.name}',"
1493
+ f" and the {predict_function_name} function was not overridden."
1494
+ )
1495
+
1435
1496
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1436
1497
  llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1437
1498
  messages, invocation_config = self.enrich_prompt(
@@ -1554,6 +1615,69 @@ class LLModel(Model):
1554
1615
  return llm_prompt_artifact
1555
1616
 
1556
1617
 
1618
+ class ModelRunnerSelector(ModelObj):
1619
+ """
1620
+ Strategy for controlling model selection and output routing in ModelRunnerStep.
1621
+
1622
+ Subclass this to implement custom logic for agent workflows:
1623
+ - `select_models()`: Called BEFORE execution to choose which models run
1624
+ - `select_outlets()`: Called AFTER execution to route output to downstream steps
1625
+
1626
+ Return `None` from either method to use default behavior (all models / all outlets).
1627
+
1628
+ Example::
1629
+
1630
+ class ToolSelector(ModelRunnerSelector):
1631
+ def select_outlets(self, event):
1632
+ tool = event.get("tool_call")
1633
+ return [tool] if tool else ["final"]
1634
+ """
1635
+
1636
+ def __init__(self, **kwargs):
1637
+ super().__init__()
1638
+
1639
+ def __init_subclass__(cls):
1640
+ super().__init_subclass__()
1641
+ cls._dict_fields = list(
1642
+ set(cls._dict_fields)
1643
+ | set(inspect.signature(cls.__init__).parameters.keys())
1644
+ )
1645
+ cls._dict_fields.remove("self")
1646
+
1647
+ def select_models(
1648
+ self,
1649
+ event: Any,
1650
+ available_models: list[Model],
1651
+ ) -> Optional[Union[list[str], list[Model]]]:
1652
+ """
1653
+ Called before model execution.
1654
+
1655
+ :param event: The full event
1656
+ :param available_models: List of available models
1657
+
1658
+ Returns the models to execute (by name or Model objects).
1659
+ """
1660
+ return None
1661
+
1662
+ def select_outlets(
1663
+ self,
1664
+ event: Any,
1665
+ ) -> Optional[list[str]]:
1666
+ """
1667
+ Called after model execution.
1668
+
1669
+ :param event: The event body after model execution
1670
+ :return: Returns the downstream outlets to route the event to.
1671
+ """
1672
+ return None
1673
+
1674
+
1675
+ # TODO: Remove in 1.13.0
1676
+ @deprecated(
1677
+ version="1.11.0",
1678
+ reason="ModelSelector will be removed in 1.13.0, use ModelRunnerSelector instead",
1679
+ category=FutureWarning,
1680
+ )
1557
1681
  class ModelSelector(ModelObj):
1558
1682
  """Used to select which models to run on each event."""
1559
1683
 
@@ -1585,16 +1709,22 @@ class ModelRunner(storey.ParallelExecution):
1585
1709
  """
1586
1710
  Runs multiple Models on each event. See ModelRunnerStep.
1587
1711
 
1588
- :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1589
- event. Optional. If not passed, all models will be run.
1712
+ :param model_runner_selector: ModelSelector instance whose select() method will be used to select models
1713
+ to run on each event. Optional. If not passed, all models will be run.
1590
1714
  """
1591
1715
 
1592
1716
  def __init__(
1593
- self, *args, context, model_selector: Optional[ModelSelector] = None, **kwargs
1717
+ self,
1718
+ *args,
1719
+ context,
1720
+ model_runner_selector: Optional[ModelRunnerSelector] = None,
1721
+ raise_exception: bool = True,
1722
+ **kwargs,
1594
1723
  ):
1595
1724
  super().__init__(*args, **kwargs)
1596
- self.model_selector = model_selector or ModelSelector()
1725
+ self.model_runner_selector = model_runner_selector or ModelRunnerSelector()
1597
1726
  self.context = context
1727
+ self._raise_exception = raise_exception
1598
1728
 
1599
1729
  def preprocess_event(self, event):
1600
1730
  if not hasattr(event, "_metadata"):
@@ -1607,7 +1737,31 @@ class ModelRunner(storey.ParallelExecution):
1607
1737
 
1608
1738
  def select_runnables(self, event):
1609
1739
  models = cast(list[Model], self.runnables)
1610
- return self.model_selector.select(event, models)
1740
+ return self.model_runner_selector.select_models(event, models)
1741
+
1742
+ def select_outlets(self, event) -> Optional[Collection[str]]:
1743
+ sys_outlets = [f"{self.name}_error_raise"]
1744
+ if "background_task_status_step" in self._name_to_outlet:
1745
+ sys_outlets.append("background_task_status_step")
1746
+ if self._raise_exception and self._is_error(event):
1747
+ return sys_outlets
1748
+ user_outlets = self.model_runner_selector.select_outlets(event)
1749
+ if user_outlets:
1750
+ return (
1751
+ user_outlets if isinstance(user_outlets, list) else [user_outlets]
1752
+ ) + sys_outlets
1753
+ return None
1754
+
1755
+ def _is_error(self, event: dict) -> bool:
1756
+ if len(self.runnables) == 1:
1757
+ if isinstance(event, dict):
1758
+ return event.get("error") is not None
1759
+ else:
1760
+ for model in event:
1761
+ body_by_model = event.get(model)
1762
+ if isinstance(body_by_model, dict) and "error" in body_by_model:
1763
+ return True
1764
+ return False
1611
1765
 
1612
1766
 
1613
1767
  class MonitoredStep(ABC, TaskStep, StepToDict):
@@ -1664,52 +1818,117 @@ class ModelRunnerStep(MonitoredStep):
1664
1818
 
1665
1819
  Note ModelRunnerStep can only be added to a graph that has the flow topology and running with async engine.
1666
1820
 
1667
- :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1668
- event. Optional. If not passed, all models will be run.
1669
- :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
1670
- an error. If False, the error will appear in the output event.
1821
+ Note see configure_pool_resource method documentation for default number of max threads and max processes.
1671
1822
 
1672
- :raise ModelRunnerError - when a model raise an error the ModelRunnerStep will handle it, collect errors and outputs
1673
- from added models, If raise_exception is True will raise ModelRunnerError Else will add
1674
- the error msg as part of the event body mapped by model name if more than one model was
1675
- added to the ModelRunnerStep
1823
+ :raise ModelRunnerError: when a model raises an error the ModelRunnerStep will handle it, collect errors and
1824
+ outputs from added models. If raise_exception is True will raise ModelRunnerError. Else
1825
+ will add the error msg as part of the event body mapped by model name if more than
1826
+ one model was added to the ModelRunnerStep
1676
1827
  """
1677
1828
 
1678
1829
  kind = "model_runner"
1679
- _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
1830
+ _dict_fields = MonitoredStep._dict_fields + [
1831
+ "_shared_proxy_mapping",
1832
+ "max_processes",
1833
+ "max_threads",
1834
+ "pool_factor",
1835
+ ]
1680
1836
 
1681
1837
  def __init__(
1682
1838
  self,
1683
1839
  *args,
1684
1840
  name: Optional[str] = None,
1841
+ model_runner_selector: Optional[Union[str, ModelRunnerSelector]] = None,
1842
+ model_runner_selector_parameters: Optional[dict] = None,
1685
1843
  model_selector: Optional[Union[str, ModelSelector]] = None,
1686
1844
  model_selector_parameters: Optional[dict] = None,
1687
1845
  raise_exception: bool = True,
1688
1846
  **kwargs,
1689
1847
  ):
1690
- if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1691
- raise mlrun.errors.MLRunInvalidArgumentError(
1692
- "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1693
- "`model_selector_parameters`."
1694
- )
1695
- if model_selector:
1696
- model_selector_parameters = model_selector_parameters or (
1697
- model_selector.to_dict()
1698
- if isinstance(model_selector, ModelSelector)
1699
- else {}
1700
- )
1701
- model_selector = (
1702
- model_selector
1703
- if isinstance(model_selector, str)
1704
- else model_selector.__class__.__name__
1705
- )
1848
+ """
1849
+
1850
+ :param name: The name of the ModelRunnerStep.
1851
+ :param model_runner_selector: ModelRunnerSelector instance whose select_models()
1852
+ and select_outlets() methods will be used
1853
+ to select models to run on each event and outlets to
1854
+ route the event to.
1855
+ :param model_runner_selector_parameters: Parameters for the model_runner_selector, if model_runner_selector
1856
+ is the class name we will use this param when
1857
+ initializing the selector.
1858
+ :param model_selector: (Deprecated)
1859
+ :param model_selector_parameters: (Deprecated)
1860
+ :param raise_exception: Determines whether to raise ModelRunnerError when one or more models
1861
+ raise an error during execution.
1862
+ If False, errors will be added to the event body.
1863
+ """
1864
+ self.max_processes = None
1865
+ self.max_threads = None
1866
+ self.pool_factor = None
1867
+
1868
+ if (model_selector or model_selector_parameters) and (
1869
+ model_runner_selector or model_runner_selector_parameters
1870
+ ):
1871
+ raise GraphError(
1872
+ "Cannot provide both `model_selector`/`model_selector_parameters` "
1873
+ "and `model_runner_selector`/`model_runner_selector_parameters`. "
1874
+ "Please use only the latter pair."
1875
+ )
1876
+ if model_selector or model_selector_parameters:
1877
+ warnings.warn(
1878
+ "`model_selector` and `model_selector_parameters` are deprecated, "
1879
+ "please use `model_runner_selector` and `model_runner_selector_parameters` instead.",
1880
+ # TODO: Remove this in 1.13.0
1881
+ FutureWarning,
1882
+ )
1883
+ if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1884
+ raise mlrun.errors.MLRunInvalidArgumentError(
1885
+ "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1886
+ "`model_selector_parameters`."
1887
+ )
1888
+ if model_selector:
1889
+ model_selector_parameters = model_selector_parameters or (
1890
+ model_selector.to_dict()
1891
+ if isinstance(model_selector, ModelSelector)
1892
+ else {}
1893
+ )
1894
+ model_selector = (
1895
+ model_selector
1896
+ if isinstance(model_selector, str)
1897
+ else model_selector.__class__.__name__
1898
+ )
1899
+ else:
1900
+ if (
1901
+ isinstance(model_runner_selector, ModelRunnerSelector)
1902
+ and model_runner_selector_parameters
1903
+ ):
1904
+ raise mlrun.errors.MLRunInvalidArgumentError(
1905
+ "Cannot provide a model_runner_selector object as argument to `model_runner_selector` "
1906
+ "and also provide `model_runner_selector_parameters`."
1907
+ )
1908
+ if model_runner_selector:
1909
+ model_runner_selector_parameters = model_runner_selector_parameters or (
1910
+ model_runner_selector.to_dict()
1911
+ if isinstance(model_runner_selector, ModelRunnerSelector)
1912
+ else {}
1913
+ )
1914
+ model_runner_selector = (
1915
+ model_runner_selector
1916
+ if isinstance(model_runner_selector, str)
1917
+ else model_runner_selector.__class__.__name__
1918
+ )
1706
1919
 
1707
1920
  super().__init__(
1708
1921
  *args,
1709
1922
  name=name,
1710
1923
  raise_exception=raise_exception,
1711
1924
  class_name="mlrun.serving.ModelRunner",
1712
- class_args=dict(model_selector=(model_selector, model_selector_parameters)),
1925
+ class_args=dict(
1926
+ model_selector=(model_selector, model_selector_parameters),
1927
+ model_runner_selector=(
1928
+ model_runner_selector,
1929
+ model_runner_selector_parameters,
1930
+ ),
1931
+ ),
1713
1932
  **kwargs,
1714
1933
  )
1715
1934
  self.raise_exception = raise_exception
@@ -1737,17 +1956,18 @@ class ModelRunnerStep(MonitoredStep):
1737
1956
  :param shared_model_name: str, the name of the shared model that is already defined within the graph
1738
1957
  :param labels: model endpoint labels, should be list of str or mapping of str:str
1739
1958
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1740
- * **overwrite**:
1741
- 1. If model endpoints with the same name exist, delete the `latest` one.
1742
- 2. Create a new model endpoint entry and set it as `latest`.
1743
- * **inplace** (default):
1744
- 1. If model endpoints with the same name exist, update the `latest` entry.
1745
- 2. Otherwise, create a new entry.
1746
- * **archive**:
1747
- 1. If model endpoints with the same name exist, preserve them.
1748
- 2. Create a new model endpoint with the same name and set it to `latest`.
1959
+
1960
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
1961
+ create a new model endpoint entry and set it as `latest`.
1962
+
1963
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest` entry;
1964
+ otherwise, create a new entry.
1965
+
1966
+ * **archive**: If model endpoints with the same name exist, preserve them;
1967
+ create a new model endpoint with the same name and set it to `latest`.
1749
1968
 
1750
1969
  :param override: bool allow override existing model on the current ModelRunnerStep.
1970
+ :raise GraphError: when the shared model is not found in the root flow step shared models.
1751
1971
  """
1752
1972
  model_class, model_params = (
1753
1973
  "mlrun.serving.Model",
@@ -1805,14 +2025,14 @@ class ModelRunnerStep(MonitoredStep):
1805
2025
  if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
1806
2026
  self._shared_proxy_mapping[shared_model_name] = {
1807
2027
  endpoint_name: model_artifact.uri
1808
- if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
2028
+ if isinstance(model_artifact, ModelArtifact | LLMPromptArtifact)
1809
2029
  else model_artifact
1810
2030
  }
1811
2031
  elif override and shared_model_name:
1812
2032
  self._shared_proxy_mapping[shared_model_name].update(
1813
2033
  {
1814
2034
  endpoint_name: model_artifact.uri
1815
- if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
2035
+ if isinstance(model_artifact, ModelArtifact | LLMPromptArtifact)
1816
2036
  else model_artifact
1817
2037
  }
1818
2038
  )
@@ -1856,49 +2076,48 @@ class ModelRunnerStep(MonitoredStep):
1856
2076
  (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
1857
2077
  outputs will be overridden with UsageResponseKeys fields.
1858
2078
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
1859
- * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
1860
- intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
1861
- Lock (GIL).
1862
- * "dedicated_process" To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
1863
- tasks that also require significant Runnable-specific initialization (e.g. a large model).
1864
- * "thread_pool" To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
1865
- otherwise block the main event loop thread.
1866
- * "asyncio" To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
1867
- event loop to continue running while waiting for a response.
1868
- * "shared_executor" Reuses an external executor (typically managed by the flow or context) to execute the
1869
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
1870
- useful when:
1871
- - You want to share a heavy resource like a large model loaded onto a GPU.
1872
- - You want to centralize task scheduling or coordination for multiple lightweight tasks.
1873
- - You aim to minimize overhead from creating new executors or processes/threads per runnable.
1874
- The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
1875
- memory and hardware accelerators.
1876
- * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
1877
- It means that the runnable will not actually be run in parallel to anything else.
1878
-
1879
- :param model_artifact: model artifact or mlrun model artifact uri
1880
- :param labels: model endpoint labels, should be list of str or mapping of str:str
1881
- :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1882
- * **overwrite**:
1883
- 1. If model endpoints with the same name exist, delete the `latest` one.
1884
- 2. Create a new model endpoint entry and set it as `latest`.
1885
- * **inplace** (default):
1886
- 1. If model endpoints with the same name exist, update the `latest` entry.
1887
- 2. Otherwise, create a new entry.
1888
- * **archive**:
1889
- 1. If model endpoints with the same name exist, preserve them.
1890
- 2. Create a new model endpoint with the same name and set it to `latest`.
1891
-
1892
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2079
+
2080
+ * **process_pool**: To run in a separate process from a process pool. This is appropriate
2081
+ for CPU or GPU intensive tasks as they would otherwise block the main process by holding
2082
+ Python's Global Interpreter Lock (GIL).
2083
+
2084
+ * **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU
2085
+ or GPU intensive tasks that also require significant Runnable-specific initialization
2086
+ (e.g. a large model).
2087
+
2088
+ * **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks,
2089
+ as they would otherwise block the main event loop thread.
2090
+
2091
+ * **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use
2092
+ asyncio, allowing the event loop to continue running while waiting for a response.
2093
+
2094
+ * **naive**: To run in the main event loop. This is appropriate only for trivial computation
2095
+ and/or file I/O. It means that the runnable will not actually be run in parallel to
2096
+ anything else.
2097
+
2098
+ :param model_artifact: model artifact or mlrun model artifact uri
2099
+ :param labels: model endpoint labels, should be list of str or mapping of str:str
2100
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
2101
+
2102
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
2103
+ create a new model endpoint entry and set it as `latest`.
2104
+
2105
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
2106
+ entry; otherwise, create a new entry.
2107
+
2108
+ * **archive**: If model endpoints with the same name exist, preserve them;
2109
+ create a new model endpoint with the same name and set it to `latest`.
2110
+
2111
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1893
2112
  that been configured in the model artifact, please note that those inputs need to
1894
2113
  be equal in length and order to the inputs that model_class predict method expects
1895
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2114
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1896
2115
  that been configured in the model artifact, please note that those outputs need to
1897
2116
  be equal to the model_class predict method outputs (length, and order)
1898
2117
 
1899
2118
  When using LLModel, the output will be overridden with UsageResponseKeys.fields().
1900
2119
 
1901
- :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
2120
+ :param input_path: when specified selects the key/path in the event to use as model monitoring inputs
1902
2121
  this require that the event body will behave like a dict, expects scopes to be
1903
2122
  defined by dot notation (e.g "data.d").
1904
2123
  examples: input_path="data.b"
@@ -1908,7 +2127,7 @@ class ModelRunnerStep(MonitoredStep):
1908
2127
  be {"f0": [1, 2]}.
1909
2128
  if a ``list`` or ``list of lists`` is provided, it must follow the order and
1910
2129
  size defined by the input schema.
1911
- :param result_path: when specified selects the key/path in the output event to use as model monitoring
2130
+ :param result_path: when specified selects the key/path in the output event to use as model monitoring
1912
2131
  outputs this require that the output event body will behave like a dict,
1913
2132
  expects scopes to be defined by dot notation (e.g "data.d").
1914
2133
  examples: result_path="out.b"
@@ -1919,8 +2138,8 @@ class ModelRunnerStep(MonitoredStep):
1919
2138
  if a ``list`` or ``list of lists`` is provided, it must follow the order and
1920
2139
  size defined by the output schema.
1921
2140
 
1922
- :param override: bool allow override existing model on the current ModelRunnerStep.
1923
- :param model_parameters: Parameters for model instantiation
2141
+ :param override: bool allow override existing model on the current ModelRunnerStep.
2142
+ :param model_parameters: Parameters for model instantiation
1924
2143
  """
1925
2144
  if isinstance(model_class, Model) and model_parameters:
1926
2145
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2093,6 +2312,24 @@ class ModelRunnerStep(MonitoredStep):
2093
2312
  "Monitoring data must be a dictionary."
2094
2313
  )
2095
2314
 
2315
+ def configure_pool_resource(
2316
+ self,
2317
+ max_processes: Optional[int] = None,
2318
+ max_threads: Optional[int] = None,
2319
+ pool_factor: Optional[int] = None,
2320
+ ) -> None:
2321
+ """
2322
+ Configure the resource limits for the shared models in the graph.
2323
+
2324
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2325
+ Defaults to the number of CPUs or 16 if undetectable.
2326
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2327
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2328
+ """
2329
+ self.max_processes = max_processes
2330
+ self.max_threads = max_threads
2331
+ self.pool_factor = pool_factor
2332
+
2096
2333
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
2097
2334
  self.context = context
2098
2335
  if not self._is_local_function(context):
@@ -2101,6 +2338,9 @@ class ModelRunnerStep(MonitoredStep):
2101
2338
  model_selector, model_selector_params = self.class_args.get(
2102
2339
  "model_selector", (None, None)
2103
2340
  )
2341
+ model_runner_selector, model_runner_selector_params = self.class_args.get(
2342
+ "model_runner_selector", (None, None)
2343
+ )
2104
2344
  execution_mechanism_by_model_name = self.class_args.get(
2105
2345
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
2106
2346
  )
@@ -2109,6 +2349,15 @@ class ModelRunnerStep(MonitoredStep):
2109
2349
  model_selector = get_class(model_selector, namespace).from_dict(
2110
2350
  model_selector_params, init_with_params=True
2111
2351
  )
2352
+ model_runner_selector = (
2353
+ self._convert_model_selector_to_model_runner_selector(
2354
+ model_selector=model_selector
2355
+ )
2356
+ )
2357
+ elif model_runner_selector:
2358
+ model_runner_selector = get_class(
2359
+ model_runner_selector, namespace
2360
+ ).from_dict(model_runner_selector_params, init_with_params=True)
2112
2361
  model_objects = []
2113
2362
  for model, model_params in models.values():
2114
2363
  model_name = model_params.get("name")
@@ -2135,14 +2384,46 @@ class ModelRunnerStep(MonitoredStep):
2135
2384
  )
2136
2385
  model_objects.append(model)
2137
2386
  self._async_object = ModelRunner(
2138
- model_selector=model_selector,
2387
+ model_runner_selector=model_runner_selector,
2139
2388
  runnables=model_objects,
2140
2389
  execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
2141
2390
  shared_proxy_mapping=self._shared_proxy_mapping or None,
2142
2391
  name=self.name,
2143
2392
  context=context,
2393
+ max_processes=self.max_processes,
2394
+ max_threads=self.max_threads,
2395
+ pool_factor=self.pool_factor,
2396
+ raise_exception=self.raise_exception,
2397
+ **extra_kwargs,
2144
2398
  )
2145
2399
 
2400
+ def _convert_model_selector_to_model_runner_selector(
2401
+ self,
2402
+ model_selector,
2403
+ ) -> "ModelRunnerSelector":
2404
+ """
2405
+ Wrap a ModelSelector into a ModelRunnerSelector for backward compatibility.
2406
+ """
2407
+
2408
+ class Adapter(ModelRunnerSelector):
2409
+ def __init__(self):
2410
+ self.selector = model_selector
2411
+
2412
+ def select_models(
2413
+ self, event, available_models
2414
+ ) -> Union[list[str], list[Model]]:
2415
+ # Call old ModelSelector logic
2416
+ return self.selector.select(event, available_models)
2417
+
2418
+ def select_outlets(
2419
+ self,
2420
+ event,
2421
+ ) -> Optional[list[str]]:
2422
+ # By default, return all outlets (old ModelSelector didn't control routing)
2423
+ return None
2424
+
2425
+ return Adapter()
2426
+
2146
2427
 
2147
2428
  class ModelRunnerErrorRaiser(storey.MapClass):
2148
2429
  def __init__(self, raise_exception: bool, models_names: list[str], **kwargs):
@@ -2155,11 +2436,15 @@ class ModelRunnerErrorRaiser(storey.MapClass):
2155
2436
  errors = {}
2156
2437
  should_raise = False
2157
2438
  if len(self._models_names) == 1:
2158
- should_raise = event.body.get("error") is not None
2159
- errors[self._models_names[0]] = event.body.get("error")
2439
+ if isinstance(event.body, dict):
2440
+ should_raise = event.body.get("error") is not None
2441
+ errors[self._models_names[0]] = event.body.get("error")
2160
2442
  else:
2161
2443
  for model in event.body:
2162
- errors[model] = event.body.get(model).get("error")
2444
+ body_by_model = event.body.get(model)
2445
+ errors[model] = None
2446
+ if isinstance(body_by_model, dict):
2447
+ errors[model] = body_by_model.get("error")
2163
2448
  if errors[model] is not None:
2164
2449
  should_raise = True
2165
2450
  if should_raise:
@@ -2229,6 +2514,8 @@ class QueueStep(BaseStep, StepToDict):
2229
2514
  model_endpoint_creation_strategy: Optional[
2230
2515
  schemas.ModelEndpointCreationStrategy
2231
2516
  ] = None,
2517
+ cycle_to: Optional[list[str]] = None,
2518
+ max_iterations: Optional[int] = None,
2232
2519
  **class_args,
2233
2520
  ):
2234
2521
  if not function:
@@ -2246,6 +2533,8 @@ class QueueStep(BaseStep, StepToDict):
2246
2533
  input_path,
2247
2534
  result_path,
2248
2535
  model_endpoint_creation_strategy,
2536
+ cycle_to,
2537
+ max_iterations,
2249
2538
  **class_args,
2250
2539
  )
2251
2540
 
@@ -2281,8 +2570,10 @@ class FlowStep(BaseStep):
2281
2570
  after: Optional[list] = None,
2282
2571
  engine=None,
2283
2572
  final_step=None,
2573
+ allow_cyclic: bool = False,
2574
+ max_iterations: Optional[int] = None,
2284
2575
  ):
2285
- super().__init__(name, after)
2576
+ super().__init__(name, after, max_iterations=max_iterations)
2286
2577
  self._steps = None
2287
2578
  self.steps = steps
2288
2579
  self.engine = engine
@@ -2294,6 +2585,7 @@ class FlowStep(BaseStep):
2294
2585
  self._wait_for_result = False
2295
2586
  self._source = None
2296
2587
  self._start_steps = []
2588
+ self._allow_cyclic = allow_cyclic
2297
2589
 
2298
2590
  def get_children(self):
2299
2591
  return self._steps.values()
@@ -2327,6 +2619,8 @@ class FlowStep(BaseStep):
2327
2619
  model_endpoint_creation_strategy: Optional[
2328
2620
  schemas.ModelEndpointCreationStrategy
2329
2621
  ] = None,
2622
+ cycle_to: Optional[list[str]] = None,
2623
+ max_iterations: Optional[int] = None,
2330
2624
  **class_args,
2331
2625
  ):
2332
2626
  """add task, queue or router step/class to the flow
@@ -2360,21 +2654,17 @@ class FlowStep(BaseStep):
2360
2654
  to event["y"] resulting in {"x": 5, "y": <result>}
2361
2655
  :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
2362
2656
 
2363
- * **overwrite**:
2364
-
2365
- 1. If model endpoints with the same name exist, delete the `latest` one.
2366
- 2. Create a new model endpoint entry and set it as `latest`.
2367
-
2368
- * **inplace** (default):
2657
+ * **overwrite**: If model endpoints with the same name exist, delete the `latest` one;
2658
+ create a new model endpoint entry and set it as `latest`.
2369
2659
 
2370
- 1. If model endpoints with the same name exist, update the `latest` entry.
2371
- 2. Otherwise, create a new entry.
2660
+ * **inplace** (default): If model endpoints with the same name exist, update the `latest`
2661
+ entry; otherwise, create a new entry.
2372
2662
 
2373
- * **archive**:
2374
-
2375
- 1. If model endpoints with the same name exist, preserve them.
2376
- 2. Create a new model endpoint with the same name and set it to `latest`.
2663
+ * **archive**: If model endpoints with the same name exist, preserve them;
2664
+ create a new model endpoint with the same name and set it to `latest`.
2377
2665
 
2666
+ :param cycle_to: list of step names to create a cycle to (for cyclic graphs)
2667
+ :param max_iterations: maximum number of iterations for this step in case of a cycle graph
2378
2668
  :param class_args: class init arguments
2379
2669
  """
2380
2670
 
@@ -2400,6 +2690,8 @@ class FlowStep(BaseStep):
2400
2690
  after_list = after if isinstance(after, list) else [after]
2401
2691
  for after in after_list:
2402
2692
  self.insert_step(name, step, after, before)
2693
+ step.cycle_to(cycle_to or [])
2694
+ step._max_iterations = max_iterations
2403
2695
  return step
2404
2696
 
2405
2697
  def insert_step(self, key, step, after, before=None):
@@ -2492,13 +2784,24 @@ class FlowStep(BaseStep):
2492
2784
  for step in self._steps.values():
2493
2785
  step._next = None
2494
2786
  step._visited = False
2495
- if step.after:
2787
+ if step.after and not step.cycle_from:
2788
+ has_illegal_branches = len(step.after) > 1 and self.engine == "sync"
2789
+ if has_illegal_branches:
2790
+ raise GraphError(
2791
+ f"synchronous flow engine doesnt support branches use async for step {step.name}"
2792
+ )
2496
2793
  loop_step = has_loop(step, [])
2497
- if loop_step:
2794
+ if loop_step and not self.allow_cyclic:
2498
2795
  raise GraphError(
2499
2796
  f"Error, loop detected in step {loop_step}, graph must be acyclic (DAG)"
2500
2797
  )
2501
- else:
2798
+ elif (
2799
+ step.after
2800
+ and step.cycle_from
2801
+ and set(step.after) == set(step.cycle_from)
2802
+ ):
2803
+ start_steps.append(step.name)
2804
+ elif not step.cycle_from:
2502
2805
  start_steps.append(step.name)
2503
2806
 
2504
2807
  responders = []
@@ -2595,6 +2898,9 @@ class FlowStep(BaseStep):
2595
2898
  def process_step(state, step, root):
2596
2899
  if not state._is_local_function(self.context) or state._visited:
2597
2900
  return
2901
+ state._visited = (
2902
+ True # mark visited to avoid re-visit in case of multiple uplinks
2903
+ )
2598
2904
  for item in state.next or []:
2599
2905
  next_state = root[item]
2600
2906
  if next_state.async_object:
@@ -2605,7 +2911,7 @@ class FlowStep(BaseStep):
2605
2911
  )
2606
2912
 
2607
2913
  default_source, self._wait_for_result = _init_async_objects(
2608
- self.context, self._steps.values()
2914
+ self.context, self._steps.values(), self
2609
2915
  )
2610
2916
 
2611
2917
  source = self._source or default_source
@@ -2836,6 +3142,8 @@ class RootFlowStep(FlowStep):
2836
3142
  "shared_models",
2837
3143
  "shared_models_mechanism",
2838
3144
  "pool_factor",
3145
+ "allow_cyclic",
3146
+ "max_iterations",
2839
3147
  ]
2840
3148
 
2841
3149
  def __init__(
@@ -2845,13 +3153,11 @@ class RootFlowStep(FlowStep):
2845
3153
  after: Optional[list] = None,
2846
3154
  engine=None,
2847
3155
  final_step=None,
3156
+ allow_cyclic: bool = False,
3157
+ max_iterations: Optional[int] = 10_000,
2848
3158
  ):
2849
3159
  super().__init__(
2850
- name,
2851
- steps,
2852
- after,
2853
- engine,
2854
- final_step,
3160
+ name, steps, after, engine, final_step, allow_cyclic, max_iterations
2855
3161
  )
2856
3162
  self._models = set()
2857
3163
  self._route_models = set()
@@ -2862,6 +3168,22 @@ class RootFlowStep(FlowStep):
2862
3168
  self._shared_max_threads = None
2863
3169
  self._pool_factor = None
2864
3170
 
3171
+ @property
3172
+ def max_iterations(self) -> int:
3173
+ return self._max_iterations
3174
+
3175
+ @max_iterations.setter
3176
+ def max_iterations(self, max_iterations: int):
3177
+ self._max_iterations = max_iterations
3178
+
3179
+ @property
3180
+ def allow_cyclic(self) -> bool:
3181
+ return self._allow_cyclic
3182
+
3183
+ @allow_cyclic.setter
3184
+ def allow_cyclic(self, allow_cyclic: bool):
3185
+ self._allow_cyclic = allow_cyclic
3186
+
2865
3187
  def add_shared_model(
2866
3188
  self,
2867
3189
  name: str,
@@ -2879,45 +3201,55 @@ class RootFlowStep(FlowStep):
2879
3201
  Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2880
3202
  :param name: Name of the shared model (should be unique in the graph)
2881
3203
  :param model_class: Model class name. If LLModel is chosen
2882
- (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
2883
- outputs will be overridden with UsageResponseKeys fields.
3204
+ (either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
3205
+ outputs will be overridden with UsageResponseKeys fields.
2884
3206
  :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2885
- * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
3207
+
3208
+ * **process_pool**: To run in a separate process from a process pool. This is appropriate for CPU or GPU
2886
3209
  intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
2887
3210
  Lock (GIL).
2888
- * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
2889
- tasks that also require significant Runnable-specific initialization (e.g. a large model).
2890
- * "thread_pool" To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
3211
+
3212
+ * **dedicated_process**: To run in a separate dedicated process. This is appropriate for CPU or GPU
3213
+ intensive tasks that also require significant Runnable-specific initialization (e.g. a large model).
3214
+
3215
+ * **thread_pool**: To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
2891
3216
  otherwise block the main event loop thread.
2892
- * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
3217
+
3218
+ * **asyncio**: To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
2893
3219
  event loop to continue running while waiting for a response.
2894
- * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
2895
- runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
2896
- useful when:
3220
+
3221
+ * **shared_executor**: Reuses an external executor (typically managed by the flow or context) to execute
3222
+ the runnable. Should be used only if you have multiple `ParallelExecution` in the same flow and
3223
+ especially useful when:
3224
+
2897
3225
  - You want to share a heavy resource like a large model loaded onto a GPU.
3226
+
2898
3227
  - You want to centralize task scheduling or coordination for multiple lightweight tasks.
3228
+
2899
3229
  - You aim to minimize overhead from creating new executors or processes/threads per runnable.
3230
+
2900
3231
  The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
2901
3232
  memory and hardware accelerators.
2902
- * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
2903
- It means that the runnable will not actually be run in parallel to anything else.
2904
-
2905
- :param model_artifact: model artifact or mlrun model artifact uri
2906
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2907
- that been configured in the model artifact, please note that those inputs need
2908
- to be equal in length and order to the inputs that model_class
2909
- predict method expects
2910
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2911
- that been configured in the model artifact, please note that those outputs need
2912
- to be equal to the model_class
2913
- predict method outputs (length, and order)
2914
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
2915
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
2916
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
2917
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
2918
- in path.
2919
- :param override: bool allow override existing model on the current ModelRunnerStep.
2920
- :param model_parameters: Parameters for model instantiation
3233
+
3234
+ * **naive**: To run in the main event loop. This is appropriate only for trivial computation and/or file
3235
+ I/O. It means that the runnable will not actually be run in parallel to anything else.
3236
+
3237
+ :param model_artifact: model artifact or mlrun model artifact uri
3238
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
3239
+ that been configured in the model artifact, please note that those inputs need
3240
+ to be equal in length and order to the inputs that model_class
3241
+ predict method expects
3242
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
3243
+ that been configured in the model artifact, please note that those outputs need
3244
+ to be equal to the model_class
3245
+ predict method outputs (length, and order)
3246
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
3247
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
3248
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
3249
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
3250
+ in path.
3251
+ :param override: bool allow override existing model on the current ModelRunnerStep.
3252
+ :param model_parameters: Parameters for model instantiation
2921
3253
  """
2922
3254
  if isinstance(model_class, Model) and model_parameters:
2923
3255
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2983,7 +3315,7 @@ class RootFlowStep(FlowStep):
2983
3315
 
2984
3316
  def get_shared_model_by_artifact_uri(
2985
3317
  self, artifact_uri: str
2986
- ) -> Optional[tuple[str, str, dict]]:
3318
+ ) -> Union[tuple[str, str, dict], tuple[None, None, None]]:
2987
3319
  """
2988
3320
  Get a shared model by its artifact URI.
2989
3321
  :param artifact_uri: The artifact URI of the model.
@@ -2992,9 +3324,9 @@ class RootFlowStep(FlowStep):
2992
3324
  for model_name, (model_class, model_params) in self.shared_models.items():
2993
3325
  if model_params.get("artifact_uri") == artifact_uri:
2994
3326
  return model_name, model_class, model_params
2995
- return None
3327
+ return None, None, None
2996
3328
 
2997
- def config_pool_resource(
3329
+ def configure_shared_pool_resource(
2998
3330
  self,
2999
3331
  max_processes: Optional[int] = None,
3000
3332
  max_threads: Optional[int] = None,
@@ -3002,8 +3334,9 @@ class RootFlowStep(FlowStep):
3002
3334
  ) -> None:
3003
3335
  """
3004
3336
  Configure the resource limits for the shared models in the graph.
3337
+
3005
3338
  :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
3006
- Defaults to the number of CPUs or 16 if undetectable.
3339
+ Defaults to the number of CPUs or 16 if undetectable.
3007
3340
  :param max_threads: Maximum number of threads to spawn. Defaults to 32.
3008
3341
  :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
3009
3342
  """
@@ -3399,7 +3732,7 @@ def params_to_step(
3399
3732
  return name, step
3400
3733
 
3401
3734
 
3402
- def _init_async_objects(context, steps):
3735
+ def _init_async_objects(context, steps, root):
3403
3736
  try:
3404
3737
  import storey
3405
3738
  except ImportError:
@@ -3414,6 +3747,7 @@ def _init_async_objects(context, steps):
3414
3747
 
3415
3748
  for step in steps:
3416
3749
  if hasattr(step, "async_object") and step._is_local_function(context):
3750
+ max_iterations = step._max_iterations or root.max_iterations
3417
3751
  if step.kind == StepKinds.queue:
3418
3752
  skip_stream = context.is_mock and step.next
3419
3753
  if step.path and not skip_stream:
@@ -3432,17 +3766,19 @@ def _init_async_objects(context, steps):
3432
3766
  datastore_profile = datastore_profile_read(stream_path)
3433
3767
  if isinstance(
3434
3768
  datastore_profile,
3435
- (DatastoreProfileKafkaTarget, DatastoreProfileKafkaStream),
3769
+ DatastoreProfileKafkaTarget | DatastoreProfileKafkaStream,
3436
3770
  ):
3437
3771
  step._async_object = KafkaStoreyTarget(
3438
3772
  path=stream_path,
3439
3773
  context=context,
3774
+ max_iterations=max_iterations,
3440
3775
  **options,
3441
3776
  )
3442
3777
  elif isinstance(datastore_profile, DatastoreProfileV3io):
3443
3778
  step._async_object = StreamStoreyTarget(
3444
3779
  stream_path=stream_path,
3445
3780
  context=context,
3781
+ max_iterations=max_iterations,
3446
3782
  **options,
3447
3783
  )
3448
3784
  else:
@@ -3462,10 +3798,15 @@ def _init_async_objects(context, steps):
3462
3798
  brokers=brokers,
3463
3799
  producer_options=kafka_producer_options,
3464
3800
  context=context,
3801
+ max_iterations=max_iterations,
3465
3802
  **options,
3466
3803
  )
3467
3804
  elif stream_path.startswith("dummy://"):
3468
- step._async_object = _DummyStream(context=context, **options)
3805
+ step._async_object = _DummyStream(
3806
+ context=context,
3807
+ max_iterations=max_iterations,
3808
+ **options,
3809
+ )
3469
3810
  else:
3470
3811
  if stream_path.startswith("v3io://"):
3471
3812
  endpoint, stream_path = parse_path(step.path)
@@ -3474,10 +3815,14 @@ def _init_async_objects(context, steps):
3474
3815
  storey.V3ioDriver(endpoint or config.v3io_api),
3475
3816
  stream_path,
3476
3817
  context=context,
3818
+ max_iterations=max_iterations,
3477
3819
  **options,
3478
3820
  )
3479
3821
  else:
3480
- step._async_object = storey.Map(lambda x: x)
3822
+ step._async_object = storey.Map(
3823
+ lambda x: x,
3824
+ max_iterations=max_iterations,
3825
+ )
3481
3826
 
3482
3827
  elif not step.async_object or not hasattr(step.async_object, "_outlets"):
3483
3828
  # if regular class, wrap with storey Map
@@ -3489,6 +3834,8 @@ def _init_async_objects(context, steps):
3489
3834
  name=step.name,
3490
3835
  context=context,
3491
3836
  pass_context=step._inject_context,
3837
+ fn_select_outlets=step._outlets_selector,
3838
+ max_iterations=max_iterations,
3492
3839
  )
3493
3840
  if (
3494
3841
  respond_supported