mlrun 1.10.0rc11__py3-none-any.whl → 1.10.0rc13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (59) hide show
  1. mlrun/__init__.py +2 -1
  2. mlrun/__main__.py +7 -1
  3. mlrun/artifacts/base.py +9 -3
  4. mlrun/artifacts/dataset.py +2 -1
  5. mlrun/artifacts/llm_prompt.py +6 -2
  6. mlrun/artifacts/model.py +2 -2
  7. mlrun/common/constants.py +1 -0
  8. mlrun/common/runtimes/constants.py +10 -1
  9. mlrun/common/schemas/__init__.py +1 -1
  10. mlrun/common/schemas/model_monitoring/model_endpoints.py +1 -1
  11. mlrun/common/schemas/serving.py +7 -0
  12. mlrun/config.py +21 -2
  13. mlrun/datastore/__init__.py +3 -1
  14. mlrun/datastore/alibaba_oss.py +1 -1
  15. mlrun/datastore/azure_blob.py +1 -1
  16. mlrun/datastore/base.py +6 -31
  17. mlrun/datastore/datastore.py +109 -33
  18. mlrun/datastore/datastore_profile.py +31 -0
  19. mlrun/datastore/dbfs_store.py +1 -1
  20. mlrun/datastore/google_cloud_storage.py +2 -2
  21. mlrun/datastore/model_provider/__init__.py +13 -0
  22. mlrun/datastore/model_provider/model_provider.py +160 -0
  23. mlrun/datastore/model_provider/openai_provider.py +144 -0
  24. mlrun/datastore/remote_client.py +65 -0
  25. mlrun/datastore/s3.py +1 -1
  26. mlrun/datastore/storeytargets.py +1 -1
  27. mlrun/datastore/utils.py +22 -0
  28. mlrun/datastore/v3io.py +1 -1
  29. mlrun/db/base.py +1 -1
  30. mlrun/db/httpdb.py +9 -4
  31. mlrun/db/nopdb.py +1 -1
  32. mlrun/execution.py +28 -7
  33. mlrun/launcher/base.py +23 -13
  34. mlrun/launcher/local.py +3 -1
  35. mlrun/launcher/remote.py +4 -2
  36. mlrun/model.py +65 -0
  37. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +175 -8
  38. mlrun/package/packagers_manager.py +2 -0
  39. mlrun/projects/operations.py +8 -1
  40. mlrun/projects/pipelines.py +40 -18
  41. mlrun/projects/project.py +28 -5
  42. mlrun/run.py +42 -2
  43. mlrun/runtimes/__init__.py +6 -0
  44. mlrun/runtimes/base.py +24 -6
  45. mlrun/runtimes/daskjob.py +1 -0
  46. mlrun/runtimes/databricks_job/databricks_runtime.py +1 -0
  47. mlrun/runtimes/local.py +1 -6
  48. mlrun/serving/server.py +1 -2
  49. mlrun/serving/states.py +438 -23
  50. mlrun/serving/system_steps.py +27 -29
  51. mlrun/utils/helpers.py +13 -2
  52. mlrun/utils/notifications/notification_pusher.py +15 -0
  53. mlrun/utils/version/version.json +2 -2
  54. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc13.dist-info}/METADATA +2 -2
  55. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc13.dist-info}/RECORD +59 -55
  56. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc13.dist-info}/WHEEL +0 -0
  57. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc13.dist-info}/entry_points.txt +0 -0
  58. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc13.dist-info}/licenses/LICENSE +0 -0
  59. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc13.dist-info}/top_level.txt +0 -0
mlrun/serving/states.py CHANGED
@@ -35,6 +35,7 @@ from storey import ParallelExecutionMechanisms
35
35
  import mlrun
36
36
  import mlrun.artifacts
37
37
  import mlrun.common.schemas as schemas
38
+ from mlrun.artifacts.llm_prompt import LLMPromptArtifact
38
39
  from mlrun.artifacts.model import ModelArtifact
