snowflake-ml-python 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. snowflake/ml/_internal/env_utils.py +16 -13
  2. snowflake/ml/_internal/exceptions/modeling_error_messages.py +5 -1
  3. snowflake/ml/_internal/telemetry.py +19 -0
  4. snowflake/ml/feature_store/__init__.py +9 -0
  5. snowflake/ml/feature_store/entity.py +73 -0
  6. snowflake/ml/feature_store/feature_store.py +1657 -0
  7. snowflake/ml/feature_store/feature_view.py +459 -0
  8. snowflake/ml/model/_client/ops/model_ops.py +16 -38
  9. snowflake/ml/model/_client/sql/model.py +1 -7
  10. snowflake/ml/model/_client/sql/model_version.py +20 -15
  11. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +9 -1
  12. snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
  13. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +12 -2
  14. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +7 -3
  15. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -6
  16. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +0 -2
  17. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  18. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -2
  19. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  20. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  21. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  22. snowflake/ml/model/model_signature.py +72 -16
  23. snowflake/ml/model/type_hints.py +12 -0
  24. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -41
  25. snowflake/ml/modeling/_internal/model_trainer_builder.py +13 -9
  26. snowflake/ml/modeling/_internal/{distributed_hpo_trainer.py → snowpark_implementations/distributed_hpo_trainer.py} +66 -96
  27. snowflake/ml/modeling/_internal/{snowpark_handlers.py → snowpark_implementations/snowpark_handlers.py} +9 -6
  28. snowflake/ml/modeling/_internal/{xgboost_external_memory_trainer.py → snowpark_implementations/xgboost_external_memory_trainer.py} +3 -1
  29. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +19 -3
  30. snowflake/ml/modeling/cluster/affinity_propagation.py +19 -3
  31. snowflake/ml/modeling/cluster/agglomerative_clustering.py +19 -3
  32. snowflake/ml/modeling/cluster/birch.py +19 -3
  33. snowflake/ml/modeling/cluster/bisecting_k_means.py +19 -3
  34. snowflake/ml/modeling/cluster/dbscan.py +19 -3
  35. snowflake/ml/modeling/cluster/feature_agglomeration.py +19 -3
  36. snowflake/ml/modeling/cluster/k_means.py +19 -3
  37. snowflake/ml/modeling/cluster/mean_shift.py +19 -3
  38. snowflake/ml/modeling/cluster/mini_batch_k_means.py +19 -3
  39. snowflake/ml/modeling/cluster/optics.py +19 -3
  40. snowflake/ml/modeling/cluster/spectral_biclustering.py +19 -3
  41. snowflake/ml/modeling/cluster/spectral_clustering.py +19 -3
  42. snowflake/ml/modeling/cluster/spectral_coclustering.py +19 -3
  43. snowflake/ml/modeling/compose/column_transformer.py +19 -3
  44. snowflake/ml/modeling/compose/transformed_target_regressor.py +19 -3
  45. snowflake/ml/modeling/covariance/elliptic_envelope.py +19 -3
  46. snowflake/ml/modeling/covariance/empirical_covariance.py +19 -3
  47. snowflake/ml/modeling/covariance/graphical_lasso.py +19 -3
  48. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +19 -3
  49. snowflake/ml/modeling/covariance/ledoit_wolf.py +19 -3
  50. snowflake/ml/modeling/covariance/min_cov_det.py +19 -3
  51. snowflake/ml/modeling/covariance/oas.py +19 -3
  52. snowflake/ml/modeling/covariance/shrunk_covariance.py +19 -3
  53. snowflake/ml/modeling/decomposition/dictionary_learning.py +19 -3
  54. snowflake/ml/modeling/decomposition/factor_analysis.py +19 -3
  55. snowflake/ml/modeling/decomposition/fast_ica.py +19 -3
  56. snowflake/ml/modeling/decomposition/incremental_pca.py +19 -3
  57. snowflake/ml/modeling/decomposition/kernel_pca.py +19 -3
  58. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +19 -3
  59. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +19 -3
  60. snowflake/ml/modeling/decomposition/pca.py +19 -3
  61. snowflake/ml/modeling/decomposition/sparse_pca.py +19 -3
  62. snowflake/ml/modeling/decomposition/truncated_svd.py +19 -3
  63. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +19 -3
  64. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +19 -3
  65. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +19 -3
  66. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +19 -3
  67. snowflake/ml/modeling/ensemble/bagging_classifier.py +19 -3
  68. snowflake/ml/modeling/ensemble/bagging_regressor.py +19 -3
  69. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +19 -3
  70. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +19 -3
  71. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +19 -3
  72. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +19 -3
  73. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +19 -3
  74. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +19 -3
  75. snowflake/ml/modeling/ensemble/isolation_forest.py +19 -3
  76. snowflake/ml/modeling/ensemble/random_forest_classifier.py +19 -3
  77. snowflake/ml/modeling/ensemble/random_forest_regressor.py +19 -3
  78. snowflake/ml/modeling/ensemble/stacking_regressor.py +19 -3
  79. snowflake/ml/modeling/ensemble/voting_classifier.py +19 -3
  80. snowflake/ml/modeling/ensemble/voting_regressor.py +19 -3
  81. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +19 -3
  82. snowflake/ml/modeling/feature_selection/select_fdr.py +19 -3
  83. snowflake/ml/modeling/feature_selection/select_fpr.py +19 -3
  84. snowflake/ml/modeling/feature_selection/select_fwe.py +19 -3
  85. snowflake/ml/modeling/feature_selection/select_k_best.py +19 -3
  86. snowflake/ml/modeling/feature_selection/select_percentile.py +19 -3
  87. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +19 -3
  88. snowflake/ml/modeling/feature_selection/variance_threshold.py +19 -3
  89. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +19 -3
  90. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +19 -3
  91. snowflake/ml/modeling/impute/iterative_imputer.py +19 -3
  92. snowflake/ml/modeling/impute/knn_imputer.py +19 -3
  93. snowflake/ml/modeling/impute/missing_indicator.py +19 -3
  94. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +19 -3
  95. snowflake/ml/modeling/kernel_approximation/nystroem.py +19 -3
  96. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +19 -3
  97. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +19 -3
  98. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +19 -3
  99. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +19 -3
  100. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +19 -3
  101. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +19 -3
  102. snowflake/ml/modeling/linear_model/ard_regression.py +19 -3
  103. snowflake/ml/modeling/linear_model/bayesian_ridge.py +19 -3
  104. snowflake/ml/modeling/linear_model/elastic_net.py +19 -3
  105. snowflake/ml/modeling/linear_model/elastic_net_cv.py +19 -3
  106. snowflake/ml/modeling/linear_model/gamma_regressor.py +19 -3
  107. snowflake/ml/modeling/linear_model/huber_regressor.py +19 -3
  108. snowflake/ml/modeling/linear_model/lars.py +19 -3
  109. snowflake/ml/modeling/linear_model/lars_cv.py +19 -3
  110. snowflake/ml/modeling/linear_model/lasso.py +19 -3
  111. snowflake/ml/modeling/linear_model/lasso_cv.py +19 -3
  112. snowflake/ml/modeling/linear_model/lasso_lars.py +19 -3
  113. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +19 -3
  114. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +19 -3
  115. snowflake/ml/modeling/linear_model/linear_regression.py +19 -3
  116. snowflake/ml/modeling/linear_model/logistic_regression.py +19 -3
  117. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +19 -3
  118. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +19 -3
  119. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +19 -3
  120. snowflake/ml/modeling/linear_model/multi_task_lasso.py +19 -3
  121. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +19 -3
  122. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +19 -3
  123. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +19 -3
  124. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +19 -3
  125. snowflake/ml/modeling/linear_model/perceptron.py +19 -3
  126. snowflake/ml/modeling/linear_model/poisson_regressor.py +19 -3
  127. snowflake/ml/modeling/linear_model/ransac_regressor.py +19 -3
  128. snowflake/ml/modeling/linear_model/ridge.py +19 -3
  129. snowflake/ml/modeling/linear_model/ridge_classifier.py +19 -3
  130. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +19 -3
  131. snowflake/ml/modeling/linear_model/ridge_cv.py +19 -3
  132. snowflake/ml/modeling/linear_model/sgd_classifier.py +19 -3
  133. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +19 -3
  134. snowflake/ml/modeling/linear_model/sgd_regressor.py +19 -3
  135. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +19 -3
  136. snowflake/ml/modeling/linear_model/tweedie_regressor.py +19 -3
  137. snowflake/ml/modeling/manifold/isomap.py +19 -3
  138. snowflake/ml/modeling/manifold/mds.py +19 -3
  139. snowflake/ml/modeling/manifold/spectral_embedding.py +19 -3
  140. snowflake/ml/modeling/manifold/tsne.py +19 -3
  141. snowflake/ml/modeling/metrics/classification.py +5 -6
  142. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  143. snowflake/ml/modeling/metrics/ranking.py +7 -3
  144. snowflake/ml/modeling/metrics/regression.py +6 -3
  145. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +19 -3
  146. snowflake/ml/modeling/mixture/gaussian_mixture.py +19 -3
  147. snowflake/ml/modeling/model_selection/grid_search_cv.py +3 -13
  148. snowflake/ml/modeling/model_selection/randomized_search_cv.py +3 -13
  149. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +19 -3
  150. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +19 -3
  151. snowflake/ml/modeling/multiclass/output_code_classifier.py +19 -3
  152. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +19 -3
  153. snowflake/ml/modeling/naive_bayes/categorical_nb.py +19 -3
  154. snowflake/ml/modeling/naive_bayes/complement_nb.py +19 -3
  155. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +19 -3
  156. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +19 -3
  157. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +19 -3
  158. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +19 -3
  159. snowflake/ml/modeling/neighbors/kernel_density.py +19 -3
  160. snowflake/ml/modeling/neighbors/local_outlier_factor.py +19 -3
  161. snowflake/ml/modeling/neighbors/nearest_centroid.py +19 -3
  162. snowflake/ml/modeling/neighbors/nearest_neighbors.py +19 -3
  163. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +19 -3
  164. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +19 -3
  165. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +19 -3
  166. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +19 -3
  167. snowflake/ml/modeling/neural_network/mlp_classifier.py +19 -3
  168. snowflake/ml/modeling/neural_network/mlp_regressor.py +19 -3
  169. snowflake/ml/modeling/preprocessing/polynomial_features.py +19 -3
  170. snowflake/ml/modeling/semi_supervised/label_propagation.py +19 -3
  171. snowflake/ml/modeling/semi_supervised/label_spreading.py +19 -3
  172. snowflake/ml/modeling/svm/linear_svc.py +19 -3
  173. snowflake/ml/modeling/svm/linear_svr.py +19 -3
  174. snowflake/ml/modeling/svm/nu_svc.py +19 -3
  175. snowflake/ml/modeling/svm/nu_svr.py +19 -3
  176. snowflake/ml/modeling/svm/svc.py +19 -3
  177. snowflake/ml/modeling/svm/svr.py +19 -3
  178. snowflake/ml/modeling/tree/decision_tree_classifier.py +19 -3
  179. snowflake/ml/modeling/tree/decision_tree_regressor.py +19 -3
  180. snowflake/ml/modeling/tree/extra_tree_classifier.py +19 -3
  181. snowflake/ml/modeling/tree/extra_tree_regressor.py +19 -3
  182. snowflake/ml/modeling/xgboost/xgb_classifier.py +19 -3
  183. snowflake/ml/modeling/xgboost/xgb_regressor.py +19 -3
  184. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +19 -3
  185. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +19 -3
  186. snowflake/ml/registry/registry.py +2 -0
  187. snowflake/ml/version.py +1 -1
  188. snowflake_ml_python-1.2.2.dist-info/LICENSE.txt +202 -0
  189. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/METADATA +276 -50
  190. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/RECORD +204 -197
  191. {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/WHEEL +2 -1
  192. snowflake_ml_python-1.2.2.dist-info/top_level.txt +1 -0
  193. /snowflake/ml/modeling/_internal/{pandas_trainer.py → local_implementations/pandas_trainer.py} +0 -0
  194. /snowflake/ml/modeling/_internal/{snowpark_trainer.py → snowpark_implementations/snowpark_trainer.py} +0 -0
@@ -0,0 +1,459 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from collections import OrderedDict
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from typing import Dict, List, Optional
9
+
10
+ from snowflake.ml._internal.exceptions import (
11
+ error_codes,
12
+ exceptions as snowml_exceptions,
13
+ )
14
+ from snowflake.ml._internal.utils.identifier import concat_names
15
+ from snowflake.ml._internal.utils.sql_identifier import (
16
+ SqlIdentifier,
17
+ to_sql_identifiers,
18
+ )
19
+ from snowflake.ml.feature_store.entity import Entity
20
+ from snowflake.snowpark import DataFrame, Session
21
+ from snowflake.snowpark.types import (
22
+ DateType,
23
+ StructType,
24
+ TimestampType,
25
+ TimeType,
26
+ _NumericType,
27
+ )
28
+
29
+ _FEATURE_VIEW_NAME_DELIMITER = "$"
30
+ _TIMESTAMP_COL_PLACEHOLDER = "FS_TIMESTAMP_COL_PLACEHOLDER_VAL"
31
+ _FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
32
+ _FEATURE_VIEW_VERSION_RE = re.compile("^([A-Za-z0-9_]*)$")
33
+
34
+
35
+ class FeatureViewVersion(str):
36
+ def __new__(cls, version: str) -> FeatureViewVersion:
37
+ if not _FEATURE_VIEW_VERSION_RE.match(version):
38
+ raise snowml_exceptions.SnowflakeMLException(
39
+ error_code=error_codes.INVALID_ARGUMENT,
40
+ original_exception=ValueError(
41
+ f"`{version}` is not a valid feature view version. Only letter, number and underscore is allowed."
42
+ ),
43
+ )
44
+ return super().__new__(cls, version.upper())
45
+
46
+ def __init__(self, version: str) -> None:
47
+ return super().__init__()
48
+
49
+
50
+ class FeatureViewStatus(Enum):
51
+ DRAFT = "DRAFT"
52
+ STATIC = "STATIC"
53
+ RUNNING = "RUNNING"
54
+ SUSPENDED = "SUSPENDED"
55
+
56
+
57
+ @dataclass(frozen=True)
58
+ class FeatureViewSlice:
59
+ feature_view_ref: FeatureView
60
+ names: List[SqlIdentifier]
61
+
62
+ def __repr__(self) -> str:
63
+ states = (f"{k}={v}" for k, v in vars(self).items())
64
+ return f"{type(self).__name__}({', '.join(states)})"
65
+
66
+ def __eq__(self, other: object) -> bool:
67
+ if not isinstance(other, FeatureViewSlice):
68
+ return False
69
+
70
+ return self.names == other.names and self.feature_view_ref == other.feature_view_ref
71
+
72
+ def to_json(self) -> str:
73
+ fvs_dict = {
74
+ "feature_view_ref": self.feature_view_ref.to_json(),
75
+ "names": self.names,
76
+ _FEATURE_OBJ_TYPE: self.__class__.__name__,
77
+ }
78
+ return json.dumps(fvs_dict)
79
+
80
+ @classmethod
81
+ def from_json(cls, json_str: str, session: Session) -> FeatureViewSlice:
82
+ json_dict = json.loads(json_str)
83
+ if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__:
84
+ raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}")
85
+ del json_dict[_FEATURE_OBJ_TYPE]
86
+ json_dict["feature_view_ref"] = FeatureView.from_json(json_dict["feature_view_ref"], session)
87
+ return cls(**json_dict)
88
+
89
+
90
+ class FeatureView:
91
+ """
92
+ A FeatureView instance encapsulates a logical group of features.
93
+ """
94
+
95
+ def __init__(
96
+ self,
97
+ name: str,
98
+ entities: List[Entity],
99
+ feature_df: DataFrame,
100
+ timestamp_col: Optional[str] = None,
101
+ refresh_freq: Optional[str] = None,
102
+ desc: str = "",
103
+ ) -> None:
104
+ """
105
+ Create a FeatureView instance.
106
+
107
+ Args:
108
+ name: name of the FeatureView. NOTE: FeatureView name will be capitalized.
109
+ entities: entities that the FeatureView is associated with.
110
+ feature_df: Snowpark DataFrame containing data source and all feature feature_df logics.
111
+ Final projection of the DataFrame should contain feature names, join keys and timestamp(if applicable).
112
+ timestamp_col: name of the timestamp column for point-in-time lookup when consuming the
113
+ feature values.
114
+ refresh_freq: Time unit defining how often the new feature data should be generated.
115
+ Valid args are { <num> { seconds | minutes | hours | days } | DOWNSTREAM | <cron expr> <time zone>}.
116
+ NOTE: Currently minimum refresh frequency is 1 minute.
117
+ NOTE: If refresh_freq is in cron expression format, there must be a valid time zone as well.
118
+ E.g. * * * * * UTC
119
+ NOTE: If refresh_freq is not provided, then FeatureView will be registered as View on Snowflake backend
120
+ and there won't be extra storage cost.
121
+ desc: description of the FeatureView.
122
+ """
123
+
124
+ self._name: SqlIdentifier = SqlIdentifier(name)
125
+ self._entities: List[Entity] = entities
126
+ self._feature_df: DataFrame = feature_df
127
+ self._timestamp_col: Optional[SqlIdentifier] = (
128
+ SqlIdentifier(timestamp_col) if timestamp_col is not None else None
129
+ )
130
+ self._desc: str = desc
131
+ self._query: str = self._get_query()
132
+ self._version: Optional[FeatureViewVersion] = None
133
+ self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
134
+ self._feature_desc: OrderedDict[SqlIdentifier, str] = OrderedDict((f, "") for f in self._get_feature_names())
135
+ self._refresh_freq: Optional[str] = refresh_freq
136
+ self._database: Optional[SqlIdentifier] = None
137
+ self._schema: Optional[SqlIdentifier] = None
138
+ self._warehouse: Optional[SqlIdentifier] = None
139
+ self._refresh_mode: Optional[str] = None
140
+ self._refresh_mode_reason: Optional[str] = None
141
+ self._validate()
142
+
143
+ def slice(self, names: List[str]) -> FeatureViewSlice:
144
+ """
145
+ Select a subset of features within the FeatureView.
146
+
147
+ Args:
148
+ names: feature names to select.
149
+
150
+ Returns:
151
+ FeatureViewSlice instance containing selected features.
152
+
153
+ Raises:
154
+ ValueError: if selected feature names is not found in the FeatureView.
155
+ """
156
+
157
+ res = []
158
+ for name in names:
159
+ name = SqlIdentifier(name)
160
+ if name not in self.feature_names:
161
+ raise ValueError(f"Feature name {name} not found in FeatureView {self.name}.")
162
+ res.append(name)
163
+ return FeatureViewSlice(self, res)
164
+
165
+ def physical_name(self) -> SqlIdentifier:
166
+ """Returns the physical name for this feature in Snowflake.
167
+
168
+ Returns:
169
+ Physical name string.
170
+
171
+ Raises:
172
+ RuntimeError: if the FeatureView is not materialized.
173
+ """
174
+ if self.status == FeatureViewStatus.DRAFT or self.version is None:
175
+ raise RuntimeError(f"FeatureView {self.name} has not been materialized.")
176
+ return FeatureView._get_physical_name(self.name, self.version)
177
+
178
+ def fully_qualified_name(self) -> str:
179
+ """Returns the fully qualified name (<database_name>.<schema_name>.<feature_view_name>) for the
180
+ FeatureView in Snowflake.
181
+
182
+ Returns:
183
+ fully qualified name string.
184
+ """
185
+ return f"{self._database}.{self._schema}.{self.physical_name()}"
186
+
187
+ def attach_feature_desc(self, descs: Dict[str, str]) -> FeatureView:
188
+ """
189
+ Associate feature level descriptions to the FeatureView.
190
+
191
+ Args:
192
+ descs: Dictionary contains feature name and corresponding descriptions.
193
+
194
+ Returns:
195
+ FeatureView with feature level desc attached.
196
+
197
+ Raises:
198
+ ValueError: if feature name is not found in the FeatureView.
199
+ """
200
+ for f, d in descs.items():
201
+ f = SqlIdentifier(f)
202
+ if f not in self._feature_desc:
203
+ raise ValueError(
204
+ f"Feature name {f} is not found in FeatureView {self.name}, "
205
+ f"valid feature names are: {self.feature_names}"
206
+ )
207
+ self._feature_desc[f] = d
208
+ return self
209
+
210
+ @property
211
+ def name(self) -> SqlIdentifier:
212
+ return self._name
213
+
214
+ @property
215
+ def entities(self) -> List[Entity]:
216
+ return self._entities
217
+
218
+ @property
219
+ def feature_df(self) -> DataFrame:
220
+ return self._feature_df
221
+
222
+ @property
223
+ def timestamp_col(self) -> Optional[SqlIdentifier]:
224
+ return self._timestamp_col
225
+
226
+ @property
227
+ def desc(self) -> str:
228
+ return self._desc
229
+
230
+ @property
231
+ def query(self) -> str:
232
+ return self._query
233
+
234
+ @property
235
+ def version(self) -> Optional[FeatureViewVersion]:
236
+ return self._version
237
+
238
+ @property
239
+ def status(self) -> FeatureViewStatus:
240
+ return self._status
241
+
242
+ @property
243
+ def feature_names(self) -> List[SqlIdentifier]:
244
+ return list(self._feature_desc.keys())
245
+
246
+ @property
247
+ def feature_descs(self) -> Dict[SqlIdentifier, str]:
248
+ return self._feature_desc
249
+
250
+ @property
251
+ def refresh_freq(self) -> Optional[str]:
252
+ return self._refresh_freq
253
+
254
+ @refresh_freq.setter
255
+ def refresh_freq(self, new_value: str) -> None:
256
+ if self.status == FeatureViewStatus.DRAFT or self.status == FeatureViewStatus.STATIC:
257
+ raise RuntimeError(
258
+ f"Feature view {self.name}/{self.version} must be registered and non-static to update refresh_freq."
259
+ )
260
+ self._refresh_freq = new_value
261
+
262
+ @property
263
+ def database(self) -> Optional[SqlIdentifier]:
264
+ return self._database
265
+
266
+ @property
267
+ def schema(self) -> Optional[SqlIdentifier]:
268
+ return self._schema
269
+
270
+ @property
271
+ def warehouse(self) -> Optional[SqlIdentifier]:
272
+ return self._warehouse
273
+
274
+ @warehouse.setter
275
+ def warehouse(self, new_value: str) -> None:
276
+ if self.status == FeatureViewStatus.DRAFT or self.status == FeatureViewStatus.STATIC:
277
+ raise RuntimeError(
278
+ f"Feature view {self.name}/{self.version} must be registered and non-static to update warehouse."
279
+ )
280
+ self._warehouse = SqlIdentifier(new_value)
281
+
282
+ @property
283
+ def output_schema(self) -> StructType:
284
+ return self._feature_df.schema
285
+
286
+ @property
287
+ def refresh_mode(self) -> Optional[str]:
288
+ return self._refresh_mode
289
+
290
+ @property
291
+ def refresh_mode_reason(self) -> Optional[str]:
292
+ return self._refresh_mode_reason
293
+
294
+ def _get_query(self) -> str:
295
+ if len(self._feature_df.queries["queries"]) != 1:
296
+ raise ValueError(
297
+ f"""feature_df dataframe must contain only 1 query.
298
+ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queries']}
299
+ """
300
+ )
301
+ return str(self._feature_df.queries["queries"][0])
302
+
303
+ def _validate(self) -> None:
304
+ if _FEATURE_VIEW_NAME_DELIMITER in self._name:
305
+ raise ValueError(
306
+ f"FeatureView name `{self._name}` contains invalid character `{_FEATURE_VIEW_NAME_DELIMITER}`."
307
+ )
308
+
309
+ unescaped_df_cols = to_sql_identifiers(self._feature_df.columns)
310
+ for e in self._entities:
311
+ for k in e.join_keys:
312
+ if k not in unescaped_df_cols:
313
+ raise ValueError(
314
+ f"join_key {k} in Entity {e.name} is not found in input dataframe: {unescaped_df_cols}"
315
+ )
316
+
317
+ if self._timestamp_col is not None:
318
+ ts_col = self._timestamp_col
319
+ if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER):
320
+ raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.")
321
+ if ts_col not in to_sql_identifiers(self._feature_df.columns):
322
+ raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.")
323
+
324
+ col_type = self._feature_df.schema[ts_col].datatype
325
+ if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
326
+ raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
327
+
328
+ def _get_feature_names(self) -> List[SqlIdentifier]:
329
+ join_keys = [k for e in self._entities for k in e.join_keys]
330
+ ts_col = [self._timestamp_col] if self._timestamp_col is not None else []
331
+ feature_names = to_sql_identifiers(self._feature_df.columns, case_sensitive=True)
332
+ return [c for c in feature_names if c not in join_keys + ts_col]
333
+
334
+ def __repr__(self) -> str:
335
+ states = (f"{k}={v}" for k, v in vars(self).items())
336
+ return f"{type(self).__name__}({', '.join(states)})"
337
+
338
+ def __eq__(self, other: object) -> bool:
339
+ if not isinstance(other, FeatureView):
340
+ return False
341
+
342
+ return (
343
+ self.name == other.name
344
+ and self.version == other.version
345
+ and self.timestamp_col == other.timestamp_col
346
+ and self.entities == other.entities
347
+ and self.desc == other.desc
348
+ and self.feature_descs == other.feature_descs
349
+ and self.feature_names == other.feature_names
350
+ and self.query == other.query
351
+ and self.refresh_freq == other.refresh_freq
352
+ and str(self.status) == str(other.status)
353
+ and self.database == other.database
354
+ and self.warehouse == other.warehouse
355
+ and self.refresh_mode == other.refresh_mode
356
+ and self.refresh_mode_reason == other.refresh_mode_reason
357
+ )
358
+
359
+ def _to_dict(self) -> Dict[str, str]:
360
+ fv_dict = self.__dict__.copy()
361
+ if "_feature_df" in fv_dict:
362
+ fv_dict.pop("_feature_df")
363
+ fv_dict["_entities"] = [e._to_dict() for e in self._entities]
364
+ fv_dict["_status"] = str(self._status)
365
+ fv_dict["_name"] = str(self._name) if self._name is not None else None
366
+ fv_dict["_version"] = str(self._version) if self._version is not None else None
367
+ fv_dict["_database"] = str(self._database) if self._database is not None else None
368
+ fv_dict["_schema"] = str(self._schema) if self._schema is not None else None
369
+ fv_dict["_warehouse"] = str(self._warehouse) if self._warehouse is not None else None
370
+ fv_dict["_timestamp_col"] = str(self._timestamp_col) if self._timestamp_col is not None else None
371
+
372
+ feature_desc_dict = {}
373
+ for k, v in self._feature_desc.items():
374
+ feature_desc_dict[k.identifier()] = v
375
+ fv_dict["_feature_desc"] = feature_desc_dict
376
+
377
+ return fv_dict
378
+
379
+ def to_df(self, session: Session) -> DataFrame:
380
+ values = list(self._to_dict().values())
381
+ schema = [x.lstrip("_") for x in list(self._to_dict().keys())]
382
+ values.append(str(self.physical_name()))
383
+ schema.append("physical_name")
384
+ return session.create_dataframe([values], schema=schema)
385
+
386
+ def to_json(self) -> str:
387
+ state_dict = self._to_dict()
388
+ state_dict[_FEATURE_OBJ_TYPE] = self.__class__.__name__
389
+ return json.dumps(state_dict)
390
+
391
+ @classmethod
392
+ def from_json(cls, json_str: str, session: Session) -> FeatureView:
393
+ json_dict = json.loads(json_str)
394
+ if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__:
395
+ raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}")
396
+
397
+ return FeatureView._construct_feature_view(
398
+ name=json_dict["_name"],
399
+ entities=[Entity(**e) for e in json_dict["_entities"]],
400
+ feature_df=session.sql(json_dict["_query"]),
401
+ timestamp_col=json_dict["_timestamp_col"],
402
+ desc=json_dict["_desc"],
403
+ version=json_dict["_version"],
404
+ status=json_dict["_status"],
405
+ feature_descs=json_dict["_feature_desc"],
406
+ refresh_freq=json_dict["_refresh_freq"],
407
+ database=json_dict["_database"],
408
+ schema=json_dict["_schema"],
409
+ warehouse=json_dict["_warehouse"],
410
+ refresh_mode=json_dict["_refresh_mode"],
411
+ refresh_mode_reason=json_dict["_refresh_mode_reason"],
412
+ )
413
+
414
+ @staticmethod
415
+ def _get_physical_name(fv_name: SqlIdentifier, fv_version: FeatureViewVersion) -> SqlIdentifier:
416
+ return SqlIdentifier(
417
+ concat_names(
418
+ [
419
+ str(fv_name),
420
+ _FEATURE_VIEW_NAME_DELIMITER,
421
+ str(fv_version),
422
+ ]
423
+ )
424
+ )
425
+
426
+ @staticmethod
427
+ def _construct_feature_view(
428
+ name: str,
429
+ entities: List[Entity],
430
+ feature_df: DataFrame,
431
+ timestamp_col: Optional[str],
432
+ desc: str,
433
+ version: str,
434
+ status: FeatureViewStatus,
435
+ feature_descs: Dict[str, str],
436
+ refresh_freq: Optional[str],
437
+ database: Optional[str],
438
+ schema: Optional[str],
439
+ warehouse: Optional[str],
440
+ refresh_mode: Optional[str],
441
+ refresh_mode_reason: Optional[str],
442
+ ) -> FeatureView:
443
+ fv = FeatureView(
444
+ name=name,
445
+ entities=entities,
446
+ feature_df=feature_df,
447
+ timestamp_col=timestamp_col,
448
+ desc=desc,
449
+ )
450
+ fv._version = FeatureViewVersion(version) if version is not None else None
451
+ fv._status = status
452
+ fv._refresh_freq = refresh_freq
453
+ fv._database = SqlIdentifier(database) if database is not None else None
454
+ fv._schema = SqlIdentifier(schema) if schema is not None else None
455
+ fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None
456
+ fv._refresh_mode = refresh_mode
457
+ fv._refresh_mode_reason = refresh_mode_reason
458
+ fv.attach_feature_desc(feature_descs)
459
+ return fv
@@ -4,9 +4,8 @@ import tempfile
4
4
  from typing import Any, Dict, List, Optional, Union, cast
