snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/__init__.py +4 -1
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +281 -21
  4. snowflake/cortex/_extract_answer.py +0 -1
  5. snowflake/cortex/_sentiment.py +0 -1
  6. snowflake/cortex/_summarize.py +0 -1
  7. snowflake/cortex/_translate.py +0 -1
  8. snowflake/cortex/_util.py +12 -85
  9. snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
  10. snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
  11. snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
  12. snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
  13. snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
  14. snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
  15. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  16. snowflake/ml/_internal/telemetry.py +38 -2
  17. snowflake/ml/_internal/utils/identifier.py +14 -0
  18. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
  19. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  20. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  21. snowflake/ml/data/data_connector.py +133 -0
  22. snowflake/ml/data/data_ingestor.py +28 -0
  23. snowflake/ml/data/data_source.py +23 -0
  24. snowflake/ml/dataset/dataset.py +39 -32
  25. snowflake/ml/dataset/dataset_reader.py +18 -118
  26. snowflake/ml/feature_store/access_manager.py +7 -1
  27. snowflake/ml/feature_store/entity.py +19 -2
  28. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  29. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  30. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  31. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  32. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  34. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  35. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  36. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  37. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  38. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  39. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  40. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  43. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  44. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  45. snowflake/ml/feature_store/feature_store.py +987 -264
  46. snowflake/ml/feature_store/feature_view.py +228 -13
  47. snowflake/ml/fileset/embedded_stage_fs.py +25 -21
  48. snowflake/ml/fileset/fileset.py +2 -2
  49. snowflake/ml/fileset/snowfs.py +4 -15
  50. snowflake/ml/fileset/stage_fs.py +24 -18
  51. snowflake/ml/lineage/__init__.py +3 -0
  52. snowflake/ml/lineage/lineage_node.py +139 -0
  53. snowflake/ml/model/_client/model/model_impl.py +47 -14
  54. snowflake/ml/model/_client/model/model_version_impl.py +82 -2
  55. snowflake/ml/model/_client/ops/model_ops.py +77 -5
  56. snowflake/ml/model/_client/sql/model.py +1 -0
  57. snowflake/ml/model/_client/sql/model_version.py +45 -2
  58. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  59. snowflake/ml/model/_model_composer/model_composer.py +15 -17
  60. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
  61. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  62. snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
  63. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  64. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
  65. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
  66. snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
  67. snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
  68. snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
  69. snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
  70. snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
  71. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
  72. snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
  73. snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
  74. snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
  75. snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
  76. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
  77. snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
  78. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
  79. snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
  80. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  81. snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
  82. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  83. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
  84. snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
  85. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  86. snowflake/ml/model/_packager/model_packager.py +9 -4
  87. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  88. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  89. snowflake/ml/model/custom_model.py +22 -2
  90. snowflake/ml/model/model_signature.py +4 -4
  91. snowflake/ml/model/type_hints.py +77 -4
  92. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
  93. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  94. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
  95. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
  96. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
  97. snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
  98. snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
  99. snowflake/ml/modeling/cluster/birch.py +4 -2
  100. snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
  101. snowflake/ml/modeling/cluster/dbscan.py +4 -2
  102. snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
  103. snowflake/ml/modeling/cluster/k_means.py +4 -2
  104. snowflake/ml/modeling/cluster/mean_shift.py +4 -2
  105. snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
  106. snowflake/ml/modeling/cluster/optics.py +4 -2
  107. snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
  108. snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
  109. snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
  110. snowflake/ml/modeling/compose/column_transformer.py +4 -2
  111. snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
  112. snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
  113. snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
  114. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
  115. snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
  116. snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
  117. snowflake/ml/modeling/covariance/oas.py +4 -2
  118. snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
  119. snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
  120. snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
  121. snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
  122. snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
  123. snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
  124. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
  125. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
  126. snowflake/ml/modeling/decomposition/pca.py +4 -2
  127. snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
  128. snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
  129. snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
  130. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
  131. snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
  132. snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
  133. snowflake/ml/modeling/impute/knn_imputer.py +4 -2
  134. snowflake/ml/modeling/impute/missing_indicator.py +4 -2
  135. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  136. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
  137. snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
  138. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
  139. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
  140. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
  141. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
  142. snowflake/ml/modeling/manifold/isomap.py +4 -2
  143. snowflake/ml/modeling/manifold/mds.py +4 -2
  144. snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
  145. snowflake/ml/modeling/manifold/tsne.py +4 -2
  146. snowflake/ml/modeling/metrics/ranking.py +3 -0
  147. snowflake/ml/modeling/metrics/regression.py +3 -0
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
  150. snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
  151. snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
  152. snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
  153. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
  154. snowflake/ml/modeling/pipeline/pipeline.py +5 -4
  155. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
  156. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
  157. snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
  158. snowflake/ml/registry/_manager/model_manager.py +16 -3
  159. snowflake/ml/registry/registry.py +100 -13
  160. snowflake/ml/version.py +1 -1
  161. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
  162. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
  163. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  164. snowflake/ml/_internal/lineage/data_source.py +0 -10
  165. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,14 @@
