snowflake-ml-python 1.5.2__py3-none-any.whl → 1.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/cortex/_complete.py +26 -5
  2. snowflake/cortex/_sse_client.py +81 -0
  3. snowflake/cortex/_util.py +105 -8
  4. snowflake/ml/_internal/lineage/lineage_utils.py +34 -25
  5. snowflake/ml/dataset/dataset.py +15 -12
  6. snowflake/ml/dataset/dataset_factory.py +3 -4
  7. snowflake/ml/feature_store/feature_store.py +2 -2
  8. snowflake/ml/model/_client/sql/model_version.py +2 -2
  9. snowflake/ml/model/_model_composer/model_composer.py +2 -2
  10. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -1
  11. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  12. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  13. snowflake/ml/model/_signatures/builtins_handler.py +2 -1
  14. snowflake/ml/model/_signatures/core.py +13 -1
  15. snowflake/ml/model/_signatures/pandas_handler.py +2 -0
  16. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  17. snowflake/ml/model/model_signature.py +2 -0
  18. snowflake/ml/model/type_hints.py +1 -0
  19. snowflake/ml/modeling/_internal/estimator_utils.py +58 -1
  20. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +156 -121
  21. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +2 -0
  22. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +38 -18
  23. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +82 -134
  24. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +21 -17
  25. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  26. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  27. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  28. snowflake/ml/modeling/cluster/birch.py +1 -1
  29. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  30. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  31. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  32. snowflake/ml/modeling/cluster/k_means.py +1 -1
  33. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  34. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  35. snowflake/ml/modeling/cluster/optics.py +1 -1
  36. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  37. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  38. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  39. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  40. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  41. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  42. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  43. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  44. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  45. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  46. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  47. snowflake/ml/modeling/covariance/oas.py +1 -1
  48. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  49. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  50. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  51. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  52. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  53. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  54. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  55. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  56. snowflake/ml/modeling/decomposition/pca.py +1 -1
  57. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  58. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  59. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  60. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  61. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  62. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  63. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  64. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  65. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  66. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  67. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  68. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  69. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  70. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  71. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  72. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  75. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  76. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  77. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  78. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  79. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  80. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  81. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  82. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  83. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  84. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  85. snowflake/ml/modeling/framework/base.py +3 -8
  86. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  87. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  88. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  89. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  90. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  91. snowflake/ml/modeling/impute/simple_imputer.py +8 -4
  92. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  93. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  94. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  95. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  96. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  97. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  98. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  99. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  100. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  101. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  102. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  103. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  104. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  106. snowflake/ml/modeling/linear_model/lars.py +1 -1
  107. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  108. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  109. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  110. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  111. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  112. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  113. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  114. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  115. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  116. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  117. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  118. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  119. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  121. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  122. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  123. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  124. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  125. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  126. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  127. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  128. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  129. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  130. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  131. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  132. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  133. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  134. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  135. snowflake/ml/modeling/manifold/isomap.py +1 -1
  136. snowflake/ml/modeling/manifold/mds.py +1 -1
  137. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  138. snowflake/ml/modeling/manifold/tsne.py +1 -1
  139. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  140. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +1 -5
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +1 -5
  143. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  144. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  145. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  146. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  147. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  148. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  151. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  152. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  153. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  154. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  155. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  156. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  157. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  158. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  159. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  160. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  161. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  162. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  163. snowflake/ml/modeling/pipeline/pipeline.py +5 -0
  164. snowflake/ml/modeling/preprocessing/binarizer.py +7 -3
  165. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +7 -2
  166. snowflake/ml/modeling/preprocessing/label_encoder.py +8 -7
  167. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +7 -3
  168. snowflake/ml/modeling/preprocessing/min_max_scaler.py +7 -4
  169. snowflake/ml/modeling/preprocessing/normalizer.py +7 -3
  170. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +10 -2
  171. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +8 -5
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  173. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -4
  174. snowflake/ml/modeling/preprocessing/standard_scaler.py +7 -3
  175. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  176. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  177. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  178. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  179. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  180. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  181. snowflake/ml/modeling/svm/svc.py +1 -1
  182. snowflake/ml/modeling/svm/svr.py +1 -1
  183. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  184. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  185. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  186. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  187. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  189. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/METADATA +21 -5
  193. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/RECORD +196 -195
  194. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/LICENSE.txt +0 -0
  195. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/WHEEL +0 -0
  196. {snowflake_ml_python-1.5.2.dist-info → snowflake_ml_python-1.5.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,12 @@
1
- from typing import Optional, Union
1
+ from typing import Iterator, Optional, Union
2
2
 
3
3
  from snowflake import snowpark
4
- from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
4
+ from snowflake.cortex._util import (
5
+ CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
6
+ call_rest_function,
7
+ call_sql_function,
8
+ process_rest_response,
9
+ )
5
10
  from snowflake.ml._internal import telemetry
6
11
 
7
12
 
@@ -10,19 +15,35 @@ from snowflake.ml._internal import telemetry
10
15
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
11
16
  )
