mlrun 1.10.0rc13__py3-none-any.whl → 1.10.0rc42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (107) hide show
  1. mlrun/__init__.py +22 -2
  2. mlrun/artifacts/base.py +0 -31
  3. mlrun/artifacts/document.py +6 -1
  4. mlrun/artifacts/llm_prompt.py +123 -25
  5. mlrun/artifacts/manager.py +0 -5
  6. mlrun/artifacts/model.py +3 -3
  7. mlrun/common/constants.py +10 -1
  8. mlrun/common/formatters/artifact.py +1 -0
  9. mlrun/common/model_monitoring/helpers.py +86 -0
  10. mlrun/common/schemas/__init__.py +3 -0
  11. mlrun/common/schemas/auth.py +2 -0
  12. mlrun/common/schemas/function.py +10 -0
  13. mlrun/common/schemas/hub.py +30 -18
  14. mlrun/common/schemas/model_monitoring/__init__.py +3 -0
  15. mlrun/common/schemas/model_monitoring/constants.py +30 -6
  16. mlrun/common/schemas/model_monitoring/functions.py +14 -5
  17. mlrun/common/schemas/model_monitoring/model_endpoints.py +21 -0
  18. mlrun/common/schemas/pipeline.py +1 -1
  19. mlrun/common/schemas/serving.py +3 -0
  20. mlrun/common/schemas/workflow.py +3 -1
  21. mlrun/common/secrets.py +22 -1
  22. mlrun/config.py +33 -11
  23. mlrun/datastore/__init__.py +11 -3
  24. mlrun/datastore/azure_blob.py +162 -47
  25. mlrun/datastore/datastore.py +9 -4
  26. mlrun/datastore/datastore_profile.py +61 -5
  27. mlrun/datastore/model_provider/huggingface_provider.py +363 -0
  28. mlrun/datastore/model_provider/mock_model_provider.py +87 -0
  29. mlrun/datastore/model_provider/model_provider.py +230 -65
  30. mlrun/datastore/model_provider/openai_provider.py +295 -42
  31. mlrun/datastore/s3.py +24 -2
  32. mlrun/datastore/storeytargets.py +2 -3
  33. mlrun/datastore/utils.py +15 -3
  34. mlrun/db/base.py +47 -19
  35. mlrun/db/httpdb.py +120 -56
  36. mlrun/db/nopdb.py +38 -10
  37. mlrun/execution.py +70 -19
  38. mlrun/hub/__init__.py +15 -0
  39. mlrun/hub/module.py +181 -0
  40. mlrun/k8s_utils.py +105 -16
  41. mlrun/launcher/base.py +13 -6
  42. mlrun/launcher/local.py +15 -0
  43. mlrun/model.py +24 -3
  44. mlrun/model_monitoring/__init__.py +1 -0
  45. mlrun/model_monitoring/api.py +66 -27
  46. mlrun/model_monitoring/applications/__init__.py +1 -1
  47. mlrun/model_monitoring/applications/base.py +509 -117
  48. mlrun/model_monitoring/applications/context.py +2 -4
  49. mlrun/model_monitoring/applications/results.py +4 -7
  50. mlrun/model_monitoring/controller.py +239 -101
  51. mlrun/model_monitoring/db/_schedules.py +116 -33
  52. mlrun/model_monitoring/db/_stats.py +4 -3
  53. mlrun/model_monitoring/db/tsdb/base.py +100 -9
  54. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +11 -6
  55. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +191 -50
  56. mlrun/model_monitoring/db/tsdb/tdengine/writer_graph_steps.py +51 -0
  57. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +17 -4
  58. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +259 -40
  59. mlrun/model_monitoring/helpers.py +54 -9
  60. mlrun/model_monitoring/stream_processing.py +45 -14
  61. mlrun/model_monitoring/writer.py +220 -1
  62. mlrun/platforms/__init__.py +3 -2
  63. mlrun/platforms/iguazio.py +7 -3
  64. mlrun/projects/operations.py +6 -1
  65. mlrun/projects/pipelines.py +46 -26
  66. mlrun/projects/project.py +166 -58
  67. mlrun/run.py +94 -17
  68. mlrun/runtimes/__init__.py +18 -0
  69. mlrun/runtimes/base.py +14 -6
  70. mlrun/runtimes/daskjob.py +7 -0
  71. mlrun/runtimes/local.py +5 -2
  72. mlrun/runtimes/mounts.py +20 -2
  73. mlrun/runtimes/mpijob/abstract.py +6 -0
  74. mlrun/runtimes/mpijob/v1.py +6 -0
  75. mlrun/runtimes/nuclio/__init__.py +1 -0
  76. mlrun/runtimes/nuclio/application/application.py +149 -17
  77. mlrun/runtimes/nuclio/function.py +76 -27
  78. mlrun/runtimes/nuclio/serving.py +97 -15
  79. mlrun/runtimes/pod.py +234 -21
  80. mlrun/runtimes/remotesparkjob.py +6 -0
  81. mlrun/runtimes/sparkjob/spark3job.py +6 -0
  82. mlrun/runtimes/utils.py +49 -11
  83. mlrun/secrets.py +54 -13
  84. mlrun/serving/__init__.py +2 -0
  85. mlrun/serving/remote.py +79 -6
  86. mlrun/serving/routers.py +23 -41
  87. mlrun/serving/server.py +320 -80
  88. mlrun/serving/states.py +725 -157
  89. mlrun/serving/steps.py +62 -0
  90. mlrun/serving/system_steps.py +200 -119
  91. mlrun/serving/v2_serving.py +9 -10
  92. mlrun/utils/helpers.py +288 -88
  93. mlrun/utils/logger.py +3 -1
  94. mlrun/utils/notifications/notification/base.py +18 -0
  95. mlrun/utils/notifications/notification/git.py +2 -4
  96. mlrun/utils/notifications/notification/slack.py +2 -4
  97. mlrun/utils/notifications/notification/webhook.py +2 -5
  98. mlrun/utils/notifications/notification_pusher.py +1 -1
  99. mlrun/utils/retryer.py +15 -2
  100. mlrun/utils/version/version.json +2 -2
  101. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/METADATA +45 -51
  102. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/RECORD +106 -101
  103. mlrun/api/schemas/__init__.py +0 -259
  104. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/WHEEL +0 -0
  105. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/entry_points.txt +0 -0
  106. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/licenses/LICENSE +0 -0
  107. {mlrun-1.10.0rc13.dist-info → mlrun-1.10.0rc42.dist-info}/top_level.txt +0 -0
