mlrun 1.10.0rc1__py3-none-any.whl → 1.10.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (59) hide show
  1. mlrun/__init__.py +2 -2
  2. mlrun/__main__.py +15 -4
  3. mlrun/artifacts/base.py +6 -6
  4. mlrun/artifacts/dataset.py +1 -1
  5. mlrun/artifacts/document.py +1 -1
  6. mlrun/artifacts/model.py +1 -1
  7. mlrun/artifacts/plots.py +2 -2
  8. mlrun/common/constants.py +7 -0
  9. mlrun/common/runtimes/constants.py +1 -1
  10. mlrun/common/schemas/__init__.py +1 -0
  11. mlrun/common/schemas/artifact.py +1 -1
  12. mlrun/common/schemas/pipeline.py +1 -1
  13. mlrun/common/schemas/project.py +1 -1
  14. mlrun/common/schemas/runs.py +1 -1
  15. mlrun/common/schemas/serving.py +17 -0
  16. mlrun/config.py +4 -4
  17. mlrun/datastore/datastore_profile.py +7 -57
  18. mlrun/datastore/sources.py +24 -16
  19. mlrun/datastore/store_resources.py +3 -3
  20. mlrun/datastore/targets.py +5 -5
  21. mlrun/datastore/utils.py +21 -6
  22. mlrun/db/base.py +7 -7
  23. mlrun/db/httpdb.py +88 -76
  24. mlrun/db/nopdb.py +1 -1
  25. mlrun/errors.py +29 -1
  26. mlrun/execution.py +9 -0
  27. mlrun/feature_store/common.py +5 -5
  28. mlrun/feature_store/feature_set.py +10 -6
  29. mlrun/feature_store/feature_vector.py +8 -6
  30. mlrun/launcher/base.py +1 -1
  31. mlrun/launcher/client.py +1 -1
  32. mlrun/lists.py +1 -1
  33. mlrun/model_monitoring/__init__.py +0 -1
  34. mlrun/model_monitoring/api.py +0 -44
  35. mlrun/model_monitoring/applications/evidently/base.py +57 -107
  36. mlrun/model_monitoring/controller.py +27 -14
  37. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +13 -5
  38. mlrun/model_monitoring/writer.py +1 -4
  39. mlrun/projects/operations.py +3 -3
  40. mlrun/projects/project.py +114 -52
  41. mlrun/render.py +5 -9
  42. mlrun/run.py +10 -10
  43. mlrun/runtimes/base.py +7 -7
  44. mlrun/runtimes/kubejob.py +2 -2
  45. mlrun/runtimes/nuclio/function.py +3 -3
  46. mlrun/runtimes/nuclio/serving.py +13 -23
  47. mlrun/runtimes/utils.py +25 -8
  48. mlrun/serving/__init__.py +5 -1
  49. mlrun/serving/server.py +39 -3
  50. mlrun/serving/states.py +176 -10
  51. mlrun/utils/helpers.py +10 -4
  52. mlrun/utils/version/version.json +2 -2
  53. {mlrun-1.10.0rc1.dist-info → mlrun-1.10.0rc3.dist-info}/METADATA +27 -15
  54. {mlrun-1.10.0rc1.dist-info → mlrun-1.10.0rc3.dist-info}/RECORD +58 -59
  55. {mlrun-1.10.0rc1.dist-info → mlrun-1.10.0rc3.dist-info}/WHEEL +1 -1
  56. mlrun/model_monitoring/tracking_policy.py +0 -124
  57. {mlrun-1.10.0rc1.dist-info → mlrun-1.10.0rc3.dist-info}/entry_points.txt +0 -0
  58. {mlrun-1.10.0rc1.dist-info → mlrun-1.10.0rc3.dist-info}/licenses/LICENSE +0 -0
  59. {mlrun-1.10.0rc1.dist-info → mlrun-1.10.0rc3.dist-info}/top_level.txt +0 -0
@@ -11,12 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
14
+ import copy
15
15
  import json
16
16
  import os
17
17
  import warnings
18
18
  from copy import deepcopy
19
- from typing import TYPE_CHECKING, Optional, Union
19
+ from typing import Optional, Union
20
20
 
21
21
  import nuclio
22
22
  from nuclio import KafkaTrigger
@@ -27,7 +27,11 @@ from mlrun.datastore import get_kafka_brokers_from_dict, parse_kafka_url
27
27
  from mlrun.model import ObjectList
