mlrun 1.3.2rc1__py3-none-any.whl → 1.3.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (93) hide show
  1. mlrun/api/api/deps.py +14 -1
  2. mlrun/api/api/endpoints/frontend_spec.py +0 -2
  3. mlrun/api/api/endpoints/functions.py +15 -27
  4. mlrun/api/api/endpoints/grafana_proxy.py +435 -74
  5. mlrun/api/api/endpoints/healthz.py +5 -18
  6. mlrun/api/api/endpoints/model_endpoints.py +33 -37
  7. mlrun/api/api/utils.py +6 -13
  8. mlrun/api/crud/__init__.py +14 -16
  9. mlrun/api/crud/logs.py +5 -7
  10. mlrun/api/crud/model_monitoring/__init__.py +2 -2
  11. mlrun/api/crud/model_monitoring/model_endpoint_store.py +847 -0
  12. mlrun/api/crud/model_monitoring/model_endpoints.py +105 -328
  13. mlrun/api/crud/pipelines.py +2 -3
  14. mlrun/api/db/sqldb/models/models_mysql.py +52 -19
  15. mlrun/api/db/sqldb/models/models_sqlite.py +52 -19
  16. mlrun/api/db/sqldb/session.py +19 -26
  17. mlrun/api/schemas/__init__.py +2 -0
  18. mlrun/api/schemas/constants.py +0 -13
  19. mlrun/api/schemas/frontend_spec.py +0 -1
  20. mlrun/api/schemas/model_endpoints.py +38 -195
  21. mlrun/api/schemas/schedule.py +2 -2
  22. mlrun/api/utils/clients/log_collector.py +5 -0
  23. mlrun/builder.py +9 -41
  24. mlrun/config.py +1 -76
  25. mlrun/data_types/__init__.py +1 -6
  26. mlrun/data_types/data_types.py +1 -3
  27. mlrun/datastore/__init__.py +2 -9
  28. mlrun/datastore/sources.py +20 -25
  29. mlrun/datastore/store_resources.py +1 -1
  30. mlrun/datastore/targets.py +34 -67
  31. mlrun/datastore/utils.py +4 -26
  32. mlrun/db/base.py +2 -4
  33. mlrun/db/filedb.py +5 -13
  34. mlrun/db/httpdb.py +32 -64
  35. mlrun/db/sqldb.py +2 -4
  36. mlrun/errors.py +0 -5
  37. mlrun/execution.py +0 -2
  38. mlrun/feature_store/api.py +8 -24
  39. mlrun/feature_store/feature_set.py +6 -28
  40. mlrun/feature_store/feature_vector.py +0 -2
  41. mlrun/feature_store/ingestion.py +11 -8
  42. mlrun/feature_store/retrieval/base.py +43 -271
  43. mlrun/feature_store/retrieval/dask_merger.py +153 -55
  44. mlrun/feature_store/retrieval/job.py +3 -12
  45. mlrun/feature_store/retrieval/local_merger.py +130 -48
  46. mlrun/feature_store/retrieval/spark_merger.py +125 -126
  47. mlrun/features.py +2 -7
  48. mlrun/model_monitoring/constants.py +6 -48
  49. mlrun/model_monitoring/helpers.py +35 -118
  50. mlrun/model_monitoring/model_monitoring_batch.py +260 -293
  51. mlrun/model_monitoring/stream_processing_fs.py +253 -220
  52. mlrun/platforms/iguazio.py +0 -33
  53. mlrun/projects/project.py +72 -34
  54. mlrun/runtimes/base.py +0 -5
  55. mlrun/runtimes/daskjob.py +0 -2
  56. mlrun/runtimes/function.py +3 -29
  57. mlrun/runtimes/kubejob.py +15 -39
  58. mlrun/runtimes/local.py +45 -7
  59. mlrun/runtimes/mpijob/abstract.py +0 -2
  60. mlrun/runtimes/mpijob/v1.py +0 -2
  61. mlrun/runtimes/pod.py +0 -2
  62. mlrun/runtimes/remotesparkjob.py +0 -2
  63. mlrun/runtimes/serving.py +0 -6
  64. mlrun/runtimes/sparkjob/abstract.py +2 -39
  65. mlrun/runtimes/sparkjob/spark3job.py +0 -2
  66. mlrun/serving/__init__.py +1 -2
  67. mlrun/serving/routers.py +35 -35
  68. mlrun/serving/server.py +12 -22
  69. mlrun/serving/states.py +30 -162
  70. mlrun/serving/v2_serving.py +10 -13
  71. mlrun/utils/clones.py +1 -1
  72. mlrun/utils/model_monitoring.py +96 -122
  73. mlrun/utils/version/version.json +2 -2
  74. {mlrun-1.3.2rc1.dist-info → mlrun-1.3.2rc2.dist-info}/METADATA +27 -23
  75. {mlrun-1.3.2rc1.dist-info → mlrun-1.3.2rc2.dist-info}/RECORD +79 -92
  76. mlrun/api/crud/model_monitoring/grafana.py +0 -427
  77. mlrun/datastore/spark_udf.py +0 -40
  78. mlrun/model_monitoring/__init__.py +0 -44
  79. mlrun/model_monitoring/common.py +0 -112
  80. mlrun/model_monitoring/model_endpoint.py +0 -141
  81. mlrun/model_monitoring/stores/__init__.py +0 -106
  82. mlrun/model_monitoring/stores/kv_model_endpoint_store.py +0 -448
  83. mlrun/model_monitoring/stores/model_endpoint_store.py +0 -147
  84. mlrun/model_monitoring/stores/models/__init__.py +0 -23
  85. mlrun/model_monitoring/stores/models/base.py +0 -18
  86. mlrun/model_monitoring/stores/models/mysql.py +0 -100
  87. mlrun/model_monitoring/stores/models/sqlite.py +0 -98
  88. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -375
  89. mlrun/utils/db.py +0 -52
  90. {mlrun-1.3.2rc1.dist-info → mlrun-1.3.2rc2.dist-info}/LICENSE +0 -0
  91. {mlrun-1.3.2rc1.dist-info → mlrun-1.3.2rc2.dist-info}/WHEEL +0 -0
  92. {mlrun-1.3.2rc1.dist-info → mlrun-1.3.2rc2.dist-info}/entry_points.txt +0 -0
  93. {mlrun-1.3.2rc1.dist-info → mlrun-1.3.2rc2.dist-info}/top_level.txt +0 -0