1
- from snowflake.cortex._complete import Complete
1
+ from snowflake.cortex._classify_text import ClassifyText
2
+ from snowflake.cortex._complete import Complete, CompleteOptions
2
3
  from snowflake.cortex._extract_answer import ExtractAnswer
3
4
  from snowflake.cortex._sentiment import Sentiment
4
5
  from snowflake.cortex._summarize import Summarize
5
6
  from snowflake.cortex._translate import Translate
6
7
 
7
8
  __all__ = [
9
+ "ClassifyText",
8
10
  "Complete",
11
+ "CompleteOptions",
9
12
  "ExtractAnswer",
10
13
  "Sentiment",
11
14
  "Summarize",
@@ -0,0 +1,36 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from snowflake import snowpark
4
+ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
+ from snowflake.ml._internal import telemetry
6
+
7
+
8
+ @telemetry.send_api_usage_telemetry(
9
+ project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
+ )
11
+ def ClassifyText(
12
+ str_input: Union[str, snowpark.Column],
13
+ categories: Union[List[str], snowpark.Column],
14
+ session: Optional[snowpark.Session] = None,
15
+ ) -> Union[str, snowpark.Column]:
16
+ """Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
17
+
18
+ Args:
19
+ str_input: A Column of strings to classify.
20
+ categories: A list of candidate categories to classify the INPUT text into.
21
+ session: The snowpark session to use. Will be inferred by context if not specified.
22
+
23
+ Returns:
24
+ A column of classification responses.
25
+ """
26
+
27
+ return _classify_text_impl("snowflake.cortex.classify_text", str_input, categories, session=session)
28
+
29
+
30
+ def _classify_text_impl(
31
+ function: str,
32
+ str_input: Union[str, snowpark.Column],
33
+ categories: Union[List[str], snowpark.Column],
34
+ session: Optional[snowpark.Session] = None,
35
+ ) -> Union[str, snowpark.Column]:
36
+ return call_sql_function(function, session, str_input, categories)
@@ -1,37 +1,299 @@
1
- from typing import Iterator, Optional, Union
1
+ import json
2
+ import logging
3
+ import time
4
+ from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast
5
+ from urllib.parse import urlunparse
6
+
7
+ import requests
8
+ from typing_extensions import NotRequired
2
9
 
3
10
  from snowflake import snowpark
11
+ from snowflake.cortex._sse_client import SSEClient
4
12
  from snowflake.cortex._util import (
5
13
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
6
- call_rest_function,
7
- call_sql_function,
8
- process_rest_response,
14
+ SnowflakeAuthenticationException,
15
+ SnowflakeConfigurationException,
9
16
  )
10
17
  from snowflake.ml._internal import telemetry
18
+ from snowflake.snowpark import context, functions
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class ConversationMessage(TypedDict):
24
+ """Represents an conversation interaction."""
25
+
26
+ role: str
27
+ """The role of the participant. For example, "user" or "assistant"."""
28
+
29
+ content: str
30
+ """The content of the message."""
31
+
32
+
33
+ class CompleteOptions(TypedDict):
34
+ """Options configuring a snowflake.cortex.Complete call."""
35
+
36
+ max_tokens: NotRequired[int]
37
+ """ Sets the maximum number of output tokens in the response. Small values can result in
38
+ truncated responses. """
39
+ temperature: NotRequired[float]
40
+ """ A value from 0 to 1 (inclusive) that controls the randomness of the output of the language
41
+ model. A higher temperature (for example, 0.7) results in more diverse and random output, while a lower
42
+ temperature (such as 0.2) makes the output more deterministic and focused. """
43
+
44
+ top_p: NotRequired[float]
45
+ """ A value from 0 to 1 (inclusive) that controls the randomness and diversity of the language model,
46
+ generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
47
+ that the model outputs, while temperature influences which tokens are chosen at each step. """
48
+
49
+
50
+ class ResponseParseException(Exception):
51
+ """This exception is raised when the server response cannot be parsed."""
52
+
53
+ pass
54
+
55
+
56
+ _MAX_RETRY_SECONDS = 30
57
+
58
+
59
+ def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Response]:
60
+ def inner(*args: Any, **kwargs: Any) -> requests.Response:
61
+ deadline = cast(Optional[float], kwargs["deadline"])
62
+ kwargs = {key: value for key, value in kwargs.items() if key != "deadline"}
63
+ expRetrySeconds = 0.5
64
+ while True:
65
+ if deadline is not None and time.time() > deadline:
66
+ raise TimeoutError()
67
+ response = func(*args, **kwargs)
68
+ if response.status_code >= 200 and response.status_code < 300:
69
+ return response
70
+ retry_status_codes = [429, 503, 504]
71
+ if response.status_code not in retry_status_codes:
72
+ response.raise_for_status()
73
+ logger.debug(f"request failed with status code {response.status_code}, retrying")
74
+
75
+ # Formula: delay(i) = max(RetryAfterHeader, min(2^i, _MAX_RETRY_SECONDS)).
76
+ expRetrySeconds = min(2 * expRetrySeconds, _MAX_RETRY_SECONDS)
77
+ retrySeconds = expRetrySeconds
78
+ retryAfterHeader = response.headers.get("retry-after")
79
+ if retryAfterHeader is not None:
80
+ retrySeconds = max(retrySeconds, int(retryAfterHeader))
81
+ logger.debug(f"sleeping for {retrySeconds}s before retrying")
82
+ time.sleep(retrySeconds)
83
+
84
+ return inner
85
+
86
+
87
+ @retry
88
+ def _call_complete_rest(
89
+ model: str,
90
+ prompt: Union[str, List[ConversationMessage]],
91
+ options: Optional[CompleteOptions] = None,
92
+ session: Optional[snowpark.Session] = None,
93
+ stream: bool = False,
94
+ ) -> requests.Response:
95
+ session = session or context.get_active_session()
96
+ if session is None:
97
+ raise SnowflakeAuthenticationException(
98
+ """Session required. Provide the session through a session=... argument or ensure an active session is
99
+ available in your environment."""
100
+ )
101
+
102
+ if session.connection.host is None or session.connection.host == "":
103
+ raise SnowflakeConfigurationException("Snowflake connection configuration does not specify 'host'")
104
+
105
+ if session.connection.rest is None or not hasattr(session.connection.rest, "token"):
106
+ raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
107
+
108
+ if session.connection.rest.token is None or session.connection.rest.token == "":
109
+ raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
110
+
111
+ scheme = "https"
112
+ if hasattr(session.connection, "scheme"):
113
+ scheme = session.connection.scheme
114
+ url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", ""))
115
+
116
+ headers = {
117
+ "Content-Type": "application/json",
118
+ "Authorization": f'Snowflake Token="{session.connection.rest.token}"',
119
+ "Accept": "application/json, text/event-stream",
120
+ }
121
+
122
+ data = {
123
+ "model": model,
124
+ "stream": stream,
125
+ }
126
+ if isinstance(prompt, List):
127
+ data["messages"] = prompt
128
+ else:
129
+ data["messages"] = [{"content": prompt}]
130
+
131
+ if options:
132
+ if "max_tokens" in options:
133
+ data["max_tokens"] = options["max_tokens"]
134
+ data["max_output_tokens"] = options["max_tokens"]
135
+ if "temperature" in options:
136
+ data["temperature"] = options["temperature"]
137
+ if "top_p" in options:
138
+ data["top_p"] = options["top_p"]
139
+
140
+ logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
141
+ return requests.post(
142
+ url,
143
+ json=data,
144
+ headers=headers,
145
+ stream=stream,
146
+ )
147
+
148
+
149
+ def _process_rest_response(
150
+ response: requests.Response,
151
+ stream: bool = False,
152
+ deadline: Optional[float] = None,
153
+ ) -> Union[str, Iterator[str]]:
154
+ if stream:
155
+ return _return_stream_response(response, deadline)
156
+
157
+ try:
158
+ content = response.json()["choices"][0]["message"]["content"]
159
+ assert isinstance(content, str)
160
+ return content
161
+ except (KeyError, IndexError, AssertionError) as e:
162
+ # Unlike the streaming case, errors are not ignored because a message must be returned.
163
+ raise ResponseParseException("Failed to parse message from response.") from e
164
+
165
+
166
+ def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
167
+ client = SSEClient(response)
168
+ for event in client.events():
169
+ if deadline is not None and time.time() > deadline:
170
+ raise TimeoutError()
171
+ try:
172
+ yield json.loads(event.data)["choices"][0]["delta"]["content"]
173
+ except (json.JSONDecodeError, KeyError, IndexError):
174
+ # For the sake of evolution of the output format,
175
+ # ignore stream messages that don't match the expected format.
176
+ pass
177
+
178
+
179
+ def _complete_call_sql_function_snowpark(
180
+ function: str, *args: Union[str, snowpark.Column, CompleteOptions]
181
+ ) -> snowpark.Column:
182
+ return cast(snowpark.Column, functions.builtin(function)(*args))
183
+
184
+
185
+ def _complete_call_sql_function_immediate(
186
+ function: str,
187
+ model: str,
188
+ prompt: Union[str, List[ConversationMessage]],
189
+ options: Optional[CompleteOptions],
190
+ session: Optional[snowpark.Session],
191
+ ) -> str:
192
+ session = session or context.get_active_session()
193
+ if session is None:
194
+ raise SnowflakeAuthenticationException(
195
+ """Session required. Provide the session through a session=... argument or ensure an active session is
196
+ available in your environment."""
197
+ )
198
+
199
+ # https://docs.snowflake.com/en/sql-reference/functions/complete-snowflake-cortex
200
+ if options is not None or not isinstance(prompt, str):
201
+ if isinstance(prompt, List):
202
+ prompt_arg = prompt
203
+ else:
204
+ prompt_arg = [{"role": "user", "content": prompt}]
205
+ options = options or {}
206
+ lit_args = [
207
+ functions.lit(model),
208
+ functions.lit(prompt_arg),
209
+ functions.lit(options),
210
+ ]
211
+ else:
212
+ lit_args = [
213
+ functions.lit(model),
214
+ functions.lit(prompt),
215
+ ]
216
+
217
+ empty_df = session.create_dataframe([snowpark.Row()])
218
+ df = empty_df.select(functions.builtin(function)(*lit_args))
219
+ return cast(str, df.collect()[0][0])
220
+
221
+
222
+ def _complete_sql_impl(
223
+ function: str,
224
+ model: Union[str, snowpark.Column],
225
+ prompt: Union[str, List[ConversationMessage], snowpark.Column],
226
+ options: Optional[Union[CompleteOptions, snowpark.Column]],
227
+ session: Optional[snowpark.Session],
228
+ ) -> Union[str, snowpark.Column]:
229
+ if isinstance(prompt, snowpark.Column):
230
+ if options is not None:
231
+ return _complete_call_sql_function_snowpark(function, model, prompt, options)
232
+ else:
233
+ return _complete_call_sql_function_snowpark(function, model, prompt)
234
+ if isinstance(model, snowpark.Column):
235
+ raise ValueError("'model' cannot be a snowpark.Column when 'prompt' is a string.")
236
+ if isinstance(options, snowpark.Column):
237
+ raise ValueError("'options' cannot be a snowpark.Column when 'prompt' is a string.")
238
+ return _complete_call_sql_function_immediate(function, model, prompt, options, session)
239
+
240
+
241
+ def _complete_impl(
242
+ model: Union[str, snowpark.Column],
243
+ prompt: Union[str, List[ConversationMessage], snowpark.Column],
244
+ options: Optional[CompleteOptions] = None,
245
+ session: Optional[snowpark.Session] = None,
246
+ use_rest_api_experimental: bool = False,
247
+ stream: bool = False,
248
+ function: str = "snowflake.cortex.complete",
249
+ timeout: Optional[float] = None,
250
+ deadline: Optional[float] = None,
251
+ ) -> Union[str, Iterator[str], snowpark.Column]:
252
+ if timeout is not None and deadline is not None:
253
+ raise ValueError('only one of "timeout" and "deadline" must be set')
254
+ if timeout is not None:
255
+ deadline = time.time() + timeout
256
+ if use_rest_api_experimental:
257
+ if not isinstance(model, str):
258
+ raise ValueError("in REST mode, 'model' must be a string")
259
+ if not isinstance(prompt, str) and not isinstance(prompt, List):
260
+ raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
261
+ response = _call_complete_rest(model, prompt, options, session=session, stream=stream, deadline=deadline)
262
+ assert response.status_code >= 200 and response.status_code < 300
263
+ return _process_rest_response(response, stream=stream)
264
+ if stream is True:
265
+ raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
266
+ return _complete_sql_impl(function, model, prompt, options, session)
11
267
 
