mlrun 1.10.0rc7__py3-none-any.whl → 1.10.0rc8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (34) hide show
  1. mlrun/__init__.py +3 -1
  2. mlrun/common/schemas/background_task.py +5 -0
  3. mlrun/common/schemas/model_monitoring/__init__.py +2 -0
  4. mlrun/common/schemas/model_monitoring/constants.py +16 -0
  5. mlrun/common/schemas/project.py +4 -0
  6. mlrun/common/schemas/serving.py +2 -0
  7. mlrun/config.py +11 -22
  8. mlrun/datastore/utils.py +3 -1
  9. mlrun/db/base.py +11 -10
  10. mlrun/db/httpdb.py +97 -25
  11. mlrun/db/nopdb.py +5 -4
  12. mlrun/frameworks/tf_keras/__init__.py +4 -4
  13. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +23 -20
  14. mlrun/frameworks/tf_keras/model_handler.py +69 -9
  15. mlrun/frameworks/tf_keras/utils.py +12 -1
  16. mlrun/launcher/base.py +6 -0
  17. mlrun/launcher/client.py +1 -21
  18. mlrun/projects/pipelines.py +33 -3
  19. mlrun/projects/project.py +13 -16
  20. mlrun/run.py +37 -5
  21. mlrun/runtimes/nuclio/serving.py +14 -5
  22. mlrun/serving/__init__.py +2 -0
  23. mlrun/serving/server.py +156 -26
  24. mlrun/serving/states.py +215 -18
  25. mlrun/serving/system_steps.py +391 -0
  26. mlrun/serving/v2_serving.py +9 -8
  27. mlrun/utils/helpers.py +18 -0
  28. mlrun/utils/version/version.json +2 -2
  29. {mlrun-1.10.0rc7.dist-info → mlrun-1.10.0rc8.dist-info}/METADATA +8 -8
  30. {mlrun-1.10.0rc7.dist-info → mlrun-1.10.0rc8.dist-info}/RECORD +34 -33
  31. {mlrun-1.10.0rc7.dist-info → mlrun-1.10.0rc8.dist-info}/WHEEL +0 -0
  32. {mlrun-1.10.0rc7.dist-info → mlrun-1.10.0rc8.dist-info}/entry_points.txt +0 -0
  33. {mlrun-1.10.0rc7.dist-info → mlrun-1.10.0rc8.dist-info}/licenses/LICENSE +0 -0
  34. {mlrun-1.10.0rc7.dist-info → mlrun-1.10.0rc8.dist-info}/top_level.txt +0 -0
mlrun/__init__.py CHANGED
@@ -61,6 +61,7 @@ from .run import (
61
61
  import_function,
62
62
  new_function,
63
63
  retry_pipeline,
64
+ terminate_pipeline,
64
65
  wait_for_pipeline_completion,
65
66
  )
66
67
  from .runtimes import mounts, new_model_server
