mlrun 1.10.0rc16__py3-none-any.whl → 1.10.0rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (98) hide show
  1. mlrun/__init__.py +22 -2
  2. mlrun/artifacts/document.py +6 -1
  3. mlrun/artifacts/llm_prompt.py +21 -15
  4. mlrun/artifacts/model.py +3 -3
  5. mlrun/common/constants.py +9 -0
  6. mlrun/common/formatters/artifact.py +1 -0
  7. mlrun/common/model_monitoring/helpers.py +86 -0
  8. mlrun/common/schemas/__init__.py +2 -0
  9. mlrun/common/schemas/auth.py +2 -0
  10. mlrun/common/schemas/function.py +10 -0
  11. mlrun/common/schemas/hub.py +30 -18
  12. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  13. mlrun/common/schemas/model_monitoring/constants.py +30 -6
  14. mlrun/common/schemas/model_monitoring/functions.py +13 -4
  15. mlrun/common/schemas/model_monitoring/model_endpoints.py +11 -0
  16. mlrun/common/schemas/pipeline.py +1 -1
  17. mlrun/common/schemas/serving.py +3 -0
  18. mlrun/common/schemas/workflow.py +1 -0
  19. mlrun/common/secrets.py +22 -1
  20. mlrun/config.py +32 -10
  21. mlrun/datastore/__init__.py +11 -3
  22. mlrun/datastore/azure_blob.py +162 -47
  23. mlrun/datastore/datastore.py +9 -4
  24. mlrun/datastore/datastore_profile.py +61 -5
  25. mlrun/datastore/model_provider/huggingface_provider.py +363 -0
  26. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  27. mlrun/datastore/model_provider/model_provider.py +211 -74
  28. mlrun/datastore/model_provider/openai_provider.py +243 -71
  29. mlrun/datastore/s3.py +24 -2
  30. mlrun/datastore/storeytargets.py +2 -3
  31. mlrun/datastore/utils.py +15 -3
  32. mlrun/db/base.py +27 -19
  33. mlrun/db/httpdb.py +57 -48
  34. mlrun/db/nopdb.py +25 -10
  35. mlrun/execution.py +55 -13
  36. mlrun/hub/__init__.py +15 -0
  37. mlrun/hub/module.py +181 -0
  38. mlrun/k8s_utils.py +105 -16
  39. mlrun/launcher/base.py +13 -6
  40. mlrun/launcher/local.py +2 -0
  41. mlrun/model.py +9 -3
  42. mlrun/model_monitoring/api.py +66 -27
  43. mlrun/model_monitoring/applications/__init__.py +1 -1
  44. mlrun/model_monitoring/applications/base.py +372 -136
  45. mlrun/model_monitoring/applications/context.py +2 -4
  46. mlrun/model_monitoring/applications/results.py +4 -7
  47. mlrun/model_monitoring/controller.py +239 -101
  48. mlrun/model_monitoring/db/_schedules.py +36 -13
  49. mlrun/model_monitoring/db/_stats.py +4 -3
  50. mlrun/model_monitoring/db/tsdb/base.py +29 -9
  51. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +4 -5
  52. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +154 -50
  53. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
  54. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
  55. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +245 -51
  56. mlrun/model_monitoring/helpers.py +28 -5
  57. mlrun/model_monitoring/stream_processing.py +45 -14
  58. mlrun/model_monitoring/writer.py +220 -1
  59. mlrun/platforms/__init__.py +3 -2
  60. mlrun/platforms/iguazio.py +7 -3
  61. mlrun/projects/operations.py +6 -1
  62. mlrun/projects/pipelines.py +2 -2
  63. mlrun/projects/project.py +128 -45
  64. mlrun/run.py +94 -17
  65. mlrun/runtimes/__init__.py +18 -0
  66. mlrun/runtimes/base.py +14 -6
  67. mlrun/runtimes/daskjob.py +1 -0
  68. mlrun/runtimes/local.py +5 -2
  69. mlrun/runtimes/mounts.py +20 -2
  70. mlrun/runtimes/nuclio/__init__.py +1 -0
  71. mlrun/runtimes/nuclio/application/application.py +147 -17
  72. mlrun/runtimes/nuclio/function.py +70 -27
  73. mlrun/runtimes/nuclio/serving.py +85 -4
  74. mlrun/runtimes/pod.py +213 -21
  75. mlrun/runtimes/utils.py +49 -9
  76. mlrun/secrets.py +54 -13
  77. mlrun/serving/remote.py +79 -6
  78. mlrun/serving/routers.py +23 -41
  79. mlrun/serving/server.py +211 -40
  80. mlrun/serving/states.py +536 -156
  81. mlrun/serving/steps.py +62 -0
  82. mlrun/serving/system_steps.py +136 -81
  83. mlrun/serving/v2_serving.py +9 -10
  84. mlrun/utils/helpers.py +212 -82
  85. mlrun/utils/logger.py +3 -1
  86. mlrun/utils/notifications/notification/base.py +18 -0
  87. mlrun/utils/notifications/notification/git.py +2 -4
  88. mlrun/utils/notifications/notification/slack.py +2 -4
  89. mlrun/utils/notifications/notification/webhook.py +2 -5
  90. mlrun/utils/notifications/notification_pusher.py +1 -1
  91. mlrun/utils/version/version.json +2 -2
  92. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/METADATA +44 -45
  93. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/RECORD +97 -92
  94. mlrun/api/schemas/__init__.py +0 -259
  95. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/WHEEL +0 -0
  96. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/entry_points.txt +0 -0
  97. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/licenses/LICENSE +0 -0
  98. {mlrun-1.10.0rc16.dist-info → mlrun-1.10.0rc42.dist-info}/top_level.txt +0 -0
