mlrun 1.10.0rc14__py3-none-any.whl → 1.10.0rc15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (41) hide show
  1. mlrun/artifacts/base.py +0 -31
  2. mlrun/artifacts/manager.py +0 -5
  3. mlrun/common/schemas/__init__.py +1 -0
  4. mlrun/common/schemas/model_monitoring/__init__.py +1 -0
  5. mlrun/common/schemas/model_monitoring/functions.py +1 -1
  6. mlrun/common/schemas/model_monitoring/model_endpoints.py +10 -0
  7. mlrun/config.py +1 -1
  8. mlrun/datastore/model_provider/model_provider.py +42 -14
  9. mlrun/datastore/model_provider/openai_provider.py +96 -15
  10. mlrun/db/base.py +14 -0
  11. mlrun/db/httpdb.py +42 -9
  12. mlrun/db/nopdb.py +8 -0
  13. mlrun/model_monitoring/__init__.py +1 -0
  14. mlrun/model_monitoring/applications/base.py +176 -20
  15. mlrun/model_monitoring/db/_schedules.py +84 -24
  16. mlrun/model_monitoring/db/tsdb/base.py +72 -1
  17. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +7 -1
  18. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +37 -0
  19. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +25 -0
  20. mlrun/model_monitoring/helpers.py +26 -4
  21. mlrun/projects/project.py +26 -6
  22. mlrun/runtimes/daskjob.py +6 -0
  23. mlrun/runtimes/mpijob/abstract.py +6 -0
  24. mlrun/runtimes/mpijob/v1.py +6 -0
  25. mlrun/runtimes/nuclio/application/application.py +2 -0
  26. mlrun/runtimes/nuclio/function.py +6 -0
  27. mlrun/runtimes/nuclio/serving.py +12 -11
  28. mlrun/runtimes/pod.py +21 -0
  29. mlrun/runtimes/remotesparkjob.py +6 -0
  30. mlrun/runtimes/sparkjob/spark3job.py +6 -0
  31. mlrun/serving/server.py +95 -26
  32. mlrun/serving/states.py +16 -0
  33. mlrun/utils/helpers.py +36 -12
  34. mlrun/utils/retryer.py +15 -2
  35. mlrun/utils/version/version.json +2 -2
  36. {mlrun-1.10.0rc14.dist-info → mlrun-1.10.0rc15.dist-info}/METADATA +2 -7
  37. {mlrun-1.10.0rc14.dist-info → mlrun-1.10.0rc15.dist-info}/RECORD +41 -41
  38. {mlrun-1.10.0rc14.dist-info → mlrun-1.10.0rc15.dist-info}/WHEEL +0 -0
  39. {mlrun-1.10.0rc14.dist-info → mlrun-1.10.0rc15.dist-info}/entry_points.txt +0 -0
  40. {mlrun-1.10.0rc14.dist-info → mlrun-1.10.0rc15.dist-info}/licenses/LICENSE +0 -0
  41. {mlrun-1.10.0rc14.dist-info → mlrun-1.10.0rc15.dist-info}/top_level.txt +0 -0
@@ -469,6 +469,7 @@ class TDEngineConnector(TSDBConnector):
469
469
  preform_agg_columns: Optional[list] = None,
470
470
  order_by: Optional[str] = None,
471
471
  desc: Optional[bool] = None,
472
+ partition_by: Optional[str] = None,
472
473
  ) -> pd.DataFrame:
473
474
  """
474
475
  Getting records from TSDB data collection.
@@ -496,6 +497,8 @@ class TDEngineConnector(TSDBConnector):
496
497
  if an empty list was provided The aggregation won't be performed.
497
498
  :param order_by: The column or alias to preform ordering on the query.
498
499
  :param desc: Whether or not to sort the results in descending order.
500
+ :param partition_by: The column to partition the results by. Note that if interval is provided,
501
+ `agg_funcs` must bg provided as well.
499
502
 
500
503
  :return: DataFrame with the provided attributes from the data collection.
501
504
  :raise: MLRunInvalidArgumentError if query the provided table failed.
@@ -517,6 +520,7 @@ class TDEngineConnector(TSDBConnector):
517
520
  preform_agg_funcs_columns=preform_agg_columns,
518
521
  order_by=order_by,
519
522
  desc=desc,
523
+ partition_by=partition_by,
520
524
  )
521
525
  logger.debug("Querying TDEngine", query=full_query)
522
526
  try:
@@ -1205,6 +1209,39 @@ class TDEngineConnector(TSDBConnector):
1205
1209
  )
1206
1210
  )
1207
1211
 
1212
+ def get_drift_data(
1213
+ self,
1214
+ start: datetime,
1215
+ end: datetime,
1216
+ ) -> mm_schemas.ModelEndpointDriftValues:
1217
+ filter_query = self._generate_filter_query(
1218
+ filter_column=mm_schemas.ResultData.RESULT_STATUS,
1219
+ filter_values=[
1220
+ mm_schemas.ResultStatusApp.potential_detection.value,
1221
+ mm_schemas.ResultStatusApp.detected.value,
1222
+ ],
1223
+ )
1224
+ table = self.tables[mm_schemas.TDEngineSuperTables.APP_RESULTS].super_table
1225
+ start, end, interval = self._prepare_aligned_start_end(start, end)
1226
+
1227
+ # get per time-interval x endpoint_id combination the max result status
1228
+ df = self._get_records(
1229
+ table=table,
1230
+ start=start,
1231
+ end=end,
1232
+ interval=interval,
1233
+ columns=[mm_schemas.ResultData.RESULT_STATUS],
1234
+ filter_query=filter_query,
1235
+ timestamp_column=mm_schemas.WriterEvent.END_INFER_TIME,
1236
+ agg_funcs=["max"],
1237
+ partition_by=mm_schemas.WriterEvent.ENDPOINT_ID,
1238
+ )
1239
+ if df.empty:
1240
+ return mm_schemas.ModelEndpointDriftValues(values=[])
1241
+
1242
+ df["_wstart"] = pd.to_datetime(df["_wstart"])
1243
+ return self._df_to_drift_data(df)
1244
+
1208
1245
  # Note: this function serves as a reference for checking the TSDB for the existence of a metric.
1209
1246
  #
1210
1247
  # def read_prediction_metric_for_endpoint_if_exists(
@@ -1450,3 +1450,28 @@ class V3IOTSDBConnector(TSDBConnector):
1450
1450
  return metric_objects
1451
1451
 
1452
1452
  return build_metric_objects()
1453
+
1454
+ def get_drift_data(
1455
+ self,
1456
+ start: datetime,
1457
+ end: datetime,
1458
+ ) -> mm_schemas.ModelEndpointDriftValues:
1459
+ table = mm_schemas.V3IOTSDBTables.APP_RESULTS
1460
+ start, end, interval = self._prepare_aligned_start_end(start, end)
1461
+
1462
+ # get per time-interval x endpoint_id combination the max result status
1463
+ df = self._get_records(
1464
+ table=table,
1465
+ start=start,
1466
+ end=end,
1467
+ interval=interval,
1468
+ sliding_window_step=interval,
1469
+ columns=[mm_schemas.ResultData.RESULT_STATUS],
1470
+ agg_funcs=["max"],
1471
+ group_by=mm_schemas.WriterEvent.ENDPOINT_ID,
1472
+ )
1473
+ if df.empty:
1474
+ return mm_schemas.ModelEndpointDriftValues(values=[])
1475
+ df = df[df[f"max({mm_schemas.ResultData.RESULT_STATUS})"] >= 1]
1476
+ df = df.reset_index(names="_wstart")
1477
+ return self._df_to_drift_data(df)
@@ -549,6 +549,10 @@ def _get_monitoring_schedules_folder_path(project: str) -> str:
549
549
  )
550
550
 
551
551
 
552
+ def _get_monitoring_schedules_user_folder_path(out_path: str) -> str:
553
+ return os.path.join(out_path, mm_constants.FileTargetKind.MONITORING_SCHEDULES)
554
+
555
+
552
556
  def _get_monitoring_schedules_file_endpoint_path(
553
557
  *, project: str, endpoint_id: str
554
558
  ) -> str:
@@ -570,10 +574,7 @@ def get_monitoring_schedules_endpoint_data(
570
574
  )
571
575
 
572
576
 
573
- def get_monitoring_schedules_chief_data(
574
- *,
575
- project: str,
576
- ) -> "DataItem":
577
+ def get_monitoring_schedules_chief_data(*, project: str) -> "DataItem":
577
578
  """
