mlrun 1.10.0rc2__py3-none-any.whl → 1.10.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (67) hide show
  1. mlrun/__init__.py +2 -2
  2. mlrun/__main__.py +2 -2
  3. mlrun/artifacts/__init__.py +1 -0
  4. mlrun/artifacts/base.py +20 -8
  5. mlrun/artifacts/dataset.py +1 -1
  6. mlrun/artifacts/document.py +1 -1
  7. mlrun/artifacts/helpers.py +40 -0
  8. mlrun/artifacts/llm_prompt.py +165 -0
  9. mlrun/artifacts/manager.py +13 -1
  10. mlrun/artifacts/model.py +92 -12
  11. mlrun/artifacts/plots.py +2 -2
  12. mlrun/common/formatters/artifact.py +1 -0
  13. mlrun/common/runtimes/constants.py +0 -21
  14. mlrun/common/schemas/artifact.py +12 -12
  15. mlrun/common/schemas/pipeline.py +0 -16
  16. mlrun/common/schemas/project.py +0 -17
  17. mlrun/common/schemas/runs.py +0 -17
  18. mlrun/config.py +3 -3
  19. mlrun/datastore/base.py +2 -2
  20. mlrun/datastore/datastore.py +1 -1
  21. mlrun/datastore/datastore_profile.py +3 -11
  22. mlrun/datastore/redis.py +2 -3
  23. mlrun/datastore/sources.py +0 -9
  24. mlrun/datastore/store_resources.py +3 -3
  25. mlrun/datastore/storeytargets.py +2 -5
  26. mlrun/datastore/targets.py +7 -57
  27. mlrun/datastore/utils.py +1 -11
  28. mlrun/db/base.py +7 -6
  29. mlrun/db/httpdb.py +72 -66
  30. mlrun/db/nopdb.py +1 -0
  31. mlrun/errors.py +22 -1
  32. mlrun/execution.py +87 -1
  33. mlrun/feature_store/common.py +5 -5
  34. mlrun/feature_store/feature_set.py +10 -6
  35. mlrun/feature_store/feature_vector.py +8 -6
  36. mlrun/launcher/base.py +1 -1
  37. mlrun/lists.py +1 -1
  38. mlrun/model.py +0 -5
  39. mlrun/model_monitoring/__init__.py +0 -1
  40. mlrun/model_monitoring/api.py +0 -44
  41. mlrun/model_monitoring/applications/evidently/base.py +3 -41
  42. mlrun/model_monitoring/controller.py +1 -1
  43. mlrun/model_monitoring/writer.py +1 -4
  44. mlrun/projects/operations.py +3 -3
  45. mlrun/projects/project.py +260 -23
  46. mlrun/run.py +9 -27
  47. mlrun/runtimes/base.py +6 -6
  48. mlrun/runtimes/kubejob.py +2 -2
  49. mlrun/runtimes/nuclio/function.py +3 -3
  50. mlrun/runtimes/nuclio/serving.py +13 -23
  51. mlrun/runtimes/remotesparkjob.py +6 -0
  52. mlrun/runtimes/sparkjob/spark3job.py +6 -0
  53. mlrun/serving/__init__.py +5 -1
  54. mlrun/serving/server.py +39 -3
  55. mlrun/serving/states.py +101 -4
  56. mlrun/serving/v2_serving.py +1 -1
  57. mlrun/utils/helpers.py +66 -9
  58. mlrun/utils/notifications/notification/slack.py +5 -1
  59. mlrun/utils/notifications/notification_pusher.py +2 -1
  60. mlrun/utils/version/version.json +2 -2
  61. {mlrun-1.10.0rc2.dist-info → mlrun-1.10.0rc4.dist-info}/METADATA +22 -10
  62. {mlrun-1.10.0rc2.dist-info → mlrun-1.10.0rc4.dist-info}/RECORD +66 -65
  63. {mlrun-1.10.0rc2.dist-info → mlrun-1.10.0rc4.dist-info}/WHEEL +1 -1
  64. mlrun/model_monitoring/tracking_policy.py +0 -124
  65. {mlrun-1.10.0rc2.dist-info → mlrun-1.10.0rc4.dist-info}/entry_points.txt +0 -0
  66. {mlrun-1.10.0rc2.dist-info → mlrun-1.10.0rc4.dist-info}/licenses/LICENSE +0 -0
  67. {mlrun-1.10.0rc2.dist-info → mlrun-1.10.0rc4.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import copy
15
15
  import json
16
16
  import os
17
17
  import warnings
18
18
  from copy import deepcopy
19
- from typing import TYPE_CHECKING, Optional, Union
19
+ from typing import Optional, Union
20
20
 
21
21
  import nuclio
22
22
  from nuclio import KafkaTrigger
@@ -27,7 +27,11 @@ from mlrun.datastore import get_kafka_brokers_from_dict, parse_kafka_url
27
27
  from mlrun.model import ObjectList
28
28
  from mlrun.runtimes.function_reference import FunctionReference
29
29
  from mlrun.secrets import SecretsStore
30
- from mlrun.serving.server import GraphServer, create_graph_server
30
+ from mlrun.serving.server import (
31
+ GraphServer,
32
+ add_system_steps_to_graph,
33
+ create_graph_server,
34
+ )
31
35
  from mlrun.serving.states import (
32
36
  RootFlowStep,
33
37
  RouterStep,
@@ -43,10 +47,6 @@ from .function import NuclioSpec, RemoteRuntime, min_nuclio_versions
43
47
 
44
48
  serving_subkind = "serving_v2"
45
49
 
46
- if TYPE_CHECKING:
47
- # remove this block in 1.9.0
48
- from mlrun.model_monitoring import TrackingPolicy
49
-
50
50
 
51
51
  def new_v2_model_server(
52
52
  name,
@@ -95,7 +95,6 @@ class ServingSpec(NuclioSpec):
95
95
  "default_class",
96
96
  "secret_sources",
97
97
  "track_models",
98
- "tracking_policy",
99
98
  ]
100
99
 
101
100
  def __init__(
@@ -132,7 +131,6 @@ class ServingSpec(NuclioSpec):
132
131
  graph_initializer=None,
133
132
  error_stream=None,
134
133
  track_models=None,
135
- tracking_policy=None,
136
134
  secret_sources=None,
137
135
  default_content_type=None,
138
136
  node_name=None,
@@ -207,7 +205,6 @@ class ServingSpec(NuclioSpec):
207
205
  self.graph_initializer = graph_initializer
208
206
  self.error_stream = error_stream
209
207
  self.track_models = track_models
210
- self.tracking_policy = tracking_policy
211
208
  self.secret_sources = secret_sources or []
212
209
  self.default_content_type = default_content_type
213
210
  self.model_endpoint_creation_task_name = model_endpoint_creation_task_name
@@ -314,7 +311,6 @@ class ServingRuntime(RemoteRuntime):
314
311
  batch: Optional[int] = None,
315
312
  sampling_percentage: float = 100,
316
313
  stream_args: Optional[dict] = None,
317
- tracking_policy: Optional[Union["TrackingPolicy", dict]] = None,
318
314
  enable_tracking: bool = True,
319
315
  ) -> None:
320
316
  """Apply on your serving function to monitor a deployed model, including real-time dashboards to detect drift
@@ -361,20 +357,12 @@ class ServingRuntime(RemoteRuntime):
361
357
  if batch:
362
358
  warnings.warn(
363
359
  "The `batch` size parameter was deprecated in version 1.8.0 and is no longer used. "
364
- "It will be removed in 1.10.",
365
- # TODO: Remove this in 1.10
360
+ "It will be removed in 1.11.",
361
+ # TODO: Remove this in 1.11
366
362
  FutureWarning,
367
363
  )
368
364
  if stream_args:
369
365
  self.spec.parameters["stream_args"] = stream_args
370
- if tracking_policy is not None:
371
- warnings.warn(
372
- "The `tracking_policy` argument is deprecated from version 1.7.0 "
373
- "and has no effect. It will be removed in 1.9.0.\n"
374
- "To set the desired model monitoring time window and schedule, use "
375
- "the `base_period` argument in `project.enable_model_monitoring()`.",
376
- FutureWarning,
377
- )
378
366
 
379
367
  def add_model(
380
368
  self,
@@ -719,7 +707,6 @@ class ServingRuntime(RemoteRuntime):
719
707
  "graph_initializer": self.spec.graph_initializer,
720
708
  "error_stream": self.spec.error_stream,
721
709
  "track_models": self.spec.track_models,
722
- "tracking_policy": None,
723
710
  "default_content_type": self.spec.default_content_type,
724
711
  "model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
725
712
  }
@@ -761,10 +748,13 @@ class ServingRuntime(RemoteRuntime):
761
748
  set_paths(workdir)
762
749
  os.chdir(workdir)
763
750
 
751
+ system_graph = None
752
+ if isinstance(self.spec.graph, RootFlowStep):
753
+ system_graph = add_system_steps_to_graph(copy.deepcopy(self.spec.graph))
764
754
  server = create_graph_server(
765
755
  parameters=self.spec.parameters,
766
756
  load_mode=self.spec.load_mode,
767
- graph=self.spec.graph,
757
+ graph=system_graph or self.spec.graph,
768
758
  verbose=self.verbose,
769
759
  current_function=current_function,
770
760
  graph_initializer=self.spec.graph_initializer,
@@ -103,6 +103,12 @@ class RemoteSparkRuntime(KubejobRuntime):
103
103
 
104
104
  @classmethod
105
105
  def deploy_default_image(cls):
106
+ if not mlrun.get_current_project(silent=True):
107
+ raise mlrun.errors.MLRunMissingProjectError(
108
+ "An active project is required to run deploy_default_image(). "
109
+ "This can be set by calling get_or_create_project(), load_project(), or new_project()."
110
+ )
111
+
106
112
  sj = mlrun.new_function(
107
113
  kind="remote-spark", name="remote-spark-default-image-deploy-temp"
108
114
  )
@@ -804,6 +804,12 @@ class Spark3Runtime(KubejobRuntime):
804
804
 
805
805
  @classmethod
806
806
  def deploy_default_image(cls, with_gpu=False):
807
+ if not mlrun.get_current_project(silent=True):
808
+ raise mlrun.errors.MLRunMissingProjectError(
809
+ "An active project is required to run deploy_default_image(). "
810
+ "This can be set by calling get_or_create_project()."
811
+ )
812
+
807
813
  sj = mlrun.new_function(kind=cls.kind, name="spark-default-image-deploy-temp")
808
814
  sj.spec.build.image = cls._get_default_deployed_mlrun_image_name(with_gpu)
809
815
 
mlrun/serving/__init__.py CHANGED
@@ -30,7 +30,11 @@ __all__ = [
30
30
  ]
31
31
 
32
32
  from .routers import ModelRouter, VotingEnsemble # noqa
33
- from .server import GraphContext, GraphServer, create_graph_server # noqa
33
+ from .server import (
34
+ GraphContext,
35
+ GraphServer,
36
+ create_graph_server,
37
+ ) # noqa
34
38
  from .states import (
35
39
  ErrorStep,
36
40
  QueueStep,
mlrun/serving/server.py CHANGED
@@ -15,6 +15,7 @@
15
15
  __all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
16
16
 
17
17
  import asyncio
18
+ import copy
18
19
  import json
19
20
  import os
20
21
  import socket
@@ -71,7 +72,7 @@ class _StreamContext:
71
72
  if (enabled or log_stream) and function_uri:
72
73
  self.enabled = True
73
74
  project, _, _, _ = parse_versioned_object_uri(
74
- function_uri, config.default_project
75
+ function_uri, config.active_project
75
76
  )
76
77
 
77
78
  stream_args = parameters.get("stream_args", {})
@@ -108,7 +109,6 @@ class GraphServer(ModelObj):
108
109
  graph_initializer=None,
109
110
  error_stream=None,
110
111
  track_models=None,
111
- tracking_policy=None,
112
112
  secret_sources=None,
113
113
  default_content_type=None,
114
114
  function_name=None,
@@ -129,7 +129,6 @@ class GraphServer(ModelObj):
129
129
  self.graph_initializer = graph_initializer
130
130
  self.error_stream = error_stream
131
131
  self.track_models = track_models
132
- self.tracking_policy = tracking_policy
133
132
  self._error_stream_object = None
134
133
  self.secret_sources = secret_sources
135
134
  self._secrets = SecretsStore.from_list(secret_sources)
@@ -330,12 +329,49 @@ class GraphServer(ModelObj):
330
329
  return self.graph.wait_for_completion()
331
330
 
332
331
 
332
+ def add_system_steps_to_graph(graph: RootFlowStep):
333
+ model_runner_raisers = {}
334
+ steps = list(graph.steps.values())
335
+ for step in steps:
336
+ if (
337
+ isinstance(step, mlrun.serving.states.ModelRunnerStep)
338
+ and step.raise_exception
339
+ ):
340
+ error_step = graph.add_step(
341
+ class_name="mlrun.serving.states.ModelRunnerErrorRaiser",
342
+ name=f"{step.name}_error_raise",
343
+ after=step.name,
344
+ full_event=True,
345
+ raise_exception=step.raise_exception,
346
+ models_names=list(step.class_args["models"].keys()),
347
+ )
348
+ if step.responder:
349
+ step.responder = False
350
+ error_step.respond()
351
+ model_runner_raisers[step.name] = error_step.name
352
+ error_step.on_error = step.on_error
353
+ if isinstance(step.after, list):
354
+ for i in range(len(step.after)):
355
+ if step.after[i] in model_runner_raisers:
356
+ step.after[i] = model_runner_raisers[step.after[i]]
357
+ else:
358
+ if step.after in model_runner_raisers:
359
+ step.after = model_runner_raisers[step.after]
360
+ return graph
361
+
362
+
333
363
  def v2_serving_init(context, namespace=None):
334
364
  """hook for nuclio init_context()"""
335
365
 
336
366
  context.logger.info("Initializing server from spec")
337
367
  spec = mlrun.utils.get_serving_spec()
338
368
  server = GraphServer.from_dict(spec)
369
+ if isinstance(server.graph, RootFlowStep):
370
+ server.graph = add_system_steps_to_graph(copy.deepcopy(server.graph))
371
+ context.logger.info_with(
372
+ "Server graph after adding system steps",
373
+ graph=str(server.graph.steps),
374
+ )
339
375
 
340
376
  if config.log_level.lower() == "debug":
341
377
  server.verbose = True
mlrun/serving/states.py CHANGED
@@ -32,12 +32,14 @@ import storey.utils
32
32
  import mlrun
33
33
  import mlrun.artifacts
34
34
  import mlrun.common.schemas as schemas
35
+ from mlrun.artifacts.model import ModelArtifact
35
36
  from mlrun.datastore.datastore_profile import (
36
37
  DatastoreProfileKafkaSource,
37
38
  DatastoreProfileKafkaTarget,
38
39
  DatastoreProfileV3io,
39
40
  datastore_profile_read,
40
41
  )
42
+ from mlrun.datastore.store_resources import get_store_resource
41
43
  from mlrun.datastore.storeytargets import KafkaStoreyTarget, StreamStoreyTarget
42
44
  from mlrun.utils import logger
43
45
 
@@ -47,7 +49,7 @@ from ..datastore.utils import (
47
49
  get_kafka_brokers_from_dict,
48
50
  parse_kafka_url,
49
51
  )
50
- from ..errors import MLRunInvalidArgumentError, err_to_str
52
+ from ..errors import MLRunInvalidArgumentError, ModelRunnerError, err_to_str
51
53
  from ..model import ModelObj, ObjectDict
52
54
  from ..platforms.iguazio import parse_path
53
55
  from ..utils import get_class, get_function, is_explicit_ack_supported
@@ -955,10 +957,33 @@ class RouterStep(TaskStep):
955
957
 
956
958
 
957
959
  class Model(storey.ParallelExecutionRunnable):
960
+ def __init__(
961
+ self,
962
+ name: str,
963
+ raise_exception: bool = True,
964
+ artifact_uri: Optional[str] = None,
965
+ **kwargs,
966
+ ):
967
+ super().__init__(name=name, raise_exception=raise_exception, **kwargs)
968
+ if artifact_uri is not None and not isinstance(artifact_uri, str):
969
+ raise MLRunInvalidArgumentError("artifact_uri argument must be a string")
970
+ self.artifact_uri = artifact_uri
971
+
958
972
  def load(self) -> None:
959
973
  """Override to load model if needed."""
960
974
  pass
961
975
 
976
+ def _get_artifact_object(self) -> Union[ModelArtifact, None]:
977
+ if self.artifact_uri:
978
+ if mlrun.datastore.is_store_uri(self.artifact_uri):
979
+ return get_store_resource(self.artifact_uri)
980
+ else:
981
+ raise ValueError(
982
+ "Could not get artifact, artifact_uri must be a valid artifact store URI"
983
+ )
984
+ else:
985
+ return None
986
+
962
987
  def init(self):
963
988
  self.load()
964
989
 
@@ -976,6 +1001,39 @@ class Model(storey.ParallelExecutionRunnable):
976
1001
  async def run_async(self, body: Any, path: str) -> Any:
977
1002
  return self.predict(body)
978
1003
 
1004
+ def get_local_model_path(self, suffix="") -> (str, dict):
1005
+ """get local model file(s) and extra data items by using artifact
1006
+ If the model file is stored in remote cloud storage, download it to the local file system
1007
+
1008
+ Examples
1009
+ --------
1010
+ ::
1011
+
1012
+ def load(self):
1013
+ model_file, extra_data = self.get_local_model_path(suffix=".pkl")
1014
+ self.model = load(open(model_file, "rb"))
1015
+ categories = extra_data["categories"].as_df()
1016
+
1017
+ Parameters
1018
+ ----------
1019
+ suffix : str
1020
+ optional, model file suffix (when the model_path is a directory)
1021
+
1022
+ Returns
1023
+ -------
1024
+ str
1025
+ (local) model file
1026
+ dict
1027
+ extra dataitems dictionary
1028
+ """
1029
+ artifact = self._get_artifact_object()
1030
+ if artifact:
1031
+ model_file, _, extra_dataitems = mlrun.artifacts.get_model(
1032
+ suffix=suffix, model_dir=artifact
1033
+ )
1034
+ return model_file, extra_dataitems
1035
+ return None, None
1036
+
979
1037
 
980
1038
  class ModelSelector:
981
1039
  """Used to select which models to run on each event."""
@@ -1022,14 +1080,18 @@ class ModelRunnerStep(TaskStep, StepToDict):
1022
1080
 
1023
1081
  :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
1024
1082
  event. Optional. If not passed, all models will be run.
1083
+ :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
1084
+ an error. If False, the error will appear in the output event.
1025
1085
  """
1026
1086
 
1027
1087
  kind = "model_runner"
1088
+ _dict_fields = TaskStep._dict_fields + ["raise_exception"]
1028
1089
 
1029
1090
  def __init__(
1030
1091
  self,
1031
1092
  *args,
1032
1093
  model_selector: Optional[Union[str, ModelSelector]] = None,
1094
+ raise_exception: bool = True,
1033
1095
  **kwargs,
1034
1096
  ):
1035
1097
  super().__init__(
@@ -1038,6 +1100,7 @@ class ModelRunnerStep(TaskStep, StepToDict):
1038
1100
  class_args=dict(model_selector=model_selector),
1039
1101
  **kwargs,
1040
1102
  )
1103
+ self.raise_exception = raise_exception
1041
1104
 
1042
1105
  def add_model(
1043
1106
  self,
@@ -1084,6 +1147,14 @@ class ModelRunnerStep(TaskStep, StepToDict):
1084
1147
  """
1085
1148
  # TODO allow model_class as Model object as part of ML-9924
1086
1149
  model_parameters = model_parameters or {}
1150
+ model_artifact = (
1151
+ model_artifact.uri
1152
+ if isinstance(model_artifact, mlrun.artifacts.Artifact)
1153
+ else model_artifact
1154
+ )
1155
+ model_parameters["artifact_uri"] = model_parameters.get(
1156
+ "artifact_uri", model_artifact
1157
+ )
1087
1158
  if model_parameters.get("name", endpoint_name) != endpoint_name:
1088
1159
  raise mlrun.errors.MLRunInvalidArgumentError(
1089
1160
  "Inconsistent name for model added to ModelRunnerStep."
@@ -1106,9 +1177,7 @@ class ModelRunnerStep(TaskStep, StepToDict):
1106
1177
  schemas.MonitoringData.INPUT_PATH: input_path,
1107
1178
  schemas.MonitoringData.CREATION_STRATEGY: creation_strategy,
1108
1179
  schemas.MonitoringData.LABELS: labels,
1109
- schemas.MonitoringData.MODEL_PATH: model_artifact.uri
1110
- if isinstance(model_artifact, mlrun.artifacts.Artifact)
1111
- else model_artifact,
1180
+ schemas.MonitoringData.MODEL_PATH: model_artifact,
1112
1181
  }
1113
1182
  self.class_args[schemas.ModelRunnerStepData.MODELS] = models
1114
1183
  self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
@@ -1121,7 +1190,12 @@ class ModelRunnerStep(TaskStep, StepToDict):
1121
1190
  model_objects = []
1122
1191
  for model, model_params in models.values():
1123
1192
  if not isinstance(model, Model):
1193
+ # prevent model predict from raising error
1194
+ model_params["raise_exception"] = False
1124
1195
  model = get_class(model, namespace)(**model_params)
1196
+ else:
1197
+ # prevent model predict from raising error
1198
+ model._raise_exception = False
1125
1199
  model_objects.append(model)
1126
1200
  self._async_object = ModelRunner(
1127
1201
  model_selector=model_selector,
@@ -1129,6 +1203,29 @@ class ModelRunnerStep(TaskStep, StepToDict):
1129
1203
  )
1130
1204
 
1131
1205
 
1206
+ class ModelRunnerErrorRaiser(storey.MapClass):
1207
+ def __init__(self, raise_exception: bool, models_names: list[str], **kwargs):
1208
+ super().__init__(**kwargs)
1209
+ self._raise_exception = raise_exception
1210
+ self._models_names = models_names
1211
+
1212
+ def do(self, event):
1213
+ if self._raise_exception:
1214
+ errors = {}
1215
+ should_raise = False
1216
+ if len(self._models_names) == 1:
1217
+ should_raise = event.body.get("error") is not None
1218
+ errors[self._models_names[0]] = event.body.get("error")
1219
+ else:
1220
+ for model in event.body:
1221
+ errors[model] = event.body.get(model).get("error")
1222
+ if errors[model] is not None:
1223
+ should_raise = True
1224
+ if should_raise:
1225
+ raise ModelRunnerError(models_errors=errors)
1226
+ return event
1227
+
1228
+
1132
1229
  class QueueStep(BaseStep, StepToDict):
1133
1230
  """queue step, implement an async queue or represent a stream"""
1134
1231
 
@@ -177,7 +177,7 @@ class V2ModelServer(StepToDict):
177
177
  """set real time metric (for model monitoring)"""
178
178
  self.metrics[name] = value
179
179
 
180
- def get_model(self, suffix=""):
180
+ def get_model(self, suffix="") -> (str, dict):
181
181
  """get the model file(s) and metadata from model store
182
182
 
183
183
  the method returns a path to the model file and the extra data (dict of dataitem objects)
mlrun/utils/helpers.py CHANGED
@@ -60,6 +60,7 @@ import mlrun_pipelines.common.constants
60
60
  import mlrun_pipelines.models
61
61
  import mlrun_pipelines.utils
62
62
  from mlrun.common.constants import MYSQL_MEDIUMBLOB_SIZE_BYTES
63
+ from mlrun.common.schemas import ArtifactCategories
63
64
  from mlrun.config import config
64
65
  from mlrun_pipelines.models import PipelineRun
65
66
 
@@ -96,6 +97,7 @@ class StorePrefix:
96
97
  Model = "models"
97
98
  Dataset = "datasets"
98
99
  Document = "documents"
100
+ LLMPrompt = "llm-prompts"
99
101
 
100
102
  @classmethod
101
103
  def is_artifact(cls, prefix):
@@ -107,6 +109,7 @@ class StorePrefix:
107
109
  "model": cls.Model,
108
110
  "dataset": cls.Dataset,
109
111
  "document": cls.Document,
112
+ "llm-prompt": cls.LLMPrompt,
110
113
  }
111
114
  return kind_map.get(kind, cls.Artifact)
112
115
 
@@ -119,6 +122,7 @@ class StorePrefix:
119
122
  cls.FeatureSet,
120
123
  cls.FeatureVector,
121
124
  cls.Document,
125
+ cls.LLMPrompt,
122
126
  ]
123
127
 
124
128
 
@@ -131,7 +135,16 @@ def get_artifact_target(item: dict, project=None):
131
135
  kind = item.get("kind")
132
136
  uid = item["metadata"].get("uid")
133
137
 
134
- if kind in {"dataset", "model", "artifact"} and db_key:
138
+ if (
139
+ kind
140
+ in {
141
+ ArtifactCategories.dataset,
142
+ ArtifactCategories.model,
143
+ ArtifactCategories.llm_prompt,
144
+ "artifact",
145
+ }
146
+ and db_key
147
+ ):
135
148
  target = (
136
149
  f"{DB_SCHEMA}://{StorePrefix.kind_to_prefix(kind)}/{project_str}/{db_key}"
137
150
  )
@@ -876,13 +889,18 @@ def enrich_image_url(
876
889
  client_version: Optional[str] = None,
877
890
  client_python_version: Optional[str] = None,
878
891
  ) -> str:
892
+ image_url = image_url.strip()
893
+
894
+ # Add python version tag if needed
895
+ if image_url == "python" and client_python_version:
896
+ image_url = f"python:{client_python_version}"
897
+
879
898
  client_version = _convert_python_package_version_to_image_tag(client_version)
880
899
  server_version = _convert_python_package_version_to_image_tag(
881
900
  mlrun.utils.version.Version().get()["version"]
882
901
  )
883
- image_url = image_url.strip()
884
902
  mlrun_version = config.images_tag or client_version or server_version
885
- tag = mlrun_version
903
+ tag = mlrun_version or ""
886
904
 
887
905
  # TODO: Remove condition when mlrun/mlrun-kfp image is also supported
888
906
  if "mlrun-kfp" not in image_url:
@@ -2093,22 +2111,60 @@ def join_urls(base_url: Optional[str], path: Optional[str]) -> str:
2093
2111
 
2094
2112
  class Workflow:
2095
2113
  @staticmethod
2096
- def get_workflow_steps(workflow_id: str, project: str) -> list:
2114
+ def get_workflow_steps(
2115
+ db: "mlrun.db.RunDBInterface", workflow_id: str, project: str
2116
+ ) -> list:
2097
2117
  steps = []
2098
- db = mlrun.get_run_db()
2099
2118
 
2100
2119
  def _add_run_step(_step: mlrun_pipelines.models.PipelineStep):
2120
+ # on kfp 1.8 argo sets the pod hostname differently than what we have with kfp 2.5
2121
+ # therefore, the heuristic needs to change. what we do here is first trying against 1.8 conventions
2122
+ # and if we can't find it then falling back to 2.5
2101
2123
  try:
2102
- _run = db.list_runs(
2124
+ # runner_pod = x-y-N
2125
+ _runs = db.list_runs(
2103
2126
  project=project,
2104
2127
  labels=f"{mlrun_constants.MLRunInternalLabels.runner_pod}={_step.node_name}",
2105
- )[0]
2128
+ )
2129
+ if not _runs:
2130
+ try:
2131
+ # x-y-N -> x-y, N
2132
+ node_name_initials, node_name_generated_id = (
2133
+ _step.node_name.rsplit("-", 1)
2134
+ )
2135
+
2136
+ except ValueError:
2137
+ # defensive programming, if the node name is not in the expected format
2138
+ node_name_initials = _step.node_name
2139
+ node_name_generated_id = ""
2140
+
2141
+ # compile the expected runner pod hostname as per kfp >= 2.4
2142
+ # x-y, Z, N -> runner_pod = x-y-Z-N
2143
+ runner_pod_value = "-".join(
2144
+ [
2145
+ node_name_initials,
2146
+ _step.display_name,
2147
+ node_name_generated_id,
2148
+ ]
2149
+ ).rstrip("-")
2150
+ logger.debug(
2151
+ "No run found for step, trying with different node name",
2152
+ step_node_name=runner_pod_value,
2153
+ )
2154
+ _runs = db.list_runs(
2155
+ project=project,
2156
+ labels=f"{mlrun_constants.MLRunInternalLabels.runner_pod}={runner_pod_value}",
2157
+ )
2158
+
2159
+ _run = _runs[0]
2106
2160
  except IndexError:
2161
+ logger.warning("No run found for step", step=_step.to_dict())
2107
2162
  _run = {
2108
2163
  "metadata": {
2109
2164
  "name": _step.display_name,
2110
2165
  "project": project,
2111
2166
  },
2167
+ "status": {},
2112
2168
  }
2113
2169
  _run["step_kind"] = _step.step_type
2114
2170
  if _step.skipped:
@@ -2226,8 +2282,9 @@ class Workflow:
2226
2282
  namespace=mlrun.mlconf.namespace,
2227
2283
  )
