mlrun 1.7.0rc37__py3-none-any.whl → 1.7.0rc39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of mlrun might be problematic. Click here for more details.
- mlrun/alerts/alert.py +34 -30
- mlrun/common/schemas/alert.py +3 -0
- mlrun/common/schemas/model_monitoring/constants.py +4 -0
- mlrun/common/schemas/notification.py +4 -3
- mlrun/datastore/alibaba_oss.py +2 -2
- mlrun/datastore/azure_blob.py +124 -31
- mlrun/datastore/base.py +1 -1
- mlrun/datastore/dbfs_store.py +2 -2
- mlrun/datastore/google_cloud_storage.py +83 -20
- mlrun/datastore/s3.py +2 -2
- mlrun/datastore/sources.py +54 -0
- mlrun/datastore/targets.py +9 -53
- mlrun/db/httpdb.py +6 -1
- mlrun/errors.py +8 -0
- mlrun/execution.py +7 -0
- mlrun/feature_store/api.py +5 -0
- mlrun/feature_store/common.py +6 -11
- mlrun/feature_store/retrieval/job.py +1 -0
- mlrun/model.py +29 -3
- mlrun/model_monitoring/api.py +9 -0
- mlrun/model_monitoring/applications/_application_steps.py +36 -0
- mlrun/model_monitoring/applications/histogram_data_drift.py +15 -13
- mlrun/model_monitoring/controller.py +15 -11
- mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +14 -11
- mlrun/model_monitoring/db/tsdb/base.py +121 -1
- mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +85 -47
- mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +100 -12
- mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +23 -1
- mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +214 -36
- mlrun/model_monitoring/helpers.py +16 -17
- mlrun/model_monitoring/stream_processing.py +68 -27
- mlrun/projects/operations.py +1 -1
- mlrun/projects/pipelines.py +19 -30
- mlrun/projects/project.py +76 -52
- mlrun/run.py +8 -6
- mlrun/runtimes/__init__.py +19 -8
- mlrun/runtimes/nuclio/api_gateway.py +9 -0
- mlrun/runtimes/nuclio/application/application.py +64 -9
- mlrun/runtimes/nuclio/function.py +1 -1
- mlrun/runtimes/pod.py +2 -2
- mlrun/runtimes/remotesparkjob.py +2 -5
- mlrun/runtimes/sparkjob/spark3job.py +7 -9
- mlrun/serving/v2_serving.py +1 -0
- mlrun/track/trackers/mlflow_tracker.py +5 -0
- mlrun/utils/helpers.py +21 -0
- mlrun/utils/version/version.json +2 -2
- {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/METADATA +14 -11
- {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/RECORD +52 -52
- {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/WHEEL +1 -1
- {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/LICENSE +0 -0
- {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/entry_points.txt +0 -0
- {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/top_level.txt +0 -0
mlrun/alerts/alert.py
CHANGED
|
@@ -29,6 +29,7 @@ class AlertConfig(ModelObj):
|
|
|
29
29
|
"reset_policy",
|
|
30
30
|
"state",
|
|
31
31
|
"count",
|
|
32
|
+
"created",
|
|
32
33
|
]
|
|
33
34
|
_fields_to_serialize = ModelObj._fields_to_serialize + [
|
|
34
35
|
"entities",
|
|
@@ -55,12 +56,13 @@ class AlertConfig(ModelObj):
|
|
|
55
56
|
created: str = None,
|
|
56
57
|
count: int = None,
|
|
57
58
|
):
|
|
58
|
-
"""
|
|
59
|
-
Alert config object
|
|
59
|
+
"""Alert config object
|
|
60
60
|
|
|
61
61
|
Example::
|
|
62
|
+
|
|
62
63
|
# create an alert on endpoint_id, which will be triggered to slack if there is a "data_drift_detected" event
|
|
63
|
-
3 times in the next hour.
|
|
64
|
+
# 3 times in the next hour.
|
|
65
|
+
|
|
64
66
|
from mlrun.alerts import AlertConfig
|
|
65
67
|
import mlrun.common.schemas.alert as alert_objects
|
|
66
68
|
|
|
@@ -93,29 +95,29 @@ class AlertConfig(ModelObj):
|
|
|
93
95
|
)
|
|
94
96
|
project.store_alert_config(alert_data)
|
|
95
97
|
|
|
96
|
-
:param project:
|
|
97
|
-
:param name:
|
|
98
|
-
:param template:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
template.
|
|
98
|
+
:param project: Name of the project to associate the alert with
|
|
99
|
+
:param name: Name of the alert
|
|
100
|
+
:param template: Optional parameter that allows creating an alert based on a predefined template.
|
|
101
|
+
You can pass either an AlertTemplate object or a string (the template name).
|
|
102
|
+
If a template is used, many fields of the alert will be auto-generated based on the
|
|
103
|
+
template.However, you still need to provide the following fields:
|
|
102
104
|
`name`, `project`, `entity`, `notifications`
|
|
103
|
-
:param description:
|
|
104
|
-
:param summary:
|
|
105
|
-
:param severity:
|
|
106
|
-
:param trigger:
|
|
105
|
+
:param description: Description of the alert
|
|
106
|
+
:param summary: Summary of the alert, will be sent in the generated notifications
|
|
107
|
+
:param severity: Severity of the alert
|
|
108
|
+
:param trigger: The events that will trigger this alert, may be a simple trigger based on events or
|
|
107
109
|
complex trigger which is based on a prometheus alert
|
|
108
|
-
:param criteria:
|
|
110
|
+
:param criteria: When the alert will be triggered based on the specified number of events within the
|
|
109
111
|
defined time period.
|
|
110
|
-
:param reset_policy:
|
|
112
|
+
:param reset_policy: When to clear the alert. May be "manual" for manual reset of the alert, or
|
|
111
113
|
"auto" if the criteria contains a time period
|
|
112
|
-
:param notifications:
|
|
113
|
-
:param entities:
|
|
114
|
-
identify a given entity in the system
|
|
115
|
-
:param id:
|
|
116
|
-
:param state:
|
|
117
|
-
:param created:
|
|
118
|
-
:param count:
|
|
114
|
+
:param notifications: List of notifications to invoke once the alert is triggered
|
|
115
|
+
:param entities: Entities that the event relates to. The entity object will contain fields that
|
|
116
|
+
uniquely identify a given entity in the system
|
|
117
|
+
:param id: Internal id of the alert (user should not supply it)
|
|
118
|
+
:param state: State of the alert, may be active/inactive (user should not supply it)
|
|
119
|
+
:param created: When the alert is created (user should not supply it)
|
|
120
|
+
:param count: Internal counter of the alert (user should not supply it)
|
|
119
121
|
"""
|
|
120
122
|
self.project = project
|
|
121
123
|
self.name = name
|
|
@@ -136,8 +138,8 @@ class AlertConfig(ModelObj):
|
|
|
136
138
|
self._apply_template(template)
|
|
137
139
|
|
|
138
140
|
def validate_required_fields(self):
|
|
139
|
-
if not self.
|
|
140
|
-
raise mlrun.errors.
|
|
141
|
+
if not self.name:
|
|
142
|
+
raise mlrun.errors.MLRunInvalidArgumentError("Alert name must be provided")
|
|
141
143
|
|
|
142
144
|
def _serialize_field(
|
|
143
145
|
self, struct: dict, field_name: str = None, strip: bool = False
|
|
@@ -236,9 +238,11 @@ class AlertConfig(ModelObj):
|
|
|
236
238
|
db = mlrun.get_run_db()
|
|
237
239
|
template = db.get_alert_template(template)
|
|
238
240
|
|
|
239
|
-
#
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
self.
|
|
243
|
-
self.
|
|
244
|
-
self.
|
|
241
|
+
# Apply parameters from the template to the AlertConfig object only if they are not already specified by the
|
|
242
|
+
# user in the current configuration.
|
|
243
|
+
# User-provided parameters will take precedence over corresponding template values
|
|
244
|
+
self.summary = self.summary or template.summary
|
|
245
|
+
self.severity = self.severity or template.severity
|
|
246
|
+
self.criteria = self.criteria or template.criteria
|
|
247
|
+
self.trigger = self.trigger or template.trigger
|
|
248
|
+
self.reset_policy = self.reset_policy or template.reset_policy
|
mlrun/common/schemas/alert.py
CHANGED
|
@@ -23,6 +23,7 @@ from mlrun.common.types import StrEnum
|
|
|
23
23
|
|
|
24
24
|
class EventEntityKind(StrEnum):
|
|
25
25
|
MODEL_ENDPOINT_RESULT = "model-endpoint-result"
|
|
26
|
+
MODEL_MONITORING_APPLICATION = "model-monitoring-application"
|
|
26
27
|
JOB = "job"
|
|
27
28
|
|
|
28
29
|
|
|
@@ -43,6 +44,7 @@ class EventKind(StrEnum):
|
|
|
43
44
|
SYSTEM_PERFORMANCE_SUSPECTED = "system_performance_suspected"
|
|
44
45
|
MM_APP_ANOMALY_DETECTED = "mm_app_anomaly_detected"
|
|
45
46
|
MM_APP_ANOMALY_SUSPECTED = "mm_app_anomaly_suspected"
|
|
47
|
+
MM_APP_FAILED = "mm_app_failed"
|
|
46
48
|
FAILED = "failed"
|
|
47
49
|
|
|
48
50
|
|
|
@@ -57,6 +59,7 @@ _event_kind_entity_map = {
|
|
|
57
59
|
EventKind.SYSTEM_PERFORMANCE_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
|
|
58
60
|
EventKind.MM_APP_ANOMALY_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
|
|
59
61
|
EventKind.MM_APP_ANOMALY_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
|
|
62
|
+
EventKind.MM_APP_FAILED: [EventEntityKind.MODEL_MONITORING_APPLICATION],
|
|
60
63
|
EventKind.FAILED: [EventEntityKind.JOB],
|
|
61
64
|
}
|
|
62
65
|
|
|
@@ -53,9 +53,11 @@ class EventFieldType:
|
|
|
53
53
|
PREDICTIONS = "predictions"
|
|
54
54
|
NAMED_PREDICTIONS = "named_predictions"
|
|
55
55
|
ERROR_COUNT = "error_count"
|
|
56
|
+
MODEL_ERROR = "model_error"
|
|
56
57
|
ENTITIES = "entities"
|
|
57
58
|
FIRST_REQUEST = "first_request"
|
|
58
59
|
LAST_REQUEST = "last_request"
|
|
60
|
+
LAST_REQUEST_TIMESTAMP = "last_request_timestamp"
|
|
59
61
|
METRIC = "metric"
|
|
60
62
|
METRICS = "metrics"
|
|
61
63
|
BATCH_INTERVALS_DICT = "batch_intervals_dict"
|
|
@@ -217,6 +219,7 @@ class FileTargetKind:
|
|
|
217
219
|
APP_METRICS = "app_metrics"
|
|
218
220
|
MONITORING_SCHEDULES = "monitoring_schedules"
|
|
219
221
|
MONITORING_APPLICATION = "monitoring_application"
|
|
222
|
+
ERRORS = "errors"
|
|
220
223
|
|
|
221
224
|
|
|
222
225
|
class ModelMonitoringMode(str, Enum):
|
|
@@ -240,6 +243,7 @@ class V3IOTSDBTables(MonitoringStrEnum):
|
|
|
240
243
|
APP_RESULTS = "app-results"
|
|
241
244
|
METRICS = "metrics"
|
|
242
245
|
EVENTS = "events"
|
|
246
|
+
ERRORS = "errors"
|
|
243
247
|
|
|
244
248
|
|
|
245
249
|
class TDEngineSuperTables(MonitoringStrEnum):
|
|
@@ -52,6 +52,7 @@ class NotificationLimits(enum.Enum):
|
|
|
52
52
|
class Notification(pydantic.BaseModel):
|
|
53
53
|
"""
|
|
54
54
|
Notification object schema
|
|
55
|
+
|
|
55
56
|
:param kind: notification implementation kind - slack, webhook, etc.
|
|
56
57
|
:param name: for logging and identification
|
|
57
58
|
:param message: message content in the notification
|
|
@@ -71,9 +72,9 @@ class Notification(pydantic.BaseModel):
|
|
|
71
72
|
|
|
72
73
|
kind: NotificationKind
|
|
73
74
|
name: str
|
|
74
|
-
message: str
|
|
75
|
-
severity: NotificationSeverity
|
|
76
|
-
when: list[str]
|
|
75
|
+
message: typing.Optional[str] = None
|
|
76
|
+
severity: typing.Optional[NotificationSeverity] = None
|
|
77
|
+
when: typing.Optional[list[str]] = None
|
|
77
78
|
condition: typing.Optional[str] = None
|
|
78
79
|
params: typing.Optional[dict[str, typing.Any]] = None
|
|
79
80
|
status: typing.Optional[NotificationStatus] = None
|
mlrun/datastore/alibaba_oss.py
CHANGED
|
@@ -22,7 +22,7 @@ from fsspec.registry import get_filesystem_class
|
|
|
22
22
|
|
|
23
23
|
import mlrun.errors
|
|
24
24
|
|
|
25
|
-
from .base import DataStore, FileStats,
|
|
25
|
+
from .base import DataStore, FileStats, make_datastore_schema_sanitizer
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class OSSStore(DataStore):
|
|
@@ -53,7 +53,7 @@ class OSSStore(DataStore):
|
|
|
53
53
|
except ImportError as exc:
|
|
54
54
|
raise ImportError("ALIBABA ossfs not installed") from exc
|
|
55
55
|
filesystem_class = get_filesystem_class(protocol=self.kind)
|
|
56
|
-
self._filesystem =
|
|
56
|
+
self._filesystem = make_datastore_schema_sanitizer(
|
|
57
57
|
filesystem_class,
|
|
58
58
|
using_bucket=self.using_bucket,
|
|
59
59
|
**self.get_storage_options(),
|
mlrun/datastore/azure_blob.py
CHANGED
|
@@ -16,12 +16,13 @@ import time
|
|
|
16
16
|
from pathlib import Path
|
|
17
17
|
from urllib.parse import urlparse
|
|
18
18
|
|
|
19
|
+
from azure.storage.blob import BlobServiceClient
|
|
19
20
|
from azure.storage.blob._shared.base_client import parse_connection_str
|
|
20
21
|
from fsspec.registry import get_filesystem_class
|
|
21
22
|
|
|
22
23
|
import mlrun.errors
|
|
23
24
|
|
|
24
|
-
from .base import DataStore, FileStats,
|
|
25
|
+
from .base import DataStore, FileStats, make_datastore_schema_sanitizer
|
|
25
26
|
|
|
26
27
|
# Azure blobs will be represented with the following URL: az://<container name>. The storage account is already
|
|
27
28
|
# pointed to by the connection string, so the user is not expected to specify it in any way.
|
|
@@ -29,47 +30,131 @@ from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
|
|
|
29
30
|
|
|
30
31
|
class AzureBlobStore(DataStore):
|
|
31
32
|
using_bucket = True
|
|
33
|
+
max_concurrency = 100
|
|
34
|
+
max_blocksize = 1024 * 1024 * 4
|
|
35
|
+
max_single_put_size = (
|
|
36
|
+
1024 * 1024 * 8
|
|
37
|
+
) # for service_client property only, does not affect filesystem
|
|
32
38
|
|
|
33
39
|
def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
|
|
34
40
|
super().__init__(parent, name, schema, endpoint, secrets=secrets)
|
|
41
|
+
self._service_client = None
|
|
42
|
+
self._storage_options = None
|
|
43
|
+
|
|
44
|
+
def get_storage_options(self):
|
|
45
|
+
return self.storage_options
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def storage_options(self):
|
|
49
|
+
if not self._storage_options:
|
|
50
|
+
res = dict(
|
|
51
|
+
account_name=self._get_secret_or_env("account_name")
|
|
52
|
+
or self._get_secret_or_env("AZURE_STORAGE_ACCOUNT_NAME"),
|
|
53
|
+
account_key=self._get_secret_or_env("account_key")
|
|
54
|
+
or self._get_secret_or_env("AZURE_STORAGE_ACCOUNT_KEY"),
|
|
55
|
+
connection_string=self._get_secret_or_env("connection_string")
|
|
56
|
+
or self._get_secret_or_env("AZURE_STORAGE_CONNECTION_STRING"),
|
|
57
|
+
tenant_id=self._get_secret_or_env("tenant_id")
|
|
58
|
+
or self._get_secret_or_env("AZURE_STORAGE_TENANT_ID"),
|
|
59
|
+
client_id=self._get_secret_or_env("client_id")
|
|
60
|
+
or self._get_secret_or_env("AZURE_STORAGE_CLIENT_ID"),
|
|
61
|
+
client_secret=self._get_secret_or_env("client_secret")
|
|
62
|
+
or self._get_secret_or_env("AZURE_STORAGE_CLIENT_SECRET"),
|
|
63
|
+
sas_token=self._get_secret_or_env("sas_token")
|
|
64
|
+
or self._get_secret_or_env("AZURE_STORAGE_SAS_TOKEN"),
|
|
65
|
+
credential=self._get_secret_or_env("credential"),
|
|
66
|
+
)
|
|
67
|
+
self._storage_options = self._sanitize_storage_options(res)
|
|
68
|
+
return self._storage_options
|
|
35
69
|
|
|
36
70
|
@property
|
|
37
71
|
def filesystem(self):
|
|
38
72
|
"""return fsspec file system object, if supported"""
|
|
39
|
-
if self._filesystem:
|
|
40
|
-
return self._filesystem
|
|
41
73
|
try:
|
|
42
74
|
import adlfs # noqa
|
|
43
75
|
except ImportError as exc:
|
|
44
76
|
raise ImportError("Azure adlfs not installed") from exc
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
filesystem_class
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
77
|
+
|
|
78
|
+
if not self._filesystem:
|
|
79
|
+
# in order to support az and wasbs kinds
|
|
80
|
+
filesystem_class = get_filesystem_class(protocol=self.kind)
|
|
81
|
+
self._filesystem = make_datastore_schema_sanitizer(
|
|
82
|
+
filesystem_class,
|
|
83
|
+
using_bucket=self.using_bucket,
|
|
84
|
+
blocksize=self.max_blocksize,
|
|
85
|
+
**self.storage_options,
|
|
86
|
+
)
|
|
52
87
|
return self._filesystem
|
|
53
88
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
89
|
+
@property
|
|
90
|
+
def service_client(self):
|
|
91
|
+
try:
|
|
92
|
+
import azure # noqa
|
|
93
|
+
except ImportError as exc:
|
|
94
|
+
raise ImportError("Azure not installed") from exc
|
|
95
|
+
|
|
96
|
+
if not self._service_client:
|
|
97
|
+
self._do_connect()
|
|
98
|
+
return self._service_client
|
|
99
|
+
|
|
100
|
+
def _do_connect(self):
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
Creates a client for azure.
|
|
104
|
+
Raises MLRunInvalidArgumentError if none of the connection details are available
|
|
105
|
+
based on do_connect in AzureBlobFileSystem:
|
|
106
|
+
https://github.com/fsspec/adlfs/blob/2023.9.0/adlfs/spec.py#L422
|
|
107
|
+
"""
|
|
108
|
+
from azure.identity import ClientSecretCredential
|
|
109
|
+
|
|
110
|
+
storage_options = self.storage_options
|
|
111
|
+
connection_string = storage_options.get("connection_string")
|
|
112
|
+
client_name = storage_options.get("account_name")
|
|
113
|
+
account_key = storage_options.get("account_key")
|
|
114
|
+
sas_token = storage_options.get("sas_token")
|
|
115
|
+
client_id = storage_options.get("client_id")
|
|
116
|
+
credential = storage_options.get("credential")
|
|
117
|
+
|
|
118
|
+
credential_from_client_id = None
|
|
119
|
+
if (
|
|
120
|
+
credential is None
|
|
121
|
+
and account_key is None
|
|
122
|
+
and sas_token is None
|
|
123
|
+
and client_id is not None
|
|
124
|
+
):
|
|
125
|
+
credential_from_client_id = ClientSecretCredential(
|
|
126
|
+
tenant_id=storage_options.get("tenant_id"),
|
|
127
|
+
client_id=client_id,
|
|
128
|
+
client_secret=storage_options.get("client_secret"),
|
|
129
|
+
)
|
|
130
|
+
try:
|
|
131
|
+
if connection_string is not None:
|
|
132
|
+
self._service_client = BlobServiceClient.from_connection_string(
|
|
133
|
+
conn_str=connection_string,
|
|
134
|
+
max_block_size=self.max_blocksize,
|
|
135
|
+
max_single_put_size=self.max_single_put_size,
|
|
136
|
+
)
|
|
137
|
+
elif client_name is not None:
|
|
138
|
+
account_url = f"https://{client_name}.blob.core.windows.net"
|
|
139
|
+
cred = credential_from_client_id or credential or account_key
|
|
140
|
+
if not cred and sas_token is not None:
|
|
141
|
+
if not sas_token.startswith("?"):
|
|
142
|
+
sas_token = f"?{sas_token}"
|
|
143
|
+
account_url = account_url + sas_token
|
|
144
|
+
self._service_client = BlobServiceClient(
|
|
145
|
+
account_url=account_url,
|
|
146
|
+
credential=cred,
|
|
147
|
+
max_block_size=self.max_blocksize,
|
|
148
|
+
max_single_put_size=self.max_single_put_size,
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
152
|
+
"Must provide either a connection_string or account_name with credentials"
|
|
153
|
+
)
|
|
154
|
+
except Exception as e:
|
|
155
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
156
|
+
f"unable to connect to account for {e}"
|
|
157
|
+
)
|
|
73
158
|
|
|
74
159
|
def _convert_key_to_remote_path(self, key):
|
|
75
160
|
key = key.strip("/")
|
|
@@ -82,7 +167,15 @@ class AzureBlobStore(DataStore):
|
|
|
82
167
|
|
|
83
168
|
def upload(self, key, src_path):
|
|
84
169
|
remote_path = self._convert_key_to_remote_path(key)
|
|
85
|
-
|
|
170
|
+
container, remote_path = remote_path.split("/", 1)
|
|
171
|
+
container_client = self.service_client.get_container_client(container=container)
|
|
172
|
+
with open(file=src_path, mode="rb") as data:
|
|
173
|
+
container_client.upload_blob(
|
|
174
|
+
name=remote_path,
|
|
175
|
+
data=data,
|
|
176
|
+
overwrite=True,
|
|
177
|
+
max_concurrency=self.max_concurrency,
|
|
178
|
+
)
|
|
86
179
|
|
|
87
180
|
def get(self, key, size=None, offset=0):
|
|
88
181
|
remote_path = self._convert_key_to_remote_path(key)
|
|
@@ -135,7 +228,7 @@ class AzureBlobStore(DataStore):
|
|
|
135
228
|
|
|
136
229
|
def get_spark_options(self):
|
|
137
230
|
res = {}
|
|
138
|
-
st = self.
|
|
231
|
+
st = self.storage_options
|
|
139
232
|
service = "blob"
|
|
140
233
|
primary_url = None
|
|
141
234
|
if st.get("connection_string"):
|
mlrun/datastore/base.py
CHANGED
|
@@ -748,7 +748,7 @@ class HttpStore(DataStore):
|
|
|
748
748
|
# As an example, it converts an S3 URL 's3://s3bucket/path' to just 's3bucket/path'.
|
|
749
749
|
# Since 'ds' schemas are not inherently processed by fsspec, we have adapted the _strip_protocol()
|
|
750
750
|
# method specifically to strip away the 'ds' schema as required.
|
|
751
|
-
def
|
|
751
|
+
def make_datastore_schema_sanitizer(cls, using_bucket=False, *args, **kwargs):
|
|
752
752
|
if not issubclass(cls, fsspec.AbstractFileSystem):
|
|
753
753
|
raise ValueError("Class must be a subclass of fsspec.AbstractFileSystem")
|
|
754
754
|
|
mlrun/datastore/dbfs_store.py
CHANGED
|
@@ -19,7 +19,7 @@ from fsspec.registry import get_filesystem_class
|
|
|
19
19
|
|
|
20
20
|
import mlrun.errors
|
|
21
21
|
|
|
22
|
-
from .base import DataStore, FileStats,
|
|
22
|
+
from .base import DataStore, FileStats, make_datastore_schema_sanitizer
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
class DatabricksFileBugFixed(DatabricksFile):
|
|
@@ -89,7 +89,7 @@ class DBFSStore(DataStore):
|
|
|
89
89
|
"""return fsspec file system object, if supported"""
|
|
90
90
|
filesystem_class = get_filesystem_class(protocol=self.kind)
|
|
91
91
|
if not self._filesystem:
|
|
92
|
-
self._filesystem =
|
|
92
|
+
self._filesystem = make_datastore_schema_sanitizer(
|
|
93
93
|
cls=filesystem_class,
|
|
94
94
|
using_bucket=False,
|
|
95
95
|
**self.get_storage_options(),
|
|
@@ -12,44 +12,82 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
import json
|
|
15
|
+
import os
|
|
15
16
|
from pathlib import Path
|
|
16
17
|
|
|
17
18
|
from fsspec.registry import get_filesystem_class
|
|
19
|
+
from google.auth.credentials import Credentials
|
|
20
|
+
from google.cloud.storage import Client, transfer_manager
|
|
21
|
+
from google.oauth2 import service_account
|
|
18
22
|
|
|
19
23
|
import mlrun.errors
|
|
20
24
|
from mlrun.utils import logger
|
|
21
25
|
|
|
22
|
-
from .base import DataStore, FileStats,
|
|
26
|
+
from .base import DataStore, FileStats, make_datastore_schema_sanitizer
|
|
23
27
|
|
|
24
28
|
# Google storage objects will be represented with the following URL: gcs://<bucket name>/<path> or gs://...
|
|
25
29
|
|
|
26
30
|
|
|
27
31
|
class GoogleCloudStorageStore(DataStore):
|
|
28
32
|
using_bucket = True
|
|
33
|
+
workers = 8
|
|
34
|
+
chunk_size = 32 * 1024 * 1024
|
|
29
35
|
|
|
30
36
|
def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
|
|
31
37
|
super().__init__(parent, name, schema, endpoint, secrets=secrets)
|
|
38
|
+
self._storage_client = None
|
|
39
|
+
self._storage_options = None
|
|
40
|
+
|
|
41
|
+
@property
|
|
42
|
+
def storage_client(self):
|
|
43
|
+
if self._storage_client:
|
|
44
|
+
return self._storage_client
|
|
45
|
+
|
|
46
|
+
token = self._get_credentials().get("token")
|
|
47
|
+
access = "https://www.googleapis.com/auth/devstorage.full_control"
|
|
48
|
+
if isinstance(token, str):
|
|
49
|
+
if os.path.exists(token):
|
|
50
|
+
credentials = service_account.Credentials.from_service_account_file(
|
|
51
|
+
token, scopes=[access]
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
55
|
+
"gcsfs authentication file not found!"
|
|
56
|
+
)
|
|
57
|
+
elif isinstance(token, dict):
|
|
58
|
+
credentials = service_account.Credentials.from_service_account_info(
|
|
59
|
+
token, scopes=[access]
|
|
60
|
+
)
|
|
61
|
+
elif isinstance(token, Credentials):
|
|
62
|
+
credentials = token
|
|
63
|
+
else:
|
|
64
|
+
raise ValueError(f"Unsupported token type: {type(token)}")
|
|
65
|
+
self._storage_client = Client(credentials=credentials)
|
|
66
|
+
return self._storage_client
|
|
32
67
|
|
|
33
68
|
@property
|
|
34
69
|
def filesystem(self):
|
|
35
70
|
"""return fsspec file system object, if supported"""
|
|
36
|
-
if self._filesystem:
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
) from exc
|
|
44
|
-
filesystem_class = get_filesystem_class(protocol=self.kind)
|
|
45
|
-
self._filesystem = makeDatastoreSchemaSanitizer(
|
|
46
|
-
filesystem_class,
|
|
47
|
-
using_bucket=self.using_bucket,
|
|
48
|
-
**self.get_storage_options(),
|
|
49
|
-
)
|
|
71
|
+
if not self._filesystem:
|
|
72
|
+
filesystem_class = get_filesystem_class(protocol=self.kind)
|
|
73
|
+
self._filesystem = make_datastore_schema_sanitizer(
|
|
74
|
+
filesystem_class,
|
|
75
|
+
using_bucket=self.using_bucket,
|
|
76
|
+
**self.storage_options,
|
|
77
|
+
)
|
|
50
78
|
return self._filesystem
|
|
51
79
|
|
|
52
|
-
|
|
80
|
+
@property
|
|
81
|
+
def storage_options(self):
|
|
82
|
+
if self._storage_options:
|
|
83
|
+
return self._storage_options
|
|
84
|
+
credentials = self._get_credentials()
|
|
85
|
+
# due to caching problem introduced in gcsfs 2024.3.1 (ML-7636)
|
|
86
|
+
credentials["use_listings_cache"] = False
|
|
87
|
+
self._storage_options = credentials
|
|
88
|
+
return self._storage_options
|
|
89
|
+
|
|
90
|
+
def _get_credentials(self):
|
|
53
91
|
credentials = self._get_secret_or_env(
|
|
54
92
|
"GCP_CREDENTIALS"
|
|
55
93
|
) or self._get_secret_or_env("GOOGLE_APPLICATION_CREDENTIALS")
|
|
@@ -71,6 +109,9 @@ class GoogleCloudStorageStore(DataStore):
|
|
|
71
109
|
)
|
|
72
110
|
return self._sanitize_storage_options(None)
|
|
73
111
|
|
|
112
|
+
def get_storage_options(self):
|
|
113
|
+
return self.storage_options
|
|
114
|
+
|
|
74
115
|
def _make_path(self, key):
|
|
75
116
|
key = key.strip("/")
|
|
76
117
|
path = Path(self.endpoint, key).as_posix()
|
|
@@ -103,8 +144,29 @@ class GoogleCloudStorageStore(DataStore):
|
|
|
103
144
|
f.write(data)
|
|
104
145
|
|
|
105
146
|
def upload(self, key, src_path):
|
|
106
|
-
|
|
107
|
-
self.
|
|
147
|
+
file_size = os.path.getsize(src_path)
|
|
148
|
+
united_path = self._make_path(key)
|
|
149
|
+
|
|
150
|
+
# Multiple upload limitation recommendations as described in
|
|
151
|
+
# https://cloud.google.com/storage/docs/multipart-uploads#storage-upload-object-chunks-python
|
|
152
|
+
|
|
153
|
+
if file_size <= self.chunk_size:
|
|
154
|
+
self.filesystem.put_file(src_path, united_path, overwrite=True)
|
|
155
|
+
return
|
|
156
|
+
|
|
157
|
+
bucket = self.storage_client.bucket(self.endpoint)
|
|
158
|
+
blob = bucket.blob(key.strip("/"))
|
|
159
|
+
|
|
160
|
+
try:
|
|
161
|
+
transfer_manager.upload_chunks_concurrently(
|
|
162
|
+
src_path, blob, chunk_size=self.chunk_size, max_workers=self.workers
|
|
163
|
+
)
|
|
164
|
+
except Exception as upload_chunks_concurrently_exception:
|
|
165
|
+
logger.warning(
|
|
166
|
+
f"gcs: failed to concurrently upload {src_path},"
|
|
167
|
+
f" exception: {upload_chunks_concurrently_exception}. Retrying with single part upload."
|
|
168
|
+
)
|
|
169
|
+
self.filesystem.put_file(src_path, united_path, overwrite=True)
|
|
108
170
|
|
|
109
171
|
def stat(self, key):
|
|
110
172
|
path = self._make_path(key)
|
|
@@ -133,12 +195,13 @@ class GoogleCloudStorageStore(DataStore):
|
|
|
133
195
|
|
|
134
196
|
def rm(self, path, recursive=False, maxdepth=None):
|
|
135
197
|
path = self._make_path(path)
|
|
198
|
+
# in order to raise an error in case of a connection error (ML-7056)
|
|
136
199
|
self.filesystem.exists(path)
|
|
137
|
-
|
|
200
|
+
super().rm(path, recursive=recursive, maxdepth=maxdepth)
|
|
138
201
|
|
|
139
202
|
def get_spark_options(self):
|
|
140
203
|
res = {}
|
|
141
|
-
st = self.
|
|
204
|
+
st = self._get_credentials()
|
|
142
205
|
if "token" in st:
|
|
143
206
|
res = {"spark.hadoop.google.cloud.auth.service.account.enable": "true"}
|
|
144
207
|
if isinstance(st["token"], str):
|
mlrun/datastore/s3.py
CHANGED
|
@@ -20,7 +20,7 @@ from fsspec.registry import get_filesystem_class
|
|
|
20
20
|
|
|
21
21
|
import mlrun.errors
|
|
22
22
|
|
|
23
|
-
from .base import DataStore, FileStats, get_range,
|
|
23
|
+
from .base import DataStore, FileStats, get_range, make_datastore_schema_sanitizer
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class S3Store(DataStore):
|
|
@@ -126,7 +126,7 @@ class S3Store(DataStore):
|
|
|
126
126
|
except ImportError as exc:
|
|
127
127
|
raise ImportError("AWS s3fs not installed") from exc
|
|
128
128
|
filesystem_class = get_filesystem_class(protocol=self.kind)
|
|
129
|
-
self._filesystem =
|
|
129
|
+
self._filesystem = make_datastore_schema_sanitizer(
|
|
130
130
|
filesystem_class,
|
|
131
131
|
using_bucket=self.using_bucket,
|
|
132
132
|
**self.get_storage_options(),
|
mlrun/datastore/sources.py
CHANGED
|
@@ -32,6 +32,7 @@ from mlrun.config import config
|
|
|
32
32
|
from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
|
|
33
33
|
from mlrun.datastore.utils import transform_list_filters_to_tuple
|
|
34
34
|
from mlrun.secrets import SecretsStore
|
|
35
|
+
from mlrun.utils import logger
|
|
35
36
|
|
|
36
37
|
from ..model import DataSource
|
|
37
38
|
from ..platforms.iguazio import parse_path
|
|
@@ -1163,6 +1164,59 @@ class KafkaSource(OnlineSource):
|
|
|
1163
1164
|
"to a Spark dataframe is not possible, as this operation is not supported by Spark"
|
|
1164
1165
|
)
|
|
1165
1166
|
|
|
1167
|
+
def create_topics(
|
|
1168
|
+
self,
|
|
1169
|
+
num_partitions: int = 4,
|
|
1170
|
+
replication_factor: int = 1,
|
|
1171
|
+
topics: list[str] = None,
|
|
1172
|
+
):
|
|
1173
|
+
"""
|
|
1174
|
+
Create Kafka topics with the specified number of partitions and replication factor.
|
|
1175
|
+
|
|
1176
|
+
:param num_partitions: number of partitions for the topics
|
|
1177
|
+
:param replication_factor: replication factor for the topics
|
|
1178
|
+
:param topics: list of topic names to create, if None,
|
|
1179
|
+
the topics will be taken from the source attributes
|
|
1180
|
+
"""
|
|
1181
|
+
from kafka.admin import KafkaAdminClient, NewTopic
|
|
1182
|
+
|
|
1183
|
+
brokers = self.attributes.get("brokers")
|
|
1184
|
+
if not brokers:
|
|
1185
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1186
|
+
"brokers must be specified in the KafkaSource attributes"
|
|
1187
|
+
)
|
|
1188
|
+
topics = topics or self.attributes.get("topics")
|
|
1189
|
+
if not topics:
|
|
1190
|
+
raise mlrun.errors.MLRunInvalidArgumentError(
|
|
1191
|
+
"topics must be specified in the KafkaSource attributes"
|
|
1192
|
+
)
|
|
1193
|
+
new_topics = [
|
|
1194
|
+
NewTopic(topic, num_partitions, replication_factor) for topic in topics
|
|
1195
|
+
]
|
|
1196
|
+
kafka_admin = KafkaAdminClient(
|
|
1197
|
+
bootstrap_servers=brokers,
|
|
1198
|
+
sasl_mechanism=self.attributes.get("sasl", {}).get("sasl_mechanism"),
|
|
1199
|
+
sasl_plain_username=self.attributes.get("sasl", {}).get("username"),
|
|
1200
|
+
sasl_plain_password=self.attributes.get("sasl", {}).get("password"),
|
|
1201
|
+
sasl_kerberos_service_name=self.attributes.get("sasl", {}).get(
|
|
1202
|
+
"sasl_kerberos_service_name", "kafka"
|
|
1203
|
+
),
|
|
1204
|
+
sasl_kerberos_domain_name=self.attributes.get("sasl", {}).get(
|
|
1205
|
+
"sasl_kerberos_domain_name"
|
|
1206
|
+
),
|
|
1207
|
+
sasl_oauth_token_provider=self.attributes.get("sasl", {}).get("mechanism"),
|
|
1208
|
+
)
|
|
1209
|
+
try:
|
|
1210
|
+
kafka_admin.create_topics(new_topics)
|
|
1211
|
+
finally:
|
|
1212
|
+
kafka_admin.close()
|
|
1213
|
+
logger.info(
|
|
1214
|
+
"Kafka topics created successfully",
|
|
1215
|
+
topics=topics,
|
|
1216
|
+
num_partitions=num_partitions,
|
|
1217
|
+
replication_factor=replication_factor,
|
|
1218
|
+
)
|
|
1219
|
+
|
|
1166
1220
|
|
|
1167
1221
|
class SQLSource(BaseSourceDriver):
|
|
1168
1222
|
kind = "sqldb"
|