mlrun 1.10.0rc11__py3-none-any.whl → 1.10.0rc12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (54) hide show
  1. mlrun/__init__.py +2 -1
  2. mlrun/__main__.py +7 -1
  3. mlrun/artifacts/base.py +9 -3
  4. mlrun/artifacts/dataset.py +2 -1
  5. mlrun/artifacts/llm_prompt.py +1 -1
  6. mlrun/artifacts/model.py +2 -2
  7. mlrun/common/constants.py +1 -0
  8. mlrun/common/runtimes/constants.py +10 -1
  9. mlrun/config.py +19 -2
  10. mlrun/datastore/__init__.py +3 -1
  11. mlrun/datastore/alibaba_oss.py +1 -1
  12. mlrun/datastore/azure_blob.py +1 -1
  13. mlrun/datastore/base.py +6 -31
  14. mlrun/datastore/datastore.py +109 -33
  15. mlrun/datastore/datastore_profile.py +31 -0
  16. mlrun/datastore/dbfs_store.py +1 -1
  17. mlrun/datastore/google_cloud_storage.py +2 -2
  18. mlrun/datastore/model_provider/__init__.py +13 -0
  19. mlrun/datastore/model_provider/model_provider.py +82 -0
  20. mlrun/datastore/model_provider/openai_provider.py +120 -0
  21. mlrun/datastore/remote_client.py +54 -0
  22. mlrun/datastore/s3.py +1 -1
  23. mlrun/datastore/storeytargets.py +1 -1
  24. mlrun/datastore/utils.py +22 -0
  25. mlrun/datastore/v3io.py +1 -1
  26. mlrun/db/base.py +1 -1
  27. mlrun/db/httpdb.py +9 -4
  28. mlrun/db/nopdb.py +1 -1
  29. mlrun/execution.py +23 -7
  30. mlrun/launcher/base.py +23 -13
  31. mlrun/launcher/local.py +3 -1
  32. mlrun/launcher/remote.py +4 -2
  33. mlrun/model.py +65 -0
  34. mlrun/package/packagers_manager.py +2 -0
  35. mlrun/projects/operations.py +8 -1
  36. mlrun/projects/project.py +23 -5
  37. mlrun/run.py +17 -0
  38. mlrun/runtimes/__init__.py +6 -0
  39. mlrun/runtimes/base.py +24 -6
  40. mlrun/runtimes/daskjob.py +1 -0
  41. mlrun/runtimes/databricks_job/databricks_runtime.py +1 -0
  42. mlrun/runtimes/local.py +1 -6
  43. mlrun/serving/server.py +0 -2
  44. mlrun/serving/states.py +30 -5
  45. mlrun/serving/system_steps.py +22 -28
  46. mlrun/utils/helpers.py +13 -2
  47. mlrun/utils/notifications/notification_pusher.py +15 -0
  48. mlrun/utils/version/version.json +2 -2
  49. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/METADATA +2 -2
  50. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/RECORD +54 -50
  51. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/WHEEL +0 -0
  52. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/entry_points.txt +0 -0
  53. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/licenses/LICENSE +0 -0
  54. {mlrun-1.10.0rc11.dist-info → mlrun-1.10.0rc12.dist-info}/top_level.txt +0 -0
mlrun/model.py CHANGED
@@ -935,6 +935,41 @@ class HyperParamOptions(ModelObj):
935
935
  )
936
936
 
937
937
 
938
+ class RetryBackoff(ModelObj):
939
+ """Backoff strategy for retries."""
940
+
941
+ def __init__(self, base_delay: Optional[str] = None):
942
+ # The base_delay time string must conform to timelength python package standards and be at least
943
+ # mlrun.mlconf.function.spec.retry.backoff.min_base_delay (e.g. 1000s, 1 hour 30m, 1h etc.).
944
+ self.base_delay = (
945
+ base_delay or mlrun.mlconf.function.spec.retry.backoff.default_base_delay
946
+ )
947
+
948
+
949
+ class Retry(ModelObj):
950
+ """Retry configuration"""
951
+
952
+ def __init__(
953
+ self,
954
+ count: int = 0,
955
+ backoff: typing.Union[RetryBackoff, dict] = None,
956
+ ):
957
+ # Set to None if count is 0 to eliminate the retry configuration from the dictionary representation.
958
+ self.count = count or None
959
+ self.backoff = backoff
960
+
961
+ @property
962
+ def backoff(self) -> Optional[RetryBackoff]:
963
+ if not self.count:
964
+ # Retry is not configured, return None
965
+ return None
966
+ return self._backoff
967
+
968
+ @backoff.setter
969
+ def backoff(self, backoff):
970
+ self._backoff = self._verify_dict(backoff, "backoff", RetryBackoff)
971
+
972
+
938
973
  class RunSpec(ModelObj):