@@ -31,12 +31,11 @@ import mlrun
31
31
  import mlrun.api.schemas
32
32
  import mlrun.data_types.infer
33
33
  import mlrun.feature_store as fstore
34
- import mlrun.model_monitoring
35
- import mlrun.model_monitoring.stores
36
34
  import mlrun.run
37
35
  import mlrun.utils.helpers
38
36
  import mlrun.utils.model_monitoring
39
37
  import mlrun.utils.v3io_clients
38
+ from mlrun.model_monitoring.constants import EventFieldType
40
39
  from mlrun.utils import logger
41
40
 
42
41
 
@@ -462,7 +461,6 @@ def calculate_inputs_statistics(
462
461
 
463
462
  :returns: The calculated statistics of the inputs data.
464
463
  """
465
-
466
464
  # Use `DFDataInfer` to calculate the statistics over the inputs:
467
465
  inputs_statistics = mlrun.data_types.infer.DFDataInfer.get_stats(
468
466
  df=inputs,
@@ -495,6 +493,8 @@ class BatchProcessor:
495
493
  self,
496
494
  context: mlrun.run.MLClientCtx,
497
495
  project: str,
496
+ model_monitoring_access_key: str,
497
+ v3io_access_key: str,
498
498
  ):
499
499
 
500
500
  """
@@ -502,16 +502,60 @@ class BatchProcessor:
502
502
 
503
503
  :param context: An MLRun context.
504
504
  :param project: Project name.
505
+ :param model_monitoring_access_key: Access key to apply the model monitoring process.
506
+ :param v3io_access_key: Token key for v3io.
505
507
  """
506
508
  self.context = context
507
509
  self.project = project
508
510
 
511
+ self.v3io_access_key = v3io_access_key
512
+ self.model_monitoring_access_key = (
513
+ model_monitoring_access_key or v3io_access_key
514
+ )
515
+
509
516
  # Initialize virtual drift object
510
517
  self.virtual_drift = VirtualDrift(inf_capping=10)
511
518
 
519
+ # Define the required paths for the project objects.
520
+ # Note that the kv table, tsdb, and the input stream paths are located at the default location
521
+ # while the parquet path is located at the user-space location
522
+ template = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default
523
+ kv_path = template.format(project=self.project, kind="endpoints")
524
+ (
525
+ _,
526
+ self.kv_container,
527
+ self.kv_path,
528
+ ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(kv_path)
529
+ tsdb_path = template.format(project=project, kind="events")
530
+ (
531
+ _,
532
+ self.tsdb_container,
533
+ self.tsdb_path,
534
+ ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(tsdb_path)
535
+ stream_path = template.format(project=self.project, kind="log_stream")
536
+ (
537
+ _,
538
+ self.stream_container,
539
+ self.stream_path,
540
+ ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(stream_path)
541
+ self.parquet_path = (
542
+ mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format(
543
+ project=project, kind="parquet"
544
+ )
545
+ )
546
+
512
547
  logger.info(
513
548
  "Initializing BatchProcessor",
514
549
  project=project,
550
+ model_monitoring_access_key_initalized=bool(model_monitoring_access_key),
551
+ v3io_access_key_initialized=bool(v3io_access_key),
552
+ parquet_path=self.parquet_path,
553
+ kv_container=self.kv_container,
554
+ kv_path=self.kv_path,
555
+ tsdb_container=self.tsdb_container,
556
+ tsdb_path=self.tsdb_path,
557
+ stream_container=self.stream_container,
558
+ stream_path=self.stream_path,
515
559
  )
516
560
 
517
561
  # Get drift thresholds from the model monitoring configuration
@@ -523,54 +567,7 @@ class BatchProcessor:
523
567
  )
524
568
 
525
569
  # Get a runtime database
526
-
527
- self.db = mlrun.model_monitoring.stores.get_model_endpoint_store(
528
- project=project
529
- )
530
-
531
- if not mlrun.mlconf.is_ce_mode():
532
- # TODO: Once there is a time series DB alternative in a non-CE deployment, we need to update this if
533
- # statement to be applied only for V3IO TSDB
534
- self._initialize_v3io_configurations()
535
-
536
- # If an error occurs, it will be raised using the following argument
537
- self.exception = None
538
-
539
- # Get the batch interval range
540
- self.batch_dict = context.parameters[
541
- mlrun.model_monitoring.EventFieldType.BATCH_INTERVALS_DICT
542
- ]
543
-
544
- # TODO: This will be removed in 1.5.0 once the job params can be parsed with different types
545
- # Convert batch dict string into a dictionary
546
- if isinstance(self.batch_dict, str):
547
- self._parse_batch_dict_str()
548
-
549
- def _initialize_v3io_configurations(self):
550
- self.v3io_access_key = os.environ.get("V3IO_ACCESS_KEY")
551
- self.model_monitoring_access_key = (
552
- os.environ.get("MODEL_MONITORING_ACCESS_KEY") or self.v3io_access_key
553
- )
554
-
555
- # Define the required paths for the project objects
556
- tsdb_path = mlrun.mlconf.get_model_monitoring_file_target_path(
557
- project=self.project, kind=mlrun.model_monitoring.FileTargetKind.EVENTS
558
- )
559
- (
560
- _,
561
- self.tsdb_container,
562
- self.tsdb_path,
563
- ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(tsdb_path)
564
- # stream_path = template.format(project=self.project, kind="log_stream")
565
- stream_path = mlrun.mlconf.get_model_monitoring_file_target_path(
566
- project=self.project,
567
- kind=mlrun.model_monitoring.FileTargetKind.LOG_STREAM,
568
- )
569
- (
570
- _,
571
- self.stream_container,
572
- self.stream_path,
573
- ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(stream_path)
570
+ self.db = mlrun.get_run_db()
574
571
 
575
572
  # Get the frames clients based on the v3io configuration
576
573
  # it will be used later for writing the results into the tsdb
@@ -583,26 +580,33 @@ class BatchProcessor:
583
580
  token=self.v3io_access_key,
584
581
  )
585
582
 
583
+ # If an error occurs, it will be raised using the following argument
584
+ self.exception = None
585
+
586
+ # Get the batch interval range
587
+ self.batch_dict = context.parameters[EventFieldType.BATCH_INTERVALS_DICT]
588
+
589
+ # TODO: This will be removed in 1.2.0 once the job params can be parsed with different types
590
+ # Convert batch dict string into a dictionary
591
+ if isinstance(self.batch_dict, str):
592
+ self._parse_batch_dict_str()
593
+
586
594
  def post_init(self):
587
595
  """
588
596
  Preprocess of the batch processing.
589
597
  """
590
598
 
591
- if not mlrun.mlconf.is_ce_mode():
592
- # Create v3io stream based on the input stream
593
- response = self.v3io.create_stream(
594
- container=self.stream_container,
595
- path=self.stream_path,
596
- shard_count=1,
597
- raise_for_status=v3io.dataplane.RaiseForStatus.never,
598
- access_key=self.v3io_access_key,
599
- )
599
+ # create v3io stream based on the input stream
600
+ response = self.v3io.create_stream(
601
+ container=self.stream_container,
602
+ path=self.stream_path,
603
+ shard_count=1,
604
+ raise_for_status=v3io.dataplane.RaiseForStatus.never,
605
+ access_key=self.v3io_access_key,
606
+ )
600
607
 
601
- if not (
602
- response.status_code == 400 and "ResourceInUse" in str(response.body)
603
- ):
604
- response.raise_for_status([409, 204, 403])
605
- pass
608
+ if not (response.status_code == 400 and "ResourceInUse" in str(response.body)):
609
+ response.raise_for_status([409, 204, 403])
606
610
 
607
611
  def run(self):
608
612
  """
@@ -610,202 +614,231 @@ class BatchProcessor:
610
614
  """
611
615
  # Get model endpoints (each deployed project has at least 1 serving model):
612
616
  try:
613
- endpoints = self.db.list_model_endpoints()
617
+ endpoints = self.db.list_model_endpoints(self.project)
614
618
  except Exception as e:
615
619
  logger.error("Failed to list endpoints", exc=e)
616
620
  return
617
621
 
618
- for endpoint in endpoints:
622
+ active_endpoints = set()
623
+ for endpoint in endpoints.endpoints:
619
624
  if (
620
- endpoint[mlrun.model_monitoring.EventFieldType.ACTIVE]
621
- and endpoint[mlrun.model_monitoring.EventFieldType.MONITORING_MODE]
622
- == mlrun.model_monitoring.ModelMonitoringMode.enabled.value
625
+ endpoint.spec.active
626
+ and endpoint.spec.monitoring_mode
627
+ == mlrun.api.schemas.ModelMonitoringMode.enabled.value
623
628
  ):
629
+ active_endpoints.add(endpoint.metadata.uid)
630
+
631
+ # perform drift analysis for each model endpoint
632
+ for endpoint_id in active_endpoints:
633
+ try:
634
+
635
+ # Get model endpoint object:
636
+ endpoint = self.db.get_model_endpoint(
637
+ project=self.project, endpoint_id=endpoint_id
638
+ )
639
+
624
640
  # Skip router endpoint:
625
641
  if (
626
- int(endpoint[mlrun.model_monitoring.EventFieldType.ENDPOINT_TYPE])
627
- == mlrun.model_monitoring.EndpointType.ROUTER
642
+ endpoint.status.endpoint_type
643
+ == mlrun.utils.model_monitoring.EndpointType.ROUTER
628
644
  ):
629
- # Router endpoint has no feature stats
630
- logger.info(
631
- f"{endpoint[mlrun.model_monitoring.EventFieldType.UID]} is router skipping"
632
- )
645
+ # endpoint.status.feature_stats is None
646
+ logger.info(f"{endpoint_id} is router skipping")
633
647
  continue
634
- self.update_drift_metrics(endpoint=endpoint)
635
648
 
636
- def update_drift_metrics(self, endpoint: dict):
637
- try:
638
- # Convert feature set into dataframe and get the latest dataset
639
- (
640
- _,
641
- serving_function_name,
642
- _,
643
- _,
644
- ) = mlrun.utils.helpers.parse_versioned_object_uri(
645
- endpoint[mlrun.model_monitoring.EventFieldType.FUNCTION_URI]
646
- )
649
+ # convert feature set into dataframe and get the latest dataset
650
+ (
651
+ _,
652
+ serving_function_name,
653
+ _,
654
+ _,
655
+ ) = mlrun.utils.helpers.parse_versioned_object_uri(
656
+ endpoint.spec.function_uri
657
+ )
647
658
 
648
- model_name = endpoint[mlrun.model_monitoring.EventFieldType.MODEL].replace(
649
- ":", "-"
650
- )
659
+ model_name = endpoint.spec.model.replace(":", "-")
651
660
 
652
- m_fs = fstore.get_feature_set(
653
- f"store://feature-sets/{self.project}/monitoring-{serving_function_name}-{model_name}"
654
- )
661
+ m_fs = fstore.get_feature_set(
662
+ f"store://feature-sets/{self.project}/monitoring-{serving_function_name}-{model_name}"
663
+ )
655
664
 
656
- # Getting batch interval start time and end time
657
- start_time, end_time = self._get_interval_range()
665
+ # Getting batch interval start time and end time
666
+ start_time, end_time = self.get_interval_range()
658
667
 
659
- try:
660
- df = m_fs.to_dataframe(
661
- start_time=start_time,
662
- end_time=end_time,
663
- time_column=mlrun.model_monitoring.EventFieldType.TIMESTAMP,
664
- )
668
+ try:
669
+ df = m_fs.to_dataframe(
670
+ start_time=start_time,
671
+ end_time=end_time,
672
+ time_column="timestamp",
673
+ )
674
+
675
+ if len(df) == 0:
676
+ logger.warn(
677
+ "Not enough model events since the beginning of the batch interval",
678
+ parquet_target=m_fs.status.targets[0].path,
679
+ endpoint=endpoint_id,
680
+ min_rqeuired_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events,
681
+ start_time=str(
682
+ datetime.datetime.now() - datetime.timedelta(hours=1)
683
+ ),
684
+ end_time=str(datetime.datetime.now()),
685
+ )
686
+ continue
665
687
 
666
- if len(df) == 0:
688
+ # TODO: The below warn will be removed once the state of the Feature Store target is updated
689
+ # as expected. In that case, the existence of the file will be checked before trying to get
690
+ # the offline data from the feature set.
691
+ # Continue if not enough events provided since the deployment of the model endpoint
692
+ except FileNotFoundError:
667
693
  logger.warn(
668
- "Not enough model events since the beginning of the batch interval",
694
+ "Parquet not found, probably due to not enough model events",
669
695
  parquet_target=m_fs.status.targets[0].path,
670
- endpoint=endpoint[mlrun.model_monitoring.EventFieldType.UID],
696
+ endpoint=endpoint_id,
671
697
  min_rqeuired_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events,
672
- start_time=str(
673
- datetime.datetime.now() - datetime.timedelta(hours=1)
674
- ),
675
- end_time=str(datetime.datetime.now()),
676
698
  )
677
- return
678
-
679
- # TODO: The below warn will be removed once the state of the Feature Store target is updated
680
- # as expected. In that case, the existence of the file will be checked before trying to get
681
- # the offline data from the feature set.
682
- # Continue if not enough events provided since the deployment of the model endpoint
683
- except FileNotFoundError:
684
- logger.warn(
685
- "Parquet not found, probably due to not enough model events",
686
- parquet_target=m_fs.status.targets[0].path,
687
- endpoint=endpoint[mlrun.model_monitoring.EventFieldType.UID],
688
- min_rqeuired_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events,
689
- )
690
- return
691
-
692
- # Get feature names from monitoring feature set
693
- feature_names = [
694
- feature_name["name"] for feature_name in m_fs.spec.features.to_dict()
695
- ]
696
-
697
- # Create DataFrame based on the input features
698
- stats_columns = [
699
- mlrun.model_monitoring.EventFieldType.TIMESTAMP,
700
- *feature_names,
701
- ]
699
+ continue
702
700
 
703
- # Add label names if provided
704
- if endpoint[mlrun.model_monitoring.EventFieldType.LABEL_NAMES]:
705
- labels = endpoint[mlrun.model_monitoring.EventFieldType.LABEL_NAMES]
706
- if isinstance(labels, str):
707
- labels = json.loads(labels)
708
- stats_columns.extend(labels)
709
- named_features_df = df[stats_columns].copy()
710
-
711
- # Infer feature set stats and schema
712
- fstore.api._infer_from_static_df(
713
- named_features_df,
714
- m_fs,
715
- options=mlrun.data_types.infer.InferOptions.all_stats(),
716
- )
701
+ # Get feature names from monitoring feature set
702
+ feature_names = [
703
+ feature_name["name"]
704
+ for feature_name in m_fs.spec.features.to_dict()
705
+ ]
706
+
707
+ # Create DataFrame based on the input features
708
+ stats_columns = [
709
+ "timestamp",
710
+ *feature_names,
711
+ ]
712
+
713
+ # Add label names if provided
714
+ if endpoint.spec.label_names:
715
+ stats_columns.extend(endpoint.spec.label_names)
716
+
717
+ named_features_df = df[stats_columns].copy()
718
+
719
+ # Infer feature set stats and schema
720
+ fstore.api._infer_from_static_df(
721
+ named_features_df,
722
+ m_fs,
723
+ options=mlrun.data_types.infer.InferOptions.all_stats(),
724
+ )
717
725
 
718
- # Save feature set to apply changes
719
- m_fs.save()
726
+ # Save feature set to apply changes
727
+ m_fs.save()
720
728
 
721
- # Get the timestamp of the latest request:
722
- timestamp = df[mlrun.model_monitoring.EventFieldType.TIMESTAMP].iloc[-1]
729
+ # Get the timestamp of the latest request:
730
+ timestamp = df["timestamp"].iloc[-1]
723
731
 
724
- # Get the feature stats from the model endpoint for reference data
725
- feature_stats = json.loads(
726
- endpoint[mlrun.model_monitoring.EventFieldType.FEATURE_STATS]
727
- )
732
+ # Get the current stats:
733
+ current_stats = calculate_inputs_statistics(
734
+ sample_set_statistics=endpoint.status.feature_stats,
735
+ inputs=named_features_df,
736
+ )
728
737
 
729
- # Get the current stats:
730
- current_stats = calculate_inputs_statistics(
731
- sample_set_statistics=feature_stats,
732
- inputs=named_features_df,
733
- )
738
+ # Compute the drift based on the histogram of the current stats and the histogram of the original
739
+ # feature stats that can be found in the model endpoint object:
740
+ drift_result = self.virtual_drift.compute_drift_from_histograms(
741
+ feature_stats=endpoint.status.feature_stats,
742
+ current_stats=current_stats,
743
+ )
744
+ logger.info("Drift result", drift_result=drift_result)
734
745
 
735
- # Compute the drift based on the histogram of the current stats and the histogram of the original
736
- # feature stats that can be found in the model endpoint object:
737
- drift_result = self.virtual_drift.compute_drift_from_histograms(
738
- feature_stats=feature_stats,
739
- current_stats=current_stats,
740
- )
741
- logger.info("Drift result", drift_result=drift_result)
742
-
743
- # Get drift thresholds from the model configuration:
744
- monitor_configuration = (
745
- json.loads(
746
- endpoint[
747
- mlrun.model_monitoring.EventFieldType.MONITOR_CONFIGURATION
748
- ]
746
+ # Get drift thresholds from the model configuration:
747
+ monitor_configuration = endpoint.spec.monitor_configuration or {}
748
+ possible_drift = monitor_configuration.get(
749
+ "possible_drift", self.default_possible_drift_threshold
750
+ )
751
+ drift_detected = monitor_configuration.get(
752
+ "drift_detected", self.default_drift_detected_threshold
749
753
  )
750
- or {}
751
- )
752
- possible_drift = monitor_configuration.get(
753
- "possible_drift", self.default_possible_drift_threshold
754
- )
755
- drift_detected = monitor_configuration.get(
756
- "drift_detected", self.default_drift_detected_threshold
757
- )
758
754
 
759
- # Check for possible drift based on the results of the statistical metrics defined above:
760
- drift_status, drift_measure = self.virtual_drift.check_for_drift(
761
- metrics_results_dictionary=drift_result,
762
- possible_drift_threshold=possible_drift,
763
- drift_detected_threshold=drift_detected,
764
- )
765
- logger.info(
766
- "Drift status",
767
- endpoint_id=endpoint[mlrun.model_monitoring.EventFieldType.UID],
768
- drift_status=drift_status.value,
769
- drift_measure=drift_measure,
770
- )
755
+ # Check for possible drift based on the results of the statistical metrics defined above:
756
+ drift_status, drift_measure = self.virtual_drift.check_for_drift(
757
+ metrics_results_dictionary=drift_result,
758
+ possible_drift_threshold=possible_drift,
759
+ drift_detected_threshold=drift_detected,
760
+ )
761
+ logger.info(
762
+ "Drift status",
763
+ endpoint_id=endpoint_id,
764
+ drift_status=drift_status.value,
765
+ drift_measure=drift_measure,
766
+ )
771
767
 
772
- attributes = {
773
- "current_stats": json.dumps(current_stats),
774
- "drift_measures": json.dumps(drift_result),
775
- "drift_status": drift_status.value,
776
- }
768
+ # If drift was detected, add the results to the input stream
769
+ if (
770
+ drift_status == DriftStatus.POSSIBLE_DRIFT
771
+ or drift_status == DriftStatus.DRIFT_DETECTED
772
+ ):
773
+ self.v3io.stream.put_records(
774
+ container=self.stream_container,
775
+ stream_path=self.stream_path,
776
+ records=[
777
+ {
778
+ "data": json.dumps(
779
+ {
780
+ "endpoint_id": endpoint_id,
781
+ "drift_status": drift_status.value,
782
+ "drift_measure": drift_measure,
783
+ "drift_per_feature": {**drift_result},
784
+ }
785
+ )
786
+ }
787
+ ],
788
+ )
777
789
 
778
- self.db.update_model_endpoint(
779
- endpoint_id=endpoint[mlrun.model_monitoring.EventFieldType.UID],
780
- attributes=attributes,
781
- )
790
+ attributes = {
791
+ "current_stats": json.dumps(current_stats),
792
+ "drift_measures": json.dumps(drift_result),
793
+ "drift_status": drift_status.value,
794
+ }
782
795
 
783
- if not mlrun.mlconf.is_ce_mode():
784
- # Update drift results in TSDB
785
- self._update_drift_in_input_stream(
786
- endpoint_id=endpoint[mlrun.model_monitoring.EventFieldType.UID],
787
- drift_status=drift_status,
788
- drift_measure=drift_measure,
789
- drift_result=drift_result,
790
- timestamp=timestamp,
791
- )
792
- logger.info(
793
- "Done updating drift measures",
794
- endpoint_id=endpoint[mlrun.model_monitoring.EventFieldType.UID],
796
+ self.db.patch_model_endpoint(
797
+ project=self.project,
798
+ endpoint_id=endpoint_id,
799
+ attributes=attributes,
795
800
  )
796
801
 
797
- except Exception as e:
798
- logger.error(
799
- f"Exception for endpoint {endpoint[mlrun.model_monitoring.EventFieldType.UID]}"
800
- )
801
- self.exception = e
802
+ # Update the results in tsdb:
803
+ tsdb_drift_measures = {
804
+ "endpoint_id": endpoint_id,
805
+ "timestamp": pd.to_datetime(
806
+ timestamp,
807
+ format=EventFieldType.TIME_FORMAT,
808
+ ),
809
+ "record_type": "drift_measures",
810
+ "tvd_mean": drift_result["tvd_mean"],
811
+ "kld_mean": drift_result["kld_mean"],
812
+ "hellinger_mean": drift_result["hellinger_mean"],
813
+ }
814
+
815
+ try:
816
+ self.frames.write(
817
+ backend="tsdb",
818
+ table=self.tsdb_path,
819
+ dfs=pd.DataFrame.from_dict([tsdb_drift_measures]),
820
+ index_cols=["timestamp", "endpoint_id", "record_type"],
821
+ )
822
+ except v3io_frames.errors.Error as err:
823
+ logger.warn(
824
+ "Could not write drift measures to TSDB",
825
+ err=err,
826
+ tsdb_path=self.tsdb_path,
827
+ endpoint=endpoint_id,
828
+ )
829
+
830
+ logger.info("Done updating drift measures", endpoint_id=endpoint_id)
802
831
 
803
- def _get_interval_range(self) -> Tuple[datetime.datetime, datetime.datetime]:
832
+ except Exception as e:
833
+ logger.error(f"Exception for endpoint {endpoint_id}")
834
+ self.exception = e
835
+
836
+ def get_interval_range(self) -> Tuple[datetime.datetime, datetime.datetime]:
804
837
  """Getting batch interval time range"""
805
838
  minutes, hours, days = (
806
- self.batch_dict[mlrun.model_monitoring.EventFieldType.MINUTES],
807
- self.batch_dict[mlrun.model_monitoring.EventFieldType.HOURS],
808
- self.batch_dict[mlrun.model_monitoring.EventFieldType.DAYS],
839
+ self.batch_dict[EventFieldType.MINUTES],
840
+ self.batch_dict[EventFieldType.HOURS],
841
+ self.batch_dict[EventFieldType.DAYS],
809
842
  )
810
843
  start_time = datetime.datetime.now() - datetime.timedelta(
811
844
  minutes=minutes, hours=hours, days=days
@@ -825,79 +858,13 @@ class BatchProcessor:
825
858
  pair_list = pair.split(":")
826
859
  self.batch_dict[pair_list[0]] = float(pair_list[1])
827
860
 
828
- def _update_drift_in_input_stream(
829
- self,
830
- endpoint_id: str,
831
- drift_status: DriftStatus,
832
- drift_measure: float,
833
- drift_result: Dict[str, Dict[str, Any]],
834
- timestamp: pd._libs.tslibs.timestamps.Timestamp,
835
- ):
836
- """Update drift results in input stream.
837
-
838
- :param endpoint_id: The unique id of the model endpoint.
839
- :param drift_status: Drift status result. Possible values can be found under DriftStatus enum class.
840
- :param drift_measure: The drift result (float) based on the mean of the Total Variance Distance and the
841
- Hellinger distance.
842
- :param drift_result: A dictionary that includes the drift results for each feature.
843
- :param timestamp: Pandas Timestamp value.
844
-
845
- """
846
-
847
- if (
848
- drift_status == DriftStatus.POSSIBLE_DRIFT
849
- or drift_status == DriftStatus.DRIFT_DETECTED
850
- ):
851
- self.v3io.stream.put_records(
852
- container=self.stream_container,
853
- stream_path=self.stream_path,
854
- records=[
855
- {
856
- "data": json.dumps(
857
- {
858
- "endpoint_id": endpoint_id,
859
- "drift_status": drift_status.value,
860
- "drift_measure": drift_measure,
861
- "drift_per_feature": {**drift_result},
862
- }
863
- )
864
- }
865
- ],
866
- )
867
-
868
- # Update the results in tsdb:
869
- tsdb_drift_measures = {
870
- "endpoint_id": endpoint_id,
871
- "timestamp": pd.to_datetime(
872
- timestamp,
873
- format=mlrun.model_monitoring.EventFieldType.TIME_FORMAT,
874
- ),
875
- "record_type": "drift_measures",
876
- "tvd_mean": drift_result["tvd_mean"],
877
- "kld_mean": drift_result["kld_mean"],
878
- "hellinger_mean": drift_result["hellinger_mean"],
879
- }
880
-
881
- try:
882
- self.frames.write(
883
- backend="tsdb",
884
- table=self.tsdb_path,
885
- dfs=pd.DataFrame.from_dict([tsdb_drift_measures]),
886
- index_cols=["timestamp", "endpoint_id", "record_type"],
887
- )
888
- except v3io_frames.errors.Error as err:
889
- logger.warn(
890
- "Could not write drift measures to TSDB",
891
- err=err,
892
- tsdb_path=self.tsdb_path,
893
- endpoint=endpoint_id,
894
- )
895
-
896
861
 
897
862
  def handler(context: mlrun.run.MLClientCtx):
898
863
  batch_processor = BatchProcessor(
899
864
  context=context,
900
865
  project=context.project,
866
+ model_monitoring_access_key=os.environ.get("MODEL_MONITORING_ACCESS_KEY"),
867
+ v3io_access_key=os.environ.get("V3IO_ACCESS_KEY"),
901
868
  )
902
869
  batch_processor.post_init()
903
870
  batch_processor.run()