mlrun/serving/server.py CHANGED
@@ -15,22 +15,28 @@
15
15
  __all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
16
16
 
17
17
  import asyncio
18
+ import base64
18
19
  import copy
20
+ import importlib
19
21
  import json
20
22
  import os
21
23
  import socket
22
24
  import traceback
23
25
  import uuid
26
+ from collections import defaultdict
27
+ from datetime import datetime, timezone
24
28
  from typing import Any, Optional, Union
25
29
 
30
+ import pandas as pd
26
31
  import storey
27
32
  from nuclio import Context as NuclioContext
28
33
  from nuclio.request import Logger as NuclioLogger
29
34
 
30
35
  import mlrun
31
- import mlrun.common.constants
32
36
  import mlrun.common.helpers
33
37
  import mlrun.common.schemas
38
+ import mlrun.common.schemas.model_monitoring.constants as mm_constants
39
+ import mlrun.datastore.datastore_profile as ds_profile
34
40
  import mlrun.model_monitoring
35
41
  import mlrun.utils
36
42
  from mlrun.config import config
@@ -39,12 +45,13 @@ from mlrun.secrets import SecretsStore
39
45
 
40
46
  from ..common.helpers import parse_versioned_object_uri
41
47
  from ..common.schemas.model_monitoring.constants import FileTargetKind