mlrun/serving/remote.py CHANGED
@@ -23,10 +23,14 @@ import storey
23
23
  from storey.flow import _ConcurrentJobExecution
24
24
 
25
25
  import mlrun
26
+ import mlrun.common.schemas
26
27
  import mlrun.config
28
+ import mlrun.platforms
29
+ import mlrun.utils.async_http
27
30
  from mlrun.errors import err_to_str
28
- from mlrun.utils import logger
31
+ from mlrun.utils import dict_to_json, logger
29
32
 
33
+ from ..config import config
30
34
  from .utils import (
31
35
  _extract_input_data,
32
36
  _update_result_body,
@@ -73,7 +77,9 @@ class RemoteStep(storey.SendToHttp):
73
77
 
74
78
  :param url: http(s) url or function [project/]name to call
75
79
  :param subpath: path (which follows the url), use `$path` to use the event.path
76
- :param method: HTTP method (GET, POST, ..), default to POST
80
+ :param method: The HTTP method to use for the request (e.g., "GET", "POST", "PUT", "DELETE").
81
+ If not provided, the step will try to use `event.method` at runtime, and if that
82
+ is also missing, it defaults to `"POST"`.
77
83
  :param headers: dictionary with http header values
78
84
  :param url_expression: an expression for getting the url from the event, e.g. "event['url']"
79
85
  :param body_expression: an expression for getting the request body from the event, e.g. "event['data']"
@@ -150,8 +156,8 @@ class RemoteStep(storey.SendToHttp):
150
156
  async def _process_event(self, event):
151
157
  # async implementation (with storey)
152
158
  body = self._get_event_or_body(event)
153
- method, url, headers, body = self._generate_request(event, body)
154
- kwargs = {}
159
+ method, url, headers, body, kwargs = self._generate_request(event, body)
160
+ kwargs = kwargs or {}
155
161
  if self.timeout:
156
162
  kwargs["timeout"] = aiohttp.ClientTimeout(total=self.timeout)
157
163
  try:
@@ -191,7 +197,7 @@ class RemoteStep(storey.SendToHttp):
191
197
  )
192
198
 
193
199
  body = _extract_input_data(self._input_path, event.body)
194
- method, url, headers, body = self._generate_request(event, body)
200
+ method, url, headers, body, kwargs = self._generate_request(event, body)
195
201
  try:
196
202
  resp = self._session.request(
197
203
  method,
@@ -200,6 +206,7 @@ class RemoteStep(storey.SendToHttp):
200
206
  headers=headers,
201
207
  data=body,
202
208
  timeout=self.timeout,
209
+ **kwargs,
203
210
  )
204
211
  except requests.exceptions.ReadTimeout as err:
