mlrun 1.6.0rc35__py3-none-any.whl → 1.7.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__main__.py +3 -3
- mlrun/api/schemas/__init__.py +1 -1
- mlrun/artifacts/base.py +11 -6
- mlrun/artifacts/dataset.py +2 -2
- mlrun/artifacts/model.py +30 -24
- mlrun/artifacts/plots.py +2 -2
- mlrun/common/db/sql_session.py +5 -3
- mlrun/common/helpers.py +1 -2
- mlrun/common/schemas/artifact.py +3 -3
- mlrun/common/schemas/auth.py +3 -3
- mlrun/common/schemas/background_task.py +1 -1
- mlrun/common/schemas/client_spec.py +1 -1
- mlrun/common/schemas/feature_store.py +16 -16
- mlrun/common/schemas/frontend_spec.py +7 -7
- mlrun/common/schemas/function.py +1 -1
- mlrun/common/schemas/hub.py +4 -9
- mlrun/common/schemas/memory_reports.py +2 -2
- mlrun/common/schemas/model_monitoring/grafana.py +4 -4
- mlrun/common/schemas/model_monitoring/model_endpoints.py +14 -15
- mlrun/common/schemas/notification.py +4 -4
- mlrun/common/schemas/object.py +2 -2
- mlrun/common/schemas/pipeline.py +1 -1
- mlrun/common/schemas/project.py +3 -3
- mlrun/common/schemas/runtime_resource.py +8 -12
- mlrun/common/schemas/schedule.py +3 -3
- mlrun/common/schemas/tag.py +1 -2
- mlrun/common/schemas/workflow.py +2 -2
- mlrun/config.py +8 -4
- mlrun/data_types/to_pandas.py +1 -3
- mlrun/datastore/base.py +0 -28
- mlrun/datastore/datastore_profile.py +9 -9
- mlrun/datastore/filestore.py +0 -1
- mlrun/datastore/google_cloud_storage.py +1 -1
- mlrun/datastore/sources.py +7 -11
- mlrun/datastore/spark_utils.py +1 -2
- mlrun/datastore/targets.py +31 -31
- mlrun/datastore/utils.py +4 -6
- mlrun/datastore/v3io.py +70 -46
- mlrun/db/base.py +22 -23
- mlrun/db/httpdb.py +34 -34
- mlrun/db/nopdb.py +19 -19
- mlrun/errors.py +1 -1
- mlrun/execution.py +4 -4
- mlrun/feature_store/api.py +20 -21
- mlrun/feature_store/common.py +1 -1
- mlrun/feature_store/feature_set.py +28 -32
- mlrun/feature_store/feature_vector.py +24 -27
- mlrun/feature_store/retrieval/base.py +7 -7
- mlrun/feature_store/retrieval/conversion.py +2 -4
- mlrun/feature_store/steps.py +7 -15
- mlrun/features.py +5 -7
- mlrun/frameworks/_common/artifacts_library.py +9 -9
- mlrun/frameworks/_common/mlrun_interface.py +5 -5
- mlrun/frameworks/_common/model_handler.py +48 -48
- mlrun/frameworks/_common/plan.py +2 -3
- mlrun/frameworks/_common/producer.py +3 -4
- mlrun/frameworks/_common/utils.py +5 -5
- mlrun/frameworks/_dl_common/loggers/logger.py +6 -7
- mlrun/frameworks/_dl_common/loggers/mlrun_logger.py +9 -9
- mlrun/frameworks/_dl_common/loggers/tensorboard_logger.py +16 -35
- mlrun/frameworks/_ml_common/artifacts_library.py +1 -2
- mlrun/frameworks/_ml_common/loggers/logger.py +3 -4
- mlrun/frameworks/_ml_common/loggers/mlrun_logger.py +4 -5
- mlrun/frameworks/_ml_common/model_handler.py +24 -24
- mlrun/frameworks/_ml_common/pkl_model_server.py +2 -2
- mlrun/frameworks/_ml_common/plan.py +1 -1
- mlrun/frameworks/_ml_common/plans/calibration_curve_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/confusion_matrix_plan.py +2 -3
- mlrun/frameworks/_ml_common/plans/dataset_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/feature_importance_plan.py +3 -3
- mlrun/frameworks/_ml_common/plans/roc_curve_plan.py +4 -4
- mlrun/frameworks/_ml_common/utils.py +4 -4
- mlrun/frameworks/auto_mlrun/auto_mlrun.py +7 -7
- mlrun/frameworks/huggingface/model_server.py +4 -4
- mlrun/frameworks/lgbm/__init__.py +32 -32
- mlrun/frameworks/lgbm/callbacks/logging_callback.py +4 -5
- mlrun/frameworks/lgbm/callbacks/mlrun_logging_callback.py +4 -5
- mlrun/frameworks/lgbm/mlrun_interfaces/booster_mlrun_interface.py +1 -3
- mlrun/frameworks/lgbm/mlrun_interfaces/mlrun_interface.py +6 -6
- mlrun/frameworks/lgbm/model_handler.py +9 -9
- mlrun/frameworks/lgbm/model_server.py +6 -6
- mlrun/frameworks/lgbm/utils.py +5 -5
- mlrun/frameworks/onnx/dataset.py +8 -8
- mlrun/frameworks/onnx/mlrun_interface.py +3 -3
- mlrun/frameworks/onnx/model_handler.py +6 -6
- mlrun/frameworks/onnx/model_server.py +7 -7
- mlrun/frameworks/parallel_coordinates.py +2 -2
- mlrun/frameworks/pytorch/__init__.py +16 -16
- mlrun/frameworks/pytorch/callbacks/callback.py +4 -5
- mlrun/frameworks/pytorch/callbacks/logging_callback.py +17 -17
- mlrun/frameworks/pytorch/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/pytorch/callbacks/tensorboard_logging_callback.py +23 -29
- mlrun/frameworks/pytorch/callbacks_handler.py +38 -38
- mlrun/frameworks/pytorch/mlrun_interface.py +20 -20
- mlrun/frameworks/pytorch/model_handler.py +17 -17
- mlrun/frameworks/pytorch/model_server.py +7 -7
- mlrun/frameworks/sklearn/__init__.py +12 -12
- mlrun/frameworks/sklearn/estimator.py +4 -4
- mlrun/frameworks/sklearn/metrics_library.py +14 -14
- mlrun/frameworks/sklearn/mlrun_interface.py +3 -6
- mlrun/frameworks/sklearn/model_handler.py +2 -2
- mlrun/frameworks/tf_keras/__init__.py +5 -5
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +14 -14
- mlrun/frameworks/tf_keras/callbacks/mlrun_logging_callback.py +11 -11
- mlrun/frameworks/tf_keras/callbacks/tensorboard_logging_callback.py +19 -23
- mlrun/frameworks/tf_keras/mlrun_interface.py +7 -9
- mlrun/frameworks/tf_keras/model_handler.py +14 -14
- mlrun/frameworks/tf_keras/model_server.py +6 -6
- mlrun/frameworks/xgboost/__init__.py +12 -12
- mlrun/frameworks/xgboost/model_handler.py +6 -6
- mlrun/k8s_utils.py +4 -5
- mlrun/kfpops.py +2 -2
- mlrun/launcher/base.py +10 -10
- mlrun/launcher/local.py +8 -8
- mlrun/launcher/remote.py +7 -7
- mlrun/lists.py +3 -4
- mlrun/model.py +205 -55
- mlrun/model_monitoring/api.py +21 -24
- mlrun/model_monitoring/application.py +4 -4
- mlrun/model_monitoring/batch.py +17 -17
- mlrun/model_monitoring/controller.py +2 -1
- mlrun/model_monitoring/features_drift_table.py +44 -31
- mlrun/model_monitoring/prometheus.py +1 -4
- mlrun/model_monitoring/stores/kv_model_endpoint_store.py +11 -13
- mlrun/model_monitoring/stores/model_endpoint_store.py +9 -11
- mlrun/model_monitoring/stores/models/__init__.py +2 -2
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +11 -13
- mlrun/model_monitoring/stream_processing.py +16 -34
- mlrun/model_monitoring/tracking_policy.py +2 -1
- mlrun/package/__init__.py +6 -6
- mlrun/package/context_handler.py +5 -5
- mlrun/package/packager.py +7 -7
- mlrun/package/packagers/default_packager.py +6 -6
- mlrun/package/packagers/numpy_packagers.py +15 -15
- mlrun/package/packagers/pandas_packagers.py +5 -5
- mlrun/package/packagers/python_standard_library_packagers.py +10 -10
- mlrun/package/packagers_manager.py +18 -23
- mlrun/package/utils/_formatter.py +4 -4
- mlrun/package/utils/_pickler.py +2 -2
- mlrun/package/utils/_supported_format.py +4 -4
- mlrun/package/utils/log_hint_utils.py +2 -2
- mlrun/package/utils/type_hint_utils.py +4 -9
- mlrun/platforms/other.py +1 -2
- mlrun/projects/operations.py +5 -5
- mlrun/projects/pipelines.py +9 -9
- mlrun/projects/project.py +58 -46
- mlrun/render.py +1 -1
- mlrun/run.py +9 -9
- mlrun/runtimes/__init__.py +7 -4
- mlrun/runtimes/base.py +20 -23
- mlrun/runtimes/constants.py +5 -5
- mlrun/runtimes/daskjob.py +8 -8
- mlrun/runtimes/databricks_job/databricks_cancel_task.py +1 -1
- mlrun/runtimes/databricks_job/databricks_runtime.py +7 -7
- mlrun/runtimes/function_reference.py +1 -1
- mlrun/runtimes/local.py +1 -1
- mlrun/runtimes/mpijob/abstract.py +1 -2
- mlrun/runtimes/nuclio/__init__.py +20 -0
- mlrun/runtimes/{function.py → nuclio/function.py} +15 -16
- mlrun/runtimes/{nuclio.py → nuclio/nuclio.py} +6 -6
- mlrun/runtimes/{serving.py → nuclio/serving.py} +13 -12
- mlrun/runtimes/pod.py +95 -48
- mlrun/runtimes/remotesparkjob.py +1 -1
- mlrun/runtimes/sparkjob/spark3job.py +50 -33
- mlrun/runtimes/utils.py +1 -2
- mlrun/secrets.py +3 -3
- mlrun/serving/remote.py +0 -4
- mlrun/serving/routers.py +6 -6
- mlrun/serving/server.py +4 -4
- mlrun/serving/states.py +29 -0
- mlrun/serving/utils.py +3 -3
- mlrun/serving/v1_serving.py +6 -7
- mlrun/serving/v2_serving.py +50 -8
- mlrun/track/tracker_manager.py +3 -3
- mlrun/track/trackers/mlflow_tracker.py +1 -2
- mlrun/utils/async_http.py +5 -7
- mlrun/utils/azure_vault.py +1 -1
- mlrun/utils/clones.py +1 -2
- mlrun/utils/condition_evaluator.py +3 -3
- mlrun/utils/db.py +3 -3
- mlrun/utils/helpers.py +37 -119
- mlrun/utils/http.py +1 -4
- mlrun/utils/logger.py +49 -14
- mlrun/utils/notifications/notification/__init__.py +3 -3
- mlrun/utils/notifications/notification/base.py +2 -2
- mlrun/utils/notifications/notification/ipython.py +1 -1
- mlrun/utils/notifications/notification_pusher.py +8 -14
- mlrun/utils/retryer.py +207 -0
- mlrun/utils/singleton.py +1 -1
- mlrun/utils/v3io_clients.py +2 -3
- mlrun/utils/version/version.json +2 -2
- mlrun/utils/version/version.py +2 -6
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/METADATA +9 -9
- mlrun-1.7.0rc2.dist-info/RECORD +315 -0
- mlrun-1.6.0rc35.dist-info/RECORD +0 -313
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/LICENSE +0 -0
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/WHEEL +0 -0
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/entry_points.txt +0 -0
- {mlrun-1.6.0rc35.dist-info → mlrun-1.7.0rc2.dist-info}/top_level.txt +0 -0
mlrun/model_monitoring/api.py
CHANGED
|
@@ -45,7 +45,7 @@ def get_or_create_model_endpoint(
|
|
|
45
45
|
endpoint_id: str = "",
|
|
46
46
|
function_name: str = "",
|
|
47
47
|
context: mlrun.MLClientCtx = None,
|
|
48
|
-
sample_set_statistics:
|
|
48
|
+
sample_set_statistics: dict[str, typing.Any] = None,
|
|
49
49
|
drift_threshold: float = None,
|
|
50
50
|
possible_drift_threshold: float = None,
|
|
51
51
|
monitoring_mode: ModelMonitoringMode = ModelMonitoringMode.disabled,
|
|
@@ -82,7 +82,7 @@ def get_or_create_model_endpoint(
|
|
|
82
82
|
if not endpoint_id:
|
|
83
83
|
# Generate a new model endpoint id based on the project name and model name
|
|
84
84
|
endpoint_id = hashlib.sha1(
|
|
85
|
-
f"{project}_{model_endpoint_name}".encode(
|
|
85
|
+
f"{project}_{model_endpoint_name}".encode()
|
|
86
86
|
).hexdigest()
|
|
87
87
|
|
|
88
88
|
if not db_session:
|
|
@@ -239,7 +239,7 @@ def record_results(
|
|
|
239
239
|
def _model_endpoint_validations(
|
|
240
240
|
model_endpoint: ModelEndpoint,
|
|
241
241
|
model_path: str = "",
|
|
242
|
-
sample_set_statistics:
|
|
242
|
+
sample_set_statistics: dict[str, typing.Any] = None,
|
|
243
243
|
drift_threshold: float = None,
|
|
244
244
|
possible_drift_threshold: float = None,
|
|
245
245
|
):
|
|
@@ -307,7 +307,7 @@ def get_drift_thresholds_if_not_none(
|
|
|
307
307
|
model_endpoint: ModelEndpoint,
|
|
308
308
|
drift_threshold: float = None,
|
|
309
309
|
possible_drift_threshold: float = None,
|
|
310
|
-
) ->
|
|
310
|
+
) -> tuple[float, float]:
|
|
311
311
|
"""
|
|
312
312
|
Get drift and possible drift thresholds. If one of the thresholds is missing, will try to retrieve
|
|
313
313
|
it from the `ModelEndpoint` object. If not defined under the `ModelEndpoint` as well, will retrieve it from
|
|
@@ -386,7 +386,7 @@ def _generate_model_endpoint(
|
|
|
386
386
|
model_endpoint_name: str,
|
|
387
387
|
function_name: str,
|
|
388
388
|
context: mlrun.MLClientCtx,
|
|
389
|
-
sample_set_statistics:
|
|
389
|
+
sample_set_statistics: dict[str, typing.Any],
|
|
390
390
|
drift_threshold: float,
|
|
391
391
|
possible_drift_threshold: float,
|
|
392
392
|
monitoring_mode: ModelMonitoringMode = ModelMonitoringMode.disabled,
|
|
@@ -452,8 +452,8 @@ def _generate_model_endpoint(
|
|
|
452
452
|
def trigger_drift_batch_job(
|
|
453
453
|
project: str,
|
|
454
454
|
default_batch_image="mlrun/mlrun",
|
|
455
|
-
model_endpoints_ids:
|
|
456
|
-
batch_intervals_dict:
|
|
455
|
+
model_endpoints_ids: list[str] = None,
|
|
456
|
+
batch_intervals_dict: dict[str, float] = None,
|
|
457
457
|
db_session=None,
|
|
458
458
|
):
|
|
459
459
|
"""
|
|
@@ -476,9 +476,7 @@ def trigger_drift_batch_job(
|
|
|
476
476
|
db_session = mlrun.get_run_db()
|
|
477
477
|
|
|
478
478
|
# Register the monitoring batch job (do nothing if already exist) and get the job function as a dictionary
|
|
479
|
-
batch_function_dict: typing.
|
|
480
|
-
str, typing.Any
|
|
481
|
-
] = db_session.deploy_monitoring_batch_job(
|
|
479
|
+
batch_function_dict: dict[str, typing.Any] = db_session.deploy_monitoring_batch_job(
|
|
482
480
|
project=project,
|
|
483
481
|
default_batch_image=default_batch_image,
|
|
484
482
|
)
|
|
@@ -495,8 +493,8 @@ def trigger_drift_batch_job(
|
|
|
495
493
|
|
|
496
494
|
|
|
497
495
|
def _generate_job_params(
|
|
498
|
-
model_endpoints_ids:
|
|
499
|
-
batch_intervals_dict:
|
|
496
|
+
model_endpoints_ids: list[str],
|
|
497
|
+
batch_intervals_dict: dict[str, float] = None,
|
|
500
498
|
):
|
|
501
499
|
"""
|
|
502
500
|
Generate the required params for the model monitoring batch job function.
|
|
@@ -519,9 +517,9 @@ def _generate_job_params(
|
|
|
519
517
|
def get_sample_set_statistics(
|
|
520
518
|
sample_set: DatasetType = None,
|
|
521
519
|
model_artifact_feature_stats: dict = None,
|
|
522
|
-
sample_set_columns: typing.Optional[
|
|
523
|
-
sample_set_drop_columns: typing.Optional[
|
|
524
|
-
sample_set_label_columns: typing.Optional[
|
|
520
|
+
sample_set_columns: typing.Optional[list] = None,
|
|
521
|
+
sample_set_drop_columns: typing.Optional[list] = None,
|
|
522
|
+
sample_set_label_columns: typing.Optional[list] = None,
|
|
525
523
|
) -> dict:
|
|
526
524
|
"""
|
|
527
525
|
Get the sample set statistics either from the given sample set or the statistics logged with the model while
|
|
@@ -576,10 +574,10 @@ def get_sample_set_statistics(
|
|
|
576
574
|
|
|
577
575
|
def read_dataset_as_dataframe(
|
|
578
576
|
dataset: DatasetType,
|
|
579
|
-
feature_columns: typing.Union[str,
|
|
580
|
-
label_columns: typing.Union[str,
|
|
581
|
-
drop_columns: typing.Union[str,
|
|
582
|
-
) ->
|
|
577
|
+
feature_columns: typing.Union[str, list[str]] = None,
|
|
578
|
+
label_columns: typing.Union[str, list[str]] = None,
|
|
579
|
+
drop_columns: typing.Union[str, list[str], int, list[int]] = None,
|
|
580
|
+
) -> tuple[pd.DataFrame, list[str]]:
|
|
583
581
|
"""
|
|
584
582
|
Parse the given dataset into a DataFrame and drop the columns accordingly. In addition, the label columns will be
|
|
585
583
|
parsed and validated as well.
|
|
@@ -670,7 +668,7 @@ def perform_drift_analysis(
|
|
|
670
668
|
possible_drift_threshold: float,
|
|
671
669
|
artifacts_tag: str = "",
|
|
672
670
|
db_session=None,
|
|
673
|
-
):
|
|
671
|
+
) -> None:
|
|
674
672
|
"""
|
|
675
673
|
Calculate drift per feature and produce the drift table artifact for logging post prediction. Note that most of
|
|
676
674
|
the calculations were already made through the monitoring batch job.
|
|
@@ -696,7 +694,7 @@ def perform_drift_analysis(
|
|
|
696
694
|
metrics = model_endpoint.status.drift_measures
|
|
697
695
|
inputs_statistics = model_endpoint.status.current_stats
|
|
698
696
|
|
|
699
|
-
inputs_statistics.pop(
|
|
697
|
+
inputs_statistics.pop(EventFieldType.TIMESTAMP, None)
|
|
700
698
|
|
|
701
699
|
# Calculate drift for each feature
|
|
702
700
|
virtual_drift = VirtualDrift()
|
|
@@ -708,7 +706,6 @@ def perform_drift_analysis(
|
|
|
708
706
|
|
|
709
707
|
# Drift table plot
|
|
710
708
|
html_plot = FeaturesDriftTablePlot().produce(
|
|
711
|
-
features=list(inputs_statistics.keys()),
|
|
712
709
|
sample_set_statistics=sample_set_statistics,
|
|
713
710
|
inputs_statistics=inputs_statistics,
|
|
714
711
|
metrics=metrics,
|
|
@@ -746,7 +743,7 @@ def perform_drift_analysis(
|
|
|
746
743
|
def _log_drift_artifacts(
|
|
747
744
|
context: mlrun.MLClientCtx,
|
|
748
745
|
html_plot: str,
|
|
749
|
-
metrics_per_feature:
|
|
746
|
+
metrics_per_feature: dict[str, float],
|
|
750
747
|
drift_status: bool,
|
|
751
748
|
drift_metric: float,
|
|
752
749
|
artifacts_tag: str,
|
|
@@ -789,7 +786,7 @@ def _get_drift_result(
|
|
|
789
786
|
tvd: float,
|
|
790
787
|
hellinger: float,
|
|
791
788
|
threshold: float,
|
|
792
|
-
) ->
|
|
789
|
+
) -> tuple[bool, float]:
|
|
793
790
|
"""
|
|
794
791
|
Calculate the drift result by the following equation: (tvd + hellinger) / 2
|
|
795
792
|
|
|
@@ -16,7 +16,7 @@ import dataclasses
|
|
|
16
16
|
import json
|
|
17
17
|
import re
|
|
18
18
|
from abc import ABC, abstractmethod
|
|
19
|
-
from typing import Any, Optional,
|
|
19
|
+
from typing import Any, Optional, Union
|
|
20
20
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
import pandas as pd
|
|
@@ -108,7 +108,7 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
|
|
|
108
108
|
|
|
109
109
|
def do(
|
|
110
110
|
self, event: dict[str, Any]
|
|
111
|
-
) ->
|
|
111
|
+
) -> tuple[list[ModelMonitoringApplicationResult], dict]:
|
|
112
112
|
"""
|
|
113
113
|
Process the monitoring event and return application results.
|
|
114
114
|
|
|
@@ -165,7 +165,7 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
|
|
|
165
165
|
def _resolve_event(
|
|
166
166
|
cls,
|
|
167
167
|
event: dict[str, Any],
|
|
168
|
-
) ->
|
|
168
|
+
) -> tuple[
|
|
169
169
|
str,
|
|
170
170
|
pd.DataFrame,
|
|
171
171
|
pd.DataFrame,
|
|
@@ -272,7 +272,7 @@ class PushToMonitoringWriter(StepToDict):
|
|
|
272
272
|
self.output_stream = None
|
|
273
273
|
self.name = name or "PushToMonitoringWriter"
|
|
274
274
|
|
|
275
|
-
def do(self, event:
|
|
275
|
+
def do(self, event: tuple[list[ModelMonitoringApplicationResult], dict]) -> None:
|
|
276
276
|
"""
|
|
277
277
|
Push application results to the monitoring writer stream.
|
|
278
278
|
|
mlrun/model_monitoring/batch.py
CHANGED
|
@@ -19,7 +19,7 @@ import datetime
|
|
|
19
19
|
import json
|
|
20
20
|
import os
|
|
21
21
|
import re
|
|
22
|
-
from typing import Any, ClassVar,
|
|
22
|
+
from typing import Any, ClassVar, Optional, Union
|
|
23
23
|
|
|
24
24
|
import numpy as np
|
|
25
25
|
import pandas as pd
|
|
@@ -38,7 +38,7 @@ import mlrun.utils.v3io_clients
|
|
|
38
38
|
from mlrun.utils import logger
|
|
39
39
|
|
|
40
40
|
# A type for representing a drift result, a tuple of the status and the drift mean:
|
|
41
|
-
DriftResultType =
|
|
41
|
+
DriftResultType = tuple[mlrun.common.schemas.model_monitoring.DriftStatus, float]
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
@dataclasses.dataclass
|
|
@@ -157,7 +157,7 @@ class VirtualDrift:
|
|
|
157
157
|
self,
|
|
158
158
|
prediction_col: Optional[str] = None,
|
|
159
159
|
label_col: Optional[str] = None,
|
|
160
|
-
feature_weights: Optional[
|
|
160
|
+
feature_weights: Optional[list[float]] = None,
|
|
161
161
|
inf_capping: Optional[float] = 10,
|
|
162
162
|
):
|
|
163
163
|
"""
|
|
@@ -179,7 +179,7 @@ class VirtualDrift:
|
|
|
179
179
|
self.capping = inf_capping
|
|
180
180
|
|
|
181
181
|
# Initialize objects of the current metrics
|
|
182
|
-
self.metrics:
|
|
182
|
+
self.metrics: dict[str, type[HistogramDistanceMetric]] = {
|
|
183
183
|
metric_class.NAME: metric_class
|
|
184
184
|
for metric_class in (
|
|
185
185
|
TotalVarianceDistance,
|
|
@@ -189,7 +189,7 @@ class VirtualDrift:
|
|
|
189
189
|
}
|
|
190
190
|
|
|
191
191
|
@staticmethod
|
|
192
|
-
def dict_to_histogram(histogram_dict:
|
|
192
|
+
def dict_to_histogram(histogram_dict: dict[str, dict[str, Any]]) -> pd.DataFrame:
|
|
193
193
|
"""
|
|
194
194
|
Convert histogram dictionary to pandas DataFrame with feature histograms as columns
|
|
195
195
|
|
|
@@ -212,9 +212,9 @@ class VirtualDrift:
|
|
|
212
212
|
|
|
213
213
|
def compute_metrics_over_df(
|
|
214
214
|
self,
|
|
215
|
-
base_histogram:
|
|
216
|
-
latest_histogram:
|
|
217
|
-
) ->
|
|
215
|
+
base_histogram: dict[str, dict[str, Any]],
|
|
216
|
+
latest_histogram: dict[str, dict[str, Any]],
|
|
217
|
+
) -> dict[str, dict[str, Any]]:
|
|
218
218
|
"""
|
|
219
219
|
Calculate metrics values for each feature.
|
|
220
220
|
|
|
@@ -243,9 +243,9 @@ class VirtualDrift:
|
|
|
243
243
|
|
|
244
244
|
def compute_drift_from_histograms(
|
|
245
245
|
self,
|
|
246
|
-
feature_stats:
|
|
247
|
-
current_stats:
|
|
248
|
-
) ->
|
|
246
|
+
feature_stats: dict[str, dict[str, Any]],
|
|
247
|
+
current_stats: dict[str, dict[str, Any]],
|
|
248
|
+
) -> dict[str, dict[str, Any]]:
|
|
249
249
|
"""
|
|
250
250
|
Compare the distributions of both the original features data and the latest input data
|
|
251
251
|
:param feature_stats: Histogram dictionary of the original feature dataset that was used in the model training.
|
|
@@ -335,10 +335,10 @@ class VirtualDrift:
|
|
|
335
335
|
|
|
336
336
|
@staticmethod
|
|
337
337
|
def check_for_drift_per_feature(
|
|
338
|
-
metrics_results_dictionary:
|
|
338
|
+
metrics_results_dictionary: dict[str, Union[float, dict]],
|
|
339
339
|
possible_drift_threshold: float = 0.5,
|
|
340
340
|
drift_detected_threshold: float = 0.7,
|
|
341
|
-
) ->
|
|
341
|
+
) -> dict[str, DriftResultType]:
|
|
342
342
|
"""
|
|
343
343
|
Check for drift based on the defined decision rule and the calculated results of the statistical metrics per
|
|
344
344
|
feature.
|
|
@@ -389,7 +389,7 @@ class VirtualDrift:
|
|
|
389
389
|
|
|
390
390
|
@staticmethod
|
|
391
391
|
def check_for_drift(
|
|
392
|
-
metrics_results_dictionary:
|
|
392
|
+
metrics_results_dictionary: dict[str, Union[float, dict]],
|
|
393
393
|
possible_drift_threshold: float = 0.5,
|
|
394
394
|
drift_detected_threshold: float = 0.7,
|
|
395
395
|
) -> DriftResultType:
|
|
@@ -880,7 +880,7 @@ class BatchProcessor:
|
|
|
880
880
|
],
|
|
881
881
|
)
|
|
882
882
|
|
|
883
|
-
def _get_interval_range(self) ->
|
|
883
|
+
def _get_interval_range(self) -> tuple[datetime.datetime, datetime.datetime]:
|
|
884
884
|
"""Getting batch interval time range"""
|
|
885
885
|
minutes, hours, days = (
|
|
886
886
|
self.batch_dict[
|
|
@@ -912,7 +912,7 @@ class BatchProcessor:
|
|
|
912
912
|
endpoint_id: str,
|
|
913
913
|
drift_status: mlrun.common.schemas.model_monitoring.DriftStatus,
|
|
914
914
|
drift_measure: float,
|
|
915
|
-
drift_result:
|
|
915
|
+
drift_result: dict[str, dict[str, Any]],
|
|
916
916
|
timestamp: pd.Timestamp,
|
|
917
917
|
):
|
|
918
918
|
"""Update drift results in input stream.
|
|
@@ -978,7 +978,7 @@ class BatchProcessor:
|
|
|
978
978
|
self,
|
|
979
979
|
endpoint_id: str,
|
|
980
980
|
drift_status: mlrun.common.schemas.model_monitoring.DriftStatus,
|
|
981
|
-
drift_result:
|
|
981
|
+
drift_result: dict[str, dict[str, Any]],
|
|
982
982
|
):
|
|
983
983
|
"""Push drift metrics to Prometheus registry. Please note that the metrics are being pushed through HTTP
|
|
984
984
|
to the monitoring stream pod that writes them into a local registry. Afterwards, Prometheus wil scrape these
|
|
@@ -17,7 +17,8 @@ import datetime
|
|
|
17
17
|
import json
|
|
18
18
|
import os
|
|
19
19
|
import re
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import Iterator
|
|
21
|
+
from typing import Any, NamedTuple, Optional, Union, cast
|
|
21
22
|
|
|
22
23
|
from v3io.dataplane.response import HttpResponseError
|
|
23
24
|
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
from typing import
|
|
15
|
+
from typing import Union
|
|
16
16
|
|
|
17
17
|
import numpy as np
|
|
18
18
|
import plotly.graph_objects as go
|
|
@@ -21,7 +21,7 @@ from plotly.subplots import make_subplots
|
|
|
21
21
|
import mlrun.common.schemas.model_monitoring
|
|
22
22
|
|
|
23
23
|
# A type for representing a drift result, a tuple of the status and the drift mean:
|
|
24
|
-
DriftResultType =
|
|
24
|
+
DriftResultType = tuple[mlrun.common.schemas.model_monitoring.DriftStatus, float]
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class FeaturesDriftTablePlot:
|
|
@@ -93,17 +93,14 @@ class FeaturesDriftTablePlot:
|
|
|
93
93
|
|
|
94
94
|
def produce(
|
|
95
95
|
self,
|
|
96
|
-
features: List[str],
|
|
97
96
|
sample_set_statistics: dict,
|
|
98
97
|
inputs_statistics: dict,
|
|
99
|
-
metrics:
|
|
100
|
-
drift_results:
|
|
98
|
+
metrics: dict[str, Union[dict, float]],
|
|
99
|
+
drift_results: dict[str, DriftResultType],
|
|
101
100
|
) -> str:
|
|
102
101
|
"""
|
|
103
102
|
Produce the html code of the table plot with the given information and the stored configurations in the class.
|
|
104
103
|
|
|
105
|
-
:param features: List of all the features names to include in the table. These names expected to be
|
|
106
|
-
in the statistics and metrics dictionaries.
|
|
107
104
|
:param sample_set_statistics: The sample set calculated statistics dictionary.
|
|
108
105
|
:param inputs_statistics: The inputs calculated statistics dictionary.
|
|
109
106
|
:param metrics: The drift detection metrics calculated on the sample set and inputs.
|
|
@@ -113,7 +110,7 @@ class FeaturesDriftTablePlot:
|
|
|
113
110
|
"""
|
|
114
111
|
# Plot the drift table:
|
|
115
112
|
figure = self._plot(
|
|
116
|
-
features=
|
|
113
|
+
features=list(inputs_statistics.keys()),
|
|
117
114
|
sample_set_statistics=sample_set_statistics,
|
|
118
115
|
inputs_statistics=inputs_statistics,
|
|
119
116
|
metrics=metrics,
|
|
@@ -165,7 +162,7 @@ class FeaturesDriftTablePlot:
|
|
|
165
162
|
self._metrics_columns
|
|
166
163
|
)
|
|
167
164
|
|
|
168
|
-
def _plot_headers_tables(self) ->
|
|
165
|
+
def _plot_headers_tables(self) -> tuple[go.Table, go.Table]:
|
|
169
166
|
"""
|
|
170
167
|
Plot the headers of the table:
|
|
171
168
|
|
|
@@ -232,7 +229,7 @@ class FeaturesDriftTablePlot:
|
|
|
232
229
|
|
|
233
230
|
return header_table, sub_header_table
|
|
234
231
|
|
|
235
|
-
def _separate_feature_name(self, feature_name: str) ->
|
|
232
|
+
def _separate_feature_name(self, feature_name: str) -> list[str]:
|
|
236
233
|
"""
|
|
237
234
|
Separate the given feature name by the maximum length configured in the class. Used for calculating the amount
|
|
238
235
|
of lines required to represent the longest feature name in the table, so the row heights will fit accordingly.
|
|
@@ -293,15 +290,22 @@ class FeaturesDriftTablePlot:
|
|
|
293
290
|
:return: The feature row - `Table` trace.
|
|
294
291
|
"""
|
|
295
292
|
# Add '\n' to the feature name in order to make it fit into its cell:
|
|
296
|
-
|
|
293
|
+
html_feature_name = "<br>".join(self._separate_feature_name(feature_name))
|
|
297
294
|
|
|
298
295
|
# Initialize the cells values list with the bold feature name as the first value:
|
|
299
|
-
cells_values = [f"<b>{
|
|
296
|
+
cells_values = [f"<b>{html_feature_name}</b>"]
|
|
300
297
|
|
|
301
298
|
# Add the statistics columns:
|
|
302
299
|
for column in self._statistics_columns:
|
|
303
300
|
cells_values.append(sample_statistics[column])
|
|
304
|
-
|
|
301
|
+
try:
|
|
302
|
+
cells_values.append(input_statistics[column])
|
|
303
|
+
except KeyError:
|
|
304
|
+
raise ValueError(
|
|
305
|
+
f"The `input_statistics['{feature_name}']` dictionary "
|
|
306
|
+
f"does not include the expected key '{column}'. "
|
|
307
|
+
"Please check the current data."
|
|
308
|
+
)
|
|
305
309
|
|
|
306
310
|
# Add the metrics columns:
|
|
307
311
|
for column in self._metrics_columns:
|
|
@@ -329,8 +333,8 @@ class FeaturesDriftTablePlot:
|
|
|
329
333
|
return feature_row_table
|
|
330
334
|
|
|
331
335
|
def _plot_histogram_scatters(
|
|
332
|
-
self, sample_hist:
|
|
333
|
-
) ->
|
|
336
|
+
self, sample_hist: tuple[list, list], input_hist: tuple[list, list]
|
|
337
|
+
) -> tuple[go.Scatter, go.Scatter]:
|
|
334
338
|
"""
|
|
335
339
|
Plot the feature's histograms to include in the "histograms" column. Both histograms are returned to later be
|
|
336
340
|
added in the same figure, so they will be on top of each other and not separated. Both histograms are rescaled
|
|
@@ -375,7 +379,7 @@ class FeaturesDriftTablePlot:
|
|
|
375
379
|
|
|
376
380
|
return scatters[0], scatters[1]
|
|
377
381
|
|
|
378
|
-
def _calculate_row_height(self, features:
|
|
382
|
+
def _calculate_row_height(self, features: list[str]) -> int:
|
|
379
383
|
"""
|
|
380
384
|
Calculate the feature row height according to the given features. The longest feature will set the height to all
|
|
381
385
|
the rows. The height depends on the separations amount of the longest feature name - more '\n' means more pixels
|
|
@@ -450,11 +454,11 @@ class FeaturesDriftTablePlot:
|
|
|
450
454
|
|
|
451
455
|
def _plot(
|
|
452
456
|
self,
|
|
453
|
-
features:
|
|
457
|
+
features: list[str],
|
|
454
458
|
sample_set_statistics: dict,
|
|
455
459
|
inputs_statistics: dict,
|
|
456
|
-
metrics:
|
|
457
|
-
drift_results:
|
|
460
|
+
metrics: dict[str, Union[dict, float]],
|
|
461
|
+
drift_results: dict[str, DriftResultType],
|
|
458
462
|
) -> go.Figure:
|
|
459
463
|
"""
|
|
460
464
|
Plot the drift table using the given data and stored configurations of the class.
|
|
@@ -517,18 +521,27 @@ class FeaturesDriftTablePlot:
|
|
|
517
521
|
# Start going over the features and plot each row, histogram and notification:
|
|
518
522
|
row = 3 # We are currently at row 3 counting the headers.
|
|
519
523
|
for feature in features:
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
524
|
+
try:
|
|
525
|
+
# Add the feature values:
|
|
526
|
+
main_figure.add_trace(
|
|
527
|
+
self._plot_feature_row_table(
|
|
528
|
+
feature_name=feature,
|
|
529
|
+
sample_statistics=sample_set_statistics[feature],
|
|
530
|
+
input_statistics=inputs_statistics[feature],
|
|
531
|
+
metrics=metrics[feature],
|
|
532
|
+
row_height=row_height,
|
|
533
|
+
),
|
|
534
|
+
row=row,
|
|
535
|
+
col=1,
|
|
536
|
+
)
|
|
537
|
+
except KeyError:
|
|
538
|
+
raise ValueError(
|
|
539
|
+
"`sample_set_statistics` does not contain the expected "
|
|
540
|
+
f"key '{feature}' from `inputs_statistics`. Please verify "
|
|
541
|
+
"the data integrity.\n"
|
|
542
|
+
f"{sample_set_statistics.keys() = }\n"
|
|
543
|
+
f"{inputs_statistics.keys() = }\n"
|
|
544
|
+
)
|
|
532
545
|
# Add the histograms (both traces are added to the same subplot figure):
|
|
533
546
|
sample_hist, input_hist = self._plot_histogram_scatters(
|
|
534
547
|
sample_hist=sample_set_statistics[feature]["hist"],
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
#
|
|
15
|
-
import typing
|
|
16
15
|
|
|
17
16
|
import prometheus_client
|
|
18
17
|
|
|
@@ -134,9 +133,7 @@ def write_predictions_and_latency_metrics(
|
|
|
134
133
|
|
|
135
134
|
|
|
136
135
|
@_write_registry
|
|
137
|
-
def write_income_features(
|
|
138
|
-
project: str, endpoint_id: str, features: typing.Dict[str, float]
|
|
139
|
-
):
|
|
136
|
+
def write_income_features(project: str, endpoint_id: str, features: dict[str, float]):
|
|
140
137
|
"""Update a sample of features.
|
|
141
138
|
|
|
142
139
|
:param project: Project name.
|
|
@@ -50,7 +50,7 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
50
50
|
# Get the KV table path and container
|
|
51
51
|
self.path, self.container = self._get_path_and_container()
|
|
52
52
|
|
|
53
|
-
def write_model_endpoint(self, endpoint:
|
|
53
|
+
def write_model_endpoint(self, endpoint: dict[str, typing.Any]):
|
|
54
54
|
"""
|
|
55
55
|
Create a new endpoint record in the KV table.
|
|
56
56
|
|
|
@@ -72,7 +72,7 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
72
72
|
self._infer_kv_schema()
|
|
73
73
|
|
|
74
74
|
def update_model_endpoint(
|
|
75
|
-
self, endpoint_id: str, attributes:
|
|
75
|
+
self, endpoint_id: str, attributes: dict[str, typing.Any]
|
|
76
76
|
):
|
|
77
77
|
"""
|
|
78
78
|
Update a model endpoint record with a given attributes.
|
|
@@ -114,7 +114,7 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
114
114
|
def get_model_endpoint(
|
|
115
115
|
self,
|
|
116
116
|
endpoint_id: str,
|
|
117
|
-
) ->
|
|
117
|
+
) -> dict[str, typing.Any]:
|
|
118
118
|
"""
|
|
119
119
|
Get a single model endpoint record.
|
|
120
120
|
|
|
@@ -167,10 +167,10 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
167
167
|
self,
|
|
168
168
|
model: str = None,
|
|
169
169
|
function: str = None,
|
|
170
|
-
labels:
|
|
170
|
+
labels: list[str] = None,
|
|
171
171
|
top_level: bool = None,
|
|
172
|
-
uids:
|
|
173
|
-
) ->
|
|
172
|
+
uids: list = None,
|
|
173
|
+
) -> list[dict[str, typing.Any]]:
|
|
174
174
|
"""
|
|
175
175
|
Returns a list of model endpoint dictionaries, supports filtering by model, function, labels or top level.
|
|
176
176
|
By default, when no filters are applied, all available model endpoints for the given project will
|
|
@@ -239,9 +239,7 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
239
239
|
|
|
240
240
|
return endpoint_list
|
|
241
241
|
|
|
242
|
-
def delete_model_endpoints_resources(
|
|
243
|
-
self, endpoints: typing.List[typing.Dict[str, typing.Any]]
|
|
244
|
-
):
|
|
242
|
+
def delete_model_endpoints_resources(self, endpoints: list[dict[str, typing.Any]]):
|
|
245
243
|
"""
|
|
246
244
|
Delete all model endpoints resources in both KV and the time series DB.
|
|
247
245
|
|
|
@@ -310,11 +308,11 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
310
308
|
def get_endpoint_real_time_metrics(
|
|
311
309
|
self,
|
|
312
310
|
endpoint_id: str,
|
|
313
|
-
metrics:
|
|
311
|
+
metrics: list[str],
|
|
314
312
|
start: str = "now-1h",
|
|
315
313
|
end: str = "now",
|
|
316
314
|
access_key: str = None,
|
|
317
|
-
) ->
|
|
315
|
+
) -> dict[str, list[tuple[str, float]]]:
|
|
318
316
|
"""
|
|
319
317
|
Getting metrics from the time series DB. There are pre-defined metrics for model endpoints such as
|
|
320
318
|
`predictions_per_second` and `latency_avg_5m` but also custom metrics defined by the user.
|
|
@@ -396,7 +394,7 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
396
394
|
|
|
397
395
|
return metrics_mapping
|
|
398
396
|
|
|
399
|
-
def _generate_tsdb_paths(self) ->
|
|
397
|
+
def _generate_tsdb_paths(self) -> tuple[str, str]:
|
|
400
398
|
"""Generate a short path to the TSDB resources and a filtered path for the frames object
|
|
401
399
|
:return: A tuple of:
|
|
402
400
|
[0] = Short path to the TSDB resources
|
|
@@ -455,7 +453,7 @@ class KVModelEndpointStore(ModelEndpointStore):
|
|
|
455
453
|
project: str,
|
|
456
454
|
function: str = None,
|
|
457
455
|
model: str = None,
|
|
458
|
-
labels:
|
|
456
|
+
labels: list[str] = None,
|
|
459
457
|
top_level: bool = False,
|
|
460
458
|
) -> str:
|
|
461
459
|
"""
|
|
@@ -31,7 +31,7 @@ class ModelEndpointStore(ABC):
|
|
|
31
31
|
self.project = project
|
|
32
32
|
|
|
33
33
|
@abstractmethod
|
|
34
|
-
def write_model_endpoint(self, endpoint:
|
|
34
|
+
def write_model_endpoint(self, endpoint: dict[str, typing.Any]):
|
|
35
35
|
"""
|
|
36
36
|
Create a new endpoint record in the DB table.
|
|
37
37
|
|
|
@@ -41,7 +41,7 @@ class ModelEndpointStore(ABC):
|
|
|
41
41
|
|
|
42
42
|
@abstractmethod
|
|
43
43
|
def update_model_endpoint(
|
|
44
|
-
self, endpoint_id: str, attributes:
|
|
44
|
+
self, endpoint_id: str, attributes: dict[str, typing.Any]
|
|
45
45
|
):
|
|
46
46
|
"""
|
|
47
47
|
Update a model endpoint record with a given attributes.
|
|
@@ -63,9 +63,7 @@ class ModelEndpointStore(ABC):
|
|
|
63
63
|
pass
|
|
64
64
|
|
|
65
65
|
@abstractmethod
|
|
66
|
-
def delete_model_endpoints_resources(
|
|
67
|
-
self, endpoints: typing.List[typing.Dict[str, typing.Any]]
|
|
68
|
-
):
|
|
66
|
+
def delete_model_endpoints_resources(self, endpoints: list[dict[str, typing.Any]]):
|
|
69
67
|
"""
|
|
70
68
|
Delete all model endpoints resources.
|
|
71
69
|
|
|
@@ -78,7 +76,7 @@ class ModelEndpointStore(ABC):
|
|
|
78
76
|
def get_model_endpoint(
|
|
79
77
|
self,
|
|
80
78
|
endpoint_id: str,
|
|
81
|
-
) ->
|
|
79
|
+
) -> dict[str, typing.Any]:
|
|
82
80
|
"""
|
|
83
81
|
Get a single model endpoint record.
|
|
84
82
|
|
|
@@ -93,10 +91,10 @@ class ModelEndpointStore(ABC):
|
|
|
93
91
|
self,
|
|
94
92
|
model: str = None,
|
|
95
93
|
function: str = None,
|
|
96
|
-
labels:
|
|
94
|
+
labels: list[str] = None,
|
|
97
95
|
top_level: bool = None,
|
|
98
|
-
uids:
|
|
99
|
-
) ->
|
|
96
|
+
uids: list = None,
|
|
97
|
+
) -> list[dict[str, typing.Any]]:
|
|
100
98
|
"""
|
|
101
99
|
Returns a list of model endpoint dictionaries, supports filtering by model, function, labels or top level.
|
|
102
100
|
By default, when no filters are applied, all available model endpoints for the given project will
|
|
@@ -118,11 +116,11 @@ class ModelEndpointStore(ABC):
|
|
|
118
116
|
def get_endpoint_real_time_metrics(
|
|
119
117
|
self,
|
|
120
118
|
endpoint_id: str,
|
|
121
|
-
metrics:
|
|
119
|
+
metrics: list[str],
|
|
122
120
|
start: str = "now-1h",
|
|
123
121
|
end: str = "now",
|
|
124
122
|
access_key: str = None,
|
|
125
|
-
) ->
|
|
123
|
+
) -> dict[str, list[tuple[str, float]]]:
|
|
126
124
|
"""
|
|
127
125
|
Getting metrics from the time series DB. There are pre-defined metrics for model endpoints such as
|
|
128
126
|
`predictions_per_second` and `latency_avg_5m` but also custom metrics defined by the user.
|
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Optional,
|
|
15
|
+
from typing import Optional, Union
|
|
16
16
|
|
|
17
17
|
from .mysql import ModelEndpointsTable as MySQLModelEndpointsTable
|
|
18
18
|
from .sqlite import ModelEndpointsTable as SQLiteModelEndpointsTable
|
|
@@ -20,7 +20,7 @@ from .sqlite import ModelEndpointsTable as SQLiteModelEndpointsTable
|
|
|
20
20
|
|
|
21
21
|
def get_model_endpoints_table(
|
|
22
22
|
connection_string: Optional[str] = None,
|
|
23
|
-
) -> Union[
|
|
23
|
+
) -> Union[type[MySQLModelEndpointsTable], type[SQLiteModelEndpointsTable]]:
|
|
24
24
|
"""Return ModelEndpointsTable based on the provided connection string"""
|
|
25
25
|
if connection_string and "mysql:" in connection_string:
|
|
26
26
|
return MySQLModelEndpointsTable
|