mlrun 1.6.0rc13__py3-none-any.whl → 1.6.0rc15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (37) hide show
  1. mlrun/__main__.py +7 -2
  2. mlrun/artifacts/__init__.py +7 -1
  3. mlrun/artifacts/base.py +38 -3
  4. mlrun/artifacts/dataset.py +1 -1
  5. mlrun/artifacts/manager.py +5 -5
  6. mlrun/artifacts/model.py +1 -1
  7. mlrun/common/schemas/__init__.py +8 -1
  8. mlrun/common/schemas/artifact.py +36 -1
  9. mlrun/config.py +11 -0
  10. mlrun/datastore/azure_blob.py +37 -79
  11. mlrun/datastore/datastore_profile.py +2 -1
  12. mlrun/datastore/store_resources.py +2 -3
  13. mlrun/datastore/targets.py +3 -3
  14. mlrun/db/base.py +8 -5
  15. mlrun/db/httpdb.py +151 -71
  16. mlrun/db/nopdb.py +6 -3
  17. mlrun/feature_store/feature_vector.py +1 -1
  18. mlrun/feature_store/steps.py +2 -2
  19. mlrun/frameworks/_common/model_handler.py +1 -1
  20. mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +0 -1
  21. mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +1 -1
  22. mlrun/frameworks/sklearn/metric.py +0 -1
  23. mlrun/frameworks/tf_keras/mlrun_interface.py +1 -2
  24. mlrun/model_monitoring/application.py +20 -27
  25. mlrun/projects/pipelines.py +5 -5
  26. mlrun/projects/project.py +3 -3
  27. mlrun/runtimes/constants.py +10 -0
  28. mlrun/runtimes/local.py +2 -3
  29. mlrun/utils/db.py +6 -5
  30. mlrun/utils/helpers.py +53 -9
  31. mlrun/utils/version/version.json +2 -2
  32. {mlrun-1.6.0rc13.dist-info → mlrun-1.6.0rc15.dist-info}/METADATA +26 -30
  33. {mlrun-1.6.0rc13.dist-info → mlrun-1.6.0rc15.dist-info}/RECORD +37 -37
  34. {mlrun-1.6.0rc13.dist-info → mlrun-1.6.0rc15.dist-info}/LICENSE +0 -0
  35. {mlrun-1.6.0rc13.dist-info → mlrun-1.6.0rc15.dist-info}/WHEEL +0 -0
  36. {mlrun-1.6.0rc13.dist-info → mlrun-1.6.0rc15.dist-info}/entry_points.txt +0 -0
  37. {mlrun-1.6.0rc13.dist-info → mlrun-1.6.0rc15.dist-info}/top_level.txt +0 -0
mlrun/db/httpdb.py CHANGED
@@ -18,6 +18,7 @@ import tempfile
18
18
  import time
19
19
  import traceback
20
20
  import typing
21
+ import warnings
21
22
  from datetime import datetime, timedelta
22
23
  from os import path, remove
23
24
  from typing import Dict, List, Optional, Union
@@ -280,9 +281,12 @@ class HTTPRunDB(RunDBInterface):
280
281
  retry_on_post=retry_on_post,
281
282
  )
282
283
 
283
- def _path_of(self, prefix, project, uid):
284
+ def _path_of(self, resource, project, uid=None):
284
285
  project = project or config.default_project
285
- return f"{prefix}/{project}/{uid}"
286
+ _path = f"projects/{project}/{resource}"
287
+ if uid:
288
+ _path += f"/{uid}"
289
+ return _path
286
290
 
287
291
  def _is_retry_on_post_allowed(self, method, path: str):
