mlrun 1.10.0rc6__py3-none-any.whl → 1.10.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (52) hide show
  1. mlrun/__init__.py +3 -1
  2. mlrun/__main__.py +47 -4
  3. mlrun/artifacts/base.py +0 -27
  4. mlrun/artifacts/dataset.py +0 -8
  5. mlrun/artifacts/model.py +0 -7
  6. mlrun/artifacts/plots.py +0 -13
  7. mlrun/common/schemas/background_task.py +5 -0
  8. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  9. mlrun/common/schemas/model_monitoring/constants.py +16 -0
  10. mlrun/common/schemas/project.py +4 -0
  11. mlrun/common/schemas/serving.py +2 -0
  12. mlrun/config.py +11 -22
  13. mlrun/datastore/utils.py +3 -1
  14. mlrun/db/base.py +0 -19
  15. mlrun/db/httpdb.py +73 -65
  16. mlrun/db/nopdb.py +0 -12
  17. mlrun/frameworks/tf_keras/__init__.py +4 -4
  18. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +23 -20
  19. mlrun/frameworks/tf_keras/model_handler.py +69 -9
  20. mlrun/frameworks/tf_keras/utils.py +12 -1
  21. mlrun/launcher/base.py +7 -0
  22. mlrun/launcher/client.py +2 -21
  23. mlrun/launcher/local.py +4 -0
  24. mlrun/model_monitoring/applications/_application_steps.py +23 -39
  25. mlrun/model_monitoring/applications/base.py +167 -32
  26. mlrun/model_monitoring/helpers.py +0 -3
  27. mlrun/projects/operations.py +11 -24
  28. mlrun/projects/pipelines.py +33 -3
  29. mlrun/projects/project.py +45 -89
  30. mlrun/run.py +37 -5
  31. mlrun/runtimes/daskjob.py +2 -0
  32. mlrun/runtimes/kubejob.py +5 -8
  33. mlrun/runtimes/mpijob/abstract.py +2 -0
  34. mlrun/runtimes/mpijob/v1.py +2 -0
  35. mlrun/runtimes/nuclio/function.py +2 -0
  36. mlrun/runtimes/nuclio/serving.py +60 -5
  37. mlrun/runtimes/pod.py +3 -0
  38. mlrun/runtimes/remotesparkjob.py +2 -0
  39. mlrun/runtimes/sparkjob/spark3job.py +2 -0
  40. mlrun/serving/__init__.py +2 -0
  41. mlrun/serving/server.py +253 -29
  42. mlrun/serving/states.py +215 -18
  43. mlrun/serving/system_steps.py +391 -0
  44. mlrun/serving/v2_serving.py +9 -8
  45. mlrun/utils/helpers.py +18 -4
  46. mlrun/utils/version/version.json +2 -2
  47. {mlrun-1.10.0rc6.dist-info → mlrun-1.10.0rc8.dist-info}/METADATA +9 -9
  48. {mlrun-1.10.0rc6.dist-info → mlrun-1.10.0rc8.dist-info}/RECORD +52 -51
  49. {mlrun-1.10.0rc6.dist-info → mlrun-1.10.0rc8.dist-info}/WHEEL +0 -0
  50. {mlrun-1.10.0rc6.dist-info → mlrun-1.10.0rc8.dist-info}/entry_points.txt +0 -0
  51. {mlrun-1.10.0rc6.dist-info → mlrun-1.10.0rc8.dist-info}/licenses/LICENSE +0 -0
  52. {mlrun-1.10.0rc6.dist-info → mlrun-1.10.0rc8.dist-info}/top_level.txt +0 -0
mlrun/db/httpdb.py CHANGED
@@ -46,6 +46,7 @@ import mlrun.utils
46
46
  from mlrun.alerts.alert import AlertConfig
47
47
  from mlrun.db.auth_utils import OAuthClientIDTokenProvider, StaticTokenProvider
48
48
  from mlrun.errors import MLRunInvalidArgumentError, err_to_str
49
+ from mlrun.secrets import get_secret_or_env
49
50
  from mlrun_pipelines.utils import compile_pipeline
50
51
 
51
52
  from ..artifacts import Artifact
@@ -156,9 +157,9 @@ class HTTPRunDB(RunDBInterface):
156
157
 
157
158
  if config.auth_with_client_id.enabled:
158
159
  self.token_provider = OAuthClientIDTokenProvider(
159
- token_endpoint=mlrun.get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
160
- client_id=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
161
- client_secret=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
160
+ token_endpoint=get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
161
+ client_id=get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
162
+ client_secret=get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
162
163
  timeout=config.auth_with_client_id.request_timeout,
163
164
  )
164
165
  else:
@@ -901,9 +902,6 @@ class HTTPRunDB(RunDBInterface):
901
902
  uid: Optional[Union[str, list[str]]] = None,
902
903
  project: Optional[str] = None,
903
904
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
904
- state: Optional[
905
- mlrun.common.runtimes.constants.RunStates
906
- ] = None, # Backward compatibility
907
905
  states: typing.Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
908
906
  sort: bool = True,
909
907
  iter: bool = False,
@@ -948,7 +946,6 @@ class HTTPRunDB(RunDBInterface):
948
946
  or just `"label"` for key existence.
949
947
  - A comma-separated string formatted as `"label1=value1,label2"` to match entities with
950
948
  the specified key-value pairs or key existence.
951
- :param state: Deprecated - List only runs whose state is specified (will be removed in 1.10.0)
952
949
  :param states: List only runs whose state is one of the provided states.
953
950
  :param sort: Whether to sort the result according to their start time. Otherwise, results will be
954
951
  returned by their internal order in the DB (order will not be guaranteed).
@@ -976,7 +973,6 @@ class HTTPRunDB(RunDBInterface):
976
973
  uid=uid,
977
974
  project=project,
978
975
  labels=labels,
979
- state=state,
980
976
  states=states,
981
977
  sort=sort,
982
978
  iter=iter,
@@ -2357,8 +2353,7 @@ class HTTPRunDB(RunDBInterface):
2357
2353
  ):
