mlrun 1.7.0rc7__py3-none-any.whl → 1.7.0rc9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__main__.py +2 -0
- mlrun/common/schemas/__init__.py +3 -0
- mlrun/common/schemas/api_gateway.py +8 -1
- mlrun/common/schemas/hub.py +7 -9
- mlrun/common/schemas/model_monitoring/constants.py +1 -1
- mlrun/common/schemas/pagination.py +26 -0
- mlrun/common/schemas/project.py +15 -10
- mlrun/config.py +28 -10
- mlrun/datastore/__init__.py +3 -7
- mlrun/datastore/datastore_profile.py +19 -1
- mlrun/datastore/snowflake_utils.py +43 -0
- mlrun/datastore/sources.py +9 -26
- mlrun/datastore/targets.py +131 -11
- mlrun/datastore/utils.py +10 -5
- mlrun/db/base.py +44 -0
- mlrun/db/httpdb.py +122 -21
- mlrun/db/nopdb.py +107 -0
- mlrun/feature_store/api.py +3 -2
- mlrun/feature_store/retrieval/spark_merger.py +27 -23
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +1 -1
- mlrun/frameworks/tf_keras/mlrun_interface.py +2 -2
- mlrun/kfpops.py +2 -5
- mlrun/launcher/base.py +1 -1
- mlrun/launcher/client.py +2 -2
- mlrun/model_monitoring/helpers.py +3 -1
- mlrun/projects/pipelines.py +1 -1
- mlrun/projects/project.py +32 -21
- mlrun/run.py +5 -1
- mlrun/runtimes/__init__.py +16 -0
- mlrun/runtimes/base.py +4 -1
- mlrun/runtimes/kubejob.py +26 -121
- mlrun/runtimes/nuclio/api_gateway.py +58 -8
- mlrun/runtimes/nuclio/application/application.py +79 -1
- mlrun/runtimes/nuclio/application/reverse_proxy.go +9 -1
- mlrun/runtimes/nuclio/function.py +11 -8
- mlrun/runtimes/nuclio/serving.py +2 -2
- mlrun/runtimes/pod.py +145 -0
- mlrun/runtimes/utils.py +0 -28
- mlrun/serving/remote.py +2 -3
- mlrun/serving/routers.py +4 -3
- mlrun/serving/server.py +1 -1
- mlrun/serving/states.py +6 -9
- mlrun/serving/v2_serving.py +4 -3
- mlrun/utils/http.py +1 -1
- mlrun/utils/retryer.py +1 -0
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/METADATA +15 -15
- {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/RECORD +52 -50
- {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/LICENSE +0 -0
- {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/WHEEL +0 -0
- {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc9.dist-info}/top_level.txt +0 -0
mlrun/__main__.py
CHANGED
mlrun/common/schemas/__init__.py
CHANGED
|
@@ -21,6 +21,7 @@ from .api_gateway import (
|
|
|
21
21
|
APIGatewayMetadata,
|
|
22
22
|
APIGatewaysOutput,
|
|
23
23
|
APIGatewaySpec,
|
|
24
|
+
APIGatewayState,
|
|
24
25
|
APIGatewayStatus,
|
|
25
26
|
APIGatewayUpstream,
|
|
26
27
|
)
|
|
@@ -151,12 +152,14 @@ from .notification import (
|
|
|
151
152
|
SetNotificationRequest,
|
|
152
153
|
)
|
|
153
154
|
from .object import ObjectKind, ObjectMetadata, ObjectSpec, ObjectStatus
|
|
155
|
+
from .pagination import PaginationInfo
|
|
154
156
|
from .pipeline import PipelinesFormat, PipelinesOutput, PipelinesPagination
|
|
155
157
|
from .project import (
|
|
156
158
|
IguazioProject,
|
|
157
159
|
Project,
|
|
158
160
|
ProjectDesiredState,
|
|
159
161
|
ProjectMetadata,
|
|
162
|
+
ProjectOutput,
|
|
160
163
|
ProjectOwner,
|
|
161
164
|
ProjectsFormat,
|
|
162
165
|
ProjectsOutput,
|
|
@@ -36,6 +36,13 @@ class APIGatewayAuthenticationMode(mlrun.common.types.StrEnum):
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
+
class APIGatewayState(mlrun.common.types.StrEnum):
|
|
40
|
+
none = ""
|
|
41
|
+
ready = "ready"
|
|
42
|
+
error = "error"
|
|
43
|
+
waiting_for_provisioning = "waitingForProvisioning"
|
|
44
|
+
|
|
45
|
+
|
|
39
46
|
class _APIGatewayBaseModel(pydantic.BaseModel):
|
|
40
47
|
class Config:
|
|
41
48
|
extra = pydantic.Extra.allow
|
|
@@ -72,7 +79,7 @@ class APIGatewaySpec(_APIGatewayBaseModel):
|
|
|
72
79
|
|
|
73
80
|
class APIGatewayStatus(_APIGatewayBaseModel):
|
|
74
81
|
name: Optional[str]
|
|
75
|
-
state: Optional[
|
|
82
|
+
state: Optional[APIGatewayState]
|
|
76
83
|
|
|
77
84
|
|
|
78
85
|
class APIGateway(_APIGatewayBaseModel):
|
mlrun/common/schemas/hub.py
CHANGED
|
@@ -59,28 +59,26 @@ class HubSource(BaseModel):
|
|
|
59
59
|
return f"{self.spec.path}/{self.spec.object_type}/{self.spec.channel}/{relative_path}"
|
|
60
60
|
|
|
61
61
|
def get_catalog_uri(self):
|
|
62
|
-
return self.get_full_uri(mlrun.
|
|
62
|
+
return self.get_full_uri(mlrun.mlconf.hub.catalog_filename)
|
|
63
63
|
|
|
64
64
|
@classmethod
|
|
65
65
|
def generate_default_source(cls):
|
|
66
|
-
if not mlrun.
|
|
66
|
+
if not mlrun.mlconf.hub.default_source.create:
|
|
67
67
|
return None
|
|
68
68
|
|
|
69
69
|
now = datetime.now(timezone.utc)
|
|
70
70
|
hub_metadata = HubObjectMetadata(
|
|
71
|
-
name=mlrun.
|
|
72
|
-
description=mlrun.
|
|
71
|
+
name=mlrun.mlconf.hub.default_source.name,
|
|
72
|
+
description=mlrun.mlconf.hub.default_source.description,
|
|
73
73
|
created=now,
|
|
74
74
|
updated=now,
|
|
75
75
|
)
|
|
76
76
|
return cls(
|
|
77
77
|
metadata=hub_metadata,
|
|
78
78
|
spec=HubSourceSpec(
|
|
79
|
-
path=mlrun.
|
|
80
|
-
channel=mlrun.
|
|
81
|
-
object_type=HubSourceType(
|
|
82
|
-
mlrun.config.config.hub.default_source.object_type
|
|
83
|
-
),
|
|
79
|
+
path=mlrun.mlconf.hub.default_source.url,
|
|
80
|
+
channel=mlrun.mlconf.hub.default_source.channel,
|
|
81
|
+
object_type=HubSourceType(mlrun.mlconf.hub.default_source.object_type),
|
|
84
82
|
),
|
|
85
83
|
status=ObjectStatus(state="created"),
|
|
86
84
|
)
|
|
@@ -151,7 +151,7 @@ class ProjectSecretKeys:
|
|
|
151
151
|
ENDPOINT_STORE_CONNECTION = "MODEL_MONITORING_ENDPOINT_STORE_CONNECTION"
|
|
152
152
|
ACCESS_KEY = "MODEL_MONITORING_ACCESS_KEY"
|
|
153
153
|
PIPELINES_ACCESS_KEY = "MODEL_MONITORING_PIPELINES_ACCESS_KEY"
|
|
154
|
-
|
|
154
|
+
KAFKA_BROKERS = "KAFKA_BROKERS"
|
|
155
155
|
STREAM_PATH = "STREAM_PATH"
|
|
156
156
|
|
|
157
157
|
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2023 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import typing
|
|
16
|
+
|
|
17
|
+
import pydantic
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PaginationInfo(pydantic.BaseModel):
|
|
21
|
+
class Config:
|
|
22
|
+
allow_population_by_field_name = True
|
|
23
|
+
|
|
24
|
+
page: typing.Optional[int]
|
|
25
|
+
page_size: typing.Optional[int] = pydantic.Field(alias="page-size")
|
|
26
|
+
page_token: typing.Optional[str] = pydantic.Field(alias="page-token")
|
mlrun/common/schemas/project.py
CHANGED
|
@@ -120,17 +120,22 @@ class IguazioProject(pydantic.BaseModel):
|
|
|
120
120
|
data: dict
|
|
121
121
|
|
|
122
122
|
|
|
123
|
+
# The format query param controls the project type used:
|
|
124
|
+
# full - Project
|
|
125
|
+
# name_only - str
|
|
126
|
+
# summary - ProjectSummary
|
|
127
|
+
# leader - currently only IguazioProject supported
|
|
128
|
+
# The way pydantic handles typing.Union is that it takes the object and tries to coerce it to be the types of the
|
|
129
|
+
# union by the definition order. Therefore we can't currently add generic dict for all leader formats, but we need
|
|
130
|
+
# to add a specific classes for them. it's frustrating but couldn't find other workaround, see:
|
|
131
|
+
# https://github.com/samuelcolvin/pydantic/issues/1423, https://github.com/samuelcolvin/pydantic/issues/619
|
|
132
|
+
ProjectOutput = typing.TypeVar(
|
|
133
|
+
"ProjectOutput", Project, str, ProjectSummary, IguazioProject
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
123
137
|
class ProjectsOutput(pydantic.BaseModel):
|
|
124
|
-
|
|
125
|
-
# full - Project
|
|
126
|
-
# name_only - str
|
|
127
|
-
# summary - ProjectSummary
|
|
128
|
-
# leader - currently only IguazioProject supported
|
|
129
|
-
# The way pydantic handles typing.Union is that it takes the object and tries to coerce it to be the types of the
|
|
130
|
-
# union by the definition order. Therefore we can't currently add generic dict for all leader formats, but we need
|
|
131
|
-
# to add a specific classes for them. it's frustrating but couldn't find other workaround, see:
|
|
132
|
-
# https://github.com/samuelcolvin/pydantic/issues/1423, https://github.com/samuelcolvin/pydantic/issues/619
|
|
133
|
-
projects: list[typing.Union[Project, str, ProjectSummary, IguazioProject]]
|
|
138
|
+
projects: list[ProjectOutput]
|
|
134
139
|
|
|
135
140
|
|
|
136
141
|
class ProjectSummariesOutput(pydantic.BaseModel):
|
mlrun/config.py
CHANGED
|
@@ -240,6 +240,7 @@ default_config = {
|
|
|
240
240
|
"remote": "mlrun/mlrun",
|
|
241
241
|
"dask": "mlrun/ml-base",
|
|
242
242
|
"mpijob": "mlrun/mlrun",
|
|
243
|
+
"application": "python:3.9-slim",
|
|
243
244
|
},
|
|
244
245
|
# see enrich_function_preemption_spec for more info,
|
|
245
246
|
# and mlrun.common.schemas.function.PreemptionModes for available options
|
|
@@ -481,10 +482,13 @@ default_config = {
|
|
|
481
482
|
# if set to true, will log a warning for trying to use run db functionality while in nop db mode
|
|
482
483
|
"verbose": True,
|
|
483
484
|
},
|
|
484
|
-
"
|
|
485
|
-
"
|
|
486
|
-
"
|
|
487
|
-
|
|
485
|
+
"pagination": {
|
|
486
|
+
"default_page_size": 20,
|
|
487
|
+
"pagination_cache": {
|
|
488
|
+
"interval": 60,
|
|
489
|
+
"ttl": 3600,
|
|
490
|
+
"max_size": 10000,
|
|
491
|
+
},
|
|
488
492
|
},
|
|
489
493
|
},
|
|
490
494
|
"model_endpoint_monitoring": {
|
|
@@ -548,6 +552,7 @@ default_config = {
|
|
|
548
552
|
"nosql": "v3io:///projects/{project}/FeatureStore/{name}/{kind}",
|
|
549
553
|
# "authority" is optional and generalizes [userinfo "@"] host [":" port]
|
|
550
554
|
"redisnosql": "redis://{authority}/projects/{project}/FeatureStore/{name}/{kind}",
|
|
555
|
+
"dsnosql": "ds://{ds_profile_name}/projects/{project}/FeatureStore/{name}/{kind}",
|
|
551
556
|
},
|
|
552
557
|
"default_targets": "parquet,nosql",
|
|
553
558
|
"default_job_image": "mlrun/mlrun",
|
|
@@ -1073,7 +1078,7 @@ class Config:
|
|
|
1073
1078
|
target: str = "online",
|
|
1074
1079
|
artifact_path: str = None,
|
|
1075
1080
|
function_name: str = None,
|
|
1076
|
-
) -> str:
|
|
1081
|
+
) -> typing.Union[str, list[str]]:
|
|
1077
1082
|
"""Get the full path from the configuration based on the provided project and kind.
|
|
1078
1083
|
|
|
1079
1084
|
:param project: Project name.
|
|
@@ -1089,7 +1094,8 @@ class Config:
|
|
|
1089
1094
|
relative artifact path will be taken from the global MLRun artifact path.
|
|
1090
1095
|
:param function_name: Application name, None for model_monitoring_stream.
|
|
1091
1096
|
|
|
1092
|
-
:return: Full configured path for the provided kind.
|
|
1097
|
+
:return: Full configured path for the provided kind. Can be either a single path
|
|
1098
|
+
or a list of paths in the case of the online model monitoring stream path.
|
|
1093
1099
|
"""
|
|
1094
1100
|
|
|
1095
1101
|
if target != "offline":
|
|
@@ -1111,10 +1117,22 @@ class Config:
|
|
|
1111
1117
|
if function_name is None
|
|
1112
1118
|
else f"{kind}-{function_name.lower()}",
|
|
1113
1119
|
)
|
|
1114
|
-
return mlrun.
|
|
1115
|
-
|
|
1116
|
-
|
|
1117
|
-
|
|
1120
|
+
elif kind == "stream": # return list for mlrun<1.6.3 BC
|
|
1121
|
+
return [
|
|
1122
|
+
mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
|
|
1123
|
+
project=project,
|
|
1124
|
+
kind=kind,
|
|
1125
|
+
), # old stream uri (pipelines) for BC ML-6043
|
|
1126
|
+
mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format(
|
|
1127
|
+
project=project,
|
|
1128
|
+
kind=kind,
|
|
1129
|
+
), # new stream uri (projects)
|
|
1130
|
+
]
|
|
1131
|
+
else:
|
|
1132
|
+
return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format(
|
|
1133
|
+
project=project,
|
|
1134
|
+
kind=kind,
|
|
1135
|
+
)
|
|
1118
1136
|
|
|
1119
1137
|
# Get the current offline path from the configuration
|
|
1120
1138
|
file_path = mlrun.mlconf.model_endpoint_monitoring.offline_storage_path.format(
|
mlrun/datastore/__init__.py
CHANGED
|
@@ -107,13 +107,9 @@ def get_stream_pusher(stream_path: str, **kwargs):
|
|
|
107
107
|
:param stream_path: path/url of stream
|
|
108
108
|
"""
|
|
109
109
|
|
|
110
|
-
if stream_path.startswith("kafka://") or "
|
|
111
|
-
topic,
|
|
112
|
-
|
|
113
|
-
)
|
|
114
|
-
return KafkaOutputStream(
|
|
115
|
-
topic, bootstrap_servers, kwargs.get("kafka_producer_options")
|
|
116
|
-
)
|
|
110
|
+
if stream_path.startswith("kafka://") or "kafka_brokers" in kwargs:
|
|
111
|
+
topic, brokers = parse_kafka_url(stream_path, kwargs.get("kafka_brokers"))
|
|
112
|
+
return KafkaOutputStream(topic, brokers, kwargs.get("kafka_producer_options"))
|
|
117
113
|
elif stream_path.startswith("http://") or stream_path.startswith("https://"):
|
|
118
114
|
return HTTPOutputStream(stream_path=stream_path)
|
|
119
115
|
elif "://" not in stream_path:
|
|
@@ -16,6 +16,7 @@ import ast
|
|
|
16
16
|
import base64
|
|
17
17
|
import json
|
|
18
18
|
import typing
|
|
19
|
+
import warnings
|
|
19
20
|
from urllib.parse import ParseResult, urlparse, urlunparse
|
|
20
21
|
|
|
21
22
|
import pydantic
|
|
@@ -68,6 +69,9 @@ class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton
|
|
|
68
69
|
def get(self, key):
|
|
69
70
|
return self._data.get(key, None)
|
|
70
71
|
|
|
72
|
+
def remove(self, key):
|
|
73
|
+
self._data.pop(key, None)
|
|
74
|
+
|
|
71
75
|
|
|
72
76
|
class DatastoreProfileBasic(DatastoreProfile):
|
|
73
77
|
type: str = pydantic.Field("basic")
|
|
@@ -80,12 +84,22 @@ class DatastoreProfileKafkaTarget(DatastoreProfile):
|
|
|
80
84
|
type: str = pydantic.Field("kafka_target")
|
|
81
85
|
_private_attributes = "kwargs_private"
|
|
82
86
|
bootstrap_servers: str
|
|
87
|
+
brokers: str
|
|
83
88
|
topic: str
|
|
84
89
|
kwargs_public: typing.Optional[dict]
|
|
85
90
|
kwargs_private: typing.Optional[dict]
|
|
86
91
|
|
|
92
|
+
def __pydantic_post_init__(self):
|
|
93
|
+
if self.bootstrap_servers:
|
|
94
|
+
warnings.warn(
|
|
95
|
+
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
|
|
96
|
+
"use 'brokers' instead.",
|
|
97
|
+
# TODO: Remove this in 1.9.0
|
|
98
|
+
FutureWarning,
|
|
99
|
+
)
|
|
100
|
+
|
|
87
101
|
def attributes(self):
|
|
88
|
-
attributes = {"
|
|
102
|
+
attributes = {"brokers": self.brokers or self.bootstrap_servers}
|
|
89
103
|
if self.kwargs_public:
|
|
90
104
|
attributes = merge(attributes, self.kwargs_public)
|
|
91
105
|
if self.kwargs_private:
|
|
@@ -460,3 +474,7 @@ def register_temporary_client_datastore_profile(profile: DatastoreProfile):
|
|
|
460
474
|
It's beneficial for testing purposes.
|
|
461
475
|
"""
|
|
462
476
|
TemporaryClientDatastoreProfiles().add(profile)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def remove_temporary_client_datastore_profile(profile_name: str):
|
|
480
|
+
TemporaryClientDatastoreProfiles().remove(profile_name)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2024 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
import mlrun
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_snowflake_password():
|
|
20
|
+
key = "SNOWFLAKE_PASSWORD"
|
|
21
|
+
snowflake_password = mlrun.get_secret_or_env(key)
|
|
22
|
+
|
|
23
|
+
if not snowflake_password:
|
|
24
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
25
|
+
f"No password provided. Set password using the {key} "
|
|
26
|
+
"project secret or environment variable."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return snowflake_password
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_snowflake_spark_options(attributes):
|
|
33
|
+
return {
|
|
34
|
+
"format": "net.snowflake.spark.snowflake",
|
|
35
|
+
"sfURL": attributes.get("url"),
|
|
36
|
+
"sfUser": attributes.get("user"),
|
|
37
|
+
"sfPassword": get_snowflake_password(),
|
|
38
|
+
"sfDatabase": attributes.get("database"),
|
|
39
|
+
"sfSchema": attributes.get("schema"),
|
|
40
|
+
"sfWarehouse": attributes.get("warehouse"),
|
|
41
|
+
"application": "iguazio_platform",
|
|
42
|
+
"TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ",
|
|
43
|
+
}
|
mlrun/datastore/sources.py
CHANGED
|
@@ -28,6 +28,7 @@ from nuclio.config import split_path
|
|
|
28
28
|
|
|
29
29
|
import mlrun
|
|
30
30
|
from mlrun.config import config
|
|
31
|
+
from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
|
|
31
32
|
from mlrun.secrets import SecretsStore
|
|
32
33
|
|
|
33
34
|
from ..model import DataSource
|
|
@@ -113,7 +114,11 @@ class BaseSourceDriver(DataSource):
|
|
|
113
114
|
|
|
114
115
|
def to_spark_df(self, session, named_view=False, time_field=None, columns=None):
|
|
115
116
|
if self.support_spark:
|
|
116
|
-
|
|
117
|
+
spark_options = self.get_spark_options()
|
|
118
|
+
spark_format = spark_options.pop("format", None)
|
|
119
|
+
df = load_spark_dataframe_with_options(
|
|
120
|
+
session, spark_options, format=spark_format
|
|
121
|
+
)
|
|
117
122
|
if named_view:
|
|
118
123
|
df.createOrReplaceTempView(self.name)
|
|
119
124
|
return self._filter_spark_df(df, time_field, columns)
|
|
@@ -673,32 +678,10 @@ class SnowflakeSource(BaseSourceDriver):
|
|
|
673
678
|
**kwargs,
|
|
674
679
|
)
|
|
675
680
|
|
|
676
|
-
def _get_password(self):
|
|
677
|
-
key = "SNOWFLAKE_PASSWORD"
|
|
678
|
-
snowflake_password = os.getenv(key) or os.getenv(
|
|
679
|
-
SecretsStore.k8s_env_variable_name_for_secret(key)
|
|
680
|
-
)
|
|
681
|
-
|
|
682
|
-
if not snowflake_password:
|
|
683
|
-
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
684
|
-
"No password provided. Set password using the SNOWFLAKE_PASSWORD "
|
|
685
|
-
"project secret or environment variable."
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
return snowflake_password
|
|
689
|
-
|
|
690
681
|
def get_spark_options(self):
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
"sfURL": self.attributes.get("url"),
|
|
695
|
-
"sfUser": self.attributes.get("user"),
|
|
696
|
-
"sfPassword": self._get_password(),
|
|
697
|
-
"sfDatabase": self.attributes.get("database"),
|
|
698
|
-
"sfSchema": self.attributes.get("schema"),
|
|
699
|
-
"sfWarehouse": self.attributes.get("warehouse"),
|
|
700
|
-
"application": "iguazio_platform",
|
|
701
|
-
}
|
|
682
|
+
spark_options = get_snowflake_spark_options(self.attributes)
|
|
683
|
+
spark_options["query"] = self.attributes.get("query")
|
|
684
|
+
return spark_options
|
|
702
685
|
|
|
703
686
|
|
|
704
687
|
class CustomSource(BaseSourceDriver):
|
mlrun/datastore/targets.py
CHANGED
|
@@ -17,6 +17,7 @@ import os
|
|
|
17
17
|
import random
|
|
18
18
|
import sys
|
|
19
19
|
import time
|
|
20
|
+
import warnings
|
|
20
21
|
from collections import Counter
|
|
21
22
|
from copy import copy
|
|
22
23
|
from typing import Any, Optional, Union
|
|
@@ -28,6 +29,7 @@ from mergedeep import merge
|
|
|
28
29
|
import mlrun
|
|
29
30
|
import mlrun.utils.helpers
|
|
30
31
|
from mlrun.config import config
|
|
32
|
+
from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
|
|
31
33
|
from mlrun.model import DataSource, DataTarget, DataTargetBase, TargetPathObject
|
|
32
34
|
from mlrun.utils import logger, now_date
|
|
33
35
|
from mlrun.utils.helpers import to_parquet
|
|
@@ -57,6 +59,7 @@ class TargetTypes:
|
|
|
57
59
|
dataframe = "dataframe"
|
|
58
60
|
custom = "custom"
|
|
59
61
|
sql = "sql"
|
|
62
|
+
snowflake = "snowflake"
|
|
60
63
|
|
|
61
64
|
@staticmethod
|
|
62
65
|
def all():
|
|
@@ -71,6 +74,7 @@ class TargetTypes:
|
|
|
71
74
|
TargetTypes.dataframe,
|
|
72
75
|
TargetTypes.custom,
|
|
73
76
|
TargetTypes.sql,
|
|
77
|
+
TargetTypes.snowflake,
|
|
74
78
|
]
|
|
75
79
|
|
|
76
80
|
|
|
@@ -78,11 +82,14 @@ def generate_target_run_id():
|
|
|
78
82
|
return f"{round(time.time() * 1000)}_{random.randint(0, 999)}"
|
|
79
83
|
|
|
80
84
|
|
|
81
|
-
def write_spark_dataframe_with_options(spark_options, df, mode):
|
|
85
|
+
def write_spark_dataframe_with_options(spark_options, df, mode, write_format=None):
|
|
82
86
|
non_hadoop_spark_options = spark_session_update_hadoop_options(
|
|
83
87
|
df.sql_ctx.sparkSession, spark_options
|
|
84
88
|
)
|
|
85
|
-
|
|
89
|
+
if write_format:
|
|
90
|
+
df.write.format(write_format).mode(mode).save(**non_hadoop_spark_options)
|
|
91
|
+
else:
|
|
92
|
+
df.write.mode(mode).save(**non_hadoop_spark_options)
|
|
86
93
|
|
|
87
94
|
|
|
88
95
|
def default_target_names():
|
|
@@ -497,7 +504,10 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
497
504
|
options = self.get_spark_options(key_column, timestamp_key)
|
|
498
505
|
options.update(kwargs)
|
|
499
506
|
df = self.prepare_spark_df(df, key_column, timestamp_key, options)
|
|
500
|
-
|
|
507
|
+
write_format = options.pop("format", None)
|
|
508
|
+
write_spark_dataframe_with_options(
|
|
509
|
+
options, df, "overwrite", write_format=write_format
|
|
510
|
+
)
|
|
501
511
|
elif hasattr(df, "dask"):
|
|
502
512
|
dask_options = self.get_dask_options()
|
|
503
513
|
store, path_in_store, target_path = self._get_store_and_path()
|
|
@@ -524,7 +534,12 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
524
534
|
store, path_in_store, target_path = self._get_store_and_path()
|
|
525
535
|
target_path = generate_path_with_chunk(self, chunk_id, target_path)
|
|
526
536
|
file_system = store.filesystem
|
|
527
|
-
if
|
|
537
|
+
if (
|
|
538
|
+
file_system.protocol == "file"
|
|
539
|
+
# fsspec 2023.10.0 changed protocol from "file" to ("file", "local")
|
|
540
|
+
or isinstance(file_system.protocol, (tuple, list))
|
|
541
|
+
and "file" in file_system.protocol
|
|
542
|
+
):
|
|
528
543
|
dir = os.path.dirname(target_path)
|
|
529
544
|
if dir:
|
|
530
545
|
os.makedirs(dir, exist_ok=True)
|
|
@@ -1108,6 +1123,97 @@ class CSVTarget(BaseStoreTarget):
|
|
|
1108
1123
|
return True
|
|
1109
1124
|
|
|
1110
1125
|
|
|
1126
|
+
class SnowflakeTarget(BaseStoreTarget):
|
|
1127
|
+
"""
|
|
1128
|
+
:param attributes: A dictionary of attributes for Snowflake connection; will be overridden by database parameters
|
|
1129
|
+
if they exist.
|
|
1130
|
+
:param url: Snowflake hostname, in the format: <account_name>.<region>.snowflakecomputing.com
|
|
1131
|
+
:param user: Snowflake user for login
|
|
1132
|
+
:param db_schema: Database schema
|
|
1133
|
+
:param database: Database name
|
|
1134
|
+
:param warehouse: Snowflake warehouse name
|
|
1135
|
+
:param table_name: Snowflake table name
|
|
1136
|
+
"""
|
|
1137
|
+
|
|
1138
|
+
support_spark = True
|
|
1139
|
+
support_append = True
|
|
1140
|
+
is_offline = True
|
|
1141
|
+
kind = TargetTypes.snowflake
|
|
1142
|
+
|
|
1143
|
+
def __init__(
|
|
1144
|
+
self,
|
|
1145
|
+
name: str = "",
|
|
1146
|
+
path=None,
|
|
1147
|
+
attributes: dict[str, str] = None,
|
|
1148
|
+
after_step=None,
|
|
1149
|
+
columns=None,
|
|
1150
|
+
partitioned: bool = False,
|
|
1151
|
+
key_bucketing_number: Optional[int] = None,
|
|
1152
|
+
partition_cols: Optional[list[str]] = None,
|
|
1153
|
+
time_partitioning_granularity: Optional[str] = None,
|
|
1154
|
+
max_events: Optional[int] = None,
|
|
1155
|
+
flush_after_seconds: Optional[int] = None,
|
|
1156
|
+
storage_options: dict[str, str] = None,
|
|
1157
|
+
schema: dict[str, Any] = None,
|
|
1158
|
+
credentials_prefix=None,
|
|
1159
|
+
url: str = None,
|
|
1160
|
+
user: str = None,
|
|
1161
|
+
db_schema: str = None,
|
|
1162
|
+
database: str = None,
|
|
1163
|
+
warehouse: str = None,
|
|
1164
|
+
table_name: str = None,
|
|
1165
|
+
):
|
|
1166
|
+
attrs = {
|
|
1167
|
+
"url": url,
|
|
1168
|
+
"user": user,
|
|
1169
|
+
"database": database,
|
|
1170
|
+
"schema": db_schema,
|
|
1171
|
+
"warehouse": warehouse,
|
|
1172
|
+
"table": table_name,
|
|
1173
|
+
}
|
|
1174
|
+
extended_attrs = {
|
|
1175
|
+
key: value for key, value in attrs.items() if value is not None
|
|
1176
|
+
}
|
|
1177
|
+
attributes = {} if not attributes else attributes
|
|
1178
|
+
attributes.update(extended_attrs)
|
|
1179
|
+
super().__init__(
|
|
1180
|
+
name,
|
|
1181
|
+
path,
|
|
1182
|
+
attributes,
|
|
1183
|
+
after_step,
|
|
1184
|
+
list(schema.keys()) if schema else columns,
|
|
1185
|
+
partitioned,
|
|
1186
|
+
key_bucketing_number,
|
|
1187
|
+
partition_cols,
|
|
1188
|
+
time_partitioning_granularity,
|
|
1189
|
+
max_events=max_events,
|
|
1190
|
+
flush_after_seconds=flush_after_seconds,
|
|
1191
|
+
storage_options=storage_options,
|
|
1192
|
+
schema=schema,
|
|
1193
|
+
credentials_prefix=credentials_prefix,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True):
|
|
1197
|
+
spark_options = get_snowflake_spark_options(self.attributes)
|
|
1198
|
+
spark_options["dbtable"] = self.attributes.get("table")
|
|
1199
|
+
return spark_options
|
|
1200
|
+
|
|
1201
|
+
def purge(self):
|
|
1202
|
+
pass
|
|
1203
|
+
|
|
1204
|
+
def as_df(
|
|
1205
|
+
self,
|
|
1206
|
+
columns=None,
|
|
1207
|
+
df_module=None,
|
|
1208
|
+
entities=None,
|
|
1209
|
+
start_time=None,
|
|
1210
|
+
end_time=None,
|
|
1211
|
+
time_column=None,
|
|
1212
|
+
**kwargs,
|
|
1213
|
+
):
|
|
1214
|
+
raise NotImplementedError()
|
|
1215
|
+
|
|
1216
|
+
|
|
1111
1217
|
class NoSqlBaseTarget(BaseStoreTarget):
|
|
1112
1218
|
is_table = True
|
|
1113
1219
|
is_online = True
|
|
@@ -1179,7 +1285,10 @@ class NoSqlBaseTarget(BaseStoreTarget):
|
|
|
1179
1285
|
options = self.get_spark_options(key_column, timestamp_key)
|
|
1180
1286
|
options.update(kwargs)
|
|
1181
1287
|
df = self.prepare_spark_df(df)
|
|
1182
|
-
|
|
1288
|
+
write_format = options.pop("format", None)
|
|
1289
|
+
write_spark_dataframe_with_options(
|
|
1290
|
+
options, df, "overwrite", write_format=write_format
|
|
1291
|
+
)
|
|
1183
1292
|
else:
|
|
1184
1293
|
# To prevent modification of the original dataframe and make sure
|
|
1185
1294
|
# that the last event of a key is the one being persisted
|
|
@@ -1419,11 +1528,19 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1419
1528
|
*args,
|
|
1420
1529
|
bootstrap_servers=None,
|
|
1421
1530
|
producer_options=None,
|
|
1531
|
+
brokers=None,
|
|
1422
1532
|
**kwargs,
|
|
1423
1533
|
):
|
|
1424
1534
|
attrs = {}
|
|
1535
|
+
if bootstrap_servers:
|
|
1536
|
+
warnings.warn(
|
|
1537
|
+
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
|
|
1538
|
+
"use 'brokers' instead.",
|
|
1539
|
+
# TODO: Remove this in 1.9.0
|
|
1540
|
+
FutureWarning,
|
|
1541
|
+
)
|
|
1425
1542
|
if bootstrap_servers is not None:
|
|
1426
|
-
attrs["
|
|
1543
|
+
attrs["brokers"] = brokers or bootstrap_servers
|
|
1427
1544
|
if producer_options is not None:
|
|
1428
1545
|
attrs["producer_options"] = producer_options
|
|
1429
1546
|
|
|
@@ -1445,14 +1562,16 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1445
1562
|
if self.path and self.path.startswith("ds://"):
|
|
1446
1563
|
datastore_profile = datastore_profile_read(self.path)
|
|
1447
1564
|
attributes = datastore_profile.attributes()
|
|
1448
|
-
|
|
1565
|
+
brokers = attributes.pop(
|
|
1566
|
+
"brokers", attributes.pop("bootstrap_servers", None)
|
|
1567
|
+
)
|
|
1449
1568
|
topic = datastore_profile.topic
|
|
1450
1569
|
else:
|
|
1451
1570
|
attributes = copy(self.attributes)
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
self.get_target_path(), bootstrap_servers
|
|
1571
|
+
brokers = attributes.pop(
|
|
1572
|
+
"brokers", attributes.pop("bootstrap_servers", None)
|
|
1455
1573
|
)
|
|
1574
|
+
topic, brokers = parse_kafka_url(self.get_target_path(), brokers)
|
|
1456
1575
|
|
|
1457
1576
|
if not topic:
|
|
1458
1577
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -1466,7 +1585,7 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1466
1585
|
class_name="storey.KafkaTarget",
|
|
1467
1586
|
columns=column_list,
|
|
1468
1587
|
topic=topic,
|
|
1469
|
-
|
|
1588
|
+
brokers=brokers,
|
|
1470
1589
|
**attributes,
|
|
1471
1590
|
)
|
|
1472
1591
|
|
|
@@ -1957,6 +2076,7 @@ kind_to_driver = {
|
|
|
1957
2076
|
TargetTypes.tsdb: TSDBTarget,
|
|
1958
2077
|
TargetTypes.custom: CustomTarget,
|
|
1959
2078
|
TargetTypes.sql: SQLTarget,
|
|
2079
|
+
TargetTypes.snowflake: SnowflakeTarget,
|
|
1960
2080
|
}
|
|
1961
2081
|
|
|
1962
2082
|
|