snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -90,7 +90,6 @@ def _call_complete_rest(
90
90
  prompt: Union[str, List[ConversationMessage]],
91
91
  options: Optional[CompleteOptions] = None,
92
92
  session: Optional[snowpark.Session] = None,
93
- stream: bool = False,
94
93
  ) -> requests.Response:
95
94
  session = session or context.get_active_session()
96
95
  if session is None:
@@ -121,7 +120,7 @@ def _call_complete_rest(
121
120
 
122
121
  data = {
123
122
  "model": model,
124
- "stream": stream,
123
+ "stream": True,
125
124
  }
126
125
  if isinstance(prompt, List):
127
126
  data["messages"] = prompt
@@ -137,32 +136,15 @@ def _call_complete_rest(
137
136
  if "top_p" in options:
138
137
  data["top_p"] = options["top_p"]
139
138
 
140
- logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
139
+ logger.debug(f"making POST request to {url} (model={model})")
141
140
  return requests.post(
142
141
  url,
143
142
  json=data,
144
143
  headers=headers,
145
- stream=stream,
144
+ stream=True,
146
145
  )
147
146
 
148
147
 
149
- def _process_rest_response(
150
- response: requests.Response,
151
- stream: bool = False,
152
- deadline: Optional[float] = None,
153
- ) -> Union[str, Iterator[str]]:
154
- if stream:
155
- return _return_stream_response(response, deadline)
156
-
157
- try:
158
- content = response.json()["choices"][0]["message"]["content"]
159
- assert isinstance(content, str)
160
- return content
161
- except (KeyError, IndexError, AssertionError) as e:
162
- # Unlike the streaming case, errors are not ignored because a message must be returned.
163
- raise ResponseParseException("Failed to parse message from response.") from e
164
-
165
-
166
148
  def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
167
149
  client = SSEClient(response)
168
150
  for event in client.events():
@@ -243,7 +225,6 @@ def _complete_impl(
243
225
  prompt: Union[str, List[ConversationMessage], snowpark.Column],
244
226
  options: Optional[CompleteOptions] = None,
245
227
  session: Optional[snowpark.Session] = None,
246
- use_rest_api_experimental: bool = False,
247
228
  stream: bool = False,
248
229
  function: str = "snowflake.cortex.complete",
249
230
  timeout: Optional[float] = None,
@@ -253,16 +234,14 @@ def _complete_impl(
253
234
  raise ValueError('only one of "timeout" and "deadline" must be set')
254
235
  if timeout is not None:
255
236
  deadline = time.time() + timeout
256
- if use_rest_api_experimental:
237
+ if stream:
257
238
  if not isinstance(model, str):
258
239
  raise ValueError("in REST mode, 'model' must be a string")
259
240
  if not isinstance(prompt, str) and not isinstance(prompt, List):
260
241
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
261
- response = _call_complete_rest(model, prompt, options, session=session, stream=stream, deadline=deadline)
242
+ response = _call_complete_rest(model, prompt, options, session=session, deadline=deadline)
262
243
  assert response.status_code >= 200 and response.status_code < 300
263
- return _process_rest_response(response, stream=stream)
264
- if stream is True:
265
- raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
244
+ return _return_stream_response(response, deadline)
266
245
  return _complete_sql_impl(function, model, prompt, options, session)
267
246
 
268
247
 
@@ -275,7 +254,6 @@ def Complete(
275
254
  *,
276
255
  options: Optional[CompleteOptions] = None,
277
256
  session: Optional[snowpark.Session] = None,
278
- use_rest_api_experimental: bool = False,
279
257
  stream: bool = False,
280
258
  timeout: Optional[float] = None,
281
259
  deadline: Optional[float] = None,
@@ -287,16 +265,13 @@ def Complete(
287
265
  prompt: A Column of prompts to send to the LLM.
288
266
  options: A instance of snowflake.cortex.CompleteOptions
289
267
  session: The snowpark session to use. Will be inferred by context if not specified.
290
- use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
291
- experimental and can be removed at any time.
292
268
  stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
293
269
  output as it is received. Each update is a string containing the new text content since the previous update.
294
- The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
295
270
  timeout (float): Timeout in seconds to retry failed REST requests.
296
271
  deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
297
272
 
298
273
  Raises:
299
- ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
274
+ ValueError: incorrect argument.
300
275
 
301
276
  Returns:
302
277
  A column of string responses.
@@ -307,7 +282,6 @@ def Complete(
307
282
  prompt,
308
283
  options=options,
309
284
  session=session,
310
- use_rest_api_experimental=use_rest_api_experimental,
311
285
  stream=stream,
312
286
  timeout=timeout,
313
287
  deadline=deadline,
@@ -27,7 +27,6 @@ class CONDA_OS(Enum):
27
27
  NO_ARCH = "noarch"
28
28
 
29
29
 
30
- _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
31
30
  _NODEFAULTS = "nodefaults"
32
31
  _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
33
32
  _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
@@ -36,6 +35,7 @@ _SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
36
35
  DEFAULT_CHANNEL_NAME = ""
37
36
  SNOWML_SPROC_ENV = "IN_SNOWML_SPROC"
38
37
  SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
38
+ SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
39
39
 
40
40
 
41
41
  def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
@@ -370,7 +370,7 @@ def get_matched_package_versions_in_snowflake_conda_channel(
370
370
 
371
371
  assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
372
372
 
373
- url = f"{_SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
373
+ url = f"{SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
374
374
 
375
375
  if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
376
376
  try:
@@ -477,6 +477,7 @@ def save_conda_env_file(
477
477
  path: pathlib.Path,
478
478
  conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
479
479
  python_version: str,
480
+ default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
480
481
  ) -> None:
481
482
  """Generate conda.yml file given a dict of dependencies after validation.
482
483
  The channels part of conda.yml file will contains Snowflake Anaconda Channel, nodefaults and all channel names
@@ -489,6 +490,7 @@ def save_conda_env_file(
489
490
  path: Path to the conda.yml file.
490
491
  conda_chan_deps: Dict of conda dependencies after validated.
491
492
  python_version: A string 'major.minor' showing python version relate to model.
493
+ default_channel_override: The default channel to be put in the first place of the channels section.
492
494
  """
493
495
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
494
496
  path.parent.mkdir(parents=True, exist_ok=True)
@@ -499,7 +501,11 @@ def save_conda_env_file(
499
501
  channels = list(dict(sorted(conda_chan_deps.items(), key=lambda item: len(item[1]), reverse=True)).keys())
500
502
  if DEFAULT_CHANNEL_NAME in channels:
501
503
  channels.remove(DEFAULT_CHANNEL_NAME)
502
- env["channels"] = [_SNOWFLAKE_CONDA_CHANNEL_URL] + channels + [_NODEFAULTS]
504
+
505
+ if default_channel_override in channels:
506
+ channels.remove(default_channel_override)
507
+
508
+ env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
503
509
  env["dependencies"] = [f"python=={python_version}.*"]
504
510
  for chan, reqs in conda_chan_deps.items():
505
511
  env["dependencies"].extend(
@@ -567,8 +573,8 @@ def load_conda_env_file(
567
573
  python_version = None
568
574
 
569
575
  channels = env.get("channels", [])
570
- if _SNOWFLAKE_CONDA_CHANNEL_URL in channels:
571
- channels.remove(_SNOWFLAKE_CONDA_CHANNEL_URL)
576
+ if len(channels) >= 1:
577
+ channels = channels[1:] # Skip the first channel which is the default channel
572
578
  if _NODEFAULTS in channels:
573
579
  channels.remove(_NODEFAULTS)
574
580
 
@@ -4,7 +4,10 @@ ATTRIBUTE_NOT_SET = (
4
4
  "-differences."
5
5
  )
6
6
  SIZE_MISMATCH = "Size mismatch: {}={}, {}={}."
7
- INVALID_MODEL_PARAM = "Invalid parameter {} for model {}. Valid parameters: {}."
7
+ INVALID_MODEL_PARAM = (
8
+ "Invalid parameter {} for model {}. Valid parameters: {}."
9
+ "Note: Scikit learn params cannot be set until the model has been fit."
10
+ )
8
11
  UNSUPPORTED_MODEL_CONVERSION = "Object doesn't support {}. Please use {}."
9
12
  INCOMPATIBLE_NEW_SKLEARN_PARAM = "Incompatible scikit-learn version: {} requires scikit-learn>={}. Installed: {}."
10
13
  REMOVED_SKLEARN_PARAM = "Incompatible scikit-learn version: {} is removed in scikit-learn>={}. Installed: {}."
@@ -1,4 +1,5 @@
1
1
  #!/usr/bin/env python3
2
+ import contextvars
2
3
  import enum
3
4
  import functools
4
5
  import inspect
@@ -12,6 +13,7 @@ from typing import (
12
13
  List,
13
14
  Mapping,
14
15
  Optional,
16
+ Set,
15
17
  Tuple,
16
18
  TypeVar,
17
19
  Union,
@@ -28,7 +30,7 @@ from snowflake.ml._internal.exceptions import (
28
30
  exceptions as snowml_exceptions,
29
31
  )
30
32
  from snowflake.snowpark import dataframe, exceptions as snowpark_exceptions, session
31
- from snowflake.snowpark._internal import utils
33
+ from snowflake.snowpark._internal import server_connection, utils
32
34
 
33
35
  _log_counter = 0
34
36
  _FLUSH_SIZE = 10
@@ -44,6 +46,20 @@ _Args = ParamSpec("_Args")
44
46
  _ReturnValue = TypeVar("_ReturnValue")
45
47
 
46
48
 
49
+ @enum.unique
50
+ class TelemetryProject(enum.Enum):
51
+ MLOPS = "MLOps"
52
+ MODELING = "ModelDevelopment"
53
+ # TODO: Update with remaining projects.
54
+
55
+
56
+ @enum.unique
57
+ class TelemetrySubProject(enum.Enum):
58
+ MONITORING = "Monitoring"
59
+ REGISTRY = "ModelManagement"
60
+ # TODO: Update with remaining subprojects.
61
+
62
+
47
63
  @enum.unique
48
64
  class TelemetryField(enum.Enum):
49
65
  # constants
@@ -71,6 +87,122 @@ class TelemetryField(enum.Enum):
71
87
  FUNC_CAT_USAGE = "usage"
72
88
 
73
89
 
90
+ class _TelemetrySourceType(enum.Enum):
91
+ # Automatically inferred telemetry/statement parameters
92
+ AUTO_TELEMETRY = "SNOWML_AUTO_TELEMETRY"
93
+ # Mixture of manual and automatic telemetry/statement parameters
94
+ AUGMENT_TELEMETRY = "SNOWML_AUGMENT_TELEMETRY"
95
+
96
+
97
+ _statement_params_context_var: contextvars.ContextVar[Dict[str, str]] = contextvars.ContextVar("statement_params")
98
+
99
+
100
+ class _StatementParamsPatchManager:
101
+ def __init__(self) -> None:
102
+ self._patch_cache: Set[server_connection.ServerConnection] = set()
103
+ self._context_var: contextvars.ContextVar[Dict[str, str]] = _statement_params_context_var
104
+
105
+ def apply_patches(self) -> None:
106
+ try:
107
+ # Apply patching to all active sessions in case of multiple
108
+ for sess in session._get_active_sessions():
109
+ # Check patch cache here to avoid unnecessary context switches
110
+ if self._get_target(sess) not in self._patch_cache:
111
+ self._patch_session(sess)
112
+ except snowpark_exceptions.SnowparkSessionException:
113
+ pass
114
+
115
+ def set_statement_params(self, statement_params: Dict[str, str]) -> None:
116
+ # Only set value if not already set in context
117
+ if not self._context_var.get({}):
118
+ self._context_var.set(statement_params)
119
+
120
+ def _get_target(self, session: session.Session) -> server_connection.ServerConnection:
121
+ return cast(server_connection.ServerConnection, session._conn)
122
+
123
+ def _patch_session(self, session: session.Session, throw_on_patch_fail: bool = False) -> None:
124
+ # Extract target
125
+ try:
126
+ target = self._get_target(session)
127
+ except AttributeError:
128
+ if throw_on_patch_fail:
129
+ raise
130
+ # TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection
131
+ return
132
+
133
+ # Check if session has already been patched
134
+ if target in self._patch_cache:
135
+ return
136
+ self._patch_cache.add(target)
137
+
138
+ functions = [
139
+ ("execute_and_notify_query_listener", "_statement_params"),
140
+ ("execute_async_and_notify_query_listener", "_statement_params"),
141
+ ]
142
+
143
+ for func, param_name in functions:
144
+ try:
145
+ self._patch_with_statement_params(target, func, param_name=param_name)
146
+ except AttributeError:
147
+ if throw_on_patch_fail: # primarily used for testing
148
+ raise
149
+ # TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection
150
+ pass
151
+
152
+ def _patch_with_statement_params(
153
+ self, target: object, function_name: str, param_name: str = "statement_params"
154
+ ) -> None:
155
+ func = getattr(target, function_name)
156
+ assert callable(func)
157
+
158
+ @functools.wraps(func)
159
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
160
+ # Retrieve context level statement parameters
161
+ context_params = self._context_var.get(dict())
162
+ if not context_params:
163
+ # Exit early if not in SnowML (decorator) context
164
+ return func(*args, **kwargs)
165
+
166
+ # Extract any explicitly provided statement parameters
167
+ orig_kwargs = dict(kwargs)
168
+ in_params = kwargs.pop(param_name, None) or {}
169
+
170
+ # Inject a special flag to statement parameters so we can filter out these patched logs if necessary
171
+ # Calls that include SnowML telemetry are tagged with "SNOWML_AUGMENT_TELEMETRY"
172
+ # and calls without SnowML telemetry are tagged with "SNOWML_AUTO_TELEMETRY"
173
+ if TelemetryField.KEY_PROJECT.value in in_params:
174
+ context_params["snowml_telemetry_type"] = _TelemetrySourceType.AUGMENT_TELEMETRY.value
175
+ else:
176
+ context_params["snowml_telemetry_type"] = _TelemetrySourceType.AUTO_TELEMETRY.value
177
+
178
+ # Apply any explicitly provided statement parameters and result into function call
179
+ context_params.update(in_params)
180
+ kwargs[param_name] = context_params
181
+
182
+ try:
183
+ return func(*args, **kwargs)
184
+ except TypeError as e:
185
+ if str(e).endswith(f"unexpected keyword argument '{param_name}'"):
186
+ # TODO: Log warning that this patch is invalid
187
+ # Unwrap function for future invocations
188
+ setattr(target, function_name, func)
189
+ return func(*args, **orig_kwargs)
190
+ else:
191
+ raise
192
+
193
+ setattr(target, function_name, wrapper)
194
+
195
+ def __getstate__(self) -> Dict[str, Any]:
196
+ return {}
197
+
198
+ def __setstate__(self, state: Dict[str, Any]) -> None:
199
+ # unpickling does not call __init__ by default, do it manually here
200
+ self.__init__() # type: ignore[misc]
201
+
202
+
203
+ _patch_manager = _StatementParamsPatchManager()
204
+
205
+
74
206
  def get_statement_params(
75
207
  project: str, subproject: Optional[str] = None, class_name: Optional[str] = None
76
208
  ) -> Dict[str, Any]:
@@ -361,7 +493,18 @@ def send_api_usage_telemetry(
361
493
  obj._statement_params = statement_params # type: ignore[assignment]
362
494
  return obj
363
495
 
496
+ # Set up framework-level credit usage instrumentation
497
+ ctx = contextvars.copy_context()
498
+ _patch_manager.apply_patches()
499
+
500
+ # This function should be executed with ctx.run()
501
+ def execute_func_with_statement_params() -> _ReturnValue:
502
+ _patch_manager.set_statement_params(statement_params)
503
+ result = func(*args, **kwargs)
504
+ return update_stmt_params_if_snowpark_df(result, statement_params)
505
+
364
506
  # prioritize `conn_attr_name` over the active session
507
+ telemetry_enabled = True
365
508
  if conn_attr_name:
366
509
  # raise AttributeError if conn attribute does not exist in `self`
367
510
  conn = operator.attrgetter(conn_attr_name)(args[0])
@@ -373,22 +516,17 @@ def send_api_usage_telemetry(
373
516
  else:
374
517
  try:
375
518
  active_session = next(iter(session._get_active_sessions()))
376
- # server no default session
519
+ conn = active_session._conn._conn
520
+ telemetry_enabled = active_session.telemetry_enabled
377
521
  except snowpark_exceptions.SnowparkSessionException:
378
- try:
379
- return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params)
380
- except Exception as e:
381
- if isinstance(e, snowml_exceptions.SnowflakeMLException):
382
- raise e.original_exception.with_traceback(e.__traceback__) from None
383
- # suppress SnowparkSessionException from telemetry in the stack trace
384
- raise e from None
385
-
386
- conn = active_session._conn._conn
387
- if (not active_session.telemetry_enabled) or (conn is None):
388
- try:
389
- return update_stmt_params_if_snowpark_df(func(*args, **kwargs), statement_params)
390
- except snowml_exceptions.SnowflakeMLException as e:
391
- raise e.original_exception from e
522
+ conn = None
523
+
524
+ if conn is None or not telemetry_enabled:
525
+ # Telemetry not enabled, just execute without our additional telemetry logic
526
+ try:
527
+ return ctx.run(execute_func_with_statement_params)
528
+ except snowml_exceptions.SnowflakeMLException as e:
529
+ raise e.original_exception from e
392
530
 
393
531
  # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
394
532
  telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
@@ -401,11 +539,11 @@ def send_api_usage_telemetry(
401
539
  custom_tags=custom_tags,
402
540
  )
403
541
  try:
404
- res = func(*args, **kwargs)
542
+ return ctx.run(execute_func_with_statement_params)
405
543
  except Exception as e:
406
544
  if not isinstance(e, snowml_exceptions.SnowflakeMLException):
407
545
  # already handled via a nested decorated function
408
- if hasattr(e, "_snowflake_ml_handled") and e._snowflake_ml_handled:
546
+ if getattr(e, "_snowflake_ml_handled", False):
409
547
  raise e
410
548
  if isinstance(e, snowpark_exceptions.SnowparkClientException):
411
549
  me = snowml_exceptions.SnowflakeMLException(
@@ -424,8 +562,6 @@ def send_api_usage_telemetry(
424
562
  raise me.original_exception from None
425
563
  else:
426
564
  raise me.original_exception from e
427
- else:
428
- return update_stmt_params_if_snowpark_df(res, statement_params)
429
565
  finally:
430
566
  telemetry.send_function_usage_telemetry(**telemetry_args)
431
567
  global _log_counter
@@ -10,9 +10,11 @@ SF_QUOTED_IDENTIFIER = '"(?:[^"]|"")*"'
10
10
  _SF_IDENTIFIER = f"({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER}|{SF_QUOTED_IDENTIFIER})"
11
11
  SF_IDENTIFIER_RE = re.compile(_SF_IDENTIFIER)
12
12
  _SF_SCHEMA_LEVEL_OBJECT = (
13
- rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})(?P<others>.*)"
13
+ rf"(?:(?:(?P<db>{_SF_IDENTIFIER})\.)?(?P<schema>{_SF_IDENTIFIER})\.)?(?P<object>{_SF_IDENTIFIER})"
14
14
  )
15
+ _SF_STAGE_PATH = rf"{_SF_SCHEMA_LEVEL_OBJECT}(?P<path>.*)"
15
16
  _SF_SCHEMA_LEVEL_OBJECT_RE = re.compile(_SF_SCHEMA_LEVEL_OBJECT)
17
+ _SF_STAGE_PATH_RE = re.compile(_SF_STAGE_PATH)
16
18
 
17
19
  UNQUOTED_CASE_INSENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_INSENSITIVE_IDENTIFIER})$")
18
20
  UNQUOTED_CASE_SENSITIVE_RE = re.compile(f"^({_SF_UNQUOTED_CASE_SENSITIVE_IDENTIFIER})$")
@@ -139,29 +141,61 @@ def rename_to_valid_snowflake_identifier(name: str) -> str:
139
141
 
140
142
 
141
143
  def parse_schema_level_object_identifier(
144
+ object_name: str,
145
+ ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
146
+ """Parse a string which starts with schema level object.
147
+
148
+ Args:
149
+ object_name: A string starts with a schema level object path, which is in the format
150
+ '<db>.<schema>.<object_name>'. Here, '<db>', '<schema>' and '<object_name>' are all snowflake identifiers.
151
+
152
+ Returns:
153
+ A tuple of 3 strings in the form of (db, schema, object_name).
154
+
155
+ Raises:
156
+ ValueError: If the id is invalid.
157
+ """
158
+ res = _SF_SCHEMA_LEVEL_OBJECT_RE.fullmatch(object_name)
159
+ if not res:
160
+ raise ValueError(
161
+ "Invalid identifier because it does not follow the pattern. "
162
+ f"It should start with [[database.]schema.]object. Getting {object_name}"
163
+ )
164
+ return (
165
+ res.group("db"),
166
+ res.group("schema"),
167
+ res.group("object"),
168
+ )
169
+
170
+
171
+ def parse_snowflake_stage_path(
142
172
  path: str,
143
173
  ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
144
- """Parse a string which starts with schema level object.
174
+ """Parse a string which represents a snowflake stage path.
145
175
 
146
176
  Args:
147
- path: A string starts with a schema level object path, which is in the format '<db>.<schema>.<object_name>'.
148
- Here, '<db>', '<schema>' and '<object_name>' are all snowflake identifiers.
177
+ path: A string starts with a schema level object path, which is in the format
178
+ '<db>.<schema>.<object_name><path>'. Here, '<db>', '<schema>' and '<object_name>' are all snowflake
179
+ identifiers.
149
180
 
150
181
  Returns:
151
- A tuple of 4 strings in the form of (db, schema, object_name, others). 'db', 'schema', 'object_name' are parsed
152
- from the schema level object and 'others' are all the content post to the object.
182
+ A tuple of 4 strings in the form of (db, schema, object_name, path). 'db', 'schema', 'object_name' are parsed
183
+ from the schema level object and 'path' are all the content post to the object.
153
184
 
154
185
  Raises:
155
186
  ValueError: If the id is invalid.
156
187
  """
157
- res = _SF_SCHEMA_LEVEL_OBJECT_RE.fullmatch(path)
188
+ res = _SF_STAGE_PATH_RE.fullmatch(path)
158
189
  if not res:
159
- raise ValueError(f"Invalid identifier. It should start with database.schema.object. Getting {path}")
190
+ raise ValueError(
191
+ "Invalid identifier because it does not follow the pattern. "
192
+ f"It should start with [[database.]schema.]object. Getting {path}"
193
+ )
160
194
  return (
161
195
  res.group("db"),
162
196
  res.group("schema"),
163
197
  res.group("object"),
164
- res.group("others"),
198
+ res.group("path"),
165
199
  )
166
200
 
167
201
 
@@ -175,8 +209,11 @@ def is_fully_qualified_name(name: str) -> bool:
175
209
  Returns:
176
210
  bool: True if the name is fully qualified, False otherwise.
177
211
  """
178
- res = parse_schema_level_object_identifier(name)
179
- return res[0] is not None and res[1] is not None and res[2] is not None and not res[3]
212
+ try:
213
+ res = parse_schema_level_object_identifier(name)
214
+ return all(res)
215
+ except ValueError:
216
+ return False
180
217
 
181
218
 
182
219
  def get_schema_level_object_identifier(
@@ -26,30 +26,11 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
26
26
  pkg_versions: List[str], session: Session, subproject: Optional[str] = None
27
27
  ) -> List[str]:
28
28
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
29
- return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(pkg_versions, session, subproject)
29
+ return pkg_versions
30
30
  else:
31
31
  return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(pkg_versions, session, subproject)
32
32
 
33
33
 
34
- def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(
35
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
36
- ) -> List[str]:
37
- for pkg_version in pkg_versions:
38
- if pkg_version not in cache:
39
- pkg_version_list = _query_pkg_version_supported_in_snowflake_conda_channel(
40
- pkg_version=pkg_version, session=session, block=True, subproject=subproject
41
- )
42
- assert isinstance(pkg_version_list, list) # keep mypy happy
43
- try:
44
- cache[pkg_version] = pkg_version_list[0]["VERSION"]
45
- except IndexError:
46
- cache[pkg_version] = None
47
-
48
- pkg_version_conda_list = _get_conda_packages_and_emit_warnings(pkg_versions)
49
-
50
- return pkg_version_conda_list
51
-
52
-
53
34
  def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
54
35
  pkg_versions: List[str], session: Session, subproject: Optional[str] = None
55
36
  ) -> List[str]:
@@ -60,7 +41,11 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
60
41
  async_job = _query_pkg_version_supported_in_snowflake_conda_channel(
61
42
  pkg_version=pkg_version, session=session, block=False, subproject=subproject
62
43
  )
63
- assert isinstance(async_job, AsyncJob)
44
+ if isinstance(async_job, list):
45
+ raise RuntimeError(
46
+ "Async job was expected, executed query was returned. Please contact Snowflake support."
47
+ )
48
+
64
49
  pkg_version_async_job_list.append((pkg_version, async_job))
65
50
 
66
51
  # Populate the cache.
@@ -143,7 +128,8 @@ def _get_conda_packages_and_emit_warnings(pkg_versions: List[str]) -> List[str]:
143
128
  warnings.warn(
144
129
  f"Package {', '.join([pkg[0] for pkg in pkg_version_warning_list])} is not supported "
145
130
  f"in snowflake conda channel for python runtime "
146
- f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}."
131
+ f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}.",
132
+ stacklevel=1,
147
133
  )
148
134
 
149
135
  return pkg_version_conda_list
@@ -2,7 +2,7 @@ import enum
2
2
  from typing import Any, Dict, Optional, TypedDict, cast
3
3
 
4
4
  from packaging import version
5
- from typing_extensions import Required
5
+ from typing_extensions import NotRequired, Required
6
6
 
7
7
  from snowflake.ml._internal.utils import query_result_checker
8
8
  from snowflake.snowpark import session
@@ -52,7 +52,7 @@ class SnowflakeCloudType(enum.Enum):
52
52
 
53
53
 
54
54
  class SnowflakeRegion(TypedDict):
55
- region_group: Required[str]
55
+ region_group: NotRequired[str]
56
56
  snowflake_region: Required[str]
57
57
  cloud: Required[SnowflakeCloudType]
58
58
  region: Required[str]
@@ -64,23 +64,33 @@ def get_regions(
64
64
  ) -> Dict[str, SnowflakeRegion]:
65
65
  res = (
66
66
  query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
67
- .has_column("region_group")
68
67
  .has_column("snowflake_region")
69
68
  .has_column("cloud")
70
69
  .has_column("region")
71
70
  .has_column("display_name")
72
71
  .validate()
73
72
  )
74
- return {
75
- f"{r.region_group}.{r.snowflake_region}": SnowflakeRegion(
76
- region_group=r.region_group,
77
- snowflake_region=r.snowflake_region,
78
- cloud=SnowflakeCloudType.from_value(r.cloud),
79
- region=r.region,
80
- display_name=r.display_name,
81
- )
82
- for r in res
83
- }
73
+ res_dict = {}
74
+ for r in res:
75
+ if hasattr(r, "region_group") and r.region_group:
76
+ key = f"{r.region_group}.{r.snowflake_region}"
77
+ res_dict[key] = SnowflakeRegion(
78
+ region_group=r.region_group,
79
+ snowflake_region=r.snowflake_region,
80
+ cloud=SnowflakeCloudType.from_value(r.cloud),
81
+ region=r.region,
82
+ display_name=r.display_name,
83
+ )
84
+ else:
85
+ key = r.snowflake_region
86
+ res_dict[key] = SnowflakeRegion(
87
+ snowflake_region=r.snowflake_region,
88
+ cloud=SnowflakeCloudType.from_value(r.cloud),
89
+ region=r.region,
90
+ display_name=r.display_name,
91
+ )
92
+
93
+ return res_dict
84
94
 
85
95
 
86
96
  def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
@@ -84,7 +84,7 @@ def to_sql_identifiers(list_of_str: List[str], *, case_sensitive: bool = False)
84
84
  def parse_fully_qualified_name(
85
85
  name: str,
86
86
  ) -> Tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
87
- db, schema, object, _ = identifier.parse_schema_level_object_identifier(name)
87
+ db, schema, object = identifier.parse_schema_level_object_identifier(name)
88
88
 
89
89
  assert name is not None, f"Unable parse the input name `{name}` as fully qualified."
90
90
  return (