28
28
  from mlrun.runtimes.function_reference import FunctionReference
29
29
  from mlrun.secrets import SecretsStore
30
- from mlrun.serving.server import GraphServer, create_graph_server
30
+ from mlrun.serving.server import (
31
+ GraphServer,
32
+ add_system_steps_to_graph,
33
+ create_graph_server,
34
+ )
31
35
  from mlrun.serving.states import (
32
36
  RootFlowStep,
33
37
  RouterStep,
@@ -43,10 +47,6 @@ from .function import NuclioSpec, RemoteRuntime, min_nuclio_versions
43
47
 
44
48
  serving_subkind = "serving_v2"
45
49
 
46
- if TYPE_CHECKING:
47
- # remove this block in 1.9.0
48
- from mlrun.model_monitoring import TrackingPolicy
49
-
50
50
 
51
51
  def new_v2_model_server(
52
52
  name,
@@ -95,7 +95,6 @@ class ServingSpec(NuclioSpec):
95
95
  "default_class",
96
96
  "secret_sources",
97
97
  "track_models",
98
- "tracking_policy",
99
98
  ]
100
99
 
101
100
  def __init__(
@@ -132,7 +131,6 @@ class ServingSpec(NuclioSpec):
132
131
  graph_initializer=None,
133
132
  error_stream=None,
134
133
  track_models=None,
135
- tracking_policy=None,
136
134
  secret_sources=None,
137
135
  default_content_type=None,
138
136
  node_name=None,
@@ -207,7 +205,6 @@ class ServingSpec(NuclioSpec):
207
205
  self.graph_initializer = graph_initializer
208
206
  self.error_stream = error_stream
209
207
  self.track_models = track_models
210
- self.tracking_policy = tracking_policy
211
208
  self.secret_sources = secret_sources or []
212
209
  self.default_content_type = default_content_type
213
210
  self.model_endpoint_creation_task_name = model_endpoint_creation_task_name
@@ -314,7 +311,6 @@ class ServingRuntime(RemoteRuntime):
314
311
  batch: Optional[int] = None,
315
312
  sampling_percentage: float = 100,
316
313
  stream_args: Optional[dict] = None,
317
- tracking_policy: Optional[Union["TrackingPolicy", dict]] = None,
318
314
  enable_tracking: bool = True,
319
315
  ) -> None:
320
316
  """Apply on your serving function to monitor a deployed model, including real-time dashboards to detect drift
@@ -361,20 +357,12 @@ class ServingRuntime(RemoteRuntime):
361
357
  if batch:
362
358
  warnings.warn(
363
359
  "The `batch` size parameter was deprecated in version 1.8.0 and is no longer used. "
364
- "It will be removed in 1.10.",
365
- # TODO: Remove this in 1.10
360
+ "It will be removed in 1.11.",
361
+ # TODO: Remove this in 1.11
366
362
  FutureWarning,
367
363
  )
368
364
  if stream_args:
369
365
  self.spec.parameters["stream_args"] = stream_args
370
- if tracking_policy is not None:
371
- warnings.warn(
372
- "The `tracking_policy` argument is deprecated from version 1.7.0 "
373
- "and has no effect. It will be removed in 1.9.0.\n"
374
- "To set the desired model monitoring time window and schedule, use "
375
- "the `base_period` argument in `project.enable_model_monitoring()`.",
376
- FutureWarning,
377
- )
378
366
 
379
367
  def add_model(
380
368
  self,
@@ -719,7 +707,6 @@ class ServingRuntime(RemoteRuntime):
719
707
  "graph_initializer": self.spec.graph_initializer,
720
708
  "error_stream": self.spec.error_stream,
721
709
  "track_models": self.spec.track_models,
722
- "tracking_policy": None,
723
710
  "default_content_type": self.spec.default_content_type,
724
711
  "model_endpoint_creation_task_name": self.spec.model_endpoint_creation_task_name,
725
712
  }
@@ -761,10 +748,13 @@ class ServingRuntime(RemoteRuntime):
761
748
  set_paths(workdir)
762
749
  os.chdir(workdir)
763
750
 
751
+ system_graph = None
752
+ if isinstance(self.spec.graph, RootFlowStep):
753
+ system_graph = add_system_steps_to_graph(copy.deepcopy(self.spec.graph))
764
754
  server = create_graph_server(
765
755
  parameters=self.spec.parameters,
766
756
  load_mode=self.spec.load_mode,
767
- graph=self.spec.graph,
757
+ graph=system_graph or self.spec.graph,
768
758
  verbose=self.verbose,
769
759
  current_function=current_function,
770
760
  graph_initializer=self.spec.graph_initializer,
mlrun/runtimes/utils.py CHANGED
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import enum
14
15
  import getpass
15
16
  import hashlib
16
17
  import json
@@ -28,7 +29,6 @@ import mlrun.common.constants as mlrun_constants
28
29
  import mlrun.common.schemas
29
30
  import mlrun.utils.regex
30
31
  from mlrun.artifacts import TableArtifact
31
- from mlrun.common.runtimes.constants import RunLabels
32
32
  from mlrun.config import config
33
33
  from mlrun.errors import err_to_str
34
34
  from mlrun.frameworks.parallel_coordinates import gen_pcp_plot
@@ -433,18 +433,35 @@ def enrich_function_from_dict(function, function_dict):
433
433
 
434
434
  def enrich_run_labels(
435
435
  labels: dict,
436
- labels_to_enrich: Optional[list[RunLabels]] = None,
436
+ labels_to_enrich: Optional[list[mlrun_constants.MLRunInternalLabels]] = None,
437
437
  ):
438
+ """
439
+ Enrich the run labels with the internal labels and the labels enrichment extension
440
+ :param labels: The run labels dict
441
+ :param labels_to_enrich: The label keys to enrich from MLRunInternalLabels.default_run_labels_to_enrich
442
+ :return: The enriched labels dict
443
+ """
444
+ # Merge the labels with the labels enrichment extension
438
445
  labels_enrichment = {
439
- RunLabels.owner: os.environ.get("V3IO_USERNAME") or getpass.getuser(),
440
- # TODO: remove this in 1.9.0
441
- RunLabels.v3io_user: os.environ.get("V3IO_USERNAME"),
446
+ mlrun_constants.MLRunInternalLabels.owner: os.environ.get("V3IO_USERNAME")
447
+ or getpass.getuser(),
448
+ # TODO: remove this in 1.10.0
449
+ mlrun_constants.MLRunInternalLabels.v3io_user: os.environ.get("V3IO_USERNAME"),
442
450
  }
443
- labels_to_enrich = labels_to_enrich or RunLabels.all()
451
+
452
+ # Resolve which label keys to enrich
453
+ if labels_to_enrich is None:
454
+ labels_to_enrich = (
455
+ mlrun_constants.MLRunInternalLabels.default_run_labels_to_enrich()
456
+ )
457
+
458
+ # Enrich labels
444
459
  for label in labels_to_enrich:
460
+ if isinstance(label, enum.Enum):
461
+ label = label.value
445
462
  enrichment = labels_enrichment.get(label)
446
- if label.value not in labels and enrichment:
447
- labels[label.value] = enrichment
463
+ if label not in labels and enrichment:
464
+ labels[label] = enrichment
448
465
  return labels
449
466
 
450
467
 
mlrun/serving/__init__.py CHANGED
@@ -30,7 +30,11 @@ __all__ = [
30
30
  ]
31
31
 
32
32
  from .routers import ModelRouter, VotingEnsemble # noqa
33
- from .server import GraphContext, GraphServer, create_graph_server # noqa
33
+ from .server import (
34
+ GraphContext,
35
+ GraphServer,
36
+ create_graph_server,
37
+ ) # noqa
34
38
  from .states import (
35
39
  ErrorStep,
36
40
  QueueStep,
mlrun/serving/server.py CHANGED
@@ -15,6 +15,7 @@
15
15
  __all__ = ["GraphServer", "create_graph_server", "GraphContext", "MockEvent"]
16
16
 
17
17
  import asyncio
18
+ import copy
18
19
  import json
19
20
  import os
20
21
  import socket
@@ -71,7 +72,7 @@ class _StreamContext:
71
72
  if (enabled or log_stream) and function_uri:
72
73
  self.enabled = True
73
74
  project, _, _, _ = parse_versioned_object_uri(
74
- function_uri, config.default_project
75
+ function_uri, config.active_project
75
76
  )
76
77
 
77
78
  stream_args = parameters.get("stream_args", {})
@@ -108,7 +109,6 @@ class GraphServer(ModelObj):
108
109
  graph_initializer=None,
109
110
  error_stream=None,
110
111
  track_models=None,
111
- tracking_policy=None,
112
112
  secret_sources=None,
113
113
  default_content_type=None,
114
114
  function_name=None,
@@ -129,7 +129,6 @@ class GraphServer(ModelObj):
129
129
  self.graph_initializer = graph_initializer
130
130
  self.error_stream = error_stream
131
131
  self.track_models = track_models
132
- self.tracking_policy = tracking_policy
133
132
  self._error_stream_object = None
134
133
  self.secret_sources = secret_sources
135
134
  self._secrets = SecretsStore.from_list(secret_sources)
@@ -330,12 +329,49 @@ class GraphServer(ModelObj):
330
329
  return self.graph.wait_for_completion()
331
330
 
332
331
 
332
+ def add_system_steps_to_graph(graph: RootFlowStep):
333
+ model_runner_raisers = {}
334
+ steps = list(graph.steps.values())
335
+ for step in steps:
336
+ if (
337
+ isinstance(step, mlrun.serving.states.ModelRunnerStep)
338
+ and step.raise_exception
339
+ ):
340
+ error_step = graph.add_step(
341
+ class_name="mlrun.serving.states.ModelRunnerErrorRaiser",
342
+ name=f"{step.name}_error_raise",
343
+ after=step.name,
344
+ full_event=True,
345
+ raise_exception=step.raise_exception,
346
+ models_names=list(step.class_args["models"].keys()),
347
+ )
348
+ if step.responder:
349
+ step.responder = False
350
+ error_step.respond()
351
+ model_runner_raisers[step.name] = error_step.name
352
+ error_step.on_error = step.on_error
353
+ if isinstance(step.after, list):
354
+ for i in range(len(step.after)):
355
+ if step.after[i] in model_runner_raisers:
356
+ step.after[i] = model_runner_raisers[step.after[i]]
357
+ else:
358
+ if step.after in model_runner_raisers:
359
+ step.after = model_runner_raisers[step.after]
360
+ return graph
361
+
362
+
333
363
  def v2_serving_init(context, namespace=None):
334
364
  """hook for nuclio init_context()"""
335
365
 
336
366
  context.logger.info("Initializing server from spec")
337
367
  spec = mlrun.utils.get_serving_spec()
338
368
  server = GraphServer.from_dict(spec)
369
+ if isinstance(server.graph, RootFlowStep):
370
+ server.graph = add_system_steps_to_graph(copy.deepcopy(server.graph))
371
+ context.logger.info_with(
372
+ "Server graph after adding system steps",
373
+ graph=str(server.graph.steps),
374
+ )
339
375
 
340
376
  if config.log_level.lower() == "debug":
341
377
  server.verbose = True
mlrun/serving/states.py CHANGED
@@ -30,6 +30,7 @@ from typing import Any, Optional, Union, cast
30
30
  import storey.utils
31
31
 
32
32
  import mlrun
33
+ import mlrun.artifacts
33
34
  import mlrun.common.schemas as schemas
34
35
  from mlrun.datastore.datastore_profile import (
35
36
  DatastoreProfileKafkaSource,
@@ -46,7 +47,7 @@ from ..datastore.utils import (
46
47
  get_kafka_brokers_from_dict,
47
48
  parse_kafka_url,
48
49
  )
49
- from ..errors import MLRunInvalidArgumentError, err_to_str
50
+ from ..errors import MLRunInvalidArgumentError, ModelRunnerError, err_to_str
50
51
  from ..model import ModelObj, ObjectDict
51
52
  from ..platforms.iguazio import parse_path
52
53
  from ..utils import get_class, get_function, is_explicit_ack_supported
@@ -402,6 +403,9 @@ class BaseStep(ModelObj):
402
403
  class_args=class_args,
403
404
  model_endpoint_creation_strategy=model_endpoint_creation_strategy,
404
405
  )
406
+
407
+ self.verify_model_runner_step(step)
408
+
405
409
  step = parent._steps.update(name, step)
406
410
  step.set_parent(parent)
407
411
  if not hasattr(self, "steps"):
@@ -446,6 +450,36 @@ class BaseStep(ModelObj):
446
450
  def supports_termination(self):
447
451
  return False
448
452
 
453
+ def verify_model_runner_step(self, step: "ModelRunnerStep"):
454
+ """
455
+ Verify ModelRunnerStep, can be part of Flow graph and models can not repeat in graph.
456
+ :param step: ModelRunnerStep to verify
457
+ """
458
+ if not isinstance(step, ModelRunnerStep):
459
+ return
460
+
461
+ root = self
462
+ while root.parent is not None:
463
+ root = root.parent
464
+
465
+ if not isinstance(root, RootFlowStep):
466
+ raise GraphError(
467
+ "ModelRunnerStep can be added to 'Flow' topology graph only"
468
+ )
469
+ step_model_endpoints_names = list(
470
+ step.class_args[schemas.ModelRunnerStepData.MODELS].keys()
471
+ )
472
+ # Get all model_endpoints names that are in both lists
473
+ common_endpoints_names = list(
474
+ set(root.model_endpoints_names) & set(step_model_endpoints_names)
475
+ )
476
+ if common_endpoints_names:
477
+ raise GraphError(
478
+ f"The graph already contains the model endpoints named - {common_endpoints_names}."
479
+ )
480
+ else:
481
+ root.extend_model_endpoints_names(step_model_endpoints_names)
482
+
449
483
 
450
484
  class TaskStep(BaseStep):
451
485
  """task execution step, runs a class or handler"""
@@ -988,14 +1022,18 @@ class ModelRunnerStep(TaskStep, StepToDict):
988
1022
 
989
1023
  :param model_selector: ModelSelector instance whose select() method will be used to select models to run on each
990
1024
  event. Optional. If not passed, all models will be run.
1025
+ :param raise_exception: If True, an error will be raised when model selection fails or if one of the models raised
1026
+ an error. If False, the error will appear in the output event.
991
1027
  """
992
1028
 
993
1029
  kind = "model_runner"
1030
+ _dict_fields = TaskStep._dict_fields + ["raise_exception"]
994
1031
 
995
1032
  def __init__(
996
1033
  self,
997
1034
  *args,
998
1035
  model_selector: Optional[Union[str, ModelSelector]] = None,
1036
+ raise_exception: bool = True,
999
1037
  **kwargs,
1000
1038
  ):
1001
1039
  super().__init__(
@@ -1004,27 +1042,96 @@ class ModelRunnerStep(TaskStep, StepToDict):
1004
1042
  class_args=dict(model_selector=model_selector),
1005
1043
  **kwargs,
1006
1044
  )
1045
+ self.raise_exception = raise_exception
1007
1046
 
1008
- def add_model(self, model: Union[str, Model], **model_parameters) -> None:
1047
+ def add_model(
1048
+ self,
1049
+ endpoint_name: str,
1050
+ model_class: str,
1051
+ model_artifact: Optional[Union[str, mlrun.artifacts.ModelArtifact]] = None,
1052
+ labels: Optional[Union[list[str], dict[str, str]]] = None,
1053
+ creation_strategy: Optional[
1054
+ schemas.ModelEndpointCreationStrategy
1055
+ ] = schemas.ModelEndpointCreationStrategy.INPLACE,
1056
+ inputs: Optional[list[str]] = None,
1057
+ outputs: Optional[list[str]] = None,
1058
+ input_path: Optional[str] = None,
1059
+ override: bool = False,
1060
+ **model_parameters,
1061
+ ) -> None:
1009
1062
  """
1010
1063
  Add a Model to this ModelRunner.
1011
1064
 
1012
- :param model: Model class name or object
1013
- :param model_parameters: Parameters for model instantiation
1065
+ :param endpoint_name: str, will identify the model in the ModelRunnerStep, and assign model endpoint name
1066
+ :param model_class: Model class name
1067
+ :param model_artifact: model artifact or mlrun model artifact uri
1068
+ :param labels: model endpoint labels, should be list of str or mapping of str:str
1069
+ :param creation_strategy: Strategy for creating or updating the model endpoint:
1070
+ * **overwrite**:
1071
+ 1. If model endpoints with the same name exist, delete the `latest` one.
1072
+ 2. Create a new model endpoint entry and set it as `latest`.
1073
+ * **inplace** (default):
1074
+ 1. If model endpoints with the same name exist, update the `latest` entry.
1075
+ 2. Otherwise, create a new entry.
1076
+ * **archive**:
1077
+ 1. If model endpoints with the same name exist, preserve them.
1078
+ 2. Create a new model endpoint with the same name and set it to `latest`.
1079
+ :param inputs: list of the model inputs (e.g. features) ,if provided will override the inputs that
1080
+ been configured in the model artifact, please note that those inputs need to be
1081
+ equal in length and order to the inputs that model_class predict method expects
1082
+ :param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs that
1083
+ been configured in the model artifact, please note that those outputs need to be
1084
+ equal to the model_class predict method outputs (length, and order)
1085
+ :param input_path: input path inside the user event, expect scopes to be defined by dot notation
1086
+ (e.g "inputs.my_model_inputs"). expects list or dictionary type object in path.
1087
+ :param override: bool allow override existing model on the current ModelRunnerStep.
1088
+ :param model_parameters: Parameters for model instantiation
1014
1089
  """
1015
- models = self.class_args.get("models", [])
1016
- models.append((model, model_parameters))
1017
- self.class_args["models"] = models
1090
+ # TODO allow model_class as Model object as part of ML-9924
1091
+ model_parameters = model_parameters or {}
1092
+ if model_parameters.get("name", endpoint_name) != endpoint_name:
1093
+ raise mlrun.errors.MLRunInvalidArgumentError(
1094
+ "Inconsistent name for model added to ModelRunnerStep."
1095
+ )
1096
+
1097
+ models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1098
+ if endpoint_name in models and not override:
1099
+ raise mlrun.errors.MLRunInvalidArgumentError(
1100
+ f"Model with name {endpoint_name} already exists in this ModelRunnerStep."
1101
+ )
1102
+
1103
+ model_parameters["name"] = endpoint_name
1104
+ monitoring_data = self.class_args.get(
1105
+ schemas.ModelRunnerStepData.MONITORING_DATA, {}
1106
+ )
1107
+ models[endpoint_name] = (model_class, model_parameters)
1108
+ monitoring_data[endpoint_name] = {
1109
+ schemas.MonitoringData.INPUTS: inputs,
1110
+ schemas.MonitoringData.OUTPUTS: outputs,
1111
+ schemas.MonitoringData.INPUT_PATH: input_path,
1112
+ schemas.MonitoringData.CREATION_STRATEGY: creation_strategy,
1113
+ schemas.MonitoringData.LABELS: labels,
1114
+ schemas.MonitoringData.MODEL_PATH: model_artifact.uri
1115
+ if isinstance(model_artifact, mlrun.artifacts.Artifact)
1116
+ else model_artifact,
1117
+ }
1118
+ self.class_args[schemas.ModelRunnerStepData.MODELS] = models
1119
+ self.class_args[schemas.ModelRunnerStepData.MONITORING_DATA] = monitoring_data
1018
1120
 
1019
1121
  def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs):
