mlrun 1.10.0rc13__py3-none-any.whl → 1.10.0rc15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (47) hide show
  1. mlrun/artifacts/base.py +0 -31
  2. mlrun/artifacts/llm_prompt.py +106 -20
  3. mlrun/artifacts/manager.py +0 -5
  4. mlrun/common/constants.py +0 -1
  5. mlrun/common/schemas/__init__.py +1 -0
  6. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  7. mlrun/common/schemas/model_monitoring/functions.py +1 -1
  8. mlrun/common/schemas/model_monitoring/model_endpoints.py +10 -0
  9. mlrun/common/schemas/workflow.py +0 -1
  10. mlrun/config.py +1 -1
  11. mlrun/datastore/model_provider/model_provider.py +42 -14
  12. mlrun/datastore/model_provider/openai_provider.py +96 -15
  13. mlrun/db/base.py +14 -0
  14. mlrun/db/httpdb.py +42 -9
  15. mlrun/db/nopdb.py +8 -0
  16. mlrun/execution.py +16 -7
  17. mlrun/model.py +15 -0
  18. mlrun/model_monitoring/__init__.py +1 -0
  19. mlrun/model_monitoring/applications/base.py +176 -20
  20. mlrun/model_monitoring/db/_schedules.py +84 -24
  21. mlrun/model_monitoring/db/tsdb/base.py +72 -1
  22. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +7 -1
  23. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +37 -0
  24. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +25 -0
  25. mlrun/model_monitoring/helpers.py +26 -4
  26. mlrun/projects/project.py +38 -12
  27. mlrun/runtimes/daskjob.py +6 -0
  28. mlrun/runtimes/mpijob/abstract.py +6 -0
  29. mlrun/runtimes/mpijob/v1.py +6 -0
  30. mlrun/runtimes/nuclio/application/application.py +2 -0
  31. mlrun/runtimes/nuclio/function.py +6 -0
  32. mlrun/runtimes/nuclio/serving.py +12 -11
  33. mlrun/runtimes/pod.py +21 -0
  34. mlrun/runtimes/remotesparkjob.py +6 -0
  35. mlrun/runtimes/sparkjob/spark3job.py +6 -0
  36. mlrun/serving/__init__.py +2 -0
  37. mlrun/serving/server.py +95 -26
  38. mlrun/serving/states.py +130 -10
  39. mlrun/utils/helpers.py +36 -12
  40. mlrun/utils/retryer.py +15 -2
  41. mlrun/utils/version/version.json +2 -2
  42. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/METADATA +3 -8
  43. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/RECORD +47 -47
  44. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/WHEEL +0 -0
  45. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/entry_points.txt +0 -0
  46. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/licenses/LICENSE +0 -0
  47. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc15.dist-info}/top_level.txt +0 -0
@@ -720,6 +720,7 @@ class ServingRuntime(RemoteRuntime):
720
720
  "track_models": self.spec.track_models,
721
721
  "default_content_type": self.spec.default_content_type,
722
722
  "model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
723
+ # TODO: find another way to pass this (needed for local run)
723
724
  "filename": getattr(self.spec, "filename", None),
724
725
  }
725
726
 
@@ -788,17 +789,13 @@ class ServingRuntime(RemoteRuntime):
788
789
  monitoring_mock=self.spec.track_models,
789
790
  )
790
791
 
791
- if (
792
- isinstance(self.spec.graph, RootFlowStep)
793
- and self.spec.graph.include_monitored_step()
794
- ):
795
- server.graph = add_system_steps_to_graph(
796
- server.project,
797
- server.graph,
798
- self.spec.track_models,
799
- server.context,
800
- self.spec,
801
- )
792
+ server.graph = add_system_steps_to_graph(
793
+ server.project,
794
+ server.graph,
795
+ self.spec.track_models,
796
+ server.context,
797
+ self.spec,
798
+ )
802
799
 
803
800
  if workdir:
804
801
  os.chdir(old_workdir)
@@ -858,6 +855,7 @@ class ServingRuntime(RemoteRuntime):
858
855
  description=self.spec.description,