39
40
  from mlrun.datastore.datastore_profile import (
40
41
  DatastoreProfileKafkaSource,
@@ -42,7 +43,7 @@ from mlrun.datastore.datastore_profile import (
42
43
  DatastoreProfileV3io,
43
44
  datastore_profile_read,
44
45
  )
45
- from mlrun.datastore.store_resources import get_store_resource
46
+ from mlrun.datastore.model_provider.model_provider import ModelProvider
46
47
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
47
48
  from mlrun.utils import logger
48
49
 
@@ -516,7 +517,7 @@ class BaseStep(ModelObj):
516
517
  "ModelRunnerStep can be added to 'Flow' topology graph only"
517
518
  )
518
519
  step_model_endpoints_names = list(
519
- step.class_args[schemas.ModelRunnerStepData.MODELS].keys()
520
+ step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}).keys()
520
521
  )
521
522
  # Get all model_endpoints names that are in both lists
522
523
  common_endpoints_names = list(
@@ -528,8 +529,77 @@ class BaseStep(ModelObj):
528
529
  raise GraphError(
529
530
  f"The graph already contains the model endpoints named - {common_endpoints_names}."
530
531
  )
532
+
533
+ # Check if shared models are defined in the graph
534
+ self._verify_shared_models(root, step, step_model_endpoints_names)
535
+ # Update model endpoints names in the root step
531
536
  root.update_model_endpoints_names(step_model_endpoints_names)
532
537
 
538
+ @staticmethod
539
+ def _verify_shared_models(
540
+ root: "RootFlowStep",
541
+ step: "ModelRunnerStep",
542
+ step_model_endpoints_names: list[str],
543
+ ) -> None:
544
+ proxy_endpoints = [
545
+ name
546
+ for name in step_model_endpoints_names
547
+ if step.class_args.get(
548
+ schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM, {}
549
+ ).get(name)
550
+ == ParallelExecutionMechanisms.shared_executor
551
+ ]
552
+ shared_models = []
553
+
554
+ for name in proxy_endpoints:
555
+ shared_runnable_name = (
556
+ step.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
557
+ .get(name, ["", {}])[schemas.ModelsData.MODEL_PARAMETERS.value]
558
+ .get("shared_runnable_name")
559
+ )
560
+ model_artifact_uri = (
561
+ step.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
562
+ .get(name, ["", {}])[schemas.ModelsData.MODEL_PARAMETERS.value]
563
+ .get("artifact_uri")
564
+ )
565
+ prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
566
+ # if the model artifact is a prompt, we need to get the model URI
567
+ # to ensure that the shared runnable name is correct
568
+ if prefix == mlrun.utils.StorePrefix.LLMPrompt:
569
+ llm_artifact, _ = mlrun.store_manager.get_store_artifact(
570
+ model_artifact_uri
571
+ )
572
+ model_artifact_uri = llm_artifact.spec.parent_uri
573
+ actual_shared_name = root.get_shared_model_name_by_artifact_uri(
574
+ model_artifact_uri
575
+ )
576
+
577
+ if not shared_runnable_name:
578
+ if not actual_shared_name:
579
+ raise GraphError(
580
+ f"Can't find shared model for {name} model endpoint"
581
+ )
582
+ else:
583
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
584
+ schemas.ModelsData.MODEL_PARAMETERS.value
585
+ ]["shared_runnable_name"] = actual_shared_name
586
+ shared_models.append(actual_shared_name)
587
+ elif actual_shared_name != shared_runnable_name:
588
+ raise GraphError(
589
+ f"Model endpoint {name} shared runnable name mismatch: "
590
+ f"expected {actual_shared_name}, got {shared_runnable_name}"
591
+ )
592
+ else:
593
+ shared_models.append(actual_shared_name)
594
+
595
+ undefined_shared_models = list(
596
+ set(shared_models) - set(root.shared_models.keys())
597
+ )
598
+ if undefined_shared_models:
599
+ raise GraphError(
600
+ f"The following shared models are not defined in the graph: {undefined_shared_models}."
601
+ )
602
+
533
603
 