12
268
 
13
- @snowpark._internal.utils.experimental(version="1.0.12")
14
269
  @telemetry.send_api_usage_telemetry(
15
270
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
16
271
  )
17
272
  def Complete(
18
273
  model: Union[str, snowpark.Column],
19
- prompt: Union[str, snowpark.Column],
274
+ prompt: Union[str, List[ConversationMessage], snowpark.Column],
275
+ *,
276
+ options: Optional[CompleteOptions] = None,
20
277
  session: Optional[snowpark.Session] = None,
21
278
  use_rest_api_experimental: bool = False,
22
279
  stream: bool = False,
280
+ timeout: Optional[float] = None,
281
+ deadline: Optional[float] = None,
23
282
  ) -> Union[str, Iterator[str], snowpark.Column]:
24
283
  """Complete calls into the LLM inference service to perform completion.
25
284
 
26
285
  Args:
27
286
  model: A Column of strings representing model types.
28
287
  prompt: A Column of prompts to send to the LLM.
288
+ options: A instance of snowflake.cortex.CompleteOptions
29
289
  session: The snowpark session to use. Will be inferred by context if not specified.
30
290
  use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
31
291
  experimental and can be removed at any time.
32
292
  stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
33
293
  output as it is received. Each update is a string containing the new text content since the previous update.
34
294
  The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
295
+ timeout (float): Timeout in seconds to retry failed REST requests.
296
+ deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
35
297
 
36
298
  Raises:
37
299
  ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
@@ -39,18 +301,16 @@ def Complete(
39
301
  Returns:
40
302
  A column of string responses.
41
303
  """
