snowflake-ml-python 1.5.1__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (207) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sentiment.py +7 -4
  3. snowflake/cortex/_sse_client.py +81 -0
  4. snowflake/cortex/_util.py +105 -8
  5. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  6. snowflake/ml/_internal/utils/temp_file_utils.py +5 -2
  7. snowflake/ml/dataset/dataset.py +15 -12
  8. snowflake/ml/dataset/dataset_factory.py +3 -4
  9. snowflake/ml/feature_store/access_manager.py +34 -30
  10. snowflake/ml/feature_store/feature_store.py +3 -3
  11. snowflake/ml/feature_store/feature_view.py +12 -11
  12. snowflake/ml/fileset/snowfs.py +2 -31
  13. snowflake/ml/model/_client/ops/model_ops.py +43 -0
  14. snowflake/ml/model/_client/sql/model_version.py +55 -3
  15. snowflake/ml/model/_model_composer/model_composer.py +7 -3
  16. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  17. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  18. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -3
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  20. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -27
  21. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  22. snowflake/ml/model/_signatures/core.py +13 -1
  23. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  24. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  25. snowflake/ml/model/model_signature.py +2 -0
  26. snowflake/ml/model/type_hints.py +1 -0
  27. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  28. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +196 -242
  29. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +161 -0
  30. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  32. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  33. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -2
  34. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -2
  35. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -2
  36. snowflake/ml/modeling/cluster/birch.py +9 -2
  37. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -2
  38. snowflake/ml/modeling/cluster/dbscan.py +9 -2
  39. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -2
  40. snowflake/ml/modeling/cluster/k_means.py +9 -2
  41. snowflake/ml/modeling/cluster/mean_shift.py +9 -2
  42. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -2
  43. snowflake/ml/modeling/cluster/optics.py +9 -2
  44. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -2
  45. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -2
  46. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -2
  47. snowflake/ml/modeling/compose/column_transformer.py +9 -2
  48. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -2
  49. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -2
  50. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -2
  51. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -2
  52. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -2
  53. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -2
  54. snowflake/ml/modeling/covariance/min_cov_det.py +9 -2
  55. snowflake/ml/modeling/covariance/oas.py +9 -2
  56. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -2
  57. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -2
  58. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -2
  59. snowflake/ml/modeling/decomposition/fast_ica.py +9 -2
  60. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -2
  61. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -2
  62. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -2
  63. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -2
  64. snowflake/ml/modeling/decomposition/pca.py +9 -2
  65. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -2
  66. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -2
  67. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -2
  68. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -2
  69. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -2
  70. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -2
  71. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -2
  72. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -2
  73. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -2
  74. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -2
  75. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -2
  76. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -2
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -2
  78. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -2
  79. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -2
  80. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -2
  81. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -2
  82. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -2
  83. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -2
  84. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -2
  85. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -2
  86. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -2
  87. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -2
  88. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -2
  89. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -2
  90. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -2
  91. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -2
  92. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -2
  93. snowflake/ml/modeling/framework/base.py +3 -8
  94. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -2
  95. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -2
  96. snowflake/ml/modeling/impute/iterative_imputer.py +9 -2
  97. snowflake/ml/modeling/impute/knn_imputer.py +9 -2
  98. snowflake/ml/modeling/impute/missing_indicator.py +9 -2
  99. snowflake/ml/modeling/impute/simple_imputer.py +28 -5
  100. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -2
  101. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -2
  102. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -2
  103. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -2
  104. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -2
  105. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -2
  106. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -2
  107. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -2
  108. snowflake/ml/modeling/linear_model/ard_regression.py +9 -2
  109. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -2
  110. snowflake/ml/modeling/linear_model/elastic_net.py +9 -2
  111. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -2
  112. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -2
  113. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -2
  114. snowflake/ml/modeling/linear_model/lars.py +9 -2
  115. snowflake/ml/modeling/linear_model/lars_cv.py +9 -2
  116. snowflake/ml/modeling/linear_model/lasso.py +9 -2
  117. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -2
  118. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -2
  119. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -2
  120. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -2
  121. snowflake/ml/modeling/linear_model/linear_regression.py +9 -2
  122. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -2
  123. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -2
  124. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -2
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -2
  126. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -2
  127. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -2
  128. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -2
  129. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -2
  130. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -2
  131. snowflake/ml/modeling/linear_model/perceptron.py +9 -2
  132. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -2
  133. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -2
  134. snowflake/ml/modeling/linear_model/ridge.py +9 -2
  135. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -2
  136. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -2
  137. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -2
  138. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -2
  139. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -2
  140. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -2
  141. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -2
  142. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -2
  143. snowflake/ml/modeling/manifold/isomap.py +9 -2
  144. snowflake/ml/modeling/manifold/mds.py +9 -2
  145. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -2
  146. snowflake/ml/modeling/manifold/tsne.py +9 -2
  147. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -2
  148. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -2
  149. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  150. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  151. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -2
  152. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -2
  153. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -2
  154. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -2
  155. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -2
  156. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -2
  157. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -2
  158. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -2
  159. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -2
  160. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -2
  161. snowflake/ml/modeling/neighbors/kernel_density.py +9 -2
  162. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -2
  163. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -2
  164. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -2
  165. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -2
  166. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -2
  167. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -2
  168. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -2
  169. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -2
  170. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -2
  171. snowflake/ml/modeling/parameters/enable_anonymous_sproc.py +5 -0
  172. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  173. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  174. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  175. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  176. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  177. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  178. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  179. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  180. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  181. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -2
  182. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  183. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  184. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -2
  185. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -2
  186. snowflake/ml/modeling/svm/linear_svc.py +9 -2
  187. snowflake/ml/modeling/svm/linear_svr.py +9 -2
  188. snowflake/ml/modeling/svm/nu_svc.py +9 -2
  189. snowflake/ml/modeling/svm/nu_svr.py +9 -2
  190. snowflake/ml/modeling/svm/svc.py +9 -2
  191. snowflake/ml/modeling/svm/svr.py +9 -2
  192. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -2
  193. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -2
  194. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -2
  195. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -2
  196. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -2
  197. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -2
  198. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -2
  199. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -2
  200. snowflake/ml/registry/_manager/model_manager.py +59 -1
  201. snowflake/ml/registry/registry.py +10 -1
  202. snowflake/ml/version.py +1 -1
  203. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +32 -4
  204. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +207 -204
  205. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  206. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  207. {snowflake_ml_python-1.5.1.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,12 @@
1
- from typing import Optional, Union
1
+ from typing import Iterator, Optional, Union
2
2
 
3
3
  from snowflake import snowpark
4
- from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
4
+ from snowflake.cortex._util import (
5
+ CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
6
+ call_rest_function,
7
+ call_sql_function,
8
+ process_rest_response,
9
+ )
5
10
  from snowflake.ml._internal import telemetry
6
11
 
7
12
 
@@ -10,19 +15,35 @@ from snowflake.ml._internal import telemetry
10
15
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
11
16
  )
