mlrun 1.7.0rc38__py3-none-any.whl → 1.7.0rc41__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (59) hide show
  1. mlrun/alerts/alert.py +30 -27
  2. mlrun/common/constants.py +3 -0
  3. mlrun/common/helpers.py +0 -1
  4. mlrun/common/schemas/alert.py +3 -0
  5. mlrun/common/schemas/model_monitoring/model_endpoints.py +0 -1
  6. mlrun/common/schemas/notification.py +1 -0
  7. mlrun/config.py +1 -1
  8. mlrun/data_types/to_pandas.py +9 -9
  9. mlrun/datastore/alibaba_oss.py +3 -2
  10. mlrun/datastore/azure_blob.py +7 -9
  11. mlrun/datastore/base.py +13 -1
  12. mlrun/datastore/dbfs_store.py +3 -7
  13. mlrun/datastore/filestore.py +1 -3
  14. mlrun/datastore/google_cloud_storage.py +84 -29
  15. mlrun/datastore/redis.py +1 -0
  16. mlrun/datastore/s3.py +3 -2
  17. mlrun/datastore/sources.py +54 -0
  18. mlrun/datastore/storeytargets.py +147 -0
  19. mlrun/datastore/targets.py +76 -122
  20. mlrun/datastore/v3io.py +1 -0
  21. mlrun/db/httpdb.py +6 -1
  22. mlrun/errors.py +8 -0
  23. mlrun/execution.py +7 -0
  24. mlrun/feature_store/api.py +5 -0
  25. mlrun/feature_store/retrieval/job.py +1 -0
  26. mlrun/model.py +24 -3
  27. mlrun/model_monitoring/api.py +10 -2
  28. mlrun/model_monitoring/applications/_application_steps.py +52 -34
  29. mlrun/model_monitoring/applications/context.py +206 -70
  30. mlrun/model_monitoring/applications/histogram_data_drift.py +15 -13
  31. mlrun/model_monitoring/controller.py +15 -12
  32. mlrun/model_monitoring/db/stores/sqldb/sql_store.py +17 -8
  33. mlrun/model_monitoring/db/stores/v3io_kv/kv_store.py +19 -9
  34. mlrun/model_monitoring/db/tsdb/tdengine/schemas.py +85 -47
  35. mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py +46 -10
  36. mlrun/model_monitoring/db/tsdb/v3io/v3io_connector.py +38 -24
  37. mlrun/model_monitoring/helpers.py +54 -18
  38. mlrun/model_monitoring/stream_processing.py +10 -29
  39. mlrun/projects/pipelines.py +19 -30
  40. mlrun/projects/project.py +86 -67
  41. mlrun/run.py +8 -6
  42. mlrun/runtimes/__init__.py +4 -0
  43. mlrun/runtimes/nuclio/api_gateway.py +18 -0
  44. mlrun/runtimes/nuclio/application/application.py +150 -59
  45. mlrun/runtimes/nuclio/function.py +5 -11
  46. mlrun/runtimes/nuclio/serving.py +2 -2
  47. mlrun/runtimes/utils.py +16 -0
  48. mlrun/serving/routers.py +1 -1
  49. mlrun/serving/server.py +19 -5
  50. mlrun/serving/states.py +8 -0
  51. mlrun/serving/v2_serving.py +34 -26
  52. mlrun/utils/helpers.py +33 -2
  53. mlrun/utils/version/version.json +2 -2
  54. {mlrun-1.7.0rc38.dist-info → mlrun-1.7.0rc41.dist-info}/METADATA +9 -12
  55. {mlrun-1.7.0rc38.dist-info → mlrun-1.7.0rc41.dist-info}/RECORD +59 -58
  56. {mlrun-1.7.0rc38.dist-info → mlrun-1.7.0rc41.dist-info}/WHEEL +1 -1
  57. {mlrun-1.7.0rc38.dist-info → mlrun-1.7.0rc41.dist-info}/LICENSE +0 -0
  58. {mlrun-1.7.0rc38.dist-info → mlrun-1.7.0rc41.dist-info}/entry_points.txt +0 -0
  59. {mlrun-1.7.0rc38.dist-info → mlrun-1.7.0rc41.dist-info}/top_level.txt +0 -0
