snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +101 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +24 -18
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +13 -8
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +3 -3
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +16 -178
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
- snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -1,39 +0,0 @@
|
|
1
|
-
import http
|
2
|
-
|
3
|
-
import requests
|
4
|
-
from requests import adapters
|
5
|
-
from urllib3.util import retry
|
6
|
-
|
7
|
-
|
8
|
-
def get_http_client(total_retries: int = 5, backoff_factor: float = 0.1) -> requests.Session:
|
9
|
-
"""Construct retryable http client.
|
10
|
-
|
11
|
-
Args:
|
12
|
-
total_retries: Total number of retries to allow.
|
13
|
-
backoff_factor: A backoff factor to apply between attempts after the second try. Time to sleep is calculated by
|
14
|
-
{backoff factor} * (2 ** ({number of previous retries})). For example, with default retries of 5 and backoff
|
15
|
-
factor set to 0.1, each subsequent retry will sleep [0.2s, 0.4s, 0.8s, 1.6s, 3.2s] respectively.
|
16
|
-
|
17
|
-
Returns:
|
18
|
-
requests.Session object.
|
19
|
-
|
20
|
-
"""
|
21
|
-
|
22
|
-
retry_strategy = retry.Retry(
|
23
|
-
total=total_retries,
|
24
|
-
backoff_factor=backoff_factor,
|
25
|
-
status_forcelist=[
|
26
|
-
http.HTTPStatus.TOO_MANY_REQUESTS,
|
27
|
-
http.HTTPStatus.INTERNAL_SERVER_ERROR,
|
28
|
-
http.HTTPStatus.BAD_GATEWAY,
|
29
|
-
http.HTTPStatus.SERVICE_UNAVAILABLE,
|
30
|
-
http.HTTPStatus.GATEWAY_TIMEOUT,
|
31
|
-
], # retry on these status codes
|
32
|
-
)
|
33
|
-
|
34
|
-
adapter = adapters.HTTPAdapter(max_retries=retry_strategy)
|
35
|
-
req_session = requests.Session()
|
36
|
-
req_session.mount("https://", adapter)
|
37
|
-
req_session.mount("http://", adapter)
|
38
|
-
|
39
|
-
return req_session
|
@@ -1,170 +0,0 @@
|
|
1
|
-
import collections
|
2
|
-
import logging
|
3
|
-
import time
|
4
|
-
from typing import Any, Deque, Dict, Iterator, List
|
5
|
-
|
6
|
-
import fsspec
|
7
|
-
import numpy as np
|
8
|
-
import numpy.typing as npt
|
9
|
-
import pyarrow as pa
|
10
|
-
import pyarrow.dataset as ds
|
11
|
-
|
12
|
-
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
13
|
-
|
14
|
-
# The row count for batches read from PyArrow Dataset. This number should be large enough so that
|
15
|
-
# dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
|
16
|
-
_DEFAULT_DATASET_BATCH_SIZE = 1000000
|
17
|
-
|
18
|
-
|
19
|
-
class _RecordBatchesBuffer:
|
20
|
-
"""A queue that stores record batches and tracks the total num of rows in it."""
|
21
|
-
|
22
|
-
def __init__(self) -> None:
|
23
|
-
self.buffer: Deque[pa.RecordBatch] = collections.deque()
|
24
|
-
self.num_rows = 0
|
25
|
-
|
26
|
-
def append(self, rb: pa.RecordBatch) -> None:
|
27
|
-
self.buffer.append(rb)
|
28
|
-
self.num_rows += rb.num_rows
|
29
|
-
|
30
|
-
def appendleft(self, rb: pa.RecordBatch) -> None:
|
31
|
-
self.buffer.appendleft(rb)
|
32
|
-
self.num_rows += rb.num_rows
|
33
|
-
|
34
|
-
def popleft(self) -> pa.RecordBatch:
|
35
|
-
popped = self.buffer.popleft()
|
36
|
-
self.num_rows -= popped.num_rows
|
37
|
-
return popped
|
38
|
-
|
39
|
-
|
40
|
-
class ParquetParser:
|
41
|
-
"""Read and parse the given parquet files and yield batched numpy array in dict.
|
42
|
-
|
43
|
-
Args:
|
44
|
-
file_paths: A list of parquet file URIs to read and parse.
|
45
|
-
filesystem: A fsspec/pyarrow file system that is used to open given file URIs.
|
46
|
-
batch_size: Specifies the size of each batch that will be yield
|
47
|
-
shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
|
48
|
-
the order of files, and then shuflle the order of rows in each file.
|
49
|
-
drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last batch will
|
50
|
-
get dropped if its size is smaller than the given batch_size.
|
51
|
-
|
52
|
-
Returns:
|
53
|
-
A PyTorch iterable datapipe that yields batched numpy array in dict. The keys will be the column names in
|
54
|
-
the parquet files, and the value will be the column value as a list.
|
55
|
-
"""
|
56
|
-
|
57
|
-
def __init__(
|
58
|
-
self,
|
59
|
-
file_paths: List[str],
|
60
|
-
filesystem: fsspec.AbstractFileSystem,
|
61
|
-
batch_size: int,
|
62
|
-
shuffle: bool = True,
|
63
|
-
drop_last_batch: bool = True,
|
64
|
-
) -> None:
|
65
|
-
self._file_paths = file_paths
|
66
|
-
self._fs = filesystem
|
67
|
-
self._batch_size = batch_size
|
68
|
-
self._dataset_batch_size = max(_DEFAULT_DATASET_BATCH_SIZE, self._batch_size)
|
69
|
-
self._shuffle = shuffle
|
70
|
-
self._drop_last_batch = drop_last_batch
|
71
|
-
|
72
|
-
def __iter__(self) -> Iterator[Dict[str, npt.NDArray[Any]]]:
|
73
|
-
"""Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
|
74
|
-
|
75
|
-
As we are generating batches with the exactly same length, the last few rows in each file might get left as they
|
76
|
-
are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
|
77
|
-
few rows of the next file to generate a new batch.
|
78
|
-
|
79
|
-
Yields:
|
80
|
-
A dict mapping column names to the corresponding data fetch from that column.
|
81
|
-
"""
|
82
|
-
self._rb_buffer = _RecordBatchesBuffer()
|
83
|
-
files = list(self._file_paths)
|
84
|
-
if self._shuffle:
|
85
|
-
np.random.shuffle(files)
|
86
|
-
pa_dataset: ds.Dataset = ds.dataset(files, format="parquet", filesystem=self._fs)
|
87
|
-
|
88
|
-
for rb in _retryable_batches(pa_dataset, batch_size=self._dataset_batch_size):
|
89
|
-
if self._shuffle:
|
90
|
-
rb = rb.take(np.random.permutation(rb.num_rows))
|
91
|
-
self._rb_buffer.append(rb)
|
92
|
-
while self._rb_buffer.num_rows >= self._batch_size:
|
93
|
-
yield self._get_batches_from_buffer()
|
94
|
-
|
95
|
-
if self._rb_buffer.num_rows and not self._drop_last_batch:
|
96
|
-
yield self._get_batches_from_buffer()
|
97
|
-
|
98
|
-
def _get_batches_from_buffer(self) -> Dict[str, npt.NDArray[Any]]:
|
99
|
-
"""Generate new batches from the existing record batch buffer."""
|
100
|
-
cnt_rbs_num_rows = 0
|
101
|
-
candidates = []
|
102
|
-
|
103
|
-
# Keep popping record batches in buffer until there are enough rows for a batch.
|
104
|
-
while self._rb_buffer.num_rows and cnt_rbs_num_rows < self._batch_size:
|
105
|
-
candidate = self._rb_buffer.popleft()
|
106
|
-
cnt_rbs_num_rows += candidate.num_rows
|
107
|
-
candidates.append(candidate)
|
108
|
-
|
109
|
-
# When there are more rows than needed, slice the last popped batch to fit batch_size.
|
110
|
-
if cnt_rbs_num_rows > self._batch_size:
|
111
|
-
row_diff = cnt_rbs_num_rows - self._batch_size
|
112
|
-
slice_target = candidates[-1]
|
113
|
-
cut_off = slice_target.num_rows - row_diff
|
114
|
-
to_merge = slice_target.slice(length=cut_off)
|
115
|
-
left_over = slice_target.slice(offset=cut_off)
|
116
|
-
candidates[-1] = to_merge
|
117
|
-
self._rb_buffer.appendleft(left_over)
|
118
|
-
|
119
|
-
res = _merge_record_batches(candidates)
|
120
|
-
return _record_batch_to_arrays(res)
|
121
|
-
|
122
|
-
|
123
|
-
def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
|
124
|
-
"""Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
|
125
|
-
if not record_batches:
|
126
|
-
return _EMPTY_RECORD_BATCH
|
127
|
-
if len(record_batches) == 1:
|
128
|
-
return record_batches[0]
|
129
|
-
record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
|
130
|
-
one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
|
131
|
-
batches = one_chunk_table.to_batches(max_chunksize=None)
|
132
|
-
return batches[0]
|
133
|
-
|
134
|
-
|
135
|
-
def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
|
136
|
-
"""Transform the record batch to a (string, numpy array) dict."""
|
137
|
-
batch_dict = {}
|
138
|
-
for column, column_schema in zip(rb, rb.schema):
|
139
|
-
# zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
|
140
|
-
array = column.to_numpy(zero_copy_only=False)
|
141
|
-
batch_dict[column_schema.name] = array
|
142
|
-
return batch_dict
|
143
|
-
|
144
|
-
|
145
|
-
def _retryable_batches(
|
146
|
-
dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
|
147
|
-
) -> Iterator[pa.RecordBatch]:
|
148
|
-
"""Make the Dataset to_batches retryable."""
|
149
|
-
retries = 0
|
150
|
-
current_batch_index = 0
|
151
|
-
|
152
|
-
while True:
|
153
|
-
try:
|
154
|
-
for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
|
155
|
-
if batch_index < current_batch_index:
|
156
|
-
# Skip batches that have already been processed
|
157
|
-
continue
|
158
|
-
|
159
|
-
yield batch
|
160
|
-
current_batch_index = batch_index + 1
|
161
|
-
# Exit the loop once all batches are processed
|
162
|
-
break
|
163
|
-
|
164
|
-
except Exception as e:
|
165
|
-
if retries < max_retries:
|
166
|
-
retries += 1
|
167
|
-
logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
|
168
|
-
time.sleep(delay)
|
169
|
-
else:
|
170
|
-
raise e
|
@@ -1,88 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, Generator, List
|
2
|
-
|
3
|
-
import fsspec
|
4
|
-
import numpy.typing as npt
|
5
|
-
import pyarrow as pa
|
6
|
-
import pyarrow.parquet as pq
|
7
|
-
import tensorflow as tf
|
8
|
-
|
9
|
-
from snowflake.ml._internal.exceptions import (
|
10
|
-
error_codes,
|
11
|
-
exceptions as snowml_exceptions,
|
12
|
-
)
|
13
|
-
from snowflake.ml.fileset import parquet_parser
|
14
|
-
|
15
|
-
|
16
|
-
def read_and_parse_parquet(
|
17
|
-
files: List[str],
|
18
|
-
filesystem: fsspec.AbstractFileSystem,
|
19
|
-
batch_size: int,
|
20
|
-
shuffle: bool,
|
21
|
-
drop_last_batch: bool,
|
22
|
-
) -> tf.data.Dataset:
|
23
|
-
"""Creates a tf.data.Dataset that reads given parquet files into batched Tensors.
|
24
|
-
|
25
|
-
Args:
|
26
|
-
files: A list of input parquet file URIs to read and parse. The parquet files should
|
27
|
-
have the same schema.
|
28
|
-
filesystem: A fsspec/pyarrow file system that is used to open given file URIs.
|
29
|
-
batch_size: Specifies the size of each batch that will be yield. It is preferred to
|
30
|
-
set it to your training batch size, and avoid using dataset.{batch(),rebatch()} later.
|
31
|
-
shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
|
32
|
-
the order of files, and then shuflle the order of rows in each file. It is preferred
|
33
|
-
to shuffle the data this way than dataset.unbatch().shuffle().rebatch().
|
34
|
-
drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last batch will
|
35
|
-
get dropped if its size is smaller than the given batch_size.
|
36
|
-
|
37
|
-
Returns:
|
38
|
-
A tf.data.Dataset generates batched Tensors in a dict. The keys will be the column names in
|
39
|
-
the parquet files.
|
40
|
-
|
41
|
-
Raises:
|
42
|
-
SnowflakeMLException: if `files` is empty.
|
43
|
-
|
44
|
-
Example:
|
45
|
-
>>> from snowflake.ml.fileset import sfcfs, tf_dataset
|
46
|
-
>>> conn = snowflake.connector.connect(**connection_parameters)
|
47
|
-
>>> fs = sfcfs.SFFileSystem(conn)
|
48
|
-
>>> files = fs.ls(dir_path)
|
49
|
-
>>> ds = tf_dataset.parse_and_read_parquet(files, fs, batch_size = 2)
|
50
|
-
>>> for batch in ds:
|
51
|
-
>>> print(batch)
|
52
|
-
----
|
53
|
-
{'_COL_1': <tf.Tensor: shape=(2,), dtype=float32, numpy=[32.5000, 6.0000]>,
|
54
|
-
'_COL_2': <tf.Tensor: shape=(2,), dtype=float32, numpy=[-73.9542, -73.9875]>}
|
55
|
-
"""
|
56
|
-
if not files:
|
57
|
-
raise snowml_exceptions.SnowflakeMLException(
|
58
|
-
error_code=error_codes.SNOWML_READ_FAILED,
|
59
|
-
original_exception=ValueError("At least one file is needed to create a TF dataset."),
|
60
|
-
)
|
61
|
-
|
62
|
-
def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
|
63
|
-
yield from parquet_parser.ParquetParser(list(files), filesystem, batch_size, shuffle, drop_last_batch)
|
64
|
-
|
65
|
-
return tf.data.Dataset.from_generator(generator, output_signature=_derive_signature(files[0], filesystem))
|
66
|
-
|
67
|
-
|
68
|
-
def _arrow_type_to_tensor_spec(field: pa.Field) -> tf.TensorSpec:
|
69
|
-
try:
|
70
|
-
dtype = tf.dtypes.as_dtype(field.type.to_pandas_dtype())
|
71
|
-
except TypeError:
|
72
|
-
raise snowml_exceptions.SnowflakeMLException(
|
73
|
-
error_code=error_codes.INVALID_DATA_TYPE,
|
74
|
-
original_exception=TypeError(f"Column {field.name} has unsupportd type {field.type}."),
|
75
|
-
)
|
76
|
-
# First dimension is batch dimension.
|
77
|
-
return tf.TensorSpec(shape=(None,), dtype=dtype)
|
78
|
-
|
79
|
-
|
80
|
-
def _derive_signature(file: str, filesystem: fsspec.AbstractFileSystem) -> Dict[str, tf.TensorSpec]:
|
81
|
-
"""Derives the signature of the TF dataset from one parquet file."""
|
82
|
-
# TODO(zpeng): pq.read_schema does not support `filesystem` until pyarrow>=10.
|
83
|
-
# switch to pq.read_schema when we depend on that.
|
84
|
-
schema = pq.read_table(file, filesystem=filesystem).schema
|
85
|
-
# Signature:
|
86
|
-
# The dataset yields dicts. Keys are column names; values are 1-D tensors (
|
87
|
-
# the first dimension is batch dimension).
|
88
|
-
return {field.name: _arrow_type_to_tensor_spec(field) for field in schema}
|
@@ -1,57 +0,0 @@
|
|
1
|
-
from typing import Any, Dict, Iterator
|
2
|
-
|
3
|
-
import fsspec
|
4
|
-
import numpy.typing as npt
|
5
|
-
from torchdata.datapipes.iter import IterDataPipe
|
6
|
-
|
7
|
-
from snowflake.ml.fileset import parquet_parser
|
8
|
-
|
9
|
-
|
10
|
-
class ReadAndParseParquet(IterDataPipe):
|
11
|
-
"""Read and parse the parquet files yield batched numpy array in dict.
|
12
|
-
|
13
|
-
Args:
|
14
|
-
input_datapipe: A datapipe of input parquet file URIs to read and parse.
|
15
|
-
Note that the datapipe must be finite.
|
16
|
-
filesystem: A fsspec/pyarrow file system that is used to open given file URIs.
|
17
|
-
batch_size: Specifies the size of each batch that will be yield
|
18
|
-
shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
|
19
|
-
the order of files, and then shuflle the order of rows in each file.
|
20
|
-
drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last batch will
|
21
|
-
get dropped if its size is smaller than the given batch_size.
|
22
|
-
|
23
|
-
Returns:
|
24
|
-
A PyTorch iterable datapipe that yields batched numpy array in dict. The keys will be the column names in
|
25
|
-
the parquet files.
|
26
|
-
|
27
|
-
Example:
|
28
|
-
>>> from snowflake.ml.fileset import sfcfs, torch_datapipe
|
29
|
-
>>> from torchdata.datapipes.iter import FSSpecFileLister
|
30
|
-
>>> conn = snowflake.connector.connect(**connection_parameters)
|
31
|
-
>>> fs = sfcfs.SFFileSystem(conn)
|
32
|
-
>>> filedp = FSSpecFileLister(root=dir_path, masks="*.parquet", mode="rb", sf_connection=conn)
|
33
|
-
>>> parquet_dp = torch_datapipe.ReadAndParseParquet(file_dp, fs, batch_size = 2)
|
34
|
-
>>> for batch in parquet_dp:
|
35
|
-
>>> print(batch)
|
36
|
-
----
|
37
|
-
{'_COL_1': [32.5000, 6.0000], '_COL_2': [-73.9542, -73.9875]}
|
38
|
-
"""
|
39
|
-
|
40
|
-
def __init__(
|
41
|
-
self,
|
42
|
-
input_datapipe: IterDataPipe[str],
|
43
|
-
filesystem: fsspec.AbstractFileSystem,
|
44
|
-
batch_size: int,
|
45
|
-
shuffle: bool,
|
46
|
-
drop_last_batch: bool,
|
47
|
-
) -> None:
|
48
|
-
self._input_datapipe = input_datapipe
|
49
|
-
self._fs = filesystem
|
50
|
-
self._batch_size = batch_size
|
51
|
-
self._shuffle = shuffle
|
52
|
-
self._drop_last_batch = drop_last_batch
|
53
|
-
|
54
|
-
def __iter__(self) -> Iterator[Dict[str, npt.NDArray[Any]]]:
|
55
|
-
yield from parquet_parser.ParquetParser(
|
56
|
-
list(self._input_datapipe), self._fs, self._batch_size, self._shuffle, self._drop_last_batch
|
57
|
-
)
|
@@ -1,151 +0,0 @@
|
|
1
|
-
from typing import Any, List, Optional
|
2
|
-
|
3
|
-
from snowflake.ml.modeling._internal.snowpark_implementations.snowpark_handlers import (
|
4
|
-
SnowparkTransformHandlers,
|
5
|
-
)
|
6
|
-
from snowflake.snowpark import DataFrame, Session
|
7
|
-
|
8
|
-
|
9
|
-
class MLRuntimeTransformHandlers:
|
10
|
-
def __init__(
|
11
|
-
self,
|
12
|
-
dataset: DataFrame,
|
13
|
-
estimator: object,
|
14
|
-
class_name: str,
|
15
|
-
subproject: str,
|
16
|
-
autogenerated: Optional[bool] = False,
|
17
|
-
) -> None:
|
18
|
-
"""
|
19
|
-
Args:
|
20
|
-
dataset: The dataset to run transform functions on.
|
21
|
-
estimator: The estimator used to run transforms.
|
22
|
-
class_name: class name to be used in telemetry.
|
23
|
-
subproject: subproject to be used in telemetry.
|
24
|
-
autogenerated: Whether the class was autogenerated from a template.
|
25
|
-
|
26
|
-
Raises:
|
27
|
-
ModuleNotFoundError: The mlruntimes_client module is not available.
|
28
|
-
"""
|
29
|
-
try:
|
30
|
-
from snowflake.ml.runtime import MLRuntimeClient
|
31
|
-
except ModuleNotFoundError as e:
|
32
|
-
# This is an internal exception, not a user-facing one. The snowflake.ml.runtime module should
|
33
|
-
# always be present when this class is instantiated.
|
34
|
-
raise e
|
35
|
-
|
36
|
-
self.client = MLRuntimeClient()
|
37
|
-
self.dataset = dataset
|
38
|
-
self.estimator = estimator
|
39
|
-
self._class_name = class_name
|
40
|
-
self._subproject = subproject
|
41
|
-
self._autogenerated = autogenerated
|
42
|
-
|
43
|
-
def batch_inference(
|
44
|
-
self,
|
45
|
-
inference_method: str,
|
46
|
-
input_cols: List[str],
|
47
|
-
expected_output_cols: List[str],
|
48
|
-
session: Session,
|
49
|
-
dependencies: List[str],
|
50
|
-
drop_input_cols: Optional[bool] = False,
|
51
|
-
expected_output_cols_type: Optional[str] = "",
|
52
|
-
*args: Any,
|
53
|
-
**kwargs: Any,
|
54
|
-
) -> DataFrame:
|
55
|
-
"""Run batch inference on the given dataset.
|
56
|
-
Temporary workaround - pushdown implementation is not currently ready for batch_inference.
|
57
|
-
We use a SnowparkTransformHandlers until we have a way to use the runtime client.
|
58
|
-
|
59
|
-
Args:
|
60
|
-
inference_method: the name of the method used by `estimator` to run inference.
|
61
|
-
input_cols: List of feature columns for inference.
|
62
|
-
session: An active Snowpark Session.
|
63
|
-
dependencies: List of dependencies for the transformer.
|
64
|
-
expected_output_cols: column names (in order) of the output dataset.
|
65
|
-
drop_input_cols: Boolean to determine whether to drop the input columns from the output dataset.
|
66
|
-
expected_output_cols_type: Expected type of the output columns.
|
67
|
-
args: additional positional arguments.
|
68
|
-
kwargs: additional keyword args.
|
69
|
-
|
70
|
-
Returns:
|
71
|
-
A new dataset of the same type as the input dataset.
|
72
|
-
|
73
|
-
"""
|
74
|
-
|
75
|
-
mlrs_inference_methods = ["predict", "predict_proba", "predict_log_proba"]
|
76
|
-
|
77
|
-
if inference_method in mlrs_inference_methods:
|
78
|
-
result_df = self.client.inference(
|
79
|
-
estimator=self.estimator,
|
80
|
-
dataset=self.dataset,
|
81
|
-
inference_method=inference_method,
|
82
|
-
input_cols=input_cols,
|
83
|
-
output_cols=expected_output_cols,
|
84
|
-
drop_input_cols=drop_input_cols,
|
85
|
-
)
|
86
|
-
|
87
|
-
else:
|
88
|
-
handler = SnowparkTransformHandlers(
|
89
|
-
dataset=self.dataset,
|
90
|
-
estimator=self.estimator,
|
91
|
-
class_name=self._class_name,
|
92
|
-
subproject=self._subproject,
|
93
|
-
autogenerated=self._autogenerated,
|
94
|
-
)
|
95
|
-
result_df = handler.batch_inference(
|
96
|
-
inference_method,
|
97
|
-
input_cols,
|
98
|
-
expected_output_cols,
|
99
|
-
session,
|
100
|
-
dependencies,
|
101
|
-
drop_input_cols,
|
102
|
-
expected_output_cols_type,
|
103
|
-
*args,
|
104
|
-
**kwargs,
|
105
|
-
)
|
106
|
-
|
107
|
-
assert isinstance(result_df, DataFrame) # mypy - The MLRS return types are annotated as `object`.
|
108
|
-
return result_df
|
109
|
-
|
110
|
-
def score(
|
111
|
-
self,
|
112
|
-
input_cols: List[str],
|
113
|
-
label_cols: List[str],
|
114
|
-
session: Session,
|
115
|
-
dependencies: List[str],
|
116
|
-
score_sproc_imports: List[str],
|
117
|
-
sample_weight_col: Optional[str] = None,
|
118
|
-
*args: Any,
|
119
|
-
**kwargs: Any,
|
120
|
-
) -> float:
|
121
|
-
"""Score the given test dataset.
|
122
|
-
|
123
|
-
Args:
|
124
|
-
session: An active Snowpark Session.
|
125
|
-
dependencies: score function dependencies.
|
126
|
-
score_sproc_imports: imports for score stored procedure.
|
127
|
-
input_cols: List of feature columns for inference.
|
128
|
-
label_cols: List of label columns for scoring.
|
129
|
-
sample_weight_col: A column assigning relative weights to each row for scoring.
|
130
|
-
args: additional positional arguments.
|
131
|
-
kwargs: additional keyword args.
|
132
|
-
|
133
|
-
|
134
|
-
Returns:
|
135
|
-
An accuracy score for the model on the given test data.
|
136
|
-
|
137
|
-
Raises:
|
138
|
-
TypeError: The ML Runtimes client returned a non-float result
|
139
|
-
"""
|
140
|
-
output_score = self.client.score(
|
141
|
-
estimator=self.estimator,
|
142
|
-
dataset=self.dataset,
|
143
|
-
input_cols=input_cols,
|
144
|
-
label_cols=label_cols,
|
145
|
-
sample_weight_col=sample_weight_col,
|
146
|
-
)
|
147
|
-
if not isinstance(output_score, float):
|
148
|
-
raise TypeError(
|
149
|
-
f"The ML Runtimes Client returned a non-float value {output_score} of type {type(output_score)}"
|
150
|
-
)
|
151
|
-
return output_score
|
@@ -1,66 +0,0 @@
|
|
1
|
-
from typing import List, Optional
|
2
|
-
|
3
|
-
from snowflake.snowpark import DataFrame, Session
|
4
|
-
|
5
|
-
|
6
|
-
class MLRuntimeModelTrainer:
|
7
|
-
"""ML model training using the ml runties client."""
|
8
|
-
|
9
|
-
def __init__(
|
10
|
-
self,
|
11
|
-
estimator: object,
|
12
|
-
dataset: DataFrame,
|
13
|
-
session: Session,
|
14
|
-
input_cols: List[str],
|
15
|
-
label_cols: Optional[List[str]],
|
16
|
-
sample_weight_col: Optional[str],
|
17
|
-
autogenerated: bool = False,
|
18
|
-
subproject: str = "",
|
19
|
-
) -> None:
|
20
|
-
"""
|
21
|
-
Initializes the MLRuntimeModelTrainer with a model, a Snowpark DataFrame, feature, and label column names.
|
22
|
-
|
23
|
-
Args:
|
24
|
-
estimator: SKLearn compatible estimator or transformer object.
|
25
|
-
dataset: The dataset used for training the model.
|
26
|
-
session: Snowflake session object to be used for training.
|
27
|
-
input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
|
28
|
-
label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
|
29
|
-
sample_weight_col: The column name representing the weight of training examples.
|
30
|
-
autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
|
31
|
-
subproject: subproject name to be used in telemetry.
|
32
|
-
|
33
|
-
Raises:
|
34
|
-
ModuleNotFoundError: The mlruntimes_client module is not available.
|
35
|
-
"""
|
36
|
-
|
37
|
-
try:
|
38
|
-
from snowflake.ml.runtime import MLRuntimeClient
|
39
|
-
except ModuleNotFoundError as e:
|
40
|
-
# This is an internal exception, not a user-facing one. The snowflake.ml.runtime module should
|
41
|
-
# always be present when this class is instantiated.
|
42
|
-
raise e
|
43
|
-
|
44
|
-
self.client = MLRuntimeClient()
|
45
|
-
|
46
|
-
self.estimator = estimator
|
47
|
-
self.dataset = dataset
|
48
|
-
self.session = session
|
49
|
-
self.input_cols = input_cols
|
50
|
-
self.label_cols = label_cols
|
51
|
-
self.sample_weight_col = sample_weight_col
|
52
|
-
self._autogenerated = autogenerated
|
53
|
-
self._subproject = subproject
|
54
|
-
self._class_name = estimator.__class__.__name__
|
55
|
-
|
56
|
-
def train(self) -> object:
|
57
|
-
"""
|
58
|
-
Trains the model by pushing down the compute into SPCS ML Runtime
|
59
|
-
"""
|
60
|
-
return self.client.train(
|
61
|
-
estimator=self.estimator,
|
62
|
-
dataset=self.dataset,
|
63
|
-
input_cols=self.input_cols,
|
64
|
-
label_cols=self.label_cols,
|
65
|
-
sample_weight_col=self.sample_weight_col,
|
66
|
-
)
|
File without changes
|
File without changes
|