mlrun 1.7.0rc5__py3-none-any.whl → 1.7.0rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (75) hide show
  1. mlrun/artifacts/base.py +2 -1
  2. mlrun/artifacts/plots.py +9 -5
  3. mlrun/common/constants.py +6 -0
  4. mlrun/common/schemas/__init__.py +2 -0
  5. mlrun/common/schemas/model_monitoring/__init__.py +4 -0
  6. mlrun/common/schemas/model_monitoring/constants.py +35 -18
  7. mlrun/common/schemas/project.py +1 -0
  8. mlrun/common/types.py +7 -1
  9. mlrun/config.py +19 -6
  10. mlrun/data_types/data_types.py +4 -0
  11. mlrun/datastore/alibaba_oss.py +130 -0
  12. mlrun/datastore/azure_blob.py +4 -5
  13. mlrun/datastore/base.py +22 -16
  14. mlrun/datastore/datastore.py +4 -0
  15. mlrun/datastore/google_cloud_storage.py +1 -1
  16. mlrun/datastore/sources.py +7 -7
  17. mlrun/db/base.py +14 -6
  18. mlrun/db/factory.py +1 -1
  19. mlrun/db/httpdb.py +61 -56
  20. mlrun/db/nopdb.py +3 -0
  21. mlrun/launcher/__init__.py +1 -1
  22. mlrun/launcher/base.py +1 -1
  23. mlrun/launcher/client.py +1 -1
  24. mlrun/launcher/factory.py +1 -1
  25. mlrun/launcher/local.py +1 -1
  26. mlrun/launcher/remote.py +1 -1
  27. mlrun/model.py +1 -0
  28. mlrun/model_monitoring/__init__.py +1 -1
  29. mlrun/model_monitoring/api.py +104 -301
  30. mlrun/model_monitoring/application.py +21 -21
  31. mlrun/model_monitoring/applications/histogram_data_drift.py +130 -40
  32. mlrun/model_monitoring/controller.py +26 -33
  33. mlrun/model_monitoring/db/__init__.py +16 -0
  34. mlrun/model_monitoring/{stores → db/stores}/__init__.py +43 -34
  35. mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
  36. mlrun/model_monitoring/{stores/model_endpoint_store.py → db/stores/base/store.py} +47 -6
  37. mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
  38. mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +49 -0
  39. mlrun/model_monitoring/{stores → db/stores/sqldb}/models/base.py +76 -3
  40. mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +68 -0
  41. mlrun/model_monitoring/{stores → db/stores/sqldb}/models/sqlite.py +13 -1
  42. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +662 -0
  43. mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
  44. mlrun/model_monitoring/{stores/kv_model_endpoint_store.py → db/stores/v3io_kv/kv_store.py} +134 -3
  45. mlrun/model_monitoring/features_drift_table.py +34 -22
  46. mlrun/model_monitoring/helpers.py +45 -6
  47. mlrun/model_monitoring/stream_processing.py +43 -9
  48. mlrun/model_monitoring/tracking_policy.py +7 -1
  49. mlrun/model_monitoring/writer.py +4 -36
  50. mlrun/projects/pipelines.py +13 -1
  51. mlrun/projects/project.py +279 -117
  52. mlrun/run.py +72 -74
  53. mlrun/runtimes/__init__.py +35 -0
  54. mlrun/runtimes/base.py +7 -1
  55. mlrun/runtimes/nuclio/api_gateway.py +188 -61
  56. mlrun/runtimes/nuclio/application/__init__.py +15 -0
  57. mlrun/runtimes/nuclio/application/application.py +283 -0
  58. mlrun/runtimes/nuclio/application/reverse_proxy.go +87 -0
  59. mlrun/runtimes/nuclio/function.py +53 -1
  60. mlrun/runtimes/nuclio/serving.py +28 -32
  61. mlrun/runtimes/pod.py +27 -1
  62. mlrun/serving/server.py +4 -6
  63. mlrun/serving/states.py +41 -33
  64. mlrun/utils/helpers.py +34 -0
  65. mlrun/utils/version/version.json +2 -2
  66. {mlrun-1.7.0rc5.dist-info → mlrun-1.7.0rc7.dist-info}/METADATA +14 -5
  67. {mlrun-1.7.0rc5.dist-info → mlrun-1.7.0rc7.dist-info}/RECORD +71 -64
  68. mlrun/model_monitoring/batch.py +0 -974
  69. mlrun/model_monitoring/stores/models/__init__.py +0 -27
  70. mlrun/model_monitoring/stores/models/mysql.py +0 -34
  71. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -382
  72. {mlrun-1.7.0rc5.dist-info → mlrun-1.7.0rc7.dist-info}/LICENSE +0 -0
  73. {mlrun-1.7.0rc5.dist-info → mlrun-1.7.0rc7.dist-info}/WHEEL +0 -0
  74. {mlrun-1.7.0rc5.dist-info → mlrun-1.7.0rc7.dist-info}/entry_points.txt +0 -0
  75. {mlrun-1.7.0rc5.dist-info → mlrun-1.7.0rc7.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,8 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import hashlib
16
- import json
17
16
  import typing
17
+ import warnings
18
18
  from datetime import datetime
19
19
 
20
20
  import numpy as np
@@ -22,13 +22,13 @@ import pandas as pd
22
22
 
23
23
  import mlrun.artifacts
24
24
  import mlrun.common.helpers
25
+ import mlrun.common.schemas.model_monitoring.constants as mm_consts
25
26
  import mlrun.feature_store
26
- from mlrun.common.schemas.model_monitoring import EventFieldType, ModelMonitoringMode
27
+ import mlrun.model_monitoring.application
28
+ import mlrun.serving
27
29
  from mlrun.data_types.infer import InferOptions, get_df_stats
28
30
  from mlrun.utils import datetime_now, logger
29
31
 
30
- from .batch import VirtualDrift
31
- from .features_drift_table import FeaturesDriftTablePlot
32
32
  from .helpers import update_model_endpoint_last_request
33
33
  from .model_endpoint import ModelEndpoint
34
34
 
@@ -48,7 +48,7 @@ def get_or_create_model_endpoint(
48
48
  sample_set_statistics: dict[str, typing.Any] = None,
49
49
  drift_threshold: float = None,
50
50
  possible_drift_threshold: float = None,
51
- monitoring_mode: ModelMonitoringMode = ModelMonitoringMode.disabled,
51
+ monitoring_mode: mm_consts.ModelMonitoringMode = mm_consts.ModelMonitoringMode.disabled,
52
52
  db_session=None,
53
53
  ) -> ModelEndpoint:
54
54
  """
@@ -128,20 +128,19 @@ def record_results(
128
128
  context: typing.Optional[mlrun.MLClientCtx] = None,
129
129
  infer_results_df: typing.Optional[pd.DataFrame] = None,
130
130
  sample_set_statistics: typing.Optional[dict[str, typing.Any]] = None,
131
- monitoring_mode: ModelMonitoringMode = ModelMonitoringMode.enabled,
131
+ monitoring_mode: mm_consts.ModelMonitoringMode = mm_consts.ModelMonitoringMode.enabled,
132
+ # Deprecated arguments:
132
133
  drift_threshold: typing.Optional[float] = None,
133
134
  possible_drift_threshold: typing.Optional[float] = None,
134
135
  trigger_monitoring_job: bool = False,
135
136
  artifacts_tag: str = "",
136
- default_batch_image="mlrun/mlrun",
137
+ default_batch_image: str = "mlrun/mlrun",
137
138
  ) -> ModelEndpoint:
138
139
  """
139
140
  Write a provided inference dataset to model endpoint parquet target. If not exist, generate a new model endpoint
140
141
  record and use the provided sample set statistics as feature stats that will be used later for the drift analysis.
141
- To manually trigger the monitoring batch job, set `trigger_monitoring_job=True` and then the batch
142
- job will immediately perform drift analysis between the sample set statistics stored in the model and the new
143
- input data (along with the outputs). The drift rule is the value per-feature mean of the TVD and Hellinger scores
144
- according to the provided thresholds.
142
+ To activate model monitoring, run `project.enable_model_monitoring()`. The model monitoring applications will be
143
+ triggered with the recorded data according to a periodic schedule.
145
144
 
146
145
  :param project: Project name.
147
146
  :param model_path: The model Store path.
@@ -160,17 +159,47 @@ def record_results(
160
159
  the current model endpoint.
161
160
  :param monitoring_mode: If enabled, apply model monitoring features on the provided endpoint id. Enabled
162
161
  by default.
163
- :param drift_threshold: The threshold of which to mark drifts.
164
- :param possible_drift_threshold: The threshold of which to mark possible drifts.
165
- :param trigger_monitoring_job: If true, run the batch drift job. If not exists, the monitoring batch function
166
- will be registered through MLRun API with the provided image.
167
- :param artifacts_tag: Tag to use for all the artifacts resulted from the function. Will be relevant
168
- only if the monitoring batch job has been triggered.
169
-
170
- :param default_batch_image: The image that will be used when registering the model monitoring batch job.
162
+ :param drift_threshold: (deprecated) The threshold of which to mark drifts.
163
+ :param possible_drift_threshold: (deprecated) The threshold of which to mark possible drifts.
164
+ :param trigger_monitoring_job: (deprecated) If true, run the batch drift job. If not exists, the monitoring
165
+ batch function will be registered through MLRun API with the provided image.
166
+ :param artifacts_tag: (deprecated) Tag to use for all the artifacts resulted from the function.
167
+ Will be relevant only if the monitoring batch job has been triggered.
168
+ :param default_batch_image: (deprecated) The image that will be used when registering the model monitoring
169
+ batch job.
171
170
 
172
171
  :return: A ModelEndpoint object
173
172
  """
173
+
174
+ if drift_threshold is not None or possible_drift_threshold is not None:
175
+ warnings.warn(
176
+ "Custom drift threshold arguments are deprecated since version "
177
+ "1.7.0 and have no effect. They will be removed in version 1.9.0.\n"
178
+ "To enable the default histogram data drift application, run:\n"
179
+ "`project.enable_model_monitoring()`.",
180
+ FutureWarning,
181
+ )
182
+ if trigger_monitoring_job is not False:
183
+ warnings.warn(
184
+ "`trigger_monitoring_job` argument is deprecated since version "
185
+ "1.7.0 and has no effect. It will be removed in version 1.9.0.\n"
186
+ "To enable the default histogram data drift application, run:\n"
187
+ "`project.enable_model_monitoring()`.",
188
+ FutureWarning,
189
+ )
190
+ if artifacts_tag != "":
191
+ warnings.warn(
192
+ "`artifacts_tag` argument is deprecated since version "
193
+ "1.7.0 and has no effect. It will be removed in version 1.9.0.",
194
+ FutureWarning,
195
+ )
196
+ if default_batch_image != "mlrun/mlrun":
197
+ warnings.warn(
198
+ "`default_batch_image` argument is deprecated since version "
199
+ "1.7.0 and has no effect. It will be removed in version 1.9.0.",
200
+ FutureWarning,
201
+ )
202
+
174
203
  db = mlrun.get_run_db()
175
204
 
176
205
  model_endpoint = get_or_create_model_endpoint(
@@ -181,8 +210,6 @@ def record_results(
181
210
  function_name=function_name,
182
211
  context=context,
183
212
  sample_set_statistics=sample_set_statistics,
184
- drift_threshold=drift_threshold,
185
- possible_drift_threshold=possible_drift_threshold,
186
213
  monitoring_mode=monitoring_mode,
187
214
  db_session=db,
188
215
  )
@@ -206,33 +233,6 @@ def record_results(
206
233
  db=db,
207
234
  )
208
235
 
209
- if trigger_monitoring_job:
210
- # Run the monitoring batch drift job
211
- trigger_drift_batch_job(
212
- project=project,
213
- default_batch_image=default_batch_image,
214
- model_endpoints_ids=[model_endpoint.metadata.uid],
215
- db_session=db,
216
- )
217
-
218
- # Getting drift thresholds if not provided
219
- drift_threshold, possible_drift_threshold = get_drift_thresholds_if_not_none(
220
- model_endpoint=model_endpoint,
221
- drift_threshold=drift_threshold,
222
- possible_drift_threshold=possible_drift_threshold,
223
- )
224
-
225
- perform_drift_analysis(
226
- project=project,
227
- context=context,
228
- sample_set_statistics=model_endpoint.status.feature_stats,
229
- drift_threshold=drift_threshold,
230
- possible_drift_threshold=possible_drift_threshold,
231
- artifacts_tag=artifacts_tag,
232
- endpoint_id=model_endpoint.metadata.uid,
233
- db_session=db,
234
- )
235
-
236
236
  return model_endpoint
237
237
 
238
238
 
@@ -282,7 +282,7 @@ def _model_endpoint_validations(
282
282
  # drift and possible drift thresholds
283
283
  if drift_threshold:
284
284
  current_drift_threshold = model_endpoint.spec.monitor_configuration.get(
285
- EventFieldType.DRIFT_DETECTED_THRESHOLD,
285
+ mm_consts.EventFieldType.DRIFT_DETECTED_THRESHOLD,
286
286
  mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.drift_detected,
287
287
  )
288
288
  if current_drift_threshold != drift_threshold:
@@ -293,7 +293,7 @@ def _model_endpoint_validations(
293
293
 
294
294
  if possible_drift_threshold:
295
295
  current_possible_drift_threshold = model_endpoint.spec.monitor_configuration.get(
296
- EventFieldType.POSSIBLE_DRIFT_THRESHOLD,
296
+ mm_consts.EventFieldType.POSSIBLE_DRIFT_THRESHOLD,
297
297
  mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.possible_drift,
298
298
  )
299
299
  if current_possible_drift_threshold != possible_drift_threshold:
@@ -303,40 +303,6 @@ def _model_endpoint_validations(
303
303
  )
304
304
 
305
305
 
306
- def get_drift_thresholds_if_not_none(
307
- model_endpoint: ModelEndpoint,
308
- drift_threshold: float = None,
309
- possible_drift_threshold: float = None,
310
- ) -> tuple[float, float]:
311
- """
312
- Get drift and possible drift thresholds. If one of the thresholds is missing, will try to retrieve
313
- it from the `ModelEndpoint` object. If not defined under the `ModelEndpoint` as well, will retrieve it from
314
- the default mlrun configuration.
315
-
316
- :param model_endpoint: `ModelEndpoint` object.
317
- :param drift_threshold: The threshold of which to mark drifts.
318
- :param possible_drift_threshold: The threshold of which to mark possible drifts.
319
-
320
- :return: A Tuple of:
321
- [0] drift threshold as a float
322
- [1] possible drift threshold as a float
323
- """
324
- if not drift_threshold:
325
- # Getting drift threshold value from either model endpoint or monitoring default configurations
326
- drift_threshold = model_endpoint.spec.monitor_configuration.get(
327
- EventFieldType.DRIFT_DETECTED_THRESHOLD,
328
- mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.drift_detected,
329
- )
330
- if not possible_drift_threshold:
331
- # Getting possible drift threshold value from either model endpoint or monitoring default configurations
332
- possible_drift_threshold = model_endpoint.spec.monitor_configuration.get(
333
- EventFieldType.POSSIBLE_DRIFT_THRESHOLD,
334
- mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.possible_drift,
335
- )
336
-
337
- return drift_threshold, possible_drift_threshold
338
-
339
-
340
306
  def write_monitoring_df(
341
307
  endpoint_id: str,
342
308
  infer_results_df: pd.DataFrame,
@@ -366,14 +332,14 @@ def write_monitoring_df(
366
332
  )
367
333
 
368
334
  # Modify the DataFrame to the required structure that will be used later by the monitoring batch job
369
- if EventFieldType.TIMESTAMP not in infer_results_df.columns:
335
+ if mm_consts.EventFieldType.TIMESTAMP not in infer_results_df.columns:
370
336
  # Initialize timestamp column with the current time
371
- infer_results_df[EventFieldType.TIMESTAMP] = infer_datetime
337
+ infer_results_df[mm_consts.EventFieldType.TIMESTAMP] = infer_datetime
372
338
 
373
339
  # `endpoint_id` is the monitoring feature set entity and therefore it should be defined as the df index before
374
340
  # the ingest process
375
- infer_results_df[EventFieldType.ENDPOINT_ID] = endpoint_id
376
- infer_results_df.set_index(EventFieldType.ENDPOINT_ID, inplace=True)
341
+ infer_results_df[mm_consts.EventFieldType.ENDPOINT_ID] = endpoint_id
342
+ infer_results_df.set_index(mm_consts.EventFieldType.ENDPOINT_ID, inplace=True)
377
343
 
378
344
  monitoring_feature_set.ingest(source=infer_results_df, overwrite=False)
379
345
 
@@ -389,7 +355,7 @@ def _generate_model_endpoint(
389
355
  sample_set_statistics: dict[str, typing.Any],
390
356
  drift_threshold: float,
391
357
  possible_drift_threshold: float,
392
- monitoring_mode: ModelMonitoringMode = ModelMonitoringMode.disabled,
358
+ monitoring_mode: mm_consts.ModelMonitoringMode = mm_consts.ModelMonitoringMode.disabled,
393
359
  ) -> ModelEndpoint:
394
360
  """
395
361
  Write a new model endpoint record.
@@ -428,11 +394,11 @@ def _generate_model_endpoint(
428
394
  model_endpoint.spec.model_class = "drift-analysis"
429
395
  if drift_threshold:
430
396
  model_endpoint.spec.monitor_configuration[
431
- EventFieldType.DRIFT_DETECTED_THRESHOLD
397
+ mm_consts.EventFieldType.DRIFT_DETECTED_THRESHOLD
432
398
  ] = drift_threshold
433
399
  if possible_drift_threshold:
434
400
  model_endpoint.spec.monitor_configuration[
435
- EventFieldType.POSSIBLE_DRIFT_THRESHOLD
401
+ mm_consts.EventFieldType.POSSIBLE_DRIFT_THRESHOLD
436
402
  ] = possible_drift_threshold
437
403
 
438
404
  model_endpoint.spec.monitoring_mode = monitoring_mode
@@ -449,71 +415,6 @@ def _generate_model_endpoint(
449
415
  return db_session.get_model_endpoint(project=project, endpoint_id=endpoint_id)
450
416
 
451
417
 
452
- def trigger_drift_batch_job(
453
- project: str,
454
- default_batch_image="mlrun/mlrun",
455
- model_endpoints_ids: list[str] = None,
456
- batch_intervals_dict: dict[str, float] = None,
457
- db_session=None,
458
- ):
459
- """
460
- Run model monitoring drift analysis job. If not exists, the monitoring batch function will be registered through
461
- MLRun API with the provided image.
462
-
463
- :param project: Project name.
464
- :param default_batch_image: The image that will be used when registering the model monitoring batch job.
465
- :param model_endpoints_ids: List of model endpoints to include in the current run.
466
- :param batch_intervals_dict: Batch interval range (days, hours, minutes). By default, the batch interval is
467
- configured to run through the last hour.
468
- :param db_session: A runtime session that manages the current dialog with the database.
469
-
470
- """
471
- if not model_endpoints_ids:
472
- raise mlrun.errors.MLRunNotFoundError(
473
- "No model endpoints provided",
474
- )
475
- if not db_session:
476
- db_session = mlrun.get_run_db()
477
-
478
- # Register the monitoring batch job (do nothing if already exist) and get the job function as a dictionary
479
- batch_function_dict: dict[str, typing.Any] = db_session.deploy_monitoring_batch_job(
480
- project=project,
481
- default_batch_image=default_batch_image,
482
- )
483
-
484
- # Prepare current run params
485
- job_params = _generate_job_params(
486
- model_endpoints_ids=model_endpoints_ids,
487
- batch_intervals_dict=batch_intervals_dict,
488
- )
489
-
490
- # Generate runtime and trigger the job function
491
- batch_function = mlrun.new_function(runtime=batch_function_dict)
492
- batch_function.run(name="model-monitoring-batch", params=job_params, watch=True)
493
-
494
-
495
- def _generate_job_params(
496
- model_endpoints_ids: list[str],
497
- batch_intervals_dict: dict[str, float] = None,
498
- ):
499
- """
500
- Generate the required params for the model monitoring batch job function.
501
-
502
- :param model_endpoints_ids: List of model endpoints to include in the current run.
503
- :param batch_intervals_dict: Batch interval range (days, hours, minutes). By default, the batch interval is
504
- configured to run through the last hour.
505
-
506
- """
507
- if not batch_intervals_dict:
508
- # Generate default batch intervals dict
509
- batch_intervals_dict = {"minutes": 0, "hours": 1, "days": 0}
510
-
511
- return {
512
- "model_endpoints": model_endpoints_ids,
513
- "batch_intervals_dict": batch_intervals_dict,
514
- }
515
-
516
-
517
418
  def get_sample_set_statistics(
518
419
  sample_set: DatasetType = None,
519
420
  model_artifact_feature_stats: dict = None,
@@ -659,151 +560,6 @@ def read_dataset_as_dataframe(
659
560
  return dataset, label_columns
660
561
 
661
562
 
662
- def perform_drift_analysis(
663
- project: str,
664
- endpoint_id: str,
665
- context: mlrun.MLClientCtx,
666
- sample_set_statistics: dict,
667
- drift_threshold: float,
668
- possible_drift_threshold: float,
669
- artifacts_tag: str = "",
670
- db_session=None,
671
- ) -> None:
672
- """
673
- Calculate drift per feature and produce the drift table artifact for logging post prediction. Note that most of
674
- the calculations were already made through the monitoring batch job.
675
-
676
- :param project: Project name.
677
- :param endpoint_id: Model endpoint unique ID.
678
- :param context: MLRun context. Will log the artifacts.
679
- :param sample_set_statistics: The statistics of the sample set logged along a model.
680
- :param drift_threshold: The threshold of which to mark drifts.
681
- :param possible_drift_threshold: The threshold of which to mark possible drifts.
682
- :param artifacts_tag: Tag to use for all the artifacts resulted from the function.
683
- :param db_session: A runtime session that manages the current dialog with the database.
684
-
685
- """
686
- if not db_session:
687
- db_session = mlrun.get_run_db()
688
-
689
- model_endpoint = db_session.get_model_endpoint(
690
- project=project, endpoint_id=endpoint_id
691
- )
692
-
693
- # Get the drift metrics results along with the feature statistics from the latest batch
694
- metrics = model_endpoint.status.drift_measures
695
- inputs_statistics = model_endpoint.status.current_stats
696
-
697
- inputs_statistics.pop(EventFieldType.TIMESTAMP, None)
698
-
699
- # Calculate drift for each feature
700
- virtual_drift = VirtualDrift()
701
- drift_results = virtual_drift.check_for_drift_per_feature(
702
- metrics_results_dictionary=metrics,
703
- possible_drift_threshold=possible_drift_threshold,
704
- drift_detected_threshold=drift_threshold,
705
- )
706
-
707
- # Drift table plot
708
- html_plot = FeaturesDriftTablePlot().produce(
709
- sample_set_statistics=sample_set_statistics,
710
- inputs_statistics=inputs_statistics,
711
- metrics=metrics,
712
- drift_results=drift_results,
713
- )
714
-
715
- # Prepare drift result per feature dictionary
716
- metrics_per_feature = {
717
- feature: _get_drift_result(
718
- tvd=metric_dictionary["tvd"],
719
- hellinger=metric_dictionary["hellinger"],
720
- threshold=drift_threshold,
721
- )[1]
722
- for feature, metric_dictionary in metrics.items()
723
- if isinstance(metric_dictionary, dict)
724
- }
725
-
726
- # Calculate the final analysis result as well
727
- drift_status, drift_metric = _get_drift_result(
728
- tvd=metrics["tvd_mean"],
729
- hellinger=metrics["hellinger_mean"],
730
- threshold=drift_threshold,
731
- )
732
- # Log the different artifacts
733
- _log_drift_artifacts(
734
- context=context,
735
- html_plot=html_plot,
736
- metrics_per_feature=metrics_per_feature,
737
- drift_status=drift_status,
738
- drift_metric=drift_metric,
739
- artifacts_tag=artifacts_tag,
740
- )
741
-
742
-
743
- def _log_drift_artifacts(
744
- context: mlrun.MLClientCtx,
745
- html_plot: str,
746
- metrics_per_feature: dict[str, float],
747
- drift_status: bool,
748
- drift_metric: float,
749
- artifacts_tag: str,
750
- ):
751
- """
752
- Log the following artifacts/results:
753
- 1 - Drift table plot which includes a detailed drift analysis per feature
754
- 2 - Drift result per feature in a JSON format
755
- 3 - Results of the total drift analysis
756
-
757
- :param context: MLRun context. Will log the artifacts.
758
- :param html_plot: Body of the html file of the plot.
759
- :param metrics_per_feature: Dictionary in which the key is a feature name and the value is the drift numerical
760
- result.
761
- :param drift_status: Boolean value that represents the final drift analysis result.
762
- :param drift_metric: The final drift numerical result.
763
- :param artifacts_tag: Tag to use for all the artifacts resulted from the function.
764
-
765
- """
766
- context.log_artifact(
767
- mlrun.artifacts.Artifact(
768
- body=html_plot.encode("utf-8"), format="html", key="drift_table_plot"
769
- ),
770
- tag=artifacts_tag,
771
- )
772
- context.log_artifact(
773
- mlrun.artifacts.Artifact(
774
- body=json.dumps(metrics_per_feature),
775
- format="json",
776
- key="features_drift_results",
777
- ),
778
- tag=artifacts_tag,
779
- )
780
- context.log_results(
781
- results={"drift_status": drift_status, "drift_metric": drift_metric}
782
- )
783
-
784
-
785
- def _get_drift_result(
786
- tvd: float,
787
- hellinger: float,
788
- threshold: float,
789
- ) -> tuple[bool, float]:
790
- """
791
- Calculate the drift result by the following equation: (tvd + hellinger) / 2
792
-
793
- :param tvd: The feature's TVD value.
794
- :param hellinger: The feature's Hellinger value.
795
- :param threshold: The threshold from which the value is considered a drift.
796
-
797
- :returns: A tuple of:
798
- [0] = Boolean value as the drift status.
799
- [1] = The result.
800
- """
801
- result = (tvd + hellinger) / 2
802
- if result >= threshold:
803
- return True, result
804
- return False, result
805
-
806
-
807
563
  def log_result(
808
564
  context: mlrun.MLClientCtx,
809
565
  result_set_name: str,
@@ -826,3 +582,50 @@ def log_result(
826
582
  key="batch_id",
827
583
  value=batch_id,
828
584
  )
585
+
586
+
587
+ def _create_model_monitoring_function_base(
588
+ *,
589
+ project: str,
590
+ func: typing.Union[str, None] = None,
591
+ application_class: typing.Union[
592
+ str, mlrun.model_monitoring.application.ModelMonitoringApplicationBase, None
593
+ ] = None,
594
+ name: typing.Optional[str] = None,
595
+ image: typing.Optional[str] = None,
596
+ tag: typing.Optional[str] = None,
597
+ requirements: typing.Union[str, list[str], None] = None,
598
+ requirements_file: str = "",
599
+ **application_kwargs,
600
+ ) -> mlrun.runtimes.ServingRuntime:
601
+ """
602
+ Note: this is an internal API only.
603
+ This function does not set the labels or mounts v3io.
604
+ """
605
+ if func is None:
606
+ func = ""
607
+ func_obj = typing.cast(
608
+ mlrun.runtimes.ServingRuntime,
609
+ mlrun.code_to_function(
610
+ filename=func,
611
+ name=name,
612
+ project=project,
613
+ tag=tag,
614
+ kind=mlrun.run.RuntimeKinds.serving,
615
+ image=image,
616
+ requirements=requirements,
617
+ requirements_file=requirements_file,
618
+ ),
619
+ )
620
+ graph = func_obj.set_topology(mlrun.serving.states.StepKinds.flow)
621
+ if isinstance(application_class, str):
622
+ first_step = graph.to(class_name=application_class, **application_kwargs)
623
+ else:
624
+ first_step = graph.to(class_name=application_class)
625
+ first_step.to(
626
+ class_name="mlrun.model_monitoring.application.PushToMonitoringWriter",
627
+ name="PushToMonitoringWriter",
628
+ project=project,
629
+ writer_application_name=mm_consts.MonitoringFunctionNames.WRITER,
630
+ ).respond()
631
+ return func_obj
@@ -16,13 +16,13 @@ import dataclasses
16
16
  import json
17
17
  import re
18
18
  from abc import ABC, abstractmethod
19
- from typing import Any, Optional, Union
19
+ from typing import Any, Optional, Union, cast
20
20
 
21
21
  import numpy as np
22
22
  import pandas as pd
23
23
 
24
24
  import mlrun.common.helpers
25
- import mlrun.common.schemas.model_monitoring
25
+ import mlrun.common.model_monitoring.helpers
26
26
  import mlrun.common.schemas.model_monitoring.constants as mm_constant
27
27
  import mlrun.utils.v3io_clients
28
28
  from mlrun.datastore import get_stream_pusher
@@ -84,8 +84,8 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
84
84
  class MyApp(ApplicationBase):
85
85
  def do_tracking(
86
86
  self,
87
- sample_df_stats: pd.DataFrame,
88
- feature_stats: pd.DataFrame,
87
+ sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
88
+ feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
89
89
  start_infer_time: pd.Timestamp,
90
90
  end_infer_time: pd.Timestamp,
91
91
  schedule_time: pd.Timestamp,
@@ -93,7 +93,7 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
93
93
  endpoint_id: str,
94
94
  output_stream_uri: str,
95
95
  ) -> ModelMonitoringApplicationResult:
96
- self.context.log_artifact(TableArtifact("sample_df_stats", df=sample_df_stats))
96
+ self.context.log_artifact(TableArtifact("sample_df_stats", df=self.dict_to_histogram(sample_df_stats)))
97
97
  return ModelMonitoringApplicationResult(
98
98
  name="data_drift_test",
99
99
  value=0.5,
@@ -126,14 +126,16 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
126
126
  return results, event
127
127
 
128
128
  def _lazy_init(self, app_name: str):
129
- self.context = self._create_context_for_logging(app_name=app_name)
129
+ self.context = cast(
130
+ mlrun.MLClientCtx, self._create_context_for_logging(app_name=app_name)
131
+ )
130
132
 
131
133
  @abstractmethod
132
134
  def do_tracking(
133
135
  self,
134
136
  application_name: str,
135
- sample_df_stats: pd.DataFrame,
136
- feature_stats: pd.DataFrame,
137
+ sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
138
+ feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
137
139
  sample_df: pd.DataFrame,
138
140
  start_infer_time: pd.Timestamp,
139
141
  end_infer_time: pd.Timestamp,
@@ -147,8 +149,8 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
147
149
  Implement this method with your custom monitoring logic.
148
150
 
149
151
  :param application_name: (str) the app name
150
- :param sample_df_stats: (pd.DataFrame) The new sample distribution DataFrame.
151
- :param feature_stats: (pd.DataFrame) The train sample distribution DataFrame.
152
+ :param sample_df_stats: (FeatureStats) The new sample distribution dictionary.
153
+ :param feature_stats: (FeatureStats) The train sample distribution dictionary.
152
154
  :param sample_df: (pd.DataFrame) The new sample DataFrame.
153
155
  :param start_infer_time: (pd.Timestamp) Start time of the monitoring schedule.
154
156
  :param end_infer_time: (pd.Timestamp) End time of the monitoring schedule.
@@ -167,8 +169,8 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
167
169
  event: dict[str, Any],
168
170
  ) -> tuple[
169
171
  str,
170
- pd.DataFrame,
171
- pd.DataFrame,
172
+ mlrun.common.model_monitoring.helpers.FeatureStats,
173
+ mlrun.common.model_monitoring.helpers.FeatureStats,
172
174
  pd.DataFrame,
173
175
  pd.Timestamp,
174
176
  pd.Timestamp,
@@ -184,8 +186,8 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
184
186
 
185
187
  :return: A tuple of:
186
188
  [0] = (str) application name
187
- [1] = (pd.DataFrame) current input statistics
188
- [2] = (pd.DataFrame) train statistics
189
+ [1] = (dict) current input statistics
190
+ [2] = (dict) train statistics
189
191
  [3] = (pd.DataFrame) current input data
190
192
  [4] = (pd.Timestamp) start time of the monitoring schedule
191
193
  [5] = (pd.Timestamp) end time of the monitoring schedule
@@ -197,12 +199,8 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
197
199
  end_time = pd.Timestamp(event[mm_constant.ApplicationEvent.END_INFER_TIME])
198
200
  return (
199
201
  event[mm_constant.ApplicationEvent.APPLICATION_NAME],
200
- cls._dict_to_histogram(
201
- json.loads(event[mm_constant.ApplicationEvent.CURRENT_STATS])
202
- ),
203
- cls._dict_to_histogram(
204
- json.loads(event[mm_constant.ApplicationEvent.FEATURE_STATS])
205
- ),
202
+ json.loads(event[mm_constant.ApplicationEvent.CURRENT_STATS]),
203
+ json.loads(event[mm_constant.ApplicationEvent.FEATURE_STATS]),
206
204
  ParquetTarget(
207
205
  path=event[mm_constant.ApplicationEvent.SAMPLE_PARQUET_PATH]
208
206
  ).as_df(start_time=start_time, end_time=end_time, time_column="timestamp"),
@@ -223,7 +221,9 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
223
221
  return context
224
222
 
225
223
  @staticmethod
226
- def _dict_to_histogram(histogram_dict: dict[str, dict[str, Any]]) -> pd.DataFrame:
224
+ def dict_to_histogram(
225
+ histogram_dict: mlrun.common.model_monitoring.helpers.FeatureStats,
226
+ ) -> pd.DataFrame:
227
227
  """
228
228
  Convert histogram dictionary to pandas DataFrame with feature histograms as columns
229
229