mlrun/alerts/alert.py CHANGED
@@ -62,6 +62,7 @@ class AlertConfig(ModelObj):
62
62
 
63
63
  # create an alert on endpoint_id, which will be triggered to slack if there is a "data_drift_detected" event
64
64
  # 3 times in the next hour.
65
+
65
66
  from mlrun.alerts import AlertConfig
66
67
  import mlrun.common.schemas.alert as alert_objects
67
68
 
@@ -94,29 +95,29 @@ class AlertConfig(ModelObj):
94
95
  )
95
96
  project.store_alert_config(alert_data)
96
97
 
97
- :param project: name of the project to associate the alert with
98
- :param name: name of the alert
99
- :param template: optional parameter that allows to create an alert based on a predefined template.
100
- you can pass either an AlertTemplate object or a string (the template name).
101
- if a template is used, many fields of the alert will be auto-generated based on the
102
- template. however, you still need to provide the following fields:
98
+ :param project: Name of the project to associate the alert with
99
+ :param name: Name of the alert
100
+ :param template: Optional parameter that allows creating an alert based on a predefined template.
101
+ You can pass either an AlertTemplate object or a string (the template name).
102
+ If a template is used, many fields of the alert will be auto-generated based on the
103
+ template.However, you still need to provide the following fields:
103
104
  `name`, `project`, `entity`, `notifications`
104
- :param description: description of the alert
105
- :param summary: summary of the alert, will be sent in the generated notifications
106
- :param severity: severity of the alert
107
- :param trigger: the events that will trigger this alert, may be a simple trigger based on events or
105
+ :param description: Description of the alert
106
+ :param summary: Summary of the alert, will be sent in the generated notifications
107
+ :param severity: Severity of the alert
108
+ :param trigger: The events that will trigger this alert, may be a simple trigger based on events or
108
109
  complex trigger which is based on a prometheus alert
109
- :param criteria: when the alert will be triggered based on the specified number of events within the
110
+ :param criteria: When the alert will be triggered based on the specified number of events within the
110
111
  defined time period.
111
- :param reset_policy: when to clear the alert. May be "manual" for manual reset of the alert, or
112
+ :param reset_policy: When to clear the alert. May be "manual" for manual reset of the alert, or
112
113
  "auto" if the criteria contains a time period