578
579
  Get the model monitoring schedules' data item of the project's model endpoint.
579
580
  """
@@ -582,6 +583,19 @@ def get_monitoring_schedules_chief_data(
582
583
  )
583
584
 
584
585
 
586
+ def get_monitoring_schedules_user_application_data(
587
+ *, out_path: str, application: str
588
+ ) -> "DataItem":
589
+ """
590
+ Get the model monitoring schedules' data item of user application runs.
591
+ """
592
+ return mlrun.datastore.store_manager.object(
593
+ _get_monitoring_schedules_file_user_application_path(
594
+ out_path=out_path, application=application
595
+ )
596
+ )
597
+
598
+
585
599
  def _get_monitoring_schedules_file_chief_path(
586
600
  *,
587
601
  project: str,
@@ -591,6 +605,14 @@ def _get_monitoring_schedules_file_chief_path(
591
605
  )
592
606
 
593
607
 
608
+ def _get_monitoring_schedules_file_user_application_path(
609
+ *, out_path: str, application: str
610
+ ) -> str:
611
+ return os.path.join(
612
+ _get_monitoring_schedules_user_folder_path(out_path), f"{application}.json"
613
+ )
614
+
615
+
594
616
  def get_start_end(
595
617
  start: Union[datetime.datetime, None],
596
618
  end: Union[datetime.datetime, None],
mlrun/projects/project.py CHANGED
@@ -1042,12 +1042,7 @@ class ProjectSpec(ModelObj):
1042
1042
  artifact = artifact.to_dict()
1043
1043
  else: # artifact is a dict
1044
1044
  # imported/legacy artifacts don't have metadata,spec,status fields
1045
- key_field = (
1046
- "key"
1047
- if _is_imported_artifact(artifact)
1048
- or mlrun.utils.is_legacy_artifact(artifact)
1049
- else "metadata.key"
1050
- )
1045
+ key_field = "key" if _is_imported_artifact(artifact) else "metadata.key"
1051
1046
  key = mlrun.utils.get_in(artifact, key_field, "")
1052
1047
  if not key:
1053
1048
  raise ValueError(f'artifacts "{key_field}" must be specified')
@@ -5557,6 +5552,31 @@ class MlrunProject(ModelObj):
5557
5552
  **kwargs,
5558
5553
  )
5559
5554
 
5555
+ def get_drift_over_time(
5556
+ self,
5557
+ start: Optional[datetime.datetime] = None,
5558
+ end: Optional[datetime.datetime] = None,
5559
+ ) -> mlrun.common.schemas.model_monitoring.ModelEndpointDriftValues:
5560
+ """
5561
+ Get drift counts over time for the project.
5562
+
5563
+ This method returns a list of tuples, each representing a time-interval (in a granularity set by the
5564
+ duration of the given time range) and the number of suspected drifts and detected drifts in that interval.
5565
+ For a range of 6 hours or less, the granularity is 10 minute, for a range of 2 hours to 72 hours, the
5566
+ granularity is 1 hour, and for a range of more than 72 hours, the granularity is 24 hours.
5567
+
5568
+ :param start: Start time of the range to retrieve drift counts from.
5569
+ :param end: End time of the range to retrieve drift counts from.
5570
+
5571
+ :return: A ModelEndpointDriftValues object containing the drift counts over time.
5572
+ """
5573
+ db = mlrun.db.get_run_db(secrets=self._secrets)
5574
+ return db.get_drift_over_time(
5575
+ project=self.metadata.name,
5576
+ start=start,
5577
+ end=end,
5578
+ )
5579
+
5560
5580
  def _run_authenticated_git_action(
5561
5581
  self,
5562
5582
  action: Callable,
mlrun/runtimes/daskjob.py CHANGED
@@ -93,6 +93,9 @@ class DaskSpec(KubeResourceSpec):
93
93
  security_context=None,
94
94
  state_thresholds=None,
95
95
  serving_spec=None,
96
+ graph=None,
97
+ parameters=None,
98
+ track_models=None,
96
99
  ):
97
100
  super().__init__(
98
101
  command=command,
@@ -123,6 +126,9 @@ class DaskSpec(KubeResourceSpec):
123
126
  security_context=security_context,
124
127
  state_thresholds=state_thresholds,
125
128
  serving_spec=serving_spec,
129
+ graph=graph,
130
+ parameters=parameters,
131
+ track_models=track_models,
126
132
  )
127
133
  self.args = args
128
134
 
@@ -55,6 +55,9 @@ class MPIResourceSpec(KubeResourceSpec):
55
55
  security_context=None,
56
56
  state_thresholds=None,
57
57
  serving_spec=None,
58
+ graph=None,
59
+ parameters=None,
60
+ track_models=None,
58
61
  ):
59
62
  super().__init__(
60
63
  command=command,
@@ -85,6 +88,9 @@ class MPIResourceSpec(KubeResourceSpec):
85
88
  security_context=security_context,
86
89
  state_thresholds=state_thresholds,
87
90
  serving_spec=serving_spec,
91
+ graph=graph,
92
+ parameters=parameters,
93
+ track_models=track_models,
88
94
  )
89
95
  self.mpi_args = mpi_args or [
90
96
  "-x",
@@ -50,6 +50,9 @@ class MPIV1ResourceSpec(MPIResourceSpec):
50
50
  security_context=None,
51
51
  state_thresholds=None,
52
52
  serving_spec=None,
53
+ graph=None,
54
+ parameters=None,
55
+ track_models=None,
53
56
  ):
54
57
  super().__init__(
55
58
  command=command,
@@ -81,6 +84,9 @@ class MPIV1ResourceSpec(MPIResourceSpec):
81
84
  security_context=security_context,
82
85
  state_thresholds=state_thresholds,
83
86
  serving_spec=serving_spec,
87
+ graph=graph,
88
+ parameters=parameters,
89
+ track_models=track_models,
84
90
  )
85
91
  self.clean_pod_policy = clean_pod_policy or MPIJobV1CleanPodPolicies.default()
86
92
 
@@ -400,8 +400,10 @@ class ApplicationRuntime(RemoteRuntime):
400
400
  # nuclio implementation detail - when providing the image and emptying out the source code and build source,
401
401
  # nuclio skips rebuilding the image and simply takes the prebuilt image
402
402
  self.spec.build.functionSourceCode = ""
403
+ self.spec.config.pop("spec.build.functionSourceCode", None)
403
404
  self.status.application_source = self.spec.build.source
404
405
  self.spec.build.source = ""
406
+ self.spec.config.pop("spec.build.source", None)
405
407
 
406
408
  # save the image in the status, so we won't repopulate the function source code
407
409
  self.status.container_image = image
@@ -155,6 +155,9 @@ class NuclioSpec(KubeResourceSpec):
155
155
  state_thresholds=None,
156
156
  disable_default_http_trigger=None,
157
157
  serving_spec=None,
158
+ graph=None,
159
+ parameters=None,
160
+ track_models=None,
158
161
  ):
159
162
  super().__init__(
160
163
  command=command,
@@ -185,6 +188,9 @@ class NuclioSpec(KubeResourceSpec):
185
188
  security_context=security_context,
186
189
  state_thresholds=state_thresholds,
187
190
  serving_spec=serving_spec,
191
+ graph=graph,
192
+ parameters=parameters,
193
+ track_models=track_models,
188
194
  )
189
195
 
190
196
  self.base_spec = base_spec or {}
@@ -720,6 +720,7 @@ class ServingRuntime(RemoteRuntime):
720
720
  "track_models": self.spec.track_models,
721
721
  "default_content_type": self.spec.default_content_type,
722
722
  "model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
723
+ # TODO: find another way to pass this (needed for local run)
723
724
  "filename": getattr(self.spec, "filename", None),
724
725
  }
725
726
 
@@ -788,17 +789,13 @@ class ServingRuntime(RemoteRuntime):
788
789
  monitoring_mock=self.spec.track_models,
789
790
  )
790
791
 
791
- if (
792
- isinstance(self.spec.graph, RootFlowStep)
793
- and self.spec.graph.include_monitored_step()
794
- ):
795
- server.graph = add_system_steps_to_graph(
796
- server.project,
797
- server.graph,
798
- self.spec.track_models,
799
- server.context,
800
- self.spec,
801
- )
792
+ server.graph = add_system_steps_to_graph(
793
+ server.project,
794
+ server.graph,
795
+ self.spec.track_models,
796
+ server.context,
797
+ self.spec,
798
+ )
802
799
 
803
800
  if workdir:
804
801
  os.chdir(old_workdir)
@@ -858,6 +855,7 @@ class ServingRuntime(RemoteRuntime):
858
855
  description=self.spec.description,
859
856
  workdir=self.spec.workdir,
860
857
  image_pull_secret=self.spec.image_pull_secret,
858
+ build=self.spec.build,
861
859
  node_name=self.spec.node_name,
862
860
  node_selector=self.spec.node_selector,
863
861
  affinity=self.spec.affinity,
@@ -868,6 +866,9 @@ class ServingRuntime(RemoteRuntime):
868
866
  security_context=self.spec.security_context,
869
867
  state_thresholds=self.spec.state_thresholds,
870
868
  serving_spec=self._get_serving_spec(),
869
+ track_models=self.spec.track_models,
870
+ parameters=self.spec.parameters,
871
+ graph=self.spec.graph,
871
872
  )
872
873
  job = KubejobRuntime(
873
874
  spec=spec,
mlrun/runtimes/pod.py CHANGED
@@ -104,6 +104,9 @@ class KubeResourceSpec(FunctionSpec):
104
104
  "security_context",
105
105
  "state_thresholds",
106
106
  "serving_spec",
107
+ "track_models",
108
+ "parameters",
109
+ "graph",
107
110
  ]
108
111
  _default_fields_to_strip = FunctionSpec._default_fields_to_strip + [
109
112
  "volumes",
@@ -180,6 +183,9 @@ class KubeResourceSpec(FunctionSpec):
180
183
  security_context=None,
181
184
  state_thresholds=None,
182
185
  serving_spec=None,
186
+ track_models=None,
187
+ parameters=None,
188
+ graph=None,
183
189
  ):
184
190
  super().__init__(
185
191
  command=command,
@@ -226,6 +232,10 @@ class KubeResourceSpec(FunctionSpec):
226
232
  or mlrun.mlconf.function.spec.state_thresholds.default.to_dict()
227
233
  )
228
234
  self.serving_spec = serving_spec
235
+ self.track_models = track_models
236
+ self.parameters = parameters
237
+ self._graph = None
238
+ self.graph = graph
229
239
  # Termination grace period is internal for runtimes that have a pod termination hook hence it is not in the
230
240
  # _dict_fields and doesn't have a setter.
231
241
  self._termination_grace_period_seconds = None
@@ -303,6 +313,17 @@ class KubeResourceSpec(FunctionSpec):
303
313
  def termination_grace_period_seconds(self) -> typing.Optional[int]:
304
314
  return self._termination_grace_period_seconds
305
315
 
316
+ @property
317
+ def graph(self):
318
+ """states graph, holding the serving workflow/DAG topology"""
319
+ return self._graph
320
+
321
+ @graph.setter
322
+ def graph(self, graph):
323
+ from ..serving.states import graph_root_setter
324
+
325
+ graph_root_setter(self, graph)
326
+
306
327
  def _serialize_field(
307
328
  self, struct: dict, field_name: typing.Optional[str] = None, strip: bool = False
308
329
  ) -> typing.Any:
@@ -59,6 +59,9 @@ class RemoteSparkSpec(KubeResourceSpec):
59
59
  security_context=None,
60
60
  state_thresholds=None,
61
61
  serving_spec=None,
62
+ graph=None,
63
+ parameters=None,
64
+ track_models=None,
62
65
  ):
63
66
  super().__init__(
64
67
  command=command,
@@ -89,6 +92,9 @@ class RemoteSparkSpec(KubeResourceSpec):
89
92
  security_context=security_context,
90
93
  state_thresholds=state_thresholds,
91
94
  serving_spec=serving_spec,
95
+ graph=graph,
96
+ parameters=parameters,
97
+ track_models=track_models,
92
98
  )
93
99
  self.provider = provider
94
100
 
@@ -169,6 +169,9 @@ class Spark3JobSpec(KubeResourceSpec):
169
169
  security_context=None,
170
170
  state_thresholds=None,
171
171
  serving_spec=None,
172
+ graph=None,
173
+ parameters=None,
174
+ track_models=None,
172
175
  ):
173
176
  super().__init__(
174
177
  command=command,
@@ -199,6 +202,9 @@ class Spark3JobSpec(KubeResourceSpec):
199
202
  security_context=security_context,
200
203
  state_thresholds=state_thresholds,
201
204
  serving_spec=serving_spec,
205
+ graph=graph,
206
+ parameters=parameters,
207
+ track_models=track_models,
202
208
  )
203
209
 
204
210
  self.driver_resources = driver_resources or {}
mlrun/serving/server.py CHANGED
@@ -15,6 +15,7 @@
15
15
  __all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
16
16
 
17
17
  import asyncio
18
+ import base64
18
19
  import copy
19
20
  import json
20
21
  import os
@@ -384,6 +385,7 @@ def add_monitoring_general_steps(
384
385
  graph: RootFlowStep,
385
386
  context,
386
387
  serving_spec,
388
+ pause_until_background_task_completion: bool,
387
389
  ) -> tuple[RootFlowStep, FlowStep]:
388
390
  """