859
856
  workdir=self.spec.workdir,
860
857
  image_pull_secret=self.spec.image_pull_secret,
858
+ build=self.spec.build,
861
859
  node_name=self.spec.node_name,
862
860
  node_selector=self.spec.node_selector,
863
861
  affinity=self.spec.affinity,
@@ -868,6 +866,9 @@ class ServingRuntime(RemoteRuntime):
868
866
  security_context=self.spec.security_context,
869
867
  state_thresholds=self.spec.state_thresholds,
870
868
  serving_spec=self._get_serving_spec(),
869
+ track_models=self.spec.track_models,
870
+ parameters=self.spec.parameters,
871
+ graph=self.spec.graph,
871
872
  )
872
873
  job = KubejobRuntime(
873
874
  spec=spec,
mlrun/runtimes/pod.py CHANGED
@@ -104,6 +104,9 @@ class KubeResourceSpec(FunctionSpec):
104
104
  "security_context",
105
105
  "state_thresholds",
106
106
  "serving_spec",
107
+ "track_models",
108
+ "parameters",
109
+ "graph",
107
110
  ]
108
111
  _default_fields_to_strip = FunctionSpec._default_fields_to_strip + [
109
112
  "volumes",
@@ -180,6 +183,9 @@ class KubeResourceSpec(FunctionSpec):
180
183
  security_context=None,
181
184
  state_thresholds=None,
182
185
  serving_spec=None,
186
+ track_models=None,
187
+ parameters=None,
188
+ graph=None,
183
189
  ):
184
190
  super().__init__(
185
191
  command=command,
@@ -226,6 +232,10 @@ class KubeResourceSpec(FunctionSpec):
226
232
  or mlrun.mlconf.function.spec.state_thresholds.default.to_dict()
227
233
  )
228
234
  self.serving_spec = serving_spec
235
+ self.track_models = track_models
236
+ self.parameters = parameters
237
+ self._graph = None
238
+ self.graph = graph
229
239
  # Termination grace period is internal for runtimes that have a pod termination hook hence it is not in the
230
240
  # _dict_fields and doesn't have a setter.
231
241
  self._termination_grace_period_seconds = None
@@ -303,6 +313,17 @@ class KubeResourceSpec(FunctionSpec):
303
313
  def termination_grace_period_seconds(self) -> typing.Optional[int]:
304
314
  return self._termination_grace_period_seconds
305
315
 
316
+ @property
317
+ def graph(self):
318
+ """states graph, holding the serving workflow/DAG topology"""
319
+ return self._graph
320
+
321
+ @graph.setter
322
+ def graph(self, graph):
323
+ from ..serving.states import graph_root_setter
324
+
325
+ graph_root_setter(self, graph)
326
+
306
327
  def _serialize_field(
307
328
  self, struct: dict, field_name: typing.Optional[str] = None, strip: bool = False
308
329
  ) -> typing.Any:
@@ -59,6 +59,9 @@ class RemoteSparkSpec(KubeResourceSpec):
59
59
  security_context=None,
60
60
  state_thresholds=None,
61
61
  serving_spec=None,
62
+ graph=None,
63
+ parameters=None,
64
+ track_models=None,
62
65
  ):
63
66
  super().__init__(
64
67
  command=command,
@@ -89,6 +92,9 @@ class RemoteSparkSpec(KubeResourceSpec):
89
92
  security_context=security_context,
90
93
  state_thresholds=state_thresholds,
91
94
  serving_spec=serving_spec,
95
+ graph=graph,
96
+ parameters=parameters,
97
+ track_models=track_models,
92
98
  )
93
99
  self.provider = provider
94
100
 
@@ -169,6 +169,9 @@ class Spark3JobSpec(KubeResourceSpec):
169
169
  security_context=None,
170
170
  state_thresholds=None,
171
171
  serving_spec=None,
172
+ graph=None,
173
+ parameters=None,
174
+ track_models=None,
172
175
  ):