5
5
 
6
6
  import yaml
7
- from packaging import version
8
7
 
9
- from snowflake.ml._internal.utils import identifier, snowflake_env, sql_identifier
8
+ from snowflake.ml._internal.utils import identifier, sql_identifier
10
9
  from snowflake.ml.model import model_signature, type_hints
11
10
  from snowflake.ml.model._client.ops import metadata_ops
12
11
  from snowflake.ml.model._client.sql import (
@@ -25,8 +24,6 @@ from snowflake.ml.model._signatures import snowpark_handler
25
24
  from snowflake.snowpark import dataframe, row, session
26
25
  from snowflake.snowpark._internal import utils as snowpark_utils
27
26
 
28
- _TAG_ON_MODEL_AVAILABLE_VERSION = version.parse("8.2.0")
29
-
30
27
 
31
28
  class ModelOperator:
32
29
  def __init__(
@@ -296,21 +293,14 @@ class ModelOperator:
296
293
  tag_value: str,
297
294
  statement_params: Optional[Dict[str, Any]] = None,
298
295
  ) -> None:
299
- sf_version = snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
300
- if sf_version >= _TAG_ON_MODEL_AVAILABLE_VERSION:
301
- self._tag_client.set_tag_on_model(
302
- model_name=model_name,
303
- tag_database_name=tag_database_name,
304
- tag_schema_name=tag_schema_name,
305
- tag_name=tag_name,
306
- tag_value=tag_value,
307
- statement_params=statement_params,
308
- )
309
- else:
310
- raise NotImplementedError(
311
- f"`set_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION},"
312
- f" currently is {sf_version}"
313
- )
296
+ self._tag_client.set_tag_on_model(
297
+ model_name=model_name,
298
+ tag_database_name=tag_database_name,
299
+ tag_schema_name=tag_schema_name,
300
+ tag_name=tag_name,
301
+ tag_value=tag_value,
302
+ statement_params=statement_params,
303
+ )
314
304
 
315
305
  def unset_tag(
316
306
  self,
@@ -321,20 +311,13 @@ class ModelOperator:
321
311
  tag_name: sql_identifier.SqlIdentifier,
322
312
  statement_params: Optional[Dict[str, Any]] = None,
323
313
  ) -> None:
324
- sf_version = snowflake_env.get_current_snowflake_version(self._session, statement_params=statement_params)
325
- if sf_version >= _TAG_ON_MODEL_AVAILABLE_VERSION:
326
- self._tag_client.unset_tag_on_model(
327
- model_name=model_name,
328
- tag_database_name=tag_database_name,
329
- tag_schema_name=tag_schema_name,
330
- tag_name=tag_name,
331
- statement_params=statement_params,
332
- )
333
- else:
334
- raise NotImplementedError(
335
- f"`unset_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION},"
336
- f" currently is {sf_version}"
337
- )
314
+ self._tag_client.unset_tag_on_model(
315
+ model_name=model_name,
316
+ tag_database_name=tag_database_name,
317
+ tag_schema_name=tag_schema_name,
318
+ tag_name=tag_name,
319
+ statement_params=statement_params,
320
+ )
338
321
 
339
322
  def get_model_version_manifest(
340
323
  self,
@@ -382,11 +365,6 @@ class ModelOperator:
382
365
  version_name: sql_identifier.SqlIdentifier,
383
366
  statement_params: Optional[Dict[str, Any]] = None,
384
367
  ) -> model_manifest_schema.SnowparkMLDataDict:
385
- if (
386
- snowflake_env.get_current_snowflake_version(self._session)
387
- < model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
388
- ):
389
- raise NotImplementedError("User_data has not been supported yet.")
390
368
  raw_user_data_json_string = self._model_client.show_versions(
391
369
  model_name=model_name,
392
370
  version_name=version_name,
@@ -3,10 +3,8 @@ from typing import Any, Dict, List, Optional
3
3
  from snowflake.ml._internal.utils import (
4
4
  identifier,
5
5
  query_result_checker,
6
- snowflake_env,
7
6
  sql_identifier,
8
7
  )
9
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
10
8
  from snowflake.snowpark import row, session
11
9
 
12
10
 
@@ -89,12 +87,8 @@ class ModelSQLClient:
89
87
  .has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
90
88
  .has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True)
91
89
  .has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True)
