mlrun 1.10.0rc21__py3-none-any.whl → 1.10.0rc23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

mlrun/serving/states.py CHANGED
@@ -546,8 +546,8 @@ class BaseStep(ModelObj):
546
546
  # Update model endpoints names in the root step
547
547
  root.update_model_endpoints_names(step_model_endpoints_names)
548
548
 
549
- @staticmethod
550
549
  def _verify_shared_models(
550
+ self,
551
551
  root: "RootFlowStep",
552
552
  step: "ModelRunnerStep",
553
553
  step_model_endpoints_names: list[str],
@@ -576,15 +576,17 @@ class BaseStep(ModelObj):
576
576
  prefix, _ = mlrun.datastore.parse_store_uri(model_artifact_uri)
577
577
  # if the model artifact is a prompt, we need to get the model URI
578
578
  # to ensure that the shared runnable name is correct
579
+ llm_artifact_uri = None
579
580
  if prefix == mlrun.utils.StorePrefix.LLMPrompt:
580
581
  llm_artifact, _ = mlrun.store_manager.get_store_artifact(
581
582
  model_artifact_uri
582
583
  )
584
+ llm_artifact_uri = llm_artifact.uri
583
585
  model_artifact_uri = mlrun.utils.remove_tag_from_artifact_uri(
584
586
  llm_artifact.spec.parent_uri
585
587
  )
586
- actual_shared_name = root.get_shared_model_name_by_artifact_uri(
587
- model_artifact_uri
588
+ actual_shared_name, shared_model_class, shared_model_params = (
589
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
588
590
  )
589
591
 
590
592
  if not shared_runnable_name:
@@ -596,15 +598,20 @@ class BaseStep(ModelObj):
596
598
  step.class_args[schemas.ModelRunnerStepData.MODELS][name][
597
599
  schemas.ModelsData.MODEL_PARAMETERS.value
598
600
  ]["shared_runnable_name"] = actual_shared_name
599
- shared_models.append(actual_shared_name)
600
601
  elif actual_shared_name != shared_runnable_name:
601
602
  raise GraphError(
602
603
  f"Model endpoint {name} shared runnable name mismatch: "
603
604
  f"expected {actual_shared_name}, got {shared_runnable_name}"
604
605
  )
605
- else:
606
- shared_models.append(actual_shared_name)
607
-
606
+ shared_models.append(actual_shared_name)
607
+ self._edit_proxy_model_data(
608
+ step,
609
+ name,
610
+ actual_shared_name,
611
+ shared_model_params,
612
+ shared_model_class,
613
+ llm_artifact_uri or model_artifact_uri,
614
+ )
608
615
  undefined_shared_models = list(
609
616
  set(shared_models) - set(root.shared_models.keys())
610
617
  )
@@ -613,6 +620,52 @@ class BaseStep(ModelObj):
613
620
  f"The following shared models are not defined in the graph: {undefined_shared_models}."
614
621
  )
615
622
 
623
+ @staticmethod
624
+ def _edit_proxy_model_data(
625
+ step: "ModelRunnerStep",
626
+ name: str,
627
+ actual_shared_name: str,
628
+ shared_model_params: dict,
629
+ shared_model_class: Any,
630
+ artifact: Union[ModelArtifact, LLMPromptArtifact, str],
631
+ ):
632
+ monitoring_data = step.class_args.setdefault(
633
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
634
+ )
635
+
636
+ # edit monitoring data according to the shared model parameters
637
+ monitoring_data[name][schemas.MonitoringData.INPUT_PATH] = shared_model_params[
638
+ "input_path"
639
+ ]
640
+ monitoring_data[name][schemas.MonitoringData.RESULT_PATH] = shared_model_params[
641
+ "result_path"
642
+ ]
643
+ monitoring_data[name][schemas.MonitoringData.INPUTS] = shared_model_params[
644
+ "inputs"
645
+ ]
646
+ monitoring_data[name][schemas.MonitoringData.OUTPUTS] = shared_model_params[
647
+ "outputs"
648
+ ]
649
+ monitoring_data[name][schemas.MonitoringData.MODEL_CLASS] = (
650
+ shared_model_class
651
+ if isinstance(shared_model_class, str)
652
+ else shared_model_class.__class__.__name__
653
+ )
654
+ if actual_shared_name and actual_shared_name not in step._shared_proxy_mapping:
655
+ step._shared_proxy_mapping[actual_shared_name] = {
656
+ name: artifact.uri
657
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
658
+ else artifact
659
+ }
660
+ elif actual_shared_name:
661
+ step._shared_proxy_mapping[actual_shared_name].update(
662
+ {
663
+ name: artifact.uri
664
+ if isinstance(artifact, (ModelArtifact, LLMPromptArtifact))
665
+ else artifact
666
+ }
667
+ )
668
+
616
669
 
617
670
  class TaskStep(BaseStep):
618
671
  """task execution step, runs a class or handler"""
@@ -1116,6 +1169,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1116
1169
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1117
1170
  self.model_artifact: Optional[ModelArtifact] = None
1118
1171
  self.model_provider: Optional[ModelProvider] = None
1172
+ self._artifact_were_loaded = False
1119
1173
 
1120
1174
  def __init_subclass__(cls):
1121
1175
  super().__init_subclass__()
@@ -1136,12 +1190,14 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1136
1190
  )