205
212
  raise requests.exceptions.ReadTimeout(
@@ -240,7 +247,7 @@ class RemoteStep(storey.SendToHttp):
240
247
  body = json.dumps(body)
241
248
  headers["Content-Type"] = "application/json"
242
249
 
243
- return method, url, headers, body
250
+ return method, url, headers, body, {}
244
251
 
245
252
  def _get_data(self, data, headers):
246
253
  if (
@@ -454,3 +461,69 @@ class BatchHttpRequests(_ConcurrentJobExecution):
454
461
  ) and isinstance(data, (str, bytes)):
455
462
  data = json.loads(data)
456
463
  return data
464
+
465
+
466
+ class MLRunAPIRemoteStep(RemoteStep):
467
+ def __init__(
468
+ self, method: str, path: str, fill_placeholders: Optional[bool] = None, **kwargs
469
+ ):
470
+ """
471
+ Graph step implementation for calling MLRun API endpoints
472
+
473
+ :param method: The HTTP method to use for the request (e.g., "GET", "POST", "PUT", "DELETE").
474
+ If not provided, the step will try to use `event.method` at runtime, and if that
475
+ is also missing, it defaults to `"POST"`.
476
+ :param path: API path (e.g. /api/projects)
477
+ :param fill_placeholders: if True, fill placeholders in the path using event fields (default to False)
478
+ :param kwargs: other arguments passed to RemoteStep
479
+ """
480
+ super().__init__(url="", method=method, **kwargs)
481
+ self.rundb = None
482
+ self.path = path
483
+ self.fill_placeholders = fill_placeholders
484
+
485
+ def _generate_request(self, event, body):
486
+ method = self.method or event.method or "POST"
487
+ kw = {
488
+ key: value
489
+ for key, value in (
490
+ ("params", body.get("params")),
491
+ ("json", body.get("json")),
492
+ )
493
+ if value is not None
494
+ }
495
+
496
+ headers = self.headers or {}
497
+ headers.update(body.get("headers", {}))
498
+
499
+ if self.rundb.user:
500
+ kw["auth"] = (self.rundb.user, self.rundb.password)
501
+ elif self.rundb.token_provider:
502
+ token = self.rundb.token_provider.get_token()
503
+ if token:
504
+ # Iguazio auth doesn't support passing token through bearer, so use cookie instead
505
+ if self.rundb.token_provider.is_iguazio_session():
506
+ session_cookie = f'session=j:{{"sid": "{token}"}}'
507
+ headers["cookie"] = session_cookie
508
+ else:
509
+ if "Authorization" not in kw.setdefault("headers", {}):
510
+ headers.update({"Authorization": "Bearer " + token})
511
+
512
+ if mlrun.common.schemas.HeaderNames.client_version not in headers:
513
+ headers.update(
514
+ {
515
+ mlrun.common.schemas.HeaderNames.client_version: self.rundb.client_version,
516
+ mlrun.common.schemas.HeaderNames.python_version: self.rundb.python_version,
517
+ "User-Agent": f"{requests.utils.default_user_agent()} mlrun/{config.version}",
518
+ }
519
+ )
520
+
521
+ url = self.url.format(**body) if self.fill_placeholders else self.url
522
+ headers["Content-Type"] = "application/json"
523
+ return method, url, headers, dict_to_json(body), kw
524
+
525
+ def post_init(self, mode="sync", **kwargs):
526
+ super().post_init(mode=mode, **kwargs)
527
+ self.fill_placeholders = self.fill_placeholders or False
528
+ self.rundb = mlrun.get_run_db()
529
+ self.url = self.rundb.get_base_api_url(self.path)
mlrun/serving/routers.py CHANGED
@@ -31,6 +31,9 @@ import mlrun.common.model_monitoring
31
31
  import mlrun.common.schemas.model_monitoring
32
32
  from mlrun.utils import logger, now_date
33
33
 
34
+ from ..common.model_monitoring.helpers import (
35
+ get_model_endpoints_creation_task_status,
36
+ )
34
37
  from .utils import RouterToDict, _extract_input_data, _update_result_body
35
38
  from .v2_serving import _ModelLogPusher
36
39
 
@@ -171,46 +174,6 @@ class BaseModelRouter(RouterToDict):
171
174
  """run tasks after processing the event"""
172
175
  return event
173
176
 
174
- def _get_background_task_status(
175
- self,
176
- ) -> mlrun.common.schemas.BackgroundTaskState:
177
- self._background_task_check_timestamp = now_date()
178
- server: mlrun.serving.GraphServer = getattr(
179
- self.context, "_server", None
180
- ) or getattr(self.context, "server", None)
181
- if not self.context.is_mock:
182
- if server.model_endpoint_creation_task_name:
183
- background_task = mlrun.get_run_db().get_project_background_task(
184
- server.project, server.model_endpoint_creation_task_name
185
- )
186
- logger.debug(
187
- "Checking model endpoint creation task status",
188
- task_name=server.model_endpoint_creation_task_name,
189
- )
190
- if (
191
- background_task.status.state
192
- in mlrun.common.schemas.BackgroundTaskState.terminal_states()
193
- ):
194
- logger.info(
195
- f"Model endpoint creation task completed with state {background_task.status.state}"
196
- )
197
- else: # in progress
198
- logger.info(
199
- f"Model endpoint creation task is still in progress with the current state: "
200
- f"{background_task.status.state}. Events will not be monitored for the next "
201
- f"{mlrun.mlconf.model_endpoint_monitoring.model_endpoint_creation_check_period} seconds",
202
- name=self.name,
203
- background_task_check_timestamp=self._background_task_check_timestamp.isoformat(),
204
- )
205
- return background_task.status.state
206
- else:
207
- logger.error(
208
- "Model endpoint creation task name not provided. This function is not being monitored.",
209
- )
210
- elif self.context.monitoring_mock:
211
- return mlrun.common.schemas.BackgroundTaskState.succeeded
212
- return mlrun.common.schemas.BackgroundTaskState.failed
213
-
214
177
  def _update_background_task_state(self, event):
215
178
  if not self.background_task_reached_terminal_state and (
216
179
  self._background_task_check_timestamp is None
@@ -219,7 +182,26 @@ class BaseModelRouter(RouterToDict):
219
182
  seconds=mlrun.mlconf.model_endpoint_monitoring.model_endpoint_creation_check_period
220
183
  )
221
184
  ):
