snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sentiment.py +7 -4
  3. snowflake/cortex/_sse_client.py +81 -0
  4. snowflake/cortex/_util.py +105 -8
  5. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  6. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  7. snowflake/ml/dataset/dataset.py +15 -12
  8. snowflake/ml/dataset/dataset_factory.py +3 -4
  9. snowflake/ml/feature_store/access_manager.py +34 -30
  10. snowflake/ml/feature_store/feature_store.py +3 -3
  11. snowflake/ml/feature_store/feature_view.py +12 -11
  12. snowflake/ml/fileset/snowfs.py +2 -31
  13. snowflake/ml/model/_client/ops/model_ops.py +43 -0
  14. snowflake/ml/model/_client/sql/model_version.py +55 -3
  15. snowflake/ml/model/_model_composer/model_composer.py +7 -3
  16. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  17. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  18. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  20. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  21. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  22. snowflake/ml/model/_signatures/core.py +13 -1
  23. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  24. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  25. snowflake/ml/model/model_signature.py +2 -0
  26. snowflake/ml/model/type_hints.py +1 -0
  27. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
  29. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
  30. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  32. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
  36. snowflake/ml/modeling/cluster/birch.py +9 -2
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
  38. snowflake/ml/modeling/cluster/dbscan.py +9 -2
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
  40. snowflake/ml/modeling/cluster/k_means.py +9 -2
  41. snowflake/ml/modeling/cluster/mean_shift.py +9 -2
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
  43. snowflake/ml/modeling/cluster/optics.py +9 -2
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
  47. snowflake/ml/modeling/compose/column_transformer.py +9 -2
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
  54. snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
  55. snowflake/ml/modeling/covariance/oas.py +9 -2
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
  59. snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
  64. snowflake/ml/modeling/decomposition/pca.py +9 -2
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
  93. snowflake/ml/modeling/framework/base.py +3 -8
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
  96. snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
  97. snowflake/ml/modeling/impute/knn_imputer.py +9 -2
  98. snowflake/ml/modeling/impute/missing_indicator.py +9 -2
  99. snowflake/ml/modeling/impute/simple_imputer.py +28 -5
  100. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
  101. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
  102. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
  103. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
  104. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
  105. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
  106. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
  107. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
  108. snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
  109. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
  110. snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
  111. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
  112. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
  113. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
  114. snowflake/ml/modeling/linear_model/lars.py +9 -2
  115. snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
  116. snowflake/ml/modeling/linear_model/lasso.py +9 -2
  117. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
  118. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
  119. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
  120. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
  121. snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
  122. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
  123. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
  124. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
  126. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
  127. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
  128. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
  129. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
  130. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
  131. snowflake/ml/modeling/linear_model/perceptron.py +9 -2
  132. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
  133. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
  134. snowflake/ml/modeling/linear_model/ridge.py +9 -2
  135. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
  136. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
  137. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
  138. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
  139. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
  140. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
  141. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
  142. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
  143. snowflake/ml/modeling/manifold/isomap.py +9 -2
  144. snowflake/ml/modeling/manifold/mds.py +9 -2
  145. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
  146. snowflake/ml/modeling/manifold/tsne.py +9 -2
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
  161. snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
  171. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  172. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  173. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  174. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  175. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  176. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  177. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  178. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  179. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  180. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
  182. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  183. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  184. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
  185. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
  186. snowflake/ml/modeling/svm/linear_svc.py +9 -2
  187. snowflake/ml/modeling/svm/linear_svr.py +9 -2
  188. snowflake/ml/modeling/svm/nu_svc.py +9 -2
  189. snowflake/ml/modeling/svm/nu_svr.py +9 -2
  190. snowflake/ml/modeling/svm/svc.py +9 -2
  191. snowflake/ml/modeling/svm/svr.py +9 -2
  192. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
  193. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
  194. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
  195. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
  196. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
  197. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
  198. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
  199. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
  200. snowflake/ml/registry/_manager/model_manager.py +59 -1
  201. snowflake/ml/registry/registry.py +10 -1
  202. snowflake/ml/version.py +1 -1
  203. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
  204. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
  205. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  206. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  207. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -7,10 +7,6 @@ from dataclasses import asdict, dataclass
