snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/__init__.py +4 -1
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +281 -21
  4. snowflake/cortex/_extract_answer.py +0 -1
  5. snowflake/cortex/_sentiment.py +0 -1
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +12 -85
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  16. snowflake/ml/_internal/telemetry.py +38 -2
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  20. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  21. snowflake/ml/data/data_connector.py +133 -0
  22. snowflake/ml/data/data_ingestor.py +28 -0
  23. snowflake/ml/data/data_source.py +23 -0
  24. snowflake/ml/dataset/dataset.py +39 -32
  25. snowflake/ml/dataset/dataset_reader.py +18 -118
  26. snowflake/ml/feature_store/access_manager.py +7 -1
  27. snowflake/ml/feature_store/entity.py +19 -2
  28. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  29. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  30. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  31. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  32. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  34. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  35. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  36. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  37. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  38. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  39. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  40. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  43. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  44. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  45. snowflake/ml/feature_store/feature_store.py +987 -264
  46. snowflake/ml/feature_store/feature_view.py +228 -13
  47. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  48. snowflake/ml/fileset/fileset.py +2 -2
  49. snowflake/ml/fileset/snowfs.py +4 -15
  50. snowflake/ml/fileset/stage_fs.py +24 -18
  51. snowflake/ml/lineage/__init__.py +3 -0
  52. snowflake/ml/lineage/lineage_node.py +139 -0
  53. snowflake/ml/model/_client/model/model_impl.py +47 -14
  54. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  55. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  56. snowflake/ml/model/_client/sql/model.py +1 -0
  57. snowflake/ml/model/_client/sql/model_version.py +45 -2
  58. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  59. snowflake/ml/model/_model_composer/model_composer.py +15 -17
  60. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
  61. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  62. snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
  63. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  64. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
  65. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
  66. snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
  67. snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
  68. snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
  69. snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
  70. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  71. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  72. snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
  73. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  74. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  75. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  76. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  77. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  78. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  79. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  80. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  81. snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
  82. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  83. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  84. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  85. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  86. snowflake/ml/model/_packager/model_packager.py +9 -4
  87. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  88. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  89. snowflake/ml/model/custom_model.py +22 -2
  90. snowflake/ml/model/model_signature.py +4 -4
  91. snowflake/ml/model/type_hints.py +77 -4
  92. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
  93. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  94. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  95. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  96. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  97. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  98. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  99. snowflake/ml/modeling/cluster/birch.py +4 -2
  100. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  101. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  102. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  103. snowflake/ml/modeling/cluster/k_means.py +4 -2
  104. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  105. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  106. snowflake/ml/modeling/cluster/optics.py +4 -2
  107. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  108. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  109. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  110. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  111. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  112. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  113. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  114. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  115. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  116. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  117. snowflake/ml/modeling/covariance/oas.py +4 -2
  118. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  119. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  120. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  121. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  122. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  123. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  124. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  125. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  126. snowflake/ml/modeling/decomposition/pca.py +4 -2
  127. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  128. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  129. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  130. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  131. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  132. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  133. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  134. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  135. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  136. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  137. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  138. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  139. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  140. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  141. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  142. snowflake/ml/modeling/manifold/isomap.py +4 -2
  143. snowflake/ml/modeling/manifold/mds.py +4 -2
  144. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  145. snowflake/ml/modeling/manifold/tsne.py +4 -2
  146. snowflake/ml/modeling/metrics/ranking.py +3 -0
  147. snowflake/ml/modeling/metrics/regression.py +3 -0
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  150. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  151. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  152. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  153. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  154. snowflake/ml/modeling/pipeline/pipeline.py +5 -4
  155. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  156. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  157. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  158. snowflake/ml/registry/_manager/model_manager.py +16 -3
  159. snowflake/ml/registry/registry.py +100 -13
  160. snowflake/ml/version.py +1 -1
  161. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
  162. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
  163. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  164. snowflake/ml/_internal/lineage/data_source.py +0 -10
  165. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ This library only supports a limited set of features:
13
13
  It's recommended to use this library to copy previously tested images using sha256 to avoid surprises
14
14
  with respect to compatibility.
15
15
  """
16
+
16
17
  import dataclasses
17
18
  import hashlib
18
19
  import io
@@ -152,7 +153,8 @@ class BlobTransfer:
152
153
  src_image: ImageDescriptor
153
154
  dest_image: ImageDescriptor
154
155
  manifest: Manifest
155
- image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
156
+ src_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
157
+ dest_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
156
158
 
157
159
  def upload_all_blobs(self) -> None:
158
160
  blob_digests = self.manifest.get_blob_digests()
@@ -169,7 +171,7 @@ class BlobTransfer:
169
171
  """
170
172
  Check if the blob already exists in the destination registry.
171
173
  """
172
- resp = self.image_registry_http_client.head(self.dest_image.blob_link(blob_digest), headers={})
174
+ resp = self.dest_image_registry_http_client.head(self.dest_image.blob_link(blob_digest), headers={})
173
175
  return resp.status_code != 200
174
176
 
175
177
  def _fetch_blob(self, blob_digest: str) -> Tuple[io.BytesIO, int]:
@@ -178,7 +180,7 @@ class BlobTransfer:
178
180
  """
179
181
  src_blob_link = self.src_image.blob_link(blob_digest)
180
182
  headers = {_CONTENT_LENGTH_HEADER: "0"}
181
- resp = self.image_registry_http_client.get(src_blob_link, headers=headers)
183
+ resp = self.src_image_registry_http_client.get(src_blob_link, headers=headers)
182
184
 
183
185
  assert resp.status_code == 200, f"Blob GET failed with code {resp.status_code}"
184
186
  assert _CONTENT_LENGTH_HEADER in resp.headers, f"Blob does not contain {_CONTENT_LENGTH_HEADER}"
@@ -189,7 +191,7 @@ class BlobTransfer:
189
191
  """
190
192
  Obtain the upload URL from the destination registry.
191
193
  """
192
- response = self.image_registry_http_client.post(self.dest_image.blob_upload_link())
194
+ response = self.dest_image_registry_http_client.post(self.dest_image.blob_upload_link())
193
195
  assert (
194
196
  response.status_code == 202
195
197
  ), f"Failed to get the upload URL to destination. Status {response.status_code}. {str(response.content)}"
@@ -216,14 +218,14 @@ class BlobTransfer:
216
218
  headers[_CONTENT_RANGE_HEADER] = f"{start_byte}-{end_byte}"
217
219
  headers[_CONTENT_LENGTH_HEADER] = str(chunk_length)
218
220
 
219
- resp = self.image_registry_http_client.patch(next_loc, headers=headers, data=chunk)
221
+ resp = self.dest_image_registry_http_client.patch(next_loc, headers=headers, data=chunk)
220
222
  assert resp.status_code == 202, f"Blob PATCH failed with code {resp.status_code}"
221
223
 
222
224
  next_loc = resp.headers[_LOCATION_HEADER]
223
225
  start_byte += chunk_length
224
226
 
225
227
  # Finalize the upload
226
- resp = self.image_registry_http_client.put(f"{next_loc}&digest={blob_digest}")
228
+ resp = self.dest_image_registry_http_client.put(f"{next_loc}&digest={blob_digest}")
227
229
  assert resp.status_code == 201, f"Blob PUT failed with code {resp.status_code}"
228
230
 
229
231
  def _transfer(self, blob_digest: str) -> None:
@@ -340,21 +342,32 @@ def copy_image(
340
342
  src_image: ImageDescriptor,
341
343
  dest_image: ImageDescriptor,
342
344
  arch: _Arch,
343
- retryable_http: image_registry_http_client.ImageRegistryHttpClient,
345
+ src_retryable_http: image_registry_http_client.ImageRegistryHttpClient,
346
+ dest_retryable_http: image_registry_http_client.ImageRegistryHttpClient,
344
347
  ) -> None:
345
348
  logger.debug(f"Pulling image manifest for {src_image}")
346
349
 
347
350
  # 1. Get the manifest
348
- manifest = get_manifest(src_image, arch, retryable_http)
351
+ manifest = get_manifest(src_image, arch, src_retryable_http)
349
352
  logger.debug(f"Manifest pulled for {src_image} with digest {manifest.manifest_digest}")
350
353
 
351
354
  # 2: Retrieve all blob digests from manifest; fetch blob based on blob digest, then upload blob.
352
- blob_transfer = BlobTransfer(src_image, dest_image, manifest, image_registry_http_client=retryable_http)
355
+ blob_transfer = BlobTransfer(
356
+ src_image,
357
+ dest_image,
358
+ manifest,
359
+ src_image_registry_http_client=src_retryable_http,
360
+ dest_image_registry_http_client=dest_retryable_http,
361
+ )
353
362
  blob_transfer.upload_all_blobs()
354
363
 
355
364
  # 3. Upload the manifest
356
365
  logger.debug(f"All blobs copied successfully. Copying manifest for {src_image} to {dest_image}")
357
- put_manifest(dest_image, manifest, retryable_http)
366
+ put_manifest(
367
+ dest_image,
368
+ manifest,
369
+ dest_retryable_http,
370
+ )
358
371
 
359
372
  logger.debug(f"Image {src_image} copied to {dest_image}")
360
373
 
@@ -201,6 +201,12 @@ class ImageRegistryClient:
201
201
  )
202
202
  # TODO[shchen]: Remove the imagelib, instead rely on the copy image system function later.
203
203
  imagelib.copy_image(
204
- src_image=src_image, dest_image=dest_image, arch=arch, retryable_http=self.image_registry_http_client
204
+ src_image=src_image,
205
+ dest_image=dest_image,
206
+ arch=arch,
207
+ src_retryable_http=image_registry_http_client.ImageRegistryHttpClient(
208
+ repo_url=src_image.registry_name, no_cred=True
209
+ ),
210
+ dest_retryable_http=self.image_registry_http_client,
205
211
  )
206
212
  logger.info("Image copy completed successfully")
@@ -1,11 +1,11 @@
1
1
  # Error code from Snowflake Python Connector.
2
- ERRNO_OBJECT_ALREADY_EXISTS = "002002"
3
- ERRNO_OBJECT_NOT_EXIST = "002043"
4
- ERRNO_FILES_ALREADY_EXISTING = "001030"
5
- ERRNO_VERSION_ALREADY_EXISTS = "092917"
6
- ERRNO_DATASET_NOT_EXIST = "399019"
7
- ERRNO_DATASET_VERSION_NOT_EXIST = "399012"
8
- ERRNO_DATASET_VERSION_ALREADY_EXISTS = "399020"
2
+ ERRNO_OBJECT_ALREADY_EXISTS = 2002
3
+ ERRNO_OBJECT_NOT_EXIST = 2043
4
+ ERRNO_FILES_ALREADY_EXISTING = 1030
5
+ ERRNO_VERSION_ALREADY_EXISTS = 92917
6
+ ERRNO_DATASET_NOT_EXIST = 399019
7
+ ERRNO_DATASET_VERSION_NOT_EXIST = 399012
8
+ ERRNO_DATASET_VERSION_ALREADY_EXISTS = 399020
9
9
 
10
10
 
11
11
  class DatasetError(Exception):
@@ -1,7 +1,7 @@
1
1
  # Error code from Snowflake Python Connector.
2
- ERRNO_FILE_EXIST_IN_STAGE = "001030"
3
- ERRNO_DOMAIN_NOT_EXIST = "002003"
4
- ERRNO_STAGE_NOT_EXIST = "391707"
2
+ ERRNO_FILE_EXIST_IN_STAGE = 1030
3
+ ERRNO_DOMAIN_NOT_EXIST = 2003
4
+ ERRNO_STAGE_NOT_EXIST = 391707
5
5
 
6
6
 
7
7
  class FileSetError(Exception):
@@ -0,0 +1,6 @@
1
+ """SQL Error Codes"""
2
+
3
+ # SQL compilation error: Object ''{0}'' does not exist or not authorized.
4
+ OBJECT_NOT_EXIST = 2001
5
+ # SQL compilation error: Object ''{0}'' already exists.
6
+ OBJECT_ALREADY_EXISTS = 2002
@@ -1,9 +1,9 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List, Optional
3
+ from typing import Any, Callable, List, Optional, get_args
4
4
 
5
5
  from snowflake import snowpark
6
- from snowflake.ml._internal.lineage import data_source
6
+ from snowflake.ml.data import data_source
7
7
 
8
8
  _DATA_SOURCES_ATTR = "_data_sources"
9
9
 
@@ -39,7 +39,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
39
39
  result: Optional[List[data_source.DataSource]] = None
40
40
  for arg in args:
41
41
  srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
- if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
42
+ if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
43
43
  if result is None:
44
44
  result = []
45
45
  result += srcs
@@ -49,7 +49,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
49
49
  def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
50
50
  """Helper method for attaching data sources to an object"""
51
51
  if data_sources:
52
- assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
52
+ assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
53
53
  setattr(obj, _DATA_SOURCES_ATTR, data_sources)
54
54
 
55
55
 
@@ -10,6 +10,7 @@ from typing import (
10
10
  Dict,
11
11
  Iterable,
12
12
  List,
13
+ Mapping,
13
14
  Optional,
14
15
  Tuple,
15
16
  TypeVar,
@@ -92,6 +93,31 @@ def get_statement_params(
92
93
  )
93
94
 
94
95
 
96
+ def add_statement_params_custom_tags(
97
+ statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
98
+ ) -> Dict[str, Any]:
99
+ """
100
+ Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
101
+ If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
102
+
103
+ Args:
104
+ statement_params: Existing statement_params dictionary.
105
+ custom_tags: Dictionary of existing k/v pairs to add as custom_tags
106
+
107
+ Returns:
108
+ new statement_params dictionary with all keys and an updated custom_tags field.
109
+ """
110
+ if not statement_params:
111
+ return {}
112
+ existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
113
+ existing_custom_tags.update(custom_tags)
114
+ # NOTE: This can be done with | operator after upgrade from py3.8
115
+ return {
116
+ **statement_params,
117
+ TelemetryField.KEY_CUSTOM_TAGS.value: existing_custom_tags,
118
+ }
119
+
120
+
95
121
  # TODO: we can merge this with get_statement_params after code clean up
96
122
  def get_statement_params_full_func_name(frame: Optional[types.FrameType], class_name: Optional[str] = None) -> str:
97
123
  """
@@ -251,6 +277,7 @@ def send_api_usage_telemetry(
251
277
  ]
252
278
  ] = None,
253
279
  sfqids_extractor: Optional[Callable[..., List[str]]] = None,
280
+ subproject_extractor: Optional[Callable[[Any], str]] = None,
254
281
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
255
282
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
256
283
  """
@@ -264,6 +291,7 @@ def send_api_usage_telemetry(
264
291
  conn_attr_name: Name of the SnowflakeConnection attribute in `self`.
265
292
  api_calls_extractor: Extract API calls from `self`.
266
293
  sfqids_extractor: Extract sfqids from `self`.
294
+ subproject_extractor: Extract subproject at runtime from `self`.
267
295
  custom_tags: Custom tags.
268
296
 
269
297
  Returns:
@@ -271,10 +299,14 @@ def send_api_usage_telemetry(
271
299
 
272
300
  Raises:
273
301
  TypeError: If `conn_attr_name` is provided but the conn attribute is not of type SnowflakeConnection.
302
+ ValueError: If both `subproject` and `subproject_extractor` are provided
274
303
 
275
304
  # noqa: DAR402
276
305
  """
277
306
 
307
+ if subproject is not None and subproject_extractor is not None:
308
+ raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
309
+
278
310
  def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
279
311
  @functools.wraps(func)
280
312
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
@@ -296,9 +328,13 @@ def send_api_usage_telemetry(
296
328
  if sfqids_extractor:
297
329
  sfqids = sfqids_extractor(args[0])
298
330
 
331
+ subproject_name = subproject
332
+ if subproject_extractor is not None:
333
+ subproject_name = subproject_extractor(args[0])
334
+
299
335
  statement_params = get_function_usage_statement_params(
300
336
  project=project,
301
- subproject=subproject,
337
+ subproject=subproject_name,
302
338
  function_category=TelemetryField.FUNC_CAT_USAGE.value,
303
339
  function_name=_get_full_func_name(func),
304
340
  function_parameters=params,
@@ -355,7 +391,7 @@ def send_api_usage_telemetry(
355
391
  raise e.original_exception from e
356
392
 
357
393
  # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
358
- telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
394
+ telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
359
395
  telemetry_args = dict(
360
396
  func_name=_get_full_func_name(func),
361
397
  function_category=TelemetryField.FUNC_CAT_USAGE.value,
@@ -165,6 +165,20 @@ def parse_schema_level_object_identifier(
165
165
  )
166
166
 
167
167
 
168
+ def is_fully_qualified_name(name: str) -> bool:
169
+ """
170
+ Checks if a given name is a fully qualified name, which is in the format '<db>.<schema>.<object_name>'.
171
+
172
+ Args:
173
+ name: The name to be checked.
174
+
175
+ Returns:
176
+ bool: True if the name is fully qualified, False otherwise.
177
+ """
178
+ res = parse_schema_level_object_identifier(name)
179
+ return res[0] is not None and res[1] is not None and res[2] is not None and not res[3]
180
+
181
+
168
182
  def get_schema_level_object_identifier(
169
183
  db: Optional[str],
170
184
  schema: Optional[str],
@@ -1,22 +1,27 @@
1
1
  import logging
2
2
  import warnings
3
+ from typing import List, Optional
3
4
 
4
5
  from snowflake import snowpark
6
+ from snowflake.ml._internal.utils import sql_identifier
5
7
  from snowflake.snowpark import functions, types
6
8
 
7
9
 
8
- def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
10
+ def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
9
11
  """Cast columns in the dataframe to types that are compatible with tensor.
10
12
 
11
13
  It assists FileSet.make() in performing implicit data casting.
12
14
 
13
15
  Args:
14
16
  df: A snowpark dataframe.
17
+ ignore_columns: Columns to exclude from casting. These columns will be propagated unchanged.
15
18
 
16
19
  Returns:
17
20
  A snowpark dataframe whose data type has been casted.
18
21
  """
19
22
 
23
+ ignore_cols_set = {sql_identifier.SqlIdentifier(c).identifier() for c in ignore_columns} if ignore_columns else {}
24
+
20
25
  fields = df.schema.fields
21
26
  selected_cols = []
22
27
  for field in fields:
@@ -40,7 +45,9 @@ def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
40
45
  dest = field.datatype
41
46
  selected_cols.append(functions.cast(functions.col(src), dest).alias(src))
42
47
  else:
43
- if field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
48
+ if field.column_identifier.name in ignore_cols_set:
49
+ pass
50
+ elif field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
44
51
  logging.warning(
45
52
  "A Column with DATE or TIMESTAMP data type detected. "
46
53
  "It might not be able to get converted to tensors. "
@@ -90,7 +97,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
90
97
  " is being automatically converted to DoubleType in the Snowpark DataFrame. "
91
98
  "This automatic conversion may lead to potential precision loss and rounding errors. "
92
99
  "If you wish to prevent this conversion, you should manually perform "
93
- "the necessary data type conversion."
100
+ "the necessary data type conversion.",
101
+ UserWarning,
102
+ stacklevel=2,
94
103
  )
95
104
  else:
96
105
  # IntegerType default as NUMBER(38, 0), but
@@ -102,7 +111,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
102
111
  " is being automatically converted to LongType in the Snowpark DataFrame. "
103
112
  "This automatic conversion may lead to potential precision loss and rounding errors. "
104
113
  "If you wish to prevent this conversion, you should manually perform "
105
- "the necessary data type conversion."
114
+ "the necessary data type conversion.",
115
+ UserWarning,
116
+ stacklevel=2,
106
117
  )
107
118
  selected_cols.append(functions.cast(functions.col(src), dest_dtype).alias(src))
108
119
  # TODO: add more type handling or error message
@@ -0,0 +1,228 @@
1
+ import collections
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import Any, Deque, Dict, Iterator, List, Optional
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import pandas as pd
10
+ import pyarrow as pa
11
+ import pyarrow.dataset as ds
12
+
13
+ from snowflake import snowpark
14
+ from snowflake.ml.data import data_ingestor, data_source
15
+ from snowflake.ml.data._internal import ingestor_utils
16
+
17
+ _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
18
+
19
+ # The row count for batches read from PyArrow Dataset. This number should be large enough so that
20
+ # dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
21
+ _DEFAULT_DATASET_BATCH_SIZE = 1000000
22
+
23
+
24
+ class _RecordBatchesBuffer:
25
+ """A queue that stores record batches and tracks the total num of rows in it."""
26
+
27
+ def __init__(self) -> None:
28
+ self.buffer: Deque[pa.RecordBatch] = collections.deque()
29
+ self.num_rows = 0
30
+
31
+ def append(self, rb: pa.RecordBatch) -> None:
32
+ self.buffer.append(rb)
33
+ self.num_rows += rb.num_rows
34
+
35
+ def appendleft(self, rb: pa.RecordBatch) -> None:
36
+ self.buffer.appendleft(rb)
37
+ self.num_rows += rb.num_rows
38
+
39
+ def popleft(self) -> pa.RecordBatch:
40
+ popped = self.buffer.popleft()
41
+ self.num_rows -= popped.num_rows
42
+ return popped
43
+
44
+
45
+ class ArrowIngestor(data_ingestor.DataIngestor):
46
+ """Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
47
+
48
+ def __init__(
49
+ self,
50
+ session: snowpark.Session,
51
+ data_sources: List[data_source.DataSource],
52
+ format: Optional[str] = None,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ """
56
+ Args:
57
+ session: The Snowpark Session to use.
58
+ data_sources: List of data sources to ingest.
59
+ format: Currently “parquet”, “ipc”/”arrow”/”feather”, “csv”, “json”, and “orc” are supported.
60
+ Will be inferred if not specified.
61
+ kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
62
+ """
63
+ self._session = session
64
+ self._data_sources = data_sources
65
+ self._format = format
66
+ self._kwargs = kwargs
67
+
68
+ self._schema: Optional[pa.Schema] = None
69
+
70
+ @property
71
+ def data_sources(self) -> List[data_source.DataSource]:
72
+ return self._data_sources
73
+
74
+ def to_batches(
75
+ self,
76
+ batch_size: int,
77
+ shuffle: bool = True,
78
+ drop_last_batch: bool = True,
79
+ ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
80
+ """Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
81
+
82
+ As we are generating batches with the exactly same length, the last few rows in each file might get left as they
83
+ are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
84
+ few rows of the next file to generate a new batch.
85
+
86
+ Args:
87
+ batch_size: Specifies the size of each batch that will be yield
88
+ shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
89
+ the order of files, and then shuflle the order of rows in each file.
90
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last
91
+ batch will get dropped if its size is smaller than the given batch_size.
92
+
93
+ Yields:
94
+ A dict mapping column names to the corresponding data fetch from that column.
95
+ """
96
+ self._rb_buffer = _RecordBatchesBuffer()
97
+
98
+ # Extract schema if not already known
99
+ dataset = self._get_dataset(shuffle)
100
+ if self._schema is None:
101
+ self._schema = dataset.schema
102
+
103
+ for rb in _retryable_batches(dataset, batch_size=max(_DEFAULT_DATASET_BATCH_SIZE, batch_size)):
104
+ if shuffle:
105
+ rb = rb.take(np.random.permutation(rb.num_rows))
106
+ self._rb_buffer.append(rb)
107
+ while self._rb_buffer.num_rows >= batch_size:
108
+ yield self._get_batches_from_buffer(batch_size)
109
+
110
+ if self._rb_buffer.num_rows and not drop_last_batch:
111
+ yield self._get_batches_from_buffer(batch_size)
112
+
113
+ def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
114
+ ds = self._get_dataset(shuffle=False)
115
+ table = ds.to_table() if limit is None else ds.head(num_rows=limit)
116
+ return table.to_pandas()
117
+
118
+ def _get_dataset(self, shuffle: bool) -> ds.Dataset:
119
+ format = self._format
120
+ sources = []
121
+ source_format = None
122
+ for source in self._data_sources:
123
+ if isinstance(source, str):
124
+ sources.append(source)
125
+ source_format = format or os.path.splitext(source)[-1]
126
+ elif isinstance(source, data_source.DatasetInfo):
127
+ if not self._kwargs.get("filesystem"):
128
+ self._kwargs["filesystem"] = ingestor_utils.get_dataset_filesystem(self._session, source)
129
+ sources.extend(
130
+ ingestor_utils.get_dataset_files(self._session, source, filesystem=self._kwargs["filesystem"])
131
+ )
132
+ source_format = "parquet"
133
+ elif isinstance(source, data_source.DataFrameInfo):
134
+ # FIXME: This currently loads all result batches into memory so that it
135
+ # can be passed into pyarrow.dataset as a list/tuple of pa.RecordBatches
136
+ # We may be able to optimize this by splitting the result batches into
137
+ # in-memory (first batch) and file URLs (subsequent batches) and creating a
138
+ # union dataset.
139
+ result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
140
+ sources.extend(b.to_arrow() for b in result_batches)
141
+ source_format = "arrow"
142
+ else:
143
+ raise RuntimeError(f"Unsupported data source type: {type(source)}")
144
+
145
+ # Make sure source types not mixed
146
+ if format and format != source_format:
147
+ raise RuntimeError(f"Unexpected data source format (expected {format}, found {source_format})")
148
+ format = source_format
149
+
150
+ # Re-shuffle input files on each iteration start
151
+ if shuffle:
152
+ np.random.shuffle(sources)
153
+ pa_dataset: ds.Dataset = ds.dataset(sources, format=format, **self._kwargs)
154
+ return pa_dataset
155
+
156
+ def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
157
+ """Generate new batches from the existing record batch buffer."""
158
+ cnt_rbs_num_rows = 0
159
+ candidates = []
160
+
161
+ # Keep popping record batches in buffer until there are enough rows for a batch.
162
+ while self._rb_buffer.num_rows and cnt_rbs_num_rows < batch_size:
163
+ candidate = self._rb_buffer.popleft()
164
+ cnt_rbs_num_rows += candidate.num_rows
165
+ candidates.append(candidate)
166
+
167
+ # When there are more rows than needed, slice the last popped batch to fit batch_size.
168
+ if cnt_rbs_num_rows > batch_size:
169
+ row_diff = cnt_rbs_num_rows - batch_size
170
+ slice_target = candidates[-1]
171
+ cut_off = slice_target.num_rows - row_diff
172
+ to_merge = slice_target.slice(length=cut_off)
173
+ left_over = slice_target.slice(offset=cut_off)
174
+ candidates[-1] = to_merge
175
+ self._rb_buffer.appendleft(left_over)
176
+
177
+ res = _merge_record_batches(candidates)
178
+ return _record_batch_to_arrays(res)
179
+
180
+
181
+ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
182
+ """Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
183
+ if not record_batches:
184
+ return _EMPTY_RECORD_BATCH
185
+ if len(record_batches) == 1:
186
+ return record_batches[0]
187
+ record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
188
+ one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
189
+ batches = one_chunk_table.to_batches(max_chunksize=None)
190
+ return batches[0]
191
+
192
+
193
+ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
194
+ """Transform the record batch to a (string, numpy array) dict."""
195
+ batch_dict = {}
196
+ for column, column_schema in zip(rb, rb.schema):
197
+ # zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
198
+ array = column.to_numpy(zero_copy_only=False)
199
+ batch_dict[column_schema.name] = array
200
+ return batch_dict
201
+
202
+
203
+ def _retryable_batches(
204
+ dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
205
+ ) -> Iterator[pa.RecordBatch]:
206
+ """Make the Dataset to_batches retryable."""
207
+ retries = 0
208
+ current_batch_index = 0
209
+
210
+ while True:
211
+ try:
212
+ for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
213
+ if batch_index < current_batch_index:
214
+ # Skip batches that have already been processed
215
+ continue
216
+
217
+ yield batch
218
+ current_batch_index = batch_index + 1
219
+ # Exit the loop once all batches are processed
220
+ break
221
+
222
+ except Exception as e:
223
+ if retries < max_retries:
224
+ retries += 1
225
+ logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
226
+ time.sleep(delay)
227
+ else:
228
+ raise e
@@ -0,0 +1,58 @@
1
+ from typing import List, Optional
2
+
3
+ import fsspec
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.connector import result_batch
7
+ from snowflake.ml.data import data_source
8
+ from snowflake.ml.fileset import snowfs
9
+
10
+ _TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
11
+
12
+
13
+ def get_dataframe_result_batches(
14
+ session: snowpark.Session, df_info: data_source.DataFrameInfo
15
+ ) -> List[result_batch.ResultBatch]:
16
+ cursor = session._conn._cursor
17
+
18
+ if df_info.query_id:
19
+ query_id = df_info.query_id
20
+ else:
21
+ query_id = session.sql(df_info.sql).collect_nowait().query_id
22
+
23
+ # TODO: Check if query result cache is still live
24
+ cursor.get_results_from_sfqid(sfqid=query_id)
25
+
26
+ # Prefetch hook should be set by `get_results_from_sfqid`
27
+ # This call blocks until the query results are ready
28
+ if cursor._prefetch_hook is None:
29
+ raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
30
+ cursor._prefetch_hook()
31
+ batches = cursor.get_result_batches()
32
+ if batches is None:
33
+ raise ValueError(
34
+ "Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}"
35
+ )
36
+ return batches
37
+
38
+
39
+ def get_dataset_filesystem(
40
+ session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
41
+ ) -> fsspec.AbstractFileSystem:
42
+ # We can't directly load the Dataset to avoid a circular dependency
43
+ # Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
44
+ # TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
45
+ return snowfs.SnowFileSystem(
46
+ snowpark_session=session,
47
+ cache_type="bytes",
48
+ block_size=2 * _TARGET_FILE_SIZE,
49
+ )
50
+
51
+
52
+ def get_dataset_files(
53
+ session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
54
+ ) -> List[str]:
55
+ if filesystem is None:
56
+ filesystem = get_dataset_filesystem(session, ds_info)
57
+ assert bool(ds_info.url) # Not null or empty
58
+ return sorted(filesystem.ls(ds_info.url))