snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -1
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +281 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +38 -2
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +39 -32
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +987 -264
- snowflake/ml/feature_store/feature_view.py +228 -13
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +24 -18
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_model_composer/model_composer.py +15 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
- snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
- snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +77 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +5 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@ This library only supports a limited set of features:
|
|
13
13
|
It's recommended to use this library to copy previously tested images using sha256 to avoid surprises
|
14
14
|
with respect to compatibility.
|
15
15
|
"""
|
16
|
+
|
16
17
|
import dataclasses
|
17
18
|
import hashlib
|
18
19
|
import io
|
@@ -152,7 +153,8 @@ class BlobTransfer:
|
|
152
153
|
src_image: ImageDescriptor
|
153
154
|
dest_image: ImageDescriptor
|
154
155
|
manifest: Manifest
|
155
|
-
|
156
|
+
src_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
|
157
|
+
dest_image_registry_http_client: image_registry_http_client.ImageRegistryHttpClient
|
156
158
|
|
157
159
|
def upload_all_blobs(self) -> None:
|
158
160
|
blob_digests = self.manifest.get_blob_digests()
|
@@ -169,7 +171,7 @@ class BlobTransfer:
|
|
169
171
|
"""
|
170
172
|
Check if the blob already exists in the destination registry.
|
171
173
|
"""
|
172
|
-
resp = self.
|
174
|
+
resp = self.dest_image_registry_http_client.head(self.dest_image.blob_link(blob_digest), headers={})
|
173
175
|
return resp.status_code != 200
|
174
176
|
|
175
177
|
def _fetch_blob(self, blob_digest: str) -> Tuple[io.BytesIO, int]:
|
@@ -178,7 +180,7 @@ class BlobTransfer:
|
|
178
180
|
"""
|
179
181
|
src_blob_link = self.src_image.blob_link(blob_digest)
|
180
182
|
headers = {_CONTENT_LENGTH_HEADER: "0"}
|
181
|
-
resp = self.
|
183
|
+
resp = self.src_image_registry_http_client.get(src_blob_link, headers=headers)
|
182
184
|
|
183
185
|
assert resp.status_code == 200, f"Blob GET failed with code {resp.status_code}"
|
184
186
|
assert _CONTENT_LENGTH_HEADER in resp.headers, f"Blob does not contain {_CONTENT_LENGTH_HEADER}"
|
@@ -189,7 +191,7 @@ class BlobTransfer:
|
|
189
191
|
"""
|
190
192
|
Obtain the upload URL from the destination registry.
|
191
193
|
"""
|
192
|
-
response = self.
|
194
|
+
response = self.dest_image_registry_http_client.post(self.dest_image.blob_upload_link())
|
193
195
|
assert (
|
194
196
|
response.status_code == 202
|
195
197
|
), f"Failed to get the upload URL to destination. Status {response.status_code}. {str(response.content)}"
|
@@ -216,14 +218,14 @@ class BlobTransfer:
|
|
216
218
|
headers[_CONTENT_RANGE_HEADER] = f"{start_byte}-{end_byte}"
|
217
219
|
headers[_CONTENT_LENGTH_HEADER] = str(chunk_length)
|
218
220
|
|
219
|
-
resp = self.
|
221
|
+
resp = self.dest_image_registry_http_client.patch(next_loc, headers=headers, data=chunk)
|
220
222
|
assert resp.status_code == 202, f"Blob PATCH failed with code {resp.status_code}"
|
221
223
|
|
222
224
|
next_loc = resp.headers[_LOCATION_HEADER]
|
223
225
|
start_byte += chunk_length
|
224
226
|
|
225
227
|
# Finalize the upload
|
226
|
-
resp = self.
|
228
|
+
resp = self.dest_image_registry_http_client.put(f"{next_loc}&digest={blob_digest}")
|
227
229
|
assert resp.status_code == 201, f"Blob PUT failed with code {resp.status_code}"
|
228
230
|
|
229
231
|
def _transfer(self, blob_digest: str) -> None:
|
@@ -340,21 +342,32 @@ def copy_image(
|
|
340
342
|
src_image: ImageDescriptor,
|
341
343
|
dest_image: ImageDescriptor,
|
342
344
|
arch: _Arch,
|
343
|
-
|
345
|
+
src_retryable_http: image_registry_http_client.ImageRegistryHttpClient,
|
346
|
+
dest_retryable_http: image_registry_http_client.ImageRegistryHttpClient,
|
344
347
|
) -> None:
|
345
348
|
logger.debug(f"Pulling image manifest for {src_image}")
|
346
349
|
|
347
350
|
# 1. Get the manifest
|
348
|
-
manifest = get_manifest(src_image, arch,
|
351
|
+
manifest = get_manifest(src_image, arch, src_retryable_http)
|
349
352
|
logger.debug(f"Manifest pulled for {src_image} with digest {manifest.manifest_digest}")
|
350
353
|
|
351
354
|
# 2: Retrieve all blob digests from manifest; fetch blob based on blob digest, then upload blob.
|
352
|
-
blob_transfer = BlobTransfer(
|
355
|
+
blob_transfer = BlobTransfer(
|
356
|
+
src_image,
|
357
|
+
dest_image,
|
358
|
+
manifest,
|
359
|
+
src_image_registry_http_client=src_retryable_http,
|
360
|
+
dest_image_registry_http_client=dest_retryable_http,
|
361
|
+
)
|
353
362
|
blob_transfer.upload_all_blobs()
|
354
363
|
|
355
364
|
# 3. Upload the manifest
|
356
365
|
logger.debug(f"All blobs copied successfully. Copying manifest for {src_image} to {dest_image}")
|
357
|
-
put_manifest(
|
366
|
+
put_manifest(
|
367
|
+
dest_image,
|
368
|
+
manifest,
|
369
|
+
dest_retryable_http,
|
370
|
+
)
|
358
371
|
|
359
372
|
logger.debug(f"Image {src_image} copied to {dest_image}")
|
360
373
|
|
@@ -201,6 +201,12 @@ class ImageRegistryClient:
|
|
201
201
|
)
|
202
202
|
# TODO[shchen]: Remove the imagelib, instead rely on the copy image system function later.
|
203
203
|
imagelib.copy_image(
|
204
|
-
src_image=src_image,
|
204
|
+
src_image=src_image,
|
205
|
+
dest_image=dest_image,
|
206
|
+
arch=arch,
|
207
|
+
src_retryable_http=image_registry_http_client.ImageRegistryHttpClient(
|
208
|
+
repo_url=src_image.registry_name, no_cred=True
|
209
|
+
),
|
210
|
+
dest_retryable_http=self.image_registry_http_client,
|
205
211
|
)
|
206
212
|
logger.info("Image copy completed successfully")
|
@@ -1,11 +1,11 @@
|
|
1
1
|
# Error code from Snowflake Python Connector.
|
2
|
-
ERRNO_OBJECT_ALREADY_EXISTS =
|
3
|
-
ERRNO_OBJECT_NOT_EXIST =
|
4
|
-
ERRNO_FILES_ALREADY_EXISTING =
|
5
|
-
ERRNO_VERSION_ALREADY_EXISTS =
|
6
|
-
ERRNO_DATASET_NOT_EXIST =
|
7
|
-
ERRNO_DATASET_VERSION_NOT_EXIST =
|
8
|
-
ERRNO_DATASET_VERSION_ALREADY_EXISTS =
|
2
|
+
ERRNO_OBJECT_ALREADY_EXISTS = 2002
|
3
|
+
ERRNO_OBJECT_NOT_EXIST = 2043
|
4
|
+
ERRNO_FILES_ALREADY_EXISTING = 1030
|
5
|
+
ERRNO_VERSION_ALREADY_EXISTS = 92917
|
6
|
+
ERRNO_DATASET_NOT_EXIST = 399019
|
7
|
+
ERRNO_DATASET_VERSION_NOT_EXIST = 399012
|
8
|
+
ERRNO_DATASET_VERSION_ALREADY_EXISTS = 399020
|
9
9
|
|
10
10
|
|
11
11
|
class DatasetError(Exception):
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Error code from Snowflake Python Connector.
|
2
|
-
ERRNO_FILE_EXIST_IN_STAGE =
|
3
|
-
ERRNO_DOMAIN_NOT_EXIST =
|
4
|
-
ERRNO_STAGE_NOT_EXIST =
|
2
|
+
ERRNO_FILE_EXIST_IN_STAGE = 1030
|
3
|
+
ERRNO_DOMAIN_NOT_EXIST = 2003
|
4
|
+
ERRNO_STAGE_NOT_EXIST = 391707
|
5
5
|
|
6
6
|
|
7
7
|
class FileSetError(Exception):
|
@@ -1,9 +1,9 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable, List, Optional
|
3
|
+
from typing import Any, Callable, List, Optional, get_args
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
|
-
from snowflake.ml.
|
6
|
+
from snowflake.ml.data import data_source
|
7
7
|
|
8
8
|
_DATA_SOURCES_ATTR = "_data_sources"
|
9
9
|
|
@@ -39,7 +39,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
39
39
|
result: Optional[List[data_source.DataSource]] = None
|
40
40
|
for arg in args:
|
41
41
|
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
|
-
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
42
|
+
if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
|
43
43
|
if result is None:
|
44
44
|
result = []
|
45
45
|
result += srcs
|
@@ -49,7 +49,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
49
49
|
def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
|
50
50
|
"""Helper method for attaching data sources to an object"""
|
51
51
|
if data_sources:
|
52
|
-
assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
|
52
|
+
assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
|
53
53
|
setattr(obj, _DATA_SOURCES_ATTR, data_sources)
|
54
54
|
|
55
55
|
|
@@ -10,6 +10,7 @@ from typing import (
|
|
10
10
|
Dict,
|
11
11
|
Iterable,
|
12
12
|
List,
|
13
|
+
Mapping,
|
13
14
|
Optional,
|
14
15
|
Tuple,
|
15
16
|
TypeVar,
|
@@ -92,6 +93,31 @@ def get_statement_params(
|
|
92
93
|
)
|
93
94
|
|
94
95
|
|
96
|
+
def add_statement_params_custom_tags(
|
97
|
+
statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
|
98
|
+
) -> Dict[str, Any]:
|
99
|
+
"""
|
100
|
+
Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
|
101
|
+
If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
statement_params: Existing statement_params dictionary.
|
105
|
+
custom_tags: Dictionary of existing k/v pairs to add as custom_tags
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
new statement_params dictionary with all keys and an updated custom_tags field.
|
109
|
+
"""
|
110
|
+
if not statement_params:
|
111
|
+
return {}
|
112
|
+
existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
|
113
|
+
existing_custom_tags.update(custom_tags)
|
114
|
+
# NOTE: This can be done with | operator after upgrade from py3.8
|
115
|
+
return {
|
116
|
+
**statement_params,
|
117
|
+
TelemetryField.KEY_CUSTOM_TAGS.value: existing_custom_tags,
|
118
|
+
}
|
119
|
+
|
120
|
+
|
95
121
|
# TODO: we can merge this with get_statement_params after code clean up
|
96
122
|
def get_statement_params_full_func_name(frame: Optional[types.FrameType], class_name: Optional[str] = None) -> str:
|
97
123
|
"""
|
@@ -251,6 +277,7 @@ def send_api_usage_telemetry(
|
|
251
277
|
]
|
252
278
|
] = None,
|
253
279
|
sfqids_extractor: Optional[Callable[..., List[str]]] = None,
|
280
|
+
subproject_extractor: Optional[Callable[[Any], str]] = None,
|
254
281
|
custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
|
255
282
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
|
256
283
|
"""
|
@@ -264,6 +291,7 @@ def send_api_usage_telemetry(
|
|
264
291
|
conn_attr_name: Name of the SnowflakeConnection attribute in `self`.
|
265
292
|
api_calls_extractor: Extract API calls from `self`.
|
266
293
|
sfqids_extractor: Extract sfqids from `self`.
|
294
|
+
subproject_extractor: Extract subproject at runtime from `self`.
|
267
295
|
custom_tags: Custom tags.
|
268
296
|
|
269
297
|
Returns:
|
@@ -271,10 +299,14 @@ def send_api_usage_telemetry(
|
|
271
299
|
|
272
300
|
Raises:
|
273
301
|
TypeError: If `conn_attr_name` is provided but the conn attribute is not of type SnowflakeConnection.
|
302
|
+
ValueError: If both `subproject` and `subproject_extractor` are provided
|
274
303
|
|
275
304
|
# noqa: DAR402
|
276
305
|
"""
|
277
306
|
|
307
|
+
if subproject is not None and subproject_extractor is not None:
|
308
|
+
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
309
|
+
|
278
310
|
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
|
279
311
|
@functools.wraps(func)
|
280
312
|
def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
|
@@ -296,9 +328,13 @@ def send_api_usage_telemetry(
|
|
296
328
|
if sfqids_extractor:
|
297
329
|
sfqids = sfqids_extractor(args[0])
|
298
330
|
|
331
|
+
subproject_name = subproject
|
332
|
+
if subproject_extractor is not None:
|
333
|
+
subproject_name = subproject_extractor(args[0])
|
334
|
+
|
299
335
|
statement_params = get_function_usage_statement_params(
|
300
336
|
project=project,
|
301
|
-
subproject=
|
337
|
+
subproject=subproject_name,
|
302
338
|
function_category=TelemetryField.FUNC_CAT_USAGE.value,
|
303
339
|
function_name=_get_full_func_name(func),
|
304
340
|
function_parameters=params,
|
@@ -355,7 +391,7 @@ def send_api_usage_telemetry(
|
|
355
391
|
raise e.original_exception from e
|
356
392
|
|
357
393
|
# TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
|
358
|
-
telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=
|
394
|
+
telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
|
359
395
|
telemetry_args = dict(
|
360
396
|
func_name=_get_full_func_name(func),
|
361
397
|
function_category=TelemetryField.FUNC_CAT_USAGE.value,
|
@@ -165,6 +165,20 @@ def parse_schema_level_object_identifier(
|
|
165
165
|
)
|
166
166
|
|
167
167
|
|
168
|
+
def is_fully_qualified_name(name: str) -> bool:
|
169
|
+
"""
|
170
|
+
Checks if a given name is a fully qualified name, which is in the format '<db>.<schema>.<object_name>'.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
name: The name to be checked.
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
bool: True if the name is fully qualified, False otherwise.
|
177
|
+
"""
|
178
|
+
res = parse_schema_level_object_identifier(name)
|
179
|
+
return res[0] is not None and res[1] is not None and res[2] is not None and not res[3]
|
180
|
+
|
181
|
+
|
168
182
|
def get_schema_level_object_identifier(
|
169
183
|
db: Optional[str],
|
170
184
|
schema: Optional[str],
|
@@ -1,22 +1,27 @@
|
|
1
1
|
import logging
|
2
2
|
import warnings
|
3
|
+
from typing import List, Optional
|
3
4
|
|
4
5
|
from snowflake import snowpark
|
6
|
+
from snowflake.ml._internal.utils import sql_identifier
|
5
7
|
from snowflake.snowpark import functions, types
|
6
8
|
|
7
9
|
|
8
|
-
def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
|
10
|
+
def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
|
9
11
|
"""Cast columns in the dataframe to types that are compatible with tensor.
|
10
12
|
|
11
13
|
It assists FileSet.make() in performing implicit data casting.
|
12
14
|
|
13
15
|
Args:
|
14
16
|
df: A snowpark dataframe.
|
17
|
+
ignore_columns: Columns to exclude from casting. These columns will be propagated unchanged.
|
15
18
|
|
16
19
|
Returns:
|
17
20
|
A snowpark dataframe whose data type has been casted.
|
18
21
|
"""
|
19
22
|
|
23
|
+
ignore_cols_set = {sql_identifier.SqlIdentifier(c).identifier() for c in ignore_columns} if ignore_columns else {}
|
24
|
+
|
20
25
|
fields = df.schema.fields
|
21
26
|
selected_cols = []
|
22
27
|
for field in fields:
|
@@ -40,7 +45,9 @@ def cast_snowpark_dataframe(df: snowpark.DataFrame) -> snowpark.DataFrame:
|
|
40
45
|
dest = field.datatype
|
41
46
|
selected_cols.append(functions.cast(functions.col(src), dest).alias(src))
|
42
47
|
else:
|
43
|
-
if field.
|
48
|
+
if field.column_identifier.name in ignore_cols_set:
|
49
|
+
pass
|
50
|
+
elif field.datatype in (types.DateType(), types.TimestampType(), types.TimeType()):
|
44
51
|
logging.warning(
|
45
52
|
"A Column with DATE or TIMESTAMP data type detected. "
|
46
53
|
"It might not be able to get converted to tensors. "
|
@@ -90,7 +97,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
|
|
90
97
|
" is being automatically converted to DoubleType in the Snowpark DataFrame. "
|
91
98
|
"This automatic conversion may lead to potential precision loss and rounding errors. "
|
92
99
|
"If you wish to prevent this conversion, you should manually perform "
|
93
|
-
"the necessary data type conversion."
|
100
|
+
"the necessary data type conversion.",
|
101
|
+
UserWarning,
|
102
|
+
stacklevel=2,
|
94
103
|
)
|
95
104
|
else:
|
96
105
|
# IntegerType default as NUMBER(38, 0), but
|
@@ -102,7 +111,9 @@ def cast_snowpark_dataframe_column_types(df: snowpark.DataFrame) -> snowpark.Dat
|
|
102
111
|
" is being automatically converted to LongType in the Snowpark DataFrame. "
|
103
112
|
"This automatic conversion may lead to potential precision loss and rounding errors. "
|
104
113
|
"If you wish to prevent this conversion, you should manually perform "
|
105
|
-
"the necessary data type conversion."
|
114
|
+
"the necessary data type conversion.",
|
115
|
+
UserWarning,
|
116
|
+
stacklevel=2,
|
106
117
|
)
|
107
118
|
selected_cols.append(functions.cast(functions.col(src), dest_dtype).alias(src))
|
108
119
|
# TODO: add more type handling or error message
|
@@ -0,0 +1,228 @@
|
|
1
|
+
import collections
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import numpy.typing as npt
|
9
|
+
import pandas as pd
|
10
|
+
import pyarrow as pa
|
11
|
+
import pyarrow.dataset as ds
|
12
|
+
|
13
|
+
from snowflake import snowpark
|
14
|
+
from snowflake.ml.data import data_ingestor, data_source
|
15
|
+
from snowflake.ml.data._internal import ingestor_utils
|
16
|
+
|
17
|
+
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
18
|
+
|
19
|
+
# The row count for batches read from PyArrow Dataset. This number should be large enough so that
|
20
|
+
# dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
|
21
|
+
_DEFAULT_DATASET_BATCH_SIZE = 1000000
|
22
|
+
|
23
|
+
|
24
|
+
class _RecordBatchesBuffer:
|
25
|
+
"""A queue that stores record batches and tracks the total num of rows in it."""
|
26
|
+
|
27
|
+
def __init__(self) -> None:
|
28
|
+
self.buffer: Deque[pa.RecordBatch] = collections.deque()
|
29
|
+
self.num_rows = 0
|
30
|
+
|
31
|
+
def append(self, rb: pa.RecordBatch) -> None:
|
32
|
+
self.buffer.append(rb)
|
33
|
+
self.num_rows += rb.num_rows
|
34
|
+
|
35
|
+
def appendleft(self, rb: pa.RecordBatch) -> None:
|
36
|
+
self.buffer.appendleft(rb)
|
37
|
+
self.num_rows += rb.num_rows
|
38
|
+
|
39
|
+
def popleft(self) -> pa.RecordBatch:
|
40
|
+
popped = self.buffer.popleft()
|
41
|
+
self.num_rows -= popped.num_rows
|
42
|
+
return popped
|
43
|
+
|
44
|
+
|
45
|
+
class ArrowIngestor(data_ingestor.DataIngestor):
|
46
|
+
"""Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
session: snowpark.Session,
|
51
|
+
data_sources: List[data_source.DataSource],
|
52
|
+
format: Optional[str] = None,
|
53
|
+
**kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
"""
|
56
|
+
Args:
|
57
|
+
session: The Snowpark Session to use.
|
58
|
+
data_sources: List of data sources to ingest.
|
59
|
+
format: Currently “parquet”, “ipc”/”arrow”/”feather”, “csv”, “json”, and “orc” are supported.
|
60
|
+
Will be inferred if not specified.
|
61
|
+
kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
|
62
|
+
"""
|
63
|
+
self._session = session
|
64
|
+
self._data_sources = data_sources
|
65
|
+
self._format = format
|
66
|
+
self._kwargs = kwargs
|
67
|
+
|
68
|
+
self._schema: Optional[pa.Schema] = None
|
69
|
+
|
70
|
+
@property
|
71
|
+
def data_sources(self) -> List[data_source.DataSource]:
|
72
|
+
return self._data_sources
|
73
|
+
|
74
|
+
def to_batches(
|
75
|
+
self,
|
76
|
+
batch_size: int,
|
77
|
+
shuffle: bool = True,
|
78
|
+
drop_last_batch: bool = True,
|
79
|
+
) -> Iterator[Dict[str, npt.NDArray[Any]]]:
|
80
|
+
"""Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
|
81
|
+
|
82
|
+
As we are generating batches with the exactly same length, the last few rows in each file might get left as they
|
83
|
+
are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
|
84
|
+
few rows of the next file to generate a new batch.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
batch_size: Specifies the size of each batch that will be yield
|
88
|
+
shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
|
89
|
+
the order of files, and then shuflle the order of rows in each file.
|
90
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last
|
91
|
+
batch will get dropped if its size is smaller than the given batch_size.
|
92
|
+
|
93
|
+
Yields:
|
94
|
+
A dict mapping column names to the corresponding data fetch from that column.
|
95
|
+
"""
|
96
|
+
self._rb_buffer = _RecordBatchesBuffer()
|
97
|
+
|
98
|
+
# Extract schema if not already known
|
99
|
+
dataset = self._get_dataset(shuffle)
|
100
|
+
if self._schema is None:
|
101
|
+
self._schema = dataset.schema
|
102
|
+
|
103
|
+
for rb in _retryable_batches(dataset, batch_size=max(_DEFAULT_DATASET_BATCH_SIZE, batch_size)):
|
104
|
+
if shuffle:
|
105
|
+
rb = rb.take(np.random.permutation(rb.num_rows))
|
106
|
+
self._rb_buffer.append(rb)
|
107
|
+
while self._rb_buffer.num_rows >= batch_size:
|
108
|
+
yield self._get_batches_from_buffer(batch_size)
|
109
|
+
|
110
|
+
if self._rb_buffer.num_rows and not drop_last_batch:
|
111
|
+
yield self._get_batches_from_buffer(batch_size)
|
112
|
+
|
113
|
+
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
|
114
|
+
ds = self._get_dataset(shuffle=False)
|
115
|
+
table = ds.to_table() if limit is None else ds.head(num_rows=limit)
|
116
|
+
return table.to_pandas()
|
117
|
+
|
118
|
+
def _get_dataset(self, shuffle: bool) -> ds.Dataset:
|
119
|
+
format = self._format
|
120
|
+
sources = []
|
121
|
+
source_format = None
|
122
|
+
for source in self._data_sources:
|
123
|
+
if isinstance(source, str):
|
124
|
+
sources.append(source)
|
125
|
+
source_format = format or os.path.splitext(source)[-1]
|
126
|
+
elif isinstance(source, data_source.DatasetInfo):
|
127
|
+
if not self._kwargs.get("filesystem"):
|
128
|
+
self._kwargs["filesystem"] = ingestor_utils.get_dataset_filesystem(self._session, source)
|
129
|
+
sources.extend(
|
130
|
+
ingestor_utils.get_dataset_files(self._session, source, filesystem=self._kwargs["filesystem"])
|
131
|
+
)
|
132
|
+
source_format = "parquet"
|
133
|
+
elif isinstance(source, data_source.DataFrameInfo):
|
134
|
+
# FIXME: This currently loads all result batches into memory so that it
|
135
|
+
# can be passed into pyarrow.dataset as a list/tuple of pa.RecordBatches
|
136
|
+
# We may be able to optimize this by splitting the result batches into
|
137
|
+
# in-memory (first batch) and file URLs (subsequent batches) and creating a
|
138
|
+
# union dataset.
|
139
|
+
result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
|
140
|
+
sources.extend(b.to_arrow() for b in result_batches)
|
141
|
+
source_format = "arrow"
|
142
|
+
else:
|
143
|
+
raise RuntimeError(f"Unsupported data source type: {type(source)}")
|
144
|
+
|
145
|
+
# Make sure source types not mixed
|
146
|
+
if format and format != source_format:
|
147
|
+
raise RuntimeError(f"Unexpected data source format (expected {format}, found {source_format})")
|
148
|
+
format = source_format
|
149
|
+
|
150
|
+
# Re-shuffle input files on each iteration start
|
151
|
+
if shuffle:
|
152
|
+
np.random.shuffle(sources)
|
153
|
+
pa_dataset: ds.Dataset = ds.dataset(sources, format=format, **self._kwargs)
|
154
|
+
return pa_dataset
|
155
|
+
|
156
|
+
def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
|
157
|
+
"""Generate new batches from the existing record batch buffer."""
|
158
|
+
cnt_rbs_num_rows = 0
|
159
|
+
candidates = []
|
160
|
+
|
161
|
+
# Keep popping record batches in buffer until there are enough rows for a batch.
|
162
|
+
while self._rb_buffer.num_rows and cnt_rbs_num_rows < batch_size:
|
163
|
+
candidate = self._rb_buffer.popleft()
|
164
|
+
cnt_rbs_num_rows += candidate.num_rows
|
165
|
+
candidates.append(candidate)
|
166
|
+
|
167
|
+
# When there are more rows than needed, slice the last popped batch to fit batch_size.
|
168
|
+
if cnt_rbs_num_rows > batch_size:
|
169
|
+
row_diff = cnt_rbs_num_rows - batch_size
|
170
|
+
slice_target = candidates[-1]
|
171
|
+
cut_off = slice_target.num_rows - row_diff
|
172
|
+
to_merge = slice_target.slice(length=cut_off)
|
173
|
+
left_over = slice_target.slice(offset=cut_off)
|
174
|
+
candidates[-1] = to_merge
|
175
|
+
self._rb_buffer.appendleft(left_over)
|
176
|
+
|
177
|
+
res = _merge_record_batches(candidates)
|
178
|
+
return _record_batch_to_arrays(res)
|
179
|
+
|
180
|
+
|
181
|
+
def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
|
182
|
+
"""Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
|
183
|
+
if not record_batches:
|
184
|
+
return _EMPTY_RECORD_BATCH
|
185
|
+
if len(record_batches) == 1:
|
186
|
+
return record_batches[0]
|
187
|
+
record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
|
188
|
+
one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
|
189
|
+
batches = one_chunk_table.to_batches(max_chunksize=None)
|
190
|
+
return batches[0]
|
191
|
+
|
192
|
+
|
193
|
+
def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
|
194
|
+
"""Transform the record batch to a (string, numpy array) dict."""
|
195
|
+
batch_dict = {}
|
196
|
+
for column, column_schema in zip(rb, rb.schema):
|
197
|
+
# zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
|
198
|
+
array = column.to_numpy(zero_copy_only=False)
|
199
|
+
batch_dict[column_schema.name] = array
|
200
|
+
return batch_dict
|
201
|
+
|
202
|
+
|
203
|
+
def _retryable_batches(
|
204
|
+
dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
|
205
|
+
) -> Iterator[pa.RecordBatch]:
|
206
|
+
"""Make the Dataset to_batches retryable."""
|
207
|
+
retries = 0
|
208
|
+
current_batch_index = 0
|
209
|
+
|
210
|
+
while True:
|
211
|
+
try:
|
212
|
+
for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
|
213
|
+
if batch_index < current_batch_index:
|
214
|
+
# Skip batches that have already been processed
|
215
|
+
continue
|
216
|
+
|
217
|
+
yield batch
|
218
|
+
current_batch_index = batch_index + 1
|
219
|
+
# Exit the loop once all batches are processed
|
220
|
+
break
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
if retries < max_retries:
|
224
|
+
retries += 1
|
225
|
+
logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
|
226
|
+
time.sleep(delay)
|
227
|
+
else:
|
228
|
+
raise e
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import fsspec
|
4
|
+
|
5
|
+
from snowflake import snowpark
|
6
|
+
from snowflake.connector import result_batch
|
7
|
+
from snowflake.ml.data import data_source
|
8
|
+
from snowflake.ml.fileset import snowfs
|
9
|
+
|
10
|
+
_TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
|
11
|
+
|
12
|
+
|
13
|
+
def get_dataframe_result_batches(
|
14
|
+
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
15
|
+
) -> List[result_batch.ResultBatch]:
|
16
|
+
cursor = session._conn._cursor
|
17
|
+
|
18
|
+
if df_info.query_id:
|
19
|
+
query_id = df_info.query_id
|
20
|
+
else:
|
21
|
+
query_id = session.sql(df_info.sql).collect_nowait().query_id
|
22
|
+
|
23
|
+
# TODO: Check if query result cache is still live
|
24
|
+
cursor.get_results_from_sfqid(sfqid=query_id)
|
25
|
+
|
26
|
+
# Prefetch hook should be set by `get_results_from_sfqid`
|
27
|
+
# This call blocks until the query results are ready
|
28
|
+
if cursor._prefetch_hook is None:
|
29
|
+
raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
|
30
|
+
cursor._prefetch_hook()
|
31
|
+
batches = cursor.get_result_batches()
|
32
|
+
if batches is None:
|
33
|
+
raise ValueError(
|
34
|
+
"Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}"
|
35
|
+
)
|
36
|
+
return batches
|
37
|
+
|
38
|
+
|
39
|
+
def get_dataset_filesystem(
|
40
|
+
session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
|
41
|
+
) -> fsspec.AbstractFileSystem:
|
42
|
+
# We can't directly load the Dataset to avoid a circular dependency
|
43
|
+
# Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
|
44
|
+
# TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
|
45
|
+
return snowfs.SnowFileSystem(
|
46
|
+
snowpark_session=session,
|
47
|
+
cache_type="bytes",
|
48
|
+
block_size=2 * _TARGET_FILE_SIZE,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
def get_dataset_files(
|
53
|
+
session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
|
54
|
+
) -> List[str]:
|
55
|
+
if filesystem is None:
|
56
|
+
filesystem = get_dataset_filesystem(session, ds_info)
|
57
|
+
assert bool(ds_info.url) # Not null or empty
|
58
|
+
return sorted(filesystem.ls(ds_info.url))
|