snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +101 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/platform_capabilities.py +87 -0
  12. snowflake/ml/_internal/utils/identifier.py +4 -2
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/dataset/dataset.py +0 -1
  19. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  20. snowflake/ml/fileset/fileset.py +24 -18
  21. snowflake/ml/jobs/__init__.py +21 -0
  22. snowflake/ml/jobs/_utils/constants.py +51 -0
  23. snowflake/ml/jobs/_utils/payload_utils.py +352 -0
  24. snowflake/ml/jobs/_utils/spec_utils.py +298 -0
  25. snowflake/ml/jobs/_utils/types.py +39 -0
  26. snowflake/ml/jobs/decorators.py +91 -0
  27. snowflake/ml/jobs/job.py +113 -0
  28. snowflake/ml/jobs/manager.py +298 -0
  29. snowflake/ml/model/_client/model/model_version_impl.py +5 -3
  30. snowflake/ml/model/_client/ops/model_ops.py +13 -8
  31. snowflake/ml/model/_client/ops/service_ops.py +1 -11
  32. snowflake/ml/model/_client/sql/model_version.py +11 -0
  33. snowflake/ml/model/_client/sql/service.py +13 -6
  34. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  35. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  37. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  38. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  39. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  40. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  41. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  42. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  43. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  44. snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
  45. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  46. snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  49. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  50. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
  51. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  52. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  53. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  54. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  55. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  56. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  57. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  58. snowflake/ml/model/_signatures/base_handler.py +1 -2
  59. snowflake/ml/model/_signatures/builtins_handler.py +2 -2
  60. snowflake/ml/model/_signatures/numpy_handler.py +6 -7
  61. snowflake/ml/model/_signatures/pandas_handler.py +3 -3
  62. snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
  63. snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
  64. snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
  65. snowflake/ml/model/model_signature.py +17 -4
  66. snowflake/ml/model/type_hints.py +1 -0
  67. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  68. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  69. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
  70. snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
  71. snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
  72. snowflake/ml/modeling/cluster/birch.py +6 -3
  73. snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
  74. snowflake/ml/modeling/cluster/dbscan.py +6 -3
  75. snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
  76. snowflake/ml/modeling/cluster/k_means.py +6 -3
  77. snowflake/ml/modeling/cluster/mean_shift.py +6 -3
  78. snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
  79. snowflake/ml/modeling/cluster/optics.py +6 -3
  80. snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
  81. snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
  82. snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
  83. snowflake/ml/modeling/compose/column_transformer.py +6 -3
  84. snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
  85. snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
  86. snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
  87. snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
  88. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
  89. snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
  90. snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
  91. snowflake/ml/modeling/covariance/oas.py +6 -3
  92. snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
  93. snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
  94. snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
  95. snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
  96. snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
  97. snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
  98. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
  99. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
  100. snowflake/ml/modeling/decomposition/pca.py +6 -3
  101. snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
  102. snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
  103. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
  104. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
  105. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
  106. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
  107. snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
  108. snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
  109. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
  110. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
  111. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
  112. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
  113. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
  114. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
  115. snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
  116. snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
  117. snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
  118. snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
  119. snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
  120. snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
  121. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
  122. snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
  123. snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
  124. snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
  125. snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
  126. snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
  127. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
  128. snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
  129. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
  130. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
  131. snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
  132. snowflake/ml/modeling/impute/knn_imputer.py +6 -3
  133. snowflake/ml/modeling/impute/missing_indicator.py +6 -3
  134. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
  135. snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
  136. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
  137. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
  138. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
  139. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
  140. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
  141. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
  142. snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
  143. snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
  144. snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
  145. snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
  146. snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
  147. snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
  148. snowflake/ml/modeling/linear_model/lars.py +6 -3
  149. snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
  150. snowflake/ml/modeling/linear_model/lasso.py +6 -3
  151. snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
  152. snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
  153. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
  154. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
  155. snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
  156. snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
  157. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
  158. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
  159. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
  160. snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
  161. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
  162. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
  163. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
  164. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
  165. snowflake/ml/modeling/linear_model/perceptron.py +6 -3
  166. snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
  167. snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
  168. snowflake/ml/modeling/linear_model/ridge.py +6 -3
  169. snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
  170. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
  171. snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
  172. snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
  173. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
  174. snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
  175. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
  176. snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
  177. snowflake/ml/modeling/manifold/isomap.py +6 -3
  178. snowflake/ml/modeling/manifold/mds.py +6 -3
  179. snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
  180. snowflake/ml/modeling/manifold/tsne.py +6 -3
  181. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
  182. snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
  183. snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
  184. snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
  185. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
  186. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
  187. snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
  188. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
  189. snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
  190. snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
  191. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
  192. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
  193. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
  194. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
  195. snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
  196. snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
  197. snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
  198. snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
  199. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
  200. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
  201. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
  202. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
  203. snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
  204. snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
  205. snowflake/ml/modeling/pipeline/pipeline.py +16 -178
  206. snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
  207. snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
  208. snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
  209. snowflake/ml/modeling/svm/linear_svc.py +6 -3
  210. snowflake/ml/modeling/svm/linear_svr.py +6 -3
  211. snowflake/ml/modeling/svm/nu_svc.py +6 -3
  212. snowflake/ml/modeling/svm/nu_svr.py +6 -3
  213. snowflake/ml/modeling/svm/svc.py +6 -3
  214. snowflake/ml/modeling/svm/svr.py +6 -3
  215. snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
  216. snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
  217. snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
  218. snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
  219. snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
  220. snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
  221. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
  222. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
  223. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
  224. snowflake/ml/registry/_manager/model_manager.py +70 -33
  225. snowflake/ml/registry/registry.py +41 -22
  226. snowflake/ml/version.py +1 -1
  227. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
  228. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
  229. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
  230. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  231. snowflake/ml/fileset/parquet_parser.py +0 -170
  232. snowflake/ml/fileset/tf_dataset.py +0 -88
  233. snowflake/ml/fileset/torch_datapipe.py +0 -57
  234. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  235. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  236. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
  237. {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -158,8 +158,10 @@ def parse_schema_level_object_identifier(
158
158
  res = _SF_SCHEMA_LEVEL_OBJECT_RE.fullmatch(object_name)
159
159
  if not res:
160
160
  raise ValueError(
161
- "Invalid identifier because it does not follow the pattern. "
162
- f"It should start with [[database.]schema.]object. Getting {object_name}"
161
+ f"Invalid object name `{object_name}` cannot be parsed as a SQL identifier. "
162
+ "Alphanumeric characters and underscores are permitted. "
163
+ "See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for "
164
+ "more information."
163
165
  )
164
166
  return (
165
167
  res.group("db"),
@@ -1,5 +1,8 @@
1
+ from pkgutil import extend_path
2
+
1
3
  from .data_connector import DataConnector
2
4
  from .data_ingestor import DataIngestor, DataIngestorType
3
5
  from .data_source import DataFrameInfo, DatasetInfo, DataSource
4
6
 
5
7
  __all__ = ["DataConnector", "DataSource", "DataFrameInfo", "DatasetInfo", "DataIngestor", "DataIngestorType"]
8
+ __path__ = extend_path(__path__, __name__)
@@ -2,7 +2,7 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Dict, Iterator, List, Optional, Union
5
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Sequence, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -47,7 +47,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
47
47
  def __init__(
48
48
  self,
49
49
  session: snowpark.Session,
50
- data_sources: List[data_source.DataSource],
50
+ data_sources: Sequence[data_source.DataSource],
51
51
  format: Optional[str] = None,
52
52
  **kwargs: Any,
53
53
  ) -> None:
@@ -60,14 +60,14 @@ class ArrowIngestor(data_ingestor.DataIngestor):
60
60
  kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
61
61
  """
62
62
  self._session = session
63
- self._data_sources = data_sources
63
+ self._data_sources = list(data_sources)
64
64
  self._format = format
65
65
  self._kwargs = kwargs
66
66
 
67
67
  self._schema: Optional[pa.Schema] = None
68
68
 
69
69
  @classmethod
70
- def from_sources(cls, session: snowpark.Session, sources: List[data_source.DataSource]) -> "ArrowIngestor":
70
+ def from_sources(cls, session: snowpark.Session, sources: Sequence[data_source.DataSource]) -> "ArrowIngestor":
71
71
  return cls(session, sources)
72
72
 
73
73
  @property
@@ -1,5 +1,16 @@
1
1
  import os
2
- from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
+ from typing import (
3
+ TYPE_CHECKING,
4
+ Any,
5
+ Dict,
6
+ Generator,
7
+ List,
8
+ Optional,
9
+ Sequence,
10
+ Type,
11
+ TypeVar,
12
+ cast,
13
+ )
3
14
 
4
15
  import numpy.typing as npt
5
16
  from typing_extensions import deprecated
@@ -12,6 +23,7 @@ from snowflake.ml.modeling._internal.constants import (
12
23
  IN_ML_RUNTIME_ENV_VAR,
13
24
  USE_OPTIMIZED_DATA_INGESTOR,
14
25
  )
26
+ from snowflake.snowpark import context as sf_context
15
27
 
16
28
  if TYPE_CHECKING:
17
29
  import pandas as pd
@@ -35,8 +47,10 @@ class DataConnector:
35
47
  def __init__(
36
48
  self,
37
49
  ingestor: data_ingestor.DataIngestor,
50
+ **kwargs: Any,
38
51
  ) -> None:
39
52
  self._ingestor = ingestor
53
+ self._kwargs = kwargs
40
54
 
41
55
  @classmethod
42
56
  @snowpark._internal.utils.private_preview(version="1.6.0")
@@ -44,20 +58,34 @@ class DataConnector:
44
58
  cls: Type[DataConnectorType],
45
59
  df: snowpark.DataFrame,
46
60
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
47
- **kwargs: Any
61
+ **kwargs: Any,
48
62
  ) -> DataConnectorType:
49
63
  if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
50
64
  raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
51
- source = data_source.DataFrameInfo(df.queries["queries"][0])
52
- assert df._session is not None
53
- return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
65
+ return cast(
66
+ DataConnectorType,
67
+ cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
68
+ )
69
+
70
+ @classmethod
71
+ @snowpark._internal.utils.private_preview(version="1.7.3")
72
+ def from_sql(
73
+ cls: Type[DataConnectorType],
74
+ query: str,
75
+ session: Optional[snowpark.Session] = None,
76
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
77
+ **kwargs: Any,
78
+ ) -> DataConnectorType:
79
+ session = session or sf_context.get_active_session()
80
+ source = data_source.DataFrameInfo(query)
81
+ return cls.from_sources(session, [source], ingestor_class=ingestor_class, **kwargs)
54
82
 
55
83
  @classmethod
56
84
  def from_dataset(
57
85
  cls: Type[DataConnectorType],
58
86
  ds: "dataset.Dataset",
59
87
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
60
- **kwargs: Any
88
+ **kwargs: Any,
61
89
  ) -> DataConnectorType:
62
90
  dsv = ds.selected_version
63
91
  assert dsv is not None
@@ -75,9 +103,9 @@ class DataConnector:
75
103
  def from_sources(
76
104
  cls: Type[DataConnectorType],
77
105
  session: snowpark.Session,
78
- sources: List[data_source.DataSource],
106
+ sources: Sequence[data_source.DataSource],
79
107
  ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
80
- **kwargs: Any
108
+ **kwargs: Any,
81
109
  ) -> DataConnectorType:
82
110
  ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
83
111
  ingestor = ingestor_class.from_sources(session, sources)
@@ -130,7 +158,11 @@ class DataConnector:
130
158
  func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
131
159
  )
132
160
  def to_torch_datapipe(
133
- self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
161
+ self,
162
+ *,
163
+ batch_size: int,
164
+ shuffle: bool = False,
165
+ drop_last_batch: bool = True,
134
166
  ) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
135
167
  """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
136
168
 
@@ -149,8 +181,13 @@ class DataConnector:
149
181
  """
150
182
  from snowflake.ml.data import torch_utils
151
183
 
184
+ expand_dims = self._kwargs.get("expand_dims", True)
152
185
  return torch_utils.TorchDataPipeWrapper(
153
- self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
186
+ self._ingestor,
187
+ batch_size=batch_size,
188
+ shuffle=shuffle,
189
+ drop_last=drop_last_batch,
190
+ expand_dims=expand_dims,
154
191
  )
155
192
 
156
193
  @telemetry.send_api_usage_telemetry(
@@ -179,8 +216,13 @@ class DataConnector:
179
216
  """
180
217
  from snowflake.ml.data import torch_utils
181
218
 
219
+ expand_dims = self._kwargs.get("expand_dims", True)
182
220
  return torch_utils.TorchDatasetWrapper(
183
- self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
221
+ self._ingestor,
222
+ batch_size=batch_size,
223
+ shuffle=shuffle,
224
+ drop_last=drop_last_batch,
225
+ expand_dims=expand_dims,
184
226
  )
185
227
 
186
228
  @telemetry.send_api_usage_telemetry(
@@ -6,6 +6,7 @@ from typing import (
6
6
  List,
7
7
  Optional,
8
8
  Protocol,
9
+ Sequence,
9
10
  Type,
10
11
  TypeVar,
11
12
  )
@@ -25,7 +26,7 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
25
26
  class DataIngestor(Protocol):
26
27
  @classmethod
27
28
  def from_sources(
28
- cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
29
+ cls: Type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
29
30
  ) -> DataIngestorType:
30
31
  raise NotImplementedError
31
32
 
@@ -17,6 +17,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
17
17
  batch_size: Optional[int],
18
18
  shuffle: bool = False,
19
19
  drop_last: bool = False,
20
+ expand_dims: bool = True,
20
21
  ) -> None:
21
22
  """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
22
23
  squeeze = False
@@ -29,6 +30,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
29
30
  self._shuffle = shuffle
30
31
  self._drop_last = drop_last
31
32
  self._squeeze_outputs = squeeze
33
+ self._expand_dims = expand_dims
32
34
 
33
35
  def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
34
36
  max_idx = 0
@@ -47,7 +49,10 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
47
49
  ):
48
50
  # Skip indices during multi-process data loading to prevent data duplication
49
51
  if counter == filter_idx:
50
- yield {k: _preprocess_array(v, squeeze=self._squeeze_outputs) for k, v in batch.items()}
52
+ yield {
53
+ k: _preprocess_array(v, squeeze=self._squeeze_outputs, expand_dims=self._expand_dims)
54
+ for k, v in batch.items()
55
+ }
51
56
  if counter < max_idx:
52
57
  counter += 1
53
58
  else:
@@ -58,13 +63,21 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
58
63
  """Wrap a DataIngestor into a PyTorch IterDataPipe"""
59
64
 
60
65
  def __init__(
61
- self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False
66
+ self,
67
+ ingestor: data_ingestor.DataIngestor,
68
+ *,
69
+ batch_size: int,
70
+ shuffle: bool = False,
71
+ drop_last: bool = False,
72
+ expand_dims: bool = True,
62
73
  ) -> None:
63
74
  """Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
64
- super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
75
+ super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, expand_dims=expand_dims)
65
76
 
66
77
 
67
- def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt.NDArray[Any], List[np.object_]]:
78
+ def _preprocess_array(
79
+ arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
80
+ ) -> Union[npt.NDArray[Any], List[np.object_]]:
68
81
  """Preprocesses batch column values."""
69
82
  single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
70
83
 
@@ -73,7 +86,7 @@ def _preprocess_array(arr: npt.NDArray[Any], squeeze: bool = False) -> Union[npt
73
86
  arr = arr.squeeze(axis=0)
74
87
 
75
88
  # For single dimensional data,
76
- if single_dimensional:
89
+ if single_dimensional and expand_dims:
77
90
  axis = 0 if arr.ndim == 0 else 1
78
91
  arr = np.expand_dims(arr, axis=axis)
79
92
 
@@ -419,7 +419,6 @@ class Dataset(lineage_node.LineageNode):
419
419
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
420
420
  def delete(self) -> None:
421
421
  """Delete Dataset and all contained versions"""
422
- # TODO: Check and warn if any versions exist
423
422
  self._session.sql(f"DROP DATASET {self.fully_qualified_name}").collect(
424
423
  statement_params=_TELEMETRY_STATEMENT_PARAMS
425
424
  )
@@ -45,8 +45,9 @@ class ExampleHelper:
45
45
  """Return a dataframe object about descriptions of all examples."""
46
46
  root_dir = Path(__file__).parent
47
47
  rows = []
48
+ hide_folders = ["citibike_trip_features", "source_data"]
48
49
  for f_name in os.listdir(root_dir):
49
- if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name != "source_data":
50
+ if os.path.isdir(os.path.join(root_dir, f_name)) and f_name[0].isalpha() and f_name not in hide_folders:
50
51
  source_file_path = root_dir.joinpath(f"{f_name}/source.yaml")
51
52
  source_dict = self._read_yaml(str(source_file_path))
52
53
  rows.append((f_name, source_dict["model_category"], source_dict["desc"], source_dict["label_columns"]))
@@ -2,6 +2,8 @@ import functools
2
2
  import inspect
3
3
  from typing import Any, Callable, List, Optional
4
4
 
5
+ from typing_extensions import deprecated
6
+
5
7
  from snowflake import snowpark
6
8
  from snowflake.connector import connection
7
9
  from snowflake.ml._internal import telemetry
@@ -11,11 +13,9 @@ from snowflake.ml._internal.exceptions import (
11
13
  fileset_error_messages,
12
14
  fileset_errors,
13
15
  )
14
- from snowflake.ml._internal.utils import (
15
- identifier,
16
- import_utils,
17
- snowpark_dataframe_utils,
18
- )
16
+ from snowflake.ml._internal.utils import identifier, snowpark_dataframe_utils
17
+ from snowflake.ml.data import data_connector
18
+ from snowflake.ml.data._internal import arrow_ingestor
19
19
  from snowflake.ml.fileset import sfcfs
20
20
  from snowflake.snowpark import exceptions as snowpark_exceptions, functions
21
21
 
@@ -44,6 +44,10 @@ def _raise_if_deleted(func: Callable[..., Any]) -> Callable[..., Any]:
44
44
  return raise_if_deleted_helper
45
45
 
46
46
 
47
+ @deprecated(
48
+ "FileSet is deprecated and will be removed in a future release."
49
+ " Use snowflake.ml.dataset.Dataset and snowflake.ml.data.DataConnector instead"
50
+ )
47
51
  class FileSet:
48
52
  """A FileSet represents an immutable snapshot of the result of a query in the form of files."""
49
53
 
@@ -285,6 +289,16 @@ class FileSet:
285
289
  """Get the Snowflake absolute path to this FileSet directory."""
286
290
  return _fileset_absolute_path(self._target_stage_loc, self.name)
287
291
 
292
+ def _to_data_connector(self) -> data_connector.DataConnector:
293
+ self._fs.optimize_read(self._list_files())
294
+ ingester = arrow_ingestor.ArrowIngestor(
295
+ self._snowpark_session,
296
+ self._list_files(),
297
+ format="parquet",
298
+ filesystem=self._fs,
299
+ )
300
+ return data_connector.DataConnector(ingester, expand_dims=False)
301
+
288
302
  @telemetry.send_api_usage_telemetry(
289
303
  project=_PROJECT,
290
304
  )
@@ -362,13 +376,9 @@ class FileSet:
362
376
  ----
363
377
  {'_COL_1':[10]}
364
378
  """
365
- IterableWrapper, _ = import_utils.import_or_get_dummy("torchdata.datapipes.iter.IterableWrapper")
366
- torch_datapipe_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.torch_datapipe")
367
-
368
- self._fs.optimize_read(self._list_files())
369
-
370
- input_dp = IterableWrapper(self._list_files())
371
- return torch_datapipe_module.ReadAndParseParquet(input_dp, self._fs, batch_size, shuffle, drop_last_batch)
379
+ return self._to_data_connector().to_torch_datapipe(
380
+ batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
381
+ )
372
382
 
373
383
  @telemetry.send_api_usage_telemetry(
374
384
  project=_PROJECT,
@@ -402,12 +412,8 @@ class FileSet:
402
412
  ----
403
413
  {'_COL_1': <tf.Tensor: shape=(1,), dtype=int64, numpy=[10]>}
404
414
  """
405
- tf_dataset_module, _ = import_utils.import_or_get_dummy("snowflake.ml.fileset.tf_dataset")
406
-
407
- self._fs.optimize_read(self._list_files())
408
-
409
- return tf_dataset_module.read_and_parse_parquet(
410
- self._list_files(), self._fs, batch_size, shuffle, drop_last_batch
415
+ return self._to_data_connector().to_tf_dataset(
416
+ batch_size=batch_size, shuffle=shuffle, drop_last_batch=drop_last_batch
411
417
  )
412
418
 
413
419
  @telemetry.send_api_usage_telemetry(
@@ -0,0 +1,21 @@
1
+ from snowflake.ml.jobs._utils.types import JOB_STATUS
2
+ from snowflake.ml.jobs.decorators import remote
3
+ from snowflake.ml.jobs.job import MLJob
4
+ from snowflake.ml.jobs.manager import (
5
+ delete_job,
6
+ get_job,
7
+ list_jobs,
8
+ submit_directory,
9
+ submit_file,
10
+ )
11
+
12
+ __all__ = [
13
+ "remote",
14
+ "submit_file",
15
+ "submit_directory",
16
+ "list_jobs",
17
+ "get_job",
18
+ "delete_job",
19
+ "MLJob",
20
+ "JOB_STATUS",
21
+ ]
@@ -0,0 +1,51 @@
1
+ from snowflake.ml._internal.utils.snowflake_env import SnowflakeCloudType
2
+ from snowflake.ml.jobs._utils.types import ComputeResources
3
+
4
+ # SPCS specification constants
5
+ DEFAULT_CONTAINER_NAME = "main"
6
+ PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
+
8
+ # Default container image information
9
+ DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
10
+ DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
11
+ DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
12
+ DEFAULT_IMAGE_TAG = "0.8.0"
13
+ DEFAULT_ENTRYPOINT_PATH = "func.py"
14
+
15
+ # Percent of container memory to allocate for /dev/shm volume
16
+ MEMORY_VOLUME_SIZE = 0.3
17
+
18
+ # Job status polling constants
19
+ JOB_POLL_INITIAL_DELAY_SECONDS = 0.1
20
+ JOB_POLL_MAX_DELAY_SECONDS = 1
21
+
22
+ # Compute pool resource information
23
+ # TODO: Query Snowflake for resource information instead of relying on this hardcoded
24
+ # table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
25
+ COMMON_INSTANCE_FAMILIES = {
26
+ "CPU_X64_XS": ComputeResources(cpu=1, memory=6),
27
+ "CPU_X64_S": ComputeResources(cpu=3, memory=13),
28
+ "CPU_X64_M": ComputeResources(cpu=6, memory=28),
29
+ "CPU_X64_L": ComputeResources(cpu=28, memory=116),
30
+ "HIGHMEM_X64_S": ComputeResources(cpu=6, memory=58),
31
+ }
32
+ AWS_INSTANCE_FAMILIES = {
33
+ "HIGHMEM_X64_M": ComputeResources(cpu=28, memory=240),
34
+ "HIGHMEM_X64_L": ComputeResources(cpu=124, memory=984),
35
+ "GPU_NV_S": ComputeResources(cpu=6, memory=27, gpu=1, gpu_type="A10G"),
36
+ "GPU_NV_M": ComputeResources(cpu=44, memory=178, gpu=4, gpu_type="A10G"),
37
+ "GPU_NV_L": ComputeResources(cpu=92, memory=1112, gpu=8, gpu_type="A100"),
38
+ }
39
+ AZURE_INSTANCE_FAMILIES = {
40
+ "HIGHMEM_X64_M": ComputeResources(cpu=28, memory=244),
41
+ "HIGHMEM_X64_L": ComputeResources(cpu=92, memory=654),
42
+ "GPU_NV_XS": ComputeResources(cpu=3, memory=26, gpu=1, gpu_type="T4"),
43
+ "GPU_NV_SM": ComputeResources(cpu=32, memory=424, gpu=1, gpu_type="A10"),
44
+ "GPU_NV_2M": ComputeResources(cpu=68, memory=858, gpu=2, gpu_type="A10"),
45
+ "GPU_NV_3M": ComputeResources(cpu=44, memory=424, gpu=2, gpu_type="A100"),
46
+ "GPU_NV_SL": ComputeResources(cpu=92, memory=858, gpu=4, gpu_type="A100"),
47
+ }
48
+ CLOUD_INSTANCE_FAMILIES = {
49
+ SnowflakeCloudType.AWS: AWS_INSTANCE_FAMILIES,
50
+ SnowflakeCloudType.AZURE: AZURE_INSTANCE_FAMILIES,
51
+ }