mlrun 1.10.0rc12__py3-none-any.whl → 1.10.0rc14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

mlrun/serving/states.py CHANGED
@@ -44,7 +44,6 @@ from mlrun.datastore.datastore_profile import (
44
44
  datastore_profile_read,
45
45
  )
46
46
  from mlrun.datastore.model_provider.model_provider import ModelProvider
47
- from mlrun.datastore.store_resources import get_store_resource
48
47
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
49
48
  from mlrun.utils import logger
50
49
 
@@ -518,7 +517,7 @@ class BaseStep(ModelObj):
518
517
  "ModelRunnerStep can be added to 'Flow' topology graph only"
519
518
  )
520
519
  step_model_endpoints_names = list(
521
- step.class_args[schemas.ModelRunnerStepData.MODELS].keys()
520
+ step.class_args.get(schemas.ModelRunnerStepData.MODELS, {}).keys()
522
521
  )
523
522
  # Get all model_endpoints names that are in both lists
524
523
  common_endpoints_names = list(
@@ -530,8 +529,77 @@ class BaseStep(ModelObj):
530
529
  raise GraphError(
531
530
  f"The graph already contains the model endpoints named - {common_endpoints_names}."
532
531
  )
532
+
533
+ # Check if shared models are defined in the graph
534
+ self._verify_shared_models(root, step, step_model_endpoints_names)
535
+ # Update model endpoints names in the root step
533
536
  root.update_model_endpoints_names(step_model_endpoints_names)
534
537
 
538
+ @staticmethod
539
+ def _verify_shared_models(
540
+ root: "RootFlowStep",
541
+ step: "ModelRunnerStep",
542
+ step_model_endpoints_names: list[str],
543
+ ) -> None:
544
+ proxy_endpoints = [
545
+ name
546
+ for name in step_model_endpoints_names
547
+ if step.class_args.get(
548
+ schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM, {}
549
+ ).get(name)
550
+ == ParallelExecutionMechanisms.shared_executor
551
+ ]
552
+ shared_models = []
553
+
554
+ for name in proxy_endpoints:
555
+ shared_runnable_name = (
556
+ step.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
557
+ .get(name, ["", {}])[schemas.ModelsData.MODEL_PARAMETERS.value]
558
+ .get("shared_runnable_name")
559
+ )
560
+ model_artifact_uri = (
561
+ step.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
562
+ .get(name, ["", {}])[schemas.ModelsData.MODEL_PARAMETERS.value]
563
+ .get("artifact_uri")
564
+ )
565
+ prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
566
+ # if the model artifact is a prompt, we need to get the model URI
567
+ # to ensure that the shared runnable name is correct
568
+ if prefix == mlrun.utils.StorePrefix.LLMPrompt:
569
+ llm_artifact, _ = mlrun.store_manager.get_store_artifact(
570
+ model_artifact_uri
571
+ )
572
+ model_artifact_uri = llm_artifact.spec.parent_uri
573
+ actual_shared_name = root.get_shared_model_name_by_artifact_uri(
574
+ model_artifact_uri
575
+ )
576
+
577
+ if not shared_runnable_name:
578
+ if not actual_shared_name:
579
+ raise GraphError(
580
+ f"Can't find shared model for {name} model endpoint"
581
+ )
582
+ else:
583
+ step.class_args[schemas.ModelRunnerStepData.MODELS][name][
584
+ schemas.ModelsData.MODEL_PARAMETERS.value
585
+ ]["shared_runnable_name"] = actual_shared_name
586
+ shared_models.append(actual_shared_name)
587
+ elif actual_shared_name != shared_runnable_name:
588
+ raise GraphError(
589
+ f"Model endpoint {name} shared runnable name mismatch: "
590
+ f"expected {actual_shared_name}, got {shared_runnable_name}"
591
+ )
592
+ else:
593
+ shared_models.append(actual_shared_name)
594
+
595
+ undefined_shared_models = list(
596
+ set(shared_models) - set(root.shared_models.keys())
597
+ )
598
+ if undefined_shared_models:
599
+ raise GraphError(
600
+ f"The following shared models are not defined in the graph: {undefined_shared_models}."
601
+ )
602
+
535
603
 
536
604
  class TaskStep(BaseStep):
537
605
  """task execution step, runs a class or handler"""
@@ -1008,19 +1076,30 @@ class RouterStep(TaskStep):
1008
1076
 
1009
1077
 
1010
1078
  class Model(storey.ParallelExecutionRunnable, ModelObj):
1011
- _dict_fields = ["name", "raise_exception", "artifact_uri"]
1079
+ _dict_fields = [
1080
+ "name",
1081
+ "raise_exception",
1082
+ "artifact_uri",
1083
+ "shared_runnable_name",
1084
+ "shared_proxy_mapping",
1085
+ ]
1086
+ kind = "model"
1012
1087
 
1013
1088
  def __init__(
1014
1089
  self,
1015
1090
  name: str,
1016
1091
  raise_exception: bool = True,
1017
1092
  artifact_uri: Optional[str] = None,
1093
+ shared_proxy_mapping: Optional[dict] = None,
1018
1094
  **kwargs,
1019
1095
  ):
1020
1096
  super().__init__(name=name, raise_exception=raise_exception, **kwargs)
1021
1097
  if artifact_uri is not None and not isinstance(artifact_uri, str):
1022
1098
  raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
1023
1099
  self.artifact_uri = artifact_uri
1100
+ self.shared_proxy_mapping: dict[
1101
+ str : Union[str, ModelArtifact, LLMPromptArtifact]
1102
+ ] = shared_proxy_mapping
1024
1103
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1025
1104
  self.model_artifact: Optional[ModelArtifact] = None
1026
1105
  self.model_provider: Optional[ModelProvider] = None
@@ -1051,10 +1130,13 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1051
1130
  else:
1052
1131
  self.model_artifact = artifact
1053
1132
 
1054
- def _get_artifact_object(self) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1055
- if self.artifact_uri:
1056
- if mlrun.datastore.is_store_uri(self.artifact_uri):
1057
- artifact, _ = mlrun.store_manager.get_store_artifact(self.artifact_uri)
1133
+ def _get_artifact_object(
1134
+ self, proxy_uri: Optional[str] = None
1135
+ ) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1136
+ uri = proxy_uri or self.artifact_uri
1137
+ if uri:
1138
+ if mlrun.datastore.is_store_uri(uri):
1139
+ artifact, _ = mlrun.store_manager.get_store_artifact(uri)
1058
1140
  return artifact
1059
1141
  else:
1060
1142
  raise ValueError(
@@ -1074,10 +1156,12 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1074
1156
  """Override to implement prediction logic if the logic requires asyncio."""
1075
1157
  return body
1076
1158
 
1077
- def run(self, body: Any, path: str) -> Any:
1159
+ def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1078
1160
  return self.predict(body)
1079
1161
 
1080
- async def run_async(self, body: Any, path: str) -> Any:
1162
+ async def run_async(
1163
+ self, body: Any, path: str, origin_name: Optional[str] = None
1164
+ ) -> Any:
1081
1165
  return await self.predict_async(body)
1082
1166
 
1083
1167
  def get_local_model_path(self, suffix="") -> (str, dict):
@@ -1112,6 +1196,65 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1112
1196
  return None, None
1113
1197
 
1114
1198
 
1199
+ class LLModel(Model):
1200
+ def __init__(self, name: str, **kwargs):
1201
+ super().__init__(name, **kwargs)
1202
+
1203
+ def predict(
1204
+ self, body: Any, messages: list[dict], model_configuration: dict
1205
+ ) -> Any:
1206
+ return body
1207
+
1208
+ async def predict_async(
1209
+ self, body: Any, messages: list[dict], model_configuration: dict
1210
+ ) -> Any:
1211
+ return body
1212
+
1213
+ def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1214
+ messages, model_configuration = self.enrich_prompt(body, origin_name)
1215
+ return self.predict(
1216
+ body, messages=messages, model_configuration=model_configuration
1217
+ )
1218
+
1219
+ async def run_async(
1220
+ self, body: Any, path: str, origin_name: Optional[str] = None
1221
+ ) -> Any:
1222
+ messages, model_configuration = self.enrich_prompt(body, origin_name)
1223
+ return await self.predict_async(
1224
+ body, messages=messages, model_configuration=model_configuration
1225
+ )
1226
+
1227
+ def enrich_prompt(
1228
+ self, body: dict, origin_name: str
1229
+ ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1230
+ if origin_name and self.shared_proxy_mapping:
1231
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1232
+ if isinstance(llm_prompt_artifact, str):
1233
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1234
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1235
+ else:
1236
+ llm_prompt_artifact = (
1237
+ self.invocation_artifact or self._get_artifact_object()
1238
+ )
1239
+ if not (
1240
+ llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1241
+ ):
1242
+ logger.warning(
1243
+ "LLMModel must be provided with LLMPromptArtifact",
1244
+ llm_prompt_artifact=llm_prompt_artifact,
1245
+ )
1246
+ return None, None
1247
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1248
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1249
+ kwargs = {
1250
+ place_holder: body.get(body_map["field"])
1251
+ for place_holder, body_map in prompt_legend.items()
1252
+ }
1253
+ for d in prompt_template:
1254
+ d["content"] = d["content"].format(**kwargs)
1255
+ return prompt_template, llm_prompt_artifact.spec.model_configuration
1256
+
1257
+
1115
1258
  class ModelSelector:
1116
1259
  """Used to select which models to run on each event."""
1117
1260
 
@@ -1218,6 +1361,7 @@ class ModelRunnerStep(MonitoredStep):
1218
1361
  """
1219
1362
 
1220
1363
  kind = "model_runner"
1364
+ _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
1221
1365
 
1222
1366
  def __init__(
1223
1367
  self,
@@ -1237,17 +1381,122 @@ class ModelRunnerStep(MonitoredStep):
1237
1381
  )
1238
1382
  self.raise_exception = raise_exception
1239
1383
  self.shape = "folder"
1384
+ self._shared_proxy_mapping = {}
1385
+
1386
+ def add_shared_model_proxy(
1387
+ self,
1388
+ endpoint_name: str,
1389
+ model_artifact: Union[str, ModelArtifact, LLMPromptArtifact],
1390
+ shared_model_name: Optional[str] = None,
1391
+ labels: Optional[Union[list[str], dict[str, str]]] = None,
1392
+ model_endpoint_creation_strategy: Optional[
1393
+ schemas.ModelEndpointCreationStrategy
1394
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1395
+ inputs: Optional[list[str]] = None,
1396
+ outputs: Optional[list[str]] = None,
1397
+ input_path: Optional[str] = None,
1398
+ result_path: Optional[str] = None,
1399
+ override: bool = False,
1400
+ ) -> None:
1401
+ """
1402
+ Add a proxy model to the ModelRunnerStep, which is a proxy for a model that is already defined as shared model
1403
+ within the graph
1404
+
1405
+ :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1406
+ :param model_artifact: model artifact or mlrun model artifact uri, according to the model artifact
1407
+ we will match the model endpoint to the correct shared model.
1408
+ :param shared_model_name: str, the name of the shared model that is already defined within the graph
1409
+ :param labels: model endpoint labels, should be list of str or mapping of str:str
1410
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1411
+ * **overwrite**:
1412
+ 1. If model endpoints with the same name exist, delete the `latest` one.
1413
+ 2. Create a new model endpoint entry and set it as `latest`.
1414
+ * **inplace** (default):
1415
+ 1. If model endpoints with the same name exist, update the `latest` entry.
1416
+ 2. Otherwise, create a new entry.
1417
+ * **archive**:
1418
+ 1. If model endpoints with the same name exist, preserve them.
1419
+ 2. Create a new model endpoint with the same name and set it to `latest`.
1420
+
1421
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1422
+ that been configured in the model artifact, please note that those inputs need to
1423
+ be equal in length and order to the inputs that model_class predict method expects
1424
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1425
+ that been configured in the model artifact, please note that those outputs need to
1426
+ be equal to the model_class predict method outputs (length, and order)
1427
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1428
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1429
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
1430
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1431
+ in path.
1432
+ :param override: bool allow override existing model on the current ModelRunnerStep.
1433
+ """
1434
+ model_class, model_params = (
1435
+ "mlrun.serving.Model",
1436
+ {"name": endpoint_name, "shared_runnable_name": shared_model_name},
1437
+ )
1438
+ if isinstance(model_artifact, str):
1439
+ model_artifact_uri = model_artifact
1440
+ elif isinstance(model_artifact, ModelArtifact):
1441
+ model_artifact_uri = model_artifact.uri
1442
+ elif isinstance(model_artifact, LLMPromptArtifact):
1443
+ model_artifact_uri = model_artifact.model_artifact.uri
1444
+ else:
1445
+ raise MLRunInvalidArgumentError(
1446
+ "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1447
+ )
1448
+ root = self._extract_root_step()
1449
+ if isinstance(root, RootFlowStep):
1450
+ shared_model_name = (
1451
+ shared_model_name
1452
+ or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1453
+ )
1454
+ if not root.shared_models or (
1455
+ root.shared_models
1456
+ and shared_model_name
1457
+ and shared_model_name not in root.shared_models.keys()
1458
+ ):
1459
+ raise GraphError(
1460
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1461
+ f"model {shared_model_name} is not in the shared models."
1462
+ )
1463
+ if shared_model_name not in self._shared_proxy_mapping:
1464
+ self._shared_proxy_mapping[shared_model_name] = {
1465
+ endpoint_name: model_artifact.uri
1466
+ if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1467
+ else model_artifact
1468
+ }
1469
+ else:
1470
+ self._shared_proxy_mapping[shared_model_name].update(
1471
+ {
1472
+ endpoint_name: model_artifact.uri
1473
+ if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1474
+ else model_artifact
1475
+ }
1476
+ )
1477
+ self.add_model(
1478
+ endpoint_name=endpoint_name,
1479
+ model_class=model_class,
1480
+ execution_mechanism=ParallelExecutionMechanisms.shared_executor,
1481
+ model_artifact=model_artifact,
1482
+ labels=labels,
1483
+ model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1484
+ override=override,
1485
+ inputs=inputs,
1486
+ outputs=outputs,
1487
+ input_path=input_path,
1488
+ result_path=result_path,
1489
+ **model_params,
1490
+ )
1240
1491
 
1241
1492
  def add_model(
1242
1493
  self,
1243
1494
  endpoint_name: str,
1244
1495
  model_class: Union[str, Model],
1245
1496
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
1246
- model_artifact: Optional[
1247
- Union[str, mlrun.artifacts.ModelArtifact, mlrun.artifacts.LLMPromptArtifact]
1248
- ] = None,
1497
+ model_artifact: Optional[Union[str, ModelArtifact, LLMPromptArtifact]] = None,
1249
1498
  labels: Optional[Union[list[str], dict[str, str]]] = None,
1250
- creation_strategy: Optional[
1499
+ model_endpoint_creation_strategy: Optional[
1251
1500
  schemas.ModelEndpointCreationStrategy
1252
1501
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1253
1502
  inputs: Optional[list[str]] = None,
@@ -1285,7 +1534,7 @@ class ModelRunnerStep(MonitoredStep):
1285
1534
 
1286
1535
  :param model_artifact: model artifact or mlrun model artifact uri
1287
1536
  :param labels: model endpoint labels, should be list of str or mapping of str:str
1288
- :param creation_strategy: Strategy for creating or updating the model endpoint:
1537
+ :param model_endpoint_creation_strategy: Strategy for creating or updating the model endpoint:
1289
1538
  * **overwrite**:
1290
1539
  1. If model endpoints with the same name exist, delete the `latest` one.
1291
1540
  2. Create a new model endpoint entry and set it as `latest`.
@@ -1310,7 +1559,6 @@ class ModelRunnerStep(MonitoredStep):
1310
1559
  :param override: bool allow override existing model on the current ModelRunnerStep.
1311
1560
  :param model_parameters: Parameters for model instantiation
1312
1561
  """
1313
-
1314
1562
  if isinstance(model_class, Model) and model_parameters:
1315
1563
  raise mlrun.errors.MLRunInvalidArgumentError(
1316
1564
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
@@ -1319,10 +1567,20 @@ class ModelRunnerStep(MonitoredStep):
1319
1567
  model_parameters = model_parameters or (
1320
1568
  model_class.to_dict() if isinstance(model_class, Model) else {}
1321
1569
  )
1322
- if outputs is None and isinstance(
1323
- model_artifact, mlrun.artifacts.ModelArtifact
1570
+
1571
+ if isinstance(
1572
+ model_artifact,
1573
+ str,
1324
1574
  ):
1325
- outputs = [feature.name for feature in model_artifact.spec.outputs]
1575
+ try:
1576
+ model_artifact, _ = mlrun.store_manager.get_store_artifact(
1577
+ model_artifact
1578
+ )
1579
+ except mlrun.errors.MLRunNotFoundError:
1580
+ raise mlrun.errors.MLRunInvalidArgumentError("Artifact not found.")
1581
+
1582
+ outputs = outputs or self._get_model_output_schema(model_artifact)
1583
+
1326
1584
  model_artifact = (
1327
1585
  model_artifact.uri
1328
1586
  if isinstance(model_artifact, mlrun.artifacts.Artifact)
@@ -1369,7 +1627,7 @@ class ModelRunnerStep(MonitoredStep):
1369
1627
  schemas.MonitoringData.OUTPUTS: outputs,
1370
1628
  schemas.MonitoringData.INPUT_PATH: input_path,
1371
1629
  schemas.MonitoringData.RESULT_PATH: result_path,
1372
- schemas.MonitoringData.CREATION_STRATEGY: creation_strategy,
1630
+ schemas.MonitoringData.CREATION_STRATEGY: model_endpoint_creation_strategy,
1373
1631
  schemas.MonitoringData.LABELS: labels,
1374
1632
  schemas.MonitoringData.MODEL_PATH: model_artifact,
1375
1633
  schemas.MonitoringData.MODEL_CLASS: model_class,
@@ -1379,14 +1637,44 @@ class ModelRunnerStep(MonitoredStep):
1379
1637
 
1380
1638
  @staticmethod
1381
1639
  def _get_model_output_schema(
1382
- model: str, monitoring_data: dict[str, dict[str, str]]
1640
+ model_artifact: Union[ModelArtifact, LLMPromptArtifact],
1641
+ ) -> Optional[list[str]]:
1642
+ if isinstance(
1643
+ model_artifact,
1644
+ ModelArtifact,
1645
+ ):
1646
+ return [feature.name for feature in model_artifact.spec.outputs]
1647
+ elif isinstance(
1648
+ model_artifact,
1649
+ LLMPromptArtifact,
1650
+ ):
1651
+ _model_artifact = model_artifact.model_artifact
1652
+ return [feature.name for feature in _model_artifact.spec.outputs]
1653
+
1654
+ @staticmethod
1655
+ def _get_model_endpoint_output_schema(
1656
+ name: str,
1657
+ project: str,
1658
+ uid: str,
1383
1659
  ) -> list[str]:
1384
1660
  output_schema = None
1385
- if monitoring_data[model].get(schemas.MonitoringData.MODEL_PATH) is not None:
1386
- artifact = get_store_resource(
1387
- monitoring_data[model].get(schemas.MonitoringData.MODEL_PATH)
1661
+ try:
1662
+ model_endpoint: mlrun.common.schemas.model_monitoring.ModelEndpoint = (
1663
+ mlrun.db.get_run_db().get_model_endpoint(
1664
+ name=name,
1665
+ project=project,
1666
+ endpoint_id=uid,
1667
+ tsdb_metrics=False,
1668
+ )
1669
+ )
1670
+ output_schema = model_endpoint.spec.label_names
1671
+ except (
1672
+ mlrun.errors.MLRunNotFoundError,
1673
+ mlrun.errors.MLRunInvalidArgumentError,
1674
+ ):
1675
+ logger.warning(
1676
+ f"Model endpoint not found, using default output schema for model {name}"
1388
1677
  )
1389
- output_schema = [feature.name for feature in artifact.spec.outputs]
1390
1678
  return output_schema
1391
1679
 
1392
1680
  @staticmethod
@@ -1407,8 +1695,14 @@ class ModelRunnerStep(MonitoredStep):
1407
1695
  if isinstance(monitoring_data, dict):
1408
1696
  for model in monitoring_data:
1409
1697
  monitoring_data[model][schemas.MonitoringData.OUTPUTS] = (
1410
- monitoring_data[model][schemas.MonitoringData.OUTPUTS]
1411
- or self._get_model_output_schema(model, monitoring_data)
1698
+ monitoring_data.get(model, {}).get(schemas.MonitoringData.OUTPUTS)
1699
+ or self._get_model_endpoint_output_schema(
1700
+ name=model,
1701
+ project=self.context.project if self.context else None,
1702
+ uid=monitoring_data.get(model, {}).get(
1703
+ mlrun.common.schemas.MonitoringData.MODEL_ENDPOINT_UID
1704
+ ),
1705
+ )
1412
1706
  )
1413
1707
  # Prevent calling _get_model_output_schema for same model more than once
1414
1708
  self.class_args[
@@ -1429,6 +1723,7 @@ class ModelRunnerStep(MonitoredStep):
1429
1723
  return monitoring_data
1430
1724
 
1431
1725
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1726
+ self.context = context
1432
1727
  if not self._is_local_function(context):
1433
1728
  # skip init of non local functions
1434
1729
  return
@@ -1450,6 +1745,7 @@ class ModelRunnerStep(MonitoredStep):
1450
1745
  model_selector=model_selector,
1451
1746
  runnables=model_objects,
1452
1747
  execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
1748
+ shared_proxy_mapping=self._shared_proxy_mapping or None,
1453
1749
  name=self.name,
1454
1750
  context=context,
1455
1751
  )
@@ -1773,7 +2069,7 @@ class FlowStep(BaseStep):
1773
2069
  self._insert_all_error_handlers()
1774
2070
  self.check_and_process_graph()
1775
2071
 
1776
- for step in self._steps.values():
2072
+ for step in self.steps.values():
1777
2073
  step.set_parent(self)
1778
2074
  step.init_object(context, namespace, mode, reset=reset)
1779
2075
  self._set_error_handler()
@@ -2136,6 +2432,11 @@ class RootFlowStep(FlowStep):
2136
2432
  "model_endpoints_names",
2137
2433
  "model_endpoints_routes_names",
2138
2434
  "track_models",
2435
+ "shared_max_processes",
2436
+ "shared_max_threads",
2437
+ "shared_models",
2438
+ "shared_models_mechanism",
2439
+ "pool_factor",
2139
2440
  ]
2140
2441
 
2141
2442
  def __init__(
@@ -2156,6 +2457,157 @@ class RootFlowStep(FlowStep):
2156
2457
  self._models = set()
2157
2458
  self._route_models = set()
2158
2459
  self._track_models = False
2460
+ self._shared_models: dict[str, tuple[str, dict]] = {}
2461
+ self._shared_models_mechanism: dict[str, ParallelExecutionMechanisms] = {}
2462
+ self._shared_max_processes = None
2463
+ self._shared_max_threads = None
2464
+ self._pool_factor = None
2465
+
2466
+ def add_shared_model(
2467
+ self,
2468
+ name: str,
2469
+ model_class: Union[str, Model],
2470
+ execution_mechanism: Union[str, ParallelExecutionMechanisms],
2471
+ model_artifact: Optional[Union[str, ModelArtifact]],
2472
+ override: bool = False,
2473
+ **model_parameters,
2474
+ ) -> None:
2475
+ """
2476
+ Add a shared model to the graph, this model will be available to all the ModelRunners in the graph
2477
+ :param name: Name of the shared model (should be unique in the graph)
2478
+ :param model_class: Model class name
2479
+ :param execution_mechanism: Parallel execution mechanism to be used to execute this model. Must be one of:
2480
+ * "process_pool" – To run in a separate process from a process pool. This is appropriate for CPU or GPU
2481
+ intensive tasks as they would otherwise block the main process by holding Python's Global Interpreter
2482
+ Lock (GIL).
2483
+ * "dedicated_process" – To run in a separate dedicated process. This is appropriate for CPU or GPU intensive
2484
+ tasks that also require significant Runnable-specific initialization (e.g. a large model).
2485
+ * "thread_pool" – To run in a separate thread. This is appropriate for blocking I/O tasks, as they would
2486
+ otherwise block the main event loop thread.
2487
+ * "asyncio" – To run in an asyncio task. This is appropriate for I/O tasks that use asyncio, allowing the
2488
+ event loop to continue running while waiting for a response.
2489
+ * "shared_executor" – Reuses an external executor (typically managed by the flow or context) to execute the
2490
+ runnable. Should be used only if you have multiply `ParallelExecution` in the same flow and especially
2491
+ useful when:
2492
+ - You want to share a heavy resource like a large model loaded onto a GPU.
2493
+ - You want to centralize task scheduling or coordination for multiple lightweight tasks.
2494
+ - You aim to minimize overhead from creating new executors or processes/threads per runnable.
2495
+ The runnable is expected to be pre-initialized and reused across events, enabling efficient use of
2496
+ memory and hardware accelerators.
2497
+ * "naive" – To run in the main event loop. This is appropriate only for trivial computation and/or file I/O.
2498
+ It means that the runnable will not actually be run in parallel to anything else.
2499
+
2500
+ :param model_artifact: model artifact or mlrun model artifact uri
2501
+ :param override: bool allow override existing model on the current ModelRunnerStep.
2502
+ :param model_parameters: Parameters for model instantiation
2503
+ """
2504
+ if isinstance(model_class, Model) and model_parameters:
2505
+ raise mlrun.errors.MLRunInvalidArgumentError(
2506
+ "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2507
+ )
2508
+
2509
+ if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2510
+ raise mlrun.errors.MLRunInvalidArgumentError(
2511
+ "Cannot add a shared model with execution mechanism 'shared_executor'"
2512
+ )
2513
+ ParallelExecutionMechanisms.validate(execution_mechanism)
2514
+
2515
+ model_parameters = model_parameters or (
2516
+ model_class.to_dict() if isinstance(model_class, Model) else {}
2517
+ )
2518
+ model_artifact = (
2519
+ model_artifact.uri
2520
+ if isinstance(model_artifact, mlrun.artifacts.Artifact)
2521
+ else model_artifact
2522
+ )
2523
+ model_parameters["artifact_uri"] = model_parameters.get(
2524
+ "artifact_uri", model_artifact
2525
+ )
2526
+
2527
+ if model_parameters.get("name", name) != name or (
2528
+ isinstance(model_class, Model) and model_class.name != name
2529
+ ):
2530
+ raise mlrun.errors.MLRunInvalidArgumentError(
2531
+ "Inconsistent name for the added model."
2532
+ )
2533
+ model_parameters["name"] = name
2534
+
2535
+ if name in self.shared_models and not override:
2536
+ raise mlrun.errors.MLRunInvalidArgumentError(
2537
+ f"Model with name {name} already exists in this graph."
2538
+ )
2539
+
2540
+ model_class = (
2541
+ model_class
2542
+ if isinstance(model_class, str)
2543
+ else model_class.__class__.__name__
2544
+ )
2545
+ self.shared_models[name] = (model_class, model_parameters)
2546
+ self.shared_models_mechanism[name] = execution_mechanism
2547
+
2548
+ def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
2549
+ """
2550
+ Get a shared model by its artifact URI.
2551
+ :param artifact_uri: The artifact URI of the model.
2552
+ :return: A tuple of (model_class, model_parameters) if found, otherwise None.
2553
+ """
2554
+ for model_name, (model_class, model_params) in self.shared_models.items():
2555
+ if model_params.get("artifact_uri") == artifact_uri:
2556
+ return model_name
2557
+ return None
2558
+
2559
+ def config_pool_resource(
2560
+ self,
2561
+ max_processes: Optional[int] = None,
2562
+ max_threads: Optional[int] = None,
2563
+ pool_factor: Optional[int] = None,
2564
+ ) -> None:
2565
+ """
2566
+ Configure the resource limits for the shared models in the graph.
2567
+ :param max_processes: Maximum number of processes to spawn (excluding dedicated processes).
2568
+ Defaults to the number of CPUs or 16 if undetectable.
2569
+ :param max_threads: Maximum number of threads to spawn. Defaults to 32.
2570
+ :param pool_factor: Multiplier to scale the number of process/thread workers per runnable. Defaults to 1.
2571
+ """
2572
+ self.shared_max_processes = max_processes
2573
+ self.shared_max_threads = max_threads
2574
+ self.pool_factor = pool_factor
2575
+
2576
+ def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
2577
+ self.context = context
2578
+ if self.shared_models:
2579
+ self.context.executor = storey.flow.RunnableExecutor(
2580
+ max_processes=self.shared_max_processes,
2581
+ max_threads=self.shared_max_threads,
2582
+ pool_factor=self.pool_factor,
2583
+ )
2584
+ monitored_steps = self.get_monitored_steps().values()
2585
+ for monitored_step in monitored_steps:
2586
+ if isinstance(monitored_step, ModelRunnerStep):
2587
+ for model, model_params in self.shared_models.values():
2588
+ if "shared_proxy_mapping" in model_params:
2589
+ model_params["shared_proxy_mapping"].update(
2590
+ deepcopy(
2591
+ monitored_step._shared_proxy_mapping.get(
2592
+ model_params.get("name"), {}
2593
+ )
2594
+ )
2595
+ )
2596
+ else:
2597
+ model_params["shared_proxy_mapping"] = deepcopy(
2598
+ monitored_step._shared_proxy_mapping.get(
2599
+ model_params.get("name"), {}
2600
+ )
2601
+ )
2602
+ for model, model_params in self.shared_models.values():
2603
+ model = get_class(model, namespace).from_dict(
2604
+ model_params, init_with_params=True
2605
+ )
2606
+ model._raise_exception = False
2607
+ self.context.executor.add_runnable(
2608
+ model, self._shared_models_mechanism[model.name]
2609
+ )
2610
+ super().init_object(context, namespace, mode, reset=reset, **extra_kwargs)
2159
2611
 
2160
2612
  @property
2161
2613
  def model_endpoints_names(self) -> list[str]:
@@ -2184,6 +2636,48 @@ class RootFlowStep(FlowStep):
2184
2636
  def track_models(self, track_models: bool):
2185
2637
  self._track_models = track_models
2186
2638
 
2639
+ @property
2640
+ def shared_models(self) -> dict[str, tuple[str, dict]]:
2641
+ return self._shared_models
2642
+
2643
+ @shared_models.setter
2644
+ def shared_models(self, shared_models: dict[str, tuple[str, dict]]):
2645
+ self._shared_models = shared_models
2646
+
2647
+ @property
2648
+ def shared_models_mechanism(self) -> dict[str, ParallelExecutionMechanisms]:
2649
+ return self._shared_models_mechanism
2650
+
2651
+ @shared_models_mechanism.setter
2652
+ def shared_models_mechanism(
2653
+ self, shared_models_mechanism: dict[str, ParallelExecutionMechanisms]
2654
+ ):
2655
+ self._shared_models_mechanism = shared_models_mechanism
2656
+
2657
+ @property
2658
+ def shared_max_processes(self) -> Optional[int]:
2659
+ return self._shared_max_processes
2660
+
2661
+ @shared_max_processes.setter
2662
+ def shared_max_processes(self, max_processes: Optional[int]):
2663
+ self._shared_max_processes = max_processes
2664
+
2665
+ @property
2666
+ def shared_max_threads(self) -> Optional[int]:
2667
+ return self._shared_max_threads
2668
+
2669
+ @shared_max_threads.setter
2670
+ def shared_max_threads(self, max_threads: Optional[int]):
2671
+ self._shared_max_threads = max_threads
2672
+
2673
+ @property
2674
+ def pool_factor(self) -> Optional[int]:
2675
+ return self._pool_factor
2676
+
2677
+ @pool_factor.setter
2678
+ def pool_factor(self, pool_factor: Optional[int]):
2679
+ self._pool_factor = pool_factor
2680
+
2187
2681
  def update_model_endpoints_routes_names(self, model_endpoints_names: list):
2188
2682
  self._route_models.update(model_endpoints_names)
2189
2683