288
292
  """
@@ -477,28 +481,37 @@ class HTTPRunDB(RunDBInterface):
477
481
  if not body:
478
482
  return
479
483
 
480
- path = self._path_of("log", project, uid)
484
+ path = self._path_of("logs", project, uid)
481
485
  params = {"append": bool2str(append)}
482
486
  error = f"store log {project}/{uid}"
483
487
  self.api_call("POST", path, error, params, body)
484
488
 
485
- def get_log(self, uid, project="", offset=0, size=-1):
486
- """Retrieve a log.
489
+ def get_log(self, uid, project="", offset=0, size=None):
490
+ """Retrieve 1 MB data of log.
487
491
 
488
492
  :param uid: Log unique ID
489
493
  :param project: Project name for which the log belongs
490
494
  :param offset: Retrieve partial log, get up to ``size`` bytes starting at offset ``offset``
491
495
  from beginning of log
492
- :param size: See ``offset``. If set to ``-1`` (the default) will retrieve all data to end of log.
496
+ :param size: If set to ``-1`` will retrieve and print all data to end of the log by chunks of 1MB each.
493
497
  :returns: The following objects:
494
498
 
495
499
  - state - The state of the runtime object which generates this log, if it exists. In case no known state
496
500
  exists, this will be ``unknown``.
497
501
  - content - The actual log content.
502
+ * in case size = -1, return the state and the final offset
498
503
  """
504
+ if size is None:
505
+ size = int(mlrun.mlconf.httpdb.logs.pull_logs_default_size_limit)
506
+ elif size == -1:
507
+ logger.warning(
508
+ "Retrieving all logs. This may be inefficient and can result in a large log."
509
+ )
510
+ state, offset = self.watch_log(uid, project, watch=False, offset=offset)
511
+ return state, offset
499
512
 
500
513
  params = {"offset": offset, "size": size}
501
- path = self._path_of("log", project, uid)
514
+ path = self._path_of("logs", project, uid)
502
515
  error = f"get log {project}/{uid}"
503
516
  resp = self.api_call("GET", path, error, params=params)
504
517
  if resp.headers:
@@ -508,46 +521,50 @@ class HTTPRunDB(RunDBInterface):
508
521
  return "unknown", resp.content
509
522
 
510
523
  def watch_log(self, uid, project="", watch=True, offset=0):
511
- """Retrieve logs of a running process, and watch the progress of the execution until it completes. This
512
- method will print out the logs and continue to periodically poll for, and print, new logs as long as the
513
- state of the runtime which generates this log is either ``pending`` or ``running``.
524
+ """Retrieve logs of a running process by chunks of 1MB, and watch the progress of the execution until it
525
+ completes. This method will print out the logs and continue to periodically poll for, and print,
526
+ new logs as long as the state of the runtime which generates this log is either ``pending`` or ``running``.
514
527
 
515
528
  :param uid: The uid of the log object to watch.
516
529
  :param project: Project that the log belongs to.
517
530
  :param watch: If set to ``True`` will continue tracking the log as described above. Otherwise this function
518
531
  is practically equivalent to the :py:func:`~get_log` function.
519
532
  :param offset: Minimal offset in the log to watch.
520
- :returns: The final state of the log being watched.
533
+ :returns: The final state of the log being watched and the final offset.
521
534
  """
522
535
 
523
536
  state, text = self.get_log(uid, project, offset=offset)
524
537
  if text:
525
538
  print(text.decode(errors=mlrun.mlconf.httpdb.logs.decode.errors))
526
- if watch:
527
- nil_resp = 0
528
- while state in ["pending", "running"]:
529
- offset += len(text)
530
- # if we get 3 nil responses in a row, increase the sleep time to 10 seconds
531
- # TODO: refactor this to use a conditional backoff mechanism
532
- if nil_resp < 3:
533
- time.sleep(int(mlrun.mlconf.httpdb.logs.pull_logs_default_interval))
534
- else:
535
- time.sleep(
536
- int(
537
- mlrun.mlconf.httpdb.logs.pull_logs_backoff_no_logs_default_interval
538
- )
539
- )
540
- state, text = self.get_log(uid, project, offset=offset)
541
- if text:
542
- nil_resp = 0
543
- print(
544
- text.decode(errors=mlrun.mlconf.httpdb.logs.decode.errors),
545
- end="",
546
- )
547
- else:
548
- nil_resp += 1
549
- else:
539
+ nil_resp = 0
540
+ while True:
550
541
  offset += len(text)
542
+ # if we get 3 nil responses in a row, increase the sleep time to 10 seconds
543
+ # TODO: refactor this to use a conditional backoff mechanism
544
+ if nil_resp < 3:
545
+ time.sleep(int(mlrun.mlconf.httpdb.logs.pull_logs_default_interval))
546
+ else:
547
+ time.sleep(
548
+ int(
549
+ mlrun.mlconf.httpdb.logs.pull_logs_backoff_no_logs_default_interval
550
+ )
551
+ )
552
+ state, text = self.get_log(uid, project, offset=offset)
553
+ if text:
554
+ nil_resp = 0
555
+ print(
556
+ text.decode(errors=mlrun.mlconf.httpdb.logs.decode.errors),
557
+ end="",
558
+ )
559
+ else:
560
+ nil_resp += 1
561
+
562
+ if watch and state in ["pending", "running"]:
563
+ continue
564
+ else:
565
+ # the whole log was retrieved
566
+ if len(text) == 0:
567
+ break
551
568
 
552
569
  return state, offset
553
570
 
@@ -555,7 +572,7 @@ class HTTPRunDB(RunDBInterface):
555
572
  """Store run details in the DB. This method is usually called from within other :py:mod:`mlrun` flows