12
17
  def Complete(
13
- model: Union[str, snowpark.Column], prompt: Union[str, snowpark.Column], session: Optional[snowpark.Session] = None
14
- ) -> Union[str, snowpark.Column]:
18
+ model: Union[str, snowpark.Column],
19
+ prompt: Union[str, snowpark.Column],
20
+ session: Optional[snowpark.Session] = None,
21
+ use_rest_api_experimental: bool = False,
22
+ stream: bool = False,
23
+ ) -> Union[str, Iterator[str], snowpark.Column]:
15
24
  """Complete calls into the LLM inference service to perform completion.
16
25
 
17
26
  Args:
18
27
  model: A Column of strings representing model types.
19
28
  prompt: A Column of prompts to send to the LLM.
20
29
  session: The snowpark session to use. Will be inferred by context if not specified.
30
+ use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
31
+ experimental and can be removed at any time.
32
+ stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
33
+ output as it is received. Each update is a string containing the new text content since the previous update.
34
+ The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
35
+
36
+ Raises:
37
+ ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
21
38
 
22
39
  Returns:
23
40
  A column of string responses.
24
41
  """
25
-
42
+ if stream is True and use_rest_api_experimental is False:
43
+ raise ValueError("If stream is set to True use_rest_api_experimental must also be set to True")
44
+ if use_rest_api_experimental:
45
+ response = call_rest_function("complete", model, prompt, session=session, stream=stream)
46
+ return process_rest_response(response)
26
47
  return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
27
48
 
28
49
 
@@ -0,0 +1,81 @@
1
+ from typing import Iterator, cast
2
+
3
+ import requests
4
+
5
+
6
+ class Event:
7
+ def __init__(self, event: str = "message", data: str = "") -> None:
8
+ self.event = event
9
+ self.data = data
10
+
11
+ def __str__(self) -> str:
12
+ s = f"{self.event} event"
13
+ if self.data:
14
+ s += f", {len(self.data)} bytes"
15
+ else:
16
+ s += ", no data"
17
+ return s
18
+
19
+
20
+ class SSEClient:
21
+ def __init__(self, response: requests.Response) -> None:
22
+
23
+ self.response = response
24
+
25
+ def _read(self) -> Iterator[str]:
26
+
27
+ lines = b""
28
+ for chunk in self.response:
29
+ for line in chunk.splitlines(True):
30
+ lines += line
31
+ if lines.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
32
+ yield cast(str, lines)
33
+ lines = b""
34
+ if lines:
35
+ yield cast(str, lines)
36
+
37
+ def events(self) -> Iterator[Event]:
38
+ for raw_event in self._read():
39
+ event = Event()
40
+ # splitlines() only uses \r and \n
41
+ for line in raw_event.splitlines():
42
+
43
+ line = cast(bytes, line).decode("utf-8")
44
+
45
+ data = line.split(":", 1)
46
+ field = data[0]
47
+
48
+ if len(data) > 1:
49
+ # "If value starts with a single U+0020 SPACE character,
50
+ # remove it from value. .strip() would remove all white spaces"
51
+ if data[1].startswith(" "):
52
+ value = data[1][1:]
53
+ else:
54
+ value = data[1]
55
+ else:
56
+ value = ""
57
+
58
+ # The data field may come over multiple lines and their values
59
+ # are concatenated with each other.
60
+ if field == "data":
61
+ event.data += value + "\n"
62
+ elif field == "event":
63
+ event.event = value
64
+
65
+ if not event.data:
66
+ continue
67
+
68
+ # If the data field ends with a newline, remove it.
69
+ if event.data.endswith("\n"):
70
+ event.data = event.data[0:-1] # Replace trailing newline - rstrip would remove multiple.
71
+
72
+ # Empty event names default to 'message'
73
+ event.event = event.event or "message"
74
+
75
+ if event.event != "message": # ignore anything but “message” or default event
76
+ continue
77
+
78
+ yield event
79
+
80
+ def close(self) -> None:
81
+ self.response.close()
snowflake/cortex/_util.py CHANGED
@@ -1,15 +1,34 @@
1
- from typing import Optional, Union, cast
1
+ import json
2
+ from typing import Iterator, Optional, Union, cast
3
+ from urllib.parse import urljoin, urlparse
4
+
5
+ import requests
2
6
 