@@ -217,5 +218,6 @@ def set_env_from_file(env_file: str, return_dict: bool = False) -> Optional[dict
217
218
  for key, value in env_vars.items():
218
219
  environ[key] = value
219
220
 
220
- mlconf.reload() # reload mlrun configuration
221
+ # reload mlrun configuration
222
+ mlconf.reload()
221
223
  return env_vars if return_dict else None
@@ -22,6 +22,10 @@ import mlrun.common.types
22
22
  from .object import ObjectKind
23
23
 
24
24
 
25
+ class BackGroundTaskLabel(mlrun.common.types.StrEnum):
26
+ pipeline = "pipeline"
27
+
28
+
25
29
  class BackgroundTaskState(mlrun.common.types.StrEnum):
26
30
  succeeded = "succeeded"
27
31
  failed = "failed"
@@ -37,6 +41,7 @@ class BackgroundTaskState(mlrun.common.types.StrEnum):
37
41
 
38
42
  class BackgroundTaskMetadata(pydantic.v1.BaseModel):
39
43
  name: str
44
+ id: typing.Optional[int]
40
45
  kind: typing.Optional[str]
41
46
  project: typing.Optional[str]
42
47
  created: typing.Optional[datetime.datetime]
@@ -28,6 +28,7 @@ from .constants import (
28
28
  ModelEndpointCreationStrategy,
29
29
  ModelEndpointMonitoringMetricType,
30
30
  ModelEndpointSchema,
31
+ ModelMonitoringAppLabel,
31
32
  ModelMonitoringMode,
32
33
  MonitoringFunctionNames,
33
34
  PredictionsQueryConstants,
@@ -36,6 +37,7 @@ from .constants import (
36
37
  ResultKindApp,
37
38
  ResultStatusApp,
38
39
  SpecialApps,
40
+ StreamProcessingEvent,
39
41
  TDEngineSuperTables,
40
42
  TSDBTarget,
41
43
  V3IOTSDBTables,
@@ -142,6 +142,22 @@ class EventFieldType:
142
142
  EFFECTIVE_SAMPLE_COUNT = "effective_sample_count"
143
143
 
144
144
 
145
+ class StreamProcessingEvent:
146
+ MODEL = "model"
147
+ MODEL_CLASS = "model_class"
148
+ MICROSEC = "microsec"
149
+ WHEN = "when"
150
+ ERROR = "error"
151
+ ENDPOINT_ID = "endpoint_id"
152
+ SAMPLING_PERCENTAGE = "sampling_percentage"
153
+ EFFECTIVE_SAMPLE_COUNT = "effective_sample_count"
154
+ LABELS = "labels"
155
+ FUNCTION_URI = "function_uri"
156
+ REQUEST = "request"
157
+ RESPONSE = "resp"
158
+ METRICS = "metrics"
159
+
160
+
145
161
  class FeatureSetFeatures(MonitoringStrEnum):
146
162
  LATENCY = EventFieldType.LATENCY
147
163
  METRICS = EventFieldType.METRICS
@@ -148,6 +148,10 @@ class ProjectSummary(pydantic.v1.BaseModel):
148
148
  datasets_count: int = 0
149
149
  documents_count: int = 0
150
150
  llm_prompts_count: int = 0
151
+ running_model_monitoring_functions: int = 0
152
+ failed_model_monitoring_functions: int = 0
153
+ real_time_model_endpoint_count: int = 0
154
+ batch_model_endpoint_count: int = 0
151
155
 
152
156
 
153
157
  class IguazioProject(pydantic.v1.BaseModel):
@@ -33,7 +33,9 @@ class MonitoringData(StrEnum):
33
33
  INPUTS = "inputs"
34
34
  OUTPUTS = "outputs"
35
35
  INPUT_PATH = "input_path"
36
+ RESULT_PATH = "result_path"
36
37
  CREATION_STRATEGY = "creation_strategy"
37
38
  LABELS = "labels"
38
39
  MODEL_PATH = "model_path"
39
40
  MODEL_ENDPOINT_UID = "model_endpoint_uid"
41
+ MODEL_CLASS = "model_class"
mlrun/config.py CHANGED
@@ -107,6 +107,8 @@ default_config = {
107
107
  "submit_timeout": "280", # timeout when submitting a new k8s resource
108
108
  # runtimes cleanup interval in seconds
109
109
  "runtimes_cleanup_interval": "300",
110
+ "background_task_cleanup_interval": "86400", # 24 hours in seconds
111
+ "background_task_max_age": "21600", # 6 hours in seconds
110
112
  "monitoring": {
111
113
  "runs": {
112
114
  # runs monitoring interval in seconds
@@ -233,6 +235,7 @@ default_config = {
233
235
  "delete_function": "900",
234
236
  "model_endpoint_creation": "600",
235
237
  "model_endpoint_tsdb_leftovers": "900",
238
+ "terminate_pipeline": "300",
236
239
  },
237
240
  "runtimes": {
238
241
  "dask": "600",
@@ -638,6 +641,7 @@ default_config = {
638
641
  "offline_storage_path": "model-endpoints/{kind}",
639
642
  "parquet_batching_max_events": 10_000,
640
643
  "parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(),
644
+ "model_endpoint_creation_check_period": "15",
641
645
  },
642
646
  "secret_stores": {
643
647
  # Use only in testing scenarios (such as integration tests) to avoid using k8s for secrets (will use in-memory
@@ -896,11 +900,7 @@ class Config:
896
900
  return result
897
901
 
898
902
  def __setattr__(self, attr, value):
899
- # in order for the dbpath setter to work
900
- if attr == "dbpath":
901
- super().__setattr__(attr, value)
902
- else:
903
- self._cfg[attr] = value
903
+ self._cfg[attr] = value
904
904
 
905
905
  def __dir__(self):
906
906
  return list(self._cfg) + dir(self.__class__)
@@ -1244,23 +1244,6 @@ class Config:
1244
1244
  # since the property will need to be url, which exists in other structs as well
1245
1245
  return config.ui.url or config.ui_url
1246
1246
 
1247
- @property
1248
- def dbpath(self):
1249
- return self._dbpath
1250
-
1251
- @dbpath.setter
1252
- def dbpath(self, value):
1253
- self._dbpath = value
1254
- if value:
1255
- # importing here to avoid circular dependency
1256
- import mlrun.db
1257
-
1258
- # It ensures that SSL verification is set before establishing a connection
1259
- _configure_ssl_verification(self.httpdb.http.verify)
1260
-
1261
- # when dbpath is set we want to connect to it which will sync configuration from it to the client
1262
- mlrun.db.get_run_db(value, force_reconnect=True)
1263
-
1264
1247
  def is_api_running_on_k8s(self):
1265
1248
  # determine if the API service is attached to K8s cluster
1266
1249
  # when there is a cluster the .namespace is set
@@ -1436,6 +1419,12 @@ def _do_populate(env=None, skip_errors=False):
1436
1419
  _configure_ssl_verification(config.httpdb.http.verify)
1437
1420
  _validate_config(config)
1438
1421
 
1422
+ if config.dbpath:
1423
+ from mlrun.db import get_run_db
1424
+
1425
+ # when dbpath is set we want to connect to it which will sync configuration from it to the client
1426
+ get_run_db(config.dbpath, force_reconnect=True)
1427
+
1439
1428
 
1440
1429
  def _validate_config(config):
1441
1430
  try:
mlrun/datastore/utils.py CHANGED
@@ -236,9 +236,11 @@ class KafkaParameters:
236
236
  "partitions": "",
237
237
  "sasl": "",
238
238
  "worker_allocation_mode": "",
239
- "tls_enable": "", # for Nuclio with Confluent Kafka (Sarama client)
239
+ # for Nuclio with Confluent Kafka
240
+ "tls_enable": "",
240
241
  "tls": "",
241
242
  "new_topic": "",
243
+ "nuclio_annotations": "",
242
244
  }
243
245
  self._reference_dicts = (
244
246
  self._custom_attributes,
mlrun/db/base.py CHANGED
@@ -439,31 +439,32 @@ class RunDBInterface(ABC):
439
439
  ) -> dict:
440
440
  pass
441
441
 
442
+ # TODO: remove in 1.10.0
443
+ @deprecated(
444
+ version="1.7.0",
445
+ reason="'list_features' will be removed in 1.10.0, use 'list_features_v2' instead",
446
+ category=FutureWarning,
447
+ )
442
448
  @abstractmethod
443
- def list_features_v2(
449
+ def list_features(
444
450
  self,
445
451
  project: str,
446
452
  name: Optional[str] = None,
447
453
  tag: Optional[str] = None,
448
454
  entities: Optional[list[str]] = None,
449
455
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
450
- ) -> mlrun.common.schemas.FeaturesOutputV2:
456
+ ) -> mlrun.common.schemas.FeaturesOutput:
451
457
  pass
452
458
 
453
- # TODO: remove in 1.10.0
454
- @deprecated(
455
- version="1.7.0",
456
- reason="'list_entities' will be removed in 1.10.0, use 'list_entities_v2' instead",
457
- category=FutureWarning,
458
- )
459
459
  @abstractmethod
460
- def list_entities(
460
+ def list_features_v2(
461
461
  self,
462
462
  project: str,
463
463
  name: Optional[str] = None,
464
464
  tag: Optional[str] = None,
465
+ entities: Optional[list[str]] = None,
465
466
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
466
- ) -> mlrun.common.schemas.EntitiesOutput:
467
+ ) -> mlrun.common.schemas.FeaturesOutputV2:
467
468
  pass
468
469
 
469
470
  @abstractmethod
mlrun/db/httpdb.py CHANGED
@@ -46,6 +46,7 @@ import mlrun.utils
46
46
  from mlrun.alerts.alert import AlertConfig
47
47
  from mlrun.db.auth_utils import OAuthClientIDTokenProvider, StaticTokenProvider
48
48
  from mlrun.errors import MLRunInvalidArgumentError, err_to_str
49
+ from mlrun.secrets import get_secret_or_env
49
50
  from mlrun_pipelines.utils import compile_pipeline
50
51
 
51
52
  from ..artifacts import Artifact
@@ -156,9 +157,9 @@ class HTTPRunDB(RunDBInterface):
156
157
 
157
158
  if config.auth_with_client_id.enabled:
158
159
  self.token_provider = OAuthClientIDTokenProvider(
159
- token_endpoint=mlrun.get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
160
- client_id=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
161
- client_secret=mlrun.get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
160
+ token_endpoint=get_secret_or_env("MLRUN_AUTH_TOKEN_ENDPOINT"),
161
+ client_id=get_secret_or_env("MLRUN_AUTH_CLIENT_ID"),
162
+ client_secret=get_secret_or_env("MLRUN_AUTH_CLIENT_SECRET"),
162
163
  timeout=config.auth_with_client_id.request_timeout,
163
164
  )
164
165
  else:
@@ -2352,8 +2353,7 @@ class HTTPRunDB(RunDBInterface):
2352
2353
  ):
2353
2354
  """
2354
2355
  Retry a specific pipeline run using its run ID. This function sends an API request
2355
- to retry a pipeline run. If a project is specified, the run must belong to that
2356
- project; otherwise, all projects are queried.
2356
+ to retry a pipeline run.
2357
2357
 
2358
2358
  :param run_id: The unique ID of the pipeline run to retry.
2359
2359
  :param namespace: Kubernetes namespace where the pipeline is running. Optional.
@@ -2394,7 +2394,7 @@ class HTTPRunDB(RunDBInterface):
2394
2394
  namespace=namespace,
2395
2395
  response_code=resp_code,
2396
2396
  response_text=resp_text,
2397
- error=str(exc),
2397
+ error=err_to_str(exc),
2398
2398
  )
2399
2399
  if isinstance(exc, mlrun.errors.MLRunHTTPError):
2400
2400
  raise exc # Re-raise known HTTP errors
@@ -2410,6 +2410,72 @@ class HTTPRunDB(RunDBInterface):
2410
2410
  )
2411
2411
  return resp.json()
2412
2412
 
2413
+ def terminate_pipeline(
2414
+ self,
2415
+ run_id: str,
2416
+ project: str,
2417
+ namespace: Optional[str] = None,
2418
+ timeout: int = 30,
2419
+ ):
2420
+ """
2421
+ Terminate a specific pipeline run using its run ID. This function sends an API request
2422
+ to terminate a pipeline run.
2423
+
2424
+ :param run_id: The unique ID of the pipeline run to terminate.
2425
+ :param namespace: Kubernetes namespace where the pipeline is running. Optional.
2426
+ :param timeout: Timeout (in seconds) for the API call. Defaults to 30 seconds.
2427
+ :param project: Name of the MLRun project associated with the pipeline.
2428
+
2429
+ :raises ValueError: Raised if the API response is not successful or contains an
2430
+ error.
2431
+
2432
+ :return: JSON response containing details of the terminate pipeline run background task.
2433
+ """
2434
+
2435
+ params = {}
2436
+ if namespace:
2437
+ params["namespace"] = namespace
2438
+
2439
+ resp_text = ""
2440
+ resp_code = None
2441
+ try:
2442
+ resp = self.api_call(
2443
+ "POST",
2444
+ f"projects/{project}/pipelines/{run_id}/terminate",
2445
+ params=params,
2446
+ timeout=timeout,
2447
+ )
2448
+ resp_code = resp.status_code
2449
+ resp_text = resp.text
2450
+ if not resp.ok:
2451
+ raise mlrun.errors.MLRunHTTPError(
2452
+ f"Failed to retry pipeline run '{run_id}'. "
2453
+ f"HTTP {resp_code}: {resp_text}"
2454
+ )
2455
+ except Exception as exc:
2456
+ logger.error(
2457
+ "Failed to invoke terminate pipeline API",
2458
+ run_id=run_id,
2459
+ project=project,
2460
+ namespace=namespace,
2461
+ response_code=resp_code,
2462
+ response_text=resp_text,
2463
+ error=err_to_str(exc),
2464
+ )
2465
+ if isinstance(exc, mlrun.errors.MLRunHTTPError):
2466
+ raise exc # Re-raise known HTTP errors
2467
+ raise mlrun.errors.MLRunRuntimeError(
2468
+ f"Unexpected error while terminating pipeline run '{run_id}'."
2469
+ ) from exc
2470
+
2471
+ logger.info(
2472
+ "Successfully scheduled terminate pipeline run background task",
2473
+ run_id=run_id,
2474
+ project=project,
2475
+ namespace=namespace,
2476
+ )
2477
+ return resp.json()
2478
+
2413
2479
  @staticmethod
2414
2480
  def _resolve_reference(tag, uid):
2415
2481
  if uid and tag:
@@ -2478,14 +2544,14 @@ class HTTPRunDB(RunDBInterface):
2478
2544
  resp = self.api_call("GET", path, error_message)
2479
2545
  return FeatureSet.from_dict(resp.json())
2480
2546
 
2481
- def list_features_v2(
2547
+ def list_features(
2482
2548
  self,
2483
2549
  project: Optional[str] = None,
2484
2550
  name: Optional[str] = None,
2485
2551
  tag: Optional[str] = None,
2486
2552
  entities: Optional[list[str]] = None,
2487
2553
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
2488
- ) -> dict[str, list[dict]]:
2554
+ ) -> list[dict]:
2489
2555
  """List feature-sets which contain specific features. This function may return multiple versions of the same
2490
2556
  feature-set if a specific tag is not requested. Note that the various filters of this function actually
2491
2557
  refer to the feature-set object containing the features, not to the features themselves.
@@ -2502,7 +2568,9 @@ class HTTPRunDB(RunDBInterface):
2502
2568
  or just `"label"` for key existence.
2503
2569
  - A comma-separated string formatted as `"label1=value1,label2"` to match entities with
2504
2570
  the specified key-value pairs or key existence.
2505
- :returns: A list of features, and a list of their corresponding feature sets.
2571
+ :returns: A list of mapping from feature to a digest of the feature-set, which contains the feature-set
2572
+ meta-data. Multiple entries may be returned for any specific feature due to multiple tags or versions
2573
+ of the feature-set.
2506
2574
  """
2507
2575
 
2508
2576
  project = project or config.active_project
@@ -2517,31 +2585,34 @@ class HTTPRunDB(RunDBInterface):
2517
2585
  path = f"projects/{project}/features"
2518
2586
 
2519
2587
  error_message = f"Failed listing features, project: {project}, query: {params}"
2520
- resp = self.api_call("GET", path, error_message, params=params, version="v2")
2521
- return resp.json()
2588
+ resp = self.api_call("GET", path, error_message, params=params)
2589
+ return resp.json()["features"]
2522
2590
 
2523
- def list_entities(
2591
+ def list_features_v2(
2524
2592
  self,
2525
2593
  project: Optional[str] = None,
2526
2594
  name: Optional[str] = None,
2527
2595
  tag: Optional[str] = None,
2596
+ entities: Optional[list[str]] = None,
2528
2597
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
2529
- ) -> list[dict]:
2530
- """Retrieve a list of entities and their mapping to the containing feature-sets. This function is similar
2531
- to the :py:func:`~list_features` function, and uses the same logic. However, the entities are matched
2532
- against the name rather than the features.
2598
+ ) -> dict[str, list[dict]]:
2599
+ """List feature-sets which contain specific features. This function may return multiple versions of the same
2600
+ feature-set if a specific tag is not requested. Note that the various filters of this function actually
2601
+ refer to the feature-set object containing the features, not to the features themselves.
2533
2602
 
2534
- :param project: The project containing the entities.
2535
- :param name: The name of the entities to retrieve.
2536
- :param tag: The tag of the specific entity version to retrieve.
2537
- :param labels: Filter entities by label key-value pairs or key existence. This can be provided as:
2603
+ :param project: Project which contains these features.
2604
+ :param name: Name of the feature to look for. The name is used in a like query, and is not case-sensitive. For
2605
+ example, looking for ``feat`` will return features which are named ``MyFeature`` as well as ``defeat``.
2606
+ :param tag: Return feature-sets which contain the features looked for, and are tagged with the specific tag.
2607
+ :param entities: Return only feature-sets which contain an entity whose name is contained in this list.
2608
+ :param labels: Filter feature-sets by label key-value pairs or key existence. This can be provided as:
2538
2609
  - A dictionary in the format `{"label": "value"}` to match specific label key-value pairs,
2539
2610
  or `{"label": None}` to check for key existence.
2540
2611
  - A list of strings formatted as `"label=value"` to match specific label key-value pairs,
2541
2612
  or just `"label"` for key existence.
2542
2613
  - A comma-separated string formatted as `"label1=value1,label2"` to match entities with
2543
2614
  the specified key-value pairs or key existence.
2544
- :returns: A list of entities.
2615
+ :returns: A list of features, and a list of their corresponding feature sets.
2545
2616
  """
2546
2617
 
2547
2618
  project = project or config.active_project
@@ -2549,14 +2620,15 @@ class HTTPRunDB(RunDBInterface):
2549
2620
  params = {
2550
2621
  "name": name,
2551
2622
  "tag": tag,
2623
+ "entity": entities or [],
2552
2624
  "label": labels,
2553
2625
  }
2554
2626
 
2555
- path = f"projects/{project}/entities"
2627
+ path = f"projects/{project}/features"
2556
2628
 
2557
- error_message = f"Failed listing entities, project: {project}, query: {params}"
2558
- resp = self.api_call("GET", path, error_message, params=params)
2559
- return resp.json()["entities"]
2629
+ error_message = f"Failed listing features, project: {project}, query: {params}"
2630
+ resp = self.api_call("GET", path, error_message, params=params, version="v2")
2631
+ return resp.json()
2560
2632
 
2561
2633
  def list_entities_v2(
2562
2634
  self,
mlrun/db/nopdb.py CHANGED
@@ -371,23 +371,24 @@ class NopDB(RunDBInterface):
371
371
  ) -> dict:
372
372
  pass
373
373
 
374
- def list_features_v2(
374
+ def list_features(
375
375
  self,
376
376
  project: str,
377
377
  name: Optional[str] = None,
378
378
  tag: Optional[str] = None,
379
379
  entities: Optional[list[str]] = None,
380
380
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
381
- ) -> mlrun.common.schemas.FeaturesOutputV2:
381
+ ) -> mlrun.common.schemas.FeaturesOutput:
382
382
  pass
383
383
 
384
- def list_entities(
384
+ def list_features_v2(
385
385
  self,
386
386
  project: str,
387
387
  name: Optional[str] = None,
388
388
  tag: Optional[str] = None,
389
+ entities: Optional[list[str]] = None,
389
390
  labels: Optional[Union[str, dict[str, Optional[str]], list[str]]] = None,
390
- ) -> mlrun.common.schemas.EntitiesOutput:
391
+ ) -> mlrun.common.schemas.FeaturesOutputV2:
391
392
  pass
392
393
 
393
394
  def list_entities_v2(
@@ -14,7 +14,7 @@
14
14
 
15
15
  from typing import Any, Optional, Union
16
16
 
17
- from tensorflow import keras
17
+ import tensorflow as tf
18
18
 
19
19
  import mlrun
20
20
  import mlrun.common.constants as mlrun_constants
@@ -27,11 +27,11 @@ from .utils import TFKerasTypes, TFKerasUtils
27
27
 
28
28
 
29
29
  def apply_mlrun(
30
- model: keras.Model = None,
30
+ model: tf.keras.Model = None,
31
31
  model_name: Optional[str] = None,
32
32
  tag: str = "",
33
33
  model_path: Optional[str] = None,
34
- model_format: str = TFKerasModelHandler.ModelFormats.SAVED_MODEL,
34
+ model_format: Optional[str] = None,
35
35
  save_traces: bool = False,
36
36
  modules_map: Optional[Union[dict[str, Union[None, str, list[str]]], str]] = None,
37
37
  custom_objects_map: Optional[Union[dict[str, Union[str, list[str]]], str]] = None,
@@ -54,7 +54,7 @@ def apply_mlrun(
54
54
  :param model_path: The model's store object path. Mandatory for evaluation (to know which model to
55
55
  update). If model is not provided, it will be loaded from this path.
56
56
  :param model_format: The format to use for saving and loading the model. Should be passed as a
57
- member of the class 'ModelFormats'. Default: 'ModelFormats.SAVED_MODEL'.
57
+ member of the class 'ModelFormats'.
58
58
  :param save_traces: Whether or not to use functions saving (only available for the 'SavedModel'
59
59
  format) for loading the model later without the custom objects dictionary. Only
60
60
  from tensorflow version >= 2.4.0. Using this setting will increase the model
@@ -16,14 +16,14 @@ from typing import Callable, Optional, Union
16
16
 
17
17
  import numpy as np
18
18
  import tensorflow as tf
19
- from tensorflow import Tensor, Variable
19
+ from tensorflow import keras
20
20
  from tensorflow.python.keras.callbacks import Callback
21
21
 
22
22
  import mlrun
23
23
 
24
24
  from ..._common import LoggingMode
25
25
  from ..._dl_common.loggers import Logger
26
- from ..utils import TFKerasTypes
26
+ from ..utils import TFKerasTypes, is_keras_3
27
27
 
28
28
 
29
29
  class LoggingCallback(Callback):
@@ -70,7 +70,7 @@ class LoggingCallback(Callback):
70
70
  {
71
71
  "epochs": 7
72
72
  }
73
- :param auto_log: Whether or not to enable auto logging, trying to track common static and dynamic
73
+ :param auto_log: Whether to enable auto logging, trying to track common static and dynamic
74
74
  hyperparameters.
75
75
  """
76
76
  super().__init__()
@@ -385,18 +385,24 @@ class LoggingCallback(Callback):
385
385
  self._logger.log_context_parameters()
386
386
 
387
387
  # Add learning rate:
388
- learning_rate_key = "lr"
389
- learning_rate_key_chain = ["optimizer", "lr"]
390
- if learning_rate_key not in self._dynamic_hyperparameters_keys and hasattr(
391
- self.model, "optimizer"
392
- ):
393
- try:
394
- self._get_hyperparameter(key_chain=learning_rate_key_chain)
388
+ learning_rate_keys = [
389
+ "learning_rate",
390
+ "lr",
391
+ ] # "lr" is for backward compatibility in older keras versions.
392
+ if all(
393
+ learning_rate_key not in self._dynamic_hyperparameters_keys
394
+ for learning_rate_key in learning_rate_keys
395
+ ) and hasattr(self.model, "optimizer"):
396
+ for learning_rate_key in learning_rate_keys:
397
+ learning_rate_key_chain = ["optimizer", learning_rate_key]
398
+ try:
399
+ self._get_hyperparameter(key_chain=learning_rate_key_chain)
400
+ except (KeyError, IndexError, AttributeError, ValueError):
401
+ continue
395
402
  self._dynamic_hyperparameters_keys[learning_rate_key] = (
396
403
  learning_rate_key_chain
397
404
  )
398
- except (KeyError, IndexError, ValueError):
399
- pass
405
+ break
400
406
 
401
407
  def _get_hyperparameter(
402
408
  self,
@@ -427,7 +433,7 @@ class LoggingCallback(Callback):
427
433
  value = value[key]
428
434
  else:
429
435
  value = getattr(value, key)
430
- except KeyError or IndexError as KeyChainError:
436
+ except KeyError or IndexError or AttributeError as KeyChainError:
431
437
  raise KeyChainError(
432
438
  f"Error during getting a hyperparameter value with the key chain {key_chain}. "
433
439
  f"The {value.__class__} in it does not have the following key/index from the key provided: "
@@ -435,7 +441,9 @@ class LoggingCallback(Callback):
435
441
  )
436
442
 
437
443
  # Parse the value:
438
- if isinstance(value, Tensor) or isinstance(value, Variable):
444
+ if isinstance(value, (tf.Tensor, tf.Variable)) or (
445
+ is_keras_3() and isinstance(value, (keras.KerasTensor, keras.Variable))
446
+ ):
439
447
  if int(tf.size(value)) == 1:
440
448
  value = float(value)
441
449
  else:
@@ -451,12 +459,7 @@ class LoggingCallback(Callback):
451
459
  f"The parameter with the following key chain: {key_chain} is a numpy.ndarray with {value.size} "
452
460
  f"elements. numpy arrays are trackable only if they have 1 element."
453
461
  )
454
- elif not (
455
- isinstance(value, float)
456
- or isinstance(value, int)
457
- or isinstance(value, str)
458
- or isinstance(value, bool)
459
- ):
462
+ elif not (isinstance(value, (float, int, str, bool))):
460
463
  raise mlrun.errors.MLRunInvalidArgumentError(
461
464
  f"The parameter with the following key chain: {key_chain} is of type '{type(value)}'. The only "
462
465
  f"trackable types are: float, int, str and bool."