1137
1191
 
1138
1192
  def _load_artifacts(self) -> None:
1139
- artifact = self._get_artifact_object()
1140
- if isinstance(artifact, LLMPromptArtifact):
1141
- self.invocation_artifact = artifact
1142
- self.model_artifact = self.invocation_artifact.model_artifact
1143
- else:
1144
- self.model_artifact = artifact
1193
+ if not self._artifact_were_loaded:
1194
+ artifact = self._get_artifact_object()
1195
+ if isinstance(artifact, LLMPromptArtifact):
1196
+ self.invocation_artifact = artifact
1197
+ self.model_artifact = self.invocation_artifact.model_artifact
1198
+ else:
1199
+ self.model_artifact = artifact
1200
+ self._artifact_were_loaded = True
1145
1201
 
1146
1202
  def _get_artifact_object(
1147
1203
  self, proxy_uri: Optional[str] = None
@@ -1210,6 +1266,57 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1210
1266
 
1211
1267
 
1212
1268
  class LLModel(Model):
1269
+ """
1270
+ A model wrapper for handling LLM (Large Language Model) prompt-based inference.
1271
+
1272
+ This class extends the base `Model` to provide specialized handling for
1273
+ `LLMPromptArtifact` objects, enabling both synchronous and asynchronous
1274
+ invocation of language models.
1275
+
1276
+ **Model Invocation**:
1277
+
1278
+ - The execution of enriched prompts is delegated to the `model_provider`
1279
+ configured for the model (e.g., **Hugging Face** or **OpenAI**).
1280
+ - The `model_provider` is responsible for sending the prompt to the correct
1281
+ backend API and returning the generated output.
1282
+ - Users can override the `predict` and `predict_async` methods to customize
1283
+ the behavior of the model invocation.
1284
+
1285
+ **Prompt Enrichment Overview**:
1286
+
1287
+ - If an `LLMPromptArtifact` is found, load its prompt template and fill in
1288
+ placeholders using values from the request body.
1289
+ - If the artifact is not an `LLMPromptArtifact`, skip formatting and attempt
1290
+ to retrieve `messages` directly from the request body using the input path.
1291
+
1292
+ **Simplified Example**:
1293
+
1294
+ Input body::
1295
+
1296
+ {"city": "Paris", "days": 3}
1297
+
1298
+ Prompt template in artifact::
1299
+
1300
+ [
1301
+ {"role": "system", "content": "You are a travel planning assistant."},
1302
+ {"role": "user", "content": "Create a {{days}}-day itinerary for {{city}}."},
1303
+ ]
1304
+
1305
+ Result after enrichment::
1306
+
1307
+ [
1308
+ {"role": "system", "content": "You are a travel planning assistant."},
1309
+ {"role": "user", "content": "Create a 3-day itinerary for Paris."},
1310
+ ]
1311
+
1312
+ :param name: Name of the model.
1313
+ :param input_path: Path in the request body where input data is located.
1314
+ :param result_path: Path in the response body where model outputs and the statistics
1315
+ will be stored.
1316
+ """
1317
+
1318
+ _dict_fields = Model._dict_fields + ["result_path", "input_path"]
1319
+
1213
1320
  def __init__(
1214
1321
  self,
1215
1322
  name: str,
@@ -1220,6 +1327,12 @@ class LLModel(Model):
1220
1327
  super().__init__(name, **kwargs)
1221
1328
  self._input_path = split_path(input_path)
1222
1329
  self._result_path = split_path(result_path)
1330
+ logger.info(
1331
+ "LLModel initialized",
1332
+ model_name=name,
1333
+ input_path=input_path,
1334
+ result_path=result_path,
1335
+ )
1223
1336
 
1224
1337
  def predict(
1225
1338
  self,
@@ -1228,9 +1341,16 @@ class LLModel(Model):
1228
1341
  model_configuration: Optional[dict] = None,
1229
1342
  **kwargs,
1230
1343
  ) -> Any:
1344
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1231
1345
  if isinstance(
1232
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1346
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1233
1347
  ) and isinstance(self.model_provider, ModelProvider):
1348
+ logger.debug(
1349
+ "Invoking model provider",
1350
+ model_name=self.name,
1351
+ messages=messages,
1352
+ model_configuration=model_configuration,
1353
+ )
1234
1354
  response_with_stats = self.model_provider.invoke(
1235
1355
  messages=messages,
1236
1356
  invoke_response_format=InvokeResponseFormat.USAGE,
@@ -1239,6 +1359,19 @@ class LLModel(Model):
1239
1359
  set_data_by_path(
1240
1360
  path=self._result_path, data=body, value=response_with_stats
1241
1361
  )
1362
+ logger.debug(
1363
+ "LLModel prediction completed",
1364
+ model_name=self.name,
1365
+ answer=response_with_stats.get("answer"),
1366
+ usage=response_with_stats.get("usage"),
1367
+ )
1368
+ else:
1369
+ logger.warning(
1370
+ "LLModel invocation artifact or model provider not set, skipping prediction",
1371
+ model_name=self.name,
1372
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1373
+ model_provider_type=type(self.model_provider).__name__,
1374
+ )
1242
1375
  return body
1243
1376
 
1244
1377
  async def predict_async(
@@ -1248,9 +1381,16 @@ class LLModel(Model):
1248
1381
  model_configuration: Optional[dict] = None,
1249
1382
  **kwargs,
1250
1383
  ) -> Any:
1384
+ llm_prompt_artifact = kwargs.get("llm_prompt_artifact")
1251
1385
  if isinstance(
1252
- self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1386
+ llm_prompt_artifact, mlrun.artifacts.LLMPromptArtifact
1253
1387
  ) and isinstance(self.model_provider, ModelProvider):
1388
+ logger.debug(
1389
+ "Async invoking model provider",
1390
+ model_name=self.name,
1391
+ messages=messages,
1392
+ model_configuration=model_configuration,
1393
+ )
1254
1394
  response_with_stats = await self.model_provider.async_invoke(
1255
1395
  messages=messages,
1256
1396
  invoke_response_format=InvokeResponseFormat.USAGE,
@@ -1259,46 +1399,86 @@ class LLModel(Model):
1259
1399
  set_data_by_path(
1260
1400
  path=self._result_path, data=body, value=response_with_stats
1261
1401
  )
1402
+ logger.debug(
1403
+ "LLModel async prediction completed",
1404
+ model_name=self.name,
1405
+ answer=response_with_stats.get("answer"),
1406
+ usage=response_with_stats.get("usage"),
1407
+ )
1408
+ else:
1409
+ logger.warning(
1410
+ "LLModel invocation artifact or model provider not set, skipping async prediction",
1411
+ model_name=self.name,
1412
+ invocation_artifact_type=type(llm_prompt_artifact).__name__,
1413
+ model_provider_type=type(self.model_provider).__name__,
1414
+ )
1262
1415
  return body
1263
1416
 
1264
1417
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1265
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1418
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1419
+ messages, model_configuration = self.enrich_prompt(
1420
+ body, origin_name, llm_prompt_artifact
1421
+ )
1422
+ logger.info(
1423
+ "Calling LLModel predict",
1424
+ model_name=self.name,
1425
+ model_endpoint_name=origin_name,
1426
+ messages_len=len(messages) if messages else 0,
1427
+ )
1266
1428
  return self.predict(
1267
- body, messages=messages, model_configuration=model_configuration
1429
+ body,
1430
+ messages=messages,
1431
+ model_configuration=model_configuration,
1432
+ llm_prompt_artifact=llm_prompt_artifact,
1268
1433
  )
1269
1434
 
1270
1435
  async def run_async(
1271
1436
  self, body: Any, path: str, origin_name: Optional[str] = None
1272
1437
  ) -> Any:
1273
- messages, model_configuration = self.enrich_prompt(body, origin_name)
1438
+ llm_prompt_artifact = self._get_invocation_artifact(origin_name)
1439
+ messages, model_configuration = self.enrich_prompt(
1440
+ body, origin_name, llm_prompt_artifact
1441
+ )
1442
+ logger.info(
1443
+ "Calling LLModel async predict",
1444
+ model_name=self.name,
1445
+ model_endpoint_name=origin_name,
1446
+ messages_len=len(messages) if messages else 0,
1447
+ )
1274
1448
  return await self.predict_async(
1275
- body, messages=messages, model_configuration=model_configuration
1449
+ body,
1450
+ messages=messages,
1451
+ model_configuration=model_configuration,
1452
+ llm_prompt_artifact=llm_prompt_artifact,
1276
1453
  )
1277
1454
 
1278
1455
  def enrich_prompt(
1279
- self, body: dict, origin_name: str
1456
+ self,
1457
+ body: dict,
1458
+ origin_name: str,
1459
+ llm_prompt_artifact: Optional[LLMPromptArtifact] = None,
1280
1460
  ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1281
- if origin_name and self.shared_proxy_mapping:
1282
- llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1283
- if isinstance(llm_prompt_artifact, str):
1284
- llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1285
- self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1286
- else:
1287
- llm_prompt_artifact = (
1288
- self.invocation_artifact or self._get_artifact_object()
1289
- )
1290
- if not (
1461
+ logger.info(
1462
+ "Enriching prompt",
1463
+ model_name=self.name,
1464
+ model_endpoint_name=origin_name,
1465
+ )
1466
+ if not llm_prompt_artifact or not (
1291
1467
  llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1292
1468
  ):
1293
1469
  logger.warning(
1294
- "LLMModel must be provided with LLMPromptArtifact",
1470
+ "LLModel must be provided with LLMPromptArtifact",
1471
+ model_name=self.name,
1472
+ artifact_type=type(llm_prompt_artifact).__name__,
1295
1473
  llm_prompt_artifact=llm_prompt_artifact,
1296
1474
  )
1297
- return None, None
1298
- prompt_legend = llm_prompt_artifact.spec.prompt_legend
1299
- prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1475
+ prompt_legend, prompt_template, model_configuration = {}, [], {}
1476
+ else:
1477
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1478
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1479
+ model_configuration = llm_prompt_artifact.spec.model_configuration
1300
1480
  input_data = copy(get_data_from_path(self._input_path, body))
1301
- if isinstance(input_data, dict):
1481
+ if isinstance(input_data, dict) and prompt_template:
1302
1482
  kwargs = (
1303
1483
  {
1304
1484
  place_holder: input_data.get(body_map["field"])
@@ -1315,23 +1495,61 @@ class LLModel(Model):
1315
1495
  message["content"] = message["content"].format(**input_data)
1316
1496
  except KeyError as e:
1317
1497
  logger.warning(
1318
- "Input data was missing a placeholder, placeholder stay unformatted",
1319
- key_error=e,
1498
+ "Input data missing placeholder, content stays unformatted",
1499
+ model_name=self.name,
1500
+ key_error=mlrun.errors.err_to_str(e),
1320
1501
  )
1321
1502
  message["content"] = message["content"].format_map(
1322
1503
  default_place_holders
1323
1504
  )
1505
+ elif isinstance(input_data, dict) and not prompt_template:
1506
+ # If there is no prompt template, we assume the input data is already in the correct format.
1507
+ logger.debug("Attempting to retrieve messages from the request body.")
1508
+ prompt_template = input_data.get("messages", [])
1324
1509
  else:
1325
1510
  logger.warning(
1326
- f"Expected input data to be a dict, but received input data from type {type(input_data)} prompt "
1327
- f"template stay unformatted",
1511
+ "Expected input data to be a dict, prompt template stays unformatted",
1512
+ model_name=self.name,
1513
+ input_data_type=type(input_data).__name__,
1328
1514
  )
1329
- return prompt_template, llm_prompt_artifact.spec.model_configuration
1515
+ return prompt_template, model_configuration
1516
+
1517
+ def _get_invocation_artifact(
1518
+ self, origin_name: Optional[str] = None
1519
+ ) -> Union[LLMPromptArtifact, None]:
1520
+ """
1521
+ Get the LLMPromptArtifact object for this model.
1522
+
1523
+ :param proxy_uri: Optional; URI to the proxy artifact.
1524
+ :return: LLMPromptArtifact object or None if not found.
1525
+ """
1526
+ if origin_name and self.shared_proxy_mapping:
1527
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1528
+ if isinstance(llm_prompt_artifact, str):
1529
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1530
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1531
+ elif self._artifact_were_loaded:
1532
+ llm_prompt_artifact = self.invocation_artifact
1533
+ else:
1534
+ self._load_artifacts()
1535
+ llm_prompt_artifact = self.invocation_artifact
1536
+ return llm_prompt_artifact
1330
1537
 
1331
1538
 
1332
- class ModelSelector:
1539
+ class ModelSelector(ModelObj):
1333
1540
  """Used to select which models to run on each event."""
1334
1541
 
1542
+ def __init__(self, **kwargs):
1543
+ super().__init__()
1544
+
1545
+ def __init_subclass__(cls):
1546
+ super().__init_subclass__()
1547
+ cls._dict_fields = list(
1548
+ set(cls._dict_fields)
1549
+ | set(inspect.signature(cls.__init__).parameters.keys())
1550
+ )
1551
+ cls._dict_fields.remove("self")
1552
+
1335
1553
  def select(
1336
1554
  self, event, available_models: list[Model]
1337
1555
  ) -> Union[list[str], list[Model]]:
@@ -1442,15 +1660,33 @@ class ModelRunnerStep(MonitoredStep):
1442
1660
  *args,
1443
1661
  name: Optional[str] = None,
1444
1662
  model_selector: Optional[Union[str, ModelSelector]] = None,
1663
+ model_selector_parameters: Optional[dict] = None,
1445
1664
  raise_exception: bool = True,
1446
1665
  **kwargs,
1447
1666
  ):
1667
+ if isinstance(model_selector, ModelSelector) and model_selector_parameters:
1668
+ raise mlrun.errors.MLRunInvalidArgumentError(
1669
+ "Cannot provide a model_selector object as argument to `model_selector` and also provide "
1670
+ "`model_selector_parameters`."
1671
+ )
1672
+ if model_selector:
1673
+ model_selector_parameters = model_selector_parameters or (
1674
+ model_selector.to_dict()
1675
+ if isinstance(model_selector, ModelSelector)
1676
+ else {}
1677
+ )
1678
+ model_selector = (
1679
+ model_selector
1680
+ if isinstance(model_selector, str)
1681
+ else model_selector.__class__.__name__
1682
+ )
1683
+
1448
1684
  super().__init__(
1449
1685
  *args,
1450
1686
  name=name,
1451
1687
  raise_exception=raise_exception,
1452
1688
  class_name="mlrun.serving.ModelRunner",
1453
- class_args=dict(model_selector=model_selector),
1689
+ class_args=dict(model_selector=(model_selector, model_selector_parameters)),
1454
1690
  **kwargs,
1455
1691
  )
1456
1692
  self.raise_exception = raise_exception
@@ -1466,10 +1702,6 @@ class ModelRunnerStep(MonitoredStep):
1466
1702
  model_endpoint_creation_strategy: Optional[
1467
1703
  schemas.ModelEndpointCreationStrategy
1468
1704
  ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1469
- inputs: Optional[list[str]] = None,
1470
- outputs: Optional[list[str]] = None,
1471
- input_path: Optional[str] = None,
1472
- result_path: Optional[str] = None,
1473
1705
  override: bool = False,
1474
1706
  ) -> None:
1475
1707
  """
@@ -1492,17 +1724,6 @@ class ModelRunnerStep(MonitoredStep):
1492
1724
  1. If model endpoints with the same name exist, preserve them.
1493
1725
  2. Create a new model endpoint with the same name and set it to `latest`.
1494
1726
 
1495
- :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
1496
- that been configured in the model artifact, please note that those inputs need to
1497
- be equal in length and order to the inputs that model_class predict method expects
1498
- :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
1499
- that been configured in the model artifact, please note that those outputs need to
1500
- be equal to the model_class predict method outputs (length, and order)
1501
- :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1502
- (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1503
- :param result_path: result path inside the user output event, expect scopes to be defined by dot
1504
- notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
1505
- in path.
1506
1727
  :param override: bool allow override existing model on the current ModelRunnerStep.
1507
1728
  """
1508
1729
  model_class, model_params = (
@@ -1520,11 +1741,21 @@ class ModelRunnerStep(MonitoredStep):
1520
1741
  "model_artifact must be a string, ModelArtifact or LLMPromptArtifact"
1521
1742
  )
1522
1743
  root = self._extract_root_step()
1744
+ shared_model_params = {}
1523
1745
  if isinstance(root, RootFlowStep):
1524
- shared_model_name = (
1525
- shared_model_name
1526
- or root.get_shared_model_name_by_artifact_uri(model_artifact_uri)
1746
+ actual_shared_model_name, shared_model_class, shared_model_params = (
1747
+ root.get_shared_model_by_artifact_uri(model_artifact_uri)
1527
1748
  )
1749
+ if not actual_shared_model_name or (
1750
+ shared_model_name and actual_shared_model_name != shared_model_name
1751
+ ):
1752
+ raise GraphError(
1753
+ f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1754
+ f"model {shared_model_name} is not in the shared models."
1755
+ )
1756
+ elif not shared_model_name:
1757
+ shared_model_name = actual_shared_model_name
1758
+ model_params["shared_runnable_name"] = shared_model_name
1528
1759
  if not root.shared_models or (
1529
1760
  root.shared_models
1530
1761
  and shared_model_name
@@ -1534,13 +1765,27 @@ class ModelRunnerStep(MonitoredStep):
1534
1765
  f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1535
1766
  f"model {shared_model_name} is not in the shared models."
1536
1767
  )
1537
- if shared_model_name not in self._shared_proxy_mapping:
1768
+ monitoring_data = self.class_args.get(
1769
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
1770
+ )
1771
+ monitoring_data.setdefault(endpoint_name, {})[
1772
+ schemas.MonitoringData.MODEL_CLASS
1773
+ ] = (
1774
+ shared_model_class
1775
+ if isinstance(shared_model_class, str)
1776
+ else shared_model_class.__class__.__name__
1777
+ )
1778
+ self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = (
1779
+ monitoring_data
1780
+ )
1781
+
1782
+ if shared_model_name and shared_model_name not in self._shared_proxy_mapping:
1538
1783
  self._shared_proxy_mapping[shared_model_name] = {
1539
1784
  endpoint_name: model_artifact.uri
1540
1785
  if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1541
1786
  else model_artifact
1542
1787
  }
1543
- else:
1788
+ elif override and shared_model_name:
1544
1789
  self._shared_proxy_mapping[shared_model_name].update(
1545
1790
  {
1546
1791
  endpoint_name: model_artifact.uri
@@ -1555,11 +1800,11 @@ class ModelRunnerStep(MonitoredStep):
1555
1800
  model_artifact=model_artifact,
1556
1801
  labels=labels,
1557
1802
  model_endpoint_creation_strategy=model_endpoint_creation_strategy,
1803
+ inputs=shared_model_params.get("inputs"),
1804
+ outputs=shared_model_params.get("outputs"),
1805
+ input_path=shared_model_params.get("input_path"),
1806
+ result_path=shared_model_params.get("result_path"),
1558
1807
  override=override,
1559
- inputs=inputs,
1560
- outputs=outputs,
1561
- input_path=input_path,
1562
- result_path=result_path,
1563
1808
  **model_params,
1564
1809
  )
1565
1810
 
@@ -1827,13 +2072,17 @@ class ModelRunnerStep(MonitoredStep):
1827
2072
  if not self._is_local_function(context):
1828
2073
  # skip init of non local functions
1829
2074
  return
1830
- model_selector = self.class_args.get("model_selector")
2075
+ model_selector, model_selector_params = self.class_args.get(
2076
+ "model_selector", (None, None)
2077
+ )
1831
2078
  execution_mechanism_by_model_name = self.class_args.get(
1832
2079
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
1833
2080
  )
1834
2081
  models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1835
- if isinstance(model_selector, str):
1836
- model_selector = get_class(model_selector, namespace)()
2082
+ if model_selector:
2083
+ model_selector = get_class(model_selector, namespace).from_dict(
2084
+ model_selector_params, init_with_params=True
2085
+ )
1837
2086
  model_objects = []
1838
2087
  for model, model_params in models.values():
1839
2088
  model_params[schemas.MonitoringData.INPUT_PATH] = (
@@ -2589,6 +2838,10 @@ class RootFlowStep(FlowStep):
2589
2838
  model_class: Union[str, Model],
2590
2839
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
2591
2840
  model_artifact: Union[str, ModelArtifact],
2841
+ inputs: Optional[list[str]] = None,
2842
+ outputs: Optional[list[str]] = None,
2843
+ input_path: Optional[str] = None,
2844
+ result_path: Optional[str] = None,
2592
2845
  override: bool = False,
2593
2846
  **model_parameters,
2594
2847
  ) -> None:
@@ -2618,6 +2871,19 @@ class RootFlowStep(FlowStep):
2618
2871
  It means that the runnable will not actually be run in parallel to anything else.
2619
2872
 
2620
2873
  :param model_artifact: model artifact or mlrun model artifact uri
2874
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs
2875
+ that been configured in the model artifact, please note that those inputs need
2876
+ to be equal in length and order to the inputs that model_class
2877
+ predict method expects
2878
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
2879
+ that been configured in the model artifact, please note that those outputs need
2880
+ to be equal to the model_class
2881
+ predict method outputs (length, and order)
2882
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
2883
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
2884
+ :param result_path: result path inside the user output event, expect scopes to be defined by dot
2885
+ notation (e.g "outputs.my_model_outputs") expects list or dictionary type object
2886
+ in path.
2621
2887
  :param override: bool allow override existing model on the current ModelRunnerStep.
2622
2888
  :param model_parameters: Parameters for model instantiation
2623
2889
  """
@@ -2625,6 +2891,14 @@ class RootFlowStep(FlowStep):
2625
2891
  raise mlrun.errors.MLRunInvalidArgumentError(
2626
2892
  "Cannot provide a model object as argument to `model_class` and also provide `model_parameters`."
2627
2893
  )
2894
+ if type(model_class) is LLModel or (
2895
+ isinstance(model_class, str) and model_class == LLModel.__name__
2896
+ ):
2897
+ if outputs:
2898
+ warnings.warn(
2899
+ "LLModel with existing outputs detected, overriding to default"
2900
+ )
2901
+ outputs = UsageResponseKeys.fields()
2628
2902
 
2629
2903
  if execution_mechanism == ParallelExecutionMechanisms.shared_executor:
2630
2904
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2652,6 +2926,14 @@ class RootFlowStep(FlowStep):
2652
2926
  "Inconsistent name for the added model."
2653
2927
  )
2654
2928
  model_parameters["name"] = name
2929
+ model_parameters["inputs"] = inputs or model_parameters.get("inputs", [])
2930
+ model_parameters["outputs"] = outputs or model_parameters.get("outputs", [])
2931
+ model_parameters["input_path"] = input_path or model_parameters.get(
2932
+ "input_path"
2933
+ )
2934
+ model_parameters["result_path"] = result_path or model_parameters.get(
2935
+ "result_path"
2936
+ )
2655
2937
 
2656
2938
  if name in self.shared_models and not override:
2657
2939
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -2666,7 +2948,9 @@ class RootFlowStep(FlowStep):
2666
2948
  self.shared_models[name] = (model_class, model_parameters)
2667
2949
  self.shared_models_mechanism[name] = execution_mechanism
2668
2950
 
2669
- def get_shared_model_name_by_artifact_uri(self, artifact_uri: str) -> Optional[str]:
2951
+ def get_shared_model_by_artifact_uri(
2952
+ self, artifact_uri: str
2953
+ ) -> Optional[tuple[str, str, dict]]:
2670
2954
  """
2671
2955
  Get a shared model by its artifact URI.
2672
2956
  :param artifact_uri: The artifact URI of the model.
@@ -2674,7 +2958,7 @@ class RootFlowStep(FlowStep):
2674
2958
  """
2675
2959
  for model_name, (model_class, model_params) in self.shared_models.items():
2676
2960
  if model_params.get("artifact_uri") == artifact_uri:
2677
- return model_name
2961
+ return model_name, model_class, model_params
2678
2962
  return None
2679
2963
 
2680
2964
  def config_pool_resource(
@@ -2844,12 +3128,10 @@ def _add_graphviz_router(graph, step, source=None, **kwargs):
2844
3128
  graph.edge(step.fullname, route.fullname)
2845
3129
 
2846
3130
 
2847
- def _add_graphviz_model_runner(graph, step, source=None):
3131
+ def _add_graphviz_model_runner(graph, step, source=None, is_monitored=False):
2848
3132
  if source:
2849
3133
  graph.node("_start", source.name, shape=source.shape, style="filled")
2850
3134
  graph.edge("_start", step.fullname)
2851
-
2852
- is_monitored = step._extract_root_step().track_models
2853
3135
  m_cell = '<FONT POINT-SIZE="9">🄼</FONT>' if is_monitored else ""
2854
3136
 
2855
3137
  number_of_models = len(
@@ -2888,6 +3170,7 @@ def _add_graphviz_flow(
2888
3170
  allow_empty=True
2889
3171
  )
2890
3172
  graph.node("_start", source.name, shape=source.shape, style="filled")
3173
+ is_monitored = step.track_models if isinstance(step, RootFlowStep) else False
2891
3174
  for start_step in start_steps:
2892
3175
  graph.edge("_start", start_step.fullname)
2893
3176
  for child in step.get_children():
@@ -2896,7 +3179,7 @@ def _add_graphviz_flow(
2896
3179
  with graph.subgraph(name="cluster_" + child.fullname) as sg:
2897
3180
  _add_graphviz_router(sg, child)
2898
3181
  elif kind == StepKinds.model_runner:
2899
- _add_graphviz_model_runner(graph, child)
3182
+ _add_graphviz_model_runner(graph, child, is_monitored=is_monitored)
2900
3183
  else:
2901
3184
  graph.node(child.fullname, label=child.name, shape=child.get_shape())
2902
3185
  _add_edges(child.after or [], step, graph, child)