12
17
  def Complete(
13
- model: Union[str, snowpark.Column], prompt: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
14
- ) -> Union[str, snowpark.Column]:
18
+ model: Union[str, snowpark.Column],
19
+ prompt: Union[str, snowpark.Column],
20
+ session: Optional[snowpark.Session] = None,
21
+ use_rest_api_experimental: bool = False,
22
+ stream: bool = False,
23
+ ) -> Union[str, Iterator[str], snowpark.Column]:
15
24
  """Complete calls into the LLM inference service to perform completion.
16
25
 
17
26
  Args:
18
27
  model: A Column of strings representing model types.
19
28
  prompt: A Column of prompts to send to the LLM.
20
29
  session: The snowpark session to use. Will be inferred by context if not specified.
30
+ use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
31
+ experimental and can be removed at any time.
32
+ stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
33
+ output as it is received. Each update is a string containing the new text content since the previous update.
34
+ The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
35
+
36
+ Raises:
37
+ ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
21
38
 
22
39
  Returns:
23
40
  A column of string responses.
24
41
  """
25
-
42
+ if stream is True and use_rest_api_experimental is False:
43
+ raise ValueError("If stream is set to True use_rest_api_experimental must also be set to True")
44
+ if use_rest_api_experimental:
45
+ response = call_rest_function("complete", model, prompt, session=session, stream=stream)
46
+ return process_rest_response(response)
26
47
  return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
27
48
 