2228
2284
 
2229
- # arbitrary timeout of 5 seconds, the workflow should be done by now
2230
- kfp_run = kfp_client.wait_for_run_completion(workflow_id, 5)
2285
+ # arbitrary timeout of 30 seconds, the workflow should be done by now, however sometimes kfp takes a few
2286
+ # seconds to update the workflow status
2287
+ kfp_run = kfp_client.wait_for_run_completion(workflow_id, 30)
2231
2288
  if not kfp_run:
2232
2289
  return None
2233
2290
 
@@ -16,6 +16,7 @@ import typing
16
16
 
17
17
  import aiohttp
18
18
 
19
+ import mlrun.common.runtimes.constants as runtimes_constants
19
20
  import mlrun.common.schemas
20
21
  import mlrun.lists
21
22
  import mlrun.utils.helpers
@@ -177,7 +178,10 @@ class SlackNotification(NotificationBase):
177
178
  # Only show the URL if the run is not a function (serving or mlrun function)
178
179
  kind = run.get("step_kind")
179
180
  state = run["status"].get("state", "")
180
- if state != "skipped" and (url and not kind or kind == "run"):
181
+
182
+ if state != runtimes_constants.RunStates.skipped and (
183
+ url and not kind or kind == "run"
184
+ ):
181
185
  line = f'<{url}|*{meta.get("name")}*>'
182
186
  else:
183
187
  line = meta.get("name")
@@ -287,7 +287,8 @@ class NotificationPusher(_NotificationPusherBase):
287
287
  )
288
288
  project = run.metadata.project
289
289
  workflow_id = run.status.results.get("workflow_id", None)
290
- runs.extend(Workflow.get_workflow_steps(workflow_id, project))
290
+ db = mlrun.get_run_db()
291
+ runs.extend(Workflow.get_workflow_steps(db, workflow_id, project))
291
292
 
292
293
  message = (
293
294
  self.messages.get(run.state(), "").format(resource=resource)
@@ -1,4 +1,4 @@
1
1
  {
2
- "git_commit": "104d0f2f30c2896ede9efe0344f8cbaed8b5616c",
3
- "version": "1.10.0-rc2"
2
+ "git_commit": "aca543927ff594b8db166e423cb47001dfdf7bcc",
3
+ "version": "1.10.0-rc4"
4
4
  }