2358
2354
  """
2359
2355
  Retry a specific pipeline run using its run ID. This function sends an API request
2360
- to retry a pipeline run. If a project is specified, the run must belong to that
2361
- project; otherwise, all projects are queried.
2356
+ to retry a pipeline run.
2362
2357
 
2363
2358
  :param run_id: The unique ID of the pipeline run to retry.
2364
2359
  :param namespace: Kubernetes namespace where the pipeline is running. Optional.
@@ -2399,7 +2394,7 @@ class HTTPRunDB(RunDBInterface):
2399
2394
  namespace=namespace,
2400
2395
  response_code=resp_code,
2401
2396
  response_text=resp_text,
2402
- error=str(exc),
2397
+ error=err_to_str(exc),
2403
2398
  )
2404
2399
  if isinstance(exc, mlrun.errors.MLRunHTTPError):
2405
2400
  raise exc # Re-raise known HTTP errors
@@ -2415,6 +2410,72 @@ class HTTPRunDB(RunDBInterface):
2415
2410
  )
2416
2411
  return resp.json()
2417
2412
 
2413
+ def terminate_pipeline(
2414
+ self,
2415
+ run_id: str,
2416
+ project: str,
2417
+ namespace: Optional[str] = None,
2418
+ timeout: int = 30,
2419
+ ):
2420
+ """
2421
+ Terminate a specific pipeline run using its run ID. This function sends an API request
2422
+ to terminate a pipeline run.
2423
+
2424
+ :param run_id: The unique ID of the pipeline run to terminate.
2425
+ :param namespace: Kubernetes namespace where the pipeline is running. Optional.
2426
+ :param timeout: Timeout (in seconds) for the API call. Defaults to 30 seconds.
2427
+ :param project: Name of the MLRun project associated with the pipeline.
2428
+
2429
+ :raises ValueError: Raised if the API response is not successful or contains an
2430
+ error.
2431
+
2432
+ :return: JSON response containing details of the terminate pipeline run background task.
2433
+ """
2434
+
2435
+ params = {}
2436
+ if namespace:
2437
+ params["namespace"] = namespace
2438
+
2439
+ resp_text = ""
2440
+ resp_code = None
2441
+ try:
2442
+ resp = self.api_call(
2443
+ "POST",
2444
+ f"projects/{project}/pipelines/{run_id}/terminate",
2445
+ params=params,
2446
+ timeout=timeout,
2447
+ )
2448
+ resp_code = resp.status_code
2449
+ resp_text = resp.text
2450
+ if not resp.ok:
2451
+ raise mlrun.errors.MLRunHTTPError(
2452
+ f"Failed to retry pipeline run '{run_id}'. "
2453
+ f"HTTP {resp_code}: {resp_text}"
2454
+ )
2455
+ except Exception as exc:
2456
+ logger.error(
2457
+ "Failed to invoke terminate pipeline API",
2458
+ run_id=run_id,
2459
+ project=project,
2460
+ namespace=namespace,
2461
+ response_code=resp_code,
2462
+ response_text=resp_text,
2463
+ error=err_to_str(exc),
2464
+ )
2465
+ if isinstance(exc, mlrun.errors.MLRunHTTPError):
2466
+ raise exc # Re-raise known HTTP errors
2467
+ raise mlrun.errors.MLRunRuntimeError(
2468
+ f"Unexpected error while terminating pipeline run '{run_id}'."
2469
+ ) from exc
2470
+
2471
+ logger.info(
2472
+ "Successfully scheduled terminate pipeline run background task",
2473
+ run_id=run_id,
2474
+ project=project,
2475
+ namespace=namespace,
2476
+ )
2477
+ return resp.json()
2478
+
2418
2479
  @staticmethod
2419
2480
  def _resolve_reference(tag, uid):
2420
2481
  if uid and tag:
@@ -2569,44 +2630,6 @@ class HTTPRunDB(RunDBInterface):
2569
2630
  resp = self.api_call("GET", path, error_message, params=params, version="v2")
2570
2631
  return resp.json()
2571
2632
 
2572
- def list_entities(
2573
- self,
2574
- project: Optional[str] = None,
2575
- name: Optional[str] = None,
2576
- tag: Optional[str] = None,
2577
- labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
2578
- ) -> list[dict]:
2579
- """Retrieve a list of entities and their mapping to the containing feature-sets. This function is similar
2580
- to the :py:func:`~list_features` function, and uses the same logic. However, the entities are matched
2581
- against the name rather than the features.
2582
-
2583
- :param project: The project containing the entities.
2584
- :param name: The name of the entities to retrieve.
2585
- :param tag: The tag of the specific entity version to retrieve.
2586
- :param labels: Filter entities by label key-value pairs or key existence. This can be provided as:
2587
- - A dictionary in the format `{"label": "value"}` to match specific label key-value pairs,
2588
- or `{"label": None}` to check for key existence.
2589
- - A list of strings formatted as `"label=value"` to match specific label key-value pairs,
2590
- or just `"label"` for key existence.
2591
- - A comma-separated string formatted as `"label1=value1,label2"` to match entities with
2592
- the specified key-value pairs or key existence.
2593
- :returns: A list of entities.
2594
- """
2595
-
2596
- project = project or config.active_project
2597
- labels = self._parse_labels(labels)
2598
- params = {
2599
- "name": name,
2600
- "tag": tag,
2601
- "label": labels,
2602
- }
2603
-
2604
- path = f"projects/{project}/entities"
2605
-
2606
- error_message = f"Failed listing entities, project: {project}, query: {params}"
2607
- resp = self.api_call("GET", path, error_message, params=params)
2608
- return resp.json()["entities"]
2609
-
2610
2633
  def list_entities_v2(
2611
2634
  self,
2612
2635
  project: Optional[str] = None,
@@ -5263,9 +5286,6 @@ class HTTPRunDB(RunDBInterface):
5263
5286
  uid: Optional[Union[str, list[str]]] = None,
5264
5287
  project: Optional[str] = None,
5265
5288
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
5266
- state: Optional[
5267
- mlrun.common.runtimes.constants.RunStates
5268
- ] = None, # Backward compatibility
5269
5289
  states: typing.Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
5270
5290
  sort: bool = True,
5271
5291
  iter: bool = False,
@@ -5299,20 +5319,12 @@ class HTTPRunDB(RunDBInterface):
5299
5319
  "using the `with_notifications` flag."
5300
5320
  )
5301
5321
 
5302
- if state:
5303
- # TODO: Remove this in 1.10.0
5304
- warnings.warn(
5305
- "'state' is deprecated in 1.7.0 and will be removed in 1.10.0. Use 'states' instead.",
5306
- FutureWarning,
5307
- )
5308
-
5309
5322
  labels = self._parse_labels(labels)
5310
5323
 
5311
5324
  if (
5312
5325
  not name
5313
5326
  and not uid
5314
5327
  and not labels
5315
- and not state
5316
5328
  and not states
5317
5329
  and not start_time_from
5318
5330
  and not start_time_to
@@ -5333,11 +5345,7 @@ class HTTPRunDB(RunDBInterface):
5333
5345
  "name": name,
5334
5346
  "uid": uid,
5335
5347
  "label": labels,
5336
- "state": (
5337
- mlrun.utils.helpers.as_list(state)
5338
- if state is not None
5339
- else states or None
5340
- ),
5348
+ "states": states or None,
5341
5349
  "sort": bool2str(sort),
5342
5350
  "iter": bool2str(iter),
5343
5351
  "start_time_from": datetime_to_iso(start_time_from),
mlrun/db/nopdb.py CHANGED
@@ -126,9 +126,6 @@ class NopDB(RunDBInterface):
126
126
  uid: Optional[Union[str, list[str]]] = None,
127
127
  project: Optional[str] = None,
128
128
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
129
- state: Optional[
130
- mlrun.common.runtimes.constants.RunStates
131
- ] = None, # Backward compatibility
132
129
  states: Optional[list[mlrun.common.runtimes.constants.RunStates]] = None,
133
130
  sort: bool = True,
134
131
  iter: bool = False,
@@ -394,15 +391,6 @@ class NopDB(RunDBInterface):
394
391
  ) -> mlrun.common.schemas.FeaturesOutputV2:
395
392
  pass
396
393
 
397
- def list_entities(
398
- self,
399
- project: str,
400
- name: Optional[str] = None,
401
- tag: Optional[str] = None,
402
- labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
403
- ) -> mlrun.common.schemas.EntitiesOutput:
404
- pass
405
-
406
394
  def list_entities_v2(
407
395
  self,
408
396
  project: str,
@@ -14,7 +14,7 @@
14
14
 
15
15
  from typing import Any, Optional, Union
16
16
 
17
- from tensorflow import keras
17
+ import tensorflow as tf
18
18
 
19
19
  import mlrun
20
20
  import mlrun.common.constants as mlrun_constants
@@ -27,11 +27,11 @@ from .utils import TFKerasTypes, TFKerasUtils
27
27
 
28
28
 
29
29
  def apply_mlrun(
30
- model: keras.Model = None,
30
+ model: tf.keras.Model = None,
31
31
  model_name: Optional[str] = None,
32
32
  tag: str = "",
33
33
  model_path: Optional[str] = None,
34
- model_format: str = TFKerasModelHandler.ModelFormats.SAVED_MODEL,
34
+ model_format: Optional[str] = None,
35
35
  save_traces: bool = False,
36
36
  modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
37
37
  custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
@@ -54,7 +54,7 @@ def apply_mlrun(
54
54
  :param model_path: The model's store object path. Mandatory for evaluation (to know which model to
55
55
  update). If model is not provided, it will be loaded from this path.
56
56
  :param model_format: The format to use for saving and loading the model. Should be passed as a
57
- member of the class 'ModelFormats'. Default: 'ModelFormats.SAVED_MODEL'.
57
+ member of the class 'ModelFormats'.
58
58
  :param save_traces: Whether or not to use functions saving (only available for the 'SavedModel'
59
59
  format) for loading the model later without the custom objects dictionary. Only
60
60
  from tensorflow version >= 2.4.0. Using this setting will increase the model
@@ -16,14 +16,14 @@ from typing import Callable, Optional, Union
16
16
 
17
17
  import numpy as np
18
18
  import tensorflow as tf
19
- from tensorflow import Tensor, Variable
19
+ from tensorflow import keras
20
20
  from tensorflow.python.keras.callbacks import Callback
21
21
 
22
22
  import mlrun
23
23
 
24
24
  from ..._common import LoggingMode
25
25
  from ..._dl_common.loggers import Logger
26
- from ..utils import TFKerasTypes
26
+ from ..utils import TFKerasTypes, is_keras_3
27
27
 
28
28
 
29
29
  class LoggingCallback(Callback):
@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
70
70
  {
71
71
  "epochs": 7
72
72
  }
73
- :param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
73
+ :param auto_log: Whether to enable auto logging, trying to track common static and dynamic
74
74
  hyperparameters.
75
75
  """