28
49
 
@@ -11,7 +11,7 @@ from snowflake.ml._internal import telemetry
11
11
  )
12
12
  def Sentiment(
13
13
  text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
14
- ) -> Union[str, snowpark.Column]:
14
+ ) -> Union[float, snowpark.Column]:
15
15
  """Sentiment calls into the LLM inference service to perform sentiment analysis on the input text.
16
16
 
17
17
  Args:
@@ -21,11 +21,14 @@ def Sentiment(
21
21
  Returns:
22
22
  A column of floats. 1 represents positive sentiment, -1 represents negative sentiment.
23
23
  """
24
-
25
24
  return _sentiment_impl("snowflake.cortex.sentiment", text, session=session)
26
25
 
27
26
 
28
27
  def _sentiment_impl(
29
28
  function: str, text: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
30
- ) -> Union[str, snowpark.Column]:
31
- return call_sql_function(function, session, text)
29
+ ) -> Union[float, snowpark.Column]:
30
+
31
+ output = call_sql_function(function, session, text)
32
+ if isinstance(output, snowpark.Column):
33
+ return output
34
+ return float(output)
@@ -0,0 +1,81 @@
1
+ from typing import Iterator, cast
2
+
3
+ import requests
4
+
5
+
6
+ class Event:
7
+ def __init__(self, event: str = "message", data: str = "") -> None:
8
+ self.event = event
9
+ self.data = data
10
+
11
+ def __str__(self) -> str:
12
+ s = f"{self.event} event"
13
+ if self.data:
14
+ s += f", {len(self.data)} bytes"
15
+ else:
16
+ s += ", no data"
17
+ return s
18
+
19
+
20
+ class SSEClient:
21
+ def __init__(self, response: requests.Response) -> None:
22
+
23
+ self.response = response
24
+
25
+ def _read(self) -> Iterator[str]:
26
+
27
+ lines = b""
28
+ for chunk in self.response:
29
+ for line in chunk.splitlines(True):
30
+ lines += line
31
+ if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
32
+ yield cast(str, lines)
33
+ lines = b""
34
+ if lines:
35
+ yield cast(str, lines)
36
+
37
+ def events(self) -> Iterator[Event]:
38
+ for raw_event in self._read():
39
+ event = Event()
40
+ # splitlines() only uses \r and \n
41
+ for line in raw_event.splitlines():
42
+
43
+ line = cast(bytes, line).decode("utf-8")
44
+
45
+ data = line.split(":", 1)
46
+ field = data[0]
47
+
48
+ if len(data) > 1:
49
+ # "If value starts with a single U+0020 SPACE character,
50
+ # remove it from value. .strip() would remove all white spaces"
51
+ if data[1].startswith(" "):
52
+ value = data[1][1:]
53
+ else:
54
+ value = data[1]
55
+ else:
56
+ value = ""
57
+
58
+ # The data field may come over multiple lines and their values
59
+ # are concatenated with each other.
60
+ if field == "data":
61
+ event.data += value + "\n"
62
+ elif field == "event":
63
+ event.event = value
64
+
65
+ if not event.data:
66
+ continue
67
+
68
+ # If the data field ends with a newline, remove it.
69
+ if event.data.endswith("\n"):
70
+ event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
71
+
72
+ # Empty event names default to 'message'
73
+ event.event = event.event or "message"
74
+
75
+ if event.event != "message": # ignore anything but “message” or default event
76
+ continue
77
+
78
+ yield event
79
+
80
+ def close(self) -> None:
81
+ self.response.close()
snowflake/cortex/_util.py CHANGED
@@ -1,15 +1,34 @@
1
- from typing import Optional, Union, cast
1
+ import json
2
+ from typing import Iterator, Optional, Union, cast
3
+ from urllib.parse import urljoin, urlparse
4
+
5
+ import requests
2
6
 
3
7
  from snowflake import snowpark
8
+ from snowflake.cortex._sse_client import SSEClient
4
9
  from snowflake.snowpark import context, functions
5
10
 
6
11
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
7
12
 
8
13
 