173
176
  super().__init__(
174
177
  command=command,
@@ -199,6 +202,9 @@ class Spark3JobSpec(KubeResourceSpec):
199
202
  security_context=security_context,
200
203
  state_thresholds=state_thresholds,
201
204
  serving_spec=serving_spec,
205
+ graph=graph,
206
+ parameters=parameters,
207
+ track_models=track_models,
202
208
  )
203
209
 
204
210
  self.driver_resources = driver_resources or {}
mlrun/serving/__init__.py CHANGED
@@ -28,6 +28,7 @@ __all__ = [
28
28
  "Model",
29
29
  "ModelSelector",
30
30
  "MonitoredStep",
31
+ "LLModel",
31
32
  ]
32
33
 
33
34
  from .routers import ModelRouter, VotingEnsemble # noqa
@@ -47,6 +48,7 @@ from .states import (
47
48
  Model,
48
49
  ModelSelector,
49
50
  MonitoredStep,
51
+ LLModel,
50
52
  ) # noqa
51
53
  from .v1_serving import MLModelServer, new_v1_model_server # noqa
52
54
  from .v2_serving import V2ModelServer # noqa
mlrun/serving/server.py CHANGED
@@ -15,6 +15,7 @@
15
15
  __all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
16
16
 
17
17
  import asyncio
18
+ import base64
18
19
  import copy
19
20
  import json
20
21
  import os
@@ -384,6 +385,7 @@ def add_monitoring_general_steps(
384
385
  graph: RootFlowStep,
385
386
  context,
386
387
  serving_spec,
388
+ pause_until_background_task_completion: bool,
387
389
  ) -> tuple[RootFlowStep, FlowStep]:
388
390
  """
389
391
  Adding the monitoring flow connection steps, this steps allow the graph to reconstruct the serving event enrich it
@@ -392,18 +394,22 @@ def add_monitoring_general_steps(
392
394
  "background_task_status_step" --> "filter_none" --> "monitoring_pre_processor_step" --> "flatten_events"
393
395
  --> "sampling_step" --> "filter_none_sampling" --> "model_monitoring_stream"
394
396
  """
397
+ background_task_status_step = None
398
+ if pause_until_background_task_completion:
399
+ background_task_status_step = graph.add_step(
400
+ "mlrun.serving.system_steps.BackgroundTaskStatus",
401
+ "background_task_status_step",
402
+ model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
403
+ )
395
404
  monitor_flow_step = graph.add_step(
396
- "mlrun.serving.system_steps.BackgroundTaskStatus",
397
- "background_task_status_step",
398
- model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
399
- )
400
- graph.add_step(
401
405
  "storey.Filter",
402
406
  "filter_none",
403
407
  _fn="(event is not None)",
404
- after="background_task_status_step",
408
+ after="background_task_status_step" if background_task_status_step else None,
405
409
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
406
410
  )
