mlrun 1.7.0rc22__py3-none-any.whl → 1.7.0rc28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__main__.py +10 -8
- mlrun/alerts/alert.py +13 -1
- mlrun/artifacts/manager.py +5 -0
- mlrun/common/constants.py +2 -2
- mlrun/common/formatters/__init__.py +1 -0
- mlrun/common/formatters/artifact.py +26 -3
- mlrun/common/formatters/base.py +9 -9
- mlrun/common/formatters/run.py +26 -0
- mlrun/common/helpers.py +11 -0
- mlrun/common/schemas/__init__.py +4 -0
- mlrun/common/schemas/alert.py +5 -9
- mlrun/common/schemas/api_gateway.py +64 -16
- mlrun/common/schemas/artifact.py +11 -0
- mlrun/common/schemas/constants.py +3 -0
- mlrun/common/schemas/feature_store.py +58 -28
- mlrun/common/schemas/model_monitoring/constants.py +21 -12
- mlrun/common/schemas/model_monitoring/model_endpoints.py +0 -12
- mlrun/common/schemas/pipeline.py +16 -0
- mlrun/common/schemas/project.py +17 -0
- mlrun/common/schemas/runs.py +17 -0
- mlrun/common/schemas/schedule.py +1 -1
- mlrun/common/types.py +5 -0
- mlrun/config.py +10 -25
- mlrun/datastore/azure_blob.py +2 -1
- mlrun/datastore/datastore.py +3 -3
- mlrun/datastore/google_cloud_storage.py +6 -2
- mlrun/datastore/snowflake_utils.py +3 -1
- mlrun/datastore/sources.py +26 -11
- mlrun/datastore/store_resources.py +2 -0
- mlrun/datastore/targets.py +68 -16
- mlrun/db/base.py +64 -2
- mlrun/db/httpdb.py +129 -41
- mlrun/db/nopdb.py +44 -3
- mlrun/errors.py +5 -3
- mlrun/execution.py +18 -10
- mlrun/feature_store/retrieval/spark_merger.py +2 -1
- mlrun/frameworks/__init__.py +0 -6
- mlrun/model.py +23 -0
- mlrun/model_monitoring/api.py +6 -52
- mlrun/model_monitoring/applications/histogram_data_drift.py +1 -1
- mlrun/model_monitoring/db/stores/__init__.py +37 -24
- mlrun/model_monitoring/db/stores/base/store.py +40 -1
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +42 -87
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +27 -35
- mlrun/model_monitoring/db/tsdb/__init__.py +15 -15
- mlrun/model_monitoring/db/tsdb/base.py +1 -1
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +6 -4
- mlrun/model_monitoring/helpers.py +17 -9
- mlrun/model_monitoring/stream_processing.py +9 -11
- mlrun/model_monitoring/writer.py +11 -11
- mlrun/package/__init__.py +1 -13
- mlrun/package/packagers/__init__.py +1 -6
- mlrun/projects/pipelines.py +10 -9
- mlrun/projects/project.py +95 -81
- mlrun/render.py +10 -5
- mlrun/run.py +13 -8
- mlrun/runtimes/base.py +11 -4
- mlrun/runtimes/daskjob.py +7 -1
- mlrun/runtimes/local.py +16 -3
- mlrun/runtimes/nuclio/application/application.py +0 -2
- mlrun/runtimes/nuclio/function.py +20 -0
- mlrun/runtimes/nuclio/serving.py +9 -6
- mlrun/runtimes/pod.py +5 -29
- mlrun/serving/routers.py +75 -59
- mlrun/serving/server.py +11 -0
- mlrun/serving/states.py +29 -0
- mlrun/serving/v2_serving.py +62 -39
- mlrun/utils/helpers.py +39 -1
- mlrun/utils/logger.py +36 -2
- mlrun/utils/notifications/notification/base.py +43 -7
- mlrun/utils/notifications/notification/git.py +21 -0
- mlrun/utils/notifications/notification/slack.py +9 -14
- mlrun/utils/notifications/notification/webhook.py +41 -1
- mlrun/utils/notifications/notification_pusher.py +3 -9
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.0rc22.dist-info → mlrun-1.7.0rc28.dist-info}/METADATA +12 -7
- {mlrun-1.7.0rc22.dist-info → mlrun-1.7.0rc28.dist-info}/RECORD +81 -80
- {mlrun-1.7.0rc22.dist-info → mlrun-1.7.0rc28.dist-info}/WHEEL +1 -1
- {mlrun-1.7.0rc22.dist-info → mlrun-1.7.0rc28.dist-info}/LICENSE +0 -0
- {mlrun-1.7.0rc22.dist-info → mlrun-1.7.0rc28.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.0rc22.dist-info → mlrun-1.7.0rc28.dist-info}/top_level.txt +0 -0
|
@@ -78,8 +78,6 @@ class EventFieldType:
|
|
|
78
78
|
FEATURE_SET_URI = "monitoring_feature_set_uri"
|
|
79
79
|
ALGORITHM = "algorithm"
|
|
80
80
|
VALUE = "value"
|
|
81
|
-
DRIFT_DETECTED_THRESHOLD = "drift_detected_threshold"
|
|
82
|
-
POSSIBLE_DRIFT_THRESHOLD = "possible_drift_threshold"
|
|
83
81
|
SAMPLE_PARQUET_PATH = "sample_parquet_path"
|
|
84
82
|
TIME = "time"
|
|
85
83
|
TABLE_COLUMN = "table_column"
|
|
@@ -158,19 +156,36 @@ class EventKeyMetrics:
|
|
|
158
156
|
REAL_TIME = "real_time"
|
|
159
157
|
|
|
160
158
|
|
|
161
|
-
class ModelEndpointTarget:
|
|
159
|
+
class ModelEndpointTarget(MonitoringStrEnum):
|
|
162
160
|
V3IO_NOSQL = "v3io-nosql"
|
|
163
161
|
SQL = "sql"
|
|
164
162
|
|
|
165
163
|
|
|
164
|
+
class StreamKind(MonitoringStrEnum):
|
|
165
|
+
V3IO_STREAM = "v3io_stream"
|
|
166
|
+
KAFKA = "kafka"
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
class TSDBTarget(MonitoringStrEnum):
|
|
170
|
+
V3IO_TSDB = "v3io-tsdb"
|
|
171
|
+
TDEngine = "tdengine"
|
|
172
|
+
PROMETHEUS = "prometheus"
|
|
173
|
+
|
|
174
|
+
|
|
166
175
|
class ProjectSecretKeys:
|
|
167
176
|
ENDPOINT_STORE_CONNECTION = "MODEL_MONITORING_ENDPOINT_STORE_CONNECTION"
|
|
168
177
|
ACCESS_KEY = "MODEL_MONITORING_ACCESS_KEY"
|
|
169
|
-
PIPELINES_ACCESS_KEY = "MODEL_MONITORING_PIPELINES_ACCESS_KEY"
|
|
170
|
-
KAFKA_BROKERS = "KAFKA_BROKERS"
|
|
171
178
|
STREAM_PATH = "STREAM_PATH"
|
|
172
179
|
TSDB_CONNECTION = "TSDB_CONNECTION"
|
|
173
180
|
|
|
181
|
+
@classmethod
|
|
182
|
+
def mandatory_secrets(cls):
|
|
183
|
+
return [
|
|
184
|
+
cls.ENDPOINT_STORE_CONNECTION,
|
|
185
|
+
cls.STREAM_PATH,
|
|
186
|
+
cls.TSDB_CONNECTION,
|
|
187
|
+
]
|
|
188
|
+
|
|
174
189
|
|
|
175
190
|
class ModelMonitoringStoreKinds:
|
|
176
191
|
ENDPOINTS = "endpoints"
|
|
@@ -318,7 +333,7 @@ class ResultKindApp(Enum):
|
|
|
318
333
|
concept_drift = 1
|
|
319
334
|
model_performance = 2
|
|
320
335
|
system_performance = 3
|
|
321
|
-
|
|
336
|
+
mm_app_anomaly = 4
|
|
322
337
|
|
|
323
338
|
|
|
324
339
|
class ResultStatusApp(IntEnum):
|
|
@@ -344,12 +359,6 @@ class ControllerPolicy:
|
|
|
344
359
|
BASE_PERIOD = "base_period"
|
|
345
360
|
|
|
346
361
|
|
|
347
|
-
class TSDBTarget:
|
|
348
|
-
V3IO_TSDB = "v3io-tsdb"
|
|
349
|
-
TDEngine = "tdengine"
|
|
350
|
-
PROMETHEUS = "prometheus"
|
|
351
|
-
|
|
352
|
-
|
|
353
362
|
class HistogramDataDriftApplicationConstants:
|
|
354
363
|
NAME = "histogram-data-drift"
|
|
355
364
|
GENERAL_RESULT_NAME = "general_drift"
|
|
@@ -103,18 +103,6 @@ class ModelEndpointSpec(ObjectSpec):
|
|
|
103
103
|
json_parse_values=json_parse_values,
|
|
104
104
|
)
|
|
105
105
|
|
|
106
|
-
@validator("monitor_configuration")
|
|
107
|
-
@classmethod
|
|
108
|
-
def set_name(cls, monitor_configuration):
|
|
109
|
-
return monitor_configuration or {
|
|
110
|
-
EventFieldType.DRIFT_DETECTED_THRESHOLD: (
|
|
111
|
-
mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.drift_detected
|
|
112
|
-
),
|
|
113
|
-
EventFieldType.POSSIBLE_DRIFT_THRESHOLD: (
|
|
114
|
-
mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.possible_drift
|
|
115
|
-
),
|
|
116
|
-
}
|
|
117
|
-
|
|
118
106
|
@validator("model_uri")
|
|
119
107
|
@classmethod
|
|
120
108
|
def validate_model_uri(cls, model_uri):
|
mlrun/common/schemas/pipeline.py
CHANGED
|
@@ -15,6 +15,22 @@
|
|
|
15
15
|
import typing
|
|
16
16
|
|
|
17
17
|
import pydantic
|
|
18
|
+
from deprecated import deprecated
|
|
19
|
+
|
|
20
|
+
import mlrun.common.types
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@deprecated(
|
|
24
|
+
version="1.7.0",
|
|
25
|
+
reason="mlrun.common.schemas.PipelinesFormat is deprecated and will be removed in 1.9.0. "
|
|
26
|
+
"Use mlrun.common.formatters.PipelineFormat instead.",
|
|
27
|
+
category=FutureWarning,
|
|
28
|
+
)
|
|
29
|
+
class PipelinesFormat(mlrun.common.types.StrEnum):
|
|
30
|
+
full = "full"
|
|
31
|
+
metadata_only = "metadata_only"
|
|
32
|
+
summary = "summary"
|
|
33
|
+
name_only = "name_only"
|
|
18
34
|
|
|
19
35
|
|
|
20
36
|
class PipelinesPagination(str):
|
mlrun/common/schemas/project.py
CHANGED
|
@@ -16,6 +16,7 @@ import datetime
|
|
|
16
16
|
import typing
|
|
17
17
|
|
|
18
18
|
import pydantic
|
|
19
|
+
from deprecated import deprecated
|
|
19
20
|
|
|
20
21
|
import mlrun.common.types
|
|
21
22
|
|
|
@@ -23,6 +24,22 @@ from .common import ImageBuilder
|
|
|
23
24
|
from .object import ObjectKind, ObjectStatus
|
|
24
25
|
|
|
25
26
|
|
|
27
|
+
@deprecated(
|
|
28
|
+
version="1.7.0",
|
|
29
|
+
reason="mlrun.common.schemas.ProjectsFormat is deprecated and will be removed in 1.9.0. "
|
|
30
|
+
"Use mlrun.common.formatters.ProjectFormat instead.",
|
|
31
|
+
category=FutureWarning,
|
|
32
|
+
)
|
|
33
|
+
class ProjectsFormat(mlrun.common.types.StrEnum):
|
|
34
|
+
full = "full"
|
|
35
|
+
name_only = "name_only"
|
|
36
|
+
# minimal format removes large fields from the response (e.g. functions, workflows, artifacts)
|
|
37
|
+
# and is used for faster response times (in the UI)
|
|
38
|
+
minimal = "minimal"
|
|
39
|
+
# internal - allowed only in follower mode, only for the leader for upgrade purposes
|
|
40
|
+
leader = "leader"
|
|
41
|
+
|
|
42
|
+
|
|
26
43
|
class ProjectMetadata(pydantic.BaseModel):
|
|
27
44
|
name: str
|
|
28
45
|
created: typing.Optional[datetime.datetime] = None
|
mlrun/common/schemas/runs.py
CHANGED
|
@@ -15,9 +15,26 @@
|
|
|
15
15
|
import typing
|
|
16
16
|
|
|
17
17
|
import pydantic
|
|
18
|
+
from deprecated import deprecated
|
|
19
|
+
|
|
20
|
+
import mlrun.common.types
|
|
18
21
|
|
|
19
22
|
|
|
20
23
|
class RunIdentifier(pydantic.BaseModel):
|
|
21
24
|
kind: typing.Literal["run"] = "run"
|
|
22
25
|
uid: typing.Optional[str]
|
|
23
26
|
iter: typing.Optional[int]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@deprecated(
|
|
30
|
+
version="1.7.0",
|
|
31
|
+
reason="mlrun.common.schemas.RunsFormat is deprecated and will be removed in 1.9.0. "
|
|
32
|
+
"Use mlrun.common.formatters.RunFormat instead.",
|
|
33
|
+
category=FutureWarning,
|
|
34
|
+
)
|
|
35
|
+
class RunsFormat(mlrun.common.types.StrEnum):
|
|
36
|
+
# No enrichment, data is pulled as-is from the database.
|
|
37
|
+
standard = "standard"
|
|
38
|
+
|
|
39
|
+
# Performs run enrichment, including the run's artifacts. Only available for the `get` run API.
|
|
40
|
+
full = "full"
|
mlrun/common/schemas/schedule.py
CHANGED
|
@@ -96,7 +96,7 @@ class ScheduleUpdate(BaseModel):
|
|
|
96
96
|
scheduled_object: Optional[Any]
|
|
97
97
|
cron_trigger: Optional[Union[str, ScheduleCronTrigger]]
|
|
98
98
|
desired_state: Optional[str]
|
|
99
|
-
labels: Optional[dict] =
|
|
99
|
+
labels: Optional[dict] = None
|
|
100
100
|
concurrency_limit: Optional[int]
|
|
101
101
|
credentials: Credentials = Credentials()
|
|
102
102
|
|
mlrun/common/types.py
CHANGED
mlrun/config.py
CHANGED
|
@@ -504,13 +504,12 @@ default_config = {
|
|
|
504
504
|
"model_endpoint_monitoring": {
|
|
505
505
|
"serving_stream_args": {"shard_count": 1, "retention_period_hours": 24},
|
|
506
506
|
"application_stream_args": {"shard_count": 1, "retention_period_hours": 24},
|
|
507
|
-
"drift_thresholds": {"default": {"possible_drift": 0.5, "drift_detected": 0.7}},
|
|
508
507
|
# Store prefixes are used to handle model monitoring storing policies based on project and kind, such as events,
|
|
509
508
|
# stream, and endpoints.
|
|
510
509
|
"store_prefixes": {
|
|
511
510
|
"default": "v3io:///users/pipelines/{project}/model-endpoints/{kind}",
|
|
512
511
|
"user_space": "v3io:///projects/{project}/model-endpoints/{kind}",
|
|
513
|
-
"stream": "",
|
|
512
|
+
"stream": "", # TODO: Delete in 1.9.0
|
|
514
513
|
"monitoring_application": "v3io:///users/pipelines/{project}/monitoring-apps/",
|
|
515
514
|
},
|
|
516
515
|
# Offline storage path can be either relative or a full path. This path is used for general offline data
|
|
@@ -523,11 +522,12 @@ default_config = {
|
|
|
523
522
|
"parquet_batching_max_events": 10_000,
|
|
524
523
|
"parquet_batching_timeout_secs": timedelta(minutes=1).total_seconds(),
|
|
525
524
|
# See mlrun.model_monitoring.db.stores.ObjectStoreFactory for available options
|
|
526
|
-
"store_type": "v3io-nosql",
|
|
525
|
+
"store_type": "v3io-nosql", # TODO: Delete in 1.9.0
|
|
527
526
|
"endpoint_store_connection": "",
|
|
528
527
|
# See mlrun.model_monitoring.db.tsdb.ObjectTSDBFactory for available options
|
|
529
|
-
"tsdb_connector_type": "v3io-tsdb",
|
|
530
528
|
"tsdb_connection": "",
|
|
529
|
+
# See mlrun.common.schemas.model_monitoring.constants.StreamKind for available options
|
|
530
|
+
"stream_connection": "",
|
|
531
531
|
},
|
|
532
532
|
"secret_stores": {
|
|
533
533
|
# Use only in testing scenarios (such as integration tests) to avoid using k8s for secrets (will use in-memory
|
|
@@ -660,7 +660,9 @@ default_config = {
|
|
|
660
660
|
"failed_runs_grace_period": 3600,
|
|
661
661
|
"verbose": True,
|
|
662
662
|
# the number of workers which will be used to trigger the start log collection
|
|
663
|
-
"concurrent_start_logs_workers":
|
|
663
|
+
"concurrent_start_logs_workers": 50,
|
|
664
|
+
# the number of runs for which to start logs on api startup
|
|
665
|
+
"start_logs_startup_run_limit": 150,
|
|
664
666
|
# the time in hours in which to start log collection from.
|
|
665
667
|
# after upgrade, we might have runs which completed in the mean time or still in non-terminal state and
|
|
666
668
|
# we want to collect their logs in the new log collection method (sidecar)
|
|
@@ -707,7 +709,9 @@ default_config = {
|
|
|
707
709
|
"mode": "enabled",
|
|
708
710
|
# maximum number of alerts we allow to be configured.
|
|
709
711
|
# user will get an error when exceeding this
|
|
710
|
-
"max_allowed":
|
|
712
|
+
"max_allowed": 10000,
|
|
713
|
+
# maximum allowed value for count in criteria field inside AlertConfig
|
|
714
|
+
"max_criteria_count": 100,
|
|
711
715
|
},
|
|
712
716
|
"auth_with_client_id": {
|
|
713
717
|
"enabled": False,
|
|
@@ -938,24 +942,6 @@ class Config:
|
|
|
938
942
|
f"is not allowed for iguazio version: {igz_version} < 3.5.1"
|
|
939
943
|
)
|
|
940
944
|
|
|
941
|
-
def resolve_kfp_url(self, namespace=None):
|
|
942
|
-
if config.kfp_url:
|
|
943
|
-
return config.kfp_url
|
|
944
|
-
igz_version = self.get_parsed_igz_version()
|
|
945
|
-
# TODO: When Iguazio 3.4 will deprecate we can remove this line
|
|
946
|
-
if igz_version and igz_version <= semver.VersionInfo.parse("3.6.0-b1"):
|
|
947
|
-
if namespace is None:
|
|
948
|
-
if not config.namespace:
|
|
949
|
-
raise mlrun.errors.MLRunNotFoundError(
|
|
950
|
-
"For KubeFlow Pipelines to function, a namespace must be configured"
|
|
951
|
-
)
|
|
952
|
-
namespace = config.namespace
|
|
953
|
-
# When instead of host we provided namespace we tackled this issue
|
|
954
|
-
# https://github.com/canonical/bundle-kubeflow/issues/412
|
|
955
|
-
# TODO: When we'll move to kfp 1.4.0 (server side) it should be resolved
|
|
956
|
-
return f"http://ml-pipeline.{namespace}.svc.cluster.local:8888"
|
|
957
|
-
return None
|
|
958
|
-
|
|
959
945
|
def resolve_chief_api_url(self) -> str:
|
|
960
946
|
if self.httpdb.clusterization.chief.url:
|
|
961
947
|
return self.httpdb.clusterization.chief.url
|
|
@@ -1136,7 +1122,6 @@ class Config:
|
|
|
1136
1122
|
if store_prefix_dict.get(kind):
|
|
1137
1123
|
# Target exist in store prefix and has a valid string value
|
|
1138
1124
|
return store_prefix_dict[kind].format(project=project, **kwargs)
|
|
1139
|
-
|
|
1140
1125
|
if (
|
|
1141
1126
|
function_name
|
|
1142
1127
|
and function_name
|
mlrun/datastore/azure_blob.py
CHANGED
mlrun/datastore/datastore.py
CHANGED
|
@@ -21,7 +21,7 @@ from mlrun.datastore.datastore_profile import datastore_profile_read
|
|
|
21
21
|
from mlrun.errors import err_to_str
|
|
22
22
|
from mlrun.utils.helpers import get_local_file_schema
|
|
23
23
|
|
|
24
|
-
from ..utils import DB_SCHEMA,
|
|
24
|
+
from ..utils import DB_SCHEMA, RunKeys
|
|
25
25
|
from .base import DataItem, DataStore, HttpStore
|
|
26
26
|
from .filestore import FileStore
|
|
27
27
|
from .inmem import InMemoryStore
|
|
@@ -133,7 +133,7 @@ class StoreManager:
|
|
|
133
133
|
return self._db
|
|
134
134
|
|
|
135
135
|
def from_dict(self, struct: dict):
|
|
136
|
-
stor_list = struct.get(
|
|
136
|
+
stor_list = struct.get(RunKeys.data_stores)
|
|
137
137
|
if stor_list and isinstance(stor_list, list):
|
|
138
138
|
for stor in stor_list:
|
|
139
139
|
schema, endpoint, parsed_url = parse_url(stor.get("url"))
|
|
@@ -145,7 +145,7 @@ class StoreManager:
|
|
|
145
145
|
self._stores[stor["name"]] = new_stor
|
|
146
146
|
|
|
147
147
|
def to_dict(self, struct):
|
|
148
|
-
struct[
|
|
148
|
+
struct[RunKeys.data_stores] = [
|
|
149
149
|
stor.to_dict() for stor in self._stores.values() if stor.from_spec
|
|
150
150
|
]
|
|
151
151
|
|
|
@@ -55,8 +55,12 @@ class GoogleCloudStorageStore(DataStore):
|
|
|
55
55
|
) or self._get_secret_or_env("GOOGLE_APPLICATION_CREDENTIALS")
|
|
56
56
|
if credentials:
|
|
57
57
|
try:
|
|
58
|
-
# Try to handle credentials as a json connection string
|
|
59
|
-
token =
|
|
58
|
+
# Try to handle credentials as a json connection string or do nothing if already a dict
|
|
59
|
+
token = (
|
|
60
|
+
credentials
|
|
61
|
+
if isinstance(credentials, dict)
|
|
62
|
+
else json.loads(credentials)
|
|
63
|
+
)
|
|
60
64
|
except json.JSONDecodeError:
|
|
61
65
|
# If it's not json, handle it as a filename
|
|
62
66
|
token = credentials
|
|
@@ -30,13 +30,15 @@ def get_snowflake_password():
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def get_snowflake_spark_options(attributes):
|
|
33
|
+
if not attributes:
|
|
34
|
+
return {}
|
|
33
35
|
return {
|
|
34
36
|
"format": "net.snowflake.spark.snowflake",
|
|
35
37
|
"sfURL": attributes.get("url"),
|
|
36
38
|
"sfUser": attributes.get("user"),
|
|
37
39
|
"sfPassword": get_snowflake_password(),
|
|
38
40
|
"sfDatabase": attributes.get("database"),
|
|
39
|
-
"sfSchema": attributes.get("
|
|
41
|
+
"sfSchema": attributes.get("db_schema"),
|
|
40
42
|
"sfWarehouse": attributes.get("warehouse"),
|
|
41
43
|
"application": "iguazio_platform",
|
|
42
44
|
"TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ",
|
mlrun/datastore/sources.py
CHANGED
|
@@ -747,7 +747,7 @@ class SnowflakeSource(BaseSourceDriver):
|
|
|
747
747
|
url="...",
|
|
748
748
|
user="...",
|
|
749
749
|
database="...",
|
|
750
|
-
|
|
750
|
+
db_schema="...",
|
|
751
751
|
warehouse="...",
|
|
752
752
|
)
|
|
753
753
|
|
|
@@ -762,7 +762,8 @@ class SnowflakeSource(BaseSourceDriver):
|
|
|
762
762
|
:parameter url: URL of the snowflake cluster
|
|
763
763
|
:parameter user: snowflake user
|
|
764
764
|
:parameter database: snowflake database
|
|
765
|
-
:parameter schema: snowflake schema
|
|
765
|
+
:parameter schema: snowflake schema - deprecated, use db_schema
|
|
766
|
+
:parameter db_schema: snowflake schema
|
|
766
767
|
:parameter warehouse: snowflake warehouse
|
|
767
768
|
"""
|
|
768
769
|
|
|
@@ -774,6 +775,7 @@ class SnowflakeSource(BaseSourceDriver):
|
|
|
774
775
|
self,
|
|
775
776
|
name: str = "",
|
|
776
777
|
key_field: str = None,
|
|
778
|
+
attributes: dict[str, object] = None,
|
|
777
779
|
time_field: str = None,
|
|
778
780
|
schedule: str = None,
|
|
779
781
|
start_time=None,
|
|
@@ -783,21 +785,34 @@ class SnowflakeSource(BaseSourceDriver):
|
|
|
783
785
|
user: str = None,
|
|
784
786
|
database: str = None,
|
|
785
787
|
schema: str = None,
|
|
788
|
+
db_schema: str = None,
|
|
786
789
|
warehouse: str = None,
|
|
787
790
|
**kwargs,
|
|
788
791
|
):
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
}
|
|
792
|
+
# TODO: Remove in 1.9.0
|
|
793
|
+
if schema:
|
|
794
|
+
warnings.warn(
|
|
795
|
+
"schema is deprecated in 1.7.0, and will be removed in 1.9.0, please use db_schema"
|
|
796
|
+
)
|
|
797
|
+
db_schema = db_schema or schema # TODO: Remove in 1.9.0
|
|
798
|
+
|
|
799
|
+
attributes = attributes or {}
|
|
800
|
+
if url:
|
|
801
|
+
attributes["url"] = url
|
|
802
|
+
if user:
|
|
803
|
+
attributes["user"] = user
|
|
804
|
+
if database:
|
|
805
|
+
attributes["database"] = database
|
|
806
|
+
if db_schema:
|
|
807
|
+
attributes["db_schema"] = db_schema
|
|
808
|
+
if warehouse:
|
|
809
|
+
attributes["warehouse"] = warehouse
|
|
810
|
+
if query:
|
|
811
|
+
attributes["query"] = query
|
|
797
812
|
|
|
798
813
|
super().__init__(
|
|
799
814
|
name,
|
|
800
|
-
attributes=
|
|
815
|
+
attributes=attributes,
|
|
801
816
|
key_field=key_field,
|
|
802
817
|
time_field=time_field,
|
|
803
818
|
schedule=schedule,
|
mlrun/datastore/targets.py
CHANGED
|
@@ -29,7 +29,10 @@ from mergedeep import merge
|
|
|
29
29
|
import mlrun
|
|
30
30
|
import mlrun.utils.helpers
|
|
31
31
|
from mlrun.config import config
|
|
32
|
-
from mlrun.datastore.snowflake_utils import
|
|
32
|
+
from mlrun.datastore.snowflake_utils import (
|
|
33
|
+
get_snowflake_password,
|
|
34
|
+
get_snowflake_spark_options,
|
|
35
|
+
)
|
|
33
36
|
from mlrun.datastore.utils import transform_list_filters_to_tuple
|
|
34
37
|
from mlrun.model import DataSource, DataTarget, DataTargetBase, TargetPathObject
|
|
35
38
|
from mlrun.utils import logger, now_date
|
|
@@ -696,6 +699,7 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
696
699
|
self.kind, self.name, self.get_target_templated_path()
|
|
697
700
|
)
|
|
698
701
|
target = self._target
|
|
702
|
+
target.attributes = self.attributes
|
|
699
703
|
target.run_id = self.run_id
|
|
700
704
|
target.status = status or target.status or "created"
|
|
701
705
|
target.updated = now_date().isoformat()
|
|
@@ -727,8 +731,18 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
727
731
|
raise NotImplementedError()
|
|
728
732
|
|
|
729
733
|
def purge(self):
|
|
734
|
+
"""
|
|
735
|
+
Delete the files of the target.
|
|
736
|
+
|
|
737
|
+
Do not use this function directly from the sdk. Use FeatureSet.purge_targets.
|
|
738
|
+
"""
|
|
730
739
|
store, path_in_store, target_path = self._get_store_and_path()
|
|
731
|
-
|
|
740
|
+
if path_in_store not in ["", "/"]:
|
|
741
|
+
store.rm(path_in_store, recursive=True)
|
|
742
|
+
else:
|
|
743
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
744
|
+
"Unable to delete target. Please Use purge_targets from FeatureSet object."
|
|
745
|
+
)
|
|
732
746
|
|
|
733
747
|
def as_df(
|
|
734
748
|
self,
|
|
@@ -764,6 +778,10 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
764
778
|
def get_dask_options(self):
|
|
765
779
|
raise NotImplementedError()
|
|
766
780
|
|
|
781
|
+
@property
|
|
782
|
+
def source_spark_attributes(self) -> dict:
|
|
783
|
+
return {}
|
|
784
|
+
|
|
767
785
|
|
|
768
786
|
class ParquetTarget(BaseStoreTarget):
|
|
769
787
|
"""Parquet target storage driver, used to materialize feature set/vector data into parquet files.
|
|
@@ -1197,19 +1215,20 @@ class SnowflakeTarget(BaseStoreTarget):
|
|
|
1197
1215
|
warehouse: str = None,
|
|
1198
1216
|
table_name: str = None,
|
|
1199
1217
|
):
|
|
1200
|
-
|
|
1201
|
-
|
|
1202
|
-
"
|
|
1203
|
-
|
|
1204
|
-
"
|
|
1205
|
-
|
|
1206
|
-
"
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
|
|
1210
|
-
|
|
1211
|
-
|
|
1212
|
-
|
|
1218
|
+
attributes = attributes or {}
|
|
1219
|
+
if url:
|
|
1220
|
+
attributes["url"] = url
|
|
1221
|
+
if user:
|
|
1222
|
+
attributes["user"] = user
|
|
1223
|
+
if database:
|
|
1224
|
+
attributes["database"] = database
|
|
1225
|
+
if db_schema:
|
|
1226
|
+
attributes["db_schema"] = db_schema
|
|
1227
|
+
if warehouse:
|
|
1228
|
+
attributes["warehouse"] = warehouse
|
|
1229
|
+
if table_name:
|
|
1230
|
+
attributes["table"] = table_name
|
|
1231
|
+
|
|
1213
1232
|
super().__init__(
|
|
1214
1233
|
name,
|
|
1215
1234
|
path,
|
|
@@ -1233,7 +1252,31 @@ class SnowflakeTarget(BaseStoreTarget):
|
|
|
1233
1252
|
return spark_options
|
|
1234
1253
|
|
|
1235
1254
|
def purge(self):
|
|
1236
|
-
|
|
1255
|
+
import snowflake.connector
|
|
1256
|
+
|
|
1257
|
+
missing = [
|
|
1258
|
+
key
|
|
1259
|
+
for key in ["database", "db_schema", "table", "url", "user", "warehouse"]
|
|
1260
|
+
if self.attributes.get(key) is None
|
|
1261
|
+
]
|
|
1262
|
+
if missing:
|
|
1263
|
+
raise mlrun.errors.MLRunRuntimeError(
|
|
1264
|
+
f"Can't purge Snowflake target, "
|
|
1265
|
+
f"some attributes are missing: {', '.join(missing)}"
|
|
1266
|
+
)
|
|
1267
|
+
account = self.attributes["url"].replace(".snowflakecomputing.com", "")
|
|
1268
|
+
|
|
1269
|
+
with snowflake.connector.connect(
|
|
1270
|
+
account=account,
|
|
1271
|
+
user=self.attributes["user"],
|
|
1272
|
+
password=get_snowflake_password(),
|
|
1273
|
+
warehouse=self.attributes["warehouse"],
|
|
1274
|
+
) as snowflake_connector:
|
|
1275
|
+
drop_statement = (
|
|
1276
|
+
f"DROP TABLE IF EXISTS {self.attributes['database']}.{self.attributes['db_schema']}"
|
|
1277
|
+
f".{self.attributes['table']}"
|
|
1278
|
+
)
|
|
1279
|
+
snowflake_connector.execute_string(drop_statement)
|
|
1237
1280
|
|
|
1238
1281
|
def as_df(
|
|
1239
1282
|
self,
|
|
@@ -1248,6 +1291,15 @@ class SnowflakeTarget(BaseStoreTarget):
|
|
|
1248
1291
|
):
|
|
1249
1292
|
raise NotImplementedError()
|
|
1250
1293
|
|
|
1294
|
+
@property
|
|
1295
|
+
def source_spark_attributes(self) -> dict:
|
|
1296
|
+
keys = ["url", "user", "database", "db_schema", "warehouse"]
|
|
1297
|
+
attributes = self.attributes or {}
|
|
1298
|
+
snowflake_dict = {key: attributes.get(key) for key in keys}
|
|
1299
|
+
table = attributes.get("table")
|
|
1300
|
+
snowflake_dict["query"] = f"SELECT * from {table}" if table else None
|
|
1301
|
+
return snowflake_dict
|
|
1302
|
+
|
|
1251
1303
|
|
|
1252
1304
|
class NoSqlBaseTarget(BaseStoreTarget):
|
|
1253
1305
|
is_table = True
|