48
+ from ..common.schemas.serving import MAX_BATCH_JOB_DURATION
42
49
  from ..datastore import DataItem, get_stream_pusher
43
50
  from ..datastore.store_resources import ResourceCache
44
51
  from ..errors import MLRunInvalidArgumentError
45
52
  from ..execution import MLClientCtx
46
53
  from ..model import ModelObj
47
- from ..utils import get_caller_globals
54
+ from ..utils import get_caller_globals, get_relative_module_name_from_path
48
55
  from .states import (
49
56
  FlowStep,
50
57
  MonitoredStep,
@@ -76,7 +83,6 @@ class _StreamContext:
76
83
  self.hostname = socket.gethostname()
77
84
  self.function_uri = function_uri
78
85
  self.output_stream = None
79
- stream_uri = None
80
86
  log_stream = parameters.get(FileTargetKind.LOG_STREAM, "")
81
87
 
82
88
  if (enabled or log_stream) and function_uri:
@@ -87,20 +93,16 @@ class _StreamContext:
87
93
 
88
94
  stream_args = parameters.get("stream_args", {})
89
95
 
90
- if log_stream == DUMMY_STREAM:
91
- # Dummy stream used for testing, see tests/serving/test_serving.py
92
- stream_uri = DUMMY_STREAM
93
- elif not stream_args.get("mock"): # if not a mock: `context.is_mock = True`
94
- stream_uri = mlrun.model_monitoring.get_stream_path(project=project)
95
-
96
96
  if log_stream:
97
- # Update the stream path to the log stream value
98
- stream_uri = log_stream.format(project=project)
99
- self.output_stream = get_stream_pusher(stream_uri, **stream_args)
97
+ # Get the output stream from the log stream path
98
+ stream_path = log_stream.format(project=project)
99
+ self.output_stream = get_stream_pusher(stream_path, **stream_args)
100
100
  else:
101
101
  # Get the output stream from the profile
102
102
  self.output_stream = mlrun.model_monitoring.helpers.get_output_stream(
103
- project=project, mock=stream_args.get("mock", False)
103
+ project=project,
104
+ profile=parameters.get("stream_profile"),
105
+ mock=stream_args.get("mock", False),
104
106
  )
105
107
 
106
108
 
@@ -178,11 +180,12 @@ class GraphServer(ModelObj):
178
180
  self,
179
181
  context,
180
182
  namespace,
181
- resource_cache: ResourceCache = None,
183
+ resource_cache: Optional[ResourceCache] = None,
182
184
  logger=None,
183
185
  is_mock=False,
184
186
  monitoring_mock=False,
185
- ):
187
+ stream_profile: Optional[ds_profile.DatastoreProfile] = None,
188
+ ) -> None:
186
189
  """for internal use, initialize all steps (recursively)"""
187
190
 
188
191
  if self.secret_sources:
@@ -197,6 +200,20 @@ class GraphServer(ModelObj):
197
200
  context.monitoring_mock = monitoring_mock
198
201
  context.root = self.graph
199
202
 
203
+ if is_mock and monitoring_mock:
204
+ if stream_profile:
205
+ # Add the user-defined stream profile to the parameters
206
+ self.parameters["stream_profile"] = stream_profile
207
+ elif not (
208
+ self.parameters.get(FileTargetKind.LOG_STREAM)
209
+ or mlrun.get_secret_or_env(
210
+ mm_constants.ProjectSecretKeys.STREAM_PROFILE_NAME
211
+ )
212
+ ):
213
+ # Set a dummy log stream for mocking purposes if there is no direct
214
+ # user-defined stream profile and no information in the environment
215
+ self.parameters[FileTargetKind.LOG_STREAM] = DUMMY_STREAM
216
+
200
217
  context.stream = _StreamContext(
201
218
  self.track_models, self.parameters, self.function_uri
202
219
  )