222
- self._background_task_current_state = self._get_background_task_status()
185
+ server: mlrun.serving.GraphServer = getattr(
186
+ self.context, "_server", None
187
+ ) or getattr(self.context, "server", None)
188
+ if not self.context.is_mock:
189
+ (
190
+ self._background_task_current_state,
191
+ self._background_task_check_timestamp,
192
+ _,
193
+ ) = get_model_endpoints_creation_task_status(server)
194
+ elif self.context.monitoring_mock:
195
+ self._background_task_current_state = (
196
+ mlrun.common.schemas.BackgroundTaskState.succeeded
197
+ )
198
+ self._background_task_check_timestamp = mlrun.utils.now_date()
199
+ else:
200
+ self._background_task_current_state = (
201
+ mlrun.common.schemas.BackgroundTaskState.failed
202
+ )
203
+ self._background_task_check_timestamp = mlrun.utils.now_date()
204
+
223
205
  if event.body:
224
206
  event.body["background_task_state"] = (
225
207
  self._background_task_current_state
mlrun/serving/server.py CHANGED
@@ -17,21 +17,26 @@ __all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
17
17
  import asyncio
18
18
  import base64
19
19
  import copy
20
+ import importlib
20
21
  import json
21
22
  import os
22
23
  import socket
23
24
  import traceback
24
25
  import uuid
26
+ from collections import defaultdict
27
+ from datetime import datetime, timezone
25
28
  from typing import Any, Optional, Union
26
29
 
30
+ import pandas as pd
27
31
  import storey
28
32
  from nuclio import Context as NuclioContext
29
33
  from nuclio.request import Logger as NuclioLogger
30
34
 
31
35
  import mlrun
32
- import mlrun.common.constants
33
36
  import mlrun.common.helpers
34
37
  import mlrun.common.schemas
38
+ import mlrun.common.schemas.model_monitoring.constants as mm_constants
39
+ import mlrun.datastore.datastore_profile as ds_profile
35
40
  import mlrun.model_monitoring
36
41
  import mlrun.utils
37
42
  from mlrun.config import config
@@ -40,12 +45,13 @@ from mlrun.secrets import SecretsStore
40
45
 
41
46
  from ..common.helpers import parse_versioned_object_uri
42
47
  from ..common.schemas.model_monitoring.constants import FileTargetKind
48
+ from ..common.schemas.serving import MAX_BATCH_JOB_DURATION
43
49
  from ..datastore import DataItem, get_stream_pusher
44
50
  from ..datastore.store_resources import ResourceCache
45
51
  from ..errors import MLRunInvalidArgumentError
46
52
  from ..execution import MLClientCtx
47
53
  from ..model import ModelObj
48
- from ..utils import get_caller_globals
54
+ from ..utils import get_caller_globals, get_relative_module_name_from_path
49
55
  from .states import (
50
56
  FlowStep,
51
57
  MonitoredStep,
@@ -77,7 +83,6 @@ class _StreamContext:
77
83
  self.hostname = socket.gethostname()
78
84
  self.function_uri = function_uri
79
85
  self.output_stream = None
80
- stream_uri = None
81
86
  log_stream = parameters.get(FileTargetKind.LOG_STREAM, "")
82
87
 
83
88
  if (enabled or log_stream) and function_uri:
@@ -88,20 +93,16 @@ class _StreamContext:
88
93
 
89
94
  stream_args = parameters.get("stream_args", {})
90
95
 
91
- if log_stream == DUMMY_STREAM:
92
- # Dummy stream used for testing, see tests/serving/test_serving.py
93
- stream_uri = DUMMY_STREAM
94
- elif not stream_args.get("mock"): # if not a mock: `context.is_mock = True`
95
- stream_uri = mlrun.model_monitoring.get_stream_path(project=project)
96
-
97
96
  if log_stream:
98
- # Update the stream path to the log stream value
99
- stream_uri = log_stream.format(project=project)
100
- self.output_stream = get_stream_pusher(stream_uri, **stream_args)
97
+ # Get the output stream from the log stream path
98
+ stream_path = log_stream.format(project=project)
99
+ self.output_stream = get_stream_pusher(stream_path, **stream_args)
101
100
  else:
102
101
  # Get the output stream from the profile
103
102
  self.output_stream = mlrun.model_monitoring.helpers.get_output_stream(
104
- project=project, mock=stream_args.get("mock", False)
103
+ project=project,
104
+ profile=parameters.get("stream_profile"),
105
+ mock=stream_args.get("mock", False),
105
106
  )
106
107
 
107
108
 
@@ -179,11 +180,12 @@ class GraphServer(ModelObj):
179
180
  self,
180
181
  context,
181
182
  namespace,
182
- resource_cache: ResourceCache = None,
183
+ resource_cache: Optional[ResourceCache] = None,
183
184
  logger=None,
184
185
  is_mock=False,
185
186
  monitoring_mock=False,
186
- ):
187
+ stream_profile: Optional[ds_profile.DatastoreProfile] = None,
188
+ ) -> None:
187
189
  """for internal use, initialize all steps (recursively)"""