76
76
  super().__init__()
@@ -385,18 +385,24 @@ class LoggingCallback(Callback):
385
385
  self._logger.log_context_parameters()
386
386
 
387
387
  # Add learning rate:
388
- learning_rate_key = "lr"
389
- learning_rate_key_chain = ["optimizer", "lr"]
390
- if learning_rate_key not in self._dynamic_hyperparameters_keys and hasattr(
391
- self.model, "optimizer"
392
- ):
393
- try:
394
- self._get_hyperparameter(key_chain=learning_rate_key_chain)
388
+ learning_rate_keys = [
389
+ "learning_rate",
390
+ "lr",
391
+ ] # "lr" is for backward compatibility in older keras versions.
392
+ if all(
393
+ learning_rate_key not in self._dynamic_hyperparameters_keys
394
+ for learning_rate_key in learning_rate_keys
395
+ ) and hasattr(self.model, "optimizer"):
396
+ for learning_rate_key in learning_rate_keys:
397
+ learning_rate_key_chain = ["optimizer", learning_rate_key]
398
+ try:
399
+ self._get_hyperparameter(key_chain=learning_rate_key_chain)
400
+ except (KeyError, IndexError, AttributeError, ValueError):
401
+ continue
395
402
  self._dynamic_hyperparameters_keys[learning_rate_key] = (
396
403
  learning_rate_key_chain
397
404
  )