42
- if stream is True and use_rest_api_experimental is False:
43
- raise ValueError("If stream is set to True use_rest_api_experimental must also be set to True")
44
- if use_rest_api_experimental:
45
- response = call_rest_function("complete", model, prompt, session=session, stream=stream)
46
- return process_rest_response(response)
47
- return _complete_impl("snowflake.cortex.complete", model, prompt, session=session)
48
-
49
-
50
- def _complete_impl(
51
- function: str,
52
- model: Union[str, snowpark.Column],
53
- prompt: Union[str, snowpark.Column],
54
- session: Optional[snowpark.Session] = None,
55
- ) -> Union[str, snowpark.Column]:
56
- return call_sql_function(function, session, model, prompt)
304
+ try:
305
+ return _complete_impl(
306
+ model,
307
+ prompt,
308
+ options=options,
309
+ session=session,
310
+ use_rest_api_experimental=use_rest_api_experimental,
311
+ stream=stream,
312
+ timeout=timeout,
313
+ deadline=deadline,
314
+ )
315
+ except ValueError as err:
316
+ raise err
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
5
5
  from snowflake.ml._internal import telemetry
6
6
 
7
7
 
8
- @snowpark._internal.utils.experimental(version="1.0.12")
9
8
  @telemetry.send_api_usage_telemetry(
10
9
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
11
10
  )
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
5
5
  from snowflake.ml._internal import telemetry