939
974
  """Run specification"""
940
975
 
@@ -971,6 +1006,7 @@ class RunSpec(ModelObj):
971
1006
  node_selector=None,
972
1007
  tolerations=None,
973
1008
  affinity=None,
1009
+ retry=None,
974
1010
  ):
975
1011
  # A dictionary of parsing configurations that will be read from the inputs the user set. The keys are the inputs
976
1012
  # keys (parameter names) and the values are the type hint given in the input keys after the colon.
@@ -1011,6 +1047,7 @@ class RunSpec(ModelObj):
1011
1047
  self.node_selector = node_selector or {}
1012
1048
  self.tolerations = tolerations or {}
1013
1049
  self.affinity = affinity or {}
1050
+ self.retry = retry or {}
1014
1051
 
1015
1052
  def _serialize_field(
1016
1053
  self, struct: dict, field_name: Optional[str] = None, strip: bool = False
@@ -1212,6 +1249,14 @@ class RunSpec(ModelObj):
1212
1249
  self._verify_dict(state_thresholds, "state_thresholds")
1213
1250
  self._state_thresholds = state_thresholds
1214
1251
 
1252
+ @property
1253
+ def retry(self) -> Retry:
1254
+ return self._retry
1255
+
1256
+ @retry.setter
1257
+ def retry(self, retry: typing.Union[Retry, dict]):
1258
+ self._retry = self._verify_dict(retry, "retry", Retry)
1259
+
1215
1260
  def extract_type_hints_from_inputs(self):
1216
1261
  """
1217
1262
  This method extracts the type hints from the input keys in the input dictionary.
@@ -1329,6 +1374,7 @@ class RunStatus(ModelObj):
1329
1374
  reason: Optional[str] = None,
1330
1375
  notifications: Optional[dict[str, Notification]] = None,
1331
1376
  artifact_uris: Optional[dict[str, str]] = None,
1377
+ retry_count: Optional[int] = None,
1332
1378
  ):
1333
1379
  self.state = state or "created"
1334
1380
  self.status_text = status_text
@@ -1346,6 +1392,7 @@ class RunStatus(ModelObj):
1346
1392
  self.notifications = notifications or {}
1347
1393
  # Artifact key -> URI mapping, since the full artifacts are not stored in the runs DB table
1348
1394
  self._artifact_uris = artifact_uris or {}
1395
+ self._retry_count = retry_count or None
1349
1396
 
1350
1397
  @classmethod
1351
1398
  def from_dict(
@@ -1399,6 +1446,21 @@ class RunStatus(ModelObj):
1399
1446
 
1400
1447
  self._artifact_uris = resolved_artifact_uris
1401
1448
 
1449
+ @property
1450
+ def retry_count(self) -> Optional[int]:
1451
+ """
1452
+ The number of retries that were made for this run.
1453
+ """
1454
+ return self._retry_count
1455
+
1456
+ @retry_count.setter
1457
+ def retry_count(self, retry_count: int):
1458
+ """
1459
+ Set the number of retries that were made for this run.
1460
+ :param retry_count: The number of retries.
1461
+ """
1462
+ self._retry_count = retry_count
1463
+
1402
1464
  def is_failed(self) -> Optional[bool]:
1403
1465
  """
1404
1466
  This method returns whether a run has failed.
@@ -2026,6 +2088,7 @@ def new_task(
2026
2088
  secrets=None,
2027
2089
  base=None,
2028
2090
  returns=None,
2091
+ retry=None,
2029
2092
  ) -> RunTemplate:
2030
2093
  """Creates a new task
2031
2094
 
@@ -2061,6 +2124,7 @@ def new_task(
2061
2124
  * A dictionary of configurations to use when logging. Further info per object type and
2062
2125
  artifact type can be given there. The artifact key must appear in the dictionary as
2063
2126
  "key": "the_key".
2127
+ :param retry: Retry configuration for the run, can be a dict or an instance of mlrun.model.Retry.
2064
2128
  """
2065
2129
 
2066
2130
  if base:
@@ -2086,6 +2150,7 @@ def new_task(
2086
2150
  run.spec.hyper_param_options.selector = (
2087
2151
  selector or run.spec.hyper_param_options.selector
2088
2152
  )
2153
+ run.spec.retry = retry or run.spec.retry
2089
2154
  return run
2090
2155
 
2091
2156
 
@@ -21,6 +21,7 @@ from typing import Any, Optional, Union
21
21
 
22
22
  import mlrun.errors
23
23
  from mlrun.artifacts import Artifact
24
+ from mlrun.artifacts.base import verify_target_path
24
25
  from mlrun.datastore import DataItem, get_store_resource, store_manager
25
26
  from mlrun.errors import MLRunInvalidArgumentError
26
27
  from mlrun.utils import logger
@@ -276,6 +277,7 @@ class PackagersManager:
276
277
  if data_item.get_artifact_type():
277
278
  # Get the artifact object in the data item:
278
279
  artifact, _ = store_manager.get_store_artifact(url=data_item.artifact_url)
280
+ verify_target_path(artifact)
279
281
  # Get the key from the artifact's metadata and instructions from the artifact's spec:
280
282
  artifact_key = artifact.metadata.key
281
283
  packaging_instructions = artifact.spec.unpackaging_instructions
@@ -20,7 +20,6 @@ import mlrun
20
20
  import mlrun.common.constants as mlrun_constants
21
21
  import mlrun.common.schemas.function
22
22
  import mlrun.common.schemas.workflow
23
- import mlrun_pipelines.common.models
24
23
  import mlrun_pipelines.models
25
24
  from mlrun.utils import hub_prefix
26
25
 
@@ -82,6 +81,7 @@ def run_function(
82
81
  builder_env: Optional[list] = None,
83
82
  reset_on_run: Optional[bool] = None,
84
83
  output_path: Optional[str] = None,
84
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
85
85
  ) -> Union[mlrun.model.RunObject, mlrun_pipelines.models.PipelineNodeWrapper]:
86
86
  """Run a local or remote task as part of a local/kubeflow pipeline
87
87
 
@@ -177,6 +177,7 @@ def run_function(
177
177
  This ensures latest code changes are executed. This argument must be used in
178
178
  conjunction with the local=True argument.
179
179
  :param output_path: path to store artifacts, when running in a workflow this will be set automatically
180
+ :param retry: Retry configuration for the run, can be a dict or an instance of mlrun.model.Retry.
180
181
  :return: MLRun RunObject or PipelineNodeWrapper
181
182
  """
182
183
  if artifact_path:
@@ -197,6 +198,7 @@ def run_function(
197
198
  returns=returns,
198
199
  base=base_task,
199
200
  selector=selector,
201
+ retry=retry,
200
202
  )
201
203
  task.spec.verbose = task.spec.verbose or verbose
202
204
 
@@ -205,6 +207,11 @@ def run_function(
205
207
  raise mlrun.errors.MLRunInvalidArgumentError(
206
208
  "Scheduling jobs is not supported when running a workflow with the kfp engine."
207
209
  )
210
+ if retry:
211
+ raise mlrun.errors.MLRunInvalidArgumentError(
212
+ "Retrying jobs is not supported when running a workflow with the kfp engine. "
213
+ "Use KFP set_retry instead."
214
+ )
208
215
  return function.as_step(
209
216
  name=name, runspec=task, workdir=workdir, outputs=outputs, labels=labels
210
217
  )
mlrun/projects/project.py CHANGED
@@ -159,7 +159,8 @@ def new_project(
159
159
  parameters: Optional[dict] = None,
160
160
  default_function_node_selector: Optional[dict] = None,
161
161
  ) -> "MlrunProject":
162
- """Create a new MLRun project, optionally load it from a yaml/zip/git template
162
+ """Create a new MLRun project, optionally load it from a yaml/zip/git template.
163
+ The project will become the active project for the current session.
163
164
 
164
165
  A new project is created and returned, you can customize the project by placing a project_setup.py file
165
166
  in the project root dir, it will be executed upon project creation or loading.
@@ -326,7 +327,8 @@ def load_project(
326
327
  parameters: Optional[dict] = None,
327
328
  allow_cross_project: Optional[bool] = None,
328
329
  ) -> "MlrunProject":
329
- """Load an MLRun project from git or tar or dir
330
+ """Load an MLRun project from git or tar or dir. The project will become the active project for
331
+ the current session.
330
332
 
331
333
  MLRun looks for a project.yaml file with project definition and objects in the project root path
332
334
  and use it to initialize the project, in addition it runs the project_setup.py file (if it exists)
@@ -2688,8 +2690,8 @@ class MlrunProject(ModelObj):
2688
2690
  requirements_file: str = "",
2689
2691
  ) -> mlrun.runtimes.BaseRuntime:
2690
2692
  """
2691
- | Update or add a function object to the project.
2692
- | Function can be provided as an object (func) or a .py/.ipynb/.yaml URL.
2693
+ Update or add a function object to the project.
2694
+ Function can be provided as an object (func) or a .py/.ipynb/.yaml URL.
2693
2695
 
2694
2696
  | Creating a function from a single file is done by specifying ``func`` and disabling ``with_repo``.
2695
2697
  | Creating a function with project source (specify ``with_repo=True``):
@@ -2734,6 +2736,20 @@ class MlrunProject(ModelObj):
2734
2736
  # By providing a path to a pip requirements file
2735
2737
  proj.set_function("my.py", requirements="requirements.txt")
2736
2738
 
2739
+ One of the most important parameters is 'kind', used to specify the chosen runtime. The options are:
2740
+ - local: execute a local python or shell script
2741
+ - job: insert the code into a Kubernetes pod and execute it
2742
+ - nuclio: insert the code into a real-time serverless nuclio function
2743
+ - serving: insert code into orchestrated nuclio function(s) forming a DAG
2744
+ - dask: run the specified python code / script as Dask Distributed job
2745
+ - mpijob: run distributed Horovod jobs over the MPI job operator
2746
+ - spark: run distributed Spark job using Spark Kubernetes Operator
2747
+ - remote-spark: run distributed Spark job on remote Spark service
2748
+ - databricks: run code on Databricks cluster (python scripts, Spark etc.)
2749
+ - application: run a long living application (e.g. a web server, UI, etc.)
2750
+
2751
+ Learn more about :doc:`../../concepts/functions-overview`.
2752
+
2737
2753
  :param func: Function object or spec/code url, None refers to current Notebook
2738
2754
  :param name: Name of the function (under the project), can be specified with a tag to support
2739
2755
  Versions (e.g. myfunc:v1). If the `tag` parameter is provided, the tag in the name
@@ -3967,6 +3983,7 @@ class MlrunProject(ModelObj):
3967
3983
  builder_env: Optional[dict] = None,
3968
3984
  reset_on_run: Optional[bool] = None,
3969
3985
  output_path: Optional[str] = None,
3986
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
3970
3987
  ) -> typing.Union[mlrun.model.RunObject, PipelineNodeWrapper]:
3971
3988
  """Run a local or remote task as part of a local/kubeflow pipeline
3972
3989
 
@@ -4029,7 +4046,7 @@ class MlrunProject(ModelObj):
4029
4046
  This ensures latest code changes are executed. This argument must be used in
4030
4047
  conjunction with the local=True argument.
4031
4048
  :param output_path: path to store artifacts, when running in a workflow this will be set automatically
4032
-
4049
+ :param retry: Retry configuration for the run, can be a dict or an instance of mlrun.model.Retry.
4033
4050
  :return: MLRun RunObject or PipelineNodeWrapper
4034
4051
  """
4035
4052
  if artifact_path:
@@ -4068,6 +4085,7 @@ class MlrunProject(ModelObj):
4068
4085
  returns=returns,
4069
4086
  builder_env=builder_env,
4070
4087
  reset_on_run=reset_on_run,
4088
+ retry=retry,
4071
4089
  )
4072
4090
 
4073
4091
  def build_function(
mlrun/run.py CHANGED
@@ -36,6 +36,7 @@ import mlrun.common.schemas
36
36
  import mlrun.errors
37
37
  import mlrun.utils.helpers
38
38
  import mlrun_pipelines.utils
39
+ from mlrun.datastore.model_provider.model_provider import ModelProvider
39
40
  from mlrun_pipelines.common.models import RunStatuses
40
41
  from mlrun_pipelines.common.ops import format_summary_from_kfp_run, show_kfp_run
41
42
 
@@ -1152,6 +1153,22 @@ def get_dataitem(url, secrets=None, db=None) -> "DataItem":
1152
1153
  return stores.object(url=url)
1153
1154
 
1154
1155
 
1156
+ def get_model_provider(
1157
+ url,
1158
+ secrets=None,
1159
+ db=None,
1160
+ default_invoke_kwargs: Optional[dict] = None,
1161
+ raise_missing_schema_exception=True,
1162
+ ) -> ModelProvider:
1163
+ """get mlrun dataitem object (from path/url)"""
1164
+ store_manager.set(secrets, db=db)
1165
+ return store_manager.model_provider_object(
1166
+ url=url,
1167
+ default_invoke_kwargs=default_invoke_kwargs,
1168
+ raise_missing_schema_exception=raise_missing_schema_exception,
1169
+ )
1170
+
1171
+
1155
1172
  def download_object(url, target, secrets=None):
1156
1173
  """download mlrun dataitem (from path/url to target path)"""
1157
1174
  stores = store_manager.set(secrets)
@@ -148,6 +148,12 @@ class RuntimeKinds:
148
148
  "",
149
149
  ]
150
150
 
151
+ @staticmethod
152
+ def retriable_runtimes():
153
+ return [
154
+ RuntimeKinds.job,
155
+ ]
156
+
151
157
  @staticmethod
152
158
  def nuclio_runtimes():
153
159
  return [
mlrun/runtimes/base.py CHANGED
@@ -33,6 +33,7 @@ import mlrun.launcher.factory
33
33
  import mlrun.utils.helpers
34
34
  import mlrun.utils.notifications
35
35
  import mlrun.utils.regex
36
+ from mlrun.common.runtimes.constants import RunStates
36
37
  from mlrun.model import (
37
38
  BaseMetadata,
38
39
  HyperParamOptions,
@@ -319,6 +320,7 @@ class BaseRuntime(ModelObj):
319
320
  state_thresholds: Optional[dict[str, int]] = None,
320
321
  reset_on_run: Optional[bool] = None,
321
322
  output_path: Optional[str] = "",
323
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
322
324
  **launcher_kwargs,
323
325
  ) -> RunObject:
324
326
  """
@@ -377,6 +379,7 @@ class BaseRuntime(ModelObj):
377
379
  This ensures latest code changes are executed. This argument must be used in
378
380
  conjunction with the local=True argument.
379
381
  :param output_path: Default artifact output path.
382
+ :param retry: Retry configuration for the run, can be a dict or an instance of mlrun.model.Retry.
380
383
  :return: Run context object (RunObject) with run metadata, results and status
381
384
  """
382
385
  if artifact_path or out_path:
@@ -414,6 +417,7 @@ class BaseRuntime(ModelObj):
414
417
  returns=returns,
415
418
  state_thresholds=state_thresholds,
416
419
  reset_on_run=reset_on_run,
420
+ retry=retry,
417
421
  )
418
422
 
419
423
  def _get_db_run(
@@ -570,12 +574,27 @@ class BaseRuntime(ModelObj):
570
574
  updates = None
571
575
  last_state = get_in(resp, "status.state", "")
572
576
  kind = get_in(resp, "metadata.labels.kind", "")
573
- if last_state == "error" or err:
577
+ if last_state in RunStates.error_states() or err:
578
+ new_state = RunStates.error
579
+ status_text = None
580
+ max_retries = get_in(resp, "spec.retry.count", 0)
581
+ retry_count = get_in(resp, "status.retry_count", 0) or 0
582
+ attempts = retry_count + 1
583
+ if max_retries:
584
+ if retry_count < max_retries:
585
+ new_state = RunStates.pending_retry
586
+ status_text = f"Run failed attempt {attempts} of {max_retries + 1}"
587
+ elif retry_count >= max_retries:
588
+ status_text = f"Run failed after {attempts} attempts"
589
+
574
590
  updates = {
575
591
  "status.last_update": now_date().isoformat(),
576
- "status.state": "error",
592
+ "status.state": new_state,
577
593
  }
578
- update_in(resp, "status.state", "error")
594
+ update_in(resp, "status.state", new_state)
595
+ if status_text:
596
+ updates["status.status_text"] = status_text
597
+ update_in(resp, "status.status_text", status_text)
579
598
  if err:
580
599
  update_in(resp, "status.error", err_to_str(err))
581
600
  err = get_in(resp, "status.error")
@@ -584,9 +603,8 @@ class BaseRuntime(ModelObj):
584
603
 
585
604
  elif (
586
605
  not was_none
587
- and last_state != mlrun.common.runtimes.constants.RunStates.completed
588
- and last_state
589
- not in mlrun.common.runtimes.constants.RunStates.error_and_abortion_states()
606
+ and last_state != RunStates.completed
607
+ and last_state not in RunStates.error_and_abortion_states()
590
608
  ):
591
609
  try:
592
610
  runtime_cls = mlrun.runtimes.get_runtime_class(kind)
mlrun/runtimes/daskjob.py CHANGED
@@ -505,6 +505,7 @@ class DaskCluster(KubejobRuntime):
505
505
  state_thresholds: Optional[dict[str, int]] = None,
506
506
  reset_on_run: Optional[bool] = None,
507
507
  output_path: Optional[str] = "",
508
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
508
509
  **launcher_kwargs,
509
510
  ) -> RunObject:
510
511
  if state_thresholds:
@@ -233,6 +233,7 @@ def run_mlrun_databricks_job(context,task_parameters: dict, **kwargs):
233
233
  state_thresholds: Optional[dict[str, int]] = None,
234
234
  reset_on_run: Optional[bool] = None,
235
235
  output_path: Optional[str] = "",
236
+ retry: Optional[Union[mlrun.model.Retry, dict]] = None,
236
237
  **launcher_kwargs,
237
238
  ) -> RunObject:
238
239
  if local:
mlrun/runtimes/local.py CHANGED
@@ -34,6 +34,7 @@ from nuclio import Event
34
34
 
35
35
  import mlrun
36
36
  import mlrun.common.constants as mlrun_constants
37
+ import mlrun.common.runtimes.constants
37
38
  from mlrun.lists import RunList
38
39
 
39
40
  from ..errors import err_to_str
@@ -315,15 +316,9 @@ class LocalRuntime(BaseRuntime, ParallelRunner):
315
316
  return context.to_dict()
316
317
 
317
318
  # if RunError was raised it means that the error was raised as part of running the function
318
- # ( meaning the state was already updated to error ) therefore we just re-raise the error
319
319
  except RunError as err:
320
320
  raise err
321
- # this exception handling is for the case where we fail on pre-loading or post-running the function
322
- # and the state was not updated to error yet, therefore we update the state to error and raise as RunError
323
321
  except Exception as exc:
324
- # set_state here is mainly for sanity, as we will raise RunError which is expected to be handled
325
- # by the caller and will set the state to error ( in `update_run_state` )
326
- context.set_state(error=err_to_str(exc), commit=True)
327
322
  logger.error(f"Run error, {traceback.format_exc()}")
328
323
  raise RunError(
329
324
  "Failed on pre-loading / post-running of the function"
mlrun/serving/server.py CHANGED
@@ -395,7 +395,6 @@ def add_monitoring_general_steps(
395
395
  monitor_flow_step = graph.add_step(
396
396
  "mlrun.serving.system_steps.BackgroundTaskStatus",
397
397
  "background_task_status_step",
398
- context=context,
399
398
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
400
399
  )
401
400
  graph.add_step(
@@ -410,7 +409,6 @@ def add_monitoring_general_steps(
410
409
  "monitoring_pre_processor_step",
411
410
  after="filter_none",
412
411
  full_event=True,
413
- context=context,
414
412
  model_endpoint_creation_strategy=mlrun.common.schemas.ModelEndpointCreationStrategy.SKIP,
415
413
  )
416
414
  # flatten the events
mlrun/serving/states.py CHANGED
@@ -35,6 +35,7 @@ from storey import ParallelExecutionMechanisms
35
35
  import mlrun
36
36
  import mlrun.artifacts
37
37
  import mlrun.common.schemas as schemas
38
+ from mlrun.artifacts.llm_prompt import LLMPromptArtifact
38
39
  from mlrun.artifacts.model import ModelArtifact
39
40
  from mlrun.datastore.datastore_profile import (
40
41
  DatastoreProfileKafkaSource,
@@ -42,6 +43,7 @@ from mlrun.datastore.datastore_profile import (
42
43
  DatastoreProfileV3io,
43
44
  datastore_profile_read,
44
45
  )
46
+ from mlrun.datastore.model_provider.model_provider import ModelProvider
45
47
  from mlrun.datastore.store_resources import get_store_resource
46
48
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
47
49
  from mlrun.utils import logger
@@ -1019,6 +1021,9 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1019
1021
  if artifact_uri is not None and not isinstance(artifact_uri, str):
1020
1022
  raise MLRunInvalidArgumentError("'artifact_uri' argument must be a string")
1021
1023
  self.artifact_uri = artifact_uri
1024
+ self.invocation_artifact: Optional[LLMPromptArtifact] = None
1025
+ self.model_artifact: Optional[ModelArtifact] = None
1026
+ self.model_provider: Optional[ModelProvider] = None
1022
1027
 
1023
1028
  def __init_subclass__(cls):
1024
1029
  super().__init_subclass__()
@@ -1030,12 +1035,27 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1030
1035
 
1031
1036
  def load(self) -> None:
1032
1037
  """Override to load model if needed."""
1033
- pass
1038
+ self._load_artifacts()
1039
+ if self.model_artifact:
1040
+ self.model_provider = mlrun.get_model_provider(
1041
+ url=self.model_artifact.model_url,
1042
+ default_invoke_kwargs=self.model_artifact.default_config,
1043
+ raise_missing_schema_exception=False,
1044
+ )
1045
+
1046
+ def _load_artifacts(self) -> None:
1047
+ artifact = self._get_artifact_object()
1048
+ if isinstance(artifact, LLMPromptArtifact):
1049
+ self.invocation_artifact = artifact
1050
+ self.model_artifact = self.invocation_artifact.model_artifact
1051
+ else:
1052
+ self.model_artifact = artifact
1034
1053
 
1035
- def _get_artifact_object(self) -> Union[ModelArtifact, None]:
1054
+ def _get_artifact_object(self) -> Union[ModelArtifact, LLMPromptArtifact, None]:
1036
1055
  if self.artifact_uri:
1037
1056
  if mlrun.datastore.is_store_uri(self.artifact_uri):
1038
- return get_store_resource(self.artifact_uri)
1057
+ artifact, _ = mlrun.store_manager.get_store_artifact(self.artifact_uri)
1058
+ return artifact
1039
1059
  else:
1040
1060
  raise ValueError(
1041
1061
  "Could not get artifact, 'artifact_uri' must be a valid artifact store URI"
@@ -1058,7 +1078,7 @@ class Model(storey.ParallelExecutionRunnable, ModelObj):
1058
1078
  return self.predict(body)
1059
1079
 
1060
1080
  async def run_async(self, body: Any, path: str) -> Any:
1061
- return self.predict(body)
1081
+ return await self.predict_async(body)
1062
1082
 
1063
1083
  def get_local_model_path(self, suffix="") -> (str, dict):
1064
1084
  """
@@ -1223,7 +1243,9 @@ class ModelRunnerStep(MonitoredStep):
1223
1243
  endpoint_name: str,
1224
1244
  model_class: Union[str, Model],
1225
1245
  execution_mechanism: Union[str, ParallelExecutionMechanisms],
1226
- model_artifact: Optional[Union[str, mlrun.artifacts.ModelArtifact]] = None,
1246
+ model_artifact: Optional[
1247
+ Union[str, mlrun.artifacts.ModelArtifact, mlrun.artifacts.LLMPromptArtifact]
1248
+ ] = None,
1227
1249
  labels: Optional[Union[list[str], dict[str, str]]] = None,
1228
1250
  creation_strategy: Optional[
1229
1251
  schemas.ModelEndpointCreationStrategy
@@ -1407,6 +1429,9 @@ class ModelRunnerStep(MonitoredStep):
1407
1429
  return monitoring_data
1408
1430
 
1409
1431
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1432
+ if not self._is_local_function(context):
1433
+ # skip init of non local functions
1434
+ return
1410
1435
  model_selector = self.class_args.get("model_selector")
1411
1436
  execution_mechanism_by_model_name = self.class_args.get(
1412
1437
  schemas.ModelRunnerStepData.MODEL_TO_EXECUTION_MECHANISM
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import random
16
- from copy import copy, deepcopy
16
+ from copy import deepcopy
17
17
  from datetime import timedelta
18
18
  from typing import Any, Optional, Union
19
19
 
@@ -32,11 +32,12 @@ class MonitoringPreProcessor(storey.MapClass):
32
32
 
33
33
  def __init__(
34
34
  self,
35
- context,
36
35
  **kwargs,
37
36
  ):
38
37
  super().__init__(**kwargs)
39
- self.context = copy(context)
38
+ self.server: mlrun.serving.GraphServer = (
39
+ getattr(self.context, "server", None) if self.context else None
40
+ )
40
41
 
41
42
  def reconstruct_request_resp_fields(
42
43
  self, event, model: str, model_monitoring_data: dict
@@ -148,9 +149,8 @@ class MonitoringPreProcessor(storey.MapClass):
148
149
 
149
150
  def do(self, event):
150
151
  monitoring_event_list = []
151
- server: mlrun.serving.GraphServer = getattr(self.context, "server", None)
152
152
  model_runner_name = event._metadata.get("model_runner_name", "")
153
- step = server.graph.steps[model_runner_name] if server else {}
153
+ step = self.server.graph.steps[model_runner_name] if self.server else {}
154
154
  monitoring_data = step.monitoring_data
155
155
  logger.debug(
156
156
  "monitoring preprocessor started",
@@ -184,8 +184,8 @@ class MonitoringPreProcessor(storey.MapClass):
184
184
  mm_schemas.StreamProcessingEvent.LABELS: monitoring_data[
185
185
  model
186
186
  ].get(mlrun.common.schemas.MonitoringData.OUTPUTS),
187
- mm_schemas.StreamProcessingEvent.FUNCTION_URI: server.function_uri
188
- if server
187
+ mm_schemas.StreamProcessingEvent.FUNCTION_URI: self.server.function_uri
188
+ if self.server
189
189
  else None,
190
190
  mm_schemas.StreamProcessingEvent.REQUEST: request,
191
191
  mm_schemas.StreamProcessingEvent.RESPONSE: resp,
@@ -226,8 +226,8 @@ class MonitoringPreProcessor(storey.MapClass):
226
226
  mm_schemas.StreamProcessingEvent.LABELS: monitoring_data[model].get(
227
227
  mlrun.common.schemas.MonitoringData.OUTPUTS
228
228
  ),
229
- mm_schemas.StreamProcessingEvent.FUNCTION_URI: server.function_uri
230
- if server
229
+ mm_schemas.StreamProcessingEvent.FUNCTION_URI: self.server.function_uri
230
+ if self.server
231
231
  else None,
232
232
  mm_schemas.StreamProcessingEvent.REQUEST: request,
233
233
  mm_schemas.StreamProcessingEvent.RESPONSE: resp,
@@ -253,19 +253,17 @@ class BackgroundTaskStatus(storey.MapClass):
253
253
  creation failed or in progress
254
254
  """
255
255
 
256
- def __init__(self, context, **kwargs):
257
- self.context = copy(context)
258
- self.server: mlrun.serving.GraphServer = getattr(self.context, "server", None)
256
+ def __init__(self, **kwargs):
257
+ super().__init__(**kwargs)
258
+ self.server: mlrun.serving.GraphServer = (
259
+ getattr(self.context, "server", None) if self.context else None
260
+ )
259
261
  self._background_task_check_timestamp = None
260
262
  self._background_task_state = mlrun.common.schemas.BackgroundTaskState.running
261
- super().__init__(**kwargs)
262
263
 
263
264
  def do(self, event):
264
- if (self.context and self.context.is_mock) or self.context is None:
265
- return event
266
265
  if self.server is None:
267
266
  return None
268
-
269
267
  if (
270
268
  self._background_task_state
271
269
  == mlrun.common.schemas.BackgroundTaskState.running
@@ -283,19 +281,14 @@ class BackgroundTaskStatus(storey.MapClass):
283
281
  self._background_task_check_timestamp = mlrun.utils.now_date()
284
282
  self._log_background_task_state(background_task.status.state)
285
283
  self._background_task_state = background_task.status.state
286
- if (
287
- background_task.status.state
288
- == mlrun.common.schemas.BackgroundTaskState.succeeded
289
- ):
290
- return event
291
- else:
292
- return None
293
- elif (
284
+
285
+ if (
294
286
  self._background_task_state
295
- == mlrun.common.schemas.BackgroundTaskState.failed
287
+ == mlrun.common.schemas.BackgroundTaskState.succeeded
296
288
  ):
289
+ return event
290
+ else:
297
291
  return None
298
- return event
299
292
 
300
293
  def _log_background_task_state(
301
294
  self, background_task_state: mlrun.common.schemas.BackgroundTaskState
@@ -382,9 +375,10 @@ class SamplingStep(storey.MapClass):
382
375
 
383
376
 
384
377
  class MockStreamPusher(storey.MapClass):
385
- def __init__(self, context, output_stream=None, **kwargs):
378
+ def __init__(self, output_stream=None, **kwargs):
386
379
  super().__init__(**kwargs)
387
- self.output_stream = output_stream or context.stream.output_stream
380
+ stream = self.context.stream if self.context else None
381
+ self.output_stream = output_stream or stream.output_stream
388
382
 
389
383
  def do(self, event):
390
384
  self.output_stream.push(