556
573
  and not called directly by the user."""
557
574
 
558
- path = self._path_of("run", project, uid)
575
+ path = self._path_of("runs", project, uid)
559
576
  params = {"iter": iter}
560
577
  error = f"store run {project}/{uid}"
561
578
  body = _as_json(struct)
@@ -564,7 +581,7 @@ class HTTPRunDB(RunDBInterface):
564
581
  def update_run(self, updates: dict, uid, project="", iter=0, timeout=45):
565
582
  """Update the details of a stored run in the DB."""
566
583
 
567
- path = self._path_of("run", project, uid)
584
+ path = self._path_of("runs", project, uid)
568
585
  params = {"iter": iter}
569
586
  error = f"update run {project}/{uid}"
570
587
  body = _as_json(updates)
@@ -605,7 +622,7 @@ class HTTPRunDB(RunDBInterface):
605
622
  :param iter: Iteration within a specific execution.
606
623
  """
607
624
 
608
- path = self._path_of("run", project, uid)
625
+ path = self._path_of("runs", project, uid)
609
626
  params = {"iter": iter}
610
627
  error = f"get run {project}/{uid}"
611
628
  resp = self.api_call("GET", path, error, params=params)
@@ -619,7 +636,7 @@ class HTTPRunDB(RunDBInterface):
619
636
  :param iter: Iteration within a specific task.
620
637
  """
621
638
 
622
- path = self._path_of("run", project, uid)
639
+ path = self._path_of("runs", project, uid)
623
640
  params = {"iter": iter}
624
641
  error = f"del run {project}/{uid}"
625
642
  self.api_call("DELETE", path, error, params=params)
@@ -711,7 +728,6 @@ class HTTPRunDB(RunDBInterface):
711
728
  params = {
712
729
  "name": name,
713
730
  "uid": uid,
714
- "project": project,
715
731
  "label": labels or [],
716
732
  "state": state,
717
733
  "sort": bool2str(sort),
@@ -735,7 +751,8 @@ class HTTPRunDB(RunDBInterface):
735
751
  )
736
752
  )
737
753
  error = "list runs"
738
- resp = self.api_call("GET", "runs", error, params=params)
754
+ _path = self._path_of("runs", project)
755
+ resp = self.api_call("GET", _path, error, params=params)
739
756
  return RunList(resp.json()["runs"])
740
757
 
741
758
  def del_runs(self, name=None, project=None, labels=None, state=None, days_ago=0):
@@ -761,56 +778,113 @@ class HTTPRunDB(RunDBInterface):
761
778
  "days_ago": str(days_ago),
762
779
  }
763
780
  error = "del runs"
764
- self.api_call("DELETE", "runs", error, params=params)
781
+ _path = self._path_of("runs", project)
782
+ self.api_call("DELETE", _path, error, params=params)
765
783
 
766
- def store_artifact(self, key, artifact, uid, iter=None, tag=None, project=""):
784
+ def store_artifact(
785
+ self,
786
+ key,
787
+ artifact,
788
+ # TODO: deprecated, remove in 1.8.0
789
+ uid=None,
790
+ iter=None,
791
+ tag=None,
792
+ project="",
793
+ tree=None,
794
+ ):
767
795
  """Store an artifact in the DB.
768
796
 
769
797
  :param key: Identifying key of the artifact.