90
+ .has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
92
91
  )
93
- if (
94
- snowflake_env.get_current_snowflake_version(self._session)
95
- >= model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
96
- ):
97
- res = res.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
98
92
  if validate_result and version_name:
99
93
  res = res.has_dimensions(expected_rows=1)
100
94
 
@@ -146,24 +146,29 @@ class ModelVersionSQLClient:
146
146
  returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
147
147
  statement_params: Optional[Dict[str, Any]] = None,
148
148
  ) -> dataframe.DataFrame:
149
- tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
150
- INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
151
- self._database_name.identifier(),
152
- self._schema_name.identifier(),
153
- tmp_table_name,
154
- )
155
- input_df.write.save_as_table( # type: ignore[call-overload]
156
- table_name=INTERMEDIATE_TABLE_NAME,
157
- mode="errorifexists",
158
- table_type="temporary",
159
- statement_params=statement_params,
160
- )
149
+ with_statements = []
150
+ if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
151
+ INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
152
+ with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
153
+ else:
154
+ tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
155
+ INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
156
+ self._database_name.identifier(),
157
+ self._schema_name.identifier(),
158
+ tmp_table_name,
159
+ )
160
+ input_df.write.save_as_table( # type: ignore[call-overload]
161
+ table_name=INTERMEDIATE_TABLE_NAME,
162
+ mode="errorifexists",
163
+ table_type="temporary",
164
+ statement_params=statement_params,
165
+ )
161
166
 