6
6
 
7
7
 
8
- @snowpark._internal.utils.experimental(version="1.0.12")
9
8
  @telemetry.send_api_usage_telemetry(
10
9
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
11
10
  )
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
5
5
  from snowflake.ml._internal import telemetry
6
6
 
7
7
 
8
- @snowpark._internal.utils.experimental(version="1.0.12")
9
8
  @telemetry.send_api_usage_telemetry(
10
9
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
11
10
  )
@@ -5,7 +5,6 @@ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_
5
5
  from snowflake.ml._internal import telemetry
6
6
 
7
7
 
8
- @snowpark._internal.utils.experimental(version="1.0.12")
9
8
  @telemetry.send_api_usage_telemetry(
10
9
  project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
11
10
  )
snowflake/cortex/_util.py CHANGED
@@ -1,24 +1,19 @@
1
- import json
2
- from typing import Iterator, Optional, Union, cast
3
- from urllib.parse import urljoin, urlparse
4
-
5
- import requests
1
+ from typing import Dict, List, Optional, Union, cast
6
2
 
7
3
  from snowflake import snowpark
8
- from snowflake.cortex._sse_client import SSEClient
9
4
  from snowflake.snowpark import context, functions
10
5
 
11
6
  CORTEX_FUNCTIONS_TELEMETRY_PROJECT = "CortexFunctions"