1020
1122
  model_selector = self.class_args.get("model_selector")
1021
- models = self.class_args.get("models")
1123
+ models = self.class_args.get(schemas.ModelRunnerStepData.MODELS, {})
1022
1124
  if isinstance(model_selector, str):
1023
1125
  model_selector = get_class(model_selector, namespace)()
1024
1126
  model_objects = []
1025
- for model, model_params in models:
1127
+ for model, model_params in models.values():
1026
1128
  if not isinstance(model, Model):
1129
+ # prevent model predict from raising error
1130
+ model_params["raise_exception"] = False
1027
1131
  model = get_class(model, namespace)(**model_params)
1132
+ else:
1133
+ # prevent model predict from raising error
1134
+ model._raise_exception = False
1028
1135
  model_objects.append(model)
1029
1136
  self._async_object = ModelRunner(
1030
1137
  model_selector=model_selector,
@@ -1032,6 +1139,29 @@ class ModelRunnerStep(TaskStep, StepToDict):
1032
1139
  )
1033
1140
 
1034
1141
 
1142
+ class ModelRunnerErrorRaiser(storey.MapClass):
1143
+ def __init__(self, raise_exception: bool, models_names: list[str], **kwargs):
1144
+ super().__init__(**kwargs)
1145
+ self._raise_exception = raise_exception
1146
+ self._models_names = models_names
1147
+
1148
+ def do(self, event):
1149
+ if self._raise_exception:
1150
+ errors = {}
1151
+ should_raise = False
1152
+ if len(self._models_names) == 1:
1153
+ should_raise = event.body.get("error") is not None
1154
+ errors[self._models_names[0]] = event.body.get("error")
1155
+ else:
1156
+ for model in event.body:
1157
+ errors[model] = event.body.get(model).get("error")
1158
+ if errors[model] is not None:
1159
+ should_raise = True
1160
+ if should_raise:
1161
+ raise ModelRunnerError(models_errors=errors)
1162
+ return event
1163
+
1164
+
1035
1165
  class QueueStep(BaseStep, StepToDict):