411
+ if background_task_status_step:
412
+ monitor_flow_step = background_task_status_step
407
413
  graph.add_step(
408
414
  "mlrun.serving.system_steps.MonitoringPreProcessor",
409
415
  "monitoring_pre_processor_step",
@@ -466,14 +472,28 @@ def add_monitoring_general_steps(
466
472
 
467
473
 
468
474
  def add_system_steps_to_graph(
469
- project: str, graph: RootFlowStep, track_models: bool, context, serving_spec
475
+ project: str,
476
+ graph: RootFlowStep,
477
+ track_models: bool,
478
+ context,
479
+ serving_spec,
480
+ pause_until_background_task_completion: bool = True,
470
481
  ) -> RootFlowStep:
482
+ if not (isinstance(graph, RootFlowStep) and graph.include_monitored_step()):
483
+ return graph
471
484
  monitored_steps = graph.get_monitored_steps()
472
485
  graph = add_error_raiser_step(graph, monitored_steps)
473
486
  if track_models:
487
+ background_task_status_step = None
474
488
  graph, monitor_flow_step = add_monitoring_general_steps(
475
- project, graph, context, serving_spec
489
+ project,
490
+ graph,
491
+ context,
492
+ serving_spec,
493
+ pause_until_background_task_completion,
476
494
  )
495
+ if background_task_status_step:
496
+ monitor_flow_step = background_task_status_step
477
497
  # Connect each model runner to the monitoring step:
478
498
  for step_name, step in monitored_steps.items():
479
499
  if monitor_flow_step.after:
@@ -485,6 +505,10 @@ def add_system_steps_to_graph(
485
505
  monitor_flow_step.after = [
486
506
  step_name,
487
507
  ]
508
+ context.logger.info_with(
509
+ "Server graph after adding system steps",
510
+ graph=str(graph.steps),
511
+ )
488
512
  return graph
489
513
 
490
514
 
@@ -494,18 +518,13 @@ def v2_serving_init(context, namespace=None):
494
518
  context.logger.info("Initializing server from spec")
495
519
  spec = mlrun.utils.get_serving_spec()
496
520
  server = GraphServer.from_dict(spec)
497
- if isinstance(server.graph, RootFlowStep) and server.graph.include_monitored_step():
498
- server.graph = add_system_steps_to_graph(
499
- server.project,
500
- copy.deepcopy(server.graph),
501
- spec.get("track_models"),
502
- context,
503
- spec,
504
- )
505
- context.logger.info_with(
506
- "Server graph after adding system steps",
507
- graph=str(server.graph.steps),
508
- )
521
+ server.graph = add_system_steps_to_graph(
522
+ server.project,
523
+ copy.deepcopy(server.graph),
524
+ spec.get("track_models"),
525
+ context,
526
+ spec,
527
+ )
509
528
 
510
529
  if config.log_level.lower() == "debug":
511
530
  server.verbose = True
@@ -544,17 +563,57 @@ async def async_execute_graph(
544
563
  data: DataItem,
545
564
  batching: bool,
546
565
  batch_size: Optional[int],
566
+ read_as_lists: bool,
567
+ nest_under_inputs: bool,
547
568
  ) -> list[Any]:
548
569
  spec = mlrun.utils.get_serving_spec()
549
570
 
550
- source_filename = spec.get("filename", None)
551
571
  namespace = {}
552
- if source_filename:
553
- with open(source_filename) as f:
554
- exec(f.read(), namespace)
572
+ code = os.getenv("MLRUN_EXEC_CODE")
573
+ if code:
574
+ code = base64.b64decode(code).decode("utf-8")
575
+ exec(code, namespace)
576
+ else:
577
+ # TODO: find another way to get the local file path, or ensure that MLRUN_EXEC_CODE
578
+ # gets set in local flow and not just in the remote pod
579
+ source_filename = spec.get("filename", None)
580
+ if source_filename:
581
+ with open(source_filename) as f:
582
+ exec(f.read(), namespace)
555
583
 
556
584
  server = GraphServer.from_dict(spec)
557
585
 
586
+ if server.model_endpoint_creation_task_name:
587
+ context.logger.info(
588
+ f"Waiting for model endpoint creation task '{server.model_endpoint_creation_task_name}'..."
589
+ )
590
+ background_task = (
591
+ mlrun.get_run_db().wait_for_background_task_to_reach_terminal_state(
592
+ project=server.project,
593
+ name=server.model_endpoint_creation_task_name,
594
+ )
595
+ )
596
+ task_state = background_task.status.state
597
+ if task_state == mlrun.common.schemas.BackgroundTaskState.failed:
598
+ raise mlrun.errors.MLRunRuntimeError(
599
+ "Aborting job due to model endpoint creation background task failure"
600
+ )
601
+ elif task_state != mlrun.common.schemas.BackgroundTaskState.succeeded:
602
+ # this shouldn't happen, but we need to know if it does
603
+ raise mlrun.errors.MLRunRuntimeError(
604
+ "Aborting job because the model endpoint creation background task did not succeed "
605
+ f"(status='{task_state}')"
606
+ )
607
+
608
+ server.graph = add_system_steps_to_graph(
609
+ server.project,
610
+ copy.deepcopy(server.graph),
611
+ spec.get("track_models"),
612
+ context,
613
+ spec,
614
+ pause_until_background_task_completion=False, # we've already awaited it
615
+ )
616
+
558
617
  if config.log_level.lower() == "debug":
559
618
  server.verbose = True
560
619
  context.logger.info_with("Initializing states", namespace=namespace)
@@ -588,7 +647,9 @@ async def async_execute_graph(
588
647
 
589
648
  batch = []
590
649
  for index, row in df.iterrows():
591
- data = row.to_dict()
650
+ data = row.to_list() if read_as_lists else row.to_dict()
651
+ if nest_under_inputs:
652
+ data = {"inputs": data}
592
653
  if batching:
593
654
  batch.append(data)
594
655
  if len(batch) == batch_size:
@@ -612,6 +673,8 @@ def execute_graph(
612
673
  data: DataItem,
613
674
  batching: bool = False,
614
675
  batch_size: Optional[int] = None,
676
+ read_as_lists: bool = False,
677
+ nest_under_inputs: bool = False,
615
678
  ) -> (list[Any], Any):
616
679
  """
617
680
  Execute graph as a job, from start to finish.
@@ -621,10 +684,16 @@ def execute_graph(
621
684
  :param batching: Whether to push one or more batches into the graph rather than row by row.
622
685
  :param batch_size: The number of rows to push per batch. If not set, and batching=True, the entire dataset will
623
686
  be pushed into the graph in one batch.
687
+ :param read_as_lists: Whether to read each row as a list instead of a dictionary.
688
+ :param nest_under_inputs: Whether to wrap each row with {"inputs": ...}.
624
689
 
625
690
  :return: A list of responses.
626
691
  """
627
- return asyncio.run(async_execute_graph(context, data, batching, batch_size))
692
+ return asyncio.run(
693
+ async_execute_graph(
694
+ context, data, batching, batch_size, read_as_lists, nest_under_inputs
695
+ )
696
+ )
628
697
 
629
698
 
630
699
  def _set_callbacks(server, context):
mlrun/serving/states.py CHANGED
@@ -1081,6 +1081,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1081
1081
  "raise_exception",
1082
1082
  "artifact_uri",
1083
1083
  "shared_runnable_name",
1084
+ "shared_proxy_mapping",
1084
1085
  ]
1085
1086
  kind = "model"
1086
1087
 
@@ -1089,12 +1090,16 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1089
1090
  name: str,
1090
1091
  raise_exception: bool = True,
1091
1092
  artifact_uri: Optional[str] = None,
1093
+ shared_proxy_mapping: Optional[dict] = None,
1092
1094
  **kwargs,
1093
1095
  ):
1094
1096
  super().__init__(name=name, raise_exception=raise_exception, **kwargs)
1095
1097
  if artifact_uri is not None and not isinstance(artifact_uri, str):
1096
1098
  raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
1097
1099
  self.artifact_uri = artifact_uri
1100
+ self.shared_proxy_mapping: dict[
1101
+ str : Union[str, ModelArtifact, LLMPromptArtifact]
1102
+ ] = shared_proxy_mapping
1098
1103
  self.invocation_artifact: Optional[LLMPromptArtifact] = None
1099
1104
  self.model_artifact: Optional[ModelArtifact] = None
1100
1105
  self.model_provider: Optional[ModelProvider] = None
@@ -1125,10 +1130,13 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1125
1130
  else:
1126
1131
  self.model_artifact = artifact
1127
1132
 
1128
- def _get_artifact_object(self) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1129
- if self.artifact_uri:
1130
- if mlrun.datastore.is_store_uri(self.artifact_uri):
1131
- artifact, _ = mlrun.store_manager.get_store_artifact(self.artifact_uri)
1133
+ def _get_artifact_object(
1134
+ self, proxy_uri: Optional[str] = None
1135
+ ) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1136
+ uri = proxy_uri or self.artifact_uri
1137
+ if uri:
1138
+ if mlrun.datastore.is_store_uri(uri):
1139
+ artifact, _ = mlrun.store_manager.get_store_artifact(uri)
1132
1140
  return artifact
1133
1141
  else:
1134
1142
  raise ValueError(
@@ -1148,10 +1156,12 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1148
1156
  """Override to implement prediction logic if the logic requires asyncio."""
1149
1157
  return body
1150
1158
 
1151
- def run(self, body: Any, path: str) -> Any:
1159
+ def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1152
1160
  return self.predict(body)
1153
1161
 
1154
- async def run_async(self, body: Any, path: str) -> Any:
1162
+ async def run_async(
1163
+ self, body: Any, path: str, origin_name: Optional[str] = None
1164
+ ) -> Any:
1155
1165
  return await self.predict_async(body)
1156
1166
 
1157
1167
  def get_local_model_path(self, suffix="") -> (str, dict):
@@ -1186,6 +1196,81 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1186
1196
  return None, None
1187
1197
 
1188
1198
 
1199
+ class LLModel(Model):
1200
+ def __init__(self, name: str, **kwargs):
1201
+ super().__init__(name, **kwargs)
1202
+
1203
+ def predict(
1204
+ self, body: Any, messages: list[dict], model_configuration: dict
1205
+ ) -> Any:
1206
+ if isinstance(
1207
+ self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1208
+ ) and isinstance(self.model_provider, ModelProvider):
1209
+ body["result"] = self.model_provider.invoke(
1210
+ messages=messages,
1211
+ as_str=True,
1212
+ **(model_configuration or {}),
1213
+ )
1214
+ return body
1215
+
1216
+ async def predict_async(
1217
+ self, body: Any, messages: list[dict], model_configuration: dict
1218
+ ) -> Any:
1219
+ if isinstance(
1220
+ self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1221
+ ) and isinstance(self.model_provider, ModelProvider):
1222
+ body["result"] = await self.model_provider.async_invoke(
1223
+ messages=messages,
1224
+ as_str=True,
1225
+ **(model_configuration or {}),
1226
+ )
1227
+ return body
1228
+
1229
+ def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any:
1230
+ messages, model_configuration = self.enrich_prompt(body, origin_name)
1231
+ return self.predict(
1232
+ body, messages=messages, model_configuration=model_configuration
1233
+ )
1234
+
1235
+ async def run_async(
1236
+ self, body: Any, path: str, origin_name: Optional[str] = None
1237
+ ) -> Any:
1238
+ messages, model_configuration = self.enrich_prompt(body, origin_name)
1239
+ return await self.predict_async(
1240
+ body, messages=messages, model_configuration=model_configuration
1241
+ )
1242
+
1243
+ def enrich_prompt(
1244
+ self, body: dict, origin_name: str
1245
+ ) -> Union[tuple[list[dict], dict], tuple[None, None]]:
1246
+ if origin_name and self.shared_proxy_mapping:
1247
+ llm_prompt_artifact = self.shared_proxy_mapping.get(origin_name)
1248
+ if isinstance(llm_prompt_artifact, str):
1249
+ llm_prompt_artifact = self._get_artifact_object(llm_prompt_artifact)
1250
+ self.shared_proxy_mapping[origin_name] = llm_prompt_artifact
1251
+ else:
1252
+ llm_prompt_artifact = (
1253
+ self.invocation_artifact or self._get_artifact_object()
1254
+ )
1255
+ if not (
1256
+ llm_prompt_artifact and isinstance(llm_prompt_artifact, LLMPromptArtifact)
1257
+ ):
1258
+ logger.warning(
1259
+ "LLMModel must be provided with LLMPromptArtifact",
1260
+ llm_prompt_artifact=llm_prompt_artifact,
1261
+ )
1262
+ return None, None
1263
+ prompt_legend = llm_prompt_artifact.spec.prompt_legend
1264
+ prompt_template = deepcopy(llm_prompt_artifact.read_prompt())
1265
+ kwargs = {
1266
+ place_holder: body.get(body_map["field"])
1267
+ for place_holder, body_map in prompt_legend.items()
1268
+ }
1269
+ for d in prompt_template:
1270
+ d["content"] = d["content"].format(**kwargs)
1271
+ return prompt_template, llm_prompt_artifact.spec.model_configuration
1272
+
1273
+
1189
1274
  class ModelSelector:
1190
1275
  """Used to select which models to run on each event."""
1191
1276
 
@@ -1292,6 +1377,7 @@ class ModelRunnerStep(MonitoredStep):
1292
1377
  """
1293
1378
 
1294
1379
  kind = "model_runner"
1380
+ _dict_fields = MonitoredStep._dict_fields + ["_shared_proxy_mapping"]
1295
1381
 
1296
1382
  def __init__(
1297
1383
  self,
@@ -1311,6 +1397,7 @@ class ModelRunnerStep(MonitoredStep):
1311
1397
  )
1312
1398
  self.raise_exception = raise_exception
1313
1399
  self.shape = "folder"
1400
+ self._shared_proxy_mapping = {}
1314
1401
 
1315
1402
  def add_shared_model_proxy(
1316
1403
  self,
@@ -1360,9 +1447,9 @@ class ModelRunnerStep(MonitoredStep):
1360
1447
  in path.
1361
1448
  :param override: bool allow override existing model on the current ModelRunnerStep.
1362
1449
  """
1363
- model_class = Model(
1364
- name=endpoint_name,
1365
- shared_runnable_name=shared_model_name,
1450
+ model_class, model_params = (
1451
+ "mlrun.serving.Model",
1452
+ {"name": endpoint_name, "shared_runnable_name": shared_model_name},
1366
1453
  )
1367
1454
  if isinstance(model_artifact, str):
1368
1455
  model_artifact_uri = model_artifact
@@ -1389,6 +1476,20 @@ class ModelRunnerStep(MonitoredStep):
1389
1476
  f"ModelRunnerStep can only add proxy models that were added to the root flow step, "
1390
1477
  f"model {shared_model_name} is not in the shared models."
1391
1478
  )
1479
+ if shared_model_name not in self._shared_proxy_mapping:
1480
+ self._shared_proxy_mapping[shared_model_name] = {
1481
+ endpoint_name: model_artifact.uri
1482
+ if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1483
+ else model_artifact
1484
+ }
1485
+ else:
1486
+ self._shared_proxy_mapping[shared_model_name].update(
1487
+ {
1488
+ endpoint_name: model_artifact.uri
1489
+ if isinstance(model_artifact, (ModelArtifact, LLMPromptArtifact))
1490
+ else model_artifact
1491
+ }
1492
+ )
1392
1493
  self.add_model(
1393
1494
  endpoint_name=endpoint_name,
1394
1495
  model_class=model_class,
@@ -1401,6 +1502,7 @@ class ModelRunnerStep(MonitoredStep):
1401
1502
  outputs=outputs,
1402
1503
  input_path=input_path,
1403
1504
  result_path=result_path,
1505
+ **model_params,
1404
1506
  )
1405
1507
 
1406
1508
  def add_model(
@@ -1659,6 +1761,7 @@ class ModelRunnerStep(MonitoredStep):
1659
1761
  model_selector=model_selector,
1660
1762
  runnables=model_objects,
1661
1763
  execution_mechanism_by_runnable_name=execution_mechanism_by_model_name,
1764
+ shared_proxy_mapping=self._shared_proxy_mapping or None,
1662
1765
  name=self.name,
1663
1766
  context=context,
1664
1767
  )
@@ -2494,7 +2597,24 @@ class RootFlowStep(FlowStep):
2494
2597
  max_threads=self.shared_max_threads,
2495
2598
  pool_factor=self.pool_factor,
2496
2599
  )
2497
-
2600
+ monitored_steps = self.get_monitored_steps().values()
2601
+ for monitored_step in monitored_steps:
2602
+ if isinstance(monitored_step, ModelRunnerStep):
2603
+ for model, model_params in self.shared_models.values():
2604
+ if "shared_proxy_mapping" in model_params:
2605
+ model_params["shared_proxy_mapping"].update(
2606
+ deepcopy(
2607
+ monitored_step._shared_proxy_mapping.get(
2608
+ model_params.get("name"), {}
2609
+ )
2610
+ )
2611
+ )
2612
+ else:
2613
+ model_params["shared_proxy_mapping"] = deepcopy(
2614
+ monitored_step._shared_proxy_mapping.get(
2615
+ model_params.get("name"), {}
2616
+ )
2617
+ )
2498
2618
  for model, model_params in self.shared_models.values():
2499
2619
  model = get_class(model, namespace).from_dict(
2500
2620
  model_params, init_with_params=True
mlrun/utils/helpers.py CHANGED
@@ -162,14 +162,6 @@ def get_artifact_target(item: dict, project=None):
162
162
  return item["spec"].get("target_path")
163
163
 
164
164
 
165
- # TODO: Remove once data migration v5 is obsolete
166
- def is_legacy_artifact(artifact):
167
- if isinstance(artifact, dict):
168
- return "metadata" not in artifact
169
- else:
170
- return not hasattr(artifact, "metadata")
171
-
172
-
173
165
  logger = create_logger(config.log_level, config.log_formatter, "mlrun", sys.stdout)
174
166
  missing = object()
175
167
 
@@ -1050,7 +1042,14 @@ def fill_function_hash(function_dict, tag=""):
1050
1042
 
1051
1043
 
1052
1044
  def retry_until_successful(
1053
- backoff: int, timeout: int, logger, verbose: bool, _function, *args, **kwargs
1045
+ backoff: int,
1046
+ timeout: int,
1047
+ logger,
1048
+ verbose: bool,
1049
+ _function,
1050
+ *args,
1051
+ fatal_exceptions=(),
1052
+ **kwargs,
1054
1053
  ):
1055
1054
  """
1056
1055
  Runs function with given *args and **kwargs.
@@ -1063,14 +1062,31 @@ def retry_until_successful(
1063
1062
  :param verbose: whether to log the failure on each retry
1064
1063
  :param _function: function to run
1065
1064
  :param args: functions args
1065
+ :param fatal_exceptions: exception types that should not be retried
1066
1066
  :param kwargs: functions kwargs
1067
1067
  :return: function result
1068
1068
  """
1069
- return Retryer(backoff, timeout, logger, verbose, _function, *args, **kwargs).run()
1069
+ return Retryer(
1070
+ backoff,
1071
+ timeout,
1072
+ logger,
1073
+ verbose,
1074
+ _function,
1075
+ *args,
1076
+ fatal_exceptions=fatal_exceptions,
1077
+ **kwargs,
1078
+ ).run()
1070
1079
 
1071
1080
 
1072
1081
  async def retry_until_successful_async(
1073
- backoff: int, timeout: int, logger, verbose: bool, _function, *args, **kwargs
1082
+ backoff: int,
1083
+ timeout: int,
1084
+ logger,
1085
+ verbose: bool,
1086
+ _function,
1087
+ *args,
1088
+ fatal_exceptions=(),
1089
+ **kwargs,
1074
1090
  ):
1075
1091
  """
1076
1092
  Runs function with given *args and **kwargs.
@@ -1082,12 +1098,20 @@ async def retry_until_successful_async(
1082
1098
  :param logger: a logger so we can log the failures
1083
1099
  :param verbose: whether to log the failure on each retry
1084
1100
  :param _function: function to run
1101
+ :param fatal_exceptions: exception types that should not be retried
1085
1102
  :param args: functions args
1086
1103
  :param kwargs: functions kwargs
1087
1104
  :return: function result
1088
1105
  """
1089
1106
  return await AsyncRetryer(
1090
- backoff, timeout, logger, verbose, _function, *args, **kwargs
1107
+ backoff,
1108
+ timeout,
1109
+ logger,
1110
+ verbose,
1111
+ _function,
1112
+ *args,
1113
+ fatal_exceptions=fatal_exceptions,
1114
+ **kwargs,
1091
1115
  ).run()
1092
1116
 
1093
1117