snowflake-ml-python 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/identifier.py +78 -72
  10. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  11. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  12. snowflake/ml/dataset/dataset.py +1 -1
  13. snowflake/ml/model/_api.py +21 -14
  14. snowflake/ml/model/_client/model/model_impl.py +176 -0
  15. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  17. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  18. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  19. snowflake/ml/model/_client/sql/model.py +75 -0
  20. snowflake/ml/model/_client/sql/model_version.py +213 -0
  21. snowflake/ml/model/_client/sql/stage.py +40 -0
  22. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  23. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  24. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  25. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  26. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  27. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  28. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  30. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  31. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  32. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  33. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  34. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  35. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  36. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  37. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  38. snowflake/ml/model/model_signature.py +108 -53
  39. snowflake/ml/model/type_hints.py +1 -0
  40. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  41. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  42. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  43. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  44. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  45. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  46. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  47. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +108 -135
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +106 -135
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +106 -135
  51. snowflake/ml/modeling/cluster/birch.py +106 -135
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +106 -135
  53. snowflake/ml/modeling/cluster/dbscan.py +106 -135
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +106 -135
  55. snowflake/ml/modeling/cluster/k_means.py +105 -135
  56. snowflake/ml/modeling/cluster/mean_shift.py +106 -135
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +105 -135
  58. snowflake/ml/modeling/cluster/optics.py +106 -135
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +106 -135
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +106 -135
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +106 -135
  62. snowflake/ml/modeling/compose/column_transformer.py +106 -135
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +108 -135
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +106 -135
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +99 -128
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +106 -135
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +106 -135
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +104 -133
  69. snowflake/ml/modeling/covariance/min_cov_det.py +106 -135
  70. snowflake/ml/modeling/covariance/oas.py +99 -128
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +103 -132
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +106 -135
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +106 -135
  74. snowflake/ml/modeling/decomposition/fast_ica.py +106 -135
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +106 -135
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +106 -135
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +106 -135
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +106 -135
  79. snowflake/ml/modeling/decomposition/pca.py +106 -135
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +106 -135
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +106 -135
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +108 -135
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +108 -135
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +108 -135
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +108 -135
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +108 -135
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +108 -135
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +108 -135
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +108 -135
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +108 -135
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +108 -135
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +108 -135
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +108 -135
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +106 -135
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +108 -135
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +108 -135
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +108 -135
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +108 -135
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +108 -135
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +101 -128
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +99 -126
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +99 -126
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +99 -126
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +100 -127
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +99 -126
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +106 -135
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +95 -124
  108. snowflake/ml/modeling/framework/base.py +83 -1
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +108 -135
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +108 -135
  111. snowflake/ml/modeling/impute/iterative_imputer.py +106 -135
  112. snowflake/ml/modeling/impute/knn_imputer.py +106 -135
  113. snowflake/ml/modeling/impute/missing_indicator.py +106 -135
  114. snowflake/ml/modeling/impute/simple_imputer.py +9 -1
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +96 -125
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +106 -135
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +106 -135
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +105 -134
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +103 -132
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +108 -135
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +90 -118
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +90 -118
  123. snowflake/ml/modeling/linear_model/ard_regression.py +108 -135
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +108 -135
  125. snowflake/ml/modeling/linear_model/elastic_net.py +108 -135
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +108 -135
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +108 -135
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +108 -135
  129. snowflake/ml/modeling/linear_model/lars.py +108 -135
  130. snowflake/ml/modeling/linear_model/lars_cv.py +108 -135
  131. snowflake/ml/modeling/linear_model/lasso.py +108 -135
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +108 -135
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +108 -135
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +108 -135
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +108 -135
  136. snowflake/ml/modeling/linear_model/linear_regression.py +108 -135
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +108 -135
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +108 -135
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +108 -135
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +108 -135
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +108 -135
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +108 -135
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +108 -135
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +108 -135
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +107 -135
  146. snowflake/ml/modeling/linear_model/perceptron.py +107 -135
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +108 -135
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +108 -135
  149. snowflake/ml/modeling/linear_model/ridge.py +108 -135
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +108 -135
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +108 -135
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +108 -135
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +108 -135
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +106 -135
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +108 -135
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +108 -135
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +108 -135
  158. snowflake/ml/modeling/manifold/isomap.py +106 -135
  159. snowflake/ml/modeling/manifold/mds.py +106 -135
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +106 -135
  161. snowflake/ml/modeling/manifold/tsne.py +106 -135
  162. snowflake/ml/modeling/metrics/classification.py +196 -55
  163. snowflake/ml/modeling/metrics/correlation.py +4 -2
  164. snowflake/ml/modeling/metrics/covariance.py +7 -4
  165. snowflake/ml/modeling/metrics/ranking.py +32 -16
  166. snowflake/ml/modeling/metrics/regression.py +60 -32
  167. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +106 -135
  168. snowflake/ml/modeling/mixture/gaussian_mixture.py +106 -135
  169. snowflake/ml/modeling/model_selection/grid_search_cv.py +91 -148
  170. snowflake/ml/modeling/model_selection/randomized_search_cv.py +93 -154
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +105 -132
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +108 -135
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +108 -135
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +108 -135
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +108 -135
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +108 -135
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +98 -125
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +107 -134
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +108 -135
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +108 -135
  181. snowflake/ml/modeling/neighbors/kernel_density.py +106 -135
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +106 -135
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +108 -135
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +106 -135
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +108 -135
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +108 -135
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +108 -135
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +106 -135
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +108 -135
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +108 -135
  191. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  192. snowflake/ml/modeling/preprocessing/binarizer.py +25 -8
  193. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +9 -4
  194. snowflake/ml/modeling/preprocessing/label_encoder.py +31 -11
  195. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +27 -9
  196. snowflake/ml/modeling/preprocessing/min_max_scaler.py +42 -14
  197. snowflake/ml/modeling/preprocessing/normalizer.py +9 -4
  198. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +26 -10
  199. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +37 -13
  200. snowflake/ml/modeling/preprocessing/polynomial_features.py +106 -135
  201. snowflake/ml/modeling/preprocessing/robust_scaler.py +39 -13
  202. snowflake/ml/modeling/preprocessing/standard_scaler.py +36 -12
  203. snowflake/ml/modeling/semi_supervised/label_propagation.py +108 -135
  204. snowflake/ml/modeling/semi_supervised/label_spreading.py +108 -135
  205. snowflake/ml/modeling/svm/linear_svc.py +108 -135
  206. snowflake/ml/modeling/svm/linear_svr.py +108 -135
  207. snowflake/ml/modeling/svm/nu_svc.py +108 -135
  208. snowflake/ml/modeling/svm/nu_svr.py +108 -135
  209. snowflake/ml/modeling/svm/svc.py +108 -135
  210. snowflake/ml/modeling/svm/svr.py +108 -135
  211. snowflake/ml/modeling/tree/decision_tree_classifier.py +108 -135
  212. snowflake/ml/modeling/tree/decision_tree_regressor.py +108 -135
  213. snowflake/ml/modeling/tree/extra_tree_classifier.py +108 -135
  214. snowflake/ml/modeling/tree/extra_tree_regressor.py +108 -135
  215. snowflake/ml/modeling/xgboost/xgb_classifier.py +108 -136
  216. snowflake/ml/modeling/xgboost/xgb_regressor.py +108 -136
  217. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +108 -136
  218. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +108 -136
  219. snowflake/ml/registry/model_registry.py +2 -0
  220. snowflake/ml/registry/registry.py +215 -0
  221. snowflake/ml/version.py +1 -1
  222. {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +34 -1
  223. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  224. snowflake_ml_python-1.1.0.dist-info/RECORD +0 -331
  225. {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -23,7 +23,7 @@ quote_name_without_upper_casing = analyzer_utils.quote_name_without_upper_casing
23
23
 
24
24
 
25
25
  def _is_quoted(id: str) -> bool:
26
- """Checks if input is quoted.
26
+ """Checks if input *identifier* is quoted.
27
27
 
28
28
  NOTE: Snowflake treats all identifiers as UPPERCASE by default. That is 'Hello' would become 'HELLO'. To preserve
29
29
  case, one needs to use quoted identifiers, e.g. "Hello" (note the double quote). Callers must take care of that
@@ -40,23 +40,21 @@ def _is_quoted(id: str) -> bool:
40
40
  ValueError: If the id is invalid.
41
41
  """
42
42
  if not id:
43
- raise ValueError("Invalid id passed.")
44
- if len(id) < 2:
45
- return False
46
- if id[0] == '"' and id[-1] == '"':
43
+ raise ValueError(f"Invalid id {id} passed. ID is empty.")
44
+ if len(id) >= 2 and id[0] == '"' and id[-1] == '"':
47
45
  if len(id) == 2:
48
- raise ValueError("Invalid id passed.")
46
+ raise ValueError(f"Invalid id {id} passed. ID is empty.")
49
47
  if not QUOTED_IDENTIFIER_RE.match(id):
50
- raise ValueError("Invalid id passed.")
48
+ raise ValueError(f"Invalid id {id} passed. ID is quoted but does not match the quoted rule.")
51
49
  return True
52
- if not UNQUOTED_CASE_INSENSITIVE_RE.match(id):
53
- raise ValueError("Invalid id passed.")
54
- return False # To keep mypy happy
50
+ if not UNQUOTED_CASE_SENSITIVE_RE.match(id):
51
+ raise ValueError(f"Invalid id {id} passed. ID is unquoted but does not match the unquoted rule.")
52
+ return False
55
53
 
56
54
 
57
55
  def _get_unescaped_name(id: str) -> str:
58
56
  """Remove double quotes and unescape quotes between them from id if quoted.
59
- Uppercase if not quoted.
57
+ Return as it is otherwise
60
58
 
61
59
  NOTE: See note in :meth:`_is_quoted`.
62
60
 
@@ -67,7 +65,7 @@ def _get_unescaped_name(id: str) -> str:
67
65
  String with quotes removed if quoted; original string otherwise.
68
66
  """
69
67
  if not _is_quoted(id):
70
- return id.upper()
68
+ return id
71
69
  unquoted_id = id[1:-1]
72
70
  return unquoted_id.replace(DOUBLE_QUOTE + DOUBLE_QUOTE, DOUBLE_QUOTE)
73
71
 
@@ -88,9 +86,9 @@ def _get_escaped_name(id: str) -> str:
88
86
  return DOUBLE_QUOTE + escape_quotes + DOUBLE_QUOTE
89
87
 
90
88
 
91
- def get_inferred_name(id: str) -> str:
92
- """Double quote id when it is case-sensitive and can start with and
93
- contain any valid characters; unquote otherwise.
89
+ def get_inferred_name(name: str) -> str:
90
+ """Double quote name when it is case-sensitive and can start with and
91
+ contain any valid characters; otherwise, keep it as it is.
94
92
 
95
93
  Examples:
96
94
  COL1 -> COL1
@@ -100,42 +98,38 @@ def get_inferred_name(id: str) -> str:
100
98
  COL 1 -> "COL 1"
101
99
 
102
100
  Args:
103
- id: The string to be checked & treated.
101
+ name: The string to be checked & treated.
104
102
 
105
103
  Returns:
106
104
  Double quoted identifier if necessary; unquoted string otherwise.
107
105
  """
108
- if UNQUOTED_CASE_SENSITIVE_RE.match(id):
109
- return id
110
- escaped_id = get_escaped_names(id)
106
+ if UNQUOTED_CASE_SENSITIVE_RE.match(name):
107
+ return name
108
+ escaped_id = _get_escaped_name(name)
111
109
  assert isinstance(escaped_id, str)
112
110
  return escaped_id
113
111
 
114
112
 
115
- def concat_names(ids: List[str]) -> str:
116
- """Concatenates `ids` to form one valid id.
113
+ def concat_names(names: List[str]) -> str:
114
+ """Concatenates `names` to form one valid id.
117
115
 
118
- NOTE: See note in :meth:`_is_quoted`.
119
116
 
120
117
  Args:
121
- ids: List of identifiers to be concatenated.
118
+ names: List of identifiers to be concatenated.
122
119
 
123
120
  Returns:
124
121
  Concatenated identifier.
125
122
  """
126
- quotes_needed = False
127
123
  parts = []
128
- for id in ids:
129
- if _is_quoted(id):
130
- # If any part is quoted, the user cares about case.
131
- quotes_needed = True
132
- # Remove quotes before using it.
133
- id = _get_unescaped_name(id)
134
- parts.append(id)
124
+ for name in names:
125
+ if QUOTED_IDENTIFIER_RE.match(name):
126
+ # If any part is quoted identifier, we need to remove the quotes
127
+ unescaped_name: str = _get_unescaped_name(name)
128
+ parts.append(unescaped_name)
129
+ else:
130
+ parts.append(name)
135
131
  final_id = "".join(parts)
136
- if quotes_needed:
137
- return _get_escaped_name(final_id)
138
- return final_id
132
+ return get_inferred_name(final_id)
139
133
 
140
134
 
141
135
  def rename_to_valid_snowflake_identifier(name: str) -> str:
@@ -222,6 +216,14 @@ def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[
222
216
  response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
223
217
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
224
218
 
219
+ This function will mimic the behavior of Snowpark's `to_pandas()` from Snowpark DataFrame.
220
+
221
+ Examples:
222
+ COL1 -> COL1
223
+ "Col" -> Col
224
+ \"""COL""\" -> "COL" (ignore '\')
225
+ "COL 1" -> COL 1
226
+
225
227
  Args:
226
228
  ids: User provided column name identifier(s).
227
229
 
@@ -243,27 +245,36 @@ def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[
243
245
 
244
246
 
245
247
  @overload
246
- def get_escaped_names(ids: None) -> None:
248
+ def get_inferred_names(names: None) -> None:
247
249
  ...
248
250
 
249
251
 
250
252
  @overload
251
- def get_escaped_names(ids: str) -> str:
253
+ def get_inferred_names(names: str) -> str:
252
254
  ...
253
255
 
254
256
 
255
257
  @overload
256
- def get_escaped_names(ids: List[str]) -> List[str]:
258
+ def get_inferred_names(names: List[str]) -> List[str]:
257
259
  ...
258
260
 
259
261
 
260
- def get_escaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
261
- """Given a user provided identifier(s), this method will compute the equivalent column name identifier(s)
262
+ def get_inferred_names(names: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
263
+ """Given a user provided *string(s)*, this method will compute the equivalent column name identifier(s)
262
264
  in case of column name contains special characters, and maintains case-sensitivity
263
265
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
264
266
 
267
+ This function will mimic the behavior of Snowpark's `create_dataframe` from pandas DataFrame.
268
+
269
+ Examples:
270
+ COL1 -> COL1
271
+ 1COL -> "1COL"
272
+ Col -> "Col"
273
+ "COL" -> \"""COL""\" (ignore '\')
274
+ COL 1 -> "COL 1"
275
+
265
276
  Args:
266
- ids: User provided column name identifier(s).
277
+ names: User provided column name identifier(s).
267
278
 
268
279
  Returns:
269
280
  Double-quoted Identifiers for column names, to make sure that column names are case sensitive
@@ -272,12 +283,12 @@ def get_escaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[st
272
283
  ValueError: if input types is unsupported or column name identifiers are invalid.
273
284
  """
274
285
 
275
- if ids is None:
286
+ if names is None:
276
287
  return None
277
- elif type(ids) is list:
278
- return [_get_escaped_name(id) for id in ids]
279
- elif type(ids) is str:
280
- return _get_escaped_name(ids)
288
+ elif type(names) is list:
289
+ return [get_inferred_name(id) for id in names]
290
+ elif type(names) is str:
291
+ return get_inferred_name(names)
281
292
  else:
282
293
  raise ValueError("Unsupported type. Only string or list of string are supported for selecting columns.")
283
294
 
@@ -297,39 +308,34 @@ def remove_prefix(s: str, prefix: str) -> str:
297
308
  return s
298
309
 
299
310
 
300
- def resolve_identifier(id: str) -> str:
301
- """Following Snowflake identifier resolution strategies:
311
+ def resolve_identifier(name: str) -> str:
312
+ """Given a user provided *string*, resolve following Snowflake identifier resolution strategies:
302
313
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax#label-identifier-casing
303
314
 
304
- If identifier is unquoted, it will return upper case.
305
- Otherwise return exactly as it is.
306
-
307
- Args:
308
- id: identifier string
309
-
310
- Returns:
311
- Resolved identifier
312
- """
313
- if _is_quoted(id):
314
- if UNQUOTED_CASE_SENSITIVE_RE.match(id[1:-1]):
315
- return id[1:-1]
316
- else:
317
- return id
318
- else:
319
- return id.upper()
320
-
315
+ This function will mimic the behavior of the SQL parser.
321
316
 
322
- def strip_wrapping_quotes(id: str) -> str:
323
- """Remove wrapping quotes if the identifier is quoted.
324
- This is mainly used for keywords like `warehouse` which doesn't like wrapping quotes when being used.
317
+ Examples:
318
+ COL1 -> COL1
319
+ 1COL -> Raise Error
320
+ Col -> COL
321
+ "COL" -> COL
322
+ COL 1 -> Raise Error
325
323
 
326
324
  Args:
327
- id: identifier string
325
+ name: the string to be resolved.
326
+
327
+ Raises:
328
+ ValueError: if input would not be accepted by SQL parser.
328
329
 
329
330
  Returns:
330
- Identifier with wrapping quotes removed
331
+ Resolved identifier
331
332
  """
332
- if _is_quoted(id):
333
- return id[1:-1]
333
+ if QUOTED_IDENTIFIER_RE.match(name):
334
+ unescaped = _get_unescaped_name(name)
335
+ if UNQUOTED_CASE_SENSITIVE_RE.match(unescaped):
336
+ return unescaped
337
+ return name
338
+ elif UNQUOTED_CASE_INSENSITIVE_RE.match(name):
339
+ return name.upper()
334
340
  else:
335
- return id
341
+ raise ValueError(f"Invalid name {name} passed. ID is not quoted and cannot normalized.")
@@ -5,11 +5,23 @@ from requests import adapters
5
5
  from urllib3.util import retry
6
6
 
7
7
 
8
- def get_http_client() -> requests.Session:
9
- # Set up a retry policy for requests
8
+ def get_http_client(total_retries: int = 5, backoff_factor: float = 0.1) -> requests.Session:
9
+ """Construct retryable http client.
10
+
11
+ Args:
12
+ total_retries: Total number of retries to allow.
13
+ backoff_factor: A backoff factor to apply between attempts after the second try. Time to sleep is calculated by
14
+ {backoff factor} * (2 ** ({number of previous retries})). For example, with default retries of 5 and backoff
15
+ factor set to 0.1, each subsequent retry will sleep [0.2s, 0.4s, 0.8s, 1.6s, 3.2s] respectively.
16
+
17
+ Returns:
18
+ requests.Session object.
19
+
20
+ """
21
+
10
22
  retry_strategy = retry.Retry(
11
- total=3, # total number of retries
12
- backoff_factor=0.1, # 100ms initial delay
23
+ total=total_retries,
24
+ backoff_factor=backoff_factor,
13
25
  status_forcelist=[
14
26
  http.HTTPStatus.TOO_MANY_REQUESTS,
15
27
  http.HTTPStatus.INTERNAL_SERVER_ERROR,
@@ -0,0 +1,122 @@
1
+ import logging
2
+ from datetime import datetime
3
+ from typing import Any, Dict, Optional
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.ml._internal import telemetry
7
+ from snowflake.ml._internal.utils import query_result_checker
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f %z"
12
+ _COMPUTE_POOL = "compute_pool"
13
+ _CREATED_ON = "created_on"
14
+ _INSTANCE_FAMILY = "instance_family"
15
+ _NAME = "name"
16
+ _TELEMETRY_PROJECT = "MLOps"
17
+ _TELEMETRY_SUBPROJECT = "SpcsDeployment"
18
+ _SERVICE_START = "SPCS_SERVICE_START"
19
+ _SERVICE_END = "SPCS_SERVICE_END"
20
+
21
+
22
+ def _desc_compute_pool(session: snowpark.Session, compute_pool_name: str) -> Dict[str, Any]:
23
+ sql = f"DESC COMPUTE POOL {compute_pool_name}"
24
+ result = (
25
+ query_result_checker.SqlResultValidator(
26
+ session=session,
27
+ query=sql,
28
+ )
29
+ .has_column(_INSTANCE_FAMILY)
30
+ .has_column(_NAME)
31
+ .has_dimensions(expected_rows=1)
32
+ .validate()
33
+ )
34
+ return result[0].as_dict()
35
+
36
+
37
+ def _desc_service(session: snowpark.Session, fully_qualified_name: str) -> Dict[str, Any]:
38
+ sql = f"DESC SERVICE {fully_qualified_name}"
39
+ result = (
40
+ query_result_checker.SqlResultValidator(
41
+ session=session,
42
+ query=sql,
43
+ )
44
+ .has_column(_COMPUTE_POOL)
45
+ .has_dimensions(expected_rows=1)
46
+ .validate()
47
+ )
48
+ return result[0].as_dict()
49
+
50
+
51
+ def _get_current_time() -> datetime:
52
+ """
53
+ This method exists to make it easier to mock datetime in test.
54
+
55
+ Returns:
56
+ current datetime
57
+ """
58
+ return datetime.now()
59
+
60
+
61
+ def _send_service_telemetry(
62
+ fully_qualified_name: Optional[str] = None,
63
+ compute_pool_name: Optional[str] = None,
64
+ service_details: Optional[Dict[str, Any]] = None,
65
+ compute_pool_details: Optional[Dict[str, Any]] = None,
66
+ duration_in_seconds: Optional[int] = None,
67
+ kwargs: Optional[Dict[str, Any]] = None,
68
+ ) -> None:
69
+ try:
70
+ telemetry.send_custom_usage(
71
+ project=_TELEMETRY_PROJECT,
72
+ subproject=_TELEMETRY_SUBPROJECT,
73
+ telemetry_type=telemetry.TelemetryField.TYPE_SNOWML_SPCS_USAGE.value,
74
+ data={
75
+ "service_name": fully_qualified_name,
76
+ "compute_pool_name": compute_pool_name,
77
+ "service_details": service_details,
78
+ "compute_pool_details": compute_pool_details,
79
+ "duration_in_seconds": duration_in_seconds,
80
+ },
81
+ kwargs=kwargs,
82
+ )
83
+ except Exception as e:
84
+ logger.error(f"Failed to send service telemetry: {e}")
85
+
86
+
87
+ def record_service_start(session: snowpark.Session, fully_qualified_name: str) -> None:
88
+ service_details = _desc_service(session, fully_qualified_name)
89
+ compute_pool_name = service_details[_COMPUTE_POOL]
90
+ compute_pool_details = _desc_compute_pool(session, compute_pool_name)
91
+
92
+ _send_service_telemetry(
93
+ fully_qualified_name=fully_qualified_name,
94
+ compute_pool_name=compute_pool_name,
95
+ service_details=service_details,
96
+ compute_pool_details=compute_pool_details,
97
+ kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_START},
98
+ )
99
+
100
+ logger.info(f"Service {fully_qualified_name} created with compute pool {compute_pool_name}.")
101
+
102
+
103
+ def record_service_end(session: snowpark.Session, fully_qualified_name: str) -> None:
104
+ service_details = _desc_service(session, fully_qualified_name)
105
+ compute_pool_details = _desc_compute_pool(session, service_details[_COMPUTE_POOL])
106
+ compute_pool_name = service_details[_COMPUTE_POOL]
107
+
108
+ created_on_datetime: datetime = service_details[_CREATED_ON]
109
+ current_time: datetime = _get_current_time()
110
+ current_time = current_time.replace(tzinfo=created_on_datetime.tzinfo)
111
+ duration_in_seconds = int((current_time - created_on_datetime).total_seconds())
112
+
113
+ _send_service_telemetry(
114
+ fully_qualified_name=fully_qualified_name,
115
+ compute_pool_name=compute_pool_name,
116
+ service_details=service_details,
117
+ compute_pool_details=compute_pool_details,
118
+ duration_in_seconds=duration_in_seconds,
119
+ kwargs={telemetry.TelemetryField.KEY_CUSTOM_TAGS.value: _SERVICE_END},
120
+ )
121
+
122
+ logger.info(f"Service {fully_qualified_name} deleted from compute pool {compute_pool_name}")
@@ -140,7 +140,7 @@ Got {len(self.df.queries['queries'])}: {self.df.queries['queries']}
140
140
 
141
141
  @classmethod
142
142
  def from_json(cls, json_str: str, session: Session) -> "Dataset":
143
- json_dict = json.loads(json_str)
143
+ json_dict = json.loads(json_str, strict=False)
144
144
  json_dict["df"] = session.sql(json_dict.pop("df_query"))
145
145
 
146
146
  fs_meta_json = json_dict["feature_store_metadata"]
@@ -7,7 +7,6 @@ from snowflake.ml._internal.exceptions import (
7
7
  error_codes,
8
8
  exceptions as snowml_exceptions,
9
9
  )
10
- from snowflake.ml._internal.utils import identifier
11
10
  from snowflake.ml.model import (
12
11
  deploy_platforms,
13
12
  model_signature,
@@ -188,6 +187,10 @@ def save_model(
188
187
  Returns:
189
188
  Model
190
189
  """
190
+ if options is None:
191
+ options = {}
192
+ options["_legacy_save"] = True
193
+
191
194
  m = model_composer.ModelComposer(session=session, stage_path=stage_path)
192
195
  m.save(
193
196
  name=name,
@@ -481,6 +484,7 @@ def predict(
481
484
  # Get options
482
485
  INTERMEDIATE_OBJ_NAME = "tmp_result"
483
486
  sig = deployment["signature"]
487
+ identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
484
488
 
485
489
  # Validate and prepare input
486
490
  if not isinstance(X, SnowparkDataFrame):
@@ -491,7 +495,7 @@ def predict(
491
495
  else:
492
496
  keep_order = False
493
497
  output_with_input_features = True
494
- model_signature._validate_snowpark_data(X, sig.inputs)
498
+ identifier_rule = model_signature._validate_snowpark_data(X, sig.inputs)
495
499
  s_df = X
496
500
 
497
501
  if statement_params:
@@ -500,10 +504,14 @@ def predict(
500
504
  else:
501
505
  s_df._statement_params = statement_params # type: ignore[assignment]
502
506
 
507
+ original_cols = s_df.columns
508
+
503
509
  # Infer and get intermediate result
504
510
  input_cols = []
505
- for col_name in s_df.columns:
506
- literal_col_name = identifier.get_unescaped_names(col_name)
511
+ for input_feature in sig.inputs:
512
+ literal_col_name = input_feature.name
513
+ col_name = identifier_rule.get_identifier_from_feature(input_feature.name)
514
+
507
515
  input_cols.extend(
508
516
  [
509
517
  F.lit(literal_col_name),
@@ -511,29 +519,28 @@ def predict(
511
519
  ]
512
520
  )
513
521
 
514
- # TODO[shchen]: SNOW-870032, For SnowService, external function name cannot be double quoted, else it results in
515
- # external function no found.
516
522
  udf_name = deployment["name"]
517
- output_obj = F.call_udf(udf_name, F.object_construct(*input_cols))
518
-
519
- if output_with_input_features:
520
- df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
521
- else:
522
- df_res = s_df.select(output_obj.alias(INTERMEDIATE_OBJ_NAME))
523
+ output_obj = F.call_udf(udf_name, F.object_construct_keep_null(*input_cols))
524
+ df_res = s_df.with_column(INTERMEDIATE_OBJ_NAME, output_obj)
523
525
 
524
526
  if keep_order:
525
527
  df_res = df_res.order_by(
526
- F.col(INTERMEDIATE_OBJ_NAME)[infer_template._KEEP_ORDER_COL_NAME],
528
+ F.col(infer_template._KEEP_ORDER_COL_NAME),
527
529
  ascending=True,
528
530
  )
529
531
 
532
+ if not output_with_input_features:
533
+ df_res = df_res.drop(*original_cols)
534
+
530
535
  # Prepare the output
531
536
  output_cols = []
537
+ output_col_names = []
532
538
  for output_feature in sig.outputs:
533
539
  output_cols.append(F.col(INTERMEDIATE_OBJ_NAME)[output_feature.name].astype(output_feature.as_snowpark_type()))
540
+ output_col_names.append(identifier_rule.get_identifier_from_feature(output_feature.name))
534
541
 
535
542
  df_res = df_res.with_columns(
536
- [identifier.get_inferred_name(output_feature.name) for output_feature in sig.outputs],
543
+ output_col_names,
537
544
  output_cols,
538
545
  ).drop(INTERMEDIATE_OBJ_NAME)
539
546
 
@@ -0,0 +1,176 @@
1
+ from typing import List, Union
2
+
3
+ from snowflake.ml._internal import telemetry
4
+ from snowflake.ml._internal.utils import sql_identifier
5
+ from snowflake.ml.model._client.model import model_version_impl
6
+ from snowflake.ml.model._client.ops import model_ops
7
+
8
+ _TELEMETRY_PROJECT = "MLOps"
9
+ _TELEMETRY_SUBPROJECT = "ModelManagement"
10
+
11
+
12
+ class Model:
13
+ """Model Object containing multiple versions. Mapping to SQL's MODEL object."""
14
+
15
+ _model_ops: model_ops.ModelOperator
16
+ _model_name: sql_identifier.SqlIdentifier
17
+
18
+ def __init__(self) -> None:
19
+ raise RuntimeError("Model's initializer is not meant to be used. Use `get_model` from registry instead.")
20
+
21
+ @classmethod
22
+ def _ref(
23
+ cls,
24
+ model_ops: model_ops.ModelOperator,
25
+ *,
26
+ model_name: sql_identifier.SqlIdentifier,
27
+ ) -> "Model":
28
+ self: "Model" = object.__new__(cls)
29
+ self._model_ops = model_ops
30
+ self._model_name = model_name
31
+ return self
32
+
33
+ def __eq__(self, __value: object) -> bool:
34
+ if not isinstance(__value, Model):
35
+ return False
36
+ return self._model_ops == __value._model_ops and self._model_name == __value._model_name
37
+
38
+ @property
39
+ def name(self) -> str:
40
+ return self._model_name.identifier()
41
+
42
+ @property
43
+ def fully_qualified_name(self) -> str:
44
+ return self._model_ops._model_version_client.fully_qualified_model_name(self._model_name)
45
+
46
+ @property
47
+ @telemetry.send_api_usage_telemetry(
48
+ project=_TELEMETRY_PROJECT,
49
+ subproject=_TELEMETRY_SUBPROJECT,
50
+ )
51
+ def description(self) -> str:
52
+ statement_params = telemetry.get_statement_params(
53
+ project=_TELEMETRY_PROJECT,
54
+ subproject=_TELEMETRY_SUBPROJECT,
55
+ )
56
+ return self._model_ops.get_comment(
57
+ model_name=self._model_name,
58
+ statement_params=statement_params,
59
+ )
60
+
61
+ @description.setter
62
+ @telemetry.send_api_usage_telemetry(
63
+ project=_TELEMETRY_PROJECT,
64
+ subproject=_TELEMETRY_SUBPROJECT,
65
+ )
66
+ def description(self, description: str) -> None:
67
+ statement_params = telemetry.get_statement_params(
68
+ project=_TELEMETRY_PROJECT,
69
+ subproject=_TELEMETRY_SUBPROJECT,
70
+ )
71
+ return self._model_ops.set_comment(
72
+ comment=description,
73
+ model_name=self._model_name,
74
+ statement_params=statement_params,
75
+ )
76
+
77
+ @property
78
+ @telemetry.send_api_usage_telemetry(
79
+ project=_TELEMETRY_PROJECT,
80
+ subproject=_TELEMETRY_SUBPROJECT,
81
+ )
82
+ def default(self) -> model_version_impl.ModelVersion:
83
+ statement_params = telemetry.get_statement_params(
84
+ project=_TELEMETRY_PROJECT,
85
+ subproject=_TELEMETRY_SUBPROJECT,
86
+ class_name=self.__class__.__name__,
87
+ )
88
+ default_version_name = self._model_ops._model_version_client.get_default_version(
89
+ model_name=self._model_name, statement_params=statement_params
90
+ )
91
+ return self.version(default_version_name)
92
+
93
+ @default.setter
94
+ @telemetry.send_api_usage_telemetry(
95
+ project=_TELEMETRY_PROJECT,
96
+ subproject=_TELEMETRY_SUBPROJECT,
97
+ )
98
+ def default(self, version: Union[str, model_version_impl.ModelVersion]) -> None:
99
+ statement_params = telemetry.get_statement_params(
100
+ project=_TELEMETRY_PROJECT,
101
+ subproject=_TELEMETRY_SUBPROJECT,
102
+ class_name=self.__class__.__name__,
103
+ )
104
+ if isinstance(version, str):
105
+ version_name = sql_identifier.SqlIdentifier(version)
106
+ else:
107
+ version_name = version._version_name
108
+ self._model_ops._model_version_client.set_default_version(
109
+ model_name=self._model_name, version_name=version_name, statement_params=statement_params
110
+ )
111
+
112
+ @telemetry.send_api_usage_telemetry(
113
+ project=_TELEMETRY_PROJECT,
114
+ subproject=_TELEMETRY_SUBPROJECT,
115
+ )
116
+ def version(self, version_name: str) -> model_version_impl.ModelVersion:
117
+ """Get a model version object given a version name in the model.
118
+
119
+ Args:
120
+ version_name: The name of version
121
+
122
+ Raises:
123
+ ValueError: Raised when the version requested does not exist.
124
+
125
+ Returns:
126
+ The model version object.
127
+ """
128
+ statement_params = telemetry.get_statement_params(
129
+ project=_TELEMETRY_PROJECT,
130
+ subproject=_TELEMETRY_SUBPROJECT,
131
+ )
132
+ version_id = sql_identifier.SqlIdentifier(version_name)
133
+ if self._model_ops.validate_existence(
134
+ model_name=self._model_name,
135
+ version_name=version_id,
136
+ statement_params=statement_params,
137
+ ):
138
+ return model_version_impl.ModelVersion._ref(
139
+ self._model_ops,
140
+ model_name=self._model_name,
141
+ version_name=version_id,
142
+ )
143
+ else:
144
+ raise ValueError(
145
+ f"Unable to find version with name {version_id.identifier()} in model {self.fully_qualified_name}"
146
+ )
147
+
148
+ @telemetry.send_api_usage_telemetry(
149
+ project=_TELEMETRY_PROJECT,
150
+ subproject=_TELEMETRY_SUBPROJECT,
151
+ )
152
+ def list_versions(self) -> List[model_version_impl.ModelVersion]:
153
+ """List all versions in the model.
154
+
155
+ Returns:
156
+ A List of ModelVersion object representing all versions in the model.
157
+ """
158
+ statement_params = telemetry.get_statement_params(
159
+ project=_TELEMETRY_PROJECT,
160
+ subproject=_TELEMETRY_SUBPROJECT,
161
+ )
162
+ version_names = self._model_ops.list_models_or_versions(
163
+ model_name=self._model_name,
164
+ statement_params=statement_params,
165
+ )
166
+ return [
167
+ model_version_impl.ModelVersion._ref(
168
+ self._model_ops,
169
+ model_name=self._model_name,
170
+ version_name=version_name,
171
+ )
172
+ for version_name in version_names
173
+ ]
174
+
175
+ def delete_version(self, version_name: str) -> None:
176
+ raise NotImplementedError("Deleting version has not been supported yet.")