770
798
  :param artifact: The actual artifact to store.
771
- :param uid: A unique ID for this specific version of the artifact.
799
+ :param uid: A unique ID for this specific version of the artifact
800
+ (deprecated, artifact uid is generated in the backend use `tree` instead)
772
801
  :param iter: The task iteration which generated this artifact. If ``iter`` is not ``None`` the iteration will
773
802
  be added to the key provided to generate a unique key for the artifact of the specific iteration.
774
803
  :param tag: Tag of the artifact.
775
804
  :param project: Project that the artifact belongs to.
805
+ :param tree: The tree (producer id) which generated this artifact.
776
806
  """
807
+ if uid:
808
+ warnings.warn(
809
+ "'uid' is deprecated in 1.6.0 and will be removed in 1.8.0, use 'tree' instead.",
810
+ # TODO: Remove this in 1.8.0
811
+ FutureWarning,
812
+ )
777
813
 
778
- endpoint_path = f"projects/{project}/artifacts/{uid}/{key}"
779
- params = {
780
- "tag": tag,
781
- }
814
+ # we do this because previously the 'uid' name was used for the 'tree' parameter
815
+ tree = tree or uid
816
+
817
+ endpoint_path = f"projects/{project}/artifacts/{key}"
818
+
819
+ error = f"store artifact {project}/{key}"
820
+
821
+ params = {}
782
822
  if iter:
783
823
  params["iter"] = str(iter)
784
-
785
- error = f"store artifact {project}/{uid}/{key}"
824
+ if tag:
825
+ params["tag"] = tag
826
+ if tree:
827
+ params["tree"] = tree
786
828
 
787
829
  body = _as_json(artifact)
788
- self.api_call("POST", endpoint_path, error, params=params, body=body)
830
+ self.api_call(
831
+ "PUT", endpoint_path, error, body=body, params=params, version="v2"
832
+ )
789
833
 
790
- def read_artifact(self, key, tag=None, iter=None, project=""):
791
- """Read an artifact, identified by its key, tag and iteration."""
834
+ def read_artifact(
835
+ self,
836
+ key,
837
+ tag=None,
838
+ iter=None,
839
+ project="",
840
+ tree=None,
841
+ uid=None,
842
+ ):
843
+ """Read an artifact, identified by its key, tag, tree and iteration.
844
+
845
+ :param key: Identifying key of the artifact.
846
+ :param tag: Tag of the artifact.
847
+ :param iter: The iteration which generated this artifact (where ``iter=0`` means the root iteration).
848
+ :param project: Project that the artifact belongs to.
849
+ :param tree: The tree which generated this artifact.
850
+ :param uid: A unique ID for this specific version of the artifact (the uid that was generated in the backend)
851
+ """
792
852
 
793
853
  project = project or config.default_project
794
854
  tag = tag or "latest"
795
- endpoint_path = f"projects/{project}/artifacts/{key}?tag={tag}"
855
+ endpoint_path = f"projects/{project}/artifacts/{key}"
796
856
  error = f"read artifact {project}/{key}"
797
857
  # explicitly set artifacts format to 'full' since old servers may default to 'legacy'
798
- params = {"format": mlrun.common.schemas.ArtifactsFormat.full.value}
858
+ params = {
859
+ "format": mlrun.common.schemas.ArtifactsFormat.full.value,
860
+ "tag": tag,
861
+ "tree": tree,
862
+ "uid": uid,
863
+ }
799
864
  if iter:
800
865
  params["iter"] = str(iter)
801
- resp = self.api_call("GET", endpoint_path, error, params=params)
802
- return resp.json()["data"]
866
+ resp = self.api_call("GET", endpoint_path, error, params=params, version="v2")
867
+ return resp.json()
868
+
869
+ def del_artifact(self, key, tag=None, project="", tree=None, uid=None):
870
+ """Delete an artifact.
803
871
 