12
7
 
13
8
 
14
- class SSEParseException(Exception):
15
- """This exception is raised when an invalid server sent event is received from the server."""
9
+ class SnowflakeAuthenticationException(Exception):
10
+ """This exception is raised when there is an issue with Snowflake's configuration."""
16
11
 
17
12
  pass
18
13
 
19
14
 
20
- class SnowflakeAuthenticationException(Exception):
21
- """This exception is raised when the session object does not have session.connection.rest.token attribute."""
15
+ class SnowflakeConfigurationException(Exception):
16
+ """This exception is raised when there is an issue with Snowflake's configuration."""
22
17
 
23
18
  pass
24
19
 
@@ -28,9 +23,10 @@ class SnowflakeAuthenticationException(Exception):
28
23
  def call_sql_function(
29
24
  function: str,
30
25
  session: Optional[snowpark.Session],
31
- *args: Union[str, snowpark.Column],
26
+ *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
32
27
  ) -> Union[str, snowpark.Column]:
33
28
  handle_as_column = False
29
+
34
30
  for arg in args:
35
31
  if isinstance(arg, snowpark.Column):
36
32
  handle_as_column = True
@@ -43,17 +39,18 @@ def call_sql_function(
43
39
  )
44
40
 
45
41
 
46
- def _call_sql_function_column(function: str, *args: Union[str, snowpark.Column]) -> snowpark.Column:
42
+ def _call_sql_function_column(
43
+ function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
44
+ ) -> snowpark.Column:
47
45
  return cast(snowpark.Column, functions.builtin(function)(*args))
48
46
 
49
47
 
50
48
  def _call_sql_function_immediate(
51
49
  function: str,
52
50
  session: Optional[snowpark.Session],
53
- *args: Union[str, snowpark.Column],
51
+ *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
54
52
  ) -> str:
55
- if session is None:
56
- session = context.get_active_session()
53
+ session = session or context.get_active_session()
57
54
  if session is None:
58
55
  raise SnowflakeAuthenticationException(
59
56
  """Session required. Provide the session through a session=... argument or ensure an active session is
@@ -67,73 +64,3 @@ def _call_sql_function_immediate(
67
64
  empty_df = session.create_dataframe([snowpark.Row()])
68
65
  df = empty_df.select(functions.builtin(function)(*lit_args))
69
66
  return cast(str, df.collect()[0][0])
70
-
71
-
72
- def call_rest_function(
73
- function: str,
74
- model: Union[str, snowpark.Column],
75
- prompt: Union[str, snowpark.Column],
76
- session: Optional[snowpark.Session] = None,
77
- stream: bool = False,
78
- ) -> requests.Response:
79
- if session is None:
80
- session = context.get_active_session()
81
- if session is None:
82
- raise SnowflakeAuthenticationException(
83
- """Session required. Provide the session through a session=... argument or ensure an active session is
84
- available in your environment."""
85
- )
86
-
87
- if not hasattr(session.connection.rest, "token"):
88
- raise SnowflakeAuthenticationException("Snowflake session error: REST token missing.")
89
-
90
- if session.connection.rest.token is None or session.connection.rest.token == "": # type: ignore[union-attr]
91
- raise SnowflakeAuthenticationException("Snowflake session error: REST token is empty.")
92
-
93
- url = urljoin(session.connection.host, f"api/v2/cortex/inference/{function}")
94
- if urlparse(url).scheme == "":
95
- url = "https://" + url
96
- headers = {
97
- "Content-Type": "application/json",
98
- "Authorization": f'Snowflake Token="{session.connection.rest.token}"', # type: ignore[union-attr]
99
- "Accept": "application/json, text/event-stream",
100
- }
101
-
102
- data = {
103
- "model": model,
104
- "messages": [{"content": prompt}],
105
- "stream": stream,
106
- }
107
-
108
- response = requests.post(
109
- url,
110
- json=data,
111
- headers=headers,
112
- stream=stream,
113
- )
114
- response.raise_for_status()
115
- return response
116
-
117
-
118
- def process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
119
- if not stream:
120
- try:
121
- message = response.json()["choices"][0]["message"]
122
- output = str(message.get("content", ""))
123
- return output
124
- except (KeyError, IndexError) as e:
125
- raise SSEParseException("Failed to parse streamed response.") from e
126
- else:
127
- return _return_gen(response)
128
-
129
-
130
- def _return_gen(response: requests.Response) -> Iterator[str]:
131
- client = SSEClient(response)
132
- for event in client.events():
133
- response_loaded = json.loads(event.data)
134
- try:
135
- delta = response_loaded["choices"][0]["delta"]
136
- output = str(delta.get("content", ""))
137
- yield output
138
- except (KeyError, IndexError) as e:
139
- raise SSEParseException("Failed to parse streamed response.") from e
@@ -64,13 +64,20 @@ class ImageRegistryHttpClient:
64
64
  operations. For general use of a retryable HTTP client, consider using the "retryable_http" module.
65
65
  """
66
66
 
67
- def __init__(self, *, session: snowpark.Session, repo_url: str) -> None:
67
+ def __init__(self, *, repo_url: str, session: Optional[snowpark.Session] = None, no_cred: bool = False) -> None:
68
68
  self._repo_url = repo_url
69
- self._session_token_manager = session_token_manager.SessionTokenManager(session)
70
69
  self._retryable_http = retryable_http.get_http_client()
71
- self._bearer_token = ""
70
+ self._no_cred = no_cred
71
+
72
+ if not self._no_cred:
73
+ self._bearer_token = ""
74
+ assert session is not None
75
+ self._session_token_manager = session_token_manager.SessionTokenManager(session)
72
76
 
73
77
  def _with_bearer_token_header(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]:
78
+ if self._no_cred:
79
+ return {} if not headers else headers.copy()
80
+
74
81
  if not self._bearer_token:
75
82
  self._fetch_bearer_token()
76
83
  assert self._bearer_token