mlrun 1.10.0rc9__py3-none-any.whl → 1.10.0rc11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (57) hide show
  1. mlrun/artifacts/manager.py +1 -1
  2. mlrun/common/constants.py +12 -0
  3. mlrun/common/schemas/__init__.py +1 -0
  4. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  5. mlrun/common/schemas/model_monitoring/functions.py +2 -0
  6. mlrun/common/schemas/model_monitoring/model_endpoints.py +19 -1
  7. mlrun/common/schemas/serving.py +1 -0
  8. mlrun/common/schemas/workflow.py +8 -0
  9. mlrun/datastore/azure_blob.py +1 -1
  10. mlrun/datastore/base.py +4 -2
  11. mlrun/datastore/datastore.py +46 -14
  12. mlrun/datastore/google_cloud_storage.py +1 -1
  13. mlrun/datastore/s3.py +16 -5
  14. mlrun/datastore/sources.py +2 -2
  15. mlrun/datastore/targets.py +2 -2
  16. mlrun/db/__init__.py +0 -1
  17. mlrun/db/base.py +29 -0
  18. mlrun/db/httpdb.py +35 -0
  19. mlrun/db/nopdb.py +19 -0
  20. mlrun/execution.py +12 -0
  21. mlrun/frameworks/tf_keras/mlrun_interface.py +8 -19
  22. mlrun/frameworks/tf_keras/model_handler.py +21 -12
  23. mlrun/launcher/base.py +1 -0
  24. mlrun/launcher/client.py +1 -0
  25. mlrun/launcher/local.py +4 -0
  26. mlrun/model.py +15 -4
  27. mlrun/model_monitoring/applications/base.py +74 -56
  28. mlrun/model_monitoring/db/tsdb/base.py +52 -19
  29. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +179 -11
  30. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +26 -11
  31. mlrun/model_monitoring/helpers.py +48 -0
  32. mlrun/projects/__init__.py +1 -0
  33. mlrun/projects/pipelines.py +44 -1
  34. mlrun/projects/project.py +30 -0
  35. mlrun/runtimes/daskjob.py +2 -0
  36. mlrun/runtimes/kubejob.py +4 -0
  37. mlrun/runtimes/mpijob/abstract.py +2 -0
  38. mlrun/runtimes/mpijob/v1.py +2 -0
  39. mlrun/runtimes/nuclio/function.py +2 -0
  40. mlrun/runtimes/nuclio/serving.py +59 -0
  41. mlrun/runtimes/pod.py +3 -0
  42. mlrun/runtimes/remotesparkjob.py +2 -0
  43. mlrun/runtimes/sparkjob/spark3job.py +2 -0
  44. mlrun/serving/routers.py +17 -13
  45. mlrun/serving/server.py +97 -3
  46. mlrun/serving/states.py +146 -38
  47. mlrun/serving/system_steps.py +2 -1
  48. mlrun/serving/v2_serving.py +2 -2
  49. mlrun/utils/version/version.json +2 -2
  50. {mlrun-1.10.0rc9.dist-info → mlrun-1.10.0rc11.dist-info}/METADATA +13 -7
  51. {mlrun-1.10.0rc9.dist-info → mlrun-1.10.0rc11.dist-info}/RECORD +55 -57
  52. {mlrun-1.10.0rc9.dist-info → mlrun-1.10.0rc11.dist-info}/licenses/LICENSE +1 -1
  53. mlrun/db/sql_types.py +0 -160
  54. mlrun/utils/db.py +0 -71
  55. {mlrun-1.10.0rc9.dist-info → mlrun-1.10.0rc11.dist-info}/WHEEL +0 -0
  56. {mlrun-1.10.0rc9.dist-info → mlrun-1.10.0rc11.dist-info}/entry_points.txt +0 -0
  57. {mlrun-1.10.0rc9.dist-info → mlrun-1.10.0rc11.dist-info}/top_level.txt +0 -0