1036
1166
  """queue step, implement an async queue or represent a stream"""
1037
1167
 
@@ -1256,6 +1386,8 @@ class FlowStep(BaseStep):
1256
1386
  class_args=class_args,
1257
1387
  )
1258
1388
 
1389
+ self.verify_model_runner_step(step)
1390
+
1259
1391
  after_list = after if isinstance(after, list) else [after]
1260
1392
  for after in after_list:
1261
1393
  self.insert_step(name, step, after, before)
@@ -1676,7 +1808,41 @@ class RootFlowStep(FlowStep):
1676
1808
  """root flow step"""
1677
1809
 
1678
1810
  kind = "root"
1679
- _dict_fields = ["steps", "engine", "final_step", "on_error"]
1811
+ _dict_fields = [
1812
+ "steps",
1813
+ "engine",
1814
+ "final_step",
1815
+ "on_error",
1816
+ "model_endpoints_names",
1817
+ ]
1818
+
1819
+ def __init__(
1820
+ self,
1821
+ name=None,
1822
+ steps=None,
1823
+ after: Optional[list] = None,
1824
+ engine=None,
1825
+ final_step=None,
1826
+ ):
1827
+ super().__init__(
1828
+ name,
1829
+ steps,
1830
+ after,
1831
+ engine,
1832
+ final_step,
1833
+ )
1834
+ self._models = []
1835
+
1836
+ @property
1837
+ def model_endpoints_names(self) -> list[str]:
1838
+ return self._models
1839
+
1840
+ @model_endpoints_names.setter
1841
+ def model_endpoints_names(self, models: list[str]):
1842
+ self._models = models
1843
+
1844
+ def extend_model_endpoints_names(self, model_endpoints_names: list):
1845
+ self._models.extend(model_endpoints_names)
1680
1846
 
