mlrun 1.7.0rc7__py3-none-any.whl → 1.7.0rc11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlrun might be problematic. Click here for more details.

Files changed (91) hide show
  1. mlrun/__init__.py +1 -0
  2. mlrun/__main__.py +2 -0
  3. mlrun/artifacts/model.py +29 -25
  4. mlrun/common/schemas/__init__.py +4 -0
  5. mlrun/common/schemas/alert.py +122 -0
  6. mlrun/common/schemas/api_gateway.py +8 -1
  7. mlrun/common/schemas/auth.py +4 -0
  8. mlrun/common/schemas/client_spec.py +1 -0
  9. mlrun/common/schemas/hub.py +7 -9
  10. mlrun/common/schemas/model_monitoring/constants.py +4 -2
  11. mlrun/{datastore/helpers.py → common/schemas/pagination.py} +11 -3
  12. mlrun/common/schemas/project.py +15 -10
  13. mlrun/config.py +35 -13
  14. mlrun/datastore/__init__.py +3 -7
  15. mlrun/datastore/base.py +6 -5
  16. mlrun/datastore/datastore_profile.py +19 -1
  17. mlrun/datastore/snowflake_utils.py +43 -0
  18. mlrun/datastore/sources.py +18 -30
  19. mlrun/datastore/targets.py +140 -12
  20. mlrun/datastore/utils.py +10 -5
  21. mlrun/datastore/v3io.py +27 -50
  22. mlrun/db/base.py +88 -2
  23. mlrun/db/httpdb.py +314 -41
  24. mlrun/db/nopdb.py +142 -0
  25. mlrun/execution.py +21 -14
  26. mlrun/feature_store/api.py +9 -5
  27. mlrun/feature_store/feature_set.py +39 -23
  28. mlrun/feature_store/feature_vector.py +2 -1
  29. mlrun/feature_store/retrieval/spark_merger.py +27 -23
  30. mlrun/feature_store/steps.py +30 -19
  31. mlrun/features.py +4 -13
  32. mlrun/frameworks/auto_mlrun/auto_mlrun.py +2 -2
  33. mlrun/frameworks/lgbm/__init__.py +1 -1
  34. mlrun/frameworks/lgbm/callbacks/callback.py +2 -4
  35. mlrun/frameworks/lgbm/model_handler.py +1 -1
  36. mlrun/frameworks/pytorch/__init__.py +2 -2
  37. mlrun/frameworks/sklearn/__init__.py +1 -1
  38. mlrun/frameworks/tf_keras/__init__.py +1 -1
  39. mlrun/frameworks/tf_keras/callbacks/logging_callback.py +1 -1
  40. mlrun/frameworks/tf_keras/mlrun_interface.py +2 -2
  41. mlrun/frameworks/xgboost/__init__.py +1 -1
  42. mlrun/kfpops.py +2 -5
  43. mlrun/launcher/base.py +1 -1
  44. mlrun/launcher/client.py +2 -2
  45. mlrun/model.py +2 -2
  46. mlrun/model_monitoring/application.py +11 -2
  47. mlrun/model_monitoring/applications/histogram_data_drift.py +3 -3
  48. mlrun/model_monitoring/controller.py +2 -3
  49. mlrun/model_monitoring/helpers.py +3 -1
  50. mlrun/model_monitoring/stream_processing.py +0 -1
  51. mlrun/model_monitoring/writer.py +32 -0
  52. mlrun/package/packagers_manager.py +1 -0
  53. mlrun/platforms/__init__.py +1 -1
  54. mlrun/platforms/other.py +1 -1
  55. mlrun/projects/operations.py +11 -4
  56. mlrun/projects/pipelines.py +1 -1
  57. mlrun/projects/project.py +180 -73
  58. mlrun/run.py +77 -41
  59. mlrun/runtimes/__init__.py +16 -0
  60. mlrun/runtimes/base.py +4 -1
  61. mlrun/runtimes/kubejob.py +26 -121
  62. mlrun/runtimes/mpijob/abstract.py +8 -8
  63. mlrun/runtimes/nuclio/api_gateway.py +58 -8
  64. mlrun/runtimes/nuclio/application/application.py +79 -1
  65. mlrun/runtimes/nuclio/application/reverse_proxy.go +9 -1
  66. mlrun/runtimes/nuclio/function.py +20 -13
  67. mlrun/runtimes/nuclio/serving.py +11 -10
  68. mlrun/runtimes/pod.py +148 -3
  69. mlrun/runtimes/utils.py +0 -28
  70. mlrun/secrets.py +6 -2
  71. mlrun/serving/remote.py +2 -3
  72. mlrun/serving/routers.py +7 -4
  73. mlrun/serving/server.py +1 -1
  74. mlrun/serving/states.py +14 -38
  75. mlrun/serving/v2_serving.py +8 -7
  76. mlrun/utils/helpers.py +1 -1
  77. mlrun/utils/http.py +1 -1
  78. mlrun/utils/notifications/notification/base.py +12 -0
  79. mlrun/utils/notifications/notification/console.py +2 -0
  80. mlrun/utils/notifications/notification/git.py +3 -1
  81. mlrun/utils/notifications/notification/ipython.py +2 -0
  82. mlrun/utils/notifications/notification/slack.py +41 -13
  83. mlrun/utils/notifications/notification/webhook.py +11 -1
  84. mlrun/utils/retryer.py +3 -2
  85. mlrun/utils/version/version.json +2 -2
  86. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc11.dist-info}/METADATA +15 -15
  87. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc11.dist-info}/RECORD +91 -89
  88. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc11.dist-info}/LICENSE +0 -0
  89. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc11.dist-info}/WHEEL +0 -0
  90. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc11.dist-info}/entry_points.txt +0 -0
  91. {mlrun-1.7.0rc7.dist-info → mlrun-1.7.0rc11.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,43 @@
1
+ # Copyright 2024 Iguazio
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+
16
+ import mlrun
17
+
18
+
19
+ def get_snowflake_password():
20
+ key = "SNOWFLAKE_PASSWORD"
21
+ snowflake_password = mlrun.get_secret_or_env(key)
22
+
23
+ if not snowflake_password:
24
+ raise mlrun.errors.MLRunInvalidArgumentError(
25
+ f"No password provided. Set password using the {key} "
26
+ "project secret or environment variable."
27
+ )
28
+
29
+ return snowflake_password
30
+
31
+
32
+ def get_snowflake_spark_options(attributes):
33
+ return {
34
+ "format": "net.snowflake.spark.snowflake",
35
+ "sfURL": attributes.get("url"),
36
+ "sfUser": attributes.get("user"),
37
+ "sfPassword": get_snowflake_password(),
38
+ "sfDatabase": attributes.get("database"),
39
+ "sfSchema": attributes.get("schema"),
40
+ "sfWarehouse": attributes.get("warehouse"),
41
+ "application": "iguazio_platform",
42
+ "TIMESTAMP_TYPE_MAPPING": "TIMESTAMP_LTZ",
43
+ }
@@ -28,6 +28,7 @@ from nuclio.config import split_path
28
28
 
29
29
  import mlrun
30
30
  from mlrun.config import config
31
+ from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
31
32
  from mlrun.secrets import SecretsStore
32
33
 
33
34
  from ..model import DataSource
@@ -113,7 +114,11 @@ class BaseSourceDriver(DataSource):
113
114
 
114
115
  def to_spark_df(self, session, named_view=False, time_field=None, columns=None):
115
116
  if self.support_spark:
116
- df = load_spark_dataframe_with_options(session, self.get_spark_options())
117
+ spark_options = self.get_spark_options()
118
+ spark_format = spark_options.pop("format", None)
119
+ df = load_spark_dataframe_with_options(
120
+ session, spark_options, format=spark_format
121
+ )
117
122
  if named_view:
118
123
  df.createOrReplaceTempView(self.name)
119
124
  return self._filter_spark_df(df, time_field, columns)
@@ -401,12 +406,17 @@ class BigQuerySource(BaseSourceDriver):
401
406
 
402
407
  # use sql query
403
408
  query_string = "SELECT * FROM `the-psf.pypi.downloads20210328` LIMIT 5000"
404
- source = BigQuerySource("bq1", query=query_string,
405
- gcp_project="my_project",
406
- materialization_dataset="dataviews")
409
+ source = BigQuerySource(
410
+ "bq1",
411
+ query=query_string,
412
+ gcp_project="my_project",
413
+ materialization_dataset="dataviews",
414
+ )
407
415
 
408
416
  # read a table
409
- source = BigQuerySource("bq2", table="the-psf.pypi.downloads20210328", gcp_project="my_project")
417
+ source = BigQuerySource(
418
+ "bq2", table="the-psf.pypi.downloads20210328", gcp_project="my_project"
419
+ )
410
420
 
411
421
 
412
422
  :parameter name: source name
@@ -673,32 +683,10 @@ class SnowflakeSource(BaseSourceDriver):
673
683
  **kwargs,
674
684
  )
675
685
 
676
- def _get_password(self):
677
- key = "SNOWFLAKE_PASSWORD"
678
- snowflake_password = os.getenv(key) or os.getenv(
679
- SecretsStore.k8s_env_variable_name_for_secret(key)
680
- )
681
-
682
- if not snowflake_password:
683
- raise mlrun.errors.MLRunInvalidArgumentError(
684
- "No password provided. Set password using the SNOWFLAKE_PASSWORD "
685
- "project secret or environment variable."
686
- )
687
-
688
- return snowflake_password
689
-
690
686
  def get_spark_options(self):
691
- return {
692
- "format": "net.snowflake.spark.snowflake",
693
- "query": self.attributes.get("query"),
694
- "sfURL": self.attributes.get("url"),
695
- "sfUser": self.attributes.get("user"),
696
- "sfPassword": self._get_password(),
697
- "sfDatabase": self.attributes.get("database"),
698
- "sfSchema": self.attributes.get("schema"),
699
- "sfWarehouse": self.attributes.get("warehouse"),
700
- "application": "iguazio_platform",
701
- }
687
+ spark_options = get_snowflake_spark_options(self.attributes)
688
+ spark_options["query"] = self.attributes.get("query")
689
+ return spark_options
702
690
 
703
691
 
704
692
  class CustomSource(BaseSourceDriver):
@@ -17,6 +17,7 @@ import os
17
17
  import random
18
18
  import sys
19
19
  import time
20
+ import warnings
20
21
  from collections import Counter
21
22
  from copy import copy
22
23
  from typing import Any, Optional, Union
@@ -28,6 +29,7 @@ from mergedeep import merge
28
29
  import mlrun
29
30
  import mlrun.utils.helpers
30
31
  from mlrun.config import config
32
+ from mlrun.datastore.snowflake_utils import get_snowflake_spark_options
31
33
  from mlrun.model import DataSource, DataTarget, DataTargetBase, TargetPathObject
32
34
  from mlrun.utils import logger, now_date
33
35
  from mlrun.utils.helpers import to_parquet
@@ -57,6 +59,7 @@ class TargetTypes:
57
59
  dataframe = "dataframe"
58
60
  custom = "custom"
59
61
  sql = "sql"
62
+ snowflake = "snowflake"
60
63
 
61
64
  @staticmethod
62
65
  def all():
@@ -71,6 +74,7 @@ class TargetTypes:
71
74
  TargetTypes.dataframe,
72
75
  TargetTypes.custom,
73
76
  TargetTypes.sql,
77
+ TargetTypes.snowflake,
74
78
  ]
75
79
 
76
80
 
@@ -78,11 +82,14 @@ def generate_target_run_id():
78
82
  return f"{round(time.time() * 1000)}_{random.randint(0, 999)}"
79
83
 
80
84
 
81
- def write_spark_dataframe_with_options(spark_options, df, mode):
85
+ def write_spark_dataframe_with_options(spark_options, df, mode, write_format=None):
82
86
  non_hadoop_spark_options = spark_session_update_hadoop_options(
83
87
  df.sql_ctx.sparkSession, spark_options
84
88
  )
85
- df.write.mode(mode).save(**non_hadoop_spark_options)
89
+ if write_format:
90
+ df.write.format(write_format).mode(mode).save(**non_hadoop_spark_options)
91
+ else:
92
+ df.write.mode(mode).save(**non_hadoop_spark_options)
86
93
 
87
94
 
88
95
  def default_target_names():
@@ -497,7 +504,10 @@ class BaseStoreTarget(DataTargetBase):
497
504
  options = self.get_spark_options(key_column, timestamp_key)
498
505
  options.update(kwargs)
499
506
  df = self.prepare_spark_df(df, key_column, timestamp_key, options)
500
- write_spark_dataframe_with_options(options, df, "overwrite")
507
+ write_format = options.pop("format", None)
508
+ write_spark_dataframe_with_options(
509
+ options, df, "overwrite", write_format=write_format
510
+ )
501
511
  elif hasattr(df, "dask"):
502
512
  dask_options = self.get_dask_options()
503
513
  store, path_in_store, target_path = self._get_store_and_path()
@@ -524,7 +534,12 @@ class BaseStoreTarget(DataTargetBase):
524
534
  store, path_in_store, target_path = self._get_store_and_path()
525
535
  target_path = generate_path_with_chunk(self, chunk_id, target_path)
526
536
  file_system = store.filesystem
527
- if file_system.protocol == "file":
537
+ if (
538
+ file_system.protocol == "file"
539
+ # fsspec 2023.10.0 changed protocol from "file" to ("file", "local")
540
+ or isinstance(file_system.protocol, (tuple, list))
541
+ and "file" in file_system.protocol
542
+ ):
528
543
  dir = os.path.dirname(target_path)
529
544
  if dir:
530
545
  os.makedirs(dir, exist_ok=True)
@@ -1108,6 +1123,97 @@ class CSVTarget(BaseStoreTarget):
1108
1123
  return True
1109
1124
 
1110
1125
 
1126
+ class SnowflakeTarget(BaseStoreTarget):
1127
+ """
1128
+ :param attributes: A dictionary of attributes for Snowflake connection; will be overridden by database parameters
1129
+ if they exist.
1130
+ :param url: Snowflake hostname, in the format: <account_name>.<region>.snowflakecomputing.com
1131
+ :param user: Snowflake user for login
1132
+ :param db_schema: Database schema
1133
+ :param database: Database name
1134
+ :param warehouse: Snowflake warehouse name
1135
+ :param table_name: Snowflake table name
1136
+ """
1137
+
1138
+ support_spark = True
1139
+ support_append = True
1140
+ is_offline = True
1141
+ kind = TargetTypes.snowflake
1142
+
1143
+ def __init__(
1144
+ self,
1145
+ name: str = "",
1146
+ path=None,
1147
+ attributes: dict[str, str] = None,
1148
+ after_step=None,
1149
+ columns=None,
1150
+ partitioned: bool = False,
1151
+ key_bucketing_number: Optional[int] = None,
1152
+ partition_cols: Optional[list[str]] = None,
1153
+ time_partitioning_granularity: Optional[str] = None,
1154
+ max_events: Optional[int] = None,
1155
+ flush_after_seconds: Optional[int] = None,
1156
+ storage_options: dict[str, str] = None,
1157
+ schema: dict[str, Any] = None,
1158
+ credentials_prefix=None,
1159
+ url: str = None,
1160
+ user: str = None,
1161
+ db_schema: str = None,
1162
+ database: str = None,
1163
+ warehouse: str = None,
1164
+ table_name: str = None,
1165
+ ):
1166
+ attrs = {
1167
+ "url": url,
1168
+ "user": user,
1169
+ "database": database,
1170
+ "schema": db_schema,
1171
+ "warehouse": warehouse,
1172
+ "table": table_name,
1173
+ }
1174
+ extended_attrs = {
1175
+ key: value for key, value in attrs.items() if value is not None
1176
+ }
1177
+ attributes = {} if not attributes else attributes
1178
+ attributes.update(extended_attrs)
1179
+ super().__init__(
1180
+ name,
1181
+ path,
1182
+ attributes,
1183
+ after_step,
1184
+ list(schema.keys()) if schema else columns,
1185
+ partitioned,
1186
+ key_bucketing_number,
1187
+ partition_cols,
1188
+ time_partitioning_granularity,
1189
+ max_events=max_events,
1190
+ flush_after_seconds=flush_after_seconds,
1191
+ storage_options=storage_options,
1192
+ schema=schema,
1193
+ credentials_prefix=credentials_prefix,
1194
+ )
1195
+
1196
+ def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True):
1197
+ spark_options = get_snowflake_spark_options(self.attributes)
1198
+ spark_options["dbtable"] = self.attributes.get("table")
1199
+ return spark_options
1200
+
1201
+ def purge(self):
1202
+ pass
1203
+
1204
+ def as_df(
1205
+ self,
1206
+ columns=None,
1207
+ df_module=None,
1208
+ entities=None,
1209
+ start_time=None,
1210
+ end_time=None,
1211
+ time_column=None,
1212
+ **kwargs,
1213
+ ):
1214
+ raise NotImplementedError()
1215
+
1216
+
1111
1217
  class NoSqlBaseTarget(BaseStoreTarget):
1112
1218
  is_table = True
1113
1219
  is_online = True
@@ -1179,7 +1285,10 @@ class NoSqlBaseTarget(BaseStoreTarget):
1179
1285
  options = self.get_spark_options(key_column, timestamp_key)
1180
1286
  options.update(kwargs)
1181
1287
  df = self.prepare_spark_df(df)
1182
- write_spark_dataframe_with_options(options, df, "overwrite")
1288
+ write_format = options.pop("format", None)
1289
+ write_spark_dataframe_with_options(
1290
+ options, df, "overwrite", write_format=write_format
1291
+ )
1183
1292
  else:
1184
1293
  # To prevent modification of the original dataframe and make sure
1185
1294
  # that the last event of a key is the one being persisted
@@ -1419,11 +1528,27 @@ class KafkaTarget(BaseStoreTarget):
1419
1528
  *args,
1420
1529
  bootstrap_servers=None,
1421
1530
  producer_options=None,
1531
+ brokers=None,
1422
1532
  **kwargs,
1423
1533
  ):
1424
1534
  attrs = {}
1425
- if bootstrap_servers is not None:
1426
- attrs["bootstrap_servers"] = bootstrap_servers
1535
+
1536
+ # TODO: Remove this in 1.9.0
1537
+ if bootstrap_servers:
1538
+ if brokers:
1539
+ raise mlrun.errors.MLRunInvalidArgumentError(
1540
+ "KafkaTarget cannot be created with both the 'brokers' parameter and the deprecated "
1541
+ "'bootstrap_servers' parameter. Please use 'brokers' only."
1542
+ )
1543
+ warnings.warn(
1544
+ "'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
1545
+ "use 'brokers' instead.",
1546
+ FutureWarning,
1547
+ )
1548
+ brokers = bootstrap_servers
1549
+
1550
+ if brokers:
1551
+ attrs["brokers"] = brokers
1427
1552
  if producer_options is not None:
1428
1553
  attrs["producer_options"] = producer_options
1429
1554
 
@@ -1445,14 +1570,16 @@ class KafkaTarget(BaseStoreTarget):
1445
1570
  if self.path and self.path.startswith("ds://"):
1446
1571
  datastore_profile = datastore_profile_read(self.path)
1447
1572
  attributes = datastore_profile.attributes()
1448
- bootstrap_servers = attributes.pop("bootstrap_servers", None)
1573
+ brokers = attributes.pop(
1574
+ "brokers", attributes.pop("bootstrap_servers", None)
1575
+ )
1449
1576
  topic = datastore_profile.topic
1450
1577
  else:
1451
1578
  attributes = copy(self.attributes)
1452
- bootstrap_servers = attributes.pop("bootstrap_servers", None)
1453
- topic, bootstrap_servers = parse_kafka_url(
1454
- self.get_target_path(), bootstrap_servers
1579
+ brokers = attributes.pop(
1580
+ "brokers", attributes.pop("bootstrap_servers", None)
1455
1581
  )
1582
+ topic, brokers = parse_kafka_url(self.get_target_path(), brokers)
1456
1583
 
1457
1584
  if not topic:
1458
1585
  raise mlrun.errors.MLRunInvalidArgumentError(
@@ -1466,7 +1593,7 @@ class KafkaTarget(BaseStoreTarget):
1466
1593
  class_name="storey.KafkaTarget",
1467
1594
  columns=column_list,
1468
1595
  topic=topic,
1469
- bootstrap_servers=bootstrap_servers,
1596
+ brokers=brokers,
1470
1597
  **attributes,
1471
1598
  )
1472
1599
 
@@ -1957,6 +2084,7 @@ kind_to_driver = {
1957
2084
  TargetTypes.tsdb: TSDBTarget,
1958
2085
  TargetTypes.custom: CustomTarget,
1959
2086
  TargetTypes.sql: SQLTarget,
2087
+ TargetTypes.snowflake: SnowflakeTarget,
1960
2088
  }
1961
2089
 
1962
2090
 
mlrun/datastore/utils.py CHANGED
@@ -23,24 +23,29 @@ import semver
23
23
  import mlrun.datastore
24
24
 
25
25
 
26
- def parse_kafka_url(url: str, bootstrap_servers: list = None) -> tuple[str, list]:
26
+ def parse_kafka_url(
27
+ url: str, brokers: typing.Union[list, str] = None
28
+ ) -> tuple[str, list]:
27
29
  """Generating Kafka topic and adjusting a list of bootstrap servers.
28
30
 
29
31
  :param url: URL path to parse using urllib.parse.urlparse.
30
- :param bootstrap_servers: List of bootstrap servers for the kafka brokers.
32
+ :param brokers: List of kafka brokers.
31
33
 
32
34
  :return: A tuple of:
33
35
  [0] = Kafka topic value
34
36
  [1] = List of bootstrap servers
35
37
  """
36
- bootstrap_servers = bootstrap_servers or []
38
+ brokers = brokers or []
39
+
40
+ if isinstance(brokers, str):
41
+ brokers = brokers.split(",")
37
42
 
38
43
  # Parse the provided URL into six components according to the general structure of a URL
39
44
  url = urlparse(url)
40
45
 
41
46
  # Add the network location to the bootstrap servers list
42
47
  if url.netloc:
43
- bootstrap_servers = [url.netloc] + bootstrap_servers
48
+ brokers = [url.netloc] + brokers
44
49
 
45
50
  # Get the topic value from the parsed url
46
51
  query_dict = parse_qs(url.query)
@@ -49,7 +54,7 @@ def parse_kafka_url(url: str, bootstrap_servers: list = None) -> tuple[str, list
49
54
  else:
50
55
  topic = url.path
51
56
  topic = topic.lstrip("/")
52
- return topic, bootstrap_servers
57
+ return topic, brokers
53
58
 
54
59
 
55
60
  def upload_tarball(source_dir, target, secrets=None):
mlrun/datastore/v3io.py CHANGED
@@ -12,8 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import mmap
16
- import os
17
15
  import time
18
16
  from datetime import datetime
19
17
 
@@ -22,7 +20,6 @@ import v3io
22
20
  from v3io.dataplane.response import HttpResponseError
23
21
 
24
22
  import mlrun
25
- from mlrun.datastore.helpers import ONE_GB, ONE_MB
26
23
 
27
24
  from ..platforms.iguazio import parse_path, split_path
28
25
  from .base import (
@@ -32,6 +29,7 @@ from .base import (
32
29
  )
33
30
 
34
31
  V3IO_LOCAL_ROOT = "v3io"
32
+ V3IO_DEFAULT_UPLOAD_CHUNK_SIZE = 1024 * 1024 * 100
35
33
 
36
34
 
37
35
  class V3ioStore(DataStore):
@@ -98,46 +96,28 @@ class V3ioStore(DataStore):
98
96
  )
99
97
  return self._sanitize_storage_options(res)
100
98
 
101
- def _upload(self, key: str, src_path: str, max_chunk_size: int = ONE_GB):
99
+ def _upload(
100
+ self,
101
+ key: str,
102
+ src_path: str,
103
+ max_chunk_size: int = V3IO_DEFAULT_UPLOAD_CHUNK_SIZE,
104
+ ):
102
105
  """helper function for upload method, allows for controlling max_chunk_size in testing"""
103
106
  container, path = split_path(self._join(key))
104
- file_size = os.path.getsize(src_path) # in bytes
105
- if file_size <= ONE_MB:
106
- with open(src_path, "rb") as source_file:
107
- data = source_file.read()
108
- self._do_object_request(
109
- self.object.put,
110
- container=container,
111
- path=path,
112
- body=data,
113
- append=False,
114
- )
115
- return
116
- # chunk must be a multiple of the ALLOCATIONGRANULARITY
117
- # https://docs.python.org/3/library/mmap.html
118
- if residue := max_chunk_size % mmap.ALLOCATIONGRANULARITY:
119
- # round down to the nearest multiple of ALLOCATIONGRANULARITY
120
- max_chunk_size -= residue
121
-
122
107
  with open(src_path, "rb") as file_obj:
123
- file_offset = 0
124
- while file_offset < file_size:
125
- chunk_size = min(file_size - file_offset, max_chunk_size)
126
- with mmap.mmap(
127
- file_obj.fileno(),
128
- length=chunk_size,
129
- access=mmap.ACCESS_READ,
130
- offset=file_offset,
131
- ) as mmap_obj:
132
- append = file_offset != 0
133
- self._do_object_request(
134
- self.object.put,
135
- container=container,
136
- path=path,
137
- body=mmap_obj,
138
- append=append,
139
- )
140
- file_offset += chunk_size
108
+ append = False
109
+ while True:
110
+ data = memoryview(file_obj.read(max_chunk_size))
111
+ if not data:
112
+ break
113
+ self._do_object_request(
114
+ self.object.put,
115
+ container=container,
116
+ path=path,
117
+ body=data,
118
+ append=append,
119
+ )
120
+ append = True
141
121
 
142
122
  def upload(self, key, src_path):
143
123
  return self._upload(key, src_path)
@@ -152,19 +132,16 @@ class V3ioStore(DataStore):
152
132
  num_bytes=size,
153
133
  ).body
154
134
 
155
- def _put(self, key, data, append=False, max_chunk_size: int = ONE_GB):
135
+ def _put(
136
+ self,
137
+ key,
138
+ data,
139
+ append=False,
140
+ max_chunk_size: int = V3IO_DEFAULT_UPLOAD_CHUNK_SIZE,
141
+ ):
156
142
  """helper function for put method, allows for controlling max_chunk_size in testing"""
157
143
  container, path = split_path(self._join(key))
158
144
  buffer_size = len(data) # in bytes
159
- if buffer_size <= ONE_MB:
160
- self._do_object_request(
161
- self.object.put,
162
- container=container,
163
- path=path,
164
- body=data,
165
- append=append,
166
- )
167
- return
168
145
  buffer_offset = 0
169
146
  try:
170
147
  data = memoryview(data)