mlrun 1.7.0rc3__py3-none-any.whl → 1.7.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (69) hide show
  1. mlrun/artifacts/manager.py +6 -1
  2. mlrun/common/constants.py +1 -0
  3. mlrun/common/model_monitoring/helpers.py +12 -6
  4. mlrun/common/schemas/__init__.py +1 -0
  5. mlrun/common/schemas/client_spec.py +1 -0
  6. mlrun/common/schemas/common.py +40 -0
  7. mlrun/common/schemas/model_monitoring/constants.py +4 -1
  8. mlrun/common/schemas/project.py +2 -0
  9. mlrun/config.py +20 -15
  10. mlrun/datastore/azure_blob.py +22 -9
  11. mlrun/datastore/base.py +15 -25
  12. mlrun/datastore/datastore.py +19 -8
  13. mlrun/datastore/datastore_profile.py +47 -5
  14. mlrun/datastore/google_cloud_storage.py +10 -6
  15. mlrun/datastore/hdfs.py +51 -0
  16. mlrun/datastore/redis.py +4 -0
  17. mlrun/datastore/s3.py +4 -0
  18. mlrun/datastore/sources.py +29 -43
  19. mlrun/datastore/targets.py +58 -48
  20. mlrun/datastore/utils.py +2 -49
  21. mlrun/datastore/v3io.py +4 -0
  22. mlrun/db/base.py +34 -0
  23. mlrun/db/httpdb.py +71 -42
  24. mlrun/execution.py +3 -3
  25. mlrun/feature_store/feature_vector.py +2 -2
  26. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +3 -3
  27. mlrun/frameworks/tf_keras/model_handler.py +7 -7
  28. mlrun/k8s_utils.py +10 -5
  29. mlrun/kfpops.py +19 -10
  30. mlrun/model.py +5 -0
  31. mlrun/model_monitoring/api.py +3 -3
  32. mlrun/model_monitoring/application.py +1 -1
  33. mlrun/model_monitoring/applications/__init__.py +13 -0
  34. mlrun/model_monitoring/applications/histogram_data_drift.py +218 -0
  35. mlrun/model_monitoring/batch.py +9 -111
  36. mlrun/model_monitoring/controller.py +73 -55
  37. mlrun/model_monitoring/controller_handler.py +13 -5
  38. mlrun/model_monitoring/features_drift_table.py +62 -53
  39. mlrun/model_monitoring/helpers.py +30 -21
  40. mlrun/model_monitoring/metrics/__init__.py +13 -0
  41. mlrun/model_monitoring/metrics/histogram_distance.py +127 -0
  42. mlrun/model_monitoring/stores/kv_model_endpoint_store.py +14 -14
  43. mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -1
  44. mlrun/package/packagers/pandas_packagers.py +3 -3
  45. mlrun/package/utils/_archiver.py +3 -1
  46. mlrun/platforms/iguazio.py +8 -65
  47. mlrun/projects/pipelines.py +21 -11
  48. mlrun/projects/project.py +121 -42
  49. mlrun/runtimes/base.py +21 -2
  50. mlrun/runtimes/kubejob.py +5 -3
  51. mlrun/runtimes/local.py +2 -2
  52. mlrun/runtimes/mpijob/abstract.py +6 -6
  53. mlrun/runtimes/nuclio/function.py +9 -9
  54. mlrun/runtimes/nuclio/serving.py +3 -3
  55. mlrun/runtimes/pod.py +3 -3
  56. mlrun/runtimes/sparkjob/spark3job.py +3 -3
  57. mlrun/serving/remote.py +4 -2
  58. mlrun/serving/server.py +2 -8
  59. mlrun/utils/async_http.py +3 -3
  60. mlrun/utils/helpers.py +27 -5
  61. mlrun/utils/http.py +3 -3
  62. mlrun/utils/notifications/notification_pusher.py +6 -6
  63. mlrun/utils/version/version.json +2 -2
  64. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc4.dist-info}/METADATA +13 -16
  65. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc4.dist-info}/RECORD +69 -63
  66. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc4.dist-info}/LICENSE +0 -0
  67. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc4.dist-info}/WHEEL +0 -0
  68. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc4.dist-info}/entry_points.txt +0 -0
  69. {mlrun-1.7.0rc3.dist-info → mlrun-1.7.0rc4.dist-info}/top_level.txt +0 -0