@@ -589,3 +589,51 @@ def _get_monitoring_schedules_file_chief_path(
589
589
  return os.path.join(
590
590
  _get_monitoring_schedules_folder_path(project), f"{project}.json"
591
591
  )
592
+
593
+
594
+ def get_start_end(
595
+ start: Union[datetime.datetime, None],
596
+ end: Union[datetime.datetime, None],
597
+ delta: Optional[datetime.timedelta] = None,
598
+ ) -> tuple[datetime.datetime, datetime.datetime]:
599
+ """
600
+ static utils function for tsdb start end format
601
+ :param start: Either None or datetime, None is handled as datetime.min(tz=timezone.utc) unless `delta`
602
+ is provided.
603
+ :param end: Either None or datetime, None is handled as datetime.now(tz=timezone.utc)
604
+ :param delta: Optional timedelta to define a time span.
605
+ - If both `start` and `end` are provided, `delta` is ignored.
606
+ - If only one of `start` or `end` is provided, the other will be
607
+ calculated using `delta`.
608
+ - If neither `start` nor `end` is provided, `end` defaults to now,
609
+ and `start` is calculated as `end - delta`.
610
+ :return: start datetime, end datetime
611
+ """
612
+
613
+ if delta and start and end:
614
+ # If both start and end are provided, delta is ignored
615
+ pass
616
+ elif delta:
617
+ if start and not end:
618
+ end = start + delta
619
+ else:
620
+ end = end or mlrun.utils.datetime_now()
621
+ start = end - delta
622
+ else:
623
+ start = start or mlrun.utils.datetime_min()
624
+ end = end or mlrun.utils.datetime_now()
625
+
626
+ if not (
627
+ isinstance(start, datetime.datetime) and isinstance(end, datetime.datetime)
628
+ ):
629
+ raise mlrun.errors.MLRunInvalidArgumentError(
630
+ "Both start and end must be datetime objects"
631
+ )
632
+
633
+ if start > end:
634
+ raise mlrun.errors.MLRunInvalidArgumentError(
635
+ "The start time must be before the end time. Note that if end time is not provided, "
636
+ "the current time is used by default"
637
+ )
638
+
639
+ return start, end
@@ -32,6 +32,7 @@ from .pipelines import (
32
32
  load_and_run_workflow,
33
33
  load_and_run,
34
34
  pipeline_context,
35
+ rerun_workflow,
35
36
  ) # noqa
36
37
  from .project import (
37
38
  MlrunProject,
@@ -21,6 +21,7 @@ import typing
21
21
  import uuid
22
22
 
23
23
  import mlrun
24
+ import mlrun.common.constants as mlrun_constants
24
25
  import mlrun.common.runtimes.constants
25
26
  import mlrun.common.schemas
26
27
  import mlrun.common.schemas.function
@@ -1070,6 +1071,46 @@ def github_webhook(request):
1070
1071
  return {"msg": "pushed"}
1071
1072
 
1072
1073
 
1074
+ def rerun_workflow(
1075
+ context: mlrun.execution.MLClientCtx, run_uid: str, project_name: str
1076
+ ):
1077
+ """
1078
+ Re-run a workflow by retrying a previously failed KFP pipeline.
1079
+
1080
+ :param context: MLRun context.
1081
+ :param run_uid: The run UID of the original workflow to retry.
1082
+ :param project_name: The project name.
1083
+ """
1084
+
1085
+ try:
1086
+ # TODO in followups: handle start and running notifications
1087
+
1088
+ # Retry the pipeline - TODO: add submit-direct flag when created
1089
+ db = mlrun.get_run_db()
1090
+ new_pipeline_id = db.retry_pipeline(
1091
+ run_uid, project_name, submit_mode=mlrun_constants.WorkflowSubmitMode.direct
1092
+ )
1093
+
1094
+ # Store result for observability
1095
+ context.set_label(
1096
+ mlrun_constants.MLRunInternalLabels.workflow_id, new_pipeline_id
1097
+ )
1098
+ context.update_run()
1099
+
1100
+ context.log_result("workflow_id", new_pipeline_id)
1101
+
1102
+ # wait for pipeline completion so monitor will push terminal notifications
1103
+ wait_for_pipeline_completion(
1104
+ new_pipeline_id,
1105
+ project=project_name,
1106
+ )
1107
+
1108
+ # Temporary exception
1109
+ except Exception as exc:
1110
+ context.logger.error("Failed to rerun workflow", exc=err_to_str(exc))
1111
+ raise
1112
+
1113
+
1073
1114
  def load_and_run(context, *args, **kwargs):
1074
1115
  """
1075
1116
  This function serves as an alias to `load_and_run_workflow`,
@@ -1192,7 +1233,9 @@ def load_and_run_workflow(
1192
1233
  context.logger.info(
1193
1234
  "Associating workflow-runner with workflow ID", run_id=run.run_id
1194
1235
  )
1195
- context.set_label("workflow-id", run.run_id)
1236
+ context.set_label(mlrun_constants.MLRunInternalLabels.workflow_id, run.run_id)
1237
+ context.update_run()
1238
+
1196
1239
  context.log_result(key="workflow_id", value=run.run_id)
1197
1240
  context.log_result(key="engine", value=run._engine.engine, commit=True)
1198
1241
 
mlrun/projects/project.py CHANGED
@@ -4968,6 +4968,36 @@ class MlrunProject(ModelObj):
4968
4968
  include_infra=include_infra,
4969
4969
  )
4970
4970
 
4971
+ def get_monitoring_function_summary(
4972
+ self,
4973
+ name: str,
4974
+ start: Optional[datetime.datetime] = None,
4975
+ end: Optional[datetime.datetime] = None,
4976
+ include_latest_metrics: bool = False,
4977
+ ) -> mlrun.common.schemas.model_monitoring.FunctionSummary:
4978
+ """Get a monitoring function summary for the specified project and function name.
4979
+ :param name: Name of the monitoring function to retrieve the summary for.
4980
+ :param start: Start time for filtering the results (optional).
4981
+ :param end: End time for filtering the results (optional).
4982
+ :param include_latest_metrics: Whether to include the latest metrics in the response (default is False).
4983
+
4984
+ :return: A FunctionSummary object containing information about the monitoring function.
4985
+ """
4986
+ if start is not None and end is not None:
4987
+ if start.tzinfo is None or end.tzinfo is None:
4988
+ raise mlrun.errors.MLRunInvalidArgumentTypeError(
4989
+ "Custom start and end times must contain the timezone."
4990
+ )
4991
+
4992
+ db = mlrun.db.get_run_db(secrets=self._secrets)
4993
+ return db.get_monitoring_function_summary(
4994
+ project=self.metadata.name,
4995
+ function_name=name,
4996
+ start=start,
4997
+ end=end,
4998
+ include_latest_metrics=include_latest_metrics,
4999
+ )
5000
+
4971
5001
  def list_runs(
4972
5002
  self,
4973
5003
  name: Optional[str] = None,
mlrun/runtimes/daskjob.py CHANGED
@@ -92,6 +92,7 @@ class DaskSpec(KubeResourceSpec):
92
92
  preemption_mode=None,
93
93
  security_context=None,
94
94
  state_thresholds=None,
95
+ serving_spec=None,
95
96
  ):
96
97
  super().__init__(
97
98
  command=command,
@@ -121,6 +122,7 @@ class DaskSpec(KubeResourceSpec):
121
122
  preemption_mode=preemption_mode,
122
123
  security_context=security_context,
123
124
  state_thresholds=state_thresholds,
125
+ serving_spec=serving_spec,
124
126
  )
125
127
  self.args = args
126
128
 
mlrun/runtimes/kubejob.py CHANGED
@@ -207,3 +207,7 @@ class KubejobRuntime(KubeResource):
207
207
  raise NotImplementedError(
208
208
  f"Running a {self.kind} function from the client is not supported. Use .run() to submit the job to the API."
209
209
  )
210
+
211
+ @property
212
+ def serving_spec(self):
213
+ return self.spec.serving_spec
@@ -54,6 +54,7 @@ class MPIResourceSpec(KubeResourceSpec):
54
54
  preemption_mode=None,
55
55
  security_context=None,
56
56
  state_thresholds=None,
57
+ serving_spec=None,
57
58
  ):
58
59
  super().__init__(
59
60
  command=command,
@@ -83,6 +84,7 @@ class MPIResourceSpec(KubeResourceSpec):
83
84
  preemption_mode=preemption_mode,
84
85
  security_context=security_context,
85
86
  state_thresholds=state_thresholds,
87
+ serving_spec=serving_spec,
86
88
  )
87
89
  self.mpi_args = mpi_args or [
88
90
  "-x",
@@ -49,6 +49,7 @@ class MPIV1ResourceSpec(MPIResourceSpec):
49
49
  preemption_mode=None,
50
50
  security_context=None,
51
51
  state_thresholds=None,
52
+ serving_spec=None,
52
53
  ):
53
54
  super().__init__(
54
55
  command=command,
@@ -79,6 +80,7 @@ class MPIV1ResourceSpec(MPIResourceSpec):
79
80
  preemption_mode=preemption_mode,
80
81
  security_context=security_context,
81
82
  state_thresholds=state_thresholds,
83
+ serving_spec=serving_spec,
82
84
  )
83
85
  self.clean_pod_policy = clean_pod_policy or MPIJobV1CleanPodPolicies.default()
84
86
 
@@ -154,6 +154,7 @@ class NuclioSpec(KubeResourceSpec):
154
154
  add_templated_ingress_host_mode=None,
155
155
  state_thresholds=None,
156
156
  disable_default_http_trigger=None,
157
+ serving_spec=None,
157
158
  ):
158
159
  super().__init__(
159
160
  command=command,
@@ -183,6 +184,7 @@ class NuclioSpec(KubeResourceSpec):
183
184
  preemption_mode=preemption_mode,
184
185
  security_context=security_context,
185
186
  state_thresholds=state_thresholds,
187
+ serving_spec=serving_spec,
186
188
  )
187
189
 
188
190
  self.base_spec = base_spec or {}
@@ -42,6 +42,8 @@ from mlrun.serving.states import (
42
42
  )
43
43
  from mlrun.utils import get_caller_globals, logger, set_paths
44
44
 
45
+ from .. import KubejobRuntime
46
+ from ..pod import KubeResourceSpec
45
47
  from .function import NuclioSpec, RemoteRuntime, min_nuclio_versions
46
48
 
47
49
  serving_subkind = "serving_v2"
@@ -149,6 +151,7 @@ class ServingSpec(NuclioSpec):
149
151
  state_thresholds=None,
150
152
  disable_default_http_trigger=None,
151
153
  model_endpoint_creation_task_name=None,
154
+ serving_spec=None,
152
155
  ):
153
156
  super().__init__(
154
157
  command=command,
@@ -189,6 +192,7 @@ class ServingSpec(NuclioSpec):
189
192
  service_type=service_type,
190
193
  add_templated_ingress_host_mode=add_templated_ingress_host_mode,
191
194
  disable_default_http_trigger=disable_default_http_trigger,
195
+ serving_spec=serving_spec,
192
196
  )
193
197
 
194
198
  self.models = models or {}
@@ -296,6 +300,7 @@ class ServingRuntime(RemoteRuntime):
296
300
  self.spec.graph = step
297
301
  elif topology == StepKinds.flow:
298
302
  self.spec.graph = RootFlowStep(engine=engine or "async")
303
+ self.spec.graph.track_models = self.spec.track_models
299
304
  else:
300
305
  raise mlrun.errors.MLRunInvalidArgumentError(
301
306
  f"unsupported topology {topology}, use 'router' or 'flow'"
@@ -331,6 +336,8 @@ class ServingRuntime(RemoteRuntime):
331
336
  """
332
337
  # Applying model monitoring configurations
333
338
  self.spec.track_models = enable_tracking
339
+ if self.spec.graph and isinstance(self.spec.graph, RootFlowStep):
340
+ self.spec.graph.track_models = enable_tracking
334
341
  if self._spec and self._spec.function_refs:
335
342
  logger.debug(
336
343
  "Set tracking for children references", enable_tracking=enable_tracking
@@ -343,6 +350,16 @@ class ServingRuntime(RemoteRuntime):
343
350
  name
344
351
  ]._function.spec.track_models = enable_tracking
345
352
 
353
+ if self._spec.function_refs[
354
+ name
355
+ ]._function.spec.graph and isinstance(
356
+ self._spec.function_refs[name]._function.spec.graph,
357
+ RootFlowStep,
358
+ ):
359
+ self._spec.function_refs[
360
+ name
361
+ ]._function.spec.graph.track_models = enable_tracking
362
+
346
363
  if not 0 < sampling_percentage <= 100:
347
364
  raise mlrun.errors.MLRunInvalidArgumentError(
348
365
  "`sampling_percentage` must be greater than 0 and less or equal to 100."
@@ -703,6 +720,7 @@ class ServingRuntime(RemoteRuntime):
703
720
  "track_models": self.spec.track_models,
704
721
  "default_content_type": self.spec.default_content_type,
705
722
  "model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
723
+ "filename": getattr(self.spec, "filename", None),
706
724
  }
707
725
 
708
726
  if self.spec.secret_sources:
@@ -711,6 +729,10 @@ class ServingRuntime(RemoteRuntime):
711
729
 
712
730
  return json.dumps(serving_spec)
713
731
 
732
+ @property
733
+ def serving_spec(self):
734
+ return self._get_serving_spec()
735
+
714
736
  def to_mock_server(
715
737
  self,
716
738
  namespace=None,
@@ -815,3 +837,40 @@ class ServingRuntime(RemoteRuntime):
815
837
  "Turn off the mock (mock=False) and make sure Nuclio is installed for real deployment to Nuclio"
816
838
  )
817
839
  self._mock_server = self.to_mock_server()
840
+
841
+ def to_job(self) -> KubejobRuntime:
842
+ """Convert this ServingRuntime to a KubejobRuntime, so that the graph can be run as a standalone job."""
843
+ if self.spec.function_refs:
844
+ raise mlrun.errors.MLRunInvalidArgumentError(
845
+ f"Cannot convert function '{self.metadata.name}' to a job because it has child functions"
846
+ )
847
+
848
+ spec = KubeResourceSpec(
849
+ image=self.spec.image,
850
+ mode=self.spec.mode,
851
+ volumes=self.spec.volumes,
852
+ volume_mounts=self.spec.volume_mounts,
853
+ env=self.spec.env,
854
+ resources=self.spec.resources,
855
+ default_handler="mlrun.serving.server.execute_graph",
856
+ pythonpath=self.spec.pythonpath,
857
+ entry_points=self.spec.entry_points,
858
+ description=self.spec.description,
859
+ workdir=self.spec.workdir,
860
+ image_pull_secret=self.spec.image_pull_secret,
861
+ node_name=self.spec.node_name,
862
+ node_selector=self.spec.node_selector,
863
+ affinity=self.spec.affinity,
864
+ disable_auto_mount=self.spec.disable_auto_mount,
865
+ priority_class_name=self.spec.priority_class_name,
866
+ tolerations=self.spec.tolerations,
867
+ preemption_mode=self.spec.preemption_mode,
868
+ security_context=self.spec.security_context,
869
+ state_thresholds=self.spec.state_thresholds,
870
+ serving_spec=self._get_serving_spec(),
871
+ )
872
+ job = KubejobRuntime(
873
+ spec=spec,
874
+ metadata=self.metadata,
875
+ )
876
+ return job
mlrun/runtimes/pod.py CHANGED
@@ -103,6 +103,7 @@ class KubeResourceSpec(FunctionSpec):
103
103
  "preemption_mode",
104
104
  "security_context",
105
105
  "state_thresholds",
106
+ "serving_spec",
106
107
  ]
107
108
  _default_fields_to_strip = FunctionSpec._default_fields_to_strip + [
108
109
  "volumes",
@@ -178,6 +179,7 @@ class KubeResourceSpec(FunctionSpec):
178
179
  preemption_mode=None,
179
180
  security_context=None,
180
181
  state_thresholds=None,
182
+ serving_spec=None,
181
183
  ):
182
184
  super().__init__(
183
185
  command=command,
@@ -223,6 +225,7 @@ class KubeResourceSpec(FunctionSpec):
223
225
  state_thresholds
224
226
  or mlrun.mlconf.function.spec.state_thresholds.default.to_dict()
225
227
  )
228
+ self.serving_spec = serving_spec
226
229
  # Termination grace period is internal for runtimes that have a pod termination hook hence it is not in the
227
230
  # _dict_fields and doesn't have a setter.
228
231
  self._termination_grace_period_seconds = None
@@ -58,6 +58,7 @@ class RemoteSparkSpec(KubeResourceSpec):
58
58
  preemption_mode=None,
59
59
  security_context=None,
60
60
  state_thresholds=None,
61
+ serving_spec=None,
61
62
  ):