7
7
  from enum import Enum
8
8
  from typing import Any, Dict, List, Optional
9
9
 
10
- from snowflake.ml._internal.exceptions import (
11
- error_codes,
12
- exceptions as snowml_exceptions,
13
- )
14
10
  from snowflake.ml._internal.utils.identifier import concat_names
15
11
  from snowflake.ml._internal.utils.sql_identifier import (
16
12
  SqlIdentifier,
@@ -34,6 +30,11 @@ _FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
34
30
  _FEATURE_VIEW_VERSION_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9_.\-]*$")
35
31
  _FEATURE_VIEW_VERSION_MAX_LENGTH = 128
36
32
 
33
+ _RESULT_SCAN_QUERY_PATTERN = re.compile(
34
+ r".*FROM\s*TABLE\s*\(\s*RESULT_SCAN\s*\(.*",
35
+ flags=re.DOTALL | re.IGNORECASE | re.X,
36
+ )
37
+
37
38
 
38
39
  @dataclass(frozen=True)
39
40
  class _FeatureViewMetadata:
@@ -54,13 +55,10 @@ class _FeatureViewMetadata:
54
55
  class FeatureViewVersion(str):
55
56
  def __new__(cls, version: str) -> FeatureViewVersion:
56
57
  if not _FEATURE_VIEW_VERSION_RE.match(version) or len(version) > _FEATURE_VIEW_VERSION_MAX_LENGTH:
57
- raise snowml_exceptions.SnowflakeMLException(
58
- error_code=error_codes.INVALID_ARGUMENT,
59
- original_exception=ValueError(
60
- f"`{version}` is not a valid feature view version. "
61
- "It must start with letter or digit, and followed by letter, digit, '_', '-' or '.'. "
62
- f"The length limit is {_FEATURE_VIEW_VERSION_MAX_LENGTH}."
63
- ),
58
+ raise ValueError(
59
+ f"`{version}` is not a valid feature view version. "
60
+ "It must start with letter or digit, and followed by letter, digit, '_', '-' or '.'. "
61
+ f"The length limit is {_FEATURE_VIEW_VERSION_MAX_LENGTH}."
64
62
  )
65
63
  return super().__new__(cls, version)
66
64
 
@@ -352,6 +350,9 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
352
350
  if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
353
351
  raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
354
352
 
353
+ if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
354
+ raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
355
+
355
356
  def _get_feature_names(self) -> List[SqlIdentifier]:
356
357
  join_keys = [k for e in self._entities for k in e.join_keys]
357
358
  ts_col = [self._timestamp_col] if self._timestamp_col is not None else []
@@ -1,10 +1,9 @@
1
1
  import collections
2
2
  import logging
3
3
  import re
4
- from typing import Any, Dict, Optional
4
+ from typing import Any, Optional
5
5
 
6
6
  import fsspec
7
- import packaging.version as pkg_version
8
7
 
9
8
  from snowflake import snowpark
10
9
  from snowflake.connector import connection
@@ -12,7 +11,7 @@ from snowflake.ml._internal.exceptions import (
12
11
  error_codes,
13
12
  exceptions as snowml_exceptions,
14
13
  )
15
- from snowflake.ml._internal.utils import identifier, snowflake_env
14
+ from snowflake.ml._internal.utils import identifier
16
15
  from snowflake.ml.fileset import embedded_stage_fs, sfcfs
17
16
 
18
17
  PROTOCOL_NAME = "snow"
@@ -28,10 +27,6 @@ _SNOWURL_PATTERN = re.compile(
28
27
  r"(?P<path>versions/(?:(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)"
29
28
  )
30
29
 
31
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
32
- _BUG_VERSION_MIN = pkg_version.Version("8.17") # Inclusive minimum version with bugged behavior
33
- _BUG_VERSION_MAX = pkg_version.Version("8.18") # Exclusive maximum version with bugged behavior
34
-
35
30
 
36
31
  class SnowFileSystem(sfcfs.SFFileSystem):
37
32
  """A filesystem that allows user to access Snowflake embedded stage files with valid Snowflake locations.
@@ -54,21 +49,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
54
49
  ) -> None:
55
50
  super().__init__(sf_connection=sf_connection, snowpark_session=snowpark_session, **kwargs)
56
51
 
57
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
58
- if SnowFileSystem._IS_BUGGED_VERSION is None:
59
- try:
60
- sf_version = snowflake_env.get_current_snowflake_version(self._session)
61
- SnowFileSystem._IS_BUGGED_VERSION = _BUG_VERSION_MIN <= sf_version < _BUG_VERSION_MAX
62
- except Exception:
63
- SnowFileSystem._IS_BUGGED_VERSION = False
64
-
65
- def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
66
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
67
- res: Dict[str, Any] = super().info(path, **kwargs)
68
- if res.get("type") == "directory" and not res["name"].endswith("/"):
69
- res["name"] += "/"
70
- return res
71
-
72
52
  def _get_stage_fs(
73
53
  self, sf_file_path: _SFFileEntityPath # type: ignore[override]
74
54
  ) -> embedded_stage_fs.SFEmbeddedStageFileSystem:
@@ -100,12 +80,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
100
80
  if stage_name.startswith(protocol):
101
81
  stage_name = stage_name[len(protocol) :]
102
82
  abs_path = stage_name + "/" + path
103
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
104
- if self._IS_BUGGED_VERSION:
105
- match = _SNOWURL_PATTERN.fullmatch(abs_path)
106
- assert match is not None
107
- if match.group("relpath"):
108
- abs_path = abs_path.replace(match.group("relpath"), match.group("relpath").lstrip("/"))
109
83
  return abs_path
110
84
 
111
85
  @classmethod
@@ -144,9 +118,6 @@ class SnowFileSystem(sfcfs.SFFileSystem):
144
118
  version = snowurl_match.group("version")
145
119
  relative_path = snowurl_match.group("relpath") or ""
146
120
  logging.debug(f"Parsed snow URL: {snowurl_match.groups()}")
147
- # FIXME(dhung): Temporary fix for bug in GS version 8.17
148
- if cls._IS_BUGGED_VERSION:
149
- filepath = f"versions/{version}//{relative_path}"
150
121
  return _SFFileEntityPath(
151
122
  domain=domain, name=name, version=version, relative_path=relative_path, filepath=filepath
152
123
  )
@@ -140,6 +140,49 @@ class ModelOperator:
140
140
  statement_params=statement_params,
141
141
  )
142
142
 
143
+ def create_from_model_version(
144
+ self,
145
+ *,
146
+ source_database_name: Optional[sql_identifier.SqlIdentifier],
147
+ source_schema_name: Optional[sql_identifier.SqlIdentifier],
148
+ source_model_name: sql_identifier.SqlIdentifier,
149
+ source_version_name: sql_identifier.SqlIdentifier,
150
+ database_name: Optional[sql_identifier.SqlIdentifier],
151
+ schema_name: Optional[sql_identifier.SqlIdentifier],
152
+ model_name: sql_identifier.SqlIdentifier,
153
+ version_name: sql_identifier.SqlIdentifier,
154
+ statement_params: Optional[Dict[str, Any]] = None,
155
+ ) -> None:
156
+ if self.validate_existence(
157
+ database_name=database_name,
158
+ schema_name=schema_name,
159
+ model_name=model_name,
160
+ statement_params=statement_params,
161
+ ):
162
+ return self._model_version_client.add_version_from_model_version(
163
+ source_database_name=source_database_name,
164
+ source_schema_name=source_schema_name,
165
+ source_model_name=source_model_name,
166
+ source_version_name=source_version_name,
167
+ database_name=database_name,
168
+ schema_name=schema_name,
169
+ model_name=model_name,
170
+ version_name=version_name,
171
+ statement_params=statement_params,
172
+ )
173
+ else:
174
+ return self._model_version_client.create_from_model_version(
175
+ source_database_name=source_database_name,
176
+ source_schema_name=source_schema_name,
177
+ source_model_name=source_model_name,
178
+ source_version_name=source_version_name,
179
+ database_name=database_name,
180
+ schema_name=schema_name,
181
+ model_name=model_name,
182
+ version_name=version_name,
183
+ statement_params=statement_params,
184
+ )
185
+
143
186
  def show_models_or_versions(
144
187
  self,
145
188
  *,
@@ -44,6 +44,32 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
44
44
  statement_params=statement_params,
45
45
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
46
46
 
47
+ def create_from_model_version(
48
+ self,
49
+ *,
50
+ source_database_name: Optional[sql_identifier.SqlIdentifier],
51
+ source_schema_name: Optional[sql_identifier.SqlIdentifier],
52
+ source_model_name: sql_identifier.SqlIdentifier,
53
+ source_version_name: sql_identifier.SqlIdentifier,
54
+ database_name: Optional[sql_identifier.SqlIdentifier],
55
+ schema_name: Optional[sql_identifier.SqlIdentifier],
56
+ model_name: sql_identifier.SqlIdentifier,
57
+ version_name: sql_identifier.SqlIdentifier,
58
+ statement_params: Optional[Dict[str, Any]] = None,
59
+ ) -> None:
60
+ fq_source_model_name = self.fully_qualified_object_name(
61
+ source_database_name, source_schema_name, source_model_name
62
+ )
63
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
64
+ query_result_checker.SqlResultValidator(
65
+ self._session,
66
+ (
67
+ f"CREATE MODEL {fq_model_name} WITH VERSION {version_name} FROM MODEL {fq_source_model_name}"
68
+ f" VERSION {source_version_name}"
69
+ ),
70
+ statement_params=statement_params,
71
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
72
+
47
73
  # TODO(SNOW-987381): Merge with above when we have `create or alter module m [with] version v1 ...`
48
74
  def add_version_from_stage(
49
75
  self,
@@ -64,6 +90,32 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
64
90
  statement_params=statement_params,
65
91
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
66
92
 
93
+ def add_version_from_model_version(
94
+ self,
95
+ *,
96
+ source_database_name: Optional[sql_identifier.SqlIdentifier],
97
+ source_schema_name: Optional[sql_identifier.SqlIdentifier],
98
+ source_model_name: sql_identifier.SqlIdentifier,
99
+ source_version_name: sql_identifier.SqlIdentifier,
100
+ database_name: Optional[sql_identifier.SqlIdentifier],
101
+ schema_name: Optional[sql_identifier.SqlIdentifier],
102
+ model_name: sql_identifier.SqlIdentifier,
103
+ version_name: sql_identifier.SqlIdentifier,
104
+ statement_params: Optional[Dict[str, Any]] = None,
105
+ ) -> None:
106
+ fq_source_model_name = self.fully_qualified_object_name(
107
+ source_database_name, source_schema_name, source_model_name
108
+ )
109
+ fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
110
+ query_result_checker.SqlResultValidator(
111
+ self._session,
112
+ (
113
+ f"ALTER MODEL {fq_model_name} ADD VERSION {version_name} FROM MODEL {fq_source_model_name}"
114
+ f" VERSION {source_version_name}"
115
+ ),
116
+ statement_params=statement_params,
117
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
118
+
67
119
  def set_default_version(
68
120
  self,
69
121
  *,
@@ -145,7 +197,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
145
197
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
146
198
  options = {"parallel": 10}
147
199
  cursor = self._session._conn._cursor
148
- cursor._download(stage_location_url, str(target_path), options) # type: ignore[attr-defined]
200
+ cursor._download(stage_location_url, str(target_path), options) # type: ignore[union-attr]
149
201
  cursor.fetchall()
150
202
  else:
151
203
  query_result_checker.SqlResultValidator(
@@ -220,7 +272,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
220
272
  actual_schema_name.identifier(),
221
273
  tmp_table_name,
222
274
  )
223
- input_df.write.save_as_table( # type: ignore[call-overload]
275
+ input_df.write.save_as_table(
224
276
  table_name=INTERMEDIATE_TABLE_NAME,
225
277
  mode="errorifexists",
226
278
  table_type="temporary",
@@ -296,7 +348,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
296
348
  actual_schema_name.identifier(),
297
349
  tmp_table_name,
298
350
  )
299
- input_df.write.save_as_table( # type: ignore[call-overload]
351
+ input_df.write.save_as_table(
300
352
  table_name=INTERMEDIATE_TABLE_NAME,
301
353
  mode="errorifexists",
302
354
  table_type="temporary",
@@ -136,7 +136,7 @@ class ModelComposer:
136
136
  model_meta=self.packager.meta,
137
137
  model_file_rel_path=pathlib.PurePosixPath(self.model_file_rel_path),
138
138
  options=options,
139
- data_sources=self._get_data_sources(model),
139
+ data_sources=self._get_data_sources(model, sample_input_data),
140
140
  )
141
141
 
142
142
  file_utils.upload_directory_to_stage(
@@ -179,8 +179,12 @@ class ModelComposer:
179
179
  mp.load(meta_only=meta_only, options=options)
180
180
  return mp
181
181
 
182
- def _get_data_sources(self, model: model_types.SupportedModelType) -> Optional[List[data_source.DataSource]]:
183
- data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
182
+ def _get_data_sources(
183
+ self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
184
+ ) -> Optional[List[data_source.DataSource]]:
185
+ data_sources = lineage_utils.get_data_sources(model)
186
+ if not data_sources and sample_input_data is not None:
187
+ data_sources = lineage_utils.get_data_sources(sample_input_data)
184
188
  if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
185
189
  return data_sources
186
190
  return None
@@ -74,4 +74,6 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
74
74
  class {function_name}:
75
75
  @vectorized(input=pd.DataFrame)
76
76
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
77
- return runner(df)
77
+ df.columns = input_cols
78
+ input_df = df.astype(dtype=dtype_map)
79
+ return runner(input_df[input_cols])
@@ -6,6 +6,6 @@ REQUIREMENTS = [
6
6
  "packaging>=20.9,<24",
7
7
  "pandas>=1.0.0,<3",
8
8
  "pyyaml>=6.0,<7",
9
- "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
9
+ "snowflake-snowpark-python>=1.15.0,<2",
10
10
  "typing-extensions>=4.1.0,<5"
11
11
  ]
@@ -281,9 +281,7 @@ class ModelMetadata:
281
281
  "cpu": model_runtime.ModelRuntime("cpu", self.env),
282
282
  }
283
283
  if self.env.cuda_version:
284
- runtimes.update(
285
- {"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True, server_availability_source="conda")}
286
- )
284
+ runtimes.update({"gpu": model_runtime.ModelRuntime("gpu", self.env, is_gpu=True)})
287
285
  return runtimes
288
286
 
289
287
  def save(self, model_dir_path: str) -> None:
@@ -5,6 +5,6 @@ REQUIREMENTS = [
5
5
  "packaging>=20.9,<24",
6
6
  "pandas>=1.0.0,<3",
7
7
  "pyyaml>=6.0,<7",
8
- "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
8
+ "snowflake-snowpark-python>=1.15.0,<2",
9
9
  "typing-extensions>=4.1.0,<5"
10
10
  ]
@@ -1,11 +1,11 @@
1
1
  import copy
2
2
  import pathlib
3
3
  import warnings
4
- from typing import List, Literal, Optional
4
+ from typing import List, Optional
5
5
 
6
6
  from packaging import requirements
7
7
 
8
- from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
8
+ from snowflake.ml._internal import env_utils, file_utils
9
9
  from snowflake.ml.model._packager.model_env import model_env
10
10
  from snowflake.ml.model._packager.model_meta import model_meta_schema
11
11
  from snowflake.ml.model._packager.model_runtime import (
@@ -37,7 +37,6 @@ class ModelRuntime:
37
37
  env: model_env.ModelEnv,
38
38
  imports: Optional[List[pathlib.PurePosixPath]] = None,
39
39
  is_gpu: bool = False,
40
- server_availability_source: Literal["snowflake", "conda"] = "snowflake",
41
40
  loading_from_file: bool = False,
42
41
  ) -> None:
43
42
  self.name = name
@@ -48,30 +47,7 @@ class ModelRuntime:
48
47
  return
49
48
 
50
49
  snowml_pkg_spec = f"{env_utils.SNOWPARK_ML_PKG_NAME}=={self.runtime_env.snowpark_ml_version}"
51
- if self.runtime_env._snowpark_ml_version.local:
52
- self.embed_local_ml_library = True
53
- else:
54
- if server_availability_source == "snowflake":
55
- snowml_server_availability = (
56
- len(
57
- env_utils.get_matched_package_versions_in_information_schema_with_active_session(
58
- reqs=[requirements.Requirement(snowml_pkg_spec)],
59
- python_version=snowml_env.PYTHON_VERSION,
60
- ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
61
- )
62
- >= 1
63
- )
64
- else:
65
- snowml_server_availability = (
66
- len(
67
- env_utils.get_matched_package_versions_in_snowflake_conda_channel(
68
- req=requirements.Requirement(snowml_pkg_spec),
69
- python_version=snowml_env.PYTHON_VERSION,
70
- )
71
- )
72
- >= 1
73
- )
74
- self.embed_local_ml_library = not snowml_server_availability
50
+ self.embed_local_ml_library = self.runtime_env._snowpark_ml_version.local
75
51
 
76
52
  additional_package = (
77
53
  _SNOWML_INFERENCE_ALTERNATIVE_DEPENDENCIES if self.embed_local_ml_library else [snowml_pkg_spec]
@@ -1,3 +1,4 @@
1
+ import datetime
1
2
  from collections import abc
2
3
  from typing import Literal, Sequence
3
4
 
@@ -24,7 +25,7 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
24
25
  # String is a Sequence but we take them as an whole
25
26
  if isinstance(element, abc.Sequence) and not isinstance(element, str):
26
27
  can_handle = ListOfBuiltinHandler.can_handle(element)
27
- elif not isinstance(element, (int, float, bool, str)):
28
+ elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
28
29
  can_handle = False
29
30
  break
30
31
  return can_handle
@@ -53,6 +53,8 @@ class DataType(Enum):
53
53
  STRING = ("string", spt.StringType, np.str_)
54
54
  BYTES = ("bytes", spt.BinaryType, np.bytes_)
55
55
 
56
+ TIMESTAMP_NTZ = ("datetime64[ns]", spt.TimestampType, "datetime64[ns]")
57
+
56
58
  def as_snowpark_type(self) -> spt.DataType:
57
59
  """Convert to corresponding Snowpark Type.
58
60
 
@@ -78,6 +80,13 @@ class DataType(Enum):
78
80
  Corresponding DataType.
79
81
  """
80
82
  np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
83
+
84
+ # Add datetime types:
85
+ datetime_res = ["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"]
86
+
87
+ for res in datetime_res:
88
+ np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
89
+
81
90
  for potential_type in np_to_snowml_type_mapping.keys():
82
91
  if np.can_cast(np_type, potential_type, casting="no"):
83
92
  # This is used since the same dtype might represented in different ways.
@@ -247,9 +256,12 @@ class FeatureSpec(BaseFeatureSpec):
247
256
  result_type = spt.ArrayType(result_type)
248
257
  return result_type
249
258
 
250
- def as_dtype(self) -> npt.DTypeLike:
259
+ def as_dtype(self) -> Union[npt.DTypeLike, str]:
251
260
  """Convert to corresponding local Type."""
252
261
  if not self._shape:
262
+ # scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
263
+ if "datetime64" in self._dtype._value:
264
+ return self._dtype._value
253
265
  return self._dtype._numpy_type
254
266
  return np.object_
255
267
 
@@ -147,6 +147,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
147
147
  specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
148
148
  elif isinstance(data[df_col].iloc[0], bytes):
149
149
  specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
150
+ elif isinstance(data[df_col].iloc[0], np.datetime64):
151
+ specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
150
152
  else:
151
153
  specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(df_col_dtype), name=ft_name))
152
154
  return specs
@@ -107,6 +107,9 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
107
107
  if not features:
108
108
  features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input")
109
109
  # Role will be no effect on the column index. That is to say, the feature name is the actual column name.
110
+ if keep_order:
111
+ df = df.reset_index(drop=True)
112
+ df[infer_template._KEEP_ORDER_COL_NAME] = df.index
110
113
  sp_df = session.create_dataframe(df)
111
114
  column_names = []
112
115
  columns = []
@@ -122,7 +125,4 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
122
125
 
123
126
  sp_df = sp_df.with_columns(column_names, columns)
124
127
 
125
- if keep_order:
126
- sp_df = sp_df.with_column(infer_template._KEEP_ORDER_COL_NAME, F.monotonically_increasing_id())
127
-
128
128
  return sp_df
@@ -168,6 +168,8 @@ def _validate_numpy_array(
168
168
  max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type]
169
169
  and min_v >= np.finfo(feature_type._numpy_type).min # type: ignore[arg-type]
170
170
  )
171
+ elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
172
+ return np.issubdtype(arr.dtype, np.datetime64)
171
173
  else:
172
174
  return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no")
173
175
 
@@ -54,6 +54,7 @@ _SupportedNumpyDtype = Union[
54
54
  "np.bool_",
55
55
  "np.str_",
56
56
  "np.bytes_",
57
+ "np.datetime64",
57
58
  ]
58
59
  _SupportedNumpyArray = npt.NDArray[_SupportedNumpyDtype]
59
60
  _SupportedBuiltinsList = Sequence[_SupportedBuiltins]
@@ -1,15 +1,19 @@
1
1
  import inspect
2
2
  import numbers
3
+ import os
3
4
  from typing import Any, Callable, Dict, List, Set, Tuple
4
5
 
6
+ import cloudpickle as cp
5
7
  import numpy as np
6
8
  from numpy import typing as npt
7
- from typing_extensions import TypeGuard
8
9
 
9
10
  from snowflake.ml._internal.exceptions import error_codes, exceptions
11
+ from snowflake.ml._internal.utils import temp_file_utils
12
+ from snowflake.ml._internal.utils.query_result_checker import SqlResultValidator
10
13
  from snowflake.ml.modeling.framework._utils import to_native_format
11
14
  from snowflake.ml.modeling.framework.base import BaseTransformer
12
15
  from snowflake.snowpark import Session
16
+ from snowflake.snowpark._internal import utils as snowpark_utils
13
17
 
14
18
 
15
19
  def validate_sklearn_args(args: Dict[str, Tuple[Any, Any, bool]], klass: type) -> Dict[str, Any]:
@@ -97,6 +101,7 @@ def original_estimator_has_callable(attr: str) -> Callable[[Any], bool]:
97
101
  Returns:
98
102
  A function which checks for the existence of callable `attr` on the given object.
99
103
  """
104
+ from typing_extensions import TypeGuard
100
105
 
101
106
  def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
102
107
  """Check for the existence of callable `attr` in self.
@@ -218,3 +223,55 @@ def handle_inference_result(
218
223
  )
219
224
 
220
225
  return transformed_numpy_array, output_cols
226
+
227
+
228
+ def create_temp_stage(session: Session) -> str:
229
+ """Creates temporary stage.
230
+
231
+ Args:
232
+ session: Session
233
+
234
+ Returns:
235
+ Temp stage name.
236
+ """
237
+ # Create temp stage to upload pickled model file.
238
+ transform_stage_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
239
+ stage_creation_query = f"CREATE OR REPLACE TEMPORARY STAGE {transform_stage_name};"
240
+ SqlResultValidator(session=session, query=stage_creation_query).has_dimensions(
241
+ expected_rows=1, expected_cols=1
242
+ ).validate()
243
+ return transform_stage_name
244
+
245
+
246
+ def upload_model_to_stage(
247
+ stage_name: str, estimator: object, session: Session, statement_params: Dict[str, str]
248
+ ) -> str:
249
+ """Util method to pickle and upload the model to a temp Snowflake stage.
250
+
251
+
252
+ Args:
253
+ stage_name: Stage name to save model.
254
+ estimator: Estimator object to upload to stage (sklearn model object)
255
+ session: The snowpark session to use.
256
+ statement_params: Statement parameters for query telemetry.
257
+
258
+ Returns:
259
+ a tuple containing stage file paths for pickled input model for training and location to store trained
260
+ models(response from training sproc).
261
+ """
262
+ # Create a temp file and dump the transform to that file.
263
+ local_transform_file_name = temp_file_utils.get_temp_file_path()
264
+ with open(local_transform_file_name, mode="w+b") as local_transform_file:
265
+ cp.dump(estimator, local_transform_file)
266
+
267
+ # Put locally serialized transform on stage.
268
+ session.file.put(
269
+ local_file_name=local_transform_file_name,
270
+ stage_location=stage_name,
271
+ auto_compress=False,
272
+ overwrite=True,
273
+ statement_params=statement_params,
274
+ )
275
+
276
+ temp_file_utils.cleanup_temp_files([local_transform_file_name])
277
+ return os.path.basename(local_transform_file_name)