@@ -349,33 +366,34 @@ def add_error_raiser_step(
349
366
  monitored_steps_raisers = {}
350
367
  user_steps = list(graph.steps.values())
351
368
  for monitored_step in monitored_steps.values():
352
- if monitored_step.raise_exception:
353
- error_step = graph.add_step(
354
- class_name="mlrun.serving.states.ModelRunnerErrorRaiser",
355
- name=f"{monitored_step.name}_error_raise",
356
- after=monitored_step.name,
357
- full_event=True,
358
- raise_exception=monitored_step.raise_exception,
359
- models_names=list(monitored_step.class_args["models"].keys()),
360
- model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
361
- )
362
- if monitored_step.responder:
363
- monitored_step.responder = False
364
- error_step.respond()
365
- monitored_steps_raisers[monitored_step.name] = error_step.name
366
- error_step.on_error = monitored_step.on_error
367
- for step in user_steps:
368
- if step.after:
369
- if isinstance(step.after, list):
370
- for i in range(len(step.after)):
371
- if step.after[i] in monitored_steps_raisers:
372
- step.after[i] = monitored_steps_raisers[step.after[i]]
373
- else:
374
- if (
375
- isinstance(step.after, str)
376
- and step.after in monitored_steps_raisers
377
- ):
378
- step.after = monitored_steps_raisers[step.after]
369
+ error_step = graph.add_step(
370
+ class_name="mlrun.serving.states.ModelRunnerErrorRaiser",
371
+ name=f"{monitored_step.name}_error_raise",
372
+ after=monitored_step.name,
373
+ full_event=True,
374
+ raise_exception=monitored_step.raise_exception,
375
+ models_names=list(monitored_step.class_args["models"].keys()),
376
+ model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
377
+ function=monitored_step.function,
378
+ )
379
+ if monitored_step.responder:
380
+ monitored_step.responder = False
381
+ error_step.respond()
382
+ monitored_steps_raisers[monitored_step.name] = error_step.name
383
+ error_step.on_error = monitored_step.on_error
384
+ if monitored_steps_raisers:
385
+ for step in user_steps:
386
+ if step.after:
387
+ if isinstance(step.after, list):
388
+ for i in range(len(step.after)):
389
+ if step.after[i] in monitored_steps_raisers:
390
+ step.after[i] = monitored_steps_raisers[step.after[i]]
391
+ else:
392
+ if (
393
+ isinstance(step.after, str)
394
+ and step.after in monitored_steps_raisers
395
+ ):
396
+ step.after = monitored_steps_raisers[step.after]
379
397
  return graph
380
398
 
381
399
 
@@ -384,6 +402,7 @@ def add_monitoring_general_steps(
384
402
  graph: RootFlowStep,
385
403
  context,
386
404
  serving_spec,
405
+ pause_until_background_task_completion: bool,
387
406
  ) -> tuple[RootFlowStep, FlowStep]:
388
407
  """
389
408
  Adding the monitoring flow connection steps, this steps allow the graph to reconstruct the serving event enrich it
@@ -392,18 +411,23 @@ def add_monitoring_general_steps(
392
411
  "background_task_status_step" --> "filter_none" --> "monitoring_pre_processor_step" --> "flatten_events"
393
412
  --> "sampling_step" --> "filter_none_sampling" --> "model_monitoring_stream"
394
413
  """
414
+ background_task_status_step = None
415
+ if pause_until_background_task_completion:
416
+ background_task_status_step = graph.add_step(
417
+ "mlrun.serving.system_steps.BackgroundTaskStatus",
418
+ "background_task_status_step",
419
+ model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
420
+ full_event=True,
421
+ )
395
422
  monitor_flow_step = graph.add_step(
396
- "mlrun.serving.system_steps.BackgroundTaskStatus",
397
- "background_task_status_step",
398
- model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
399
- )
400
- graph.add_step(
401
423
  "storey.Filter",
402
424
  "filter_none",
403
425
  _fn="(event is not None)",
404
- after="background_task_status_step",
426
+ after="background_task_status_step" if background_task_status_step else None,
405
427
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
406
428
  )
429
+ if background_task_status_step:
430
+ monitor_flow_step = background_task_status_step
407
431
  graph.add_step(
408
432
  "mlrun.serving.system_steps.MonitoringPreProcessor",
409
433
  "monitoring_pre_processor_step",
@@ -466,14 +490,28 @@ def add_monitoring_general_steps(
466
490
 
467
491
 
468
492
  def add_system_steps_to_graph(
469
- project: str, graph: RootFlowStep, track_models: bool, context, serving_spec
493
+ project: str,
494
+ graph: RootFlowStep,
495
+ track_models: bool,
496
+ context,
497
+ serving_spec,
498
+ pause_until_background_task_completion: bool = True,
470
499
  ) -> RootFlowStep:
500
+ if not (isinstance(graph, RootFlowStep) and graph.include_monitored_step()):
501
+ return graph
471
502
  monitored_steps = graph.get_monitored_steps()
472
503
  graph = add_error_raiser_step(graph, monitored_steps)
473
504
  if track_models:
505
+ background_task_status_step = None
474
506
  graph, monitor_flow_step = add_monitoring_general_steps(
475
- project, graph, context, serving_spec
507
+ project,
508
+ graph,
509
+ context,
510
+ serving_spec,
511
+ pause_until_background_task_completion,
476
512
  )
513
+ if background_task_status_step:
514
+ monitor_flow_step = background_task_status_step
477
515
  # Connect each model runner to the monitoring step:
478
516
  for step_name, step in monitored_steps.items():
479
517
  if monitor_flow_step.after:
@@ -494,18 +532,13 @@ def v2_serving_init(context, namespace=None):
494
532
  context.logger.info("Initializing server from spec")
495
533
  spec = mlrun.utils.get_serving_spec()
496
534
  server = GraphServer.from_dict(spec)
497
- if isinstance(server.graph, RootFlowStep) and server.graph.include_monitored_step():
498
- server.graph = add_system_steps_to_graph(
499
- server.project,
500
- copy.deepcopy(server.graph),
501
- spec.get("track_models"),
502
- context,
503
- spec,
504
- )
505
- context.logger.info_with(
506
- "Server graph after adding system steps",
507
- graph=str(server.graph.steps),
508
- )
535
+ server.graph = add_system_steps_to_graph(
536
+ server.project,
537
+ copy.deepcopy(server.graph),
538
+ spec.get("track_models"),
539
+ context,
540
+ spec,
541
+ )
509
542
 
510
543
  if config.log_level.lower() == "debug":
511
544
  server.verbose = True
@@ -542,22 +575,120 @@ def v2_serving_init(context, namespace=None):
542
575
  async def async_execute_graph(
543
576
  context: MLClientCtx,
544
577
  data: DataItem,
578
+ timestamp_column: Optional[str],
545
579
  batching: bool,
546
580
  batch_size: Optional[int],
547
- ) -> list[Any]:
581
+ read_as_lists: bool,
582
+ nest_under_inputs: bool,
583
+ ) -> None:
584
+ # Validate that data parameter is a DataItem and not passed via params
585
+ if not isinstance(data, DataItem):
586
+ raise MLRunInvalidArgumentError(
587
+ f"Parameter 'data' has type hint 'DataItem' but got {type(data).__name__} instead. "
588
+ f"Data files and artifacts must be passed via the 'inputs' parameter, not 'params'. "
589
+ f"The 'params' parameter is for simple configuration values (strings, numbers, booleans), "
590
+ f"while 'inputs' is for data files that need to be loaded. "
591
+ f"Example: run_function(..., inputs={{'data': 'path/to/data.csv'}}, params={{other_config: value}})"
592
+ )
593
+ run_call_count = 0
548
594
  spec = mlrun.utils.get_serving_spec()
595
+ modname = None
596
+ code = os.getenv("MLRUN_EXEC_CODE")
597
+ if code:
598
+ code = base64.b64decode(code).decode("utf-8")
599
+ with open("user_code.py", "w") as fp:
600
+ fp.write(code)
601
+ modname = "user_code"
602
+ else:
603
+ # TODO: find another way to get the local file path, or ensure that MLRUN_EXEC_CODE
604
+ # gets set in local flow and not just in the remote pod
605
+ source_file_path = spec.get("filename", None)
606
+ if source_file_path:
607
+ source_file_path_object, working_dir_path_object = (
608
+ mlrun.utils.helpers.get_source_and_working_dir_paths(source_file_path)
609
+ )
610
+ if not source_file_path_object.is_relative_to(working_dir_path_object):
611
+ raise mlrun.errors.MLRunRuntimeError(
612
+ f"Source file path '{source_file_path}' is not under the current working directory "
613
+ f"(which is required when running with local=True)"
614
+ )
615
+ modname = get_relative_module_name_from_path(
616
+ source_file_path_object, working_dir_path_object
617
+ )
549
618
 
550
- source_filename = spec.get("filename", None)
551
619
  namespace = {}
552
- if source_filename:
553
- with open(source_filename) as f:
554
- exec(f.read(), namespace)
620
+ if modname:
621
+ mod = importlib.import_module(modname)
622
+ namespace = mod.__dict__
555
623
 
556
624
  server = GraphServer.from_dict(spec)
557
625
 
626
+ if server.model_endpoint_creation_task_name:
627
+ context.logger.info(
628
+ f"Waiting for model endpoint creation task '{server.model_endpoint_creation_task_name}'..."
629
+ )
630
+ background_task = (
631
+ mlrun.get_run_db().wait_for_background_task_to_reach_terminal_state(
632
+ project=server.project,
633
+ name=server.model_endpoint_creation_task_name,
634
+ )
635
+ )
636
+ task_state = background_task.status.state
637
+ if task_state == mlrun.common.schemas.BackgroundTaskState.failed:
638
+ raise mlrun.errors.MLRunRuntimeError(
639
+ "Aborting job due to model endpoint creation background task failure"
640
+ )
641
+ elif task_state != mlrun.common.schemas.BackgroundTaskState.succeeded:
642
+ # this shouldn't happen, but we need to know if it does
643
+ raise mlrun.errors.MLRunRuntimeError(
644
+ "Aborting job because the model endpoint creation background task did not succeed "
645
+ f"(status='{task_state}')"
646
+ )
647
+
648
+ df = data.as_df()
649
+
650
+ if df.empty:
651
+ context.logger.warn("Job terminated due to empty inputs (0 rows)")
652
+ return []
653
+
654
+ track_models = spec.get("track_models")
655
+
656
+ if track_models and timestamp_column:
657
+ context.logger.info(f"Sorting dataframe by {timestamp_column}")
658
+ df[timestamp_column] = pd.to_datetime( # in case it's a string
659
+ df[timestamp_column]
660
+ )
661
+ df.sort_values(by=timestamp_column, inplace=True)
662
+ if len(df) > 1:
663
+ start_time = df[timestamp_column].iloc[0]
664
+ end_time = df[timestamp_column].iloc[-1]
665
+ time_range = end_time - start_time
666
+ start_time = start_time.isoformat()
667
+ end_time = end_time.isoformat()
668
+ # TODO: tie this to the controller's base period
669
+ if time_range > pd.Timedelta(MAX_BATCH_JOB_DURATION):
670
+ raise mlrun.errors.MLRunRuntimeError(
671
+ f"Dataframe time range is too long: {time_range}. "
672
+ "Please disable tracking or reduce the input dataset's time range below the defined limit "
673
+ f"of {MAX_BATCH_JOB_DURATION}."
674
+ )
675
+ else:
676
+ start_time = end_time = df["timestamp"].iloc[0].isoformat()
677
+ else:
678
+ # end time will be set from clock time when the batch completes
679
+ start_time = datetime.now(tz=timezone.utc).isoformat()
680
+
681
+ server.graph = add_system_steps_to_graph(
682
+ server.project,
683
+ copy.deepcopy(server.graph),
684
+ track_models,
685
+ context,
686
+ spec,
687
+ pause_until_background_task_completion=False, # we've already awaited it
688
+ )
689
+
558
690
  if config.log_level.lower() == "debug":
559
691
  server.verbose = True
560
- context.logger.info_with("Initializing states", namespace=namespace)
561
692
  kwargs = {}
562
693
  if hasattr(context, "is_mock"):
563
694
  kwargs["is_mock"] = context.is_mock
@@ -574,57 +705,166 @@ async def async_execute_graph(
574
705
  if server.verbose:
575
706
  context.logger.info(server.to_yaml())
576
707
 
577
- df = data.as_df()
578
-
579
- responses = []
580
-
581
708
  async def run(body):
709
+ nonlocal run_call_count
582
710
  event = storey.Event(id=index, body=body)
583
- response = await server.run(event, context)
584
- responses.append(response)
711
+ if timestamp_column:
712
+ if batching:
713
+ # we use the first row in the batch to determine the timestamp for the whole batch
714
+ body = body[0]
715
+ if not isinstance(body, dict):
716
+ raise mlrun.errors.MLRunRuntimeError(
717
+ f"When timestamp_column=True, event body must be a dict – got {type(body).__name__} instead"
718
+ )
719
+ if timestamp_column not in body:
720
+ raise mlrun.errors.MLRunRuntimeError(
721
+ f"Event body '{body}' did not contain timestamp column '{timestamp_column}'"
722
+ )
723
+ event._original_timestamp = body[timestamp_column]
724
+ run_call_count += 1
725
+ return await server.run(event, context)
585
726
 
586
727
  if batching and not batch_size:
587
728
  batch_size = len(df)
588
729
 
589
730
  batch = []
731
+ tasks = []
590
732
  for index, row in df.iterrows():
591
- data = row.to_dict()
733
+ data = row.to_list() if read_as_lists else row.to_dict()
734
+ if nest_under_inputs:
735
+ data = {"inputs": data}
592
736
  if batching:
593
737
  batch.append(data)
594
738
  if len(batch) == batch_size:
595
- await run(batch)
739
+ tasks.append(asyncio.create_task(run(batch)))
596
740
  batch = []
597
741
  else:
598
- await run(data)
742
+ tasks.append(asyncio.create_task(run(data)))
599
743
 
600
744
  if batch:
601
- await run(batch)
745
+ tasks.append(asyncio.create_task(run(batch)))
746
+
747
+ responses = await asyncio.gather(*tasks)
602
748
 
603
749
  termination_result = server.wait_for_completion()
604
750
  if asyncio.iscoroutine(termination_result):
605
751
  await termination_result
606
752
 
607
- return responses
753
+ model_endpoint_uids = spec.get("model_endpoint_uids", [])
754
+
755
+ # needed for output_stream to be created
756
+ server = GraphServer.from_dict(spec)
757
+ server.init_states(None, namespace)
758
+
759
+ batch_completion_time = datetime.now(tz=timezone.utc).isoformat()
760
+
761
+ if not timestamp_column:
762
+ end_time = batch_completion_time
763
+
764
+ mm_stream_record = dict(
765
+ kind="batch_complete",
766
+ project=context.project,
767
+ first_timestamp=start_time,
768
+ last_timestamp=end_time,
769
+ batch_completion_time=batch_completion_time,
770
+ )
771
+ output_stream = server.context.stream.output_stream
772
+ for mep_uid in spec.get("model_endpoint_uids", []):
773
+ mm_stream_record["endpoint_id"] = mep_uid
774
+ output_stream.push(mm_stream_record, partition_key=mep_uid)
775
+
776
+ context.logger.info(
777
+ f"Job completed processing {len(df)} rows",
778
+ timestamp_column=timestamp_column,
779
+ model_endpoint_uids=model_endpoint_uids,
780
+ )
781
+
782
+ # log the results as artifacts
783
+ num_of_meps_in_the_graph = len(server.graph.model_endpoints_names)
784
+ artifact_path = None
785
+ if (
786
+ "{{run.uid}}" not in context.artifact_path
787
+ ): # TODO: delete when IG-22841 is resolved
788
+ artifact_path = "+/{{run.uid}}" # will be concatenated to the context's path in extend_artifact_path
789
+ if num_of_meps_in_the_graph <= 1:
790
+ context.log_dataset(
791
+ "prediction", df=pd.DataFrame(responses), artifact_path=artifact_path
792
+ )
793
+ else:
794
+ # turn this list of samples into a dict of lists, one per model endpoint
795
+ grouped = defaultdict(list)
796
+ for sample in responses:
797
+ for model_name, features in sample.items():
798
+ grouped[model_name].append(features)
799
+ # create a dataframe per model endpoint and log it
800
+ for model_name, features in grouped.items():
801
+ context.log_dataset(
802
+ f"prediction_{model_name}",
803
+ df=pd.DataFrame(features),
804
+ artifact_path=artifact_path,
805
+ )
806
+ context.log_result("num_rows", run_call_count)
807
+
808
+
809
+ def _is_inside_asyncio_loop():
810
+ try:
811
+ asyncio.get_running_loop()
812
+ return True
813
+ except RuntimeError:
814
+ return False
815
+
816
+
817
+ # Workaround for running with local=True in Jupyter (ML-10620)
818
+ def _workaround_asyncio_nesting():
819
+ try:
820
+ import nest_asyncio
821
+ except ImportError:
822
+ raise mlrun.errors.MLRunRuntimeError(
823
+ "Cannot execute graph from within an already running asyncio loop. "
824
+ "Attempt to import nest_asyncio as a workaround failed as well."
825
+ )
826
+ nest_asyncio.apply()
608
827
 
609
828
 
610
829
  def execute_graph(
611
830
  context: MLClientCtx,
612
831
  data: DataItem,
832
+ timestamp_column: Optional[str] = None,
613
833
  batching: bool = False,
614
834
  batch_size: Optional[int] = None,
835
+ read_as_lists: bool = False,
836
+ nest_under_inputs: bool = False,
615
837
  ) -> (list[Any], Any):
616
838
  """
617
839
  Execute graph as a job, from start to finish.
618
840
 
619
841
  :param context: The job's execution client context.
620
842
  :param data: The input data to the job, to be pushed into the graph row by row, or in batches.
843
+ :param timestamp_column: The name of the column that will be used as the timestamp for model monitoring purposes.
844
+ when timestamp_column is used in conjunction with batching, the first timestamp will be used for the entire
845
+ batch.
621
846
  :param batching: Whether to push one or more batches into the graph rather than row by row.
622
847
  :param batch_size: The number of rows to push per batch. If not set, and batching=True, the entire dataset will
623
848
  be pushed into the graph in one batch.
849
+ :param read_as_lists: Whether to read each row as a list instead of a dictionary.
850
+ :param nest_under_inputs: Whether to wrap each row with {"inputs": ...}.
624
851
 
625
852
  :return: A list of responses.
626
853
  """
627
- return asyncio.run(async_execute_graph(context, data, batching, batch_size))
854
+ if _is_inside_asyncio_loop():
855
+ _workaround_asyncio_nesting()
856
+
857
+ return asyncio.run(
858
+ async_execute_graph(
859
+ context,
860
+ data,
861
+ timestamp_column,
862
+ batching,
863
+ batch_size,
864
+ read_as_lists,
865
+ nest_under_inputs,
866
+ )
867
+ )
628
868
 
629
869
 
630
870
  def _set_callbacks(server, context):