1681
1847
 
1682
1848
  classes_map = {
mlrun/utils/helpers.py CHANGED
@@ -876,13 +876,18 @@ def enrich_image_url(
876
876
  client_version: Optional[str] = None,
877
877
  client_python_version: Optional[str] = None,
878
878
  ) -> str:
879
+ image_url = image_url.strip()
880
+
881
+ # Add python version tag if needed
882
+ if image_url == "python" and client_python_version:
883
+ image_url = f"python:{client_python_version}"
884
+
879
885
  client_version = _convert_python_package_version_to_image_tag(client_version)
880
886
  server_version = _convert_python_package_version_to_image_tag(
881
887
  mlrun.utils.version.Version().get()["version"]
882
888
  )
883
- image_url = image_url.strip()
884
889
  mlrun_version = config.images_tag or client_version or server_version
885
- tag = mlrun_version
890
+ tag = mlrun_version or ""
886
891
 
887
892
  # TODO: Remove condition when mlrun/mlrun-kfp image is also supported
888
893
  if "mlrun-kfp" not in image_url:
@@ -2226,8 +2231,9 @@ class Workflow:
2226
2231
  namespace=mlrun.mlconf.namespace,
2227
2232
  )
2228
2233
 
2229
- # arbitrary timeout of 5 seconds, the workflow should be done by now
2230
- kfp_run = kfp_client.wait_for_run_completion(workflow_id, 5)
2234
+ # arbitrary timeout of 60 seconds, the workflow should be done by now, however sometimes kfp takes a few
2235
+ # seconds to update the workflow status
2236
+ kfp_run = kfp_client.wait_for_run_completion(workflow_id, 60)
2231
2237
  if not kfp_run:
2232
2238
  return None
2233
2239
 
@@ -1,4 +1,4 @@
1
1
  {
2
- "git_commit": "4d532f1a4e8dd427fd5f870863084fa564680bd6",
3
- "version": "1.10.0-rc1"
2
+ "git_commit": "210c516a3ed5c2f2c7223f31fdfd9e99b73d56b6",
3
+ "version": "1.10.0-rc3"
4
4
  }