snowflake-ml-python 1.7.3__py3-none-any.whl → 1.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (208) hide show
  1. snowflake/cortex/_complete.py +19 -0
  2. snowflake/ml/_internal/env_utils.py +64 -21
  3. snowflake/ml/_internal/platform_capabilities.py +87 -0
  4. snowflake/ml/_internal/relax_version_strategy.py +16 -0
  5. snowflake/ml/_internal/telemetry.py +21 -0
  6. snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
  7. snowflake/ml/dataset/dataset.py +0 -1
  8. snowflake/ml/feature_store/feature_store.py +18 -0
  9. snowflake/ml/feature_store/feature_view.py +46 -1
  10. snowflake/ml/fileset/fileset.py +6 -0
  11. snowflake/ml/jobs/__init__.py +21 -0
  12. snowflake/ml/jobs/_utils/constants.py +57 -0
  13. snowflake/ml/jobs/_utils/payload_utils.py +438 -0
  14. snowflake/ml/jobs/_utils/spec_utils.py +296 -0
  15. snowflake/ml/jobs/_utils/types.py +39 -0
  16. snowflake/ml/jobs/decorators.py +71 -0
  17. snowflake/ml/jobs/job.py +113 -0
  18. snowflake/ml/jobs/manager.py +298 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +11 -2
  20. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  21. snowflake/ml/model/_client/sql/service.py +13 -6
  22. snowflake/ml/model/_packager/model_env/model_env.py +45 -28
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +19 -6
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  25. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +17 -0
  26. snowflake/ml/model/_packager/model_handlers/keras.py +230 -0
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -0
  28. snowflake/ml/model/_packager/model_handlers/sklearn.py +28 -3
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +74 -21
  30. snowflake/ml/model/_packager/model_handlers/tensorflow.py +27 -49
  31. snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
  32. snowflake/ml/model/_packager/model_meta/model_meta.py +1 -1
  33. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +3 -0
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  35. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -1
  36. snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
  37. snowflake/ml/model/_signatures/base_handler.py +1 -2
  38. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  39. snowflake/ml/model/_signatures/core.py +2 -2
  40. snowflake/ml/model/_signatures/numpy_handler.py +11 -12
  41. snowflake/ml/model/_signatures/pandas_handler.py +11 -9
  42. snowflake/ml/model/_signatures/pytorch_handler.py +3 -6
  43. snowflake/ml/model/_signatures/snowpark_handler.py +3 -3
  44. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  45. snowflake/ml/model/model_signature.py +25 -4
  46. snowflake/ml/model/type_hints.py +15 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  51. snowflake/ml/modeling/cluster/birch.py +6 -3
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  53. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  55. snowflake/ml/modeling/cluster/k_means.py +6 -3
  56. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  58. snowflake/ml/modeling/cluster/optics.py +6 -3
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  62. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  69. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  70. snowflake/ml/modeling/covariance/oas.py +6 -3
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  74. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  79. snowflake/ml/modeling/decomposition/pca.py +6 -3
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  110. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  111. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  112. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  113. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  114. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  115. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  116. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  117. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  118. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  119. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  120. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  121. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  122. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  123. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  124. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  125. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  126. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  127. snowflake/ml/modeling/linear_model/lars.py +6 -3
  128. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  129. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  130. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  131. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  132. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  133. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  134. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  135. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  136. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  137. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  139. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  140. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  141. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  142. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  143. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  144. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  145. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  146. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  148. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  149. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  151. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  152. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  153. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  154. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  155. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  156. snowflake/ml/modeling/manifold/isomap.py +6 -3
  157. snowflake/ml/modeling/manifold/mds.py +6 -3
  158. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  159. snowflake/ml/modeling/manifold/tsne.py +6 -3
  160. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  161. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  162. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  163. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  174. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  184. snowflake/ml/modeling/pipeline/pipeline.py +28 -3
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +8 -5
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  188. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  189. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  190. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  191. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  192. snowflake/ml/modeling/svm/svc.py +6 -3
  193. snowflake/ml/modeling/svm/svr.py +6 -3
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +6 -3
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +6 -3
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +6 -3
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +6 -3
  202. snowflake/ml/registry/registry.py +34 -4
  203. snowflake/ml/version.py +1 -1
  204. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/METADATA +81 -33
  205. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/RECORD +208 -196
  206. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/WHEEL +1 -1
  207. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/LICENSE.txt +0 -0
  208. {snowflake_ml_python-1.7.3.dist-info → snowflake_ml_python-1.7.5.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,10 @@ class CompleteOptions(TypedDict):
49
49
  generally used as an alternative to temperature. The difference is that top_p restricts the set of possible tokens
50
50
  that the model outputs, while temperature influences which tokens are chosen at each step. """
51
51
 
52
+ guardrails: NotRequired[bool]
53
+ """ A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
54
+ from the language model. """
55
+
52
56
 
53
57
  class ResponseParseException(Exception):
54
58
  """This exception is raised when the server response cannot be parsed."""
@@ -56,6 +60,15 @@ class ResponseParseException(Exception):
56
60
  pass
57
61
 
58
62
 
63
+ class GuardrailsOptions(TypedDict):
64
+ enabled: bool
65
+ """A boolean value that controls whether Cortex Guard filters unsafe or harmful responses
66
+ from the language model."""
67
+
68
+ response_when_unsafe: str
69
+ """The response to return when the language model generates unsafe or harmful content."""
70
+
71
+
59
72
  _MAX_RETRY_SECONDS = 30
60
73
 
61
74
 
@@ -117,6 +130,12 @@ def _make_request_body(
117
130
  data["temperature"] = options["temperature"]
118
131
  if "top_p" in options:
119
132
  data["top_p"] = options["top_p"]
133
+ if "guardrails" in options and options["guardrails"]:
134
+ guardrails_options: GuardrailsOptions = {
135
+ "enabled": True,
136
+ "response_when_unsafe": "Response filtered by Cortex Guard",
137
+ }
138
+ data["guardrails"] = guardrails_options
120
139
  return data
121
140
 
122
141
 
@@ -12,7 +12,7 @@ import yaml
12
12
  from packaging import requirements, specifiers, version
13
13
 
14
14
  import snowflake.connector
15
- from snowflake.ml._internal import env as snowml_env
15
+ from snowflake.ml._internal import env as snowml_env, relax_version_strategy
16
16
  from snowflake.ml._internal.utils import query_result_checker
17
17
  from snowflake.snowpark import context, exceptions, session
18
18
 
@@ -56,6 +56,8 @@ def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
56
56
 
57
57
  if r.name == "python":
58
58
  raise ValueError("Don't specify python as a dependency, use python version argument instead.")
59
+ if r.name == "cuda":
60
+ raise ValueError("Don't specify cuda as a dependency, use cuda version argument instead.")
59
61
  except requirements.InvalidRequirement:
60
62
  raise ValueError(f"Invalid package requirement {req_str} found.")
61
63
 
@@ -313,19 +315,14 @@ def get_package_spec_with_supported_ops_only(req: requirements.Requirement) -> r
313
315
  return new_req
314
316
 
315
317
 
316
- def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement:
317
- """Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with
318
- >=x.y, <(x+1)
319
-
320
- Args:
321
- req: The requirement that version specifier to be removed.
322
-
323
- Returns:
324
- A new requirement object after relaxations.
325
- """
326
- new_req = copy.deepcopy(req)
318
+ def _relax_specifier_set(
319
+ specifier_set: specifiers.SpecifierSet, strategy: relax_version_strategy.RelaxVersionStrategy
320
+ ) -> specifiers.SpecifierSet:
321
+ if strategy == relax_version_strategy.RelaxVersionStrategy.NO_RELAX:
322
+ return specifier_set
323
+ specifier_set = copy.deepcopy(specifier_set)
327
324
  relaxed_specifier_set = set()
328
- for spec in new_req.specifier._specs:
325
+ for spec in specifier_set._specs:
329
326
  if spec.operator != "==":
330
327
  relaxed_specifier_set.add(spec)
331
328
  continue
@@ -337,9 +334,40 @@ def relax_requirement_version(req: requirements.Requirement) -> requirements.Req
337
334
  relaxed_specifier_set.add(spec)
338
335
  continue
339
336
  assert pinned_version is not None
340
- relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
341
- relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
342
- new_req.specifier._specs = frozenset(relaxed_specifier_set)
337
+ if strategy == relax_version_strategy.RelaxVersionStrategy.PATCH:
338
+ relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
339
+ relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major}.{pinned_version.minor+1}"))
340
+ elif strategy == relax_version_strategy.RelaxVersionStrategy.MINOR:
341
+ relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}.{pinned_version.minor}"))
342
+ relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
343
+ elif strategy == relax_version_strategy.RelaxVersionStrategy.MAJOR:
344
+ relaxed_specifier_set.add(specifiers.Specifier(f">={pinned_version.major}"))
345
+ relaxed_specifier_set.add(specifiers.Specifier(f"<{pinned_version.major + 1}"))
346
+ specifier_set._specs = frozenset(relaxed_specifier_set)
347
+ return specifier_set
348
+
349
+
350
+ def relax_requirement_version(req: requirements.Requirement) -> requirements.Requirement:
351
+ """Relax version specifier from a requirement. It detects any ==x.y.z in specifiers and replaced with relaxed
352
+ version specifier based on the strategy defined in RELAX_VERSION_STRATEGY_MAP.
353
+
354
+ NO_RELAX: No relaxation.
355
+ PATCH: >=x.y, <x.(y+1)
356
+ MINOR (default): >=x.y, <(x+1)
357
+ MAJOR: >=x, <(x+1)
358
+
359
+
360
+ Args:
361
+ req: The requirement that version specifier to be removed.
362
+
363
+ Returns:
364
+ A new requirement object after relaxations.
365
+ """
366
+ new_req = copy.deepcopy(req)
367
+ strategy = relax_version_strategy.RELAX_VERSION_STRATEGY_MAP.get(
368
+ req.name, relax_version_strategy.RelaxVersionStrategy.MINOR
369
+ )
370
+ new_req.specifier = _relax_specifier_set(new_req.specifier, strategy)
343
371
  return new_req
344
372
 
345
373
 
@@ -431,10 +459,11 @@ def save_conda_env_file(
431
459
  path: pathlib.Path,
432
460
  conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
433
461
  python_version: str,
462
+ cuda_version: Optional[str] = None,
434
463
  default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
435
464
  ) -> None:
436
465
  """Generate conda.yml file given a dict of dependencies after validation.
437
- The channels part of conda.yml file will contains Snowflake Anaconda Channel, nodefaults and all channel names
466
+ The channels part of conda.yml file will contain Snowflake Anaconda Channel, nodefaults and all channel names
438
467
  in keys of the dict, ordered by the number of the packages which belongs to.
439
468
  The dependencies part of conda.yml file will contains requirements specifications. If the requirements is in the
440
469
  value list whose key is DEFAULT_CHANNEL_NAME, then the channel won't be specified explicitly. Otherwise, it will be
@@ -443,7 +472,8 @@ def save_conda_env_file(
443
472
  Args:
444
473
  path: Path to the conda.yml file.
445
474
  conda_chan_deps: Dict of conda dependencies after validated.
446
- python_version: A string 'major.minor' showing python version relate to model.
475
+ python_version: A string 'major.minor' for the model's python version.
476
+ cuda_version: A string 'major.minor' for the model's cuda version.
447
477
  default_channel_override: The default channel to be put in the first place of the channels section.
448
478
  """
449
479
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
@@ -461,6 +491,10 @@ def save_conda_env_file(
461
491
 
462
492
  env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
463
493
  env["dependencies"] = [f"python=={python_version}.*"]
494
+
495
+ if cuda_version is not None:
496
+ env["dependencies"].extend([f"nvidia::cuda=={cuda_version}.*"])
497
+
464
498
  for chan, reqs in conda_chan_deps.items():
465
499
  env["dependencies"].extend(
466
500
  [f"{chan}::{str(req)}" if chan != DEFAULT_CHANNEL_NAME else str(req) for req in reqs]
@@ -487,7 +521,12 @@ def save_requirements_file(path: pathlib.Path, pip_deps: List[requirements.Requi
487
521
 
488
522
  def load_conda_env_file(
489
523
  path: pathlib.Path,
490
- ) -> Tuple[DefaultDict[str, List[requirements.Requirement]], Optional[List[requirements.Requirement]], Optional[str]]:
524
+ ) -> Tuple[
525
+ DefaultDict[str, List[requirements.Requirement]],
526
+ Optional[List[requirements.Requirement]],
527
+ Optional[str],
528
+ Optional[str],
529
+ ]:
491
530
  """Read conda.yml file to get a dict of dependencies after validation.
492
531
  The channels part of conda.yml file will be processed with following rules:
493
532
  1. If it is Snowflake Anaconda Channel, ignore as it is default.
@@ -515,7 +554,7 @@ def load_conda_env_file(
515
554
  and a string 'major.minor.patchlevel' of python version.
516
555
  """
517
556
  if not path.exists():
518
- return collections.defaultdict(list), None, None
557
+ return collections.defaultdict(list), None, None, None
519
558
 
520
559
  with open(path, encoding="utf-8") as f:
521
560
  env = yaml.safe_load(stream=f)
@@ -526,6 +565,7 @@ def load_conda_env_file(
526
565
  pip_deps = []
527
566
 
528
567
  python_version = None
568
+ cuda_version = None
529
569
 
530
570
  channels = env.get("channels", [])
531
571
  if len(channels) >= 1:
@@ -541,6 +581,9 @@ def load_conda_env_file(
541
581
  # ver is str: python w/ specifier
542
582
  if ver:
543
583
  python_version = ver
584
+ elif dep.startswith("nvidia::cuda"):
585
+ r = requirements.Requirement(dep.split("nvidia::")[1])
586
+ cuda_version = list(r.specifier)[0].version.strip(".*")
544
587
  elif ver is None:
545
588
  deps.append(dep)
546
589
  elif isinstance(dep, dict) and "pip" in dep:
@@ -555,7 +598,7 @@ def load_conda_env_file(
555
598
  if channel not in conda_dep_dict:
556
599
  conda_dep_dict[channel] = []
557
600
 
558
- return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version
601
+ return conda_dep_dict, pip_deps_list if pip_deps_list else None, python_version, cuda_version
559
602
 
560
603
 
561
604
  def load_requirements_file(path: pathlib.Path) -> List[requirements.Requirement]:
@@ -0,0 +1,87 @@
1
+ import json
2
+ from typing import Any, Dict, Optional
3
+
4
+ from absl import logging
5
+
6
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
7
+ from snowflake.ml._internal.utils import query_result_checker
8
+ from snowflake.snowpark import (
9
+ exceptions as snowpark_exceptions,
10
+ session as snowpark_session,
11
+ )
12
+
13
+
14
+ class PlatformCapabilities:
15
+ """Class that retrieves platform feature values for the currently running server.
16
+
17
+ Example usage:
18
+ ```
19
+ pc = PlatformCapabilities.get_instance(session)
20
+ if pc.is_nested_function_enabled():
21
+ # Nested functions are enabled.
22
+ print("Nested functions are enabled.")
23
+ else:
24
+ # Nested functions are disabled.
25
+ print("Nested functions are disabled or not supported.")
26
+ ```
27
+ """
28
+
29
+ _instance: Optional["PlatformCapabilities"] = None
30
+
31
+ @classmethod
32
+ def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
33
+ if not cls._instance:
34
+ cls._instance = cls(session)
35
+ return cls._instance
36
+
37
+ def is_nested_function_enabled(self) -> bool:
38
+ return self._get_bool_feature("SPCS_MODEL_ENABLE_EMBEDDED_SERVICE_FUNCTIONS", False)
39
+
40
+ @staticmethod
41
+ def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
42
+ try:
43
+ result = (
44
+ query_result_checker.SqlResultValidator(
45
+ session=session,
46
+ query="SELECT SYSTEM$ML_PLATFORM_CAPABILITIES() AS FEATURES;",
47
+ )
48
+ .has_dimensions(expected_rows=1, expected_cols=1)
49
+ .has_column("FEATURES")
50
+ .validate()[0]
51
+ )
52
+ if "FEATURES" in result:
53
+ capabilities_json: str = result["FEATURES"]
54
+ try:
55
+ parsed_json = json.loads(capabilities_json)
56
+ assert isinstance(parsed_json, dict), f"Expected JSON object, got {type(parsed_json)}"
57
+ return parsed_json
58
+ except json.JSONDecodeError as e:
59
+ message = f"""Unable to parse JSON from: "{capabilities_json}"; Error="{e}"."""
60
+ raise exceptions.SnowflakeMLException(
61
+ error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
62
+ )
63
+ except snowpark_exceptions.SnowparkSQLException as e:
64
+ logging.debug(f"Failed to retrieve platform capabilities: {e}")
65
+ # This can happen is server side is older than 9.2. That is fine.
66
+ return {}
67
+
68
+ def __init__(self, session: Optional[snowpark_session.Session] = None) -> None:
69
+ if not session:
70
+ session = next(iter(snowpark_session._get_active_sessions()))
71
+ assert session, "Missing active session object"
72
+ self.features: Dict[str, Any] = PlatformCapabilities._get_features(session)
73
+
74
+ def _get_bool_feature(self, feature_name: str, default_value: bool) -> bool:
75
+ value = self.features.get(feature_name, default_value)
76
+ if isinstance(value, bool):
77
+ return value
78
+ if isinstance(value, int) and value in [0, 1]:
79
+ return value == 1
80
+ if isinstance(value, str):
81
+ if value.lower() in ["true", "1"]:
82
+ return True
83
+ elif value.lower() in ["false", "0"]:
84
+ return False
85
+ else:
86
+ raise ValueError(f"Invalid boolean string: {value} for feature {feature_name}")
87
+ raise ValueError(f"Invalid boolean feature value: {value} for feature {feature_name}")
@@ -0,0 +1,16 @@
1
+ from enum import Enum
2
+
3
+
4
+ class RelaxVersionStrategy(Enum):
5
+ NO_RELAX = "no_relax"
6
+ PATCH = "patch"
7
+ MINOR = "minor"
8
+ MAJOR = "major"
9
+
10
+
11
+ RELAX_VERSION_STRATEGY_MAP = {
12
+ # The version of cloudpickle should not be relaxed as it is used for serialization.
13
+ "cloudpickle": RelaxVersionStrategy.NO_RELAX,
14
+ # The version of scikit-learn should be relaxed only in patch version as it has breaking changes in minor version.
15
+ "scikit-learn": RelaxVersionStrategy.PATCH,
16
+ }
@@ -4,6 +4,9 @@ import enum
4
4
  import functools
5
5
  import inspect
6
6
  import operator
7
+ import sys
8
+ import time
9
+ import traceback
7
10
  import types
8
11
  from typing import (
9
12
  Any,
@@ -75,6 +78,8 @@ class TelemetryField(enum.Enum):
75
78
  KEY_FUNC_PARAMS = "func_params"
76
79
  KEY_ERROR_INFO = "error_info"
77
80
  KEY_ERROR_CODE = "error_code"
81
+ KEY_STACK_TRACE = "stack_trace"
82
+ KEY_DURATION = "duration"
78
83
  KEY_VERSION = "version"
79
84
  KEY_PYTHON_VERSION = "python_version"
80
85
  KEY_OS = "operating_system"
@@ -435,6 +440,7 @@ def send_api_usage_telemetry(
435
440
 
436
441
  # noqa: DAR402
437
442
  """
443
+ start_time = time.perf_counter()
438
444
 
439
445
  if subproject is not None and subproject_extractor is not None:
440
446
  raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
@@ -555,8 +561,16 @@ def send_api_usage_telemetry(
555
561
  )
556
562
  else:
557
563
  me = e
564
+
558
565
  telemetry_args["error"] = repr(me)
559
566
  telemetry_args["error_code"] = me.error_code
567
+ # exclude telemetry frames
568
+ excluded_frames = 2
569
+ tb = traceback.extract_tb(sys.exc_info()[2])
570
+ formatted_tb = "".join(traceback.format_list(tb[excluded_frames:]))
571
+ formatted_exception = traceback.format_exception_only(*sys.exc_info()[:2])[0] # error type + message
572
+ telemetry_args["stack_trace"] = formatted_tb + formatted_exception
573
+
560
574
  me.original_exception._snowflake_ml_handled = True # type: ignore[attr-defined]
561
575
  if e is not me:
562
576
  raise # Directly raise non-wrapped exceptions to preserve original stacktrace
@@ -565,6 +579,7 @@ def send_api_usage_telemetry(
565
579
  else:
566
580
  raise me.original_exception from e
567
581
  finally:
582
+ telemetry_args["duration"] = time.perf_counter() - start_time # type: ignore[assignment]
568
583
  telemetry.send_function_usage_telemetry(**telemetry_args)
569
584
  global _log_counter
570
585
  _log_counter += 1
@@ -718,12 +733,14 @@ class _SourceTelemetryClient:
718
733
  self,
719
734
  func_name: str,
720
735
  function_category: str,
736
+ duration: float,
721
737
  func_params: Optional[Dict[str, Any]] = None,
722
738
  api_calls: Optional[List[Dict[str, Any]]] = None,
723
739
  sfqids: Optional[List[Any]] = None,
724
740
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
725
741
  error: Optional[str] = None,
726
742
  error_code: Optional[str] = None,
743
+ stack_trace: Optional[str] = None,
727
744
  ) -> None:
728
745
  """
729
746
  Send function usage telemetry message.
@@ -731,12 +748,14 @@ class _SourceTelemetryClient:
731
748
  Args:
732
749
  func_name: Function name.
733
750
  function_category: Function category.
751
+ duration: Function duration.
734
752
  func_params: Function parameters.
735
753
  api_calls: API calls.
736
754
  sfqids: Snowflake query IDs.
737
755
  custom_tags: Custom tags.
738
756
  error: Error.
739
757
  error_code: Error code.
758
+ stack_trace: Error stack trace.
740
759
  """
741
760
  data: Dict[str, Any] = {
742
761
  TelemetryField.KEY_FUNC_NAME.value: func_name,
@@ -755,11 +774,13 @@ class _SourceTelemetryClient:
755
774
  message: Dict[str, Any] = {
756
775
  **self._create_basic_telemetry_data(telemetry_type),
757
776
  TelemetryField.KEY_DATA.value: data,
777
+ TelemetryField.KEY_DURATION.value: duration,
758
778
  }
759
779
 
760
780
  if error:
761
781
  message[TelemetryField.KEY_ERROR_INFO.value] = error
762
782
  message[TelemetryField.KEY_ERROR_CODE.value] = error_code
783
+ message[TelemetryField.KEY_STACK_TRACE.value] = stack_trace
763
784
 
764
785
  self._send(message)
765
786
 
@@ -116,7 +116,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
116
116
  def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
117
117
  ds = self._get_dataset(shuffle=False)
118
118
  table = ds.to_table() if limit is None else ds.head(num_rows=limit)
119
- return table.to_pandas()
119
+ return table.to_pandas(split_blocks=True, self_destruct=True)
120
120
 
121
121
  def _get_dataset(self, shuffle: bool) -> pds.Dataset:
122
122
  format = self._format
@@ -419,7 +419,6 @@ class Dataset(lineage_node.LineageNode):
419
419
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
420
420
  def delete(self) -> None:
421
421
  """Delete Dataset and all contained versions"""
422
- # TODO: Check and warn if any versions exist
423
422
  self._session.sql(f"DROP DATASET {self.fully_qualified_name}").collect(
424
423
  statement_params=_TELEMETRY_STATEMENT_PARAMS
425
424
  )
@@ -144,6 +144,7 @@ _LIST_FEATURE_VIEW_SCHEMA = StructType(
144
144
  StructField("refresh_mode", StringType()),
145
145
  StructField("scheduling_state", StringType()),
146
146
  StructField("warehouse", StringType()),
147
+ StructField("cluster_by", StringType()),
147
148
  ]
148
149
  )
149
150
 
@@ -1832,6 +1833,12 @@ class FeatureStore:
1832
1833
  WAREHOUSE = {warehouse}
1833
1834
  REFRESH_MODE = {feature_view.refresh_mode}
1834
1835
  INITIALIZE = {feature_view.initialize}
1836
+ """
1837
+ if feature_view.cluster_by:
1838
+ cluster_by_clause = f"CLUSTER BY ({', '.join(feature_view.cluster_by)})"
1839
+ query += f"{cluster_by_clause}"
1840
+
1841
+ query += f"""
1835
1842
  AS {feature_view.query}
1836
1843
  """
1837
1844
  self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp)
@@ -2249,6 +2256,7 @@ class FeatureStore:
2249
2256
  values.append(row["refresh_mode"] if "refresh_mode" in row else None)
2250
2257
  values.append(row["scheduling_state"] if "scheduling_state" in row else None)
2251
2258
  values.append(row["warehouse"] if "warehouse" in row else None)
2259
+ values.append(json.dumps(self._extract_cluster_by_columns(row["cluster_by"])) if "cluster_by" in row else None)
2252
2260
  output_values.append(values)
2253
2261
 
2254
2262
  def _lookup_feature_view_metadata(self, row: Row, fv_name: str) -> Tuple[_FeatureViewMetadata, str]:
@@ -2335,6 +2343,7 @@ class FeatureStore:
2335
2343
  owner=row["owner"],
2336
2344
  infer_schema_df=infer_schema_df,
2337
2345
  session=self._session,
2346
+ cluster_by=self._extract_cluster_by_columns(row["cluster_by"]),
2338
2347
  )
2339
2348
  return fv
2340
2349
  else:
@@ -2625,3 +2634,12 @@ class FeatureStore:
2625
2634
  )
2626
2635
 
2627
2636
  return feature_view
2637
+
2638
+ @staticmethod
2639
+ def _extract_cluster_by_columns(cluster_by_clause: str) -> List[str]:
2640
+ # Use regex to extract elements inside the parentheses.
2641
+ match = re.search(r"\((.*?)\)", cluster_by_clause)
2642
+ if match:
2643
+ # Handle both quoted and unquoted column names.
2644
+ return re.findall(identifier.SF_IDENTIFIER_RE, match.group(1))
2645
+ return []
@@ -170,6 +170,7 @@ class FeatureView(lineage_node.LineageNode):
170
170
  warehouse: Optional[str] = None,
171
171
  initialize: str = "ON_CREATE",
172
172
  refresh_mode: str = "AUTO",
173
+ cluster_by: Optional[List[str]] = None,
173
174
  **_kwargs: Any,
174
175
  ) -> None:
175
176
  """
@@ -200,6 +201,9 @@ class FeatureView(lineage_node.LineageNode):
200
201
  refresh_mode: The refresh mode of managed feature view. The value can be 'AUTO', 'FULL' or 'INCREMENETAL'.
201
202
  For managed feature view, the default value is 'AUTO'. For static feature view it has no effect.
202
203
  Check https://docs.snowflake.com/en/sql-reference/sql/create-dynamic-table for for details.
204
+ cluster_by: Columns to cluster the feature view by.
205
+ - Defaults to the join keys from entities.
206
+ - If `timestamp_col` is provided, it is added to the default clustering keys.
203
207
  _kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
204
208
 
205
209
  Example::
@@ -224,6 +228,7 @@ class FeatureView(lineage_node.LineageNode):
224
228
  >>> print(registered_fv.status)
225
229
  FeatureViewStatus.ACTIVE
226
230
 
231
+ # noqa: DAR401
227
232
  """
228
233
 
229
234
  self._name: SqlIdentifier = SqlIdentifier(name)
@@ -233,7 +238,7 @@ class FeatureView(lineage_node.LineageNode):
233
238
  SqlIdentifier(timestamp_col) if timestamp_col is not None else None
234
239
  )
235
240
  self._desc: str = desc
236
- self._infer_schema_df: DataFrame = _kwargs.get("_infer_schema_df", self._feature_df)
241
+ self._infer_schema_df: DataFrame = _kwargs.pop("_infer_schema_df", self._feature_df)
237
242
  self._query: str = self._get_query()
238
243
  self._version: Optional[FeatureViewVersion] = None
239
244
  self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
@@ -249,6 +254,14 @@ class FeatureView(lineage_node.LineageNode):
249
254
  self._refresh_mode: Optional[str] = refresh_mode
250
255
  self._refresh_mode_reason: Optional[str] = None
251
256
  self._owner: Optional[str] = None
257
+ self._cluster_by: List[SqlIdentifier] = (
258
+ [SqlIdentifier(col) for col in cluster_by] if cluster_by is not None else self._get_default_cluster_by()
259
+ )
260
+
261
+ # Validate kwargs
262
+ if _kwargs:
263
+ raise TypeError(f"FeatureView.__init__ got an unexpected keyword argument: '{next(iter(_kwargs.keys()))}'")
264
+
252
265
  self._validate()
253
266
 
254
267
  def slice(self, names: List[str]) -> FeatureViewSlice:
@@ -394,6 +407,10 @@ class FeatureView(lineage_node.LineageNode):
394
407
  def timestamp_col(self) -> Optional[SqlIdentifier]:
395
408
  return self._timestamp_col
396
409
 
410
+ @property
411
+ def cluster_by(self) -> Optional[List[SqlIdentifier]]:
412
+ return self._cluster_by
413
+
397
414
  @property
398
415
  def desc(self) -> str:
399
416
  return self._desc
@@ -656,6 +673,14 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
656
673
  if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
657
674
  raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
658
675
 
676
+ if self.cluster_by is not None:
677
+ for column in self.cluster_by:
678
+ if column not in df_cols:
679
+ raise ValueError(
680
+ f"Column '{column}' in `cluster_by` is not in the feature DataFrame schema. "
681
+ f"{df_cols}, {self.cluster_by}"
682
+ )
683
+
659
684
  if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
660
685
  raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
661
686
 
@@ -890,6 +915,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
890
915
  owner: Optional[str],
891
916
  infer_schema_df: Optional[DataFrame],
892
917
  session: Session,
918
+ cluster_by: Optional[List[str]] = None,
893
919
  ) -> FeatureView:
894
920
  fv = FeatureView(
895
921
  name=name,
@@ -898,6 +924,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
898
924
  timestamp_col=timestamp_col,
899
925
  desc=desc,
900
926
  _infer_schema_df=infer_schema_df,
927
+ cluster_by=cluster_by,
901
928
  )
902
929
  fv._version = FeatureViewVersion(version) if version is not None else None
903
930
  fv._status = status
@@ -916,5 +943,23 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
916
943
  )
917
944
  return fv
918
945
 
946
+ #
947
+ def _get_default_cluster_by(self) -> List[SqlIdentifier]:
948
+ """
949
+ Get default columns to cluster the feature view by.
950
+ Default cluster_by columns are join keys from entities and timestamp_col if it exists
951
+
952
+ Returns:
953
+ List of SqlIdentifiers representing the default columns to cluster the feature view by.
954
+ """
955
+ # We don't focus on the order of entities here, as users can define a custom 'cluster_by'
956
+ # if a specific order is required.
957
+ default_cluster_by_cols = [key for entity in self.entities if entity.join_keys for key in entity.join_keys]
958
+
959
+ if self.timestamp_col:
960
+ default_cluster_by_cols.append(self.timestamp_col)
961
+
962
+ return default_cluster_by_cols
963
+
919
964
 
920
965
  lineage_node.DOMAIN_LINEAGE_REGISTRY["feature_view"] = FeatureView
@@ -2,6 +2,8 @@ import functools
2
2
  import inspect
3
3
  from typing import Any, Callable, List, Optional
4
4
 
5
+ from typing_extensions import deprecated
6
+
5
7
  from snowflake import snowpark
6
8
  from snowflake.connector import connection
7
9
  from snowflake.ml._internal import telemetry
@@ -42,6 +44,10 @@ def _raise_if_deleted(func: Callable[..., Any]) -> Callable[..., Any]:
42
44
  return raise_if_deleted_helper
43
45
 
44
46
 
47
+ @deprecated(
48
+ "FileSet is deprecated and will be removed in a future release."
49
+ " Use snowflake.ml.dataset.Dataset and snowflake.ml.data.DataConnector instead"
50
+ )
45
51
  class FileSet:
46
52
  """A FileSet represents an immutable snapshot of the result of a query in the form of files."""
47
53
 
@@ -0,0 +1,21 @@
1
+ from snowflake.ml.jobs._utils.types import JOB_STATUS
2
+ from snowflake.ml.jobs.decorators import remote
3
+ from snowflake.ml.jobs.job import MLJob
4
+ from snowflake.ml.jobs.manager import (
5
+ delete_job,
6
+ get_job,
7
+ list_jobs,
8
+ submit_directory,
9
+ submit_file,
10
+ )
11
+
12
+ __all__ = [
13
+ "remote",
14
+ "submit_file",
15
+ "submit_directory",
16
+ "list_jobs",
17
+ "get_job",
18
+ "delete_job",
19
+ "MLJob",
20
+ "JOB_STATUS",
21
+ ]