3
7
  from snowflake import snowpark
8
+ from snowflake.cortex._sse_client import SSEClient
4
9
  from snowflake.snowpark import context, functions
5
10
 
6
11
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
7
12
 
8
13
 
14
+ class SSEParseException(Exception):
15
+ """This exception is raised when an invalid server sent event is received from the server."""
16
+
17
+ pass
18
+
19
+
20
+ class SnowflakeAuthenticationException(Exception):
21
+ """This exception is raised when the session object does not have session.connection.rest.token attribute."""
22
+
23
+ pass
24
+
25
+
9
26
  # Calls a sql function, handling both immediate (e.g. python types) and batch
10
27
  # (e.g. snowpark column and literal type modes).
11
28
  def call_sql_function(
12
- function: str, session: Optional[snowpark.Session], *args: Union[str, snowpark.Column]
29
+ function: str,
30
+ session: Optional[snowpark.Session],
31
+ *args: Union[str, snowpark.Column],
13
32
  ) -> Union[str, snowpark.Column]:
14
33
  handle_as_column = False
15
34
  for arg in args:
@@ -17,21 +36,29 @@ def call_sql_function(
17
36
  handle_as_column = True
18
37
 
19
38
  if handle_as_column:
20
- return cast(Union[str, snowpark.Column], call_sql_function_column(function, *args))
21
- return cast(Union[str, snowpark.Column], call_sql_function_immediate(function, session, *args))
39
+ return cast(Union[str, snowpark.Column], _call_sql_function_column(function, *args))
40
+ return cast(
41
+ Union[str, snowpark.Column],
42
+ _call_sql_function_immediate(function, session, *args),
43
+ )
22
44
 
23
45
 
24
- def call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
46
+ def _call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
25
47
  return cast(snowpark.Column, functions.builtin(function)(*args))
26
48
 
27
49
 
28
- def call_sql_function_immediate(
29
- function: str, session: Optional[snowpark.Session], *args: Union[str, snowpark.Column]
50
+ def _call_sql_function_immediate(
51
+ function: str,
52
+ session: Optional[snowpark.Session],
53
+ *args: Union[str, snowpark.Column],
30
54
  ) -> str:
31
55
  if session is None:
32
56
  session = context.get_active_session()
33
57
  if session is None:
34
- raise Exception("No session available in the current context nor specified as an argument.")
58
+ raise SnowflakeAuthenticationException(
59
+ """Session required. Provide the session through a session=... argument or ensure an active session is
60
+ available in your environment."""
61
+ )
35
62
 
36
63
  lit_args = []
37
64
  for arg in args:
@@ -40,3 +67,73 @@ def call_sql_function_immediate(
40
67
  empty_df = session.create_dataframe([snowpark.Row()])
41
68
  df = empty_df.select(functions.builtin(function)(*lit_args))
42
69
  return cast(str, df.collect()[0][0])
70
+
71
+
72
+ def call_rest_function(
73
+ function: str,
74
+ model: Union[str, snowpark.Column],
75
+ prompt: Union[str, snowpark.Column],
76
+ session: Optional[snowpark.Session] = None,
77
+ stream: bool = False,
78
+ ) -> requests.Response:
79
+ if session is None:
80
+ session = context.get_active_session()
81
+ if session is None:
82
+ raise SnowflakeAuthenticationException(
83
+ """Session required. Provide the session through a session=... argument or ensure an active session is
84
+ available in your environment."""
85
+ )
86
+
87
+ if not hasattr(session.connection.rest, "token"):
88
+ raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
89
+
90
+ if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
91
+ raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
92
+
93
+ url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
94
+ if urlparse(url).scheme == "":
95
+ url = "https://" + url
96
+ headers = {
97
+ "Content-Type": "application/json",
98
+ "Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
99
+ "Accept": "application/json, text/event-stream",
100
+ }
101
+
102
+ data = {
103
+ "model": model,
104
+ "messages": [{"content": prompt}],
105
+ "stream": stream,
106
+ }
107
+
108
+ response = requests.post(
109
+ url,
110
+ json=data,
111
+ headers=headers,
112
+ stream=stream,
113
+ )
114
+ response.raise_for_status()
115
+ return response
116
+
117
+
118
+ def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
119
+ if not stream:
120
+ try:
121
+ message = response.json()["choices"][0]["message"]
122
+ output = str(message.get("content", ""))
123
+ return output
124
+ except (KeyError, IndexError) as e:
125
+ raise SSEParseException("Failed to parse streamed response.") from e
126
+ else:
127
+ return _return_gen(response)
128
+
129
+
130
+ def _return_gen(response: requests.Response) -> Iterator[str]:
131
+ client = SSEClient(response)
132
+ for event in client.events():
133
+ response_loaded = json.loads(event.data)
134
+ try:
135
+ delta = response_loaded["choices"][0]["delta"]
136
+ output = str(delta.get("content", ""))
137
+ yield output
138
+ except (KeyError, IndexError) as e:
139
+ raise SSEParseException("Failed to parse streamed response.") from e
@@ -1,21 +1,11 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List
3
+ from typing import Any, Callable, List, Optional
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal.lineage import data_source
7
7
 
8
- DATA_SOURCES_ATTR = "_data_sources"
9
-
10
-
11
- def _get_datasources(*args: Any) -> List[data_source.DataSource]:
12
- """Helper method for extracting data sources attribute from DataFrames in an argument list"""
13
- result = []
14
- for arg in args:
15
- srcs = getattr(arg, DATA_SOURCES_ATTR, None)
16
- if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
17
- result += srcs
18
- return result
8
+ _DATA_SOURCES_ATTR = "_data_sources"
19
9
 
20
10
 
21
11
  def _wrap_func(
@@ -32,6 +22,37 @@ def _wrap_func(
32
22
  return wrapped
33
23
 
34
24
 
25
+ def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
26
+ @functools.wraps(fn)
27
+ def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
28
+ df = fn(*args, **kwargs)
29
+ data_sources = get_data_sources(*args, *kwargs.values())
30
+ if data_sources:
31
+ patch_dataframe(df, data_sources, inplace=True)
32
+ return df
33
+
34
+ return wrapped
35
+
36
+
37
+ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
38
+ """Helper method for extracting data sources attribute from DataFrames in an argument list"""
39
+ result: Optional[List[data_source.DataSource]] = None
40
+ for arg in args:
41
+ srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
+ if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
43
+ if result is None:
44
+ result = []
45
+ result += srcs
46
+ return result
47
+
48
+
49
+ def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
50
+ """Helper method for attaching data sources to an object"""
51
+ if data_sources:
52
+ assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
53
+ setattr(obj, _DATA_SOURCES_ATTR, data_sources)
54
+
55
+
35
56
  def patch_dataframe(
36
57
  df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
37
58
  ) -> snowpark.DataFrame:
@@ -62,7 +83,7 @@ def patch_dataframe(
62
83
  ]
63
84
  if not inplace:
64
85
  df = copy.copy(df)
65
- setattr(df, DATA_SOURCES_ATTR, data_sources)
86
+ set_data_sources(df, data_sources)
66
87
  for func in funcs:
67
88
  fn = getattr(df, func, None)
68
89
  if fn is not None:
@@ -70,18 +91,6 @@ def patch_dataframe(
70
91
  return df
71
92
 
72
93
 
73
- def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., snowpark.DataFrame]:
74
- @functools.wraps(fn)
75
- def wrapped(*args: Any, **kwargs: Any) -> snowpark.DataFrame:
76
- df = fn(*args, **kwargs)
77
- data_sources = _get_datasources(*args) + _get_datasources(*kwargs.values())
78
- if data_sources:
79
- patch_dataframe(df, data_sources, inplace=True)
80
- return df
81
-
82
- return wrapped
83
-
84
-
85
94
  # Class-level monkey-patches
86
95
  for klass, func_list in {
87
96
  snowpark.DataFrame: [
@@ -65,6 +65,20 @@ class DatasetVersion:
65
65
  comment: Optional[str] = self._get_property("comment")
66
66
  return comment
67
67
 
68
+ @property
69
+ def label_cols(self) -> List[str]:
70
+ metadata = self._get_metadata()
71
+ if metadata is None or metadata.label_cols is None:
72
+ return []
73
+ return metadata.label_cols
74
+
75
+ @property
76
+ def exclude_cols(self) -> List[str]:
77
+ metadata = self._get_metadata()
78
+ if metadata is None or metadata.exclude_cols is None:
79
+ return []
80
+ return metadata.exclude_cols
81
+
68
82
  def _get_property(self, property_name: str, default: Any = None) -> Any:
69
83
  if self._properties is None:
70
84
  sql_result = (
@@ -91,17 +105,6 @@ class DatasetVersion:
91
105
  warnings.warn(f"Metadata parsing failed with error: {e}", UserWarning, stacklevel=2)
92
106
  return self._metadata
93
107
 
94
- def _get_exclude_cols(self) -> List[str]:
95
- metadata = self._get_metadata()
96
- if metadata is None:
97
- return []
98
- cols = []
99
- if metadata.exclude_cols:
100
- cols.extend(metadata.exclude_cols)
101
- if metadata.label_cols:
102
- cols.extend(metadata.label_cols)
103
- return cols
104
-
105
108
  def url(self) -> str:
106
109
  """Returns the URL of the DatasetVersion contents in Snowflake.
107
110
 
@@ -168,7 +171,7 @@ class Dataset:
168
171
  fully_qualified_name=self._fully_qualified_name,
169
172
  version=v.name,
170
173
  url=v.url(),
171
- exclude_cols=v._get_exclude_cols(),
174
+ exclude_cols=(v.label_cols + v.exclude_cols),
172
175
  )
173
176
  ],
174
177
  )
@@ -16,8 +16,7 @@ def create_from_dataframe(
16
16
  **version_kwargs: Any,
17
17
  ) -> dataset.Dataset:
18
18
  """
19
- Create a new versioned Dataset from a DataFrame and returns
20
- a DatasetReader for the newly created Dataset version.
19
+ Create a new versioned Dataset from a DataFrame.
21
20
 
22
21
  Args:
23
22
  session: The Snowpark Session instance to use.
@@ -39,7 +38,7 @@ def create_from_dataframe(
39
38
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
40
39
  def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.Dataset:
41
40
  """
42
- Load a versioned Dataset into a DatasetReader.
41
+ Load a versioned Dataset.
43
42
 
44
43
  Args:
45
44
  session: The Snowpark Session instance to use.
@@ -47,7 +46,7 @@ def load_dataset(session: snowpark.Session, name: str, version: str) -> dataset.
47
46
  version: The dataset version name.
48
47
 
49
48
  Returns:
50
- A DatasetReader object.
49
+ A Dataset object.
51
50
  """
52
51
  ds: dataset.Dataset = dataset.Dataset.load(session, name).select_version(version)
53
52
  return ds
@@ -920,7 +920,7 @@ class FeatureStore:
920
920
  try:
921
921
  if output_type == "table":
922
922
  table_name = f"{name}_{version}"
923
- result_df.write.mode("errorifexists").save_as_table(table_name) # type: ignore[call-overload]
923
+ result_df.write.mode("errorifexists").save_as_table(table_name)
924
924
  ds_df = self._session.table(table_name)
925
925
  return ds_df
926
926
  else:
@@ -1762,7 +1762,7 @@ class FeatureStore:
1762
1762
  f"""
1763
1763
  SELECT * FROM TABLE(
1764
1764
  {self._config.database}.INFORMATION_SCHEMA.TAG_REFERENCES_INTERNAL(
1765
- TAG_NAME => '{_FEATURE_STORE_OBJECT_TAG}'
1765
+ TAG_NAME => '{self._get_fully_qualified_name(_FEATURE_STORE_OBJECT_TAG)}'
1766
1766
  )
1767
1767
  ) LIMIT 1;
1768
1768
  """
@@ -272,7 +272,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
272
272
  actual_schema_name.identifier(),
273
273
  tmp_table_name,
274
274
  )
275
- input_df.write.save_as_table( # type: ignore[call-overload]
275
+ input_df.write.save_as_table(
276
276
  table_name=INTERMEDIATE_TABLE_NAME,
277
277
  mode="errorifexists",
278
278
  table_type="temporary",
@@ -348,7 +348,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
348
348
  actual_schema_name.identifier(),
349
349
  tmp_table_name,
350
350
  )
351
- input_df.write.save_as_table( # type: ignore[call-overload]
351
+ input_df.write.save_as_table(
352
352
  table_name=INTERMEDIATE_TABLE_NAME,
353
353
  mode="errorifexists",
354
354
  table_type="temporary",
@@ -182,9 +182,9 @@ class ModelComposer:
182
182
  def _get_data_sources(
183
183
  self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
184
184
  ) -> Optional[List[data_source.DataSource]]:
185
- data_sources = getattr(model, lineage_utils.DATA_SOURCES_ATTR, None)
185
+ data_sources = lineage_utils.get_data_sources(model)
186
186
  if not data_sources and sample_input_data is not None:
187
- data_sources = getattr(sample_input_data, lineage_utils.DATA_SOURCES_ATTR, None)
187
+ data_sources = lineage_utils.get_data_sources(sample_input_data)
188
188
  if isinstance(data_sources, list) and all(isinstance(item, data_source.DataSource) for item in data_sources):
189
189
  return data_sources
190
190
  return None
@@ -74,4 +74,6 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
74
74
  class {function_name}:
75
75
  @vectorized(input=pd.DataFrame)
76
76
  def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
77
- return runner(df)
77
+ df.columns = input_cols
78
+ input_df = df.astype(dtype=dtype_map)
79
+ return runner(input_df[input_cols])
@@ -6,6 +6,6 @@ REQUIREMENTS = [
6
6
  "packaging>=20.9,<24",
7
7
  "pandas>=1.0.0,<3",
8
8
  "pyyaml>=6.0,<7",
9
- "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
9
+ "snowflake-snowpark-python>=1.15.0,<2",
10
10
  "typing-extensions>=4.1.0,<5"
11
11
  ]
@@ -5,6 +5,6 @@ REQUIREMENTS = [
5
5
  "packaging>=20.9,<24",
6
6
  "pandas>=1.0.0,<3",
7
7
  "pyyaml>=6.0,<7",
8
- "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
8
+ "snowflake-snowpark-python>=1.15.0,<2",
9
9
  "typing-extensions>=4.1.0,<5"
10
10
  ]
@@ -1,3 +1,4 @@
1
+ import datetime
1
2
  from collections import abc
2
3
  from typing import Literal, Sequence
3
4
 
@@ -24,7 +25,7 @@ class ListOfBuiltinHandler(base_handler.BaseDataHandler[model_types._SupportedBu
24
25
  # String is a Sequence but we take them as an whole
25
26
  if isinstance(element, abc.Sequence) and not isinstance(element, str):
26
27
  can_handle = ListOfBuiltinHandler.can_handle(element)
27
- elif not isinstance(element, (int, float, bool, str)):
28
+ elif not isinstance(element, (int, float, bool, str, datetime.datetime)):
28
29
  can_handle = False
29
30
  break
30
31
  return can_handle
@@ -53,6 +53,8 @@ class DataType(Enum):
53
53
  STRING = ("string", spt.StringType, np.str_)
54
54
  BYTES = ("bytes", spt.BinaryType, np.bytes_)
55
55
 
56
+ TIMESTAMP_NTZ = ("datetime64[ns]", spt.TimestampType, "datetime64[ns]")
57
+
56
58
  def as_snowpark_type(self) -> spt.DataType:
57
59
  """Convert to corresponding Snowpark Type.
58
60
 
@@ -78,6 +80,13 @@ class DataType(Enum):
78
80
  Corresponding DataType.
79
81
  """
80
82
  np_to_snowml_type_mapping = {i._numpy_type: i for i in DataType}
83
+
84
+ # Add datetime types:
85
+ datetime_res = ["Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"]
86
+
87
+ for res in datetime_res:
88
+ np_to_snowml_type_mapping[f"datetime64[{res}]"] = DataType.TIMESTAMP_NTZ
89
+
81
90
  for potential_type in np_to_snowml_type_mapping.keys():
82
91
  if np.can_cast(np_type, potential_type, casting="no"):
83
92
  # This is used since the same dtype might represented in different ways.
@@ -247,9 +256,12 @@ class FeatureSpec(BaseFeatureSpec):
247
256
  result_type = spt.ArrayType(result_type)
248
257
  return result_type
249
258
 
250
- def as_dtype(self) -> npt.DTypeLike:
259
+ def as_dtype(self) -> Union[npt.DTypeLike, str]:
251
260
  """Convert to corresponding local Type."""
252
261
  if not self._shape:
262
+ # scalar dtype: use keys from `np.sctypeDict` to prevent unit-less dtype 'datetime64'
263
+ if "datetime64" in self._dtype._value:
264
+ return self._dtype._value
253
265
  return self._dtype._numpy_type
254
266
  return np.object_
255
267
 
@@ -147,6 +147,8 @@ class PandasDataFrameHandler(base_handler.BaseDataHandler[pd.DataFrame]):
147
147
  specs.append(core.FeatureSpec(dtype=core.DataType.STRING, name=ft_name))
148
148
  elif isinstance(data[df_col].iloc[0], bytes):
149
149
  specs.append(core.FeatureSpec(dtype=core.DataType.BYTES, name=ft_name))
150
+ elif isinstance(data[df_col].iloc[0], np.datetime64):
151
+ specs.append(core.FeatureSpec(dtype=core.DataType.TIMESTAMP_NTZ, name=ft_name))
150
152
  else:
151
153
  specs.append(core.FeatureSpec(dtype=core.DataType.from_numpy_type(df_col_dtype), name=ft_name))
152
154
  return specs
@@ -107,6 +107,9 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
107
107
  if not features:
108
108
  features = pandas_handler.PandasDataFrameHandler.infer_signature(df, role="input")
109
109
  # Role will be no effect on the column index. That is to say, the feature name is the actual column name.
110
+ if keep_order:
111
+ df = df.reset_index(drop=True)
112
+ df[infer_template._KEEP_ORDER_COL_NAME] = df.index
110
113
  sp_df = session.create_dataframe(df)
111
114
  column_names = []
112
115
  columns = []
@@ -122,7 +125,4 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
122
125
 
123
126
  sp_df = sp_df.with_columns(column_names, columns)
124
127
 
125
- if keep_order:
126
- sp_df = sp_df.with_column(infer_template._KEEP_ORDER_COL_NAME, F.monotonically_increasing_id())
127
-
128
128
  return sp_df
@@ -168,6 +168,8 @@ def _validate_numpy_array(
168
168
  max_v <= np.finfo(feature_type._numpy_type).max # type: ignore[arg-type]
169
169
  and min_v >= np.finfo(feature_type._numpy_type).min # type: ignore[arg-type]
170
170
  )
171
+ elif feature_type in [core.DataType.TIMESTAMP_NTZ]:
172
+ return np.issubdtype(arr.dtype, np.datetime64)
171
173
  else:
172
174
  return np.can_cast(arr.dtype, feature_type._numpy_type, casting="no")
173
175
 
@@ -54,6 +54,7 @@ _SupportedNumpyDtype = Union[
54
54
  "np.bool_",
55
55
  "np.str_",
56
56
  "np.bytes_",
57
+ "np.datetime64",
57
58
  ]
58
59
  _SupportedNumpyArray = npt.NDArray[_SupportedNumpyDtype]
59
60
  _SupportedBuiltinsList = Sequence[_SupportedBuiltins]