188
190
 
189
191
  if self.secret_sources:
@@ -198,6 +200,20 @@ class GraphServer(ModelObj):
198
200
  context.monitoring_mock = monitoring_mock
199
201
  context.root = self.graph
200
202
 
203
+ if is_mock and monitoring_mock:
204
+ if stream_profile:
205
+ # Add the user-defined stream profile to the parameters
206
+ self.parameters["stream_profile"] = stream_profile
207
+ elif not (
208
+ self.parameters.get(FileTargetKind.LOG_STREAM)
209
+ or mlrun.get_secret_or_env(
210
+ mm_constants.ProjectSecretKeys.STREAM_PROFILE_NAME
211
+ )
212
+ ):
213
+ # Set a dummy log stream for mocking purposes if there is no direct
214
+ # user-defined stream profile and no information in the environment
215
+ self.parameters[FileTargetKind.LOG_STREAM] = DUMMY_STREAM
216
+
201
217
  context.stream = _StreamContext(
202
218
  self.track_models, self.parameters, self.function_uri
203
219
  )
@@ -358,6 +374,7 @@ def add_error_raiser_step(
358
374
  raise_exception=monitored_step.raise_exception,
359
375
  models_names=list(monitored_step.class_args["models"].keys()),
360
376
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
377
+ function=monitored_step.function,
361
378
  )