162
167
  INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
163
168
 
164
169
  module_version_alias = "MODEL_VERSION_ALIAS"
165
- model_version_alias_sql = (
166
- f"WITH {module_version_alias} AS "
170
+ with_statements.append(
171
+ f"{module_version_alias} AS "
167
172
  f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
168
173
  )
169
174
 
@@ -174,7 +179,7 @@ class ModelVersionSQLClient:
174
179
  args_sql = ", ".join(args_sql_list)
175
180
 
176
181
  sql = textwrap.dedent(
177
- f"""{model_version_alias_sql}
182
+ f"""WITH {','.join(with_statements)}
178
183
  SELECT *,
179
184
  {module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
180
185
  FROM {INTERMEDIATE_TABLE_NAME}"""
@@ -2,6 +2,7 @@ import logging
2
2
  import os
3
3
  import posixpath
4
4
  from string import Template
5
+ from typing import List
5
6
 
6
7
  import importlib_resources
7
8
 
@@ -36,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
36
37
  session: snowpark.Session,
37
38
  artifact_stage_location: str,
38
39
  compute_pool: str,
40
+ external_access_integrations: List[str],
39
41
  ) -> None:
40
42
  """Initialization
41
43
 
@@ -47,6 +49,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
47
49
  artifact_stage_location: Spec file and future deployment related artifacts will be stored under
48
50
  {stage}/models/{model_id}
49
51
  compute_pool: The compute pool used to run docker image build workload.
52
+ external_access_integrations: EAIs for network connection.
50
53
  """
51
54
  self.context_dir = context_dir
52
55
  self.image_repo = image_repo
@@ -54,6 +57,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
54
57
  self.session = session
55
58
  self.artifact_stage_location = artifact_stage_location
56
59
  self.compute_pool = compute_pool
60
+ self.external_access_integrations = external_access_integrations
57
61
  self.client = snowservice_client.SnowServiceClient(session)
58
62
 
59
63
  assert artifact_stage_location.startswith(
@@ -202,4 +206,8 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
202
206
 
203
207
  def _launch_kaniko_job(self, spec_stage_location: str) -> None:
204
208
  logger.debug("Submitting job for building docker image with kaniko")
205
- self.client.create_job(compute_pool=self.compute_pool, spec_stage_location=spec_stage_location)
209
+ self.client.create_job(
210
+ compute_pool=self.compute_pool,
211
+ spec_stage_location=spec_stage_location,
212
+ external_access_integrations=self.external_access_integrations,
213
+ )
@@ -465,6 +465,7 @@ class SnowServiceDeployment:
465
465
  session=self.session,
466
466
  artifact_stage_location=self._model_artifact_stage_location,
467
467
  compute_pool=self.options.compute_pool,
468
+ external_access_integrations=self.options.external_access_integrations,
468
469
  )
469
470
  else:
470
471
  image_builder = client_image_builder.ClientImageBuilder(
@@ -587,6 +588,7 @@ class SnowServiceDeployment:
587
588
  spec_stage_location=spec_stage_location,
588
589
  min_instances=self.options.min_instances,
589
590
  max_instances=self.options.max_instances,
591
+ external_access_integrations=self.options.external_access_integrations,
590
592
  )
591
593
  logger.info(f"Wait for service {self._service_name} to become ready...")
592
594
  client.block_until_resource_is_ready(