@@ -11,8 +11,10 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- #
15
- from typing import Union
14
+
15
+ import functools
16
+ import sys
17
+ from typing import Callable, Union
16
18
 
17
19
  import numpy as np
18
20
  import plotly.graph_objects as go
@@ -27,7 +29,7 @@ DriftResultType = tuple[mlrun.common.schemas.model_monitoring.DriftStatus, float
27
29
  class FeaturesDriftTablePlot:
28
30
  """
29
31
  Class for producing a features drift table. The plot is a table with columns of all the statistics and metrics
30
- provided with two additional plot columns of the histograms and drift notification. The rows content will be drawn
32
+ provided with two additional plot columns of the histograms and drift status. The rows content will be drawn
31
33
  per feature.
32
34
 
33
35
  For example, if the statistics are 'mean', 'min', 'max' and one metric of 'tvd', for 3 features the table will be:
@@ -47,7 +49,7 @@ class FeaturesDriftTablePlot:
47
49
  70 # The width for the values of all the statistics and metrics columns.
48
50
  )
49
51
  _HISTOGRAMS_COLUMN_WIDTH = 180
50
- _NOTIFICATIONS_COLUMN_WIDTH = 20
52
+ _STATUS_COLUMN_WIDTH = 20
51
53
 
52
54
  # Table rows heights:
53
55
  _HEADER_ROW_HEIGHT = 25
@@ -56,9 +58,10 @@ class FeaturesDriftTablePlot:
56
58
  # Histograms configurations:
57
59
  _SAMPLE_SET_HISTOGRAM_COLOR = "rgb(0,112,192)" # Blue
58
60
  _INPUTS_HISTOGRAM_COLOR = "rgb(208,0,106)" # Magenta
61
+ _HISTOGRAM_OPACITY = 0.75
59
62
 
60
- # Notification configurations:
61
- _NOTIFICATION_COLORS = {
63
+ # Status configurations:
64
+ _STATUS_COLORS = {
62
65
  mlrun.common.schemas.model_monitoring.DriftStatus.NO_DRIFT: "rgb(0,176,80)", # Green
63
66
  mlrun.common.schemas.model_monitoring.DriftStatus.POSSIBLE_DRIFT: "rgb(255,192,0)", # Orange
64
67
  mlrun.common.schemas.model_monitoring.DriftStatus.DRIFT_DETECTED: "rgb(208,0,106)", # Magenta
@@ -78,9 +81,6 @@ class FeaturesDriftTablePlot:
78
81
  _BACKGROUND_COLOR = "rgb(255,255,255)" # White
79
82
  _SEPARATORS_COLOR = "rgb(240,240,240)" # Light grey
80
83
 
81
- # File name:
82
- _FILE_NAME = "table_plot.html"
83
-
84
84
  def __init__(self):
85
85
  """
86
86
  Initialize the plot producer for later calling the `produce` method.
@@ -198,7 +198,7 @@ class FeaturesDriftTablePlot:
198
198
  self._FEATURE_NAME_COLUMN_WIDTH,
199
199
  *self._value_columns_widths,
200
200
  self._HISTOGRAMS_COLUMN_WIDTH,
201
- self._NOTIFICATIONS_COLUMN_WIDTH,
201
+ self._STATUS_COLUMN_WIDTH,
202
202
  ],
203
203
  header_fill_color=self._BACKGROUND_COLOR,
204
204
  )
@@ -222,7 +222,7 @@ class FeaturesDriftTablePlot:
222
222
  [self._FEATURE_NAME_COLUMN_WIDTH]
223
223
  + [self._VALUE_COLUMN_WIDTH]
224
224
  * (2 * len(self._statistics_columns) + len(self._metrics_columns))
225
- + [self._HISTOGRAMS_COLUMN_WIDTH, self._NOTIFICATIONS_COLUMN_WIDTH]
225
+ + [self._HISTOGRAMS_COLUMN_WIDTH, self._STATUS_COLUMN_WIDTH]
226
226
  ),