362
379
  if monitored_step.responder:
363
380
  monitored_step.responder = False
@@ -400,6 +417,7 @@ def add_monitoring_general_steps(
400
417
  "mlrun.serving.system_steps.BackgroundTaskStatus",
401
418
  "background_task_status_step",
402
419
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
420
+ full_event=True,
403
421
  )
404
422
  monitor_flow_step = graph.add_step(
405
423
  "storey.Filter",
@@ -505,10 +523,6 @@ def add_system_steps_to_graph(
505
523
  monitor_flow_step.after = [
506
524
  step_name,
507
525
  ]
508
- context.logger.info_with(
509
- "Server graph after adding system steps",
510
- graph=str(graph.steps),
511
- )
512
526
  return graph
513
527
 
514
528
 
@@ -561,25 +575,51 @@ def v2_serving_init(context, namespace=None):
561
575
  async def async_execute_graph(
562
576
  context: MLClientCtx,
563
577
  data: DataItem,
578
+ timestamp_column: Optional[str],
564
579
  batching: bool,
565
580
  batch_size: Optional[int],
566
581
  read_as_lists: bool,
567
582
  nest_under_inputs: bool,
568
- ) -> list[Any]:
583
+ ) -> None:
584
+ # Validate that data parameter is a DataItem and not passed via params
585
+ if not isinstance(data, DataItem):
586
+ raise MLRunInvalidArgumentError(
587
+ f"Parameter 'data' has type hint 'DataItem' but got {type(data).__name__} instead. "
588
+ f"Data files and artifacts must be passed via the 'inputs' parameter, not 'params'. "
589
+ f"The 'params' parameter is for simple configuration values (strings, numbers, booleans), "
590
+ f"while 'inputs' is for data files that need to be loaded. "
591
+ f"Example: run_function(..., inputs={{'data': 'path/to/data.csv'}}, params={{other_config: value}})"
592
+ )
593
+ run_call_count = 0
569
594
  spec = mlrun.utils.get_serving_spec()
570
-
571
- namespace = {}
595
+ modname = None
572
596
  code = os.getenv("MLRUN_EXEC_CODE")
573
597
  if code:
574
598
  code = base64.b64decode(code).decode("utf-8")
575
- exec(code, namespace)
599
+ with open("user_code.py", "w") as fp:
600
+ fp.write(code)
601
+ modname = "user_code"
576
602
  else:
577
603
  # TODO: find another way to get the local file path, or ensure that MLRUN_EXEC_CODE
578
604
  # gets set in local flow and not just in the remote pod
579
- source_filename = spec.get("filename", None)
580
- if source_filename:
581
- with open(source_filename) as f:
582
- exec(f.read(), namespace)
605
+ source_file_path = spec.get("filename", None)
606
+ if source_file_path:
607
+ source_file_path_object, working_dir_path_object = (
608
+ mlrun.utils.helpers.get_source_and_working_dir_paths(source_file_path)
609
+ )
610
+ if not source_file_path_object.is_relative_to(working_dir_path_object):
611
+ raise mlrun.errors.MLRunRuntimeError(
612
+ f"Source file path '{source_file_path}' is not under the current working directory "
613
+ f"(which is required when running with local=True)"
614
+ )
615
+ modname = get_relative_module_name_from_path(
616
+ source_file_path_object, working_dir_path_object
617
+ )
618
+
619
+ namespace = {}
620
+ if modname:
621
+ mod = importlib.import_module(modname)
622
+ namespace = mod.__dict__
583
623
 
584
624
  server = GraphServer.from_dict(spec)
585
625
 
@@ -605,10 +645,43 @@ async def async_execute_graph(
605
645
  f"(status='{task_state}')"
606
646
  )
607
647
 
648
+ df = data.as_df()
649
+
650
+ if df.empty:
651
+ context.logger.warn("Job terminated due to empty inputs (0 rows)")
652
+ return []
653
+
654
+ track_models = spec.get("track_models")
655
+
656
+ if track_models and timestamp_column:
657
+ context.logger.info(f"Sorting dataframe by {timestamp_column}")
658
+ df[timestamp_column] = pd.to_datetime( # in case it's a string
659
+ df[timestamp_column]
660
+ )
661
+ df.sort_values(by=timestamp_column, inplace=True)
662
+ if len(df) > 1:
663
+ start_time = df[timestamp_column].iloc[0]
664
+ end_time = df[timestamp_column].iloc[-1]
665
+ time_range = end_time - start_time
666
+ start_time = start_time.isoformat()
667
+ end_time = end_time.isoformat()
668
+ # TODO: tie this to the controller's base period
669
+ if time_range > pd.Timedelta(MAX_BATCH_JOB_DURATION):
670
+ raise mlrun.errors.MLRunRuntimeError(
671
+ f"Dataframe time range is too long: {time_range}. "
672
+ "Please disable tracking or reduce the input dataset's time range below the defined limit "
673
+ f"of {MAX_BATCH_JOB_DURATION}."
674
+ )
675
+ else:
676
+ start_time = end_time = df["timestamp"].iloc[0].isoformat()
677
+ else:
678
+ # end time will be set from clock time when the batch completes
679
+ start_time = datetime.now(tz=timezone.utc).isoformat()
680
+
608
681
  server.graph = add_system_steps_to_graph(
609
682
  server.project,
610
683
  copy.deepcopy(server.graph),
611
- spec.get("track_models"),
684
+ track_models,
612
685
  context,
613
686
  spec,
614
687
  pause_until_background_task_completion=False, # we've already awaited it
@@ -616,7 +689,6 @@ async def async_execute_graph(
616
689
 
617
690
  if config.log_level.lower() == "debug":
618
691
  server.verbose = True
619
- context.logger.info_with("Initializing states", namespace=namespace)
620
692
  kwargs = {}
621
693
  if hasattr(context, "is_mock"):
622
694
  kwargs["is_mock"] = context.is_mock
@@ -633,19 +705,30 @@ async def async_execute_graph(
633
705
  if server.verbose:
634
706
  context.logger.info(server.to_yaml())
635
707
 
636
- df = data.as_df()
637
-
638
- responses = []
639
-
640
708
  async def run(body):
709
+ nonlocal run_call_count
641
710
  event = storey.Event(id=index, body=body)
642
- response = await server.run(event, context)
643
- responses.append(response)
711
+ if timestamp_column:
712
+ if batching:
713
+ # we use the first row in the batch to determine the timestamp for the whole batch
714
+ body = body[0]
715
+ if not isinstance(body, dict):
716
+ raise mlrun.errors.MLRunRuntimeError(
717
+ f"When timestamp_column=True, event body must be a dict – got {type(body).__name__} instead"
718
+ )
719
+ if timestamp_column not in body:
720
+ raise mlrun.errors.MLRunRuntimeError(
721
+ f"Event body '{body}' did not contain timestamp column '{timestamp_column}'"
722
+ )
723
+ event._original_timestamp = body[timestamp_column]
724
+ run_call_count += 1
725
+ return await server.run(event, context)
644
726
 
645
727
  if batching and not batch_size:
646
728
  batch_size = len(df)
647
729
 
648
730
  batch = []
731
+ tasks = []
649
732
  for index, row in df.iterrows():
650
733
  data = row.to_list() if read_as_lists else row.to_dict()
651
734
  if nest_under_inputs:
@@ -653,24 +736,100 @@ async def async_execute_graph(
653
736
  if batching:
654
737
  batch.append(data)
655
738
  if len(batch) == batch_size:
656
- await run(batch)
739
+ tasks.append(asyncio.create_task(run(batch)))
657
740
  batch = []
658
741
  else:
659
- await run(data)
742
+ tasks.append(asyncio.create_task(run(data)))
660
743
 
661
744
  if batch:
662
- await run(batch)
745
+ tasks.append(asyncio.create_task(run(batch)))
746
+
747
+ responses = await asyncio.gather(*tasks)
663
748
 
664
749
  termination_result = server.wait_for_completion()
665
750
  if asyncio.iscoroutine(termination_result):
666
751
  await termination_result
667
752
 
668
- return responses
753
+ model_endpoint_uids = spec.get("model_endpoint_uids", [])
754
+
755
+ # needed for output_stream to be created
756
+ server = GraphServer.from_dict(spec)
757
+ server.init_states(None, namespace)
758
+
759
+ batch_completion_time = datetime.now(tz=timezone.utc).isoformat()
760
+
761
+ if not timestamp_column:
762
+ end_time = batch_completion_time
763
+
764
+ mm_stream_record = dict(
765
+ kind="batch_complete",
766
+ project=context.project,
767
+ first_timestamp=start_time,
768
+ last_timestamp=end_time,
769
+ batch_completion_time=batch_completion_time,
770
+ )
771
+ output_stream = server.context.stream.output_stream
772
+ for mep_uid in spec.get("model_endpoint_uids", []):
773
+ mm_stream_record["endpoint_id"] = mep_uid
774
+ output_stream.push(mm_stream_record, partition_key=mep_uid)
775
+
776
+ context.logger.info(
777
+ f"Job completed processing {len(df)} rows",
778
+ timestamp_column=timestamp_column,
779
+ model_endpoint_uids=model_endpoint_uids,
780
+ )
781
+
782
+ # log the results as artifacts
783
+ num_of_meps_in_the_graph = len(server.graph.model_endpoints_names)
784
+ artifact_path = None
785
+ if (
786
+ "{{run.uid}}" not in context.artifact_path
787
+ ): # TODO: delete when IG-22841 is resolved
788
+ artifact_path = "+/{{run.uid}}" # will be concatenated to the context's path in extend_artifact_path
789
+ if num_of_meps_in_the_graph <= 1:
790
+ context.log_dataset(
791
+ "prediction", df=pd.DataFrame(responses), artifact_path=artifact_path
792
+ )
793
+ else:
794
+ # turn this list of samples into a dict of lists, one per model endpoint
795
+ grouped = defaultdict(list)
796
+ for sample in responses:
797
+ for model_name, features in sample.items():
798
+ grouped[model_name].append(features)
799
+ # create a dataframe per model endpoint and log it
800
+ for model_name, features in grouped.items():
801
+ context.log_dataset(
802
+ f"prediction_{model_name}",
803
+ df=pd.DataFrame(features),
804
+ artifact_path=artifact_path,
805
+ )
806
+ context.log_result("num_rows", run_call_count)
807
+
808
+
809
+ def _is_inside_asyncio_loop():
810
+ try:
811
+ asyncio.get_running_loop()
812
+ return True
813
+ except RuntimeError:
814
+ return False
815
+
816
+
817
+ # Workaround for running with local=True in Jupyter (ML-10620)
818
+ def _workaround_asyncio_nesting():
819
+ try:
820
+ import nest_asyncio
821
+ except ImportError:
822
+ raise mlrun.errors.MLRunRuntimeError(
823
+ "Cannot execute graph from within an already running asyncio loop. "
824
+ "Attempt to import nest_asyncio as a workaround failed as well."
825
+ )
826
+ nest_asyncio.apply()
669
827
 
670
828
 
671
829
  def execute_graph(
672
830
  context: MLClientCtx,
673
831
  data: DataItem,
832
+ timestamp_column: Optional[str] = None,
674
833
  batching: bool = False,
675
834
  batch_size: Optional[int] = None,
676
835
  read_as_lists: bool = False,
@@ -681,6 +840,9 @@ def execute_graph(
681
840
 
682
841
  :param context: The job's execution client context.
683
842
  :param data: The input data to the job, to be pushed into the graph row by row, or in batches.
843
+ :param timestamp_column: The name of the column that will be used as the timestamp for model monitoring purposes.
844
+ when timestamp_column is used in conjunction with batching, the first timestamp will be used for the entire
845
+ batch.
684
846
  :param batching: Whether to push one or more batches into the graph rather than row by row.
685
847
  :param batch_size: The number of rows to push per batch. If not set, and batching=True, the entire dataset will
686
848
  be pushed into the graph in one batch.
@@ -689,9 +851,18 @@ def execute_graph(
689
851
 
690
852
  :return: A list of responses.
691
853
  """
854
+ if _is_inside_asyncio_loop():
855
+ _workaround_asyncio_nesting()
856
+
692
857
  return asyncio.run(
693
858
  async_execute_graph(
694
- context, data, batching, batch_size, read_as_lists, nest_under_inputs
859
+ context,
860
+ data,
861
+ timestamp_column,
862
+ batching,
863
+ batch_size,
864
+ read_as_lists,
865
+ nest_under_inputs,
695
866
  )
696
867
  )
697
868