62
63
  super().__init__(
63
64
  command=command,
@@ -87,6 +88,7 @@ class RemoteSparkSpec(KubeResourceSpec):
87
88
  preemption_mode=preemption_mode,
88
89
  security_context=security_context,
89
90
  state_thresholds=state_thresholds,
91
+ serving_spec=serving_spec,
90
92
  )
91
93
  self.provider = provider
92
94
 
@@ -168,6 +168,7 @@ class Spark3JobSpec(KubeResourceSpec):
168
168
  executor_cores=None,
169
169
  security_context=None,
170
170
  state_thresholds=None,
171
+ serving_spec=None,
171
172
  ):
172
173
  super().__init__(
173
174
  command=command,
@@ -197,6 +198,7 @@ class Spark3JobSpec(KubeResourceSpec):
197
198
  preemption_mode=preemption_mode,
198
199
  security_context=security_context,
199
200
  state_thresholds=state_thresholds,
201
+ serving_spec=serving_spec,
200
202
  )
201
203
 
202
204
  self.driver_resources = driver_resources or {}
mlrun/serving/routers.py CHANGED
@@ -80,10 +80,16 @@ class BaseModelRouter(RouterToDict):
80
80
  self._input_path = input_path
81
81
  self._result_path = result_path
82
82
  self._background_task_check_timestamp = None