398
- except (KeyError, IndexError, ValueError):
399
- pass
405
+ break
400
406
 
401
407
  def _get_hyperparameter(
402
408
  self,
@@ -427,7 +433,7 @@ class LoggingCallback(Callback):
427
433
  value = value[key]
428
434
  else:
429
435
  value = getattr(value, key)
430
- except KeyError or IndexError as KeyChainError:
436
+ except KeyError or IndexError or AttributeError as KeyChainError:
431
437
  raise KeyChainError(
432
438
  f"Error during getting a hyperparameter value with the key chain {key_chain}. "
433
439
  f"The {value.__class__} in it does not have the following key/index from the key provided: "
@@ -435,7 +441,9 @@ class LoggingCallback(Callback):
435
441
  )
436
442
 
437
443
  # Parse the value:
438
- if isinstance(value, Tensor) or isinstance(value, Variable):
444
+ if isinstance(value, (tf.Tensor, tf.Variable)) or (
445
+ is_keras_3() and isinstance(value, (keras.KerasTensor, keras.Variable))
446
+ ):
439
447
  if int(tf.size(value)) == 1:
440
448
  value = float(value)
441
449
  else:
@@ -451,12 +459,7 @@ class LoggingCallback(Callback):
451
459
  f"The parameter with the following key chain: {key_chain} is a numpy.ndarray with {value.size} "
452
460
  f"elements. numpy arrays are trackable only if they have 1 element."
453
461
  )
