mlrun 1.7.0rc6__py3-none-any.whl → 1.7.0rc9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/__main__.py +2 -0
- mlrun/common/constants.py +6 -0
- mlrun/common/schemas/__init__.py +5 -0
- mlrun/common/schemas/api_gateway.py +8 -1
- mlrun/common/schemas/hub.py +7 -9
- mlrun/common/schemas/model_monitoring/__init__.py +4 -0
- mlrun/common/schemas/model_monitoring/constants.py +36 -19
- mlrun/{model_monitoring/stores/models/__init__.py → common/schemas/pagination.py} +9 -10
- mlrun/common/schemas/project.py +16 -10
- mlrun/common/types.py +7 -1
- mlrun/config.py +35 -10
- mlrun/data_types/data_types.py +4 -0
- mlrun/datastore/__init__.py +3 -7
- mlrun/datastore/alibaba_oss.py +130 -0
- mlrun/datastore/azure_blob.py +4 -5
- mlrun/datastore/base.py +22 -16
- mlrun/datastore/datastore.py +4 -0
- mlrun/datastore/datastore_profile.py +19 -1
- mlrun/datastore/google_cloud_storage.py +1 -1
- mlrun/datastore/snowflake_utils.py +43 -0
- mlrun/datastore/sources.py +11 -29
- mlrun/datastore/targets.py +131 -11
- mlrun/datastore/utils.py +10 -5
- mlrun/db/base.py +58 -6
- mlrun/db/httpdb.py +183 -77
- mlrun/db/nopdb.py +110 -0
- mlrun/feature_store/api.py +3 -2
- mlrun/feature_store/retrieval/spark_merger.py +27 -23
- mlrun/frameworks/tf_keras/callbacks/logging_callback.py +1 -1
- mlrun/frameworks/tf_keras/mlrun_interface.py +2 -2
- mlrun/kfpops.py +2 -5
- mlrun/launcher/base.py +1 -1
- mlrun/launcher/client.py +2 -2
- mlrun/model.py +1 -0
- mlrun/model_monitoring/__init__.py +1 -1
- mlrun/model_monitoring/api.py +104 -295
- mlrun/model_monitoring/controller.py +25 -25
- mlrun/model_monitoring/db/__init__.py +16 -0
- mlrun/model_monitoring/{stores → db/stores}/__init__.py +43 -34
- mlrun/model_monitoring/db/stores/base/__init__.py +15 -0
- mlrun/model_monitoring/{stores/model_endpoint_store.py → db/stores/base/store.py} +47 -6
- mlrun/model_monitoring/db/stores/sqldb/__init__.py +13 -0
- mlrun/model_monitoring/db/stores/sqldb/models/__init__.py +49 -0
- mlrun/model_monitoring/{stores → db/stores/sqldb}/models/base.py +76 -3
- mlrun/model_monitoring/db/stores/sqldb/models/mysql.py +68 -0
- mlrun/model_monitoring/{stores → db/stores/sqldb}/models/sqlite.py +13 -1
- mlrun/model_monitoring/db/stores/sqldb/sql_store.py +662 -0
- mlrun/model_monitoring/db/stores/v3io_kv/__init__.py +13 -0
- mlrun/model_monitoring/{stores/kv_model_endpoint_store.py → db/stores/v3io_kv/kv_store.py} +134 -3
- mlrun/model_monitoring/helpers.py +3 -3
- mlrun/model_monitoring/stream_processing.py +41 -9
- mlrun/model_monitoring/tracking_policy.py +7 -1
- mlrun/model_monitoring/writer.py +4 -36
- mlrun/projects/pipelines.py +14 -2
- mlrun/projects/project.py +141 -122
- mlrun/run.py +8 -2
- mlrun/runtimes/__init__.py +16 -0
- mlrun/runtimes/base.py +10 -1
- mlrun/runtimes/kubejob.py +26 -121
- mlrun/runtimes/nuclio/api_gateway.py +243 -66
- mlrun/runtimes/nuclio/application/application.py +79 -1
- mlrun/runtimes/nuclio/application/reverse_proxy.go +9 -1
- mlrun/runtimes/nuclio/function.py +14 -8
- mlrun/runtimes/nuclio/serving.py +30 -34
- mlrun/runtimes/pod.py +171 -0
- mlrun/runtimes/utils.py +0 -28
- mlrun/serving/remote.py +2 -3
- mlrun/serving/routers.py +4 -3
- mlrun/serving/server.py +5 -7
- mlrun/serving/states.py +40 -23
- mlrun/serving/v2_serving.py +4 -3
- mlrun/utils/helpers.py +34 -0
- mlrun/utils/http.py +1 -1
- mlrun/utils/retryer.py +1 -0
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc9.dist-info}/METADATA +25 -16
- {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc9.dist-info}/RECORD +81 -75
- mlrun/model_monitoring/batch.py +0 -933
- mlrun/model_monitoring/stores/models/mysql.py +0 -34
- mlrun/model_monitoring/stores/sql_model_endpoint_store.py +0 -382
- {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc9.dist-info}/LICENSE +0 -0
- {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc9.dist-info}/WHEEL +0 -0
- {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc9.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.0rc6.dist-info → mlrun-1.7.0rc9.dist-info}/top_level.txt +0 -0
mlrun/datastore/base.py
CHANGED
|
@@ -27,6 +27,7 @@ import requests
|
|
|
27
27
|
import urllib3
|
|
28
28
|
from deprecated import deprecated
|
|
29
29
|
|
|
30
|
+
import mlrun.config
|
|
30
31
|
import mlrun.errors
|
|
31
32
|
from mlrun.errors import err_to_str
|
|
32
33
|
from mlrun.utils import StorePrefix, is_ipython, logger
|
|
@@ -34,10 +35,6 @@ from mlrun.utils import StorePrefix, is_ipython, logger
|
|
|
34
35
|
from .store_resources import is_store_uri, parse_store_uri
|
|
35
36
|
from .utils import filter_df_start_end_time, select_columns_from_df
|
|
36
37
|
|
|
37
|
-
verify_ssl = False
|
|
38
|
-
if not verify_ssl:
|
|
39
|
-
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
40
|
-
|
|
41
38
|
|
|
42
39
|
class FileStats:
|
|
43
40
|
def __init__(self, size, modified, content_type=None):
|
|
@@ -633,17 +630,6 @@ def basic_auth_header(user, password):
|
|
|
633
630
|
return {"Authorization": authstr}
|
|
634
631
|
|
|
635
632
|
|
|
636
|
-
def http_get(url, headers=None, auth=None):
|
|
637
|
-
try:
|
|
638
|
-
response = requests.get(url, headers=headers, auth=auth, verify=verify_ssl)
|
|
639
|
-
except OSError as exc:
|
|
640
|
-
raise OSError(f"error: cannot connect to {url}: {err_to_str(exc)}")
|
|
641
|
-
|
|
642
|
-
mlrun.errors.raise_for_status(response)
|
|
643
|
-
|
|
644
|
-
return response.content
|
|
645
|
-
|
|
646
|
-
|
|
647
633
|
class HttpStore(DataStore):
|
|
648
634
|
def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
|
|
649
635
|
super().__init__(parent, name, schema, endpoint, secrets)
|
|
@@ -671,7 +657,7 @@ class HttpStore(DataStore):
|
|
|
671
657
|
raise ValueError("unimplemented")
|
|
672
658
|
|
|
673
659
|
def get(self, key, size=None, offset=0):
|
|
674
|
-
data =
|
|
660
|
+
data = self._http_get(self.url + self._join(key), self._headers, self.auth)
|
|
675
661
|
if offset:
|
|
676
662
|
data = data[offset:]
|
|
677
663
|
if size:
|
|
@@ -691,6 +677,26 @@ class HttpStore(DataStore):
|
|
|
691
677
|
f"schema as it is not secure and is not recommended."
|
|
692
678
|
)
|
|
693
679
|
|
|
680
|
+
def _http_get(
|
|
681
|
+
self,
|
|
682
|
+
url,
|
|
683
|
+
headers=None,
|
|
684
|
+
auth=None,
|
|
685
|
+
):
|
|
686
|
+
# import here to prevent import cycle
|
|
687
|
+
from mlrun.config import config as mlconf
|
|
688
|
+
|
|
689
|
+
verify_ssl = mlconf.httpdb.http.verify
|
|
690
|
+
try:
|
|
691
|
+
if not verify_ssl:
|
|
692
|
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
693
|
+
response = requests.get(url, headers=headers, auth=auth, verify=verify_ssl)
|
|
694
|
+
except OSError as exc:
|
|
695
|
+
raise OSError(f"error: cannot connect to {url}: {err_to_str(exc)}")
|
|
696
|
+
|
|
697
|
+
mlrun.errors.raise_for_status(response)
|
|
698
|
+
return response.content
|
|
699
|
+
|
|
694
700
|
|
|
695
701
|
# This wrapper class is designed to extract the 'ds' schema and profile name from URL-formatted paths.
|
|
696
702
|
# Within fsspec, the AbstractFileSystem::_strip_protocol() internal method is used to handle complete URL paths.
|
mlrun/datastore/datastore.py
CHANGED
|
@@ -16,6 +16,7 @@ import ast
|
|
|
16
16
|
import base64
|
|
17
17
|
import json
|
|
18
18
|
import typing
|
|
19
|
+
import warnings
|
|
19
20
|
from urllib.parse import ParseResult, urlparse, urlunparse
|
|
20
21
|
|
|
21
22
|
import pydantic
|
|
@@ -68,6 +69,9 @@ class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton
|
|
|
68
69
|
def get(self, key):
|
|
69
70
|
return self._data.get(key, None)
|
|
70
71
|
|
|
72
|
+
def remove(self, key):
|
|
73
|
+
self._data.pop(key, None)
|
|
74
|
+
|
|
71
75
|
|
|
72
76
|
class DatastoreProfileBasic(DatastoreProfile):
|
|
73
77
|
type: str = pydantic.Field("basic")
|
|
@@ -80,12 +84,22 @@ class DatastoreProfileKafkaTarget(DatastoreProfile):
|
|
|
80
84
|
type: str = pydantic.Field("kafka_target")
|
|
81
85
|
_private_attributes = "kwargs_private"
|
|
82
86
|
bootstrap_servers: str
|
|
87
|
+
brokers: str
|
|
83
88
|
topic: str
|
|
84
89
|
kwargs_public: typing.Optional[dict]
|
|
85
90
|
kwargs_private: typing.Optional[dict]
|
|
86
91
|
|
|
92
|
+
def __pydantic_post_init__(self):
|
|
93
|
+
if self.bootstrap_servers:
|
|
94
|
+
warnings.warn(
|
|
95
|
+
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
|
|
96
|
+
"use 'brokers' instead.",
|
|
97
|
+
# TODO: Remove this in 1.9.0
|
|
98
|
+
FutureWarning,
|
|
99
|
+
)
|
|
100
|
+
|
|
87
101
|
def attributes(self):
|
|
88
|
-
attributes = {"
|
|
102
|
+
attributes = {"brokers": self.brokers or self.bootstrap_servers}
|
|
89
103
|
if self.kwargs_public:
|
|
90
104
|
attributes = merge(attributes, self.kwargs_public)
|
|
91
105
|
if self.kwargs_private:
|
|
@@ -460,3 +474,7 @@ def register_temporary_client_datastore_profile(profile: DatastoreProfile):
|
|
|
460
474
|
It's beneficial for testing purposes.
|
|
461
475
|
"""
|
|
462
476
|
TemporaryClientDatastoreProfiles().add(profile)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
def remove_temporary_client_datastore_profile(profile_name: str):
|
|
480
|
+
TemporaryClientDatastoreProfiles().remove(profile_name)
|
|
@@ -132,7 +132,7 @@ class GoogleCloudStorageStore(DataStore):
|
|
|
132
132
|
self.filesystem.rm(path=path, recursive=recursive, maxdepth=maxdepth)
|
|
133
133
|
|
|
134
134
|
def get_spark_options(self):
|
|
135
|
-
res =
|
|
135
|
+
res = {}
|
|
136
136
|
st = self.get_storage_options()
|
|
137
137
|
if "token" in st:
|
|
138
138
|
res = {"spark.hadoop.google.cloud.auth.service.account.enable": "true"}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright 2024 Iguazio
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
#
|
|
15
|
+
|
|
16
|
+
import mlrun
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_snowflake_password():
|
|
20
|
+
key = "SNOWFLAKE_PASSWORD"
|
|
21
|
+
snowflake_password = mlrun.get_secret_or_env(key)
|
|
22
|
+
|
|
23
|
+
if not snowflake_password:
|
|
24
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
25
|
+
f"No password provided. Set password using the {key} "
|
|
26
|
+
"project secret or environment variable."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
return snowflake_password
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_snowflake_spark_options(attributes):
|
|
33
|
+
return {
|
|
34
|
+
"format": "net.snowflake.spark.snowflake",
|
|
35
|
+
"sfURL": attributes.get("url"),
|
|
36
|
+
"sfUser": attributes.get("user"),
|
|
37
|
+
"sfPassword": get_snowflake_password(),
|
|
38
|
+
"sfDatabase": attributes.get("database"),
|
|
39
|
+
"sfSchema": attributes.get("schema"),
|
|
40
|
+
"sfWarehouse": attributes.get("warehouse"),
|
|
41
|
+
"application": "iguazio_platform",
|
|
42
|
+
"TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ",
|
|
43
|
+
}
|
mlrun/datastore/sources.py
CHANGED
|
@@ -28,6 +28,7 @@ from nuclio.config import split_path
|
|
|
28
28
|
|
|
29
29
|
import mlrun
|
|
30
30
|
from mlrun.config import config
|
|
31
|
+
from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
|
|
31
32
|
from mlrun.secrets import SecretsStore
|
|
32
33
|
|
|
33
34
|
from ..model import DataSource
|
|
@@ -113,7 +114,11 @@ class BaseSourceDriver(DataSource):
|
|
|
113
114
|
|
|
114
115
|
def to_spark_df(self, session, named_view=False, time_field=None, columns=None):
|
|
115
116
|
if self.support_spark:
|
|
116
|
-
|
|
117
|
+
spark_options = self.get_spark_options()
|
|
118
|
+
spark_format = spark_options.pop("format", None)
|
|
119
|
+
df = load_spark_dataframe_with_options(
|
|
120
|
+
session, spark_options, format=spark_format
|
|
121
|
+
)
|
|
117
122
|
if named_view:
|
|
118
123
|
df.createOrReplaceTempView(self.name)
|
|
119
124
|
return self._filter_spark_df(df, time_field, columns)
|
|
@@ -673,32 +678,10 @@ class SnowflakeSource(BaseSourceDriver):
|
|
|
673
678
|
**kwargs,
|
|
674
679
|
)
|
|
675
680
|
|
|
676
|
-
def _get_password(self):
|
|
677
|
-
key = "SNOWFLAKE_PASSWORD"
|
|
678
|
-
snowflake_password = os.getenv(key) or os.getenv(
|
|
679
|
-
SecretsStore.k8s_env_variable_name_for_secret(key)
|
|
680
|
-
)
|
|
681
|
-
|
|
682
|
-
if not snowflake_password:
|
|
683
|
-
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
684
|
-
"No password provided. Set password using the SNOWFLAKE_PASSWORD "
|
|
685
|
-
"project secret or environment variable."
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
return snowflake_password
|
|
689
|
-
|
|
690
681
|
def get_spark_options(self):
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
"sfURL": self.attributes.get("url"),
|
|
695
|
-
"sfUser": self.attributes.get("user"),
|
|
696
|
-
"sfPassword": self._get_password(),
|
|
697
|
-
"sfDatabase": self.attributes.get("database"),
|
|
698
|
-
"sfSchema": self.attributes.get("schema"),
|
|
699
|
-
"sfWarehouse": self.attributes.get("warehouse"),
|
|
700
|
-
"application": "iguazio_platform",
|
|
701
|
-
}
|
|
682
|
+
spark_options = get_snowflake_spark_options(self.attributes)
|
|
683
|
+
spark_options["query"] = self.attributes.get("query")
|
|
684
|
+
return spark_options
|
|
702
685
|
|
|
703
686
|
|
|
704
687
|
class CustomSource(BaseSourceDriver):
|
|
@@ -854,12 +837,11 @@ class StreamSource(OnlineSource):
|
|
|
854
837
|
super().__init__(name, attributes=attrs, **kwargs)
|
|
855
838
|
|
|
856
839
|
def add_nuclio_trigger(self, function):
|
|
857
|
-
store,
|
|
840
|
+
store, _, url = mlrun.store_manager.get_or_create_store(self.path)
|
|
858
841
|
if store.kind != "v3io":
|
|
859
842
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
860
843
|
"Only profiles that reference the v3io datastore can be used with StreamSource"
|
|
861
844
|
)
|
|
862
|
-
path = "v3io:/" + path
|
|
863
845
|
storage_options = store.get_storage_options()
|
|
864
846
|
access_key = storage_options.get("v3io_access_key")
|
|
865
847
|
endpoint, stream_path = parse_path(url)
|
|
@@ -883,7 +865,7 @@ class StreamSource(OnlineSource):
|
|
|
883
865
|
kwargs["worker_allocation_mode"] = "static"
|
|
884
866
|
|
|
885
867
|
function.add_v3io_stream_trigger(
|
|
886
|
-
|
|
868
|
+
url,
|
|
887
869
|
self.name,
|
|
888
870
|
self.attributes["group"],
|
|
889
871
|
self.attributes["seek_to"],
|
mlrun/datastore/targets.py
CHANGED
|
@@ -17,6 +17,7 @@ import os
|
|
|
17
17
|
import random
|
|
18
18
|
import sys
|
|
19
19
|
import time
|
|
20
|
+
import warnings
|
|
20
21
|
from collections import Counter
|
|
21
22
|
from copy import copy
|
|
22
23
|
from typing import Any, Optional, Union
|
|
@@ -28,6 +29,7 @@ from mergedeep import merge
|
|
|
28
29
|
import mlrun
|
|
29
30
|
import mlrun.utils.helpers
|
|
30
31
|
from mlrun.config import config
|
|
32
|
+
from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
|
|
31
33
|
from mlrun.model import DataSource, DataTarget, DataTargetBase, TargetPathObject
|
|
32
34
|
from mlrun.utils import logger, now_date
|
|
33
35
|
from mlrun.utils.helpers import to_parquet
|
|
@@ -57,6 +59,7 @@ class TargetTypes:
|
|
|
57
59
|
dataframe = "dataframe"
|
|
58
60
|
custom = "custom"
|
|
59
61
|
sql = "sql"
|
|
62
|
+
snowflake = "snowflake"
|
|
60
63
|
|
|
61
64
|
@staticmethod
|
|
62
65
|
def all():
|
|
@@ -71,6 +74,7 @@ class TargetTypes:
|
|
|
71
74
|
TargetTypes.dataframe,
|
|
72
75
|
TargetTypes.custom,
|
|
73
76
|
TargetTypes.sql,
|
|
77
|
+
TargetTypes.snowflake,
|
|
74
78
|
]
|
|
75
79
|
|
|
76
80
|
|
|
@@ -78,11 +82,14 @@ def generate_target_run_id():
|
|
|
78
82
|
return f"{round(time.time() * 1000)}_{random.randint(0, 999)}"
|
|
79
83
|
|
|
80
84
|
|
|
81
|
-
def write_spark_dataframe_with_options(spark_options, df, mode):
|
|
85
|
+
def write_spark_dataframe_with_options(spark_options, df, mode, write_format=None):
|
|
82
86
|
non_hadoop_spark_options = spark_session_update_hadoop_options(
|
|
83
87
|
df.sql_ctx.sparkSession, spark_options
|
|
84
88
|
)
|
|
85
|
-
|
|
89
|
+
if write_format:
|
|
90
|
+
df.write.format(write_format).mode(mode).save(**non_hadoop_spark_options)
|
|
91
|
+
else:
|
|
92
|
+
df.write.mode(mode).save(**non_hadoop_spark_options)
|
|
86
93
|
|
|
87
94
|
|
|
88
95
|
def default_target_names():
|
|
@@ -497,7 +504,10 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
497
504
|
options = self.get_spark_options(key_column, timestamp_key)
|
|
498
505
|
options.update(kwargs)
|
|
499
506
|
df = self.prepare_spark_df(df, key_column, timestamp_key, options)
|
|
500
|
-
|
|
507
|
+
write_format = options.pop("format", None)
|
|
508
|
+
write_spark_dataframe_with_options(
|
|
509
|
+
options, df, "overwrite", write_format=write_format
|
|
510
|
+
)
|
|
501
511
|
elif hasattr(df, "dask"):
|
|
502
512
|
dask_options = self.get_dask_options()
|
|
503
513
|
store, path_in_store, target_path = self._get_store_and_path()
|
|
@@ -524,7 +534,12 @@ class BaseStoreTarget(DataTargetBase):
|
|
|
524
534
|
store, path_in_store, target_path = self._get_store_and_path()
|
|
525
535
|
target_path = generate_path_with_chunk(self, chunk_id, target_path)
|
|
526
536
|
file_system = store.filesystem
|
|
527
|
-
if
|
|
537
|
+
if (
|
|
538
|
+
file_system.protocol == "file"
|
|
539
|
+
# fsspec 2023.10.0 changed protocol from "file" to ("file", "local")
|
|
540
|
+
or isinstance(file_system.protocol, (tuple, list))
|
|
541
|
+
and "file" in file_system.protocol
|
|
542
|
+
):
|
|
528
543
|
dir = os.path.dirname(target_path)
|
|
529
544
|
if dir:
|
|
530
545
|
os.makedirs(dir, exist_ok=True)
|
|
@@ -1108,6 +1123,97 @@ class CSVTarget(BaseStoreTarget):
|
|
|
1108
1123
|
return True
|
|
1109
1124
|
|
|
1110
1125
|
|
|
1126
|
+
class SnowflakeTarget(BaseStoreTarget):
|
|
1127
|
+
"""
|
|
1128
|
+
:param attributes: A dictionary of attributes for Snowflake connection; will be overridden by database parameters
|
|
1129
|
+
if they exist.
|
|
1130
|
+
:param url: Snowflake hostname, in the format: <account_name>.<region>.snowflakecomputing.com
|
|
1131
|
+
:param user: Snowflake user for login
|
|
1132
|
+
:param db_schema: Database schema
|
|
1133
|
+
:param database: Database name
|
|
1134
|
+
:param warehouse: Snowflake warehouse name
|
|
1135
|
+
:param table_name: Snowflake table name
|
|
1136
|
+
"""
|
|
1137
|
+
|
|
1138
|
+
support_spark = True
|
|
1139
|
+
support_append = True
|
|
1140
|
+
is_offline = True
|
|
1141
|
+
kind = TargetTypes.snowflake
|
|
1142
|
+
|
|
1143
|
+
def __init__(
|
|
1144
|
+
self,
|
|
1145
|
+
name: str = "",
|
|
1146
|
+
path=None,
|
|
1147
|
+
attributes: dict[str, str] = None,
|
|
1148
|
+
after_step=None,
|
|
1149
|
+
columns=None,
|
|
1150
|
+
partitioned: bool = False,
|
|
1151
|
+
key_bucketing_number: Optional[int] = None,
|
|
1152
|
+
partition_cols: Optional[list[str]] = None,
|
|
1153
|
+
time_partitioning_granularity: Optional[str] = None,
|
|
1154
|
+
max_events: Optional[int] = None,
|
|
1155
|
+
flush_after_seconds: Optional[int] = None,
|
|
1156
|
+
storage_options: dict[str, str] = None,
|
|
1157
|
+
schema: dict[str, Any] = None,
|
|
1158
|
+
credentials_prefix=None,
|
|
1159
|
+
url: str = None,
|
|
1160
|
+
user: str = None,
|
|
1161
|
+
db_schema: str = None,
|
|
1162
|
+
database: str = None,
|
|
1163
|
+
warehouse: str = None,
|
|
1164
|
+
table_name: str = None,
|
|
1165
|
+
):
|
|
1166
|
+
attrs = {
|
|
1167
|
+
"url": url,
|
|
1168
|
+
"user": user,
|
|
1169
|
+
"database": database,
|
|
1170
|
+
"schema": db_schema,
|
|
1171
|
+
"warehouse": warehouse,
|
|
1172
|
+
"table": table_name,
|
|
1173
|
+
}
|
|
1174
|
+
extended_attrs = {
|
|
1175
|
+
key: value for key, value in attrs.items() if value is not None
|
|
1176
|
+
}
|
|
1177
|
+
attributes = {} if not attributes else attributes
|
|
1178
|
+
attributes.update(extended_attrs)
|
|
1179
|
+
super().__init__(
|
|
1180
|
+
name,
|
|
1181
|
+
path,
|
|
1182
|
+
attributes,
|
|
1183
|
+
after_step,
|
|
1184
|
+
list(schema.keys()) if schema else columns,
|
|
1185
|
+
partitioned,
|
|
1186
|
+
key_bucketing_number,
|
|
1187
|
+
partition_cols,
|
|
1188
|
+
time_partitioning_granularity,
|
|
1189
|
+
max_events=max_events,
|
|
1190
|
+
flush_after_seconds=flush_after_seconds,
|
|
1191
|
+
storage_options=storage_options,
|
|
1192
|
+
schema=schema,
|
|
1193
|
+
credentials_prefix=credentials_prefix,
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True):
|
|
1197
|
+
spark_options = get_snowflake_spark_options(self.attributes)
|
|
1198
|
+
spark_options["dbtable"] = self.attributes.get("table")
|
|
1199
|
+
return spark_options
|
|
1200
|
+
|
|
1201
|
+
def purge(self):
|
|
1202
|
+
pass
|
|
1203
|
+
|
|
1204
|
+
def as_df(
|
|
1205
|
+
self,
|
|
1206
|
+
columns=None,
|
|
1207
|
+
df_module=None,
|
|
1208
|
+
entities=None,
|
|
1209
|
+
start_time=None,
|
|
1210
|
+
end_time=None,
|
|
1211
|
+
time_column=None,
|
|
1212
|
+
**kwargs,
|
|
1213
|
+
):
|
|
1214
|
+
raise NotImplementedError()
|
|
1215
|
+
|
|
1216
|
+
|
|
1111
1217
|
class NoSqlBaseTarget(BaseStoreTarget):
|
|
1112
1218
|
is_table = True
|
|
1113
1219
|
is_online = True
|
|
@@ -1179,7 +1285,10 @@ class NoSqlBaseTarget(BaseStoreTarget):
|
|
|
1179
1285
|
options = self.get_spark_options(key_column, timestamp_key)
|
|
1180
1286
|
options.update(kwargs)
|
|
1181
1287
|
df = self.prepare_spark_df(df)
|
|
1182
|
-
|
|
1288
|
+
write_format = options.pop("format", None)
|
|
1289
|
+
write_spark_dataframe_with_options(
|
|
1290
|
+
options, df, "overwrite", write_format=write_format
|
|
1291
|
+
)
|
|
1183
1292
|
else:
|
|
1184
1293
|
# To prevent modification of the original dataframe and make sure
|
|
1185
1294
|
# that the last event of a key is the one being persisted
|
|
@@ -1419,11 +1528,19 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1419
1528
|
*args,
|
|
1420
1529
|
bootstrap_servers=None,
|
|
1421
1530
|
producer_options=None,
|
|
1531
|
+
brokers=None,
|
|
1422
1532
|
**kwargs,
|
|
1423
1533
|
):
|
|
1424
1534
|
attrs = {}
|
|
1535
|
+
if bootstrap_servers:
|
|
1536
|
+
warnings.warn(
|
|
1537
|
+
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
|
|
1538
|
+
"use 'brokers' instead.",
|
|
1539
|
+
# TODO: Remove this in 1.9.0
|
|
1540
|
+
FutureWarning,
|
|
1541
|
+
)
|
|
1425
1542
|
if bootstrap_servers is not None:
|
|
1426
|
-
attrs["
|
|
1543
|
+
attrs["brokers"] = brokers or bootstrap_servers
|
|
1427
1544
|
if producer_options is not None:
|
|
1428
1545
|
attrs["producer_options"] = producer_options
|
|
1429
1546
|
|
|
@@ -1445,14 +1562,16 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1445
1562
|
if self.path and self.path.startswith("ds://"):
|
|
1446
1563
|
datastore_profile = datastore_profile_read(self.path)
|
|
1447
1564
|
attributes = datastore_profile.attributes()
|
|
1448
|
-
|
|
1565
|
+
brokers = attributes.pop(
|
|
1566
|
+
"brokers", attributes.pop("bootstrap_servers", None)
|
|
1567
|
+
)
|
|
1449
1568
|
topic = datastore_profile.topic
|
|
1450
1569
|
else:
|
|
1451
1570
|
attributes = copy(self.attributes)
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
self.get_target_path(), bootstrap_servers
|
|
1571
|
+
brokers = attributes.pop(
|
|
1572
|
+
"brokers", attributes.pop("bootstrap_servers", None)
|
|
1455
1573
|
)
|
|
1574
|
+
topic, brokers = parse_kafka_url(self.get_target_path(), brokers)
|
|
1456
1575
|
|
|
1457
1576
|
if not topic:
|
|
1458
1577
|
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
@@ -1466,7 +1585,7 @@ class KafkaTarget(BaseStoreTarget):
|
|
|
1466
1585
|
class_name="storey.KafkaTarget",
|
|
1467
1586
|
columns=column_list,
|
|
1468
1587
|
topic=topic,
|
|
1469
|
-
|
|
1588
|
+
brokers=brokers,
|
|
1470
1589
|
**attributes,
|
|
1471
1590
|
)
|
|
1472
1591
|
|
|
@@ -1957,6 +2076,7 @@ kind_to_driver = {
|
|
|
1957
2076
|
TargetTypes.tsdb: TSDBTarget,
|
|
1958
2077
|
TargetTypes.custom: CustomTarget,
|
|
1959
2078
|
TargetTypes.sql: SQLTarget,
|
|
2079
|
+
TargetTypes.snowflake: SnowflakeTarget,
|
|
1960
2080
|
}
|
|
1961
2081
|
|
|
1962
2082
|
|
mlrun/datastore/utils.py
CHANGED
|
@@ -23,24 +23,29 @@ import semver
|
|
|
23
23
|
import mlrun.datastore
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def parse_kafka_url(
|
|
26
|
+
def parse_kafka_url(
|
|
27
|
+
url: str, brokers: typing.Union[list, str] = None
|
|
28
|
+
) -> tuple[str, list]:
|
|
27
29
|
"""Generating Kafka topic and adjusting a list of bootstrap servers.
|
|
28
30
|
|
|
29
31
|
:param url: URL path to parse using urllib.parse.urlparse.
|
|
30
|
-
:param
|
|
32
|
+
:param brokers: List of kafka brokers.
|
|
31
33
|
|
|
32
34
|
:return: A tuple of:
|
|
33
35
|
[0] = Kafka topic value
|
|
34
36
|
[1] = List of bootstrap servers
|
|
35
37
|
"""
|
|
36
|
-
|
|
38
|
+
brokers = brokers or []
|
|
39
|
+
|
|
40
|
+
if isinstance(brokers, str):
|
|
41
|
+
brokers = brokers.split(",")
|
|
37
42
|
|
|
38
43
|
# Parse the provided URL into six components according to the general structure of a URL
|
|
39
44
|
url = urlparse(url)
|
|
40
45
|
|
|
41
46
|
# Add the network location to the bootstrap servers list
|
|
42
47
|
if url.netloc:
|
|
43
|
-
|
|
48
|
+
brokers = [url.netloc] + brokers
|
|
44
49
|
|
|
45
50
|
# Get the topic value from the parsed url
|
|
46
51
|
query_dict = parse_qs(url.query)
|
|
@@ -49,7 +54,7 @@ def parse_kafka_url(url: str, bootstrap_servers: list = None) -> tuple[str, list
|
|
|
49
54
|
else:
|
|
50
55
|
topic = url.path
|
|
51
56
|
topic = topic.lstrip("/")
|
|
52
|
-
return topic,
|
|
57
|
+
return topic, brokers
|
|
53
58
|
|
|
54
59
|
|
|
55
60
|
def upload_tarball(source_dir, target, secrets=None):
|