227
227
  header_fill_color=self._BACKGROUND_COLOR,
228
228
  )
@@ -332,25 +332,25 @@ class FeaturesDriftTablePlot:
332
332
 
333
333
  return feature_row_table
334
334
 
335
- def _plot_histogram_scatters(
336
- self, sample_hist: tuple[list, list], input_hist: tuple[list, list]
337
- ) -> tuple[go.Scatter, go.Scatter]:
335
+ def _plot_histogram_bars(
336
+ self,
337
+ figure_add_trace: Callable,
338
+ sample_hist: tuple[list, list],
339
+ input_hist: tuple[list, list],
340
+ showlegend: bool = False,
341
+ ) -> None:
338
342
  """
339
- Plot the feature's histograms to include in the "histograms" column. Both histograms are returned to later be
340
- added in the same figure, so they will be on top of each other and not separated. Both histograms are rescaled
343
+ Plot the feature's histograms to include in the "histograms" column. Both histograms are rescaled
341
344
  to be from 0.0 to 1.0, so they will be drawn in the same scale regardless the amount of elements they were
342
345
  calculated upon.
343
346
 
344
- :param sample_hist: The sample set histogram data.
345
- :param input_hist: The input histogram data.
347
+ :param figure_add_trace: The figure's method that get the histogram and adds it to the figure.
348
+ :param sample_hist: The sample set histogram data.
349
+ :param input_hist: The input histogram data.
350
+ :param showlegend: Show the legend for each histogram or not.
346
351
 
347
- :return: A tuple with both histograms - `Scatter` traces:
348
- [0] - Sample set histogram.
349
- [1] - Input histogram.
352
+ :return: None
350
353
  """
351
- # Initialize a list to collect the scatters:
352
- scatters = []
353
-
354
354
  # Plot the histograms:
355
355
  for name, color, histogram in zip(
356
356
  ["sample", "input"],
@@ -361,23 +361,29 @@ class FeaturesDriftTablePlot:
361
361
  counts, bins = histogram
362
362
  # Rescale the counts to be in percentages (between 0.0 to 1.0):
363
363
  counts = np.array(counts) / sum(counts)
364
+ hovertext = [""] * len(counts)
364
365
  # Convert to NumPy for vectorization:
365
366
  bins = np.array(bins)
367
+ if bins[0] == -sys.float_info.max:
368
+ bins[0] = bins[1] - (bins[2] - bins[1])
369
+ hovertext[0] = f"(-∞, {bins[1]})"
370
+ if bins[-1] == sys.float_info.max:
371
+ bins[-1] = bins[-2] + (bins[-2] - bins[-3])
372
+ hovertext[-1] = f"({bins[-2]}, ∞)"
366
373
  # Center the bins (leave the first one):
367
374
  bins = 0.5 * (bins[:-1] + bins[1:])
368
375
  # Plot the histogram as a line with filled background below it:
369
- histogram_scatter = go.Scatter(
376
+ histogram_bar = go.Bar(
370
377
  x=bins,
371
378
  y=counts,
372
- fill="tozeroy",
373
379
  name=name,
374
- line_shape="spline", # Make the line rounder.
375
- line={"color": color},
380
+ marker_color=color,
381
+ opacity=self._HISTOGRAM_OPACITY,
376
382
  legendgroup=name,
383
+ hovertext=hovertext,
384
+ showlegend=showlegend,
377
385
  )
378
- scatters.append(histogram_scatter)
379
-
380
- return scatters[0], scatters[1]
386
+ figure_add_trace(histogram_bar)
381
387
 
382
388
  def _calculate_row_height(self, features: list[str]) -> int:
383
389
  """
@@ -399,7 +405,7 @@ class FeaturesDriftTablePlot:
399
405
  self._FEATURE_ROW_HEIGHT, 1.5 * self._FONT_SIZE * feature_name_seperations
400
406
  )
401
407
 
402
- def _plot_notification_circle(
408
+ def _plot_status_circle(
403
409
  self,
404
410
  figure: go.Figure,
405
411
  row: int,
@@ -407,8 +413,8 @@ class FeaturesDriftTablePlot:
407
413
  drift_result: DriftResultType,
408
414
  ):
409
415
  """
410
- Plot the drift notification - a little circle with color as configured in the class. The color will beb chosen
411
- according to the drift status given.
416
+ Plot the drift status - a little circle with color as configured in the
417
+ class. The color will be chosen according to the drift status given.
412
418
 
413
419
  :param figure: The figure (feature row cell) to draw the circle in.
414
420
  :param row: The row number.
@@ -420,12 +426,12 @@ class FeaturesDriftTablePlot:
420
426
  # row 3) times the plot columns (2 columns has axes in each row) + 2 (to get to the column of the notification):
421
427
  axis_number = (row - 3) * 2 + 2
422
428
  figure["layout"][f"xaxis{axis_number}"].update(
423
- range=[0, self._NOTIFICATIONS_COLUMN_WIDTH]
429
+ range=[0, self._STATUS_COLUMN_WIDTH]
424
430
  )
425
431
  figure["layout"][f"yaxis{axis_number}"].update(range=[0, row_height])
426
432
 
427
433
  # Get the color:
428
- notification_color = self._NOTIFICATION_COLORS[drift_result[0]]
434
+ notification_color = self._STATUS_COLORS[drift_result[0]]
429
435
  half_transparent_notification_color = notification_color.replace(
430
436
  "rgb", "rgba"
431
437
  ).replace(")", ",0.5)")
@@ -434,8 +440,8 @@ class FeaturesDriftTablePlot:
434
440
  # size of the text as well):
435
441
  y0 = 36 + (row_height - self._FEATURE_ROW_HEIGHT)
436
442
  y1 = y0 + self._FONT_SIZE
437
- x0 = (self._NOTIFICATIONS_COLUMN_WIDTH / 2) - ((y1 - y0) / 2)
438
- x1 = (self._NOTIFICATIONS_COLUMN_WIDTH / 2) + ((y1 - y0) / 2)
443
+ x0 = (self._STATUS_COLUMN_WIDTH / 2) - ((y1 - y0) / 2)
444
+ x1 = (self._STATUS_COLUMN_WIDTH / 2) + ((y1 - y0) / 2)
439
445
 
440
446
  # Draw the circle on top of the figure:
441
447
  figure.add_shape(
@@ -486,7 +492,7 @@ class FeaturesDriftTablePlot:
486
492
  self._FEATURE_NAME_COLUMN_WIDTH
487
493
  + sum(self._value_columns_widths)
488
494
  + self._HISTOGRAMS_COLUMN_WIDTH
489
- + self._NOTIFICATIONS_COLUMN_WIDTH
495
+ + self._STATUS_COLUMN_WIDTH
490
496
  )
491
497
  height = 2 * self._HEADER_ROW_HEIGHT + len(features) * row_height
492
498
 
@@ -507,7 +513,7 @@ class FeaturesDriftTablePlot:
507
513
  (self._FEATURE_NAME_COLUMN_WIDTH + sum(self._value_columns_widths))
508
514
  / width,
509
515
  self._HISTOGRAMS_COLUMN_WIDTH / width,
510
- self._NOTIFICATIONS_COLUMN_WIDTH / width,
516
+ self._STATUS_COLUMN_WIDTH / width,
511
517
  ],
512
518
  horizontal_spacing=0,
513
519
  vertical_spacing=0,
@@ -518,9 +524,11 @@ class FeaturesDriftTablePlot:
518
524
  main_figure.add_trace(header_trace, row=1, col=1)
519
525
  main_figure.add_trace(sub_header_trace, row=2, col=1)
520
526
 
521
- # Start going over the features and plot each row, histogram and notification:
522
- row = 3 # We are currently at row 3 counting the headers.
523
- for feature in features:
527
+ # Start going over the features and plot each row, histogram and status
528
+ for row, feature in enumerate(
529
+ features,
530
+ start=3, # starting from row 3 after the headers
531
+ ):
524
532
  try:
525
533
  # Add the feature values:
526
534
  main_figure.add_trace(
@@ -543,23 +551,22 @@ class FeaturesDriftTablePlot:
543
551
  f"{inputs_statistics.keys() = }\n"
544
552
  )
545
553
  # Add the histograms (both traces are added to the same subplot figure):
546
- sample_hist, input_hist = self._plot_histogram_scatters(
554
+ self._plot_histogram_bars(
555
+ figure_add_trace=functools.partial(
556
+ main_figure.add_trace, row=row, col=2
557
+ ),
547
558
  sample_hist=sample_set_statistics[feature]["hist"],
548
559
  input_hist=inputs_statistics[feature]["hist"],
560
+ # Only the first row should have its legend visible
561
+ showlegend=(row == 3),
549
562
  )
550
- if row != 3: # Only the first row should have its legend visible:
551
- sample_hist.showlegend = False
552
- input_hist.showlegend = False
553
- main_figure.add_trace(sample_hist, row=row, col=2)
554
- main_figure.add_trace(input_hist, row=row, col=2)
555
- # Add the notification (a circle with color according to the drift alert):
556
- self._plot_notification_circle(
563
+ # Add the status (a circle with color according to the drift status)
564
+ self._plot_status_circle(
557
565
  figure=main_figure,
558
566
  row=row,
559
567
  row_height=row_height,
560
568
  drift_result=drift_results[feature],
561
569
  )
562
- row += 1
563
570
 
564
571
  # Configure the layout and axes for height and widths:
565
572
  main_figure.update_layout(
@@ -576,9 +583,11 @@ class FeaturesDriftTablePlot:
576
583
  "yanchor": "top",
577
584
  "y": 1.0 - (self._HEADER_ROW_HEIGHT / height) + 0.002,
578
585
  "xanchor": "right",
579
- "x": 1.0 - (self._NOTIFICATIONS_COLUMN_WIDTH / width) - 0.01,
586
+ "x": 1.0 - (self._STATUS_COLUMN_WIDTH / width) - 0.01,
580
587
  "bgcolor": "rgba(0,0,0,0)",
581
588
  },
589
+ barmode="overlay",
590
+ bargap=0,
582
591
  )
583
592
  main_figure.update_xaxes(
584
593
  showticklabels=False,
@@ -20,15 +20,14 @@ import mlrun.common.model_monitoring.helpers
20
20
  import mlrun.common.schemas
21
21
  from mlrun.common.schemas.model_monitoring import (
22
22
  EventFieldType,
23
- MonitoringFunctionNames,
24
23
  )
25
- from mlrun.errors import MLRunValueError
26
24
  from mlrun.model_monitoring.model_endpoint import ModelEndpoint
27
25
  from mlrun.utils import logger
28
26
 
29
27
  if typing.TYPE_CHECKING:
30
28
  from mlrun.db.base import RunDBInterface
31
29
  from mlrun.projects import MlrunProject
30
+ import mlrun.common.schemas.model_monitoring.constants as mm_constants
32
31
 
33
32
 
34
33
  class _BatchDict(typing.TypedDict):
@@ -41,29 +40,32 @@ class _MLRunNoRunsFoundError(Exception):
41
40
  pass
42
41
 
43
42
 
44
- def get_stream_path(project: str = None, application_name: str = None):
43
+ def get_stream_path(
44
+ project: str = None,
45
+ function_name: str = mm_constants.MonitoringFunctionNames.STREAM,
46
+ ):
45
47
  """
46
48
  Get stream path from the project secret. If wasn't set, take it from the system configurations
47
49
 
48
50
  :param project: Project name.
49
- :param application_name: Application name, None for model_monitoring_stream.
51
+ :param function_name: Application name. Default is model_monitoring_stream.
50
52
 
51
53
  :return: Monitoring stream path to the relevant application.
52
54
  """
53
55
 
54
56
  stream_uri = mlrun.get_secret_or_env(
55
57
  mlrun.common.schemas.model_monitoring.ProjectSecretKeys.STREAM_PATH
56
- if application_name is None
58
+ if function_name is mm_constants.MonitoringFunctionNames.STREAM
57
59
  else ""
58
60
  ) or mlrun.mlconf.get_model_monitoring_file_target_path(
59
61
  project=project,
60
62
  kind=mlrun.common.schemas.model_monitoring.FileTargetKind.STREAM,
61
63
  target="online",
62
- application_name=application_name,
64
+ function_name=function_name,
63
65
  )
64
66
 
65
67
  return mlrun.common.model_monitoring.helpers.parse_monitoring_stream_path(
66
- stream_uri=stream_uri, project=project, application_name=application_name
68
+ stream_uri=stream_uri, project=project, function_name=function_name
67
69
  )
68
70
 
69
71
 
@@ -125,24 +127,31 @@ def _get_monitoring_time_window_from_controller_run(
125
127
  project: str, db: "RunDBInterface"
126
128
  ) -> datetime.timedelta:
127
129
  """
128
- Get timedelta for the controller to run.
130
+ Get the base period form the controller.
129
131
 
130
132
  :param project: Project name.
131
133
  :param db: DB interface.
132
134
 
133
135
  :return: Timedelta for the controller to run.
136
+ :raise: MLRunNotFoundError if the controller isn't deployed yet
134
137
  """
135
- run_name = MonitoringFunctionNames.APPLICATION_CONTROLLER
136
- runs = db.list_runs(project=project, name=run_name, sort=True)
137
- if not runs:
138
- raise _MLRunNoRunsFoundError(f"No {run_name} runs were found")
139
- last_run = runs[0]
140
- try:
141
- batch_dict = last_run["spec"]["parameters"]["batch_intervals_dict"]
142
- except KeyError:
143
- raise MLRunValueError(
144
- f"Could not find `batch_intervals_dict` in {run_name} run"
145
- )
138
+
139
+ controller = db.get_function(
140
+ name=mm_constants.MonitoringFunctionNames.APPLICATION_CONTROLLER,
141
+ project=project,
142
+ )
143
+ if isinstance(controller, dict):
144
+ controller = mlrun.runtimes.RemoteRuntime.from_dict(controller)
145
+ elif not hasattr(controller, "to_dict"):
146
+ raise mlrun.errors.MLRunNotFoundError()
147
+ base_period = controller.spec.config["spec.triggers.cron_interval"]["attributes"][
148
+ "interval"
149
+ ]
150
+ batch_dict = {
151
+ mm_constants.EventFieldType.MINUTES: int(base_period[:-1]),
152
+ mm_constants.EventFieldType.HOURS: 0,
153
+ mm_constants.EventFieldType.DAYS: 0,
154
+ }
146
155
  return batch_dict2timedelta(batch_dict)
147
156
 
148
157
 
@@ -177,9 +186,9 @@ def update_model_endpoint_last_request(
177
186
  else:
178
187
  try:
179
188
  time_window = _get_monitoring_time_window_from_controller_run(project, db)
180
- except _MLRunNoRunsFoundError:
189
+ except mlrun.errors.MLRunNotFoundError:
181
190
  logger.debug(
182
- "Not bumping model endpoint last request time - no controller runs were found"
191
+ "Not bumping model endpoint last request time - the monitoring controller isn't deployed yet"
183
192
  )
184
193
  return
185
194
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,127 @@
1
+ # Copyright 2024 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+ import dataclasses
17
+ from typing import ClassVar, Optional
18
+
19
+ import numpy as np
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class HistogramDistanceMetric(abc.ABC):
24
+ """
25
+ An abstract base class for distance metrics between histograms.
26
+
27
+ :args distrib_t: array of distribution t (usually the latest dataset distribution)
28
+ :args distrib_u: array of distribution u (usually the sample dataset distribution)
29
+
30
+ Each distribution must contain nonnegative floats that sum up to 1.0.
31
+ """
32
+
33
+ distrib_t: np.ndarray
34
+ distrib_u: np.ndarray
35
+
36
+ NAME: ClassVar[str]
37
+
38
+ # noinspection PyMethodOverriding
39
+ def __init_subclass__(cls, *, metric_name: str, **kwargs) -> None:
40
+ super().__init_subclass__(**kwargs)
41
+ cls.NAME = metric_name
42
+
43
+ @abc.abstractmethod
44
+ def compute(self) -> float:
45
+ raise NotImplementedError
46
+
47
+
48
+ class TotalVarianceDistance(HistogramDistanceMetric, metric_name="tvd"):
49
+ """
50
+ Provides a symmetric drift distance between two periods t and u
51
+ Z - vector of random variables
52
+ Pt - Probability distribution over time span t
53
+ """
54
+
55
+ def compute(self) -> float:
56
+ """
57
+ Calculate Total Variance distance.
58
+
59
+ :returns: Total Variance Distance.
60
+ """
61
+ return np.sum(np.abs(self.distrib_t - self.distrib_u)) / 2
62
+
63
+
64
+ class HellingerDistance(HistogramDistanceMetric, metric_name="hellinger"):
65
+ """
66
+ Hellinger distance is an f divergence measure, similar to the Kullback-Leibler (KL) divergence.
67
+ It used to quantify the difference between two probability distributions.
68
+ However, unlike KL Divergence the Hellinger divergence is symmetric and bounded over a probability space.
69
+ The output range of Hellinger distance is [0,1]. The closer to 0, the more similar the two distributions.
70
+ """
71
+
72
+ def compute(self) -> float:
73
+ """
74
+ Calculate Hellinger Distance
75
+
76
+ :returns: Hellinger Distance
77
+ """
78
+ return np.sqrt(
79
+ max(
80
+ 1 - np.sum(np.sqrt(self.distrib_u * self.distrib_t)),
81
+ 0, # numerical errors may produce small negative numbers, e.g. -1e-16.
82
+ # However, Cauchy-Schwarz inequality assures this number is in the range [0, 1]
83
+ )
84
+ )
85
+
86
+
87
+ class KullbackLeiblerDivergence(HistogramDistanceMetric, metric_name="kld"):
88
+ """
89
+ KL Divergence (or relative entropy) is a measure of how one probability distribution differs from another.
90
+ It is an asymmetric measure (thus it's not a metric) and it doesn't satisfy the triangle inequality.
91
+ KL Divergence of 0, indicates two identical distributions.
92
+ """
93
+
94
+ @staticmethod
95
+ def _calc_kl_div(
96
+ actual_dist: np.ndarray, expected_dist: np.ndarray, zero_scaling: float
97
+ ) -> float:
98
+ """Return the asymmetric KL divergence"""
99
+ # We take 0*log(0) == 0 for this calculation
100
+ mask = actual_dist != 0
101
+ actual_dist = actual_dist[mask]
102
+ expected_dist = expected_dist[mask]
103
+ with np.errstate(over="ignore"):
104
+ # Ignore overflow warnings when dividing by small numbers,
105
+ # resulting in inf:
106
+ # RuntimeWarning: overflow encountered in true_divide
107
+ relative_prob = actual_dist / np.where(
108
+ expected_dist != 0, expected_dist, zero_scaling
109
+ )
110
+ return np.sum(actual_dist * np.log(relative_prob))
111
+
112
+ def compute(
113
+ self, capping: Optional[float] = None, zero_scaling: float = 1e-4
114
+ ) -> float:
115
+ """
116
+ :param capping: A bounded value for the KL Divergence. For infinite distance, the result is replaced with
117
+ the capping value which indicates a huge differences between the distributions.
118
+ :param zero_scaling: Will be used to replace 0 values for executing the logarithmic operation.
119
+
120
+ :returns: symmetric KL Divergence
121
+ """
122
+ t_u = self._calc_kl_div(self.distrib_t, self.distrib_u, zero_scaling)
123
+ u_t = self._calc_kl_div(self.distrib_u, self.distrib_t, zero_scaling)
124
+ result = t_u + u_t
125
+ if capping and result == float("inf"):
126
+ return capping
127
+ return result
@@ -302,7 +302,7 @@ class KVModelEndpointStore(ModelEndpointStore):
302
302
  )
303
303
  # Final cleanup of tsdb path
304
304
  tsdb_path.replace("://u", ":///u")
305
- store, _ = mlrun.store_manager.get_or_create_store(tsdb_path)
305
+ store, _, _ = mlrun.store_manager.get_or_create_store(tsdb_path)
306
306
  store.rm(tsdb_path, recursive=True)
307
307
 
308
308
  def get_endpoint_real_time_metrics(
@@ -538,24 +538,24 @@ class KVModelEndpointStore(ModelEndpointStore):
538
538
  and endpoint[mlrun.common.schemas.model_monitoring.EventFieldType.METRICS]
539
539
  == "null"
540
540
  ):
541
- endpoint[
542
- mlrun.common.schemas.model_monitoring.EventFieldType.METRICS
543
- ] = json.dumps(
544
- {
545
- mlrun.common.schemas.model_monitoring.EventKeyMetrics.GENERIC: {
546
- mlrun.common.schemas.model_monitoring.EventLiveStats.LATENCY_AVG_1H: 0,
547
- mlrun.common.schemas.model_monitoring.EventLiveStats.PREDICTIONS_PER_SECOND: 0,
541
+ endpoint[mlrun.common.schemas.model_monitoring.EventFieldType.METRICS] = (
542
+ json.dumps(
543
+ {
544
+ mlrun.common.schemas.model_monitoring.EventKeyMetrics.GENERIC: {
545
+ mlrun.common.schemas.model_monitoring.EventLiveStats.LATENCY_AVG_1H: 0,
546
+ mlrun.common.schemas.model_monitoring.EventLiveStats.PREDICTIONS_PER_SECOND: 0,
547
+ }
548
548
  }
549
- }
549
+ )
550
550
  )
551
551
  # Validate key `uid` instead of `endpoint_id`
552
552
  # For backwards compatibility reasons, we replace the `endpoint_id` with `uid` which is the updated key name
553
553
  if mlrun.common.schemas.model_monitoring.EventFieldType.ENDPOINT_ID in endpoint:
554
- endpoint[
555
- mlrun.common.schemas.model_monitoring.EventFieldType.UID
556
- ] = endpoint[
557
- mlrun.common.schemas.model_monitoring.EventFieldType.ENDPOINT_ID
558
- ]
554
+ endpoint[mlrun.common.schemas.model_monitoring.EventFieldType.UID] = (
555
+ endpoint[
556
+ mlrun.common.schemas.model_monitoring.EventFieldType.ENDPOINT_ID
557
+ ]
558
+ )
559
559
 
560
560
  @staticmethod
561
561
  def _encode_field(field: typing.Union[str, bytes]) -> bytes:
@@ -31,7 +31,6 @@ from .models import get_model_endpoints_table
31
31
 
32
32
 
33
33
  class SQLModelEndpointStore(ModelEndpointStore):
34
-
35
34
  """
36
35
  Handles the DB operations when the DB target is from type SQL. For the SQL operations, we use SQLAlchemy, a Python
37
36
  SQL toolkit that handles the communication with the database. When using SQL for storing the model endpoints
@@ -838,9 +838,9 @@ class PandasDataFramePackager(DefaultPackager):
838
838
  """
839
839
  if isinstance(obj, dict):
840
840
  for key, value in obj.items():
841
- obj[
842
- PandasDataFramePackager._prepare_result(obj=key)
843
- ] = PandasDataFramePackager._prepare_result(obj=value)
841
+ obj[PandasDataFramePackager._prepare_result(obj=key)] = (
842
+ PandasDataFramePackager._prepare_result(obj=value)
843
+ )
844
844
  elif isinstance(obj, list):
845
845
  for i, value in enumerate(obj):
846
846
  obj[i] = PandasDataFramePackager._prepare_result(obj=value)
@@ -179,7 +179,9 @@ class _TarArchiver(_Archiver):
179
179
 
180
180
  # Extract:
181
181
  with tarfile.open(archive_path, f"r:{cls._MODE_STRING}") as tar_file:
182
- tar_file.extractall(directory_path)
182
+ # use 'data' to ensure no security risks are imposed by the archive files
183
+ # see: https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall
184
+ tar_file.extractall(directory_path, filter="data")
183
185
 
184
186
  return str(directory_path)
185
187