804
- def del_artifact(self, key, tag=None, project=""):
805
- """Delete an artifact."""
872
+ :param key: Identifying key of the artifact.
873
+ :param tag: Tag of the artifact.
874
+ :param project: Project that the artifact belongs to.
875
+ :param tree: The tree which generated this artifact.
876
+ :param uid: A unique ID for this specific version of the artifact (the uid that was generated in the backend)
877
+ """
806
878
 
807
879
  endpoint_path = f"projects/{project}/artifacts/{key}"
808
880
  params = {
809
881
  "key": key,
810
882
  "tag": tag,
883
+ "tree": tree,
884
+ "uid": uid,
811
885
  }
812
886
  error = f"del artifact {project}/{key}"
813
- self.api_call("DELETE", endpoint_path, error, params=params)
887
+ self.api_call("DELETE", endpoint_path, error, params=params, version="v2")
814
888
 
815
889
  def list_artifacts(
816
890
  self,
@@ -824,6 +898,7 @@ class HTTPRunDB(RunDBInterface):
824
898
  best_iteration: bool = False,
825
899
  kind: str = None,
826
900
  category: Union[str, mlrun.common.schemas.ArtifactCategories] = None,
901
+ tree: str = None,
827
902
  ) -> ArtifactList:
828
903
  """List artifacts filtered by various parameters.
829
904
 
@@ -852,6 +927,7 @@ class HTTPRunDB(RunDBInterface):
852
927
  from that iteration. If using ``best_iter``, the ``iter`` parameter must not be used.
853
928
  :param kind: Return artifacts of the requested kind.
854
929
  :param category: Return artifacts of the requested category.
930
+ :param tree: Return artifacts of the requested tree.
855
931
  """
856
932
 
857
933
  project = project or config.default_project
@@ -868,16 +944,19 @@ class HTTPRunDB(RunDBInterface):
868
944
  "best-iteration": best_iteration,
869
945
  "kind": kind,
870
946
  "category": category,
947
+ "tree": tree,
871
948
  "format": mlrun.common.schemas.ArtifactsFormat.full.value,
872
949
  }
873
950
  error = "list artifacts"
874
951
  endpoint_path = f"projects/{project}/artifacts"
875
- resp = self.api_call("GET", endpoint_path, error, params=params)
952
+ resp = self.api_call("GET", endpoint_path, error, params=params, version="v2")
876
953
  values = ArtifactList(resp.json()["artifacts"])
877
954
  values.tag = tag
878
955
  return values
879
956
 
880
- def del_artifacts(self, name=None, project=None, tag=None, labels=None, days_ago=0):
957
+ def del_artifacts(
958
+ self, name=None, project=None, tag=None, labels=None, days_ago=0, tree=None
959
+ ):
881
960
  """Delete artifacts referenced by the parameters.
882
961
 
883
962
  :param name: Name of artifacts to delete. Note that this is a like query, and is case-insensitive. See
@@ -891,12 +970,13 @@ class HTTPRunDB(RunDBInterface):
891
970
  params = {
892
971
  "name": name,
893
972
  "tag": tag,
973
+ "tree": tree,
894
974
  "label": labels or [],
895
975
  "days_ago": str(days_ago),
896
976
  }
897
977
  error = "del artifacts"
898
978
  endpoint_path = f"projects/{project}/artifacts"
899
- self.api_call("DELETE", endpoint_path, error, params=params)
979
+ self.api_call("DELETE", endpoint_path, error, params=params, version="v2")
900
980
 
901
981
  def list_artifact_tags(
902
982
  self,
@@ -3414,9 +3494,9 @@ class HTTPRunDB(RunDBInterface):
3414
3494
  self, name: str, project: str
3415
3495
  ) -> Optional[mlrun.common.schemas.DatastoreProfile]:
3416
3496
  project = project or config.default_project
3417
- path = self._path_of("projects", project, "datastore-profiles") + f"/{name}"
3497
+ _path = self._path_of("datastore-profiles", project, name)
3418
3498
 
3419
- res = self.api_call(method="GET", path=path)
3499
+ res = self.api_call(method="GET", path=_path)
3420
3500
  if res:
3421
3501
  public_wrapper = res.json()
