mlrun 1.7.0rc37__py3-none-any.whl → 1.7.0rc39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (52) hide show
  1. mlrun/alerts/alert.py +34 -30
  2. mlrun/common/schemas/alert.py +3 -0
  3. mlrun/common/schemas/model_monitoring/constants.py +4 -0
  4. mlrun/common/schemas/notification.py +4 -3
  5. mlrun/datastore/alibaba_oss.py +2 -2
  6. mlrun/datastore/azure_blob.py +124 -31
  7. mlrun/datastore/base.py +1 -1
  8. mlrun/datastore/dbfs_store.py +2 -2
  9. mlrun/datastore/google_cloud_storage.py +83 -20
  10. mlrun/datastore/s3.py +2 -2
  11. mlrun/datastore/sources.py +54 -0
  12. mlrun/datastore/targets.py +9 -53
  13. mlrun/db/httpdb.py +6 -1
  14. mlrun/errors.py +8 -0
  15. mlrun/execution.py +7 -0
  16. mlrun/feature_store/api.py +5 -0
  17. mlrun/feature_store/common.py +6 -11
  18. mlrun/feature_store/retrieval/job.py +1 -0
  19. mlrun/model.py +29 -3
  20. mlrun/model_monitoring/api.py +9 -0
  21. mlrun/model_monitoring/applications/_application_steps.py +36 -0
  22. mlrun/model_monitoring/applications/histogram_data_drift.py +15 -13
  23. mlrun/model_monitoring/controller.py +15 -11
  24. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +14 -11
  25. mlrun/model_monitoring/db/tsdb/base.py +121 -1
  26. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +85 -47
  27. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +100 -12
  28. mlrun/model_monitoring/db/tsdb/v3io/stream_graph_steps.py +23 -1
  29. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +214 -36
  30. mlrun/model_monitoring/helpers.py +16 -17
  31. mlrun/model_monitoring/stream_processing.py +68 -27
  32. mlrun/projects/operations.py +1 -1
  33. mlrun/projects/pipelines.py +19 -30
  34. mlrun/projects/project.py +76 -52
  35. mlrun/run.py +8 -6
  36. mlrun/runtimes/__init__.py +19 -8
  37. mlrun/runtimes/nuclio/api_gateway.py +9 -0
  38. mlrun/runtimes/nuclio/application/application.py +64 -9
  39. mlrun/runtimes/nuclio/function.py +1 -1
  40. mlrun/runtimes/pod.py +2 -2
  41. mlrun/runtimes/remotesparkjob.py +2 -5
  42. mlrun/runtimes/sparkjob/spark3job.py +7 -9
  43. mlrun/serving/v2_serving.py +1 -0
  44. mlrun/track/trackers/mlflow_tracker.py +5 -0
  45. mlrun/utils/helpers.py +21 -0
  46. mlrun/utils/version/version.json +2 -2
  47. {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/METADATA +14 -11
  48. {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/RECORD +52 -52
  49. {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/WHEEL +1 -1
  50. {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/LICENSE +0 -0
  51. {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/entry_points.txt +0 -0
  52. {mlrun-1.7.0rc37.dist-info → mlrun-1.7.0rc39.dist-info}/top_level.txt +0 -0
mlrun/alerts/alert.py CHANGED
@@ -29,6 +29,7 @@ class AlertConfig(ModelObj):
29
29
  "reset_policy",
30
30
  "state",
31
31
  "count",
32
+ "created",
32
33
  ]
33
34
  _fields_to_serialize = ModelObj._fields_to_serialize + [
34
35
  "entities",
@@ -55,12 +56,13 @@ class AlertConfig(ModelObj):
55
56
  created: str = None,
56
57
  count: int = None,
57
58
  ):
58
- """
59
- Alert config object
59
+ """Alert config object
60
60
 
61
61
  Example::
62
+
62
63
  # create an alert on endpoint_id, which will be triggered to slack if there is a "data_drift_detected" event
63
- 3 times in the next hour.
64
+ # 3 times in the next hour.
65
+
64
66
  from mlrun.alerts import AlertConfig
65
67
  import mlrun.common.schemas.alert as alert_objects
66
68
 
@@ -93,29 +95,29 @@ class AlertConfig(ModelObj):
93
95
  )
94
96
  project.store_alert_config(alert_data)
95
97
 
96
- :param project: name of the project to associate the alert with
97
- :param name: name of the alert
98
- :param template: optional parameter that allows to create an alert based on a predefined template.
99
- you can pass either an AlertTemplate object or a string (the template name).
100
- if a template is used, many fields of the alert will be auto-generated based on the
101
- template. however, you still need to provide the following fields:
98
+ :param project: Name of the project to associate the alert with
99
+ :param name: Name of the alert
100
+ :param template: Optional parameter that allows creating an alert based on a predefined template.
101
+ You can pass either an AlertTemplate object or a string (the template name).
102
+ If a template is used, many fields of the alert will be auto-generated based on the
103
+ template.However, you still need to provide the following fields:
102
104
  `name`, `project`, `entity`, `notifications`
103
- :param description: description of the alert
104
- :param summary: summary of the alert, will be sent in the generated notifications
105
- :param severity: severity of the alert
106
- :param trigger: the events that will trigger this alert, may be a simple trigger based on events or
105
+ :param description: Description of the alert
106
+ :param summary: Summary of the alert, will be sent in the generated notifications
107
+ :param severity: Severity of the alert
108
+ :param trigger: The events that will trigger this alert, may be a simple trigger based on events or
107
109
  complex trigger which is based on a prometheus alert
108
- :param criteria: when the alert will be triggered based on the specified number of events within the
110
+ :param criteria: When the alert will be triggered based on the specified number of events within the
109
111
  defined time period.
110
- :param reset_policy: when to clear the alert. May be "manual" for manual reset of the alert, or
112
+ :param reset_policy: When to clear the alert. May be "manual" for manual reset of the alert, or
111
113
  "auto" if the criteria contains a time period
112
- :param notifications: list of notifications to invoke once the alert is triggered
113
- :param entities: entities that the event relates to. The entity object will contain fields that uniquely
114
- identify a given entity in the system
115
- :param id: internal id of the alert (user should not supply it)
116
- :param state: state of the alert, may be active/inactive (user should not supply it)
117
- :param created: when the alert is created (user should not supply it)
118
- :param count: internal counter of the alert (user should not supply it)
114
+ :param notifications: List of notifications to invoke once the alert is triggered
115
+ :param entities: Entities that the event relates to. The entity object will contain fields that
116
+ uniquely identify a given entity in the system
117
+ :param id: Internal id of the alert (user should not supply it)
118
+ :param state: State of the alert, may be active/inactive (user should not supply it)
119
+ :param created: When the alert is created (user should not supply it)
120
+ :param count: Internal counter of the alert (user should not supply it)
119
121
  """
120
122
  self.project = project
121
123
  self.name = name
@@ -136,8 +138,8 @@ class AlertConfig(ModelObj):
136
138
  self._apply_template(template)
137
139
 
138
140
  def validate_required_fields(self):
139
- if not self.project or not self.name:
140
- raise mlrun.errors.MLRunBadRequestError("Project and name must be provided")
141
+ if not self.name:
142
+ raise mlrun.errors.MLRunInvalidArgumentError("Alert name must be provided")
141
143
 
142
144
  def _serialize_field(
143
145
  self, struct: dict, field_name: str = None, strip: bool = False
@@ -236,9 +238,11 @@ class AlertConfig(ModelObj):
236
238
  db = mlrun.get_run_db()
237
239
  template = db.get_alert_template(template)
238
240
 
239
- # Extract parameters from the template and apply them to the AlertConfig object
240
- self.summary = template.summary
241
- self.severity = template.severity
242
- self.criteria = template.criteria
243
- self.trigger = template.trigger
244
- self.reset_policy = template.reset_policy
241
+ # Apply parameters from the template to the AlertConfig object only if they are not already specified by the
242
+ # user in the current configuration.
243
+ # User-provided parameters will take precedence over corresponding template values
244
+ self.summary = self.summary or template.summary
245
+ self.severity = self.severity or template.severity
246
+ self.criteria = self.criteria or template.criteria
247
+ self.trigger = self.trigger or template.trigger
248
+ self.reset_policy = self.reset_policy or template.reset_policy
@@ -23,6 +23,7 @@ from mlrun.common.types import StrEnum
23
23
 
24
24
  class EventEntityKind(StrEnum):
25
25
  MODEL_ENDPOINT_RESULT = "model-endpoint-result"
26
+ MODEL_MONITORING_APPLICATION = "model-monitoring-application"
26
27
  JOB = "job"
27
28
 
28
29
 
@@ -43,6 +44,7 @@ class EventKind(StrEnum):
43
44
  SYSTEM_PERFORMANCE_SUSPECTED = "system_performance_suspected"
44
45
  MM_APP_ANOMALY_DETECTED = "mm_app_anomaly_detected"
45
46
  MM_APP_ANOMALY_SUSPECTED = "mm_app_anomaly_suspected"
47
+ MM_APP_FAILED = "mm_app_failed"
46
48
  FAILED = "failed"
47
49
 
48
50
 
@@ -57,6 +59,7 @@ _event_kind_entity_map = {
57
59
  EventKind.SYSTEM_PERFORMANCE_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
58
60
  EventKind.MM_APP_ANOMALY_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
59
61
  EventKind.MM_APP_ANOMALY_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
62
+ EventKind.MM_APP_FAILED: [EventEntityKind.MODEL_MONITORING_APPLICATION],
60
63
  EventKind.FAILED: [EventEntityKind.JOB],
61
64
  }
62
65
 
@@ -53,9 +53,11 @@ class EventFieldType:
53
53
  PREDICTIONS = "predictions"
54
54
  NAMED_PREDICTIONS = "named_predictions"
55
55
  ERROR_COUNT = "error_count"
56
+ MODEL_ERROR = "model_error"
56
57
  ENTITIES = "entities"
57
58
  FIRST_REQUEST = "first_request"
58
59
  LAST_REQUEST = "last_request"
60
+ LAST_REQUEST_TIMESTAMP = "last_request_timestamp"
59
61
  METRIC = "metric"
60
62
  METRICS = "metrics"
61
63
  BATCH_INTERVALS_DICT = "batch_intervals_dict"
@@ -217,6 +219,7 @@ class FileTargetKind:
217
219
  APP_METRICS = "app_metrics"
218
220
  MONITORING_SCHEDULES = "monitoring_schedules"
219
221
  MONITORING_APPLICATION = "monitoring_application"
222
+ ERRORS = "errors"
220
223
 
221
224
 
222
225
  class ModelMonitoringMode(str, Enum):
@@ -240,6 +243,7 @@ class V3IOTSDBTables(MonitoringStrEnum):
240
243
  APP_RESULTS = "app-results"
241
244
  METRICS = "metrics"
242
245
  EVENTS = "events"
246
+ ERRORS = "errors"
243
247
 
244
248
 
245
249
  class TDEngineSuperTables(MonitoringStrEnum):
@@ -52,6 +52,7 @@ class NotificationLimits(enum.Enum):
52
52
  class Notification(pydantic.BaseModel):
53
53
  """
54
54
  Notification object schema
55
+
55
56
  :param kind: notification implementation kind - slack, webhook, etc.
56
57
  :param name: for logging and identification
57
58
  :param message: message content in the notification
@@ -71,9 +72,9 @@ class Notification(pydantic.BaseModel):
71
72
 
72
73
  kind: NotificationKind
73
74
  name: str
74
- message: str
75
- severity: NotificationSeverity
76
- when: list[str]
75
+ message: typing.Optional[str] = None
76
+ severity: typing.Optional[NotificationSeverity] = None
77
+ when: typing.Optional[list[str]] = None
77
78
  condition: typing.Optional[str] = None
78
79
  params: typing.Optional[dict[str, typing.Any]] = None
79
80
  status: typing.Optional[NotificationStatus] = None
@@ -22,7 +22,7 @@ from fsspec.registry import get_filesystem_class
22
22
 
23
23
  import mlrun.errors
24
24
 
25
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
25
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
26
26
 
27
27
 
28
28
  class OSSStore(DataStore):
@@ -53,7 +53,7 @@ class OSSStore(DataStore):
53
53
  except ImportError as exc:
54
54
  raise ImportError("ALIBABA ossfs not installed") from exc
55
55
  filesystem_class = get_filesystem_class(protocol=self.kind)
56
- self._filesystem = makeDatastoreSchemaSanitizer(
56
+ self._filesystem = make_datastore_schema_sanitizer(
57
57
  filesystem_class,
58
58
  using_bucket=self.using_bucket,
59
59
  **self.get_storage_options(),
@@ -16,12 +16,13 @@ import time
16
16
  from pathlib import Path
17
17
  from urllib.parse import urlparse
18
18
 
19
+ from azure.storage.blob import BlobServiceClient
19
20
  from azure.storage.blob._shared.base_client import parse_connection_str
20
21
  from fsspec.registry import get_filesystem_class
21
22
 
22
23
  import mlrun.errors
23
24
 
24
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
25
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
25
26
 
26
27
  # Azure blobs will be represented with the following URL: az://<container name>. The storage account is already
27
28
  # pointed to by the connection string, so the user is not expected to specify it in any way.
@@ -29,47 +30,131 @@ from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
29
30
 
30
31
  class AzureBlobStore(DataStore):
31
32
  using_bucket = True
33
+ max_concurrency = 100
34
+ max_blocksize = 1024 * 1024 * 4
35
+ max_single_put_size = (
36
+ 1024 * 1024 * 8
37
+ ) # for service_client property only, does not affect filesystem
32
38
 
33
39
  def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
34
40
  super().__init__(parent, name, schema, endpoint, secrets=secrets)
41
+ self._service_client = None
42
+ self._storage_options = None
43
+
44
+ def get_storage_options(self):
45
+ return self.storage_options
46
+
47
+ @property
48
+ def storage_options(self):
49
+ if not self._storage_options:
50
+ res = dict(
51
+ account_name=self._get_secret_or_env("account_name")
52
+ or self._get_secret_or_env("AZURE_STORAGE_ACCOUNT_NAME"),
53
+ account_key=self._get_secret_or_env("account_key")
54
+ or self._get_secret_or_env("AZURE_STORAGE_ACCOUNT_KEY"),
55
+ connection_string=self._get_secret_or_env("connection_string")
56
+ or self._get_secret_or_env("AZURE_STORAGE_CONNECTION_STRING"),
57
+ tenant_id=self._get_secret_or_env("tenant_id")
58
+ or self._get_secret_or_env("AZURE_STORAGE_TENANT_ID"),
59
+ client_id=self._get_secret_or_env("client_id")
60
+ or self._get_secret_or_env("AZURE_STORAGE_CLIENT_ID"),
61
+ client_secret=self._get_secret_or_env("client_secret")
62
+ or self._get_secret_or_env("AZURE_STORAGE_CLIENT_SECRET"),
63
+ sas_token=self._get_secret_or_env("sas_token")
64
+ or self._get_secret_or_env("AZURE_STORAGE_SAS_TOKEN"),
65
+ credential=self._get_secret_or_env("credential"),
66
+ )
67
+ self._storage_options = self._sanitize_storage_options(res)
68
+ return self._storage_options
35
69
 
36
70
  @property
37
71
  def filesystem(self):
38
72
  """return fsspec file system object, if supported"""
39
- if self._filesystem:
40
- return self._filesystem
41
73
  try:
42
74
  import adlfs # noqa
43
75
  except ImportError as exc:
44
76
  raise ImportError("Azure adlfs not installed") from exc
45
- # in order to support az and wasbs kinds.
46
- filesystem_class = get_filesystem_class(protocol=self.kind)
47
- self._filesystem = makeDatastoreSchemaSanitizer(
48
- filesystem_class,
49
- using_bucket=self.using_bucket,
50
- **self.get_storage_options(),
51
- )
77
+
78
+ if not self._filesystem:
79
+ # in order to support az and wasbs kinds
80
+ filesystem_class = get_filesystem_class(protocol=self.kind)
81
+ self._filesystem = make_datastore_schema_sanitizer(
82
+ filesystem_class,
83
+ using_bucket=self.using_bucket,
84
+ blocksize=self.max_blocksize,
85
+ **self.storage_options,
86
+ )
52
87
  return self._filesystem
53
88
 
54
- def get_storage_options(self):
55
- res = dict(
56
- account_name=self._get_secret_or_env("account_name")
57
- or self._get_secret_or_env("AZURE_STORAGE_ACCOUNT_NAME"),
58
- account_key=self._get_secret_or_env("account_key")
59
- or self._get_secret_or_env("AZURE_STORAGE_KEY"),
60
- connection_string=self._get_secret_or_env("connection_string")
61
- or self._get_secret_or_env("AZURE_STORAGE_CONNECTION_STRING"),
62
- tenant_id=self._get_secret_or_env("tenant_id")
63
- or self._get_secret_or_env("AZURE_STORAGE_TENANT_ID"),
64
- client_id=self._get_secret_or_env("client_id")
65
- or self._get_secret_or_env("AZURE_STORAGE_CLIENT_ID"),
66
- client_secret=self._get_secret_or_env("client_secret")
67
- or self._get_secret_or_env("AZURE_STORAGE_CLIENT_SECRET"),
68
- sas_token=self._get_secret_or_env("sas_token")
69
- or self._get_secret_or_env("AZURE_STORAGE_SAS_TOKEN"),
70
- credential=self._get_secret_or_env("credential"),
71
- )
72
- return self._sanitize_storage_options(res)
89
+ @property
90
+ def service_client(self):
91
+ try:
92
+ import azure # noqa
93
+ except ImportError as exc:
94
+ raise ImportError("Azure not installed") from exc
95
+
96
+ if not self._service_client:
97
+ self._do_connect()
98
+ return self._service_client
99
+
100
+ def _do_connect(self):
101
+ """
102
+
103
+ Creates a client for azure.
104
+ Raises MLRunInvalidArgumentError if none of the connection details are available
105
+ based on do_connect in AzureBlobFileSystem:
106
+ https://github.com/fsspec/adlfs/blob/2023.9.0/adlfs/spec.py#L422
107
+ """
108
+ from azure.identity import ClientSecretCredential
109
+
110
+ storage_options = self.storage_options
111
+ connection_string = storage_options.get("connection_string")
112
+ client_name = storage_options.get("account_name")
113
+ account_key = storage_options.get("account_key")
114
+ sas_token = storage_options.get("sas_token")
115
+ client_id = storage_options.get("client_id")
116
+ credential = storage_options.get("credential")
117
+
118
+ credential_from_client_id = None
119
+ if (
120
+ credential is None
121
+ and account_key is None
122
+ and sas_token is None
123
+ and client_id is not None
124
+ ):
125
+ credential_from_client_id = ClientSecretCredential(
126
+ tenant_id=storage_options.get("tenant_id"),
127
+ client_id=client_id,
128
+ client_secret=storage_options.get("client_secret"),
129
+ )
130
+ try:
131
+ if connection_string is not None:
132
+ self._service_client = BlobServiceClient.from_connection_string(
133
+ conn_str=connection_string,
134
+ max_block_size=self.max_blocksize,
135
+ max_single_put_size=self.max_single_put_size,
136
+ )
137
+ elif client_name is not None:
138
+ account_url = f"https://{client_name}.blob.core.windows.net"
139
+ cred = credential_from_client_id or credential or account_key
140
+ if not cred and sas_token is not None:
141
+ if not sas_token.startswith("?"):
142
+ sas_token = f"?{sas_token}"
143
+ account_url = account_url + sas_token
144
+ self._service_client = BlobServiceClient(
145
+ account_url=account_url,
146
+ credential=cred,
147
+ max_block_size=self.max_blocksize,
148
+ max_single_put_size=self.max_single_put_size,
149
+ )
150
+ else:
151
+ raise mlrun.errors.MLRunInvalidArgumentError(
152
+ "Must provide either a connection_string or account_name with credentials"
153
+ )
154
+ except Exception as e:
155
+ raise mlrun.errors.MLRunInvalidArgumentError(
156
+ f"unable to connect to account for {e}"
157
+ )
73
158
 
74
159
  def _convert_key_to_remote_path(self, key):
75
160
  key = key.strip("/")
@@ -82,7 +167,15 @@ class AzureBlobStore(DataStore):
82
167
 
83
168
  def upload(self, key, src_path):
84
169
  remote_path = self._convert_key_to_remote_path(key)
85
- self.filesystem.put_file(src_path, remote_path, overwrite=True)
170
+ container, remote_path = remote_path.split("/", 1)
171
+ container_client = self.service_client.get_container_client(container=container)
172
+ with open(file=src_path, mode="rb") as data:
173
+ container_client.upload_blob(
174
+ name=remote_path,
175
+ data=data,
176
+ overwrite=True,
177
+ max_concurrency=self.max_concurrency,
178
+ )
86
179
 
87
180
  def get(self, key, size=None, offset=0):
88
181
  remote_path = self._convert_key_to_remote_path(key)
@@ -135,7 +228,7 @@ class AzureBlobStore(DataStore):
135
228
 
136
229
  def get_spark_options(self):
137
230
  res = {}
138
- st = self.get_storage_options()
231
+ st = self.storage_options
139
232
  service = "blob"
140
233
  primary_url = None
141
234
  if st.get("connection_string"):
mlrun/datastore/base.py CHANGED
@@ -748,7 +748,7 @@ class HttpStore(DataStore):
748
748
  # As an example, it converts an S3 URL 's3://s3bucket/path' to just 's3bucket/path'.
749
749
  # Since 'ds' schemas are not inherently processed by fsspec, we have adapted the _strip_protocol()
750
750
  # method specifically to strip away the 'ds' schema as required.
751
- def makeDatastoreSchemaSanitizer(cls, using_bucket=False, *args, **kwargs):
751
+ def make_datastore_schema_sanitizer(cls, using_bucket=False, *args, **kwargs):
752
752
  if not issubclass(cls, fsspec.AbstractFileSystem):
753
753
  raise ValueError("Class must be a subclass of fsspec.AbstractFileSystem")
754
754
 
@@ -19,7 +19,7 @@ from fsspec.registry import get_filesystem_class
19
19
 
20
20
  import mlrun.errors
21
21
 
22
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
22
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
23
23
 
24
24
 
25
25
  class DatabricksFileBugFixed(DatabricksFile):
@@ -89,7 +89,7 @@ class DBFSStore(DataStore):
89
89
  """return fsspec file system object, if supported"""
90
90
  filesystem_class = get_filesystem_class(protocol=self.kind)
91
91
  if not self._filesystem:
92
- self._filesystem = makeDatastoreSchemaSanitizer(
92
+ self._filesystem = make_datastore_schema_sanitizer(
93
93
  cls=filesystem_class,
94
94
  using_bucket=False,
95
95
  **self.get_storage_options(),
@@ -12,44 +12,82 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import json
15
+ import os
15
16
  from pathlib import Path
16
17
 
17
18
  from fsspec.registry import get_filesystem_class
19
+ from google.auth.credentials import Credentials
20
+ from google.cloud.storage import Client, transfer_manager
21
+ from google.oauth2 import service_account
18
22
 
19
23
  import mlrun.errors
20
24
  from mlrun.utils import logger
21
25
 
22
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
26
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
23
27
 
24
28
  # Google storage objects will be represented with the following URL: gcs://<bucket name>/<path> or gs://...
25
29
 
26
30
 
27
31
  class GoogleCloudStorageStore(DataStore):
28
32
  using_bucket = True
33
+ workers = 8
34
+ chunk_size = 32 * 1024 * 1024
29
35
 
30
36
  def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
31
37
  super().__init__(parent, name, schema, endpoint, secrets=secrets)
38
+ self._storage_client = None
39
+ self._storage_options = None
40
+
41
+ @property
42
+ def storage_client(self):
43
+ if self._storage_client:
44
+ return self._storage_client
45
+
46
+ token = self._get_credentials().get("token")
47
+ access = "https://www.googleapis.com/auth/devstorage.full_control"
48
+ if isinstance(token, str):
49
+ if os.path.exists(token):
50
+ credentials = service_account.Credentials.from_service_account_file(
51
+ token, scopes=[access]
52
+ )
53
+ else:
54
+ raise mlrun.errors.MLRunInvalidArgumentError(
55
+ "gcsfs authentication file not found!"
56
+ )
57
+ elif isinstance(token, dict):
58
+ credentials = service_account.Credentials.from_service_account_info(
59
+ token, scopes=[access]
60
+ )
61
+ elif isinstance(token, Credentials):
62
+ credentials = token
63
+ else:
64
+ raise ValueError(f"Unsupported token type: {type(token)}")
65
+ self._storage_client = Client(credentials=credentials)
66
+ return self._storage_client
32
67
 
33
68
  @property
34
69
  def filesystem(self):
35
70
  """return fsspec file system object, if supported"""
36
- if self._filesystem:
37
- return self._filesystem
38
- try:
39
- import gcsfs # noqa
40
- except ImportError as exc:
41
- raise ImportError(
42
- "Google gcsfs not installed, run pip install gcsfs"
43
- ) from exc
44
- filesystem_class = get_filesystem_class(protocol=self.kind)
45
- self._filesystem = makeDatastoreSchemaSanitizer(
46
- filesystem_class,
47
- using_bucket=self.using_bucket,
48
- **self.get_storage_options(),
49
- )
71
+ if not self._filesystem:
72
+ filesystem_class = get_filesystem_class(protocol=self.kind)
73
+ self._filesystem = make_datastore_schema_sanitizer(
74
+ filesystem_class,
75
+ using_bucket=self.using_bucket,
76
+ **self.storage_options,
77
+ )
50
78
  return self._filesystem
51
79
 
52
- def get_storage_options(self):
80
+ @property
81
+ def storage_options(self):
82
+ if self._storage_options:
83
+ return self._storage_options
84
+ credentials = self._get_credentials()
85
+ # due to caching problem introduced in gcsfs 2024.3.1 (ML-7636)
86
+ credentials["use_listings_cache"] = False
87
+ self._storage_options = credentials
88
+ return self._storage_options
89
+
90
+ def _get_credentials(self):
53
91
  credentials = self._get_secret_or_env(
54
92
  "GCP_CREDENTIALS"
55
93
  ) or self._get_secret_or_env("GOOGLE_APPLICATION_CREDENTIALS")
@@ -71,6 +109,9 @@ class GoogleCloudStorageStore(DataStore):
71
109
  )
72
110
  return self._sanitize_storage_options(None)
73
111
 
112
+ def get_storage_options(self):
113
+ return self.storage_options
114
+
74
115
  def _make_path(self, key):
75
116
  key = key.strip("/")
76
117
  path = Path(self.endpoint, key).as_posix()
@@ -103,8 +144,29 @@ class GoogleCloudStorageStore(DataStore):
103
144
  f.write(data)
104
145
 
105
146
  def upload(self, key, src_path):
106
- path = self._make_path(key)
107
- self.filesystem.put_file(src_path, path, overwrite=True)
147
+ file_size = os.path.getsize(src_path)
148
+ united_path = self._make_path(key)
149
+
150
+ # Multiple upload limitation recommendations as described in
151
+ # https://cloud.google.com/storage/docs/multipart-uploads#storage-upload-object-chunks-python
152
+
153
+ if file_size <= self.chunk_size:
154
+ self.filesystem.put_file(src_path, united_path, overwrite=True)
155
+ return
156
+
157
+ bucket = self.storage_client.bucket(self.endpoint)
158
+ blob = bucket.blob(key.strip("/"))
159
+
160
+ try:
161
+ transfer_manager.upload_chunks_concurrently(
162
+ src_path, blob, chunk_size=self.chunk_size, max_workers=self.workers
163
+ )
164
+ except Exception as upload_chunks_concurrently_exception:
165
+ logger.warning(
166
+ f"gcs: failed to concurrently upload {src_path},"
167
+ f" exception: {upload_chunks_concurrently_exception}. Retrying with single part upload."
168
+ )
169
+ self.filesystem.put_file(src_path, united_path, overwrite=True)
108
170
 
109
171
  def stat(self, key):
110
172
  path = self._make_path(key)
@@ -133,12 +195,13 @@ class GoogleCloudStorageStore(DataStore):
133
195
 
134
196
  def rm(self, path, recursive=False, maxdepth=None):
135
197
  path = self._make_path(path)
198
+ # in order to raise an error in case of a connection error (ML-7056)
136
199
  self.filesystem.exists(path)
137
- self.filesystem.rm(path=path, recursive=recursive, maxdepth=maxdepth)
200
+ super().rm(path, recursive=recursive, maxdepth=maxdepth)
138
201
 
139
202
  def get_spark_options(self):
140
203
  res = {}
141
- st = self.get_storage_options()
204
+ st = self._get_credentials()
142
205
  if "token" in st:
143
206
  res = {"spark.hadoop.google.cloud.auth.service.account.enable": "true"}
144
207
  if isinstance(st["token"], str):
mlrun/datastore/s3.py CHANGED
@@ -20,7 +20,7 @@ from fsspec.registry import get_filesystem_class
20
20
 
21
21
  import mlrun.errors
22
22
 
23
- from .base import DataStore, FileStats, get_range, makeDatastoreSchemaSanitizer
23
+ from .base import DataStore, FileStats, get_range, make_datastore_schema_sanitizer
24
24
 
25
25
 
26
26
  class S3Store(DataStore):
@@ -126,7 +126,7 @@ class S3Store(DataStore):
126
126
  except ImportError as exc:
127
127
  raise ImportError("AWS s3fs not installed") from exc
128
128
  filesystem_class = get_filesystem_class(protocol=self.kind)
129
- self._filesystem = makeDatastoreSchemaSanitizer(
129
+ self._filesystem = make_datastore_schema_sanitizer(
130
130
  filesystem_class,
131
131
  using_bucket=self.using_bucket,
132
132
  **self.get_storage_options(),
@@ -32,6 +32,7 @@ from mlrun.config import config
32
32
  from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
33
33
  from mlrun.datastore.utils import transform_list_filters_to_tuple
34
34
  from mlrun.secrets import SecretsStore
35
+ from mlrun.utils import logger
35
36
 
36
37
  from ..model import DataSource
37
38
  from ..platforms.iguazio import parse_path
@@ -1163,6 +1164,59 @@ class KafkaSource(OnlineSource):
1163
1164
  "to a Spark dataframe is not possible, as this operation is not supported by Spark"
1164
1165
  )
1165
1166
 
1167
+ def create_topics(
1168
+ self,
1169
+ num_partitions: int = 4,
1170
+ replication_factor: int = 1,
1171
+ topics: list[str] = None,
1172
+ ):
1173
+ """
1174
+ Create Kafka topics with the specified number of partitions and replication factor.
1175
+
1176
+ :param num_partitions: number of partitions for the topics
1177
+ :param replication_factor: replication factor for the topics
1178
+ :param topics: list of topic names to create, if None,
1179
+ the topics will be taken from the source attributes
1180
+ """
1181
+ from kafka.admin import KafkaAdminClient, NewTopic
1182
+
1183
+ brokers = self.attributes.get("brokers")
1184
+ if not brokers:
1185
+ raise mlrun.errors.MLRunInvalidArgumentError(
1186
+ "brokers must be specified in the KafkaSource attributes"
1187
+ )
1188
+ topics = topics or self.attributes.get("topics")
1189
+ if not topics:
1190
+ raise mlrun.errors.MLRunInvalidArgumentError(
1191
+ "topics must be specified in the KafkaSource attributes"
1192
+ )
1193
+ new_topics = [
1194
+ NewTopic(topic, num_partitions, replication_factor) for topic in topics
1195
+ ]
1196
+ kafka_admin = KafkaAdminClient(
1197
+ bootstrap_servers=brokers,
1198
+ sasl_mechanism=self.attributes.get("sasl", {}).get("sasl_mechanism"),
1199
+ sasl_plain_username=self.attributes.get("sasl", {}).get("username"),
1200
+ sasl_plain_password=self.attributes.get("sasl", {}).get("password"),
1201
+ sasl_kerberos_service_name=self.attributes.get("sasl", {}).get(
1202
+ "sasl_kerberos_service_name", "kafka"
1203
+ ),
1204
+ sasl_kerberos_domain_name=self.attributes.get("sasl", {}).get(
1205
+ "sasl_kerberos_domain_name"
1206
+ ),
1207
+ sasl_oauth_token_provider=self.attributes.get("sasl", {}).get("mechanism"),
1208
+ )
1209
+ try:
1210
+ kafka_admin.create_topics(new_topics)
1211
+ finally:
1212
+ kafka_admin.close()
1213
+ logger.info(
1214
+ "Kafka topics created successfully",
1215
+ topics=topics,
1216
+ num_partitions=num_partitions,
1217
+ replication_factor=replication_factor,
1218
+ )
1219
+
1166
1220
 
1167
1221
  class SQLSource(BaseSourceDriver):
1168
1222
  kind = "sqldb"