113
- :param notifications: list of notifications to invoke once the alert is triggered
114
- :param entities: entities that the event relates to. The entity object will contain fields that uniquely
115
- identify a given entity in the system
116
- :param id: internal id of the alert (user should not supply it)
117
- :param state: state of the alert, may be active/inactive (user should not supply it)
118
- :param created: when the alert is created (user should not supply it)
119
- :param count: internal counter of the alert (user should not supply it)
114
+ :param notifications: List of notifications to invoke once the alert is triggered
115
+ :param entities: Entities that the event relates to. The entity object will contain fields that
116
+ uniquely identify a given entity in the system
117
+ :param id: Internal id of the alert (user should not supply it)
118
+ :param state: State of the alert, may be active/inactive (user should not supply it)
119
+ :param created: When the alert is created (user should not supply it)
120
+ :param count: Internal counter of the alert (user should not supply it)
120
121
  """
121
122
  self.project = project
122
123
  self.name = name
@@ -137,8 +138,8 @@ class AlertConfig(ModelObj):
137
138
  self._apply_template(template)
138
139
 
139
140
  def validate_required_fields(self):
140
- if not self.project or not self.name:
141
- raise mlrun.errors.MLRunBadRequestError("Project and name must be provided")
141
+ if not self.name:
142
+ raise mlrun.errors.MLRunInvalidArgumentError("Alert name must be provided")
142
143
 
143
144
  def _serialize_field(
144
145
  self, struct: dict, field_name: str = None, strip: bool = False
@@ -237,9 +238,11 @@ class AlertConfig(ModelObj):
237
238
  db = mlrun.get_run_db()
238
239
  template = db.get_alert_template(template)
239
240
 
240
- # Extract parameters from the template and apply them to the AlertConfig object
241
- self.summary = template.summary
242
- self.severity = template.severity
243
- self.criteria = template.criteria
244
- self.trigger = template.trigger
245
- self.reset_policy = template.reset_policy
241
+ # Apply parameters from the template to the AlertConfig object only if they are not already specified by the
242
+ # user in the current configuration.
243
+ # User-provided parameters will take precedence over corresponding template values
244
+ self.summary = self.summary or template.summary
245
+ self.severity = self.severity or template.severity
246
+ self.criteria = self.criteria or template.criteria
247
+ self.trigger = self.trigger or template.trigger
248
+ self.reset_policy = self.reset_policy or template.reset_policy
mlrun/common/constants.py CHANGED
@@ -65,6 +65,9 @@ class MLRunInternalLabels:
65
65
  task_name = f"{MLRUN_LABEL_PREFIX}task-name"
66
66
  resource_name = f"{MLRUN_LABEL_PREFIX}resource_name"
67
67
  created = f"{MLRUN_LABEL_PREFIX}created"
68
+ producer_type = f"{MLRUN_LABEL_PREFIX}producer-type"
69
+ app_name = f"{MLRUN_LABEL_PREFIX}app-name"
70
+ endpoint_id = f"{MLRUN_LABEL_PREFIX}endpoint-id"
68
71
  host = "host"
69
72
  job_type = "job-type"
70
73
  kind = "kind"
mlrun/common/helpers.py CHANGED
@@ -11,7 +11,6 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- #
15
14
 
16
15
 
17
16
  def parse_versioned_object_uri(
@@ -23,6 +23,7 @@ from mlrun.common.types import StrEnum
23
23
 
24
24
  class EventEntityKind(StrEnum):
25
25
  MODEL_ENDPOINT_RESULT = "model-endpoint-result"
26
+ MODEL_MONITORING_APPLICATION = "model-monitoring-application"
26
27
  JOB = "job"
27
28
 
28
29
 
@@ -43,6 +44,7 @@ class EventKind(StrEnum):
43
44
  SYSTEM_PERFORMANCE_SUSPECTED = "system_performance_suspected"
44
45
  MM_APP_ANOMALY_DETECTED = "mm_app_anomaly_detected"
45
46
  MM_APP_ANOMALY_SUSPECTED = "mm_app_anomaly_suspected"
47
+ MM_APP_FAILED = "mm_app_failed"
46
48
  FAILED = "failed"
47
49
 
48
50
 
@@ -57,6 +59,7 @@ _event_kind_entity_map = {
57
59
  EventKind.SYSTEM_PERFORMANCE_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
58
60
  EventKind.MM_APP_ANOMALY_DETECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
59
61
  EventKind.MM_APP_ANOMALY_SUSPECTED: [EventEntityKind.MODEL_ENDPOINT_RESULT],
62
+ EventKind.MM_APP_FAILED: [EventEntityKind.MODEL_MONITORING_APPLICATION],
60
63
  EventKind.FAILED: [EventEntityKind.JOB],
61
64
  }
62
65
 
@@ -21,7 +21,6 @@ from typing import Any, NamedTuple, Optional
21
21
  from pydantic import BaseModel, Field, validator
22
22
  from pydantic.main import Extra
23
23
 
24
- import mlrun.common.model_monitoring
25
24
  import mlrun.common.types
26
25
 
27
26
  from ..object import ObjectKind, ObjectSpec, ObjectStatus
@@ -52,6 +52,7 @@ class NotificationLimits(enum.Enum):
52
52
  class Notification(pydantic.BaseModel):
53
53
  """
54
54
  Notification object schema
55
+
55
56
  :param kind: notification implementation kind - slack, webhook, etc.
56
57
  :param name: for logging and identification
57
58
  :param message: message content in the notification
mlrun/config.py CHANGED
@@ -863,7 +863,7 @@ class Config:
863
863
  f"Unable to decode {attribute_path}"
864
864
  )
865
865
  parsed_attribute_value = json.loads(decoded_attribute_value)
866
- if type(parsed_attribute_value) != expected_type:
866
+ if not isinstance(parsed_attribute_value, expected_type):
867
867
  raise mlrun.errors.MLRunInvalidArgumentTypeError(
868
868
  f"Expected type {expected_type}, got {type(parsed_attribute_value)}"
869
869
  )
@@ -21,7 +21,7 @@ import semver
21
21
 
22
22
  def _toPandas(spark_df):
23
23
  """
24
- Modified version of spark DataFrame.toPandas()
24
+ Modified version of spark DataFrame.toPandas() -
25
25
  https://github.com/apache/spark/blob/v3.2.3/python/pyspark/sql/pandas/conversion.py#L35
26
26
 
27
27
  The original code (which is only replaced in pyspark 3.5.0) fails with Pandas 2 installed, with the following error:
@@ -223,21 +223,21 @@ def _to_corrected_pandas_type(dt):
223
223
  TimestampType,
224
224
  )
225
225
 
226
- if type(dt) == ByteType:
226
+ if isinstance(dt, ByteType):
227
227
  return np.int8
228
- elif type(dt) == ShortType:
228
+ elif isinstance(dt, ShortType):
229
229
  return np.int16
230
- elif type(dt) == IntegerType:
230
+ elif isinstance(dt, IntegerType):
231
231
  return np.int32
232
- elif type(dt) == LongType:
232
+ elif isinstance(dt, LongType):
233
233
  return np.int64
234
- elif type(dt) == FloatType:
234
+ elif isinstance(dt, FloatType):
235
235
  return np.float32
236
- elif type(dt) == DoubleType:
236
+ elif isinstance(dt, DoubleType):
237
237
  return np.float64
238
- elif type(dt) == BooleanType:
238
+ elif isinstance(dt, BooleanType):
239
239
  return bool
240
- elif type(dt) == TimestampType:
240
+ elif isinstance(dt, TimestampType):
241
241
  return "datetime64[ns]"
242
242
  else:
243
243
  return None
@@ -22,7 +22,7 @@ from fsspec.registry import get_filesystem_class
22
22
 
23
23
  import mlrun.errors
24
24
 
25
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
25
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
26
26
 
27
27
 
28
28
  class OSSStore(DataStore):
@@ -53,7 +53,7 @@ class OSSStore(DataStore):
53
53
  except ImportError as exc:
54
54
  raise ImportError("ALIBABA ossfs not installed") from exc
55
55
  filesystem_class = get_filesystem_class(protocol=self.kind)
56
- self._filesystem = makeDatastoreSchemaSanitizer(
56
+ self._filesystem = make_datastore_schema_sanitizer(
57
57
  filesystem_class,
58
58
  using_bucket=self.using_bucket,
59
59
  **self.get_storage_options(),
@@ -85,6 +85,7 @@ class OSSStore(DataStore):
85
85
  return oss.get_object(key).read()
86
86
 
87
87
  def put(self, key, data, append=False):
88
+ data, _ = self._prepare_put_data(data, append)
88
89
  bucket, key = self.get_bucket_and_key(key)
89
90
  oss = oss2.Bucket(self.auth, self.endpoint_url, bucket)
90
91
  oss.put_object(key, data)
@@ -22,7 +22,7 @@ from fsspec.registry import get_filesystem_class
22
22
 
23
23
  import mlrun.errors
24
24
 
25
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
25
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
26
26
 
27
27
  # Azure blobs will be represented with the following URL: az://<container name>. The storage account is already
28
28
  # pointed to by the connection string, so the user is not expected to specify it in any way.
@@ -41,6 +41,9 @@ class AzureBlobStore(DataStore):
41
41
  self._service_client = None
42
42
  self._storage_options = None
43
43
 
44
+ def get_storage_options(self):
45
+ return self.storage_options
46
+
44
47
  @property
45
48
  def storage_options(self):
46
49
  if not self._storage_options:
@@ -75,7 +78,7 @@ class AzureBlobStore(DataStore):
75
78
  if not self._filesystem:
76
79
  # in order to support az and wasbs kinds
77
80
  filesystem_class = get_filesystem_class(protocol=self.kind)
78
- self._filesystem = makeDatastoreSchemaSanitizer(
81
+ self._filesystem = make_datastore_schema_sanitizer(
79
82
  filesystem_class,
80
83
  using_bucket=self.using_bucket,
81
84
  blocksize=self.max_blocksize,
@@ -186,12 +189,7 @@ class AzureBlobStore(DataStore):
186
189
  "Append mode not supported for Azure blob datastore"
187
190
  )
188
191
  remote_path = self._convert_key_to_remote_path(key)
189
- if isinstance(data, bytes):
190
- mode = "wb"
191
- elif isinstance(data, str):
192
- mode = "w"
193
- else:
194
- raise TypeError("Data type unknown. Unable to put in Azure!")
192
+ data, mode = self._prepare_put_data(data, append)
195
193
  with self.filesystem.open(remote_path, mode) as f:
196
194
  f.write(data)
197
195
 
@@ -225,7 +223,7 @@ class AzureBlobStore(DataStore):
225
223
 
226
224
  def get_spark_options(self):
227
225
  res = {}
228
- st = self.storage_options()
226
+ st = self.storage_options
229
227
  service = "blob"
230
228
  primary_url = None
231
229
  if st.get("connection_string"):
mlrun/datastore/base.py CHANGED
@@ -157,6 +157,18 @@ class DataStore:
157
157
  def put(self, key, data, append=False):
158
158
  pass
159
159
 
160
+ def _prepare_put_data(self, data, append=False):
161
+ mode = "a" if append else "w"
162
+ if isinstance(data, bytearray):
163
+ data = bytes(data)
164
+
165
+ if isinstance(data, bytes):
166
+ return data, f"{mode}b"
167
+ elif isinstance(data, str):
168
+ return data, mode
169
+ else:
170
+ raise TypeError(f"Unable to put a value of type {type(self).__name__}")
171
+
160
172
  def stat(self, key):
161
173
  pass
162
174
 
@@ -748,7 +760,7 @@ class HttpStore(DataStore):
748
760
  # As an example, it converts an S3 URL 's3://s3bucket/path' to just 's3bucket/path'.
749
761
  # Since 'ds' schemas are not inherently processed by fsspec, we have adapted the _strip_protocol()
750
762
  # method specifically to strip away the 'ds' schema as required.
751
- def makeDatastoreSchemaSanitizer(cls, using_bucket=False, *args, **kwargs):
763
+ def make_datastore_schema_sanitizer(cls, using_bucket=False, *args, **kwargs):
752
764
  if not issubclass(cls, fsspec.AbstractFileSystem):
753
765
  raise ValueError("Class must be a subclass of fsspec.AbstractFileSystem")
754
766
 
@@ -19,7 +19,7 @@ from fsspec.registry import get_filesystem_class
19
19
 
20
20
  import mlrun.errors
21
21
 
22
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
22
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
23
23
 
24
24
 
25
25
  class DatabricksFileBugFixed(DatabricksFile):
@@ -89,7 +89,7 @@ class DBFSStore(DataStore):
89
89
  """return fsspec file system object, if supported"""
90
90
  filesystem_class = get_filesystem_class(protocol=self.kind)
91
91
  if not self._filesystem:
92
- self._filesystem = makeDatastoreSchemaSanitizer(
92
+ self._filesystem = make_datastore_schema_sanitizer(
93
93
  cls=filesystem_class,
94
94
  using_bucket=False,
95
95
  **self.get_storage_options(),
@@ -130,11 +130,7 @@ class DBFSStore(DataStore):
130
130
  "Append mode not supported for Databricks file system"
131
131
  )
132
132
  # can not use append mode because it overrides data.
133
- mode = "w"
134
- if isinstance(data, bytes):
135
- mode += "b"
136
- elif not isinstance(data, str):
137
- raise TypeError(f"Unknown data type {type(data)}")
133
+ data, mode = self._prepare_put_data(data, append)
138
134
  with self.filesystem.open(key, mode) as f:
139
135
  f.write(data)
140
136
 
@@ -66,9 +66,7 @@ class FileStore(DataStore):
66
66
  dir_to_create = path.dirname(self._join(key))
67
67
  if dir_to_create:
68
68
  self._ensure_directory(dir_to_create)
69
- mode = "a" if append else "w"
70
- if isinstance(data, bytes):
71
- mode = mode + "b"
69
+ data, mode = self._prepare_put_data(data, append)
72
70
  with open(self._join(key), mode) as fp:
73
71
  fp.write(data)
74
72
  fp.close()
@@ -12,44 +12,82 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  import json
15
+ import os
15
16
  from pathlib import Path
16
17
 
17
18
  from fsspec.registry import get_filesystem_class
19
+ from google.auth.credentials import Credentials
20
+ from google.cloud.storage import Client, transfer_manager
21
+ from google.oauth2 import service_account
18
22
 
19
23
  import mlrun.errors
20
24
  from mlrun.utils import logger
21
25
 
22
- from .base import DataStore, FileStats, makeDatastoreSchemaSanitizer
26
+ from .base import DataStore, FileStats, make_datastore_schema_sanitizer
23
27
 
24
28
  # Google storage objects will be represented with the following URL: gcs://<bucket name>/<path> or gs://...
25
29
 
26
30
 
27
31
  class GoogleCloudStorageStore(DataStore):
28
32
  using_bucket = True
33
+ workers = 8
34
+ chunk_size = 32 * 1024 * 1024
29
35
 
30
36
  def __init__(self, parent, schema, name, endpoint="", secrets: dict = None):
31
37
  super().__init__(parent, name, schema, endpoint, secrets=secrets)
38
+ self._storage_client = None
39
+ self._storage_options = None
40
+
41
+ @property
42
+ def storage_client(self):
43
+ if self._storage_client:
44
+ return self._storage_client
45
+
46
+ token = self._get_credentials().get("token")
47
+ access = "https://www.googleapis.com/auth/devstorage.full_control"
48
+ if isinstance(token, str):
49
+ if os.path.exists(token):
50
+ credentials = service_account.Credentials.from_service_account_file(
51
+ token, scopes=[access]
52
+ )
53
+ else:
54
+ raise mlrun.errors.MLRunInvalidArgumentError(
55
+ "gcsfs authentication file not found!"
56
+ )
57
+ elif isinstance(token, dict):
58
+ credentials = service_account.Credentials.from_service_account_info(
59
+ token, scopes=[access]
60
+ )
61
+ elif isinstance(token, Credentials):
62
+ credentials = token
63
+ else:
64
+ raise ValueError(f"Unsupported token type: {type(token)}")
65
+ self._storage_client = Client(credentials=credentials)
66
+ return self._storage_client
32
67
 
33
68
  @property
34
69
  def filesystem(self):
35
70
  """return fsspec file system object, if supported"""
36
- if self._filesystem:
37
- return self._filesystem
38
- try:
39
- import gcsfs # noqa
40
- except ImportError as exc:
41
- raise ImportError(
42
- "Google gcsfs not installed, run pip install gcsfs"
43
- ) from exc
44
- filesystem_class = get_filesystem_class(protocol=self.kind)
45
- self._filesystem = makeDatastoreSchemaSanitizer(
46
- filesystem_class,
47
- using_bucket=self.using_bucket,
48
- **self.get_storage_options(),
49
- )
71
+ if not self._filesystem:
72
+ filesystem_class = get_filesystem_class(protocol=self.kind)
73
+ self._filesystem = make_datastore_schema_sanitizer(
74
+ filesystem_class,
75
+ using_bucket=self.using_bucket,
76
+ **self.storage_options,
77
+ )
50
78
  return self._filesystem
51
79
 
52
- def get_storage_options(self):
80
+ @property
81
+ def storage_options(self):
82
+ if self._storage_options:
83
+ return self._storage_options
84
+ credentials = self._get_credentials()
85
+ # due to caching problem introduced in gcsfs 2024.3.1 (ML-7636)
86
+ credentials["use_listings_cache"] = False
87
+ self._storage_options = credentials
88
+ return self._storage_options
89
+
90
+ def _get_credentials(self):
53
91
  credentials = self._get_secret_or_env(
54
92
  "GCP_CREDENTIALS"
55
93
  ) or self._get_secret_or_env("GOOGLE_APPLICATION_CREDENTIALS")
@@ -71,6 +109,9 @@ class GoogleCloudStorageStore(DataStore):
71
109
  )
72
110
  return self._sanitize_storage_options(None)
73
111
 
112
+ def get_storage_options(self):
113
+ return self.storage_options
114
+
74
115
  def _make_path(self, key):
75
116
  key = key.strip("/")
76
117
  path = Path(self.endpoint, key).as_posix()
@@ -90,21 +131,34 @@ class GoogleCloudStorageStore(DataStore):
90
131
  raise mlrun.errors.MLRunInvalidArgumentError(
91
132
  "Append mode not supported for Google cloud storage datastore"
92
133
  )
93
-
94
- if isinstance(data, bytes):
95
- mode = "wb"
96
- elif isinstance(data, str):
97
- mode = "w"
98
- else:
99
- raise TypeError(
100
- "Data type unknown. Unable to put in Google cloud storage!"
101
- )
134
+ data, mode = self._prepare_put_data(data, append)
102
135
  with self.filesystem.open(path, mode) as f:
103
136
  f.write(data)
104
137
 
105
138
  def upload(self, key, src_path):
106
- path = self._make_path(key)
107
- self.filesystem.put_file(src_path, path, overwrite=True)
139
+ file_size = os.path.getsize(src_path)
140
+ united_path = self._make_path(key)
141
+
142
+ # Multiple upload limitation recommendations as described in
143
+ # https://cloud.google.com/storage/docs/multipart-uploads#storage-upload-object-chunks-python
144
+
145
+ if file_size <= self.chunk_size:
146
+ self.filesystem.put_file(src_path, united_path, overwrite=True)
147
+ return
148
+
149
+ bucket = self.storage_client.bucket(self.endpoint)
150
+ blob = bucket.blob(key.strip("/"))
151
+
152
+ try:
153
+ transfer_manager.upload_chunks_concurrently(
154
+ src_path, blob, chunk_size=self.chunk_size, max_workers=self.workers
155
+ )
156
+ except Exception as upload_chunks_concurrently_exception:
157
+ logger.warning(
158
+ f"gcs: failed to concurrently upload {src_path},"
159
+ f" exception: {upload_chunks_concurrently_exception}. Retrying with single part upload."
160
+ )
161
+ self.filesystem.put_file(src_path, united_path, overwrite=True)
108
162
 
109
163
  def stat(self, key):
110
164
  path = self._make_path(key)
@@ -133,12 +187,13 @@ class GoogleCloudStorageStore(DataStore):
133
187
 
134
188
  def rm(self, path, recursive=False, maxdepth=None):
135
189
  path = self._make_path(path)
190
+ # in order to raise an error in case of a connection error (ML-7056)
136
191
  self.filesystem.exists(path)
137
- self.filesystem.rm(path=path, recursive=recursive, maxdepth=maxdepth)
192
+ super().rm(path, recursive=recursive, maxdepth=maxdepth)
138
193
 
139
194
  def get_spark_options(self):
140
195
  res = {}
141
- st = self.get_storage_options()
196
+ st = self._get_credentials()
142
197
  if "token" in st:
143
198
  res = {"spark.hadoop.google.cloud.auth.service.account.enable": "true"}
144
199
  if isinstance(st["token"], str):
mlrun/datastore/redis.py CHANGED
@@ -126,6 +126,7 @@ class RedisStore(DataStore):
126
126
 
127
127
  def put(self, key, data, append=False):
128
128
  key = RedisStore.build_redis_key(key)
129
+ data, _ = self._prepare_put_data(data, append)
129
130
  if append:
130
131
  self.redis.append(key, data)
131
132
  else:
mlrun/datastore/s3.py CHANGED
@@ -20,7 +20,7 @@ from fsspec.registry import get_filesystem_class
20
20
 
21
21
  import mlrun.errors
22
22
 
23
- from .base import DataStore, FileStats, get_range, makeDatastoreSchemaSanitizer
23
+ from .base import DataStore, FileStats, get_range, make_datastore_schema_sanitizer
24
24
 
25
25
 
26
26
  class S3Store(DataStore):
@@ -126,7 +126,7 @@ class S3Store(DataStore):
126
126
  except ImportError as exc:
127
127
  raise ImportError("AWS s3fs not installed") from exc
128
128
  filesystem_class = get_filesystem_class(protocol=self.kind)
129
- self._filesystem = makeDatastoreSchemaSanitizer(
129
+ self._filesystem = make_datastore_schema_sanitizer(
130
130
  filesystem_class,
131
131
  using_bucket=self.using_bucket,
132
132
  **self.get_storage_options(),
@@ -183,6 +183,7 @@ class S3Store(DataStore):
183
183
  return obj.get()["Body"].read()
184
184
 
185
185
  def put(self, key, data, append=False):
186
+ data, _ = self._prepare_put_data(data, append)
186
187
  bucket, key = self.get_bucket_and_key(key)
187
188
  self.s3.Object(bucket, key).put(Body=data)
188
189
 
@@ -32,6 +32,7 @@ from mlrun.config import config
32
32
  from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
33
33
  from mlrun.datastore.utils import transform_list_filters_to_tuple
34
34
  from mlrun.secrets import SecretsStore
35
+ from mlrun.utils import logger
35
36
 
36
37
  from ..model import DataSource
37
38
  from ..platforms.iguazio import parse_path
@@ -1163,6 +1164,59 @@ class KafkaSource(OnlineSource):
1163
1164
  "to a Spark dataframe is not possible, as this operation is not supported by Spark"
1164
1165
  )
1165
1166
 
1167
+ def create_topics(
1168
+ self,
1169
+ num_partitions: int = 4,
1170
+ replication_factor: int = 1,
1171
+ topics: list[str] = None,
1172
+ ):
1173
+ """
1174
+ Create Kafka topics with the specified number of partitions and replication factor.
1175
+
1176
+ :param num_partitions: number of partitions for the topics
1177
+ :param replication_factor: replication factor for the topics
1178
+ :param topics: list of topic names to create, if None,
1179
+ the topics will be taken from the source attributes
1180
+ """
1181
+ from kafka.admin import KafkaAdminClient, NewTopic
1182
+
1183
+ brokers = self.attributes.get("brokers")
1184
+ if not brokers:
1185
+ raise mlrun.errors.MLRunInvalidArgumentError(
1186
+ "brokers must be specified in the KafkaSource attributes"
1187
+ )
1188
+ topics = topics or self.attributes.get("topics")
1189
+ if not topics:
1190
+ raise mlrun.errors.MLRunInvalidArgumentError(
1191
+ "topics must be specified in the KafkaSource attributes"
1192
+ )
1193
+ new_topics = [
1194
+ NewTopic(topic, num_partitions, replication_factor) for topic in topics
1195
+ ]
1196
+ kafka_admin = KafkaAdminClient(
1197
+ bootstrap_servers=brokers,
1198
+ sasl_mechanism=self.attributes.get("sasl", {}).get("sasl_mechanism"),
1199
+ sasl_plain_username=self.attributes.get("sasl", {}).get("username"),
1200
+ sasl_plain_password=self.attributes.get("sasl", {}).get("password"),
1201
+ sasl_kerberos_service_name=self.attributes.get("sasl", {}).get(
1202
+ "sasl_kerberos_service_name", "kafka"
1203
+ ),
1204
+ sasl_kerberos_domain_name=self.attributes.get("sasl", {}).get(
1205
+ "sasl_kerberos_domain_name"
1206
+ ),
1207
+ sasl_oauth_token_provider=self.attributes.get("sasl", {}).get("mechanism"),
1208
+ )
1209
+ try:
1210
+ kafka_admin.create_topics(new_topics)
1211
+ finally:
1212
+ kafka_admin.close()
1213
+ logger.info(
1214
+ "Kafka topics created successfully",
1215
+ topics=topics,
1216
+ num_partitions=num_partitions,
1217
+ replication_factor=replication_factor,
1218
+ )
1219
+
1166
1220
 
1167
1221
  class SQLSource(BaseSourceDriver):
1168
1222
  kind = "sqldb"