534
604
  class TaskStep(BaseStep):
535
605
  """task execution step, runs a class or handler"""
@@ -1006,7 +1076,13 @@ class RouterStep(TaskStep):
1006
1076
 
1007
1077
 
1008
1078
  class Model(storey.ParallelExecutionRunnable, ModelObj):
1009
- _dict_fields = ["name", "raise_exception", "artifact_uri"]
1079
+ _dict_fields = [
1080
+ "name",
1081
+ "raise_exception",
1082
+ "artifact_uri",
1083
+ "shared_runnable_name",
1084
+ ]
1085
+ kind = "model"
1010
1086
 
1011
1087
  def __init__(
1012
1088
  self,
@@ -1019,6 +1095,9 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1019
1095
  if artifact_uri is not None and not isinstance(artifact_uri, str):
1020
1096
  raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
1021
1097
  self.artifact_uri = artifact_uri
1098
+ self.invocation_artifact: Optional[LLMPromptArtifact] = None
1099
+ self.model_artifact: Optional[ModelArtifact] = None
1100
+ self.model_provider: Optional[ModelProvider] = None
1022
1101
 
1023
1102
  def __init_subclass__(cls):
1024
1103
  super().__init_subclass__()
@@ -1030,12 +1109,27 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1030
1109
 
1031
1110
  def load(self) -> None:
1032
1111
  """Override to load model if needed."""
1033
- pass
1112
+ self._load_artifacts()
1113
+ if self.model_artifact:
1114
+ self.model_provider = mlrun.get_model_provider(
1115
+ url=self.model_artifact.model_url,
1116
+ default_invoke_kwargs=self.model_artifact.default_config,
1117
+ raise_missing_schema_exception=False,
1118
+ )
1119
+
1120
+ def _load_artifacts(self) -> None:
1121
+ artifact = self._get_artifact_object()
1122
+ if isinstance(artifact, LLMPromptArtifact):
1123
+ self.invocation_artifact = artifact
1124
+ self.model_artifact = self.invocation_artifact.model_artifact
1125
+ else:
1126
+ self.model_artifact = artifact
1034
1127
 
1035
- def _get_artifact_object(self) -> Union[ModelArtifact, None]:
1128
+ def _get_artifact_object(self) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1036
1129
  if self.artifact_uri:
1037
1130
  if mlrun.datastore.is_store_uri(self.artifact_uri):
1038
- return get_store_resource(self.artifact_uri)
1131
+ artifact, _ = mlrun.store_manager.get_store_artifact(self.artifact_uri)
1132
+ return artifact
1039
1133
  else:
1040
1134
  raise ValueError(
1041
1135
  "Could not get artifact, 'artifact_uri' must be a valid artifact store URI"
@@ -1058,7 +1152,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1058
1152
  return self.predict(body)
1059
1153
 
1060
1154
  async def run_async(self, body: Any, path: str) -> Any:
1061
- return self.predict(body)
1155
+ return await self.predict_async(body)
1062
1156
 
1063
1157
  def get_local_model_path(self, suffix="") -> (str, dict):
1064
1158
  """
@@ -1218,14 +1312,105 @@ class ModelRunnerStep(MonitoredStep):
1218
1312
  self.raise_exception = raise_exception
1219
1313
  self.shape = "folder"
1220
1314
 
1315
+ def add_shared_model_proxy(
1316
+ self,
1317
+ endpoint_name: str,
1318
+ model_artifact: Union[str, ModelArtifact, LLMPromptArtifact],
1319
+ shared_model_name: Optional[str] = None,
1320
+ labels: Optional[Union[list[str], dict[str, str]]] = None,
1321
+ model_endpoint_creation_strategy: Optional[
1322
+ schemas.ModelEndpointCreationStrategy
1323
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1324
+ inputs: Optional[list[str]] = None,
1325
+ outputs: Optional[list[str]] = None,
1326
+ input_path: Optional[str] = None,
1327
+ result_path: Optional[str] = None,
1328
+ override: bool = False,
1329
+ ) -> None:
1330
+ """
1331
+ Add a proxy model to the ModelRunnerStep, which is a proxy for a model that is already defined as shared model
1332
+ within the graph
1333
+
1334
+ :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1335
+ :param model_artifact: model artifact or mlrun model artifact uri, according to the model artifact
1336
+ we will match the model endpoint to the correct shared model.
1337
+ :param shared_model_name: str, the name of the shared model that is already defined within the graph
1338
+ :param labels: model endpoint labels, should be list of str or mapping of str:str
1339
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1340
+ * **overwrite**:
1341
+ 1. If model endpoints with the same name exist, delete the `latest` one.
1342
+ 2. Create a new model endpoint entry and set it as `latest`.
1343
+ * **inplace** (default):
1344
+ 1. If model endpoints with the same name exist, update the `latest` entry.
1345
+ 2. Otherwise, create a new entry.
1346
+ * **archive**:
1347
+ 1. If model endpoints with the same name exist, preserve them.
1348
+ 2. Create a new model endpoint with the same name and set it to `latest`.
1349
+
1350
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1351
+ that been configured in the model artifact, please note that those inputs need to
1352
+ be equal in length and order to the inputs that model_class predict method expects
1353
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1354
+ that been configured in the model artifact, please note that those outputs need to
1355
+ be equal to the model_class predict method outputs (length, and order)
1356
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1357
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1358
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
1359
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1360
+ in path.
1361
+ :param override: bool allow override existing model on the current ModelRunnerStep.
1362
+ """
1363
+ model_class = Model(
1364
+ name=endpoint_name,
1365
+ shared_runnable_name=shared_model_name,
1366
+ )
1367
+ if isinstance(model_artifact, str):
1368
+ model_artifact_uri = model_artifact
1369
+ elif isinstance(model_artifact, ModelArtifact):
1370
+ model_artifact_uri = model_artifact.uri
1371
+ elif isinstance(model_artifact, LLMPromptArtifact):
1372
+ model_artifact_uri = model_artifact.model_artifact.uri
1373
+ else:
1374
+ raise MLRunInvalidArgumentError(
1375
+ "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1376
+ )
1377
+ root = self._extract_root_step()
1378
+ if isinstance(root, RootFlowStep):
1379
+ shared_model_name = (
1380
+ shared_model_name
1381
+ or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1382
+ )
1383
+ if not root.shared_models or (
1384
+ root.shared_models
1385
+ and shared_model_name
1386
+ and shared_model_name not in root.shared_models.keys()
1387
+ ):
1388
+ raise GraphError(
1389
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1390
+ f"model {shared_model_name} is not in the shared models."
1391
+ )
1392
+ self.add_model(
1393
+ endpoint_name=endpoint_name,
1394
+ model_class=model_class,
1395
+ execution_mechanism=ParallelExecutionMechanisms.shared_executor,
1396
+ model_artifact=model_artifact,
1397
+ labels=labels,
1398
+ model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1399
+ override=override,
1400
+ inputs=inputs,
1401
+ outputs=outputs,
1402
+ input_path=input_path,
1403
+ result_path=result_path,
1404
+ )
1405
+
1221
1406
  def add_model(
1222
1407
  self,
1223
1408
  endpoint_name: str,
1224
1409
  model_class: Union[str, Model],
1225
1410
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
1226
- model_artifact: Optional[Union[str, mlrun.artifacts.ModelArtifact]] = None,
1411
+ model_artifact: Optional[Union[str, ModelArtifact, LLMPromptArtifact]] = None,
1227
1412
  labels: Optional[Union[list[str], dict[str, str]]] = None,
1228
- creation_strategy: Optional[
1413
+ model_endpoint_creation_strategy: Optional[
1229
1414
  schemas.ModelEndpointCreationStrategy
1230
1415
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1231
1416
  inputs: Optional[list[str]] = None,
@@ -1263,7 +1448,7 @@ class ModelRunnerStep(MonitoredStep):
1263
1448
 
1264
1449
  :param model_artifact: model artifact or mlrun model artifact uri
1265
1450
  :param labels: model endpoint labels, should be list of str or mapping of str:str
1266
- :param creation_strategy: Strategy for creating or updating the model endpoint:
1451
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1267
1452
  * **overwrite**:
1268
1453
  1. If model endpoints with the same name exist, delete the `latest` one.
1269
1454
  2. Create a new model endpoint entry and set it as `latest`.
@@ -1288,7 +1473,6 @@ class ModelRunnerStep(MonitoredStep):
1288
1473
  :param override: bool allow override existing model on the current ModelRunnerStep.
1289
1474
  :param model_parameters: Parameters for model instantiation
1290
1475
  """
1291
-
1292
1476
  if isinstance(model_class, Model) and model_parameters:
1293
1477
  raise mlrun.errors.MLRunInvalidArgumentError(
1294
1478
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
@@ -1297,10 +1481,20 @@ class ModelRunnerStep(MonitoredStep):
1297
1481
  model_parameters = model_parameters or (
1298
1482
  model_class.to_dict() if isinstance(model_class, Model) else {}
1299
1483
  )
1300
- if outputs is None and isinstance(
1301
- model_artifact, mlrun.artifacts.ModelArtifact
1484
+
1485
+ if isinstance(
1486
+ model_artifact,
1487
+ str,
1302
1488
  ):
1303
- outputs = [feature.name for feature in model_artifact.spec.outputs]
1489
+ try:
1490
+ model_artifact, _ = mlrun.store_manager.get_store_artifact(
1491
+ model_artifact
1492
+ )
1493
+ except mlrun.errors.MLRunNotFoundError:
1494
+ raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
1495
+
1496
+ outputs = outputs or self._get_model_output_schema(model_artifact)
1497
+
1304
1498
  model_artifact = (
1305
1499
  model_artifact.uri
1306
1500
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
@@ -1347,7 +1541,7 @@ class ModelRunnerStep(MonitoredStep):
1347
1541
  schemas.MonitoringData.OUTPUTS: outputs,
1348
1542
  schemas.MonitoringData.INPUT_PATH: input_path,
1349
1543
  schemas.MonitoringData.RESULT_PATH: result_path,
1350
- schemas.MonitoringData.CREATION_STRATEGY: creation_strategy,
1544
+ schemas.MonitoringData.CREATION_STRATEGY: model_endpoint_creation_strategy,
1351
1545
  schemas.MonitoringData.LABELS: labels,
1352
1546
  schemas.MonitoringData.MODEL_PATH: model_artifact,
1353
1547
  schemas.MonitoringData.MODEL_CLASS: model_class,
@@ -1357,14 +1551,44 @@ class ModelRunnerStep(MonitoredStep):
1357
1551
 
1358
1552
  @staticmethod
1359
1553
  def _get_model_output_schema(
1360
- model: str, monitoring_data: dict[str, dict[str, str]]
1554
+ model_artifact: Union[ModelArtifact, LLMPromptArtifact],
1555
+ ) -> Optional[list[str]]:
1556
+ if isinstance(
1557
+ model_artifact,
1558
+ ModelArtifact,
1559
+ ):
1560
+ return [feature.name for feature in model_artifact.spec.outputs]
1561
+ elif isinstance(
1562
+ model_artifact,
1563
+ LLMPromptArtifact,
1564
+ ):
1565
+ _model_artifact = model_artifact.model_artifact
1566
+ return [feature.name for feature in _model_artifact.spec.outputs]
1567
+
1568
+ @staticmethod
1569
+ def _get_model_endpoint_output_schema(
1570
+ name: str,
1571
+ project: str,
1572
+ uid: str,
1361
1573
  ) -> list[str]:
1362
1574
  output_schema = None
1363
- if monitoring_data[model].get(schemas.MonitoringData.MODEL_PATH) is not None:
1364
- artifact = get_store_resource(
1365
- monitoring_data[model].get(schemas.MonitoringData.MODEL_PATH)
1575
+ try:
1576
+ model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
1577
+ mlrun.db.get_run_db().get_model_endpoint(
1578
+ name=name,
1579
+ project=project,
1580
+ endpoint_id=uid,
1581
+ tsdb_metrics=False,
1582
+ )
1583
+ )
1584
+ output_schema = model_endpoint.spec.label_names
1585
+ except (
1586
+ mlrun.errors.MLRunNotFoundError,
1587
+ mlrun.errors.MLRunInvalidArgumentError,
1588
+ ):
1589
+ logger.warning(
1590
+ f"Model endpoint not found, using default output schema for model {name}"
1366
1591
  )
1367
- output_schema = [feature.name for feature in artifact.spec.outputs]
1368
1592
  return output_schema
1369
1593
 
1370
1594
  @staticmethod
@@ -1385,8 +1609,14 @@ class ModelRunnerStep(MonitoredStep):
1385
1609
  if isinstance(monitoring_data, dict):
1386
1610
  for model in monitoring_data:
1387
1611
  monitoring_data[model][schemas.MonitoringData.OUTPUTS] = (
1388
- monitoring_data[model][schemas.MonitoringData.OUTPUTS]
1389
- or self._get_model_output_schema(model, monitoring_data)
1612
+ monitoring_data.get(model, {}).get(schemas.MonitoringData.OUTPUTS)
1613
+ or self._get_model_endpoint_output_schema(
1614
+ name=model,
1615
+ project=self.context.project if self.context else None,
1616
+ uid=monitoring_data.get(model, {}).get(
1617
+ mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
1618
+ ),
1619
+ )
1390
1620
  )
1391
1621
  # Prevent calling _get_model_output_schema for same model more than once
1392
1622
  self.class_args[
@@ -1407,6 +1637,10 @@ class ModelRunnerStep(MonitoredStep):
1407
1637
  return monitoring_data
1408
1638
 
1409
1639
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1640
+ self.context = context
1641
+ if not self._is_local_function(context):
1642
+ # skip init of non local functions
1643
+ return
1410
1644
  model_selector = self.class_args.get("model_selector")
1411
1645
  execution_mechanism_by_model_name = self.class_args.get(
1412
1646
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
@@ -1748,7 +1982,7 @@ class FlowStep(BaseStep):
1748
1982
  self._insert_all_error_handlers()
1749
1983
  self.check_and_process_graph()
1750
1984
 
1751
- for step in self._steps.values():
1985
+ for step in self.steps.values():
1752
1986
  step.set_parent(self)
1753
1987
  step.init_object(context, namespace, mode, reset=reset)
1754
1988
  self._set_error_handler()
@@ -2111,6 +2345,11 @@ class RootFlowStep(FlowStep):
2111
2345
  "model_endpoints_names",
2112
2346
  "model_endpoints_routes_names",
2113
2347
  "track_models",
2348
+ "shared_max_processes",
2349
+ "shared_max_threads",
2350
+ "shared_models",
2351
+ "shared_models_mechanism",
2352
+ "pool_factor",
2114
2353
  ]
2115
2354
 
2116
2355
  def __init__(
@@ -2131,6 +2370,140 @@ class RootFlowStep(FlowStep):
2131
2370
  self._models = set()
2132
2371
  self._route_models = set()
2133
2372
  self._track_models = False
2373
+ self._shared_models: dict[str, tuple[str, dict]] = {}
2374
+ self._shared_models_mechanism: dict[str, ParallelExecutionMechanisms] = {}
2375
+ self._shared_max_processes = None
2376
+ self._shared_max_threads = None
2377
+ self._pool_factor = None
2378
+
2379
+ def add_shared_model(
2380
+ self,
2381
+ name: str,
2382
+ model_class: Union[str, Model],
2383
+ execution_mechanism: Union[str, ParallelExecutionMechanisms],
2384
+ model_artifact: Optional[Union[str, ModelArtifact]],
2385
+ override: bool = False,
2386
+ **model_parameters,
2387
+ ) -> None:
2388
+ """
2389
+ Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2390
+ :param name: Name of the shared model (should be unique in the graph)
2391
+ :param model_class: Model class name
2392
+ :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2393
+ * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
2394
+ intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
2395
+ Lock (GIL).
2396
+ * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
2397
+ tasks that also require significant Runnable-specific initialization (e.g. a large model).
2398
+ * "thread_pool" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
2399
+ otherwise block the main event loop thread.
2400
+ * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
2401
+ event loop to continue running while waiting for a response.
2402
+ * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
2403
+ runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
2404
+ useful when:
2405
+ - You want to share a heavy resource like a large model loaded onto a GPU.
2406
+ - You want to centralize task scheduling or coordination for multiple lightweight tasks.
2407
+ - You aim to minimize overhead from creating new executors or processes/threads per runnable.
2408
+ The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
2409
+ memory and hardware accelerators.
2410
+ * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
2411
+ It means that the runnable will not actually be run in parallel to anything else.
2412
+
2413
+ :param model_artifact: model artifact or mlrun model artifact uri
2414
+ :param override: bool allow override existing model on the current ModelRunnerStep.
2415
+ :param model_parameters: Parameters for model instantiation
2416
+ """
2417
+ if isinstance(model_class, Model) and model_parameters:
2418
+ raise mlrun.errors.MLRunInvalidArgumentError(
2419
+ "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2420
+ )
2421
+
2422
+ if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2423
+ raise mlrun.errors.MLRunInvalidArgumentError(
2424
+ "Cannot add a shared model with execution mechanism 'shared_executor'"
2425
+ )
2426
+ ParallelExecutionMechanisms.validate(execution_mechanism)
2427
+
2428
+ model_parameters = model_parameters or (
2429
+ model_class.to_dict() if isinstance(model_class, Model) else {}
2430
+ )
2431
+ model_artifact = (
2432
+ model_artifact.uri
2433
+ if isinstance(model_artifact, mlrun.artifacts.Artifact)
2434
+ else model_artifact
2435
+ )
2436
+ model_parameters["artifact_uri"] = model_parameters.get(
2437
+ "artifact_uri", model_artifact
2438
+ )
2439
+
2440
+ if model_parameters.get("name", name) != name or (
2441
+ isinstance(model_class, Model) and model_class.name != name
2442
+ ):
2443
+ raise mlrun.errors.MLRunInvalidArgumentError(
2444
+ "Inconsistent name for the added model."
2445
+ )
2446
+ model_parameters["name"] = name
2447
+
2448
+ if name in self.shared_models and not override:
2449
+ raise mlrun.errors.MLRunInvalidArgumentError(
2450
+ f"Model with name {name} already exists in this graph."
2451
+ )
2452
+
2453
+ model_class = (
2454
+ model_class
2455
+ if isinstance(model_class, str)
2456
+ else model_class.__class__.__name__
2457
+ )
2458
+ self.shared_models[name] = (model_class, model_parameters)
2459
+ self.shared_models_mechanism[name] = execution_mechanism
2460
+
2461
+ def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
2462
+ """
2463
+ Get a shared model by its artifact URI.
2464
+ :param artifact_uri: The artifact URI of the model.
2465
+ :return: A tuple of (model_class, model_parameters) if found, otherwise None.
2466
+ """
2467
+ for model_name, (model_class, model_params) in self.shared_models.items():
2468
+ if model_params.get("artifact_uri") == artifact_uri:
2469
+ return model_name
2470
+ return None
2471
+
2472
+ def config_pool_resource(
2473
+ self,
2474
+ max_processes: Optional[int] = None,
2475
+ max_threads: Optional[int] = None,
2476
+ pool_factor: Optional[int] = None,
2477
+ ) -> None:
2478
+ """
2479
+ Configure the resource limits for the shared models in the graph.
2480
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2481
+ Defaults to the number of CPUs or 16 if undetectable.
2482
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2483
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2484
+ """
2485
+ self.shared_max_processes = max_processes
2486
+ self.shared_max_threads = max_threads
2487
+ self.pool_factor = pool_factor
2488
+
2489
+ def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
2490
+ self.context = context
2491
+ if self.shared_models:
2492
+ self.context.executor = storey.flow.RunnableExecutor(
2493
+ max_processes=self.shared_max_processes,
2494
+ max_threads=self.shared_max_threads,
2495
+ pool_factor=self.pool_factor,
2496
+ )
2497
+
2498
+ for model, model_params in self.shared_models.values():
2499
+ model = get_class(model, namespace).from_dict(
2500
+ model_params, init_with_params=True
2501
+ )
2502
+ model._raise_exception = False
2503
+ self.context.executor.add_runnable(
2504
+ model, self._shared_models_mechanism[model.name]
2505
+ )
2506
+ super().init_object(context, namespace, mode, reset=reset, **extra_kwargs)
2134
2507
 
2135
2508
  @property
2136
2509
  def model_endpoints_names(self) -> list[str]:
@@ -2159,6 +2532,48 @@ class RootFlowStep(FlowStep):
2159
2532
  def track_models(self, track_models: bool):
2160
2533
  self._track_models = track_models
2161
2534
 
2535
+ @property
2536
+ def shared_models(self) -> dict[str, tuple[str, dict]]:
2537
+ return self._shared_models
2538
+
2539
+ @shared_models.setter
2540
+ def shared_models(self, shared_models: dict[str, tuple[str, dict]]):
2541
+ self._shared_models = shared_models
2542
+
2543
+ @property
2544
+ def shared_models_mechanism(self) -> dict[str, ParallelExecutionMechanisms]:
2545
+ return self._shared_models_mechanism
2546
+
2547
+ @shared_models_mechanism.setter
2548
+ def shared_models_mechanism(
2549
+ self, shared_models_mechanism: dict[str, ParallelExecutionMechanisms]
2550
+ ):
2551
+ self._shared_models_mechanism = shared_models_mechanism
2552
+
2553
+ @property
2554
+ def shared_max_processes(self) -> Optional[int]:
2555
+ return self._shared_max_processes
2556
+
2557
+ @shared_max_processes.setter
2558
+ def shared_max_processes(self, max_processes: Optional[int]):
2559
+ self._shared_max_processes = max_processes
2560
+
2561
+ @property
2562
+ def shared_max_threads(self) -> Optional[int]:
2563
+ return self._shared_max_threads
2564
+
2565
+ @shared_max_threads.setter
2566
+ def shared_max_threads(self, max_threads: Optional[int]):
2567
+ self._shared_max_threads = max_threads
2568
+
2569
+ @property
2570
+ def pool_factor(self) -> Optional[int]:
2571
+ return self._pool_factor
2572
+
2573
+ @pool_factor.setter
2574
+ def pool_factor(self, pool_factor: Optional[int]):
2575
+ self._pool_factor = pool_factor
2576
+
2162
2577
  def update_model_endpoints_routes_names(self, model_endpoints_names: list):
2163
2578
  self._route_models.update(model_endpoints_names)
2164
2579