83
- self._background_task_terminate = False
84
83
  self._background_task_current_state = None
85
84
  self.kwargs = kwargs
86
85
 
86
+ @property
87
+ def background_task_reached_terminal_state(self):
88
+ return (
89
+ self._background_task_current_state
90
+ and self._background_task_current_state != "running"
91
+ )
92
+
87
93
  def parse_event(self, event):
88
94
  parsed_event = {}
89
95
  try:
@@ -185,35 +191,33 @@ class BaseModelRouter(RouterToDict):
185
191
  background_task.status.state
186
192
  in mlrun.common.schemas.BackgroundTaskState.terminal_states()
187
193
  ):
188
- logger.debug(
194
+ logger.info(
189
195
  f"Model endpoint creation task completed with state {background_task.status.state}"
190
196
  )
191
- self._background_task_terminate = True
192
197
  else: # in progress
193
- logger.debug(
198
+ logger.info(
194
199
  f"Model endpoint creation task is still in progress with the current state: "
195
- f"{background_task.status.state}. Events will not be monitored for the next 15 seconds",
200
+ f"{background_task.status.state}. Events will not be monitored for the next "
201
+ f"{mlrun.mlconf.model_endpoint_monitoring.model_endpoint_creation_check_period} seconds",
196
202
  name=self.name,
197
203
  background_task_check_timestamp=self._background_task_check_timestamp.isoformat(),
198
204
  )
199
205
  return background_task.status.state
200
206
  else:
201
- logger.debug(
202
- "Model endpoint creation task name not provided",
207
+ logger.error(
208
+ "Model endpoint creation task name not provided. This function is not being monitored.",
203
209
  )
204
210
  elif self.context.monitoring_mock:
205
- self._background_task_terminate = (
206
- True # If mock monitoring we return success and terminate task check.
207
- )
208
211
  return mlrun.common.schemas.BackgroundTaskState.succeeded
209
- self._background_task_terminate = True # If mock without monitoring we return failed and terminate task check.
210
212
  return mlrun.common.schemas.BackgroundTaskState.failed
211
213
 
212
214
  def _update_background_task_state(self, event):
213
- if not self._background_task_terminate and (
215
+ if not self.background_task_reached_terminal_state and (
214
216
  self._background_task_check_timestamp is None
215
217
  or now_date() - self._background_task_check_timestamp
216
- >= timedelta(seconds=15)
218
+ >= timedelta(
219
+ seconds=mlrun.mlconf.model_endpoint_monitoring.model_endpoint_creation_check_period
220
+ )
217
221
  ):
218
222
  self._background_task_current_state = self._get_background_task_status()
219
223
  if event.body:
mlrun/serving/server.py CHANGED
@@ -21,8 +21,9 @@ import os
21
21
  import socket
22
22
  import traceback
23
23
  import uuid
24
- from typing import Optional, Union
24
+ from typing import Any, Optional, Union
25
25
 
26
+ import storey
26
27
  from nuclio import Context as NuclioContext
27
28
  from nuclio.request import Logger as NuclioLogger
28
29
 
@@ -38,9 +39,10 @@ from mlrun.secrets import SecretsStore
38
39
 
39
40
  from ..common.helpers import parse_versioned_object_uri
40
41
  from ..common.schemas.model_monitoring.constants import FileTargetKind
41
- from ..datastore import get_stream_pusher
42
+ from ..datastore import DataItem, get_stream_pusher
42
43
  from ..datastore.store_resources import ResourceCache
43
44
  from ..errors import MLRunInvalidArgumentError
45
+ from ..execution import MLClientCtx
44
46
  from ..model import ModelObj
45
47
  from ..utils import get_caller_globals
46
48
  from .states import (
@@ -322,7 +324,11 @@ class GraphServer(ModelObj):
322
324
 
323
325
  def _process_response(self, context, response, get_body):
324
326
  body = response.body
325
- if isinstance(body, context.Response) or get_body:
327
+ if (
328
+ isinstance(context, MLClientCtx)
329
+ or isinstance(body, context.Response)
330
+ or get_body
331
+ ):
326
332
  return body
327
333
 
328
334
  if body and not isinstance(body, (str, bytes)):
@@ -535,6 +541,94 @@ def v2_serving_init(context, namespace=None):
535
541
  _set_callbacks(server, context)
536
542
 
537
543
 
544
+ async def async_execute_graph(
545
+ context: MLClientCtx,
546
+ data: DataItem,
547
+ batching: bool,
548
+ batch_size: Optional[int],
549
+ ) -> list[Any]:
550
+ spec = mlrun.utils.get_serving_spec()
551
+
552
+ source_filename = spec.get("filename", None)
553
+ namespace = {}
554
+ if source_filename:
555
+ with open(source_filename) as f:
556
+ exec(f.read(), namespace)
557
+
558
+ server = GraphServer.from_dict(spec)
559
+
560
+ if config.log_level.lower() == "debug":
561
+ server.verbose = True
562
+ context.logger.info_with("Initializing states", namespace=namespace)
563
+ kwargs = {}
564
+ if hasattr(context, "is_mock"):
565
+ kwargs["is_mock"] = context.is_mock
566
+ server.init_states(
567
+ context=None, # this context is expected to be a nuclio context, which we don't have in this flow
568
+ namespace=namespace,
569
+ **kwargs,
570
+ )
571
+ context.logger.info("Initializing graph steps")
572
+ server.init_object(namespace)
573
+
574
+ context.logger.info_with("Graph was initialized", verbose=server.verbose)
575
+
576
+ if server.verbose:
577
+ context.logger.info(server.to_yaml())
578
+
579
+ df = data.as_df()
580
+
581
+ responses = []
582
+
583
+ async def run(body):
584
+ event = storey.Event(id=index, body=body)
585
+ response = await server.run(event, context)
586
+ responses.append(response)
587
+
588
+ if batching and not batch_size:
589
+ batch_size = len(df)
590
+
591
+ batch = []
592
+ for index, row in df.iterrows():
593
+ data = row.to_dict()
594
+ if batching:
595
+ batch.append(data)
596
+ if len(batch) == batch_size:
597
+ await run(batch)
598
+ batch = []
599
+ else:
600
+ await run(data)
601
+
602
+ if batch:
603
+ await run(batch)
604
+
605
+ termination_result = server.wait_for_completion()
606
+ if asyncio.iscoroutine(termination_result):
607
+ await termination_result
608
+
609
+ return responses
610
+
611
+
612
+ def execute_graph(
613
+ context: MLClientCtx,
614
+ data: DataItem,
615
+ batching: bool = False,
616
+ batch_size: Optional[int] = None,
617
+ ) -> (list[Any], Any):
618
+ """
619
+ Execute graph as a job, from start to finish.
620
+
621
+ :param context: The job's execution client context.
622
+ :param data: The input data to the job, to be pushed into the graph row by row, or in batches.
623
+ :param batching: Whether to push one or more batches into the graph rather than row by row.
624
+ :param batch_size: The number of rows to push per batch. If not set, and batching=True, the entire dataset will
625
+ be pushed into the graph in one batch.
626
+
627
+ :return: A list of responses.
628
+ """
629
+ return asyncio.run(async_execute_graph(context, data, batching, batch_size))
630
+
631
+
538
632
  def _set_callbacks(server, context):
539
633
  if not server.graph.supports_termination() or not hasattr(context, "platform"):
540
634
  return