3422
3502
  datastore = DatastoreProfile2Json.create_from_json(
@@ -3427,17 +3507,17 @@ class HTTPRunDB(RunDBInterface):
3427
3507
 
3428
3508
  def delete_datastore_profile(self, name: str, project: str):
3429
3509
  project = project or config.default_project
3430
- path = self._path_of("projects", project, "datastore-profiles") + f"/{name}"
3431
- self.api_call(method="DELETE", path=path)
3510
+ _path = self._path_of("datastore-profiles", project, name)
3511
+ self.api_call(method="DELETE", path=_path)
3432
3512
  return None
3433
3513
 
3434
3514
  def list_datastore_profiles(
3435
3515
  self, project: str
3436
3516
  ) -> List[mlrun.common.schemas.DatastoreProfile]:
3437
3517
  project = project or config.default_project
3438
- path = self._path_of("projects", project, "datastore-profiles")
3518
+ _path = self._path_of("datastore-profiles", project)
3439
3519
 
3440
- res = self.api_call(method="GET", path=path)
3520
+ res = self.api_call(method="GET", path=_path)
3441
3521
  if res:
3442
3522
  public_wrapper = res.json()
3443
3523
  datastores = [
@@ -3455,9 +3535,9 @@ class HTTPRunDB(RunDBInterface):
3455
3535
  :returns: None
3456
3536
  """
3457
3537
  project = project or config.default_project
3458
- path = self._path_of("projects", project, "datastore-profiles")
3538
+ _path = self._path_of("datastore-profiles", project)
3459
3539
 
3460
- self.api_call(method="PUT", path=path, json=profile.dict())
3540
+ self.api_call(method="PUT", path=_path, json=profile.dict())
3461
3541
 
3462
3542
 
3463
3543
  def _as_json(obj):
mlrun/db/nopdb.py CHANGED
@@ -104,10 +104,12 @@ class NopDB(RunDBInterface):
104
104
  def del_runs(self, name="", project="", labels=None, state="", days_ago=0):
105
105
  pass
106
106
 
107
- def store_artifact(self, key, artifact, uid, iter=None, tag="", project=""):
107
+ def store_artifact(
108
+ self, key, artifact, uid=None, iter=None, tag="", project="", tree=None
109
+ ):
108
110
  pass
109
111
 
110
- def read_artifact(self, key, tag="", iter=None, project=""):
112
+ def read_artifact(self, key, tag="", iter=None, project="", tree=None, uid=None):
111
113
  pass
112
114
 
113
115
  def list_artifacts(
@@ -122,10 +124,11 @@ class NopDB(RunDBInterface):
122
124
  best_iteration: bool = False,
123
125
  kind: str = None,
124
126
  category: Union[str, mlrun.common.schemas.ArtifactCategories] = None,
127
+ tree: str = None,
125
128
  ):
126
129
  pass
127
130
 
128
- def del_artifact(self, key, tag="", project=""):
131
+ def del_artifact(self, key, tag="", project="", tree=None, uid=None):
129
132
  pass
130
133
 
131
134
  def del_artifacts(self, name="", project="", tag="", labels=None):
@@ -938,7 +938,7 @@ class OnlineVectorService:
938
938
  for name in data.keys():
939
939
  v = data[name]
940
940
  if v is None or (
941
- type(v) == float and (np.isinf(v) or np.isnan(v))
941
+ isinstance(v, float) and (np.isinf(v) or np.isnan(v))
942
942
  ):
943
943
  data[name] = self._impute_values.get(name, v)
944
944
  if not self.vector.spec.with_indexes:
@@ -254,7 +254,7 @@ class MapValues(StepToDict, MLRunStep):
254
254
  source_column_names = df.columns
255
255
  for column, column_map in self.mapping.items():
256
256
  new_column_name = self._get_feature_name(column)
257
- if not self.get_ranges_key() in column_map:
257
+ if self.get_ranges_key() not in column_map:
258
258
  if column not in source_column_names:
259
259
  continue
260
260
  mapping_expr = create_map([lit(x) for x in chain(*column_map.items())])
@@ -330,7 +330,7 @@ class MapValues(StepToDict, MLRunStep):
330
330
  def validate_args(cls, feature_set, **kwargs):
331
331
  mapping = kwargs.get("mapping", [])
332
332
  for column, column_map in mapping.items():
333
- if not cls.get_ranges_key() in column_map:
333
+ if cls.get_ranges_key() not in column_map:
334
334
  types = set(
335
335
  type(val)
336
336
  for val in column_map.values()
@@ -25,7 +25,7 @@ from typing import Any, Dict, Generic, List, Type, Union
25
25
  import numpy as np
26
26
 
27
27
  import mlrun
28
- from mlrun.artifacts import Artifact, ModelArtifact
28
+ from mlrun.artifacts import Artifact
29
29
  from mlrun.execution import MLClientCtx
30
30
  from mlrun.features import Feature
31
31
 
@@ -17,7 +17,6 @@ from abc import ABC
17
17
  import lightgbm as lgb
18
18
 
19
19
  from ..._common import MLRunInterface
20
- from ..._ml_common import MLModelHandler
21
20
  from ..utils import LGBMTypes
22
21
 
23
22
 
@@ -48,7 +48,7 @@ class _MLRunSummaryWriter(SummaryWriter):
48
48
  :param run_name: Not used in this SummaryWriter.
49
49
  """
50
50
  torch._C._log_api_usage_once("tensorboard.logging.add_hparams")
51
- if type(hparam_dict) is not dict or type(metric_dict) is not dict:
51
+ if not isinstance(hparam_dict, dict) or not isinstance(metric_dict, dict):
52
52
  raise TypeError("hparam_dict and metric_dict should be dictionary.")
53
53
  exp, ssi, sei = hparams(hparam_dict, metric_dict)
54
54
  self._get_file_writer().add_summary(exp)
@@ -15,7 +15,6 @@
15
15
  import importlib
16
16
  import json
17
17
  import sys
18
- from types import ModuleType
19
18
  from typing import Callable, Union
20
19
 
21
20
  import mlrun.errors
@@ -15,8 +15,7 @@
15
15
  import importlib
16
16
  import os
17
17
  from abc import ABC
18
- from types import ModuleType
19
- from typing import List, Set, Tuple, Union
18
+ from typing import List, Tuple, Union
20
19
 
21
20
  import tensorflow as tf
22
21
  from tensorflow import keras
@@ -11,11 +11,10 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- #
15
14
 
16
15
  import dataclasses
17
16
  import json
18
- from typing import Any, Dict, List, Tuple, Union
17
+ from typing import Any, Optional, Tuple, Union
19
18
 
20
19
  import numpy as np
21
20
  import pandas as pd
@@ -45,7 +44,7 @@ class ModelMonitoringApplicationResult:
45
44
  :param result_kind: (ResultKindApp) Kind of application result.
46
45
  :param result_status: (ResultStatusApp) Status of the application result.
47
46
  :param result_extra_data: (dict) Extra data associated with the application result.
48
-
47
+ :param _current_stats: (dict) Current statistics of the data.
49
48
  """
50
49
 
51
50
  application_name: str
@@ -57,7 +56,7 @@ class ModelMonitoringApplicationResult:
57
56
  result_kind: mm_constant.ResultKindApp
58
57
  result_status: mm_constant.ResultStatusApp
59
58
  result_extra_data: dict = dataclasses.field(default_factory=dict)
60
- _current_stats: dict = None
59
+ _current_stats: dict = dataclasses.field(default_factory=dict)
61
60
 
62
61
  def to_dict(self):
63
62
  """
@@ -103,7 +102,7 @@ class ModelMonitoringApplication(StepToDict):
103
102
  latest_request: pd.Timestamp,
104
103
  endpoint_id: str,
105
104
  output_stream_uri: str,
106
- ) -> typing.Union[ModelMonitoringApplicationResult, typing.List[ModelMonitoringApplicationResult]
105
+ ) -> Union[ModelMonitoringApplicationResult, list[ModelMonitoringApplicationResult]
107
106
  ]:
108
107
  self.context.log_artifact(TableArtifact("sample_df_stats", df=sample_df_stats))
109
108
  return ModelMonitoringApplicationResult(
@@ -121,23 +120,23 @@ class ModelMonitoringApplication(StepToDict):
121
120
 
122
121
  kind = "monitoring_application"
123
122
 
124
- def do(self, event: Dict[str, Any]):
123
+ def do(self, event: dict[str, Any]) -> list[ModelMonitoringApplicationResult]:
125
124
  """
126
125
  Process the monitoring event and return application results.
127
126
 
128
127
  :param event: (dict) The monitoring event to process.
129
- :returns: (List[ModelMonitoringApplicationResult]) The application results.
128
+ :returns: (list[ModelMonitoringApplicationResult]) The application results.
130
129
  """
131
130
  resolved_event = self._resolve_event(event)
132
131
  if not (
133
132
  hasattr(self, "context") and isinstance(self.context, mlrun.MLClientCtx)
134
133
  ):
135
134
  self._lazy_init(app_name=resolved_event[0])
136
- # Run application and get the result in the required `ModelMonitoringApplicationResult` format
137
- result = self.run_application(*resolved_event)
138
- # Add current stats to the result as provided in the event
139
- result._current_stats = event[mm_constant.ApplicationEvent.CURRENT_STATS]
140
- return result
135
+ results = self.run_application(*resolved_event)
136
+ results = results if isinstance(results, list) else [results]
137
+ for result in results:
138
+ result._current_stats = event[mm_constant.ApplicationEvent.CURRENT_STATS]
139
+ return results
141
140
 
142
141
  def _lazy_init(self, app_name: str):
143
142
  self.context = self._create_context_for_logging(app_name=app_name)
@@ -154,7 +153,7 @@ class ModelMonitoringApplication(StepToDict):
154
153
  endpoint_id: str,
155
154
  output_stream_uri: str,
156
155
  ) -> Union[
157
- ModelMonitoringApplicationResult, List[ModelMonitoringApplicationResult]
156
+ ModelMonitoringApplicationResult, list[ModelMonitoringApplicationResult]
158
157
  ]:
159
158
  """
160
159
  Implement this method with your custom monitoring logic.
@@ -170,13 +169,13 @@ class ModelMonitoringApplication(StepToDict):
170
169
  :param output_stream_uri: (str) URI of the output stream for results
171
170
 
172
171
  :returns: (ModelMonitoringApplicationResult) or
173
- (List[ModelMonitoringApplicationResult]) of the application results.
172
+ (list[ModelMonitoringApplicationResult]) of the application results.
174
173
  """
175
174
  raise NotImplementedError
176
175
 
177
176
  @staticmethod
178
177
  def _resolve_event(
179
- event: Dict[str, Any],
178
+ event: dict[str, Any],
180
179
  ) -> Tuple[
181
180
  str,
182
181
  pd.DataFrame,
@@ -235,7 +234,7 @@ class ModelMonitoringApplication(StepToDict):
235
234
  return context
236
235
 
237
236
  @staticmethod
238
- def _dict_to_histogram(histogram_dict: Dict[str, Dict[str, Any]]) -> pd.DataFrame:
237
+ def _dict_to_histogram(histogram_dict: dict[str, dict[str, Any]]) -> pd.DataFrame:
239
238
  """
240
239
  Convert histogram dictionary to pandas DataFrame with feature histograms as columns
241
240
 
@@ -262,10 +261,10 @@ class PushToMonitoringWriter(StepToDict):
262
261
 
263
262
  def __init__(
264
263
  self,
265
- project: str = None,
266
- writer_application_name: str = None,
267
- stream_uri: str = None,
268
- name: str = None,
264
+ project: Optional[str] = None,
265
+ writer_application_name: Optional[str] = None,
266
+ stream_uri: Optional[str] = None,
267
+ name: Optional[str] = None,
269
268
  ):
270
269
  """
271
270
  Class for pushing application results to the monitoring writer stream.
@@ -284,19 +283,13 @@ class PushToMonitoringWriter(StepToDict):
284
283
  self.output_stream = None
285
284
  self.name = name or "PushToMonitoringWriter"
286
285
 
287
- def do(
288
- self,
289
- event: Union[
290
- ModelMonitoringApplicationResult, List[ModelMonitoringApplicationResult]
291
- ],
292
- ):
286
+ def do(self, event: list[ModelMonitoringApplicationResult]) -> None:
293
287
  """
294
288
  Push application results to the monitoring writer stream.
295
289
 
296
290
  :param event: Monitoring result(s) to push.
297
291
  """
298
292
  self._lazy_init()
299
- event = event if isinstance(event, List) else [event]
300
293
  for result in event:
301
294
  data = result.to_dict()
302
295
  logger.info(f"Pushing data = {data} \n to stream = {self.stream_uri}")