389
391
  Adding the monitoring flow connection steps, this steps allow the graph to reconstruct the serving event enrich it
@@ -392,18 +394,22 @@ def add_monitoring_general_steps(
392
394
  "background_task_status_step" --> "filter_none" --> "monitoring_pre_processor_step" --> "flatten_events"
393
395
  --> "sampling_step" --> "filter_none_sampling" --> "model_monitoring_stream"
394
396
  """
397
+ background_task_status_step = None
398
+ if pause_until_background_task_completion:
399
+ background_task_status_step = graph.add_step(
400
+ "mlrun.serving.system_steps.BackgroundTaskStatus",
401
+ "background_task_status_step",
402
+ model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
403
+ )
395
404
  monitor_flow_step = graph.add_step(
396
- "mlrun.serving.system_steps.BackgroundTaskStatus",
397
- "background_task_status_step",
398
- model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
399
- )
400
- graph.add_step(
401
405
  "storey.Filter",
402
406
  "filter_none",
403
407
  _fn="(event is not None)",
404
- after="background_task_status_step",
408
+ after="background_task_status_step" if background_task_status_step else None,
405
409
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
406
410
  )
411
+ if background_task_status_step:
412
+ monitor_flow_step = background_task_status_step
407
413
  graph.add_step(
408
414
  "mlrun.serving.system_steps.MonitoringPreProcessor",
409
415
  "monitoring_pre_processor_step",
@@ -466,14 +472,28 @@ def add_monitoring_general_steps(
466
472
 
467
473
 
468
474
  def add_system_steps_to_graph(
469
- project: str, graph: RootFlowStep, track_models: bool, context, serving_spec
475
+ project: str,
476
+ graph: RootFlowStep,
477
+ track_models: bool,
478
+ context,
479
+ serving_spec,
480
+ pause_until_background_task_completion: bool = True,
470
481
  ) -> RootFlowStep:
482
+ if not (isinstance(graph, RootFlowStep) and graph.include_monitored_step()):
483
+ return graph
471
484
  monitored_steps = graph.get_monitored_steps()
472
485
  graph = add_error_raiser_step(graph, monitored_steps)
473
486
  if track_models:
487
+ background_task_status_step = None
474
488
  graph, monitor_flow_step = add_monitoring_general_steps(
475
- project, graph, context, serving_spec
489
+ project,
490
+ graph,
491
+ context,
492
+ serving_spec,
493
+ pause_until_background_task_completion,
476
494
  )
495
+ if background_task_status_step:
496
+ monitor_flow_step = background_task_status_step
477
497
  # Connect each model runner to the monitoring step:
478
498
  for step_name, step in monitored_steps.items():
479
499
  if monitor_flow_step.after:
@@ -485,6 +505,10 @@ def add_system_steps_to_graph(
485
505
  monitor_flow_step.after = [
486
506
  step_name,
487
507
  ]
508
+ context.logger.info_with(
509
+ "Server graph after adding system steps",
510
+ graph=str(graph.steps),
511
+ )
488
512
  return graph
489
513
 
490
514
 
@@ -494,18 +518,13 @@ def v2_serving_init(context, namespace=None):
494
518
  context.logger.info("Initializing server from spec")
495
519
  spec = mlrun.utils.get_serving_spec()
496
520
  server = GraphServer.from_dict(spec)
497
- if isinstance(server.graph, RootFlowStep) and server.graph.include_monitored_step():
498
- server.graph = add_system_steps_to_graph(
499
- server.project,
500
- copy.deepcopy(server.graph),
501
- spec.get("track_models"),
502
- context,
503
- spec,
504
- )
505
- context.logger.info_with(
506
- "Server graph after adding system steps",
507
- graph=str(server.graph.steps),
508
- )
521
+ server.graph = add_system_steps_to_graph(
522
+ server.project,
523
+ copy.deepcopy(server.graph),
524
+ spec.get("track_models"),
525
+ context,
526
+ spec,
527
+ )
509
528
 
510
529
  if config.log_level.lower() == "debug":
511
530
  server.verbose = True
@@ -544,17 +563,57 @@ async def async_execute_graph(
544
563
  data: DataItem,
545
564
  batching: bool,
546
565
  batch_size: Optional[int],
566
+ read_as_lists: bool,
567
+ nest_under_inputs: bool,
547
568
  ) -> list[Any]:
548
569
  spec = mlrun.utils.get_serving_spec()
549
570
 
550
- source_filename = spec.get("filename", None)
551
571
  namespace = {}
552
- if source_filename:
553
- with open(source_filename) as f:
554
- exec(f.read(), namespace)
572
+ code = os.getenv("MLRUN_EXEC_CODE")
573
+ if code:
574
+ code = base64.b64decode(code).decode("utf-8")
575
+ exec(code, namespace)
576
+ else:
577
+ # TODO: find another way to get the local file path, or ensure that MLRUN_EXEC_CODE
578
+ # gets set in local flow and not just in the remote pod
579
+ source_filename = spec.get("filename", None)
580
+ if source_filename:
581
+ with open(source_filename) as f:
582
+ exec(f.read(), namespace)
555
583
 
556
584
  server = GraphServer.from_dict(spec)
557
585
 
586
+ if server.model_endpoint_creation_task_name:
587
+ context.logger.info(
588
+ f"Waiting for model endpoint creation task '{server.model_endpoint_creation_task_name}'..."
589
+ )
590
+ background_task = (
591
+ mlrun.get_run_db().wait_for_background_task_to_reach_terminal_state(
592
+ project=server.project,
593
+ name=server.model_endpoint_creation_task_name,
594
+ )
595
+ )
596
+ task_state = background_task.status.state
597
+ if task_state == mlrun.common.schemas.BackgroundTaskState.failed:
598
+ raise mlrun.errors.MLRunRuntimeError(
599
+ "Aborting job due to model endpoint creation background task failure"
600
+ )
601
+ elif task_state != mlrun.common.schemas.BackgroundTaskState.succeeded:
602
+ # this shouldn't happen, but we need to know if it does
603
+ raise mlrun.errors.MLRunRuntimeError(
604
+ "Aborting job because the model endpoint creation background task did not succeed "
605
+ f"(status='{task_state}')"
606
+ )
607
+
608
+ server.graph = add_system_steps_to_graph(
609
+ server.project,
610
+ copy.deepcopy(server.graph),
611
+ spec.get("track_models"),
612
+ context,
613
+ spec,
614
+ pause_until_background_task_completion=False, # we've already awaited it
615
+ )
616
+
558
617
  if config.log_level.lower() == "debug":
559
618
  server.verbose = True
560
619
  context.logger.info_with("Initializing states", namespace=namespace)
@@ -588,7 +647,9 @@ async def async_execute_graph(
588
647
 
589
648
  batch = []
590
649
  for index, row in df.iterrows():
591
- data = row.to_dict()
650
+ data = row.to_list() if read_as_lists else row.to_dict()
651
+ if nest_under_inputs:
652
+ data = {"inputs": data}
592
653
  if batching:
593
654
  batch.append(data)
594
655
  if len(batch) == batch_size:
@@ -612,6 +673,8 @@ def execute_graph(
612
673
  data: DataItem,
613
674
  batching: bool = False,
614
675
  batch_size: Optional[int] = None,
676
+ read_as_lists: bool = False,
677
+ nest_under_inputs: bool = False,
615
678
  ) -> (list[Any], Any):
616
679
  """
617
680
  Execute graph as a job, from start to finish.
@@ -621,10 +684,16 @@ def execute_graph(
621
684
  :param batching: Whether to push one or more batches into the graph rather than row by row.
622
685
  :param batch_size: The number of rows to push per batch. If not set, and batching=True, the entire dataset will
623
686
  be pushed into the graph in one batch.
687
+ :param read_as_lists: Whether to read each row as a list instead of a dictionary.
688
+ :param nest_under_inputs: Whether to wrap each row with {"inputs": ...}.
624
689
 
625
690
  :return: A list of responses.
626
691
  """
627
- return asyncio.run(async_execute_graph(context, data, batching, batch_size))
692
+ return asyncio.run(
693
+ async_execute_graph(
694
+ context, data, batching, batch_size, read_as_lists, nest_under_inputs
695
+ )
696
+ )
628
697
 
629
698
 
630
699
  def _set_callbacks(server, context):
mlrun/serving/states.py CHANGED
@@ -1203,11 +1203,27 @@ class LLModel(Model):
1203
1203
  def predict(
1204
1204
  self, body: Any, messages: list[dict], model_configuration: dict
1205
1205
  ) -> Any:
1206
+ if isinstance(
1207
+ self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1208
+ ) and isinstance(self.model_provider, ModelProvider):
1209
+ body["result"] = self.model_provider.invoke(
1210
+ messages=messages,
1211
+ as_str=True,
1212
+ **(model_configuration or {}),
1213
+ )
1206
1214
  return body
1207
1215
 
1208
1216
  async def predict_async(
1209
1217
  self, body: Any, messages: list[dict], model_configuration: dict
1210
1218
  ) -> Any:
1219
+ if isinstance(
1220
+ self.invocation_artifact, mlrun.artifacts.LLMPromptArtifact
1221
+ ) and isinstance(self.model_provider, ModelProvider):
1222
+ body["result"] = await self.model_provider.async_invoke(
1223
+ messages=messages,
1224
+ as_str=True,
1225
+ **(model_configuration or {}),
1226
+ )
1211
1227
  return body
1212
1228
 
1213
1229
  def run(self, body: Any, path: str, origin_name: Optional[str] = None) -> Any: