snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +26 -5
- snowflake/cortex/_sentiment.py +7 -4
- snowflake/cortex/_sse_client.py +81 -0
- snowflake/cortex/_util.py +105 -8
- snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
- snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
- snowflake/ml/dataset/dataset.py +15 -12
- snowflake/ml/dataset/dataset_factory.py +3 -4
- snowflake/ml/feature_store/access_manager.py +34 -30
- snowflake/ml/feature_store/feature_store.py +3 -3
- snowflake/ml/feature_store/feature_view.py +12 -11
- snowflake/ml/fileset/snowfs.py +2 -31
- snowflake/ml/model/_client/ops/model_ops.py +43 -0
- snowflake/ml/model/_client/sql/model_version.py +55 -3
- snowflake/ml/model/_model_composer/model_composer.py +7 -3
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
- snowflake/ml/model/_signatures/builtins_handler.py +2 -1
- snowflake/ml/model/_signatures/core.py +13 -1
- snowflake/ml/model/_signatures/pandas_handler.py +2 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
- snowflake/ml/modeling/cluster/birch.py +9 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
- snowflake/ml/modeling/cluster/dbscan.py +9 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
- snowflake/ml/modeling/cluster/k_means.py +9 -2
- snowflake/ml/modeling/cluster/mean_shift.py +9 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
- snowflake/ml/modeling/cluster/optics.py +9 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
- snowflake/ml/modeling/compose/column_transformer.py +9 -2
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
- snowflake/ml/modeling/covariance/oas.py +9 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
- snowflake/ml/modeling/decomposition/pca.py +9 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
- snowflake/ml/modeling/framework/base.py +3 -8
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
- snowflake/ml/modeling/impute/knn_imputer.py +9 -2
- snowflake/ml/modeling/impute/missing_indicator.py +9 -2
- snowflake/ml/modeling/impute/simple_imputer.py +28 -5
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/lars.py +9 -2
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/perceptron.py +9 -2
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/ridge.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
- snowflake/ml/modeling/manifold/isomap.py +9 -2
- snowflake/ml/modeling/manifold/mds.py +9 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
- snowflake/ml/modeling/manifold/tsne.py +9 -2
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
- snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
- snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +5 -0
- snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
- snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
- snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
- snowflake/ml/modeling/svm/linear_svc.py +9 -2
- snowflake/ml/modeling/svm/linear_svr.py +9 -2
- snowflake/ml/modeling/svm/nu_svc.py +9 -2
- snowflake/ml/modeling/svm/nu_svr.py +9 -2
- snowflake/ml/modeling/svm/svc.py +9 -2
- snowflake/ml/modeling/svm/svr.py +9 -2
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
- snowflake/ml/registry/_manager/model_manager.py +59 -1
- snowflake/ml/registry/registry.py +10 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -7,10 +7,6 @@ from dataclasses import asdict, dataclass
|
|
7
7
|
from enum import Enum
|
8
8
|
from typing import Any, Dict, List, Optional
|
9
9
|
|
10
|
-
from snowflake.ml._internal.exceptions import (
|
11
|
-
error_codes,
|
12
|
-
exceptions as snowml_exceptions,
|
13
|
-
)
|
14
10
|
from snowflake.ml._internal.utils.identifier import concat_names
|
15
11
|
from snowflake.ml._internal.utils.sql_identifier import (
|
16
12
|
SqlIdentifier,
|
@@ -34,6 +30,11 @@ _FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
|
|
34
30
|
_FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$")
|
35
31
|
_FEATURE_VIEW_VERSION_MAX_LENGTH = 128
|
36
32
|
|
33
|
+
_RESULT_SCAN_QUERY_PATTERN = re.compile(
|
34
|
+
r".*FROM\s*TABLE\s*\(\s*RESULT_SCAN\s*\(.*",
|
35
|
+
flags=re.DOTALL | re.IGNORECASE | re.X,
|
36
|
+
)
|
37
|
+
|
37
38
|
|
38
39
|
@dataclass(frozen=True)
|
39
40
|
class _FeatureViewMetadata:
|
@@ -54,13 +55,10 @@ class _FeatureViewMetadata:
|
|
54
55
|
class FeatureViewVersion(str):
|
55
56
|
def __new__(cls, version: str) -> FeatureViewVersion:
|
56
57
|
if not _FEATURE_VIEW_VERSION_RE.match(version) or len(version) > _FEATURE_VIEW_VERSION_MAX_LENGTH:
|
57
|
-
raise
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
"It must start with letter or digit, and followed by letter, digit, '_', '-' or '.'. "
|
62
|
-
f"The length limit is {_FEATURE_VIEW_VERSION_MAX_LENGTH}."
|
63
|
-
),
|
58
|
+
raise ValueError(
|
59
|
+
f"`{version}` is not a valid feature view version. "
|
60
|
+
"It must start with letter or digit, and followed by letter, digit, '_', '-' or '.'. "
|
61
|
+
f"The length limit is {_FEATURE_VIEW_VERSION_MAX_LENGTH}."
|
64
62
|
)
|
65
63
|
return super().__new__(cls, version)
|
66
64
|
|
@@ -352,6 +350,9 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
352
350
|
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
353
351
|
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
354
352
|
|
353
|
+
if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
|
354
|
+
raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
|
355
|
+
|
355
356
|
def _get_feature_names(self) -> List[SqlIdentifier]:
|
356
357
|
join_keys = [k for e in self._entities for k in e.join_keys]
|
357
358
|
ts_col = [self._timestamp_col] if self._timestamp_col is not None else []
|
snowflake/ml/fileset/snowfs.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1
1
|
import collections
|
2
2
|
import logging
|
3
3
|
import re
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional
|
5
5
|
|
6
6
|
import fsspec
|
7
|
-
import packaging.version as pkg_version
|
8
7
|
|
9
8
|
from snowflake import snowpark
|
10
9
|
from snowflake.connector import connection
|
@@ -12,7 +11,7 @@ from snowflake.ml._internal.exceptions import (
|
|
12
11
|
error_codes,
|
13
12
|
exceptions as snowml_exceptions,
|
14
13
|
)
|
15
|
-
from snowflake.ml._internal.utils import identifier
|
14
|
+
from snowflake.ml._internal.utils import identifier
|
16
15
|
from snowflake.ml.fileset import embedded_stage_fs, sfcfs
|
17
16
|
|
18
17
|
PROTOCOL_NAME = "snow"
|
@@ -28,10 +27,6 @@ _SNOWURL_PATTERN = re.compile(
|
|
28
27
|
r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)"
|
29
28
|
)
|
30
29
|
|
31
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
32
|
-
_BUG_VERSION_MIN = pkg_version.Version("8.17") # Inclusive minimum version with bugged behavior
|
33
|
-
_BUG_VERSION_MAX = pkg_version.Version("8.18") # Exclusive maximum version with bugged behavior
|
34
|
-
|
35
30
|
|
36
31
|
class SnowFileSystem(sfcfs.SFFileSystem):
|
37
32
|
"""A filesystem that allows user to access Snowflake embedded stage files with valid Snowflake locations.
|
@@ -54,21 +49,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
|
|
54
49
|
) -> None:
|
55
50
|
super().__init__(sf_connection=sf_connection, snowpark_session=snowpark_session, **kwargs)
|
56
51
|
|
57
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
58
|
-
if SnowFileSystem._IS_BUGGED_VERSION is None:
|
59
|
-
try:
|
60
|
-
sf_version = snowflake_env.get_current_snowflake_version(self._session)
|
61
|
-
SnowFileSystem._IS_BUGGED_VERSION = _BUG_VERSION_MIN <= sf_version < _BUG_VERSION_MAX
|
62
|
-
except Exception:
|
63
|
-
SnowFileSystem._IS_BUGGED_VERSION = False
|
64
|
-
|
65
|
-
def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
|
66
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
67
|
-
res: Dict[str, Any] = super().info(path, **kwargs)
|
68
|
-
if res.get("type") == "directory" and not res["name"].endswith("/"):
|
69
|
-
res["name"] += "/"
|
70
|
-
return res
|
71
|
-
|
72
52
|
def _get_stage_fs(
|
73
53
|
self, sf_file_path: _SFFileEntityPath # type: ignore[override]
|
74
54
|
) -> embedded_stage_fs.SFEmbeddedStageFileSystem:
|
@@ -100,12 +80,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
|
|
100
80
|
if stage_name.startswith(protocol):
|
101
81
|
stage_name = stage_name[len(protocol) :]
|
102
82
|
abs_path = stage_name + "/" + path
|
103
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
104
|
-
if self._IS_BUGGED_VERSION:
|
105
|
-
match = _SNOWURL_PATTERN.fullmatch(abs_path)
|
106
|
-
assert match is not None
|
107
|
-
if match.group("relpath"):
|
108
|
-
abs_path = abs_path.replace(match.group("relpath"), match.group("relpath").lstrip("/"))
|
109
83
|
return abs_path
|
110
84
|
|
111
85
|
@classmethod
|
@@ -144,9 +118,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
|
|
144
118
|
version = snowurl_match.group("version")
|
145
119
|
relative_path = snowurl_match.group("relpath") or ""
|
146
120
|
logging.debug(f"Parsed snow URL: {snowurl_match.groups()}")
|
147
|
-
# FIXME(dhung): Temporary fix for bug in GS version 8.17
|
148
|
-
if cls._IS_BUGGED_VERSION:
|
149
|
-
filepath = f"versions/{version}//{relative_path}"
|
150
121
|
return _SFFileEntityPath(
|
151
122
|
domain=domain, name=name, version=version, relative_path=relative_path, filepath=filepath
|
152
123
|
)
|
@@ -140,6 +140,49 @@ class ModelOperator:
|
|
140
140
|
statement_params=statement_params,
|
141
141
|
)
|
142
142
|
|
143
|
+
def create_from_model_version(
|
144
|
+
self,
|
145
|
+
*,
|
146
|
+
source_database_name: Optional[sql_identifier.SqlIdentifier],
|
147
|
+
source_schema_name: Optional[sql_identifier.SqlIdentifier],
|
148
|
+
source_model_name: sql_identifier.SqlIdentifier,
|
149
|
+
source_version_name: sql_identifier.SqlIdentifier,
|
150
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
151
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
152
|
+
model_name: sql_identifier.SqlIdentifier,
|
153
|
+
version_name: sql_identifier.SqlIdentifier,
|
154
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
155
|
+
) -> None:
|
156
|
+
if self.validate_existence(
|
157
|
+
database_name=database_name,
|
158
|
+
schema_name=schema_name,
|
159
|
+
model_name=model_name,
|
160
|
+
statement_params=statement_params,
|
161
|
+
):
|
162
|
+
return self._model_version_client.add_version_from_model_version(
|
163
|
+
source_database_name=source_database_name,
|
164
|
+
source_schema_name=source_schema_name,
|
165
|
+
source_model_name=source_model_name,
|
166
|
+
source_version_name=source_version_name,
|
167
|
+
database_name=database_name,
|
168
|
+
schema_name=schema_name,
|
169
|
+
model_name=model_name,
|
170
|
+
version_name=version_name,
|
171
|
+
statement_params=statement_params,
|
172
|
+
)
|
173
|
+
else:
|
174
|
+
return self._model_version_client.create_from_model_version(
|
175
|
+
source_database_name=source_database_name,
|
176
|
+
source_schema_name=source_schema_name,
|
177
|
+
source_model_name=source_model_name,
|
178
|
+
source_version_name=source_version_name,
|
179
|
+
database_name=database_name,
|
180
|
+
schema_name=schema_name,
|
181
|
+
model_name=model_name,
|
182
|
+
version_name=version_name,
|
183
|
+
statement_params=statement_params,
|
184
|
+
)
|
185
|
+
|
143
186
|
def show_models_or_versions(
|
144
187
|
self,
|
145
188
|
*,
|
@@ -44,6 +44,32 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
44
44
|
statement_params=statement_params,
|
45
45
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
46
46
|
|
47
|
+
def create_from_model_version(
|
48
|
+
self,
|
49
|
+
*,
|
50
|
+
source_database_name: Optional[sql_identifier.SqlIdentifier],
|
51
|
+
source_schema_name: Optional[sql_identifier.SqlIdentifier],
|
52
|
+
source_model_name: sql_identifier.SqlIdentifier,
|
53
|
+
source_version_name: sql_identifier.SqlIdentifier,
|
54
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
55
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
56
|
+
model_name: sql_identifier.SqlIdentifier,
|
57
|
+
version_name: sql_identifier.SqlIdentifier,
|
58
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
59
|
+
) -> None:
|
60
|
+
fq_source_model_name = self.fully_qualified_object_name(
|
61
|
+
source_database_name, source_schema_name, source_model_name
|
62
|
+
)
|
63
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
64
|
+
query_result_checker.SqlResultValidator(
|
65
|
+
self._session,
|
66
|
+
(
|
67
|
+
f"CREATE MODEL {fq_model_name} WITH VERSION {version_name} FROM MODEL {fq_source_model_name}"
|
68
|
+
f" VERSION {source_version_name}"
|
69
|
+
),
|
70
|
+
statement_params=statement_params,
|
71
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
72
|
+
|
47
73
|
# TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
|
48
74
|
def add_version_from_stage(
|
49
75
|
self,
|
@@ -64,6 +90,32 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
64
90
|
statement_params=statement_params,
|
65
91
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
66
92
|
|
93
|
+
def add_version_from_model_version(
|
94
|
+
self,
|
95
|
+
*,
|
96
|
+
source_database_name: Optional[sql_identifier.SqlIdentifier],
|
97
|
+
source_schema_name: Optional[sql_identifier.SqlIdentifier],
|
98
|
+
source_model_name: sql_identifier.SqlIdentifier,
|
99
|
+
source_version_name: sql_identifier.SqlIdentifier,
|
100
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
101
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
102
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
|
+
version_name: sql_identifier.SqlIdentifier,
|
104
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
105
|
+
) -> None:
|
106
|
+
fq_source_model_name = self.fully_qualified_object_name(
|
107
|
+
source_database_name, source_schema_name, source_model_name
|
108
|
+
)
|
109
|
+
fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
|
110
|
+
query_result_checker.SqlResultValidator(
|
111
|
+
self._session,
|
112
|
+
(
|
113
|
+
f"ALTER MODEL {fq_model_name} ADD VERSION {version_name} FROM MODEL {fq_source_model_name}"
|
114
|
+
f" VERSION {source_version_name}"
|
115
|
+
),
|
116
|
+
statement_params=statement_params,
|
117
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
118
|
+
|
67
119
|
def set_default_version(
|
68
120
|
self,
|
69
121
|
*,
|
@@ -145,7 +197,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
145
197
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
146
198
|
options = {"parallel": 10}
|
147
199
|
cursor = self._session._conn._cursor
|
148
|
-
cursor._download(stage_location_url, str(target_path), options) # type: ignore[attr
|
200
|
+
cursor._download(stage_location_url, str(target_path), options) # type: ignore[union-attr]
|
149
201
|
cursor.fetchall()
|
150
202
|
else:
|
151
203
|
query_result_checker.SqlResultValidator(
|
@@ -220,7 +272,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
220
272
|
actual_schema_name.identifier(),
|
221
273
|
tmp_table_name,
|
222
274
|
)
|
223
|
-
input_df.write.save_as_table(
|
275
|
+
input_df.write.save_as_table(
|
224
276
|
table_name=INTERMEDIATE_TABLE_NAME,
|
225
277
|
mode="errorifexists",
|
226
278
|
table_type="temporary",
|
@@ -296,7 +348,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
|
|
296
348
|
actual_schema_name.identifier(),
|
297
349
|
tmp_table_name,
|
298
350
|
)
|
299
|
-
input_df.write.save_as_table(
|
351
|
+
input_df.write.save_as_table(
|
300
352
|
table_name=INTERMEDIATE_TABLE_NAME,
|
301
353
|
mode="errorifexists",
|
302
354
|
table_type="temporary",
|
@@ -136,7 +136,7 @@ class ModelComposer:
|
|
136
136
|
model_meta=self.packager.meta,
|
137
137
|
model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
|
138
138
|
options=options,
|
139
|
-
data_sources=self._get_data_sources(model),
|
139
|
+
data_sources=self._get_data_sources(model, sample_input_data),
|
140
140
|
)
|
141
141
|
|
142
142
|
file_utils.upload_directory_to_stage(
|
@@ -179,8 +179,12 @@ class ModelComposer:
|
|
179
179
|
mp.load(meta_only=meta_only, options=options)
|
180
180
|
return mp
|
181
181
|
|
182
|
-
def _get_data_sources(
|
183
|
-
|
182
|
+
def _get_data_sources(
|
183
|
+
self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
|
184
|
+
) -> Optional[List[data_source.DataSource]]:
|
185
|
+
data_sources = lineage_utils.get_data_sources(model)
|
186
|
+
if not data_sources and sample_input_data is not None:
|
187
|
+
data_sources = lineage_utils.get_data_sources(sample_input_data)
|
184
188
|
if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
|
185
189
|
return data_sources
|
186
190
|
return None
|
@@ -74,4 +74,6 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
74
74
|
class {function_name}:
|
75
75
|
@vectorized(input=pd.DataFrame)
|
76
76
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
77
|
-
|
77
|
+
df.columns = input_cols
|
78
|
+
input_df = df.astype(dtype=dtype_map)
|
79
|
+
return runner(input_df[input_cols])
|
@@ -281,9 +281,7 @@ class ModelMetadata:
|
|
281
281
|
"cpu": model_runtime.ModelRuntime("cpu", self.env),
|
282
282
|
}
|
283
283
|
if self.env.cuda_version:
|
284
|
-
runtimes.update(
|
285
|
-
{"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True, server_availability_source="conda")}
|
286
|
-
)
|
284
|
+
runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True)})
|
287
285
|
return runtimes
|
288
286
|
|
289
287
|
def save(self, model_dir_path: str) -> None:
|
@@ -1,11 +1,11 @@
|
|
1
1
|
import copy
|
2
2
|
import pathlib
|
3
3
|
import warnings
|
4
|
-
from typing import List,
|
4
|
+
from typing import List, Optional
|
5
5
|
|
6
6
|
from packaging import requirements
|
7
7
|
|
8
|
-
from snowflake.ml._internal import
|
8
|
+
from snowflake.ml._internal import env_utils, file_utils
|
9
9
|
from snowflake.ml.model._packager.model_env import model_env
|
10
10
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
11
11
|
from snowflake.ml.model._packager.model_runtime import (
|
@@ -37,7 +37,6 @@ class ModelRuntime:
|
|
37
37
|
env: model_env.ModelEnv,
|
38
38
|
imports: Optional[List[pathlib.PurePosixPath]] = None,
|
39
39
|
is_gpu: bool = False,
|
40
|
-
server_availability_source: Literal["snowflake", "conda"] = "snowflake",
|
41
40
|
loading_from_file: bool = False,
|
42
41
|
) -> None:
|
43
42
|
self.name = name
|
@@ -48,30 +47,7 @@ class ModelRuntime:
|
|
48
47
|
return
|
49
48
|
|
50
49
|
snowml_pkg_spec = f"{env_utils.SNOWPARK_ML_PKG_NAME}=={self.runtime_env.snowpark_ml_version}"
|
51
|
-
|
52
|
-
self.embed_local_ml_library = True
|
53
|
-
else:
|
54
|
-
if server_availability_source == "snowflake":
|
55
|
-
snowml_server_availability = (
|
56
|
-
len(
|
57
|
-
env_utils.get_matched_package_versions_in_information_schema_with_active_session(
|
58
|
-
reqs=[requirements.Requirement(snowml_pkg_spec)],
|
59
|
-
python_version=snowml_env.PYTHON_VERSION,
|
60
|
-
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
61
|
-
)
|
62
|
-
>= 1
|
63
|
-
)
|
64
|
-
else:
|
65
|
-
snowml_server_availability = (
|
66
|
-
len(
|
67
|
-
env_utils.get_matched_package_versions_in_snowflake_conda_channel(
|
68
|
-
req=requirements.Requirement(snowml_pkg_spec),
|
69
|
-
python_version=snowml_env.PYTHON_VERSION,
|
70
|
-
)
|
71
|
-
)
|
72
|
-
>= 1
|
73
|
-
)
|
74
|
-
self.embed_local_ml_library = not snowml_server_availability
|
50
|
+
self.embed_local_ml_library = self.runtime_env._snowpark_ml_version.local
|
75
51
|
|
76
52
|
additional_package = (
|
77
53
|
_SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES if self.embed_local_ml_library else [snowml_pkg_spec]
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import datetime
|
1
2
|
from collections import abc
|
2
3
|
from typing import Literal, Sequence
|
3
4
|
|
@@ -24,7 +25,7 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
|
|
24
25
|
# String is a Sequence but we take them as an whole
|
25
26
|
if isinstance(element, abc.Sequence) and not isinstance(element, str):
|
26
27
|
can_handle = ListOfBuiltinHandler.can_handle(element)
|
27
|
-
elif not isinstance(element, (int, float, bool, str)):
|
28
|
+
elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
|
28
29
|
can_handle = False
|
29
30
|
break
|
30
31
|
return can_handle
|
@@ -53,6 +53,8 @@ class DataType(Enum):
|
|
53
53
|
STRING = ("string", spt.StringType, np.str_)
|
54
54
|
BYTES = ("bytes", spt.BinaryType, np.bytes_)
|
55
55
|
|
56
|
+
TIMESTAMP_NTZ = ("datetime64[ns]", spt.TimestampType, "datetime64[ns]")
|
57
|
+
|
56
58
|
def as_snowpark_type(self) -> spt.DataType:
|
57
59
|
"""Convert to corresponding Snowpark Type.
|
58
60
|
|
@@ -78,6 +80,13 @@ class DataType(Enum):
|
|
78
80
|
Corresponding DataType.
|
79
81
|
"""
|
80
82
|
np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
|
83
|
+
|
84
|
+
# Add datetime types:
|
85
|
+
datetime_res = ["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"]
|
86
|
+
|
87
|
+
for res in datetime_res:
|
88
|
+
np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
|
89
|
+
|
81
90
|
for potential_type in np_to_snowml_type_mapping.keys():
|
82
91
|
if np.can_cast(np_type, potential_type, casting="no"):
|
83
92
|
# This is used since the same dtype might represented in different ways.
|
@@ -247,9 +256,12 @@ class FeatureSpec(BaseFeatureSpec):
|
|
247
256
|
result_type = spt.ArrayType(result_type)
|
248
257
|
return result_type
|
249
258
|
|
250
|
-
def as_dtype(self) -> npt.DTypeLike:
|
259
|
+
def as_dtype(self) -> Union[npt.DTypeLike, str]:
|
251
260
|
"""Convert to corresponding local Type."""
|
252
261
|
if not self._shape:
|
262
|
+
# scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
|
263
|
+
if "datetime64" in self._dtype._value:
|
264
|
+
return self._dtype._value
|
253
265
|
return self._dtype._numpy_type
|
254
266
|
return np.object_
|
255
267
|
|
@@ -147,6 +147,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
|
|
147
147
|
specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
|
148
148
|
elif isinstance(data[df_col].iloc[0], bytes):
|
149
149
|
specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
|
150
|
+
elif isinstance(data[df_col].iloc[0], np.datetime64):
|
151
|
+
specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
|
150
152
|
else:
|
151
153
|
specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(df_col_dtype), name=ft_name))
|
152
154
|
return specs
|
@@ -107,6 +107,9 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
107
107
|
if not features:
|
108
108
|
features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input")
|
109
109
|
# Role will be no effect on the column index. That is to say, the feature name is the actual column name.
|
110
|
+
if keep_order:
|
111
|
+
df = df.reset_index(drop=True)
|
112
|
+
df[infer_template._KEEP_ORDER_COL_NAME] = df.index
|
110
113
|
sp_df = session.create_dataframe(df)
|
111
114
|
column_names = []
|
112
115
|
columns = []
|
@@ -122,7 +125,4 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
122
125
|
|
123
126
|
sp_df = sp_df.with_columns(column_names, columns)
|
124
127
|
|
125
|
-
if keep_order:
|
126
|
-
sp_df = sp_df.with_column(infer_template._KEEP_ORDER_COL_NAME, F.monotonically_increasing_id())
|
127
|
-
|
128
128
|
return sp_df
|
@@ -168,6 +168,8 @@ def _validate_numpy_array(
|
|
168
168
|
max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type]
|
169
169
|
and min_v >= np.finfo(feature_type._numpy_type).min # type: ignore[arg-type]
|
170
170
|
)
|
171
|
+
elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
|
172
|
+
return np.issubdtype(arr.dtype, np.datetime64)
|
171
173
|
else:
|
172
174
|
return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no")
|
173
175
|
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1
1
|
import inspect
|
2
2
|
import numbers
|
3
|
+
import os
|
3
4
|
from typing import Any, Callable, Dict, List, Set, Tuple
|
4
5
|
|
6
|
+
import cloudpickle as cp
|
5
7
|
import numpy as np
|
6
8
|
from numpy import typing as npt
|
7
|
-
from typing_extensions import TypeGuard
|
8
9
|
|
9
10
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
11
|
+
from snowflake.ml._internal.utils import temp_file_utils
|
12
|
+
from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
|
10
13
|
from snowflake.ml.modeling.framework._utils import to_native_format
|
11
14
|
from snowflake.ml.modeling.framework.base import BaseTransformer
|
12
15
|
from snowflake.snowpark import Session
|
16
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
13
17
|
|
14
18
|
|
15
19
|
def validate_sklearn_args(args: Dict[str, Tuple[Any, Any, bool]], klass: type) -> Dict[str, Any]:
|
@@ -97,6 +101,7 @@ def original_estimator_has_callable(attr: str) -> Callable[[Any], bool]:
|
|
97
101
|
Returns:
|
98
102
|
A function which checks for the existence of callable `attr` on the given object.
|
99
103
|
"""
|
104
|
+
from typing_extensions import TypeGuard
|
100
105
|
|
101
106
|
def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
|
102
107
|
"""Check for the existence of callable `attr` in self.
|
@@ -218,3 +223,55 @@ def handle_inference_result(
|
|
218
223
|
)
|
219
224
|
|
220
225
|
return transformed_numpy_array, output_cols
|
226
|
+
|
227
|
+
|
228
|
+
def create_temp_stage(session: Session) -> str:
|
229
|
+
"""Creates temporary stage.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
session: Session
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
Temp stage name.
|
236
|
+
"""
|
237
|
+
# Create temp stage to upload pickled model file.
|
238
|
+
transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
|
239
|
+
stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
|
240
|
+
SqlResultValidator(session=session, query=stage_creation_query).has_dimensions(
|
241
|
+
expected_rows=1, expected_cols=1
|
242
|
+
).validate()
|
243
|
+
return transform_stage_name
|
244
|
+
|
245
|
+
|
246
|
+
def upload_model_to_stage(
|
247
|
+
stage_name: str, estimator: object, session: Session, statement_params: Dict[str, str]
|
248
|
+
) -> str:
|
249
|
+
"""Util method to pickle and upload the model to a temp Snowflake stage.
|
250
|
+
|
251
|
+
|
252
|
+
Args:
|
253
|
+
stage_name: Stage name to save model.
|
254
|
+
estimator: Estimator object to upload to stage (sklearn model object)
|
255
|
+
session: The snowpark session to use.
|
256
|
+
statement_params: Statement parameters for query telemetry.
|
257
|
+
|
258
|
+
Returns:
|
259
|
+
a tuple containing stage file paths for pickled input model for training and location to store trained
|
260
|
+
models(response from training sproc).
|
261
|
+
"""
|
262
|
+
# Create a temp file and dump the transform to that file.
|
263
|
+
local_transform_file_name = temp_file_utils.get_temp_file_path()
|
264
|
+
with open(local_transform_file_name, mode="w+b") as local_transform_file:
|
265
|
+
cp.dump(estimator, local_transform_file)
|
266
|
+
|
267
|
+
# Put locally serialized transform on stage.
|
268
|
+
session.file.put(
|
269
|
+
local_file_name=local_transform_file_name,
|
270
|
+
stage_location=stage_name,
|
271
|
+
auto_compress=False,
|
272
|
+
overwrite=True,
|
273
|
+
statement_params=statement_params,
|
274
|
+
)
|
275
|
+
|
276
|
+
temp_file_utils.cleanup_temp_files([local_transform_file_name])
|
277
|
+
return os.path.basename(local_transform_file_name)
|