454
- elif not (
455
- isinstance(value, float)
456
- or isinstance(value, int)
457
- or isinstance(value, str)
458
- or isinstance(value, bool)
459
- ):
462
+ elif not (isinstance(value, (float, int, str, bool))):
460
463
  raise mlrun.errors.MLRunInvalidArgumentError(
461
464
  f"The parameter with the following key chain: {key_chain} is of type '{type(value)}'. The only "
462
465
  f"trackable types are: float, int, str and bool."
@@ -29,7 +29,7 @@ from mlrun.features import Feature
29
29
  from .._common import without_mlrun_interface
30
30
  from .._dl_common import DLModelHandler
31
31
  from .mlrun_interface import TFKerasMLRunInterface
32
- from .utils import TFKerasUtils
32
+ from .utils import TFKerasUtils, is_keras_3
33
33
 
34
34
 
35
35
  class TFKerasModelHandler(DLModelHandler):
@@ -40,8 +40,8 @@ class TFKerasModelHandler(DLModelHandler):
40
40
  # Framework name:
41
41
  FRAMEWORK_NAME = "tensorflow.keras"
42
42
 
43
- # Declare a type of an input sample:
44
- IOSample = Union[tf.Tensor, tf.TensorSpec, np.ndarray]
43
+ # Declare a type of input sample (only from keras v3 there is a KerasTensor type):
44
+ IOSample = Union[tf.Tensor, tf.TensorSpec, "keras.KerasTensor", np.ndarray]
45
45
 
46
46
  class ModelFormats:
47
47
  """
@@ -49,9 +49,19 @@ class TFKerasModelHandler(DLModelHandler):
49
49
  """
50
50
 
51
51
  SAVED_MODEL = "SavedModel"
52
+ KERAS = "keras"
52
53
  H5 = "h5"
53
54
  JSON_ARCHITECTURE_H5_WEIGHTS = "json_h5"
54
55
 
56
+ @classmethod
57
+ def default(cls) -> str:
58
+ """
59
+ Get the default model format to use for saving and loading the model based on the keras version.
60
+
61
+ :return: The default model format to use.
62
+ """
63
+ return cls.KERAS if is_keras_3() else cls.SAVED_MODEL
64
+
55
65
  class _LabelKeys:
56
66
  """
57
67
  Required labels keys to log with the model.
@@ -65,7 +75,7 @@ class TFKerasModelHandler(DLModelHandler):
65
75
  model: keras.Model = None,
66
76
  model_path: Optional[str] = None,
67
77
  model_name: Optional[str] = None,
68
- model_format: str = ModelFormats.SAVED_MODEL,
78
+ model_format: Optional[str] = None,
69
79
  context: mlrun.MLClientCtx = None,
70
80
  modules_map: Optional[
71
81
  Union[dict[str, Union[None, str, list[str]]], str]
@@ -98,7 +108,7 @@ class TFKerasModelHandler(DLModelHandler):
98
108
  * If given a loaded model object and the model name is None, the name will be
99
109
  set to the model's object name / class.
100
110
  :param model_format: The format to use for saving and loading the model. Should be passed as a
101
- member of the class 'ModelFormats'. Default: 'ModelFormats.SAVED_MODEL'.
111
+ member of the class 'ModelFormats'.
102
112
  :param context: MLRun context to work with for logging the model.
103
113
  :param modules_map: A dictionary of all the modules required for loading the model. Each key
104
114
  is a path to a module and its value is the object name to import from it. All
@@ -144,8 +154,11 @@ class TFKerasModelHandler(DLModelHandler):
144
154
  * 'save_traces' parameter was miss-used.
145
155
  """
146
156
  # Validate given format:
157
+ if not model_format:
158
+ model_format = TFKerasModelHandler.ModelFormats.default()
147
159
  if model_format not in [
148
160
  TFKerasModelHandler.ModelFormats.SAVED_MODEL,
161
+ TFKerasModelHandler.ModelFormats.KERAS,
149
162
  TFKerasModelHandler.ModelFormats.H5,
150
163
  TFKerasModelHandler.ModelFormats.JSON_ARCHITECTURE_H5_WEIGHTS,
151
164
  ]:
@@ -153,6 +166,22 @@ class TFKerasModelHandler(DLModelHandler):
153
166
  f"Unrecognized model format: '{model_format}'. Please use one of the class members of "
154
167
  "'TFKerasModelHandler.ModelFormats'"
155
168
  )
169
+ if not is_keras_3():
170
+ if model_format == TFKerasModelHandler.ModelFormats.KERAS:
171
+ raise mlrun.errors.MLRunInvalidArgumentError(
172
+ "The 'keras' model format is only supported in Keras 3.0.0 and above. "
173
+ f"Current version is {keras.__version__}."
174
+ )
175
+ else:
176
+ if (
177
+ model_format == TFKerasModelHandler.ModelFormats.SAVED_MODEL
178
+ or model_format
179
+ == TFKerasModelHandler.ModelFormats.JSON_ARCHITECTURE_H5_WEIGHTS
180
+ ):
181
+ raise mlrun.errors.MLRunInvalidArgumentError(
182
+ f"The '{model_format}' model format is not supported in Keras 3.0.0 and above. "
183
+ f"Current version is {keras.__version__}."
184
+ )
156
185
 
157
186
  # Validate 'save_traces':
158
187
  if save_traces:
@@ -239,11 +268,19 @@ class TFKerasModelHandler(DLModelHandler):
239
268
  self._model_file = f"{self._model_name}.h5"
240
269
  self._model.save(self._model_file)
241
270
 
271
+ # ModelFormats.keras - Save as a keras file:
272
+ elif self._model_format == self.ModelFormats.KERAS:
273
+ self._model_file = f"{self._model_name}.keras"
274
+ self._model.save(self._model_file)
275
+
242
276
  # ModelFormats.SAVED_MODEL - Save as a SavedModel directory and zip its file:
243
277
  elif self._model_format == TFKerasModelHandler.ModelFormats.SAVED_MODEL:
244
278
  # Save it in a SavedModel format directory:
279
+ # Note: Using keras>=3.0.0 can save in this format via `model.export` but then it won't be able to load it
280
+ # back, only for inference. So, we use the `save` method instead for keras 2 and validate the user won't use
281
+ # keras 3 and this model format.
245
282
  if self._save_traces is True:
246
- # Save traces can only be used in versions >= 2.4, so only if its true we use it in the call:
283
+ # Save traces can only be used in versions >= 2.4, so only if it's true, we use it in the call:
247
284
  self._model.save(self._model_name, save_traces=self._save_traces)
248
285
  else:
249
286
  self._model.save(self._model_name)
@@ -303,6 +340,12 @@ class TFKerasModelHandler(DLModelHandler):
303
340
  self._model_file, custom_objects=self._custom_objects
304
341
  )
305
342
 
343
+ # ModelFormats.KERAS - Load from a keras file:
344
+ elif self._model_format == TFKerasModelHandler.ModelFormats.KERAS:
345
+ self._model = keras.models.load_model(
346
+ self._model_file, custom_objects=self._custom_objects
347
+ )
348
+
306
349
  # ModelFormats.SAVED_MODEL - Load from a SavedModel directory:
307
350
  elif self._model_format == TFKerasModelHandler.ModelFormats.SAVED_MODEL:
308
351
  self._model = keras.models.load_model(
@@ -434,7 +477,10 @@ class TFKerasModelHandler(DLModelHandler):
434
477
  )
435
478
 
436
479
  # Read the inputs:
437
- input_signature = [input_layer.type_spec for input_layer in self._model.inputs]
480
+ input_signature = [
481
+ getattr(input_layer, "type_spec", input_layer)
482
+ for input_layer in self._model.inputs
483
+ ]
438
484
 
439
485
  # Set the inputs:
440
486
  self.set_inputs(from_sample=input_signature)
@@ -453,7 +499,8 @@ class TFKerasModelHandler(DLModelHandler):
453
499
 
454
500
  # Read the outputs:
455
501
  output_signature = [
456
- output_layer.type_spec for output_layer in self._model.outputs
502
+ getattr(output_layer, "type_spec", output_layer)
503
+ for output_layer in self._model.outputs
457
504
  ]
458
505
 
459
506
  # Set the outputs:
@@ -509,6 +556,17 @@ class TFKerasModelHandler(DLModelHandler):
509
556
  f"'{self._model_path}'"
510
557
  )
511
558
 
559
+ # ModelFormats.KERAS - Get the keras model file:
560
+ elif self._model_format == TFKerasModelHandler.ModelFormats.KERAS:
561
+ self._model_file = os.path.join(
562
+ self._model_path, f"{self._model_name}.keras"
563
+ )
564
+ if not os.path.exists(self._model_file):
565
+ raise mlrun.errors.MLRunNotFoundError(
566
+ f"The model file '{self._model_name}.keras' was not found within the given 'model_path': "
567
+ f"'{self._model_path}'"
568
+ )
569
+
512
570
  # ModelFormats.SAVED_MODEL - Get the zip file and extract it, or simply locate the directory:
513
571
  elif self._model_format == TFKerasModelHandler.ModelFormats.SAVED_MODEL:
514
572
  self._model_file = os.path.join(self._model_path, f"{self._model_name}.zip")
@@ -559,7 +617,9 @@ class TFKerasModelHandler(DLModelHandler):
559
617
  # Supported types:
560
618
  if isinstance(sample, np.ndarray):
561
619
  return super()._read_sample(sample=sample)