14
+ class SSEParseException(Exception):
15
+ """This exception is raised when an invalid server sent event is received from the server."""
16
+
17
+ pass
18
+
19
+
20
+ class SnowflakeAuthenticationException(Exception):
21
+ """This exception is raised when the session object does not have session.connection.rest.token attribute."""
22
+
23
+ pass
24
+
25
+
9
26
  # Calls a sql function, handling both immediate (e.g. python types) and batch
10
27
  # (e.g. snowpark column and literal type modes).
11
28
  def call_sql_function(
12
- function: str, session: Optional[snowpark.Session], *args: Union[str, snowpark.Column]
29
+ function: str,
30
+ session: Optional[snowpark.Session],
31
+ *args: Union[str, snowpark.Column],
13
32
  ) -> Union[str, snowpark.Column]:
14
33
  handle_as_column = False
15
34
  for arg in args:
@@ -17,21 +36,29 @@ def call_sql_function(
17
36
  handle_as_column = True
18
37
 
19
38
  if handle_as_column:
20
- return cast(Union[str, snowpark.Column], call_sql_function_column(function, *args))
21
- return cast(Union[str, snowpark.Column], call_sql_function_immediate(function, session, *args))
39
+ return cast(Union[str, snowpark.Column], _call_sql_function_column(function, *args))
40
+ return cast(
41
+ Union[str, snowpark.Column],
42
+ _call_sql_function_immediate(function, session, *args),
43
+ )
22
44
 
23
45
 
24
- def call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
46
+ def _call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
25
47
  return cast(snowpark.Column, functions.builtin(function)(*args))
26
48
 
27
49
 
28
- def call_sql_function_immediate(
29
- function: str, session: Optional[snowpark.Session], *args: Union[str, snowpark.Column]
50
+ def _call_sql_function_immediate(
51
+ function: str,
52
+ session: Optional[snowpark.Session],
53
+ *args: Union[str, snowpark.Column],
30
54
  ) -> str:
31
55
  if session is None:
32
56
  session = context.get_active_session()
33
57
  if session is None:
34
- raise Exception("No session available in the current context nor specified as an argument.")
58
+ raise SnowflakeAuthenticationException(
59
+ """Session required. Provide the session through a session=... argument or ensure an active session is
60
+ available in your environment."""
61
+ )
35
62
 
36
63
  lit_args = []
37
64
  for arg in args:
@@ -40,3 +67,73 @@ def call_sql_function_immediate(
40
67
  empty_df = session.create_dataframe([snowpark.Row()])
41
68
  df = empty_df.select(functions.builtin(function)(*lit_args))
42
69
  return cast(str, df.collect()[0][0])
70
+
71
+
72
+ def call_rest_function(
73
+ function: str,
74
+ model: Union[str, snowpark.Column],
75
+ prompt: Union[str, snowpark.Column],
76
+ session: Optional[snowpark.Session] = None,
77
+ stream: bool = False,
78
+ ) -> requests.Response:
79
+ if session is None:
80
+ session = context.get_active_session()
81
+ if session is None:
82
+ raise SnowflakeAuthenticationException(
83
+ """Session required. Provide the session through a session=... argument or ensure an active session is
84
+ available in your environment."""
85
+ )
86
+
87
+ if not hasattr(session.connection.rest, "token"):
88
+ raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
89
+
90
+ if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
91
+ raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
92
+
93
+ url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
94
+ if urlparse(url).scheme == "":
95
+ url = "https://" + url
96
+ headers = {
97
+ "Content-Type": "application/json",
98
+ "Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
99
+ "Accept": "application/json, text/event-stream",
100
+ }
101
+
102
+ data = {
103
+ "model": model,
104
+ "messages": [{"content": prompt}],
105
+ "stream": stream,
106
+ }
107
+
108
+ response = requests.post(
109
+ url,
110
+ json=data,
111
+ headers=headers,
112
+ stream=stream,
113
+ )
114
+ response.raise_for_status()
115
+ return response
116
+
117
+
118
+ def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
119
+ if not stream:
120
+ try:
121
+ message = response.json()["choices"][0]["message"]
122
+ output = str(message.get("content", ""))
123
+ return output
124
+ except (KeyError, IndexError) as e:
125
+ raise SSEParseException("Failed to parse streamed response.") from e
126
+ else:
127
+ return _return_gen(response)
128
+
129
+
130
+ def _return_gen(response: requests.Response) -> Iterator[str]:
131
+ client = SSEClient(response)
132
+ for event in client.events():
133
+ response_loaded = json.loads(event.data)
134
+ try:
135
+ delta = response_loaded["choices"][0]["delta"]
136
+ output = str(delta.get("content", ""))
137
+ yield output
138
+ except (KeyError, IndexError) as e:
139
+ raise SSEParseException("Failed to parse streamed response.") from e
@@ -1,21 +1,11 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List
3
+ from typing import Any, Callable, List, Optional
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal.lineage import data_source
7
7
 
8
- DATA_SOURCES_ATTR = "_data_sources"
9
-
10
-
11
- def _get_datasources(*args: Any) -> List[data_source.DataSource]:
12
- """Helper method for extracting data sources attribute from DataFrames in an argument list"""
13
- result = []
14
- for arg in args:
15
- srcs = getattr(arg, DATA_SOURCES_ATTR, None)
16
- if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
17
- result += srcs
18
- return result
8
+ _DATA_SOURCES_ATTR = "_data_sources"
19
9
 
20
10
 
21
11
  def _wrap_func(
@@ -32,6 +22,37 @@ def _wrap_func(
32
22
  return wrapped
33
23
 
34
24
 
25
+ def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
26
+ @functools.wraps(fn)
27
+ def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
28
+ df = fn(*args, **kwargs)
29
+ data_sources = get_data_sources(*args, *kwargs.values())
30
+ if data_sources:
31
+ patch_dataframe(df, data_sources, inplace=True)
32
+ return df
33
+
34
+ return wrapped
35
+
36
+
37
+ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
38
+ """Helper method for extracting data sources attribute from DataFrames in an argument list"""
39
+ result: Optional[List[data_source.DataSource]] = None
40
+ for arg in args:
41
+ srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
+ if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
43
+ if result is None:
44
+ result = []
45
+ result += srcs
46
+ return result
47
+
48
+
49
+ def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
50
+ """Helper method for attaching data sources to an object"""
51
+ if data_sources:
52
+ assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
53
+ setattr(obj, _DATA_SOURCES_ATTR, data_sources)
54
+
55
+
35
56
  def patch_dataframe(
36
57
  df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
37
58
  ) -> snowpark.DataFrame:
@@ -62,7 +83,7 @@ def patch_dataframe(
62
83
  ]
63
84
  if not inplace:
64
85
  df = copy.copy(df)
65
- setattr(df, DATA_SOURCES_ATTR, data_sources)
86
+ set_data_sources(df, data_sources)
66
87
  for func in funcs:
67
88
  fn = getattr(df, func, None)
68
89
  if fn is not None:
@@ -70,18 +91,6 @@ def patch_dataframe(
70
91
  return df
71
92
 
72
93
 
73
- def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
74
- @functools.wraps(fn)
75
- def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
76
- df = fn(*args, **kwargs)
77
- data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values())
78
- if data_sources:
79
- patch_dataframe(df, data_sources, inplace=True)
80
- return df
81
-
82
- return wrapped
83
-
84
-
85
94
  # Class-level monkey-patches
86
95
  for klass, func_list in {
87
96
  snowpark.DataFrame: [
@@ -8,14 +8,17 @@ from absl.logging import logging
8
8
  logger = logging.getLogger(__name__)
9
9
 
10
10
 
11
- def get_temp_file_path() -> str:
11
+ def get_temp_file_path(prefix: str = "") -> str:
12
12
  """Returns a new random temp file path.
13
13
 
14
+ Args:
15
+ prefix: A prefix to the temp file path, this can help add stored file information. Defaults to None.
16
+
14
17
  Returns:
15
18
  A new temp file path.
16
19
  """
17
20
  # TODO(snandamuri): Use in-memory filesystem for temp files.
18
- local_file = tempfile.NamedTemporaryFile(delete=True)
21
+ local_file = tempfile.NamedTemporaryFile(prefix=prefix, delete=True)
19
22
  local_file_name = local_file.name
20
23
  local_file.close()
21
24
  return local_file_name
@@ -65,6 +65,20 @@ class DatasetVersion:
65
65
  comment: Optional[str] = self._get_property("comment")
66
66
  return comment
67
67
 
68
+ @property
69
+ def label_cols(self) -> List[str]:
70
+ metadata = self._get_metadata()
71
+ if metadata is None or metadata.label_cols is None:
72
+ return []
73
+ return metadata.label_cols
74
+
75
+ @property
76
+ def exclude_cols(self) -> List[str]:
77
+ metadata = self._get_metadata()
78
+ if metadata is None or metadata.exclude_cols is None:
79
+ return []
80
+ return metadata.exclude_cols
81
+
68
82
  def _get_property(self, property_name: str, default: Any = None) -> Any:
69
83
  if self._properties is None:
70
84
  sql_result = (
@@ -91,17 +105,6 @@ class DatasetVersion:
91
105
  warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
92
106
  return self._metadata
93
107
 
94
- def _get_exclude_cols(self) -> List[str]:
95
- metadata = self._get_metadata()
96
- if metadata is None:
97
- return []
98
- cols = []
99
- if metadata.exclude_cols:
100
- cols.extend(metadata.exclude_cols)
101
- if metadata.label_cols:
102
- cols.extend(metadata.label_cols)
103
- return cols
104
-
105
108
  def url(self) -> str:
106
109
  """Returns the URL of the DatasetVersion contents in Snowflake.
107
110
 
@@ -168,7 +171,7 @@ class Dataset:
168
171
  fully_qualified_name=self._fully_qualified_name,
169
172
  version=v.name,
170
173
  url=v.url(),
171
- exclude_cols=v._get_exclude_cols(),
174
+ exclude_cols=(v.label_cols + v.exclude_cols),
172
175
  )
173
176
  ],
174
177
  )
@@ -16,8 +16,7 @@ def create_from_dataframe(
16
16
  **version_kwargs: Any,
17
17
  ) -> dataset.Dataset:
18
18
  """
19
- Create a new versioned Dataset from a DataFrame and returns
20
- a DatasetReader for the newly created Dataset version.
19
+ Create a new versioned Dataset from a DataFrame.
21
20
 
22
21
  Args:
23
22
  session: The Snowpark Session instance to use.
@@ -39,7 +38,7 @@ def create_from_dataframe(
39
38
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
40
39
  def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
41
40
  """
42
- Load a versioned Dataset into a DatasetReader.
41
+ Load a versioned Dataset.
43
42
 
44
43
  Args:
45
44
  session: The Snowpark Session instance to use.
@@ -47,7 +46,7 @@ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.
47
46
  version: The dataset version name.
48
47
 
49
48
  Returns:
50
- A DatasetReader object.
49
+ A Dataset object.
51
50
  """
52
51
  ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
53
52
  return ds
@@ -42,6 +42,8 @@ class _SessionInfo:
42
42
  # Lists of permissions as tuples of (OBJECT_TYPE, [PRIVILEGES, ...])
43
43
  _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
44
44
  _FeatureStoreRole.PRODUCER: [
45
+ _Privilege("DATABASE", "{database}", ["USAGE"]),
46
+ _Privilege("SCHEMA", "{database}.{schema}", ["USAGE"]),
45
47
  _Privilege(
46
48
  "SCHEMA",
47
49
  "{database}.{schema}",
@@ -69,8 +71,7 @@ _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
69
71
  _Privilege("DYNAMIC TABLE", _ALL_OBJECTS, ["SELECT", "MONITOR"], "SCHEMA {database}.{schema}"),
70
72
  _Privilege("VIEW", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"),
71
73
  _Privilege("TABLE", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"),
72
- # FIXME(dhung): FUTURE DATASETS not supported until 8.19
73
- # _Privilege("DATASET", _ALL_OBJECTS, ["USAGE"], "SCHEMA {database}.{schema}"),
74
+ _Privilege("DATASET", _ALL_OBJECTS, ["USAGE"], "SCHEMA {database}.{schema}"),
74
75
  # User should decide whether they want to grant warehouse usage to CONSUMER
75
76
  # _Privilege("WAREHOUSE", "{warehouse}", ["USAGE"]),
76
77
  ],
@@ -128,8 +129,7 @@ def _grant_privileges(
128
129
  def _configure_pre_init_privileges(
129
130
  session: Session,
130
131
  session_info: _SessionInfo,
131
- producer_role: str = "SNOWML_FEATURE_STORE_PRODUCER_RL",
132
- consumer_role: str = "SNOWML_FEATURE_STORE_CONSUMER_RL",
132
+ roles_to_create: Dict[_FeatureStoreRole, str],
133
133
  ) -> None:
134
134
  """
135
135
  Configure Feature Store role privileges. Must be run with ACCOUNTADMIN
@@ -141,8 +141,7 @@ def _configure_pre_init_privileges(
141
141
  Args:
142
142
  session: Snowpark Session to interact with Snowflake backend.
143
143
  session_info: Session info like database and schema for the FeatureStore instance.
144
- producer_role: Name of producer role to be configured.
145
- consumer_role: Name of consumer role to be configured.
144
+ roles_to_create: Producer and optional consumer roles to create.
146
145
  """
147
146
 
148
147
  # Create schema if not already exists
@@ -159,29 +158,30 @@ def _configure_pre_init_privileges(
159
158
 
160
159
  # Pass schema ownership from admin to PRODUCER
161
160
  if schema_created:
161
+ # TODO: we are missing a test case for this code path
162
162
  session.sql(
163
- f"GRANT OWNERSHIP ON SCHEMA {session_info.database}.{session_info.schema} TO ROLE {producer_role}"
163
+ f"GRANT OWNERSHIP ON SCHEMA {session_info.database}.{session_info.schema} "
164
+ f"TO ROLE {roles_to_create[_FeatureStoreRole.PRODUCER]}"
164
165
  ).collect()
165
166
 
166
167
  # Grant privileges to roles
167
- _grant_privileges(session, producer_role, _PRE_INIT_PRIVILEGES[_FeatureStoreRole.PRODUCER], session_info)
168
- _grant_privileges(session, consumer_role, _PRE_INIT_PRIVILEGES[_FeatureStoreRole.CONSUMER], session_info)
168
+ for role_type, role in roles_to_create.items():
169
+ _grant_privileges(session, role, _PRE_INIT_PRIVILEGES[role_type], session_info)
169
170
 
170
171
 
171
172
  def _configure_post_init_privileges(
172
173
  session: Session,
173
174
  session_info: _SessionInfo,
174
- producer_role: str = "FS_PRODUCER",
175
- consumer_role: str = "FS_CONSUMER",
175
+ roles_to_create: Dict[_FeatureStoreRole, str],
176
176
  ) -> None:
177
- _grant_privileges(session, producer_role, _POST_INIT_PRIVILEGES[_FeatureStoreRole.PRODUCER], session_info)
178
- _grant_privileges(session, consumer_role, _POST_INIT_PRIVILEGES[_FeatureStoreRole.CONSUMER], session_info)
177
+ for role_type, role in roles_to_create.items():
178
+ _grant_privileges(session, role, _POST_INIT_PRIVILEGES[role_type], session_info)
179
179
 
180
180
 
181
181
  def _configure_role_hierarchy(
182
182
  session: Session,
183
183
  producer_role: str,
184
- consumer_role: str,
184
+ consumer_role: Optional[str],
185
185
  ) -> None:
186
186
  """
187
187
  Create Feature Store roles and configure role hierarchy hierarchy. Must be run with
@@ -195,18 +195,17 @@ def _configure_role_hierarchy(
195
195
  producer_role: Name of producer role to be configured.
196
196
  consumer_role: Name of consumer role to be configured.
197
197
  """
198
+ # Create the necessary roles and build role hierarchy
198
199
  producer_role = SqlIdentifier(producer_role)
199
- consumer_role = SqlIdentifier(consumer_role)
200
-
201
- # Create the necessary roles
202
200
  session.sql(f"CREATE ROLE IF NOT EXISTS {producer_role}").collect()
203
- session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
204
-
205
- # Build role hierarchy
206
- session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()
207
201
  session.sql(f"GRANT ROLE {producer_role} TO ROLE SYSADMIN").collect()
208
202
  session.sql(f"GRANT ROLE {producer_role} TO ROLE {session.get_current_role()}").collect()
209
203
 
204
+ if consumer_role is not None:
205
+ consumer_role = SqlIdentifier(consumer_role)
206
+ session.sql(f"CREATE ROLE IF NOT EXISTS {consumer_role}").collect()
207
+ session.sql(f"GRANT ROLE {consumer_role} TO ROLE {producer_role}").collect()
208
+
210
209
 
211
210
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
212
211
  def setup_feature_store(
@@ -215,7 +214,7 @@ def setup_feature_store(
215
214
  schema: str,
216
215
  warehouse: str,
217
216
  producer_role: str = "FS_PRODUCER",
218
- consumer_role: str = "FS_CONSUMER",
217
+ consumer_role: Optional[str] = None,
219
218
  ) -> FeatureStore:
220
219
  """
221
220
  Sets up a new Feature Store including role/privilege setup. Must be run with ACCOUNTADMIN
@@ -230,7 +229,7 @@ def setup_feature_store(
230
229
  schema: Schema to create the FeatureStore instance.
231
230
  warehouse: Default warehouse for Feature Store compute.
232
231
  producer_role: Name of producer role to be configured.
233
- consumer_role: Name of consumer role to be configured.
232
+ consumer_role: Name of consumer role to be configured. If not specified, consumer role won't be created.
234
233
 
235
234
  Returns:
236
235
  Feature Store instance.
@@ -249,20 +248,25 @@ def setup_feature_store(
249
248
  )
250
249
 
251
250
  try:
251
+ roles_to_create = {_FeatureStoreRole.PRODUCER: producer_role}
252
+ if consumer_role is not None:
253
+ roles_to_create.update({_FeatureStoreRole.CONSUMER: consumer_role})
252
254
  _configure_role_hierarchy(session, producer_role=producer_role, consumer_role=consumer_role)
253
255
  except exceptions.SnowparkSQLException:
254
256
  # Error can be safely ignored if roles already exist and hierarchy is already built
255
- for role in (producer_role, consumer_role):
257
+ for _, role in roles_to_create.items():
256
258
  # Ensure roles already exist
257
259
  if session.sql(f"SHOW ROLES LIKE '{role}' STARTS WITH '{role}'").count() == 0:
258
260
  raise
259
- # Ensure hierarchy already configured
260
- consumer_grants = session.sql(f"SHOW GRANTS ON ROLE {consumer_role}").collect()
261
- if not any(r["granted_to"] == "ROLE" and r["grantee_name"] == producer_role for r in consumer_grants):
262
- raise
261
+
262
+ if consumer_role is not None:
263
+ # Ensure hierarchy already configured
264
+ consumer_grants = session.sql(f"SHOW GRANTS ON ROLE {consumer_role}").collect()
265
+ if not any(r["granted_to"] == "ROLE" and r["grantee_name"] == producer_role for r in consumer_grants):
266
+ raise
263
267
 
264
268
  # Do any pre-FeatureStore.__init__() privilege setup
265
- _configure_pre_init_privileges(session, session_info, producer_role, consumer_role)
269
+ _configure_pre_init_privileges(session, session_info, roles_to_create)
266
270
 
267
271
  # Use PRODUCER role to create and operate new Feature Store
268
272
  current_role = session.get_current_role()
@@ -274,6 +278,6 @@ def setup_feature_store(
274
278
  session.use_role(current_role)
275
279
 
276
280
  # Do any post-FeatureStore.__init__() privilege setup
277
- _configure_post_init_privileges(session, session_info, producer_role, consumer_role)
281
+ _configure_post_init_privileges(session, session_info, roles_to_create)
278
282
 
279
283
  return fs
@@ -920,7 +920,7 @@ class FeatureStore:
920
920
  try:
921
921
  if output_type == "table":
922
922
  table_name = f"{name}_{version}"
923
- result_df.write.mode("errorifexists").save_as_table(table_name) # type: ignore[call-overload]
923
+ result_df.write.mode("errorifexists").save_as_table(table_name)
924
924
  ds_df = self._session.table(table_name)
925
925
  return ds_df
926
926
  else:
@@ -1761,8 +1761,8 @@ class FeatureStore:
1761
1761
  self._session.sql(
1762
1762
  f"""
1763
1763
  SELECT * FROM TABLE(
1764
- INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL(
1765
- TAG_NAME => '{_FEATURE_STORE_OBJECT_TAG}'
1764
+ {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL(
1765
+ TAG_NAME => '{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}'
1766
1766
  )
1767
1767
  ) LIMIT 1;
1768
1768
  """