562
- elif isinstance(sample, tf.TensorSpec):
620
+ elif isinstance(sample, tf.TensorSpec) or (
621
+ is_keras_3() and isinstance(sample, keras.KerasTensor)
622
+ ):
563
623
  return Feature(
564
624
  name=sample.name,
565
625
  value_type=TFKerasUtils.convert_tf_dtype_to_value_type(
@@ -11,8 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import tensorflow as tf
15
+ from packaging import version
16
16
  from tensorflow import keras
17
17
 
18
18
  import mlrun
@@ -117,3 +117,14 @@ class TFKerasUtils(DLUtils):
117
117
  raise mlrun.errors.MLRunInvalidArgumentError(
118
118
  f"MLRun value type is not supporting the given tensorflow data type: '{tf_dtype}'."
119
119
  )
120
+
121
+
122
+ def is_keras_3() -> bool:
123
+ """
124
+ Check if the current Keras version is 3.x.
125
+
126
+ :return: True if Keras version is 3.x, False otherwise.
127
+ """
128
+ return hasattr(keras, "__version__") and version.parse(
129
+ keras.__version__
130
+ ) >= version.parse("3.0.0")
mlrun/launcher/base.py CHANGED
@@ -82,6 +82,7 @@ class BaseLauncher(abc.ABC):
82
82
  runtime: "mlrun.runtimes.base.BaseRuntime",
83
83
  project_name: Optional[str] = "",
84
84
  full: bool = True,
85
+ client_version: str = "",
85
86
  ):
86
87
  pass
87
88
 
@@ -147,6 +148,12 @@ class BaseLauncher(abc.ABC):
147
148
  self._validate_run_params(run.spec.parameters)
148
149
  self._validate_output_path(runtime, run)
149
150
 
151
+ for image in [
152
+ runtime.spec.image,
153
+ getattr(runtime.spec.build, "base_image", None),
154
+ ]:
155
+ mlrun.utils.helpers.warn_on_deprecated_image(image)
156
+
150
157
  @staticmethod
151
158
  def _validate_output_path(
152
159
  runtime: "mlrun.runtimes.BaseRuntime",
mlrun/launcher/client.py CHANGED
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import abc
15
- import warnings
16
15
  from typing import Optional
17
16
 
18
17
  import IPython.display
@@ -37,6 +36,7 @@ class ClientBaseLauncher(launcher.BaseLauncher, abc.ABC):
37
36
  runtime: "mlrun.runtimes.base.BaseRuntime",
38
37
  project_name: Optional[str] = "",
39
38
  full: bool = True,
39
+ client_version: str = "",
40
40
  ):
41
41
  runtime.try_auto_mount_based_on_config()
42
42
  runtime._fill_credentials()
@@ -62,26 +62,7 @@ class ClientBaseLauncher(launcher.BaseLauncher, abc.ABC):
62
62
  ):
63
63
  image = mlrun.mlconf.function_defaults.image_by_kind.to_dict()[runtime.kind]
64
64
 
65
- # Warn if user explicitly set the deprecated mlrun/ml-base image
66
- if image and "mlrun/ml-base" in image:
67
- client_version = mlrun.utils.version.Version().get()["version"]
68
- auto_replaced = mlrun.utils.validate_component_version_compatibility(
69
- "mlrun-client", "1.10.0", mlrun_client_version=client_version
70
- )
71
- message = (
72
- "'mlrun/ml-base' image is deprecated in 1.10.0 and will be removed in 1.12.0, "
73
- "use 'mlrun/mlrun' instead."
74
- )
75
- if auto_replaced:
76
- message += (
77
- " Since your client version is >= 1.10.0, the image will be automatically "
78
- "replaced with mlrun/mlrun."
79
- )
80
- warnings.warn(
81
- message,
82
- # TODO: Remove this in 1.12.0
83
- FutureWarning,
84
- )
65
+ mlrun.utils.helpers.warn_on_deprecated_image(image)
85
66
 
86
67
  # TODO: need a better way to decide whether a function requires a build
87
68
  if require_build and image and not runtime.spec.build.base_image:
mlrun/launcher/local.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  import os
15
15
  import pathlib
16
+ from os import environ
16
17
  from typing import Callable, Optional, Union
17
18
 
18
19
  import mlrun.common.constants as mlrun_constants
@@ -251,6 +252,9 @@ class ClientLocalLauncher(launcher.ClientBaseLauncher):
251
252
  # copy the code/base-spec to the local function (for the UI and code logging)
252
253
  fn.spec.description = runtime.spec.description
253
254
  fn.spec.build = runtime.spec.build
255
+ serving_spec = getattr(runtime.spec, "serving_spec", None)
256
+ if serving_spec:
257
+ environ["SERVING_SPEC_ENV"] = serving_spec
254
258
 
255
259
  run.spec.handler = handler
256
260
  run.spec.reset_on_run = reset_on_run