snowflake-ml-python 1.2.0__py3-none-any.whl → 1.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +16 -13
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +5 -1
- snowflake/ml/_internal/telemetry.py +19 -0
- snowflake/ml/feature_store/__init__.py +9 -0
- snowflake/ml/feature_store/entity.py +73 -0
- snowflake/ml/feature_store/feature_store.py +1657 -0
- snowflake/ml/feature_store/feature_view.py +459 -0
- snowflake/ml/model/_client/ops/model_ops.py +16 -38
- snowflake/ml/model/_client/sql/model.py +1 -7
- snowflake/ml/model/_client/sql/model_version.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +9 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +12 -2
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +7 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -6
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +0 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
- snowflake/ml/model/model_signature.py +72 -16
- snowflake/ml/model/type_hints.py +12 -0
- snowflake/ml/modeling/_internal/estimator_protocols.py +1 -41
- snowflake/ml/modeling/_internal/model_trainer_builder.py +13 -9
- snowflake/ml/modeling/_internal/{distributed_hpo_trainer.py → snowpark_implementations/distributed_hpo_trainer.py} +66 -96
- snowflake/ml/modeling/_internal/{snowpark_handlers.py → snowpark_implementations/snowpark_handlers.py} +9 -6
- snowflake/ml/modeling/_internal/{xgboost_external_memory_trainer.py → snowpark_implementations/xgboost_external_memory_trainer.py} +3 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +19 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +19 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +19 -3
- snowflake/ml/modeling/cluster/birch.py +19 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +19 -3
- snowflake/ml/modeling/cluster/dbscan.py +19 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +19 -3
- snowflake/ml/modeling/cluster/k_means.py +19 -3
- snowflake/ml/modeling/cluster/mean_shift.py +19 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +19 -3
- snowflake/ml/modeling/cluster/optics.py +19 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +19 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +19 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +19 -3
- snowflake/ml/modeling/compose/column_transformer.py +19 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +19 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +19 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +19 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +19 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +19 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +19 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +19 -3
- snowflake/ml/modeling/covariance/oas.py +19 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +19 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +19 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +19 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +19 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +19 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +19 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +19 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +19 -3
- snowflake/ml/modeling/decomposition/pca.py +19 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +19 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +19 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +19 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +19 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +19 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +19 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +19 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +19 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +19 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +19 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +19 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +19 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +19 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +19 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +19 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +19 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +19 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +19 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +19 -3
- snowflake/ml/modeling/impute/knn_imputer.py +19 -3
- snowflake/ml/modeling/impute/missing_indicator.py +19 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +19 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +19 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +19 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +19 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +19 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +19 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +19 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +19 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +19 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +19 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +19 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/lars.py +19 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +19 -3
- snowflake/ml/modeling/linear_model/lasso.py +19 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +19 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +19 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +19 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +19 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +19 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +19 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +19 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +19 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +19 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +19 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +19 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +19 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +19 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/perceptron.py +19 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/ridge.py +19 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +19 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +19 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +19 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +19 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +19 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +19 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +19 -3
- snowflake/ml/modeling/manifold/isomap.py +19 -3
- snowflake/ml/modeling/manifold/mds.py +19 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +19 -3
- snowflake/ml/modeling/manifold/tsne.py +19 -3
- snowflake/ml/modeling/metrics/classification.py +5 -6
- snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
- snowflake/ml/modeling/metrics/ranking.py +7 -3
- snowflake/ml/modeling/metrics/regression.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +19 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +19 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +3 -13
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +3 -13
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +19 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +19 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +19 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +19 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +19 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +19 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +19 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +19 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +19 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +19 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +19 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +19 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +19 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +19 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +19 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +19 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +19 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +19 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +19 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +19 -3
- snowflake/ml/modeling/preprocessing/polynomial_features.py +19 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +19 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +19 -3
- snowflake/ml/modeling/svm/linear_svc.py +19 -3
- snowflake/ml/modeling/svm/linear_svr.py +19 -3
- snowflake/ml/modeling/svm/nu_svc.py +19 -3
- snowflake/ml/modeling/svm/nu_svr.py +19 -3
- snowflake/ml/modeling/svm/svc.py +19 -3
- snowflake/ml/modeling/svm/svr.py +19 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +19 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +19 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +19 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +19 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +19 -3
- snowflake/ml/modeling/xgboost/xgb_regressor.py +19 -3
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +19 -3
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +19 -3
- snowflake/ml/registry/registry.py +2 -0
- snowflake/ml/version.py +1 -1
- snowflake_ml_python-1.2.2.dist-info/LICENSE.txt +202 -0
- {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/METADATA +276 -50
- {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/RECORD +204 -197
- {snowflake_ml_python-1.2.0.dist-info → snowflake_ml_python-1.2.2.dist-info}/WHEEL +2 -1
- snowflake_ml_python-1.2.2.dist-info/top_level.txt +1 -0
- /snowflake/ml/modeling/_internal/{pandas_trainer.py → local_implementations/pandas_trainer.py} +0 -0
- /snowflake/ml/modeling/_internal/{snowpark_trainer.py → snowpark_implementations/snowpark_trainer.py} +0 -0
@@ -0,0 +1,459 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import json
|
4
|
+
import re
|
5
|
+
from collections import OrderedDict
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from enum import Enum
|
8
|
+
from typing import Dict, List, Optional
|
9
|
+
|
10
|
+
from snowflake.ml._internal.exceptions import (
|
11
|
+
error_codes,
|
12
|
+
exceptions as snowml_exceptions,
|
13
|
+
)
|
14
|
+
from snowflake.ml._internal.utils.identifier import concat_names
|
15
|
+
from snowflake.ml._internal.utils.sql_identifier import (
|
16
|
+
SqlIdentifier,
|
17
|
+
to_sql_identifiers,
|
18
|
+
)
|
19
|
+
from snowflake.ml.feature_store.entity import Entity
|
20
|
+
from snowflake.snowpark import DataFrame, Session
|
21
|
+
from snowflake.snowpark.types import (
|
22
|
+
DateType,
|
23
|
+
StructType,
|
24
|
+
TimestampType,
|
25
|
+
TimeType,
|
26
|
+
_NumericType,
|
27
|
+
)
|
28
|
+
|
29
|
+
_FEATURE_VIEW_NAME_DELIMITER = "$"
|
30
|
+
_TIMESTAMP_COL_PLACEHOLDER = "FS_TIMESTAMP_COL_PLACEHOLDER_VAL"
|
31
|
+
_FEATURE_OBJ_TYPE = "FEATURE_OBJ_TYPE"
|
32
|
+
_FEATURE_VIEW_VERSION_RE = re.compile("^([A-Za-z0-9_]*)$")
|
33
|
+
|
34
|
+
|
35
|
+
class FeatureViewVersion(str):
|
36
|
+
def __new__(cls, version: str) -> FeatureViewVersion:
|
37
|
+
if not _FEATURE_VIEW_VERSION_RE.match(version):
|
38
|
+
raise snowml_exceptions.SnowflakeMLException(
|
39
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
40
|
+
original_exception=ValueError(
|
41
|
+
f"`{version}` is not a valid feature view version. Only letter, number and underscore is allowed."
|
42
|
+
),
|
43
|
+
)
|
44
|
+
return super().__new__(cls, version.upper())
|
45
|
+
|
46
|
+
def __init__(self, version: str) -> None:
|
47
|
+
return super().__init__()
|
48
|
+
|
49
|
+
|
50
|
+
class FeatureViewStatus(Enum):
|
51
|
+
DRAFT = "DRAFT"
|
52
|
+
STATIC = "STATIC"
|
53
|
+
RUNNING = "RUNNING"
|
54
|
+
SUSPENDED = "SUSPENDED"
|
55
|
+
|
56
|
+
|
57
|
+
@dataclass(frozen=True)
|
58
|
+
class FeatureViewSlice:
|
59
|
+
feature_view_ref: FeatureView
|
60
|
+
names: List[SqlIdentifier]
|
61
|
+
|
62
|
+
def __repr__(self) -> str:
|
63
|
+
states = (f"{k}={v}" for k, v in vars(self).items())
|
64
|
+
return f"{type(self).__name__}({', '.join(states)})"
|
65
|
+
|
66
|
+
def __eq__(self, other: object) -> bool:
|
67
|
+
if not isinstance(other, FeatureViewSlice):
|
68
|
+
return False
|
69
|
+
|
70
|
+
return self.names == other.names and self.feature_view_ref == other.feature_view_ref
|
71
|
+
|
72
|
+
def to_json(self) -> str:
|
73
|
+
fvs_dict = {
|
74
|
+
"feature_view_ref": self.feature_view_ref.to_json(),
|
75
|
+
"names": self.names,
|
76
|
+
_FEATURE_OBJ_TYPE: self.__class__.__name__,
|
77
|
+
}
|
78
|
+
return json.dumps(fvs_dict)
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def from_json(cls, json_str: str, session: Session) -> FeatureViewSlice:
|
82
|
+
json_dict = json.loads(json_str)
|
83
|
+
if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__:
|
84
|
+
raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}")
|
85
|
+
del json_dict[_FEATURE_OBJ_TYPE]
|
86
|
+
json_dict["feature_view_ref"] = FeatureView.from_json(json_dict["feature_view_ref"], session)
|
87
|
+
return cls(**json_dict)
|
88
|
+
|
89
|
+
|
90
|
+
class FeatureView:
|
91
|
+
"""
|
92
|
+
A FeatureView instance encapsulates a logical group of features.
|
93
|
+
"""
|
94
|
+
|
95
|
+
def __init__(
|
96
|
+
self,
|
97
|
+
name: str,
|
98
|
+
entities: List[Entity],
|
99
|
+
feature_df: DataFrame,
|
100
|
+
timestamp_col: Optional[str] = None,
|
101
|
+
refresh_freq: Optional[str] = None,
|
102
|
+
desc: str = "",
|
103
|
+
) -> None:
|
104
|
+
"""
|
105
|
+
Create a FeatureView instance.
|
106
|
+
|
107
|
+
Args:
|
108
|
+
name: name of the FeatureView. NOTE: FeatureView name will be capitalized.
|
109
|
+
entities: entities that the FeatureView is associated with.
|
110
|
+
feature_df: Snowpark DataFrame containing data source and all feature feature_df logics.
|
111
|
+
Final projection of the DataFrame should contain feature names, join keys and timestamp(if applicable).
|
112
|
+
timestamp_col: name of the timestamp column for point-in-time lookup when consuming the
|
113
|
+
feature values.
|
114
|
+
refresh_freq: Time unit defining how often the new feature data should be generated.
|
115
|
+
Valid args are { <num> { seconds | minutes | hours | days } | DOWNSTREAM | <cron expr> <time zone>}.
|
116
|
+
NOTE: Currently minimum refresh frequency is 1 minute.
|
117
|
+
NOTE: If refresh_freq is in cron expression format, there must be a valid time zone as well.
|
118
|
+
E.g. * * * * * UTC
|
119
|
+
NOTE: If refresh_freq is not provided, then FeatureView will be registered as View on Snowflake backend
|
120
|
+
and there won't be extra storage cost.
|
121
|
+
desc: description of the FeatureView.
|
122
|
+
"""
|
123
|
+
|
124
|
+
self._name: SqlIdentifier = SqlIdentifier(name)
|
125
|
+
self._entities: List[Entity] = entities
|
126
|
+
self._feature_df: DataFrame = feature_df
|
127
|
+
self._timestamp_col: Optional[SqlIdentifier] = (
|
128
|
+
SqlIdentifier(timestamp_col) if timestamp_col is not None else None
|
129
|
+
)
|
130
|
+
self._desc: str = desc
|
131
|
+
self._query: str = self._get_query()
|
132
|
+
self._version: Optional[FeatureViewVersion] = None
|
133
|
+
self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
|
134
|
+
self._feature_desc: OrderedDict[SqlIdentifier, str] = OrderedDict((f, "") for f in self._get_feature_names())
|
135
|
+
self._refresh_freq: Optional[str] = refresh_freq
|
136
|
+
self._database: Optional[SqlIdentifier] = None
|
137
|
+
self._schema: Optional[SqlIdentifier] = None
|
138
|
+
self._warehouse: Optional[SqlIdentifier] = None
|
139
|
+
self._refresh_mode: Optional[str] = None
|
140
|
+
self._refresh_mode_reason: Optional[str] = None
|
141
|
+
self._validate()
|
142
|
+
|
143
|
+
def slice(self, names: List[str]) -> FeatureViewSlice:
|
144
|
+
"""
|
145
|
+
Select a subset of features within the FeatureView.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
names: feature names to select.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
FeatureViewSlice instance containing selected features.
|
152
|
+
|
153
|
+
Raises:
|
154
|
+
ValueError: if selected feature names is not found in the FeatureView.
|
155
|
+
"""
|
156
|
+
|
157
|
+
res = []
|
158
|
+
for name in names:
|
159
|
+
name = SqlIdentifier(name)
|
160
|
+
if name not in self.feature_names:
|
161
|
+
raise ValueError(f"Feature name {name} not found in FeatureView {self.name}.")
|
162
|
+
res.append(name)
|
163
|
+
return FeatureViewSlice(self, res)
|
164
|
+
|
165
|
+
def physical_name(self) -> SqlIdentifier:
|
166
|
+
"""Returns the physical name for this feature in Snowflake.
|
167
|
+
|
168
|
+
Returns:
|
169
|
+
Physical name string.
|
170
|
+
|
171
|
+
Raises:
|
172
|
+
RuntimeError: if the FeatureView is not materialized.
|
173
|
+
"""
|
174
|
+
if self.status == FeatureViewStatus.DRAFT or self.version is None:
|
175
|
+
raise RuntimeError(f"FeatureView {self.name} has not been materialized.")
|
176
|
+
return FeatureView._get_physical_name(self.name, self.version)
|
177
|
+
|
178
|
+
def fully_qualified_name(self) -> str:
|
179
|
+
"""Returns the fully qualified name (<database_name>.<schema_name>.<feature_view_name>) for the
|
180
|
+
FeatureView in Snowflake.
|
181
|
+
|
182
|
+
Returns:
|
183
|
+
fully qualified name string.
|
184
|
+
"""
|
185
|
+
return f"{self._database}.{self._schema}.{self.physical_name()}"
|
186
|
+
|
187
|
+
def attach_feature_desc(self, descs: Dict[str, str]) -> FeatureView:
|
188
|
+
"""
|
189
|
+
Associate feature level descriptions to the FeatureView.
|
190
|
+
|
191
|
+
Args:
|
192
|
+
descs: Dictionary contains feature name and corresponding descriptions.
|
193
|
+
|
194
|
+
Returns:
|
195
|
+
FeatureView with feature level desc attached.
|
196
|
+
|
197
|
+
Raises:
|
198
|
+
ValueError: if feature name is not found in the FeatureView.
|
199
|
+
"""
|
200
|
+
for f, d in descs.items():
|
201
|
+
f = SqlIdentifier(f)
|
202
|
+
if f not in self._feature_desc:
|
203
|
+
raise ValueError(
|
204
|
+
f"Feature name {f} is not found in FeatureView {self.name}, "
|
205
|
+
f"valid feature names are: {self.feature_names}"
|
206
|
+
)
|
207
|
+
self._feature_desc[f] = d
|
208
|
+
return self
|
209
|
+
|
210
|
+
@property
|
211
|
+
def name(self) -> SqlIdentifier:
|
212
|
+
return self._name
|
213
|
+
|
214
|
+
@property
|
215
|
+
def entities(self) -> List[Entity]:
|
216
|
+
return self._entities
|
217
|
+
|
218
|
+
@property
|
219
|
+
def feature_df(self) -> DataFrame:
|
220
|
+
return self._feature_df
|
221
|
+
|
222
|
+
@property
|
223
|
+
def timestamp_col(self) -> Optional[SqlIdentifier]:
|
224
|
+
return self._timestamp_col
|
225
|
+
|
226
|
+
@property
|
227
|
+
def desc(self) -> str:
|
228
|
+
return self._desc
|
229
|
+
|
230
|
+
@property
|
231
|
+
def query(self) -> str:
|
232
|
+
return self._query
|
233
|
+
|
234
|
+
@property
|
235
|
+
def version(self) -> Optional[FeatureViewVersion]:
|
236
|
+
return self._version
|
237
|
+
|
238
|
+
@property
|
239
|
+
def status(self) -> FeatureViewStatus:
|
240
|
+
return self._status
|
241
|
+
|
242
|
+
@property
|
243
|
+
def feature_names(self) -> List[SqlIdentifier]:
|
244
|
+
return list(self._feature_desc.keys())
|
245
|
+
|
246
|
+
@property
|
247
|
+
def feature_descs(self) -> Dict[SqlIdentifier, str]:
|
248
|
+
return self._feature_desc
|
249
|
+
|
250
|
+
@property
|
251
|
+
def refresh_freq(self) -> Optional[str]:
|
252
|
+
return self._refresh_freq
|
253
|
+
|
254
|
+
@refresh_freq.setter
|
255
|
+
def refresh_freq(self, new_value: str) -> None:
|
256
|
+
if self.status == FeatureViewStatus.DRAFT or self.status == FeatureViewStatus.STATIC:
|
257
|
+
raise RuntimeError(
|
258
|
+
f"Feature view {self.name}/{self.version} must be registered and non-static to update refresh_freq."
|
259
|
+
)
|
260
|
+
self._refresh_freq = new_value
|
261
|
+
|
262
|
+
@property
|
263
|
+
def database(self) -> Optional[SqlIdentifier]:
|
264
|
+
return self._database
|
265
|
+
|
266
|
+
@property
|
267
|
+
def schema(self) -> Optional[SqlIdentifier]:
|
268
|
+
return self._schema
|
269
|
+
|
270
|
+
@property
|
271
|
+
def warehouse(self) -> Optional[SqlIdentifier]:
|
272
|
+
return self._warehouse
|
273
|
+
|
274
|
+
@warehouse.setter
|
275
|
+
def warehouse(self, new_value: str) -> None:
|
276
|
+
if self.status == FeatureViewStatus.DRAFT or self.status == FeatureViewStatus.STATIC:
|
277
|
+
raise RuntimeError(
|
278
|
+
f"Feature view {self.name}/{self.version} must be registered and non-static to update warehouse."
|
279
|
+
)
|
280
|
+
self._warehouse = SqlIdentifier(new_value)
|
281
|
+
|
282
|
+
@property
|
283
|
+
def output_schema(self) -> StructType:
|
284
|
+
return self._feature_df.schema
|
285
|
+
|
286
|
+
@property
|
287
|
+
def refresh_mode(self) -> Optional[str]:
|
288
|
+
return self._refresh_mode
|
289
|
+
|
290
|
+
@property
|
291
|
+
def refresh_mode_reason(self) -> Optional[str]:
|
292
|
+
return self._refresh_mode_reason
|
293
|
+
|
294
|
+
def _get_query(self) -> str:
|
295
|
+
if len(self._feature_df.queries["queries"]) != 1:
|
296
|
+
raise ValueError(
|
297
|
+
f"""feature_df dataframe must contain only 1 query.
|
298
|
+
Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queries']}
|
299
|
+
"""
|
300
|
+
)
|
301
|
+
return str(self._feature_df.queries["queries"][0])
|
302
|
+
|
303
|
+
def _validate(self) -> None:
|
304
|
+
if _FEATURE_VIEW_NAME_DELIMITER in self._name:
|
305
|
+
raise ValueError(
|
306
|
+
f"FeatureView name `{self._name}` contains invalid character `{_FEATURE_VIEW_NAME_DELIMITER}`."
|
307
|
+
)
|
308
|
+
|
309
|
+
unescaped_df_cols = to_sql_identifiers(self._feature_df.columns)
|
310
|
+
for e in self._entities:
|
311
|
+
for k in e.join_keys:
|
312
|
+
if k not in unescaped_df_cols:
|
313
|
+
raise ValueError(
|
314
|
+
f"join_key {k} in Entity {e.name} is not found in input dataframe: {unescaped_df_cols}"
|
315
|
+
)
|
316
|
+
|
317
|
+
if self._timestamp_col is not None:
|
318
|
+
ts_col = self._timestamp_col
|
319
|
+
if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER):
|
320
|
+
raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.")
|
321
|
+
if ts_col not in to_sql_identifiers(self._feature_df.columns):
|
322
|
+
raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.")
|
323
|
+
|
324
|
+
col_type = self._feature_df.schema[ts_col].datatype
|
325
|
+
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
326
|
+
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
327
|
+
|
328
|
+
def _get_feature_names(self) -> List[SqlIdentifier]:
|
329
|
+
join_keys = [k for e in self._entities for k in e.join_keys]
|
330
|
+
ts_col = [self._timestamp_col] if self._timestamp_col is not None else []
|
331
|
+
feature_names = to_sql_identifiers(self._feature_df.columns, case_sensitive=True)
|
332
|
+
return [c for c in feature_names if c not in join_keys + ts_col]
|
333
|
+
|
334
|
+
def __repr__(self) -> str:
|
335
|
+
states = (f"{k}={v}" for k, v in vars(self).items())
|
336
|
+
return f"{type(self).__name__}({', '.join(states)})"
|
337
|
+
|
338
|
+
def __eq__(self, other: object) -> bool:
|
339
|
+
if not isinstance(other, FeatureView):
|
340
|
+
return False
|
341
|
+
|
342
|
+
return (
|
343
|
+
self.name == other.name
|
344
|
+
and self.version == other.version
|
345
|
+
and self.timestamp_col == other.timestamp_col
|
346
|
+
and self.entities == other.entities
|
347
|
+
and self.desc == other.desc
|
348
|
+
and self.feature_descs == other.feature_descs
|
349
|
+
and self.feature_names == other.feature_names
|
350
|
+
and self.query == other.query
|
351
|
+
and self.refresh_freq == other.refresh_freq
|
352
|
+
and str(self.status) == str(other.status)
|
353
|
+
and self.database == other.database
|
354
|
+
and self.warehouse == other.warehouse
|
355
|
+
and self.refresh_mode == other.refresh_mode
|
356
|
+
and self.refresh_mode_reason == other.refresh_mode_reason
|
357
|
+
)
|
358
|
+
|
359
|
+
def _to_dict(self) -> Dict[str, str]:
|
360
|
+
fv_dict = self.__dict__.copy()
|
361
|
+
if "_feature_df" in fv_dict:
|
362
|
+
fv_dict.pop("_feature_df")
|
363
|
+
fv_dict["_entities"] = [e._to_dict() for e in self._entities]
|
364
|
+
fv_dict["_status"] = str(self._status)
|
365
|
+
fv_dict["_name"] = str(self._name) if self._name is not None else None
|
366
|
+
fv_dict["_version"] = str(self._version) if self._version is not None else None
|
367
|
+
fv_dict["_database"] = str(self._database) if self._database is not None else None
|
368
|
+
fv_dict["_schema"] = str(self._schema) if self._schema is not None else None
|
369
|
+
fv_dict["_warehouse"] = str(self._warehouse) if self._warehouse is not None else None
|
370
|
+
fv_dict["_timestamp_col"] = str(self._timestamp_col) if self._timestamp_col is not None else None
|
371
|
+
|
372
|
+
feature_desc_dict = {}
|
373
|
+
for k, v in self._feature_desc.items():
|
374
|
+
feature_desc_dict[k.identifier()] = v
|
375
|
+
fv_dict["_feature_desc"] = feature_desc_dict
|
376
|
+
|
377
|
+
return fv_dict
|
378
|
+
|
379
|
+
def to_df(self, session: Session) -> DataFrame:
|
380
|
+
values = list(self._to_dict().values())
|
381
|
+
schema = [x.lstrip("_") for x in list(self._to_dict().keys())]
|
382
|
+
values.append(str(self.physical_name()))
|
383
|
+
schema.append("physical_name")
|
384
|
+
return session.create_dataframe([values], schema=schema)
|
385
|
+
|
386
|
+
def to_json(self) -> str:
|
387
|
+
state_dict = self._to_dict()
|
388
|
+
state_dict[_FEATURE_OBJ_TYPE] = self.__class__.__name__
|
389
|
+
return json.dumps(state_dict)
|
390
|
+
|
391
|
+
@classmethod
|
392
|
+
def from_json(cls, json_str: str, session: Session) -> FeatureView:
|
393
|
+
json_dict = json.loads(json_str)
|
394
|
+
if _FEATURE_OBJ_TYPE not in json_dict or json_dict[_FEATURE_OBJ_TYPE] != cls.__name__:
|
395
|
+
raise ValueError(f"Invalid json str for {cls.__name__}: {json_str}")
|
396
|
+
|
397
|
+
return FeatureView._construct_feature_view(
|
398
|
+
name=json_dict["_name"],
|
399
|
+
entities=[Entity(**e) for e in json_dict["_entities"]],
|
400
|
+
feature_df=session.sql(json_dict["_query"]),
|
401
|
+
timestamp_col=json_dict["_timestamp_col"],
|
402
|
+
desc=json_dict["_desc"],
|
403
|
+
version=json_dict["_version"],
|
404
|
+
status=json_dict["_status"],
|
405
|
+
feature_descs=json_dict["_feature_desc"],
|
406
|
+
refresh_freq=json_dict["_refresh_freq"],
|
407
|
+
database=json_dict["_database"],
|
408
|
+
schema=json_dict["_schema"],
|
409
|
+
warehouse=json_dict["_warehouse"],
|
410
|
+
refresh_mode=json_dict["_refresh_mode"],
|
411
|
+
refresh_mode_reason=json_dict["_refresh_mode_reason"],
|
412
|
+
)
|
413
|
+
|
414
|
+
@staticmethod
|
415
|
+
def _get_physical_name(fv_name: SqlIdentifier, fv_version: FeatureViewVersion) -> SqlIdentifier:
|
416
|
+
return SqlIdentifier(
|
417
|
+
concat_names(
|
418
|
+
[
|
419
|
+
str(fv_name),
|
420
|
+
_FEATURE_VIEW_NAME_DELIMITER,
|
421
|
+
str(fv_version),
|
422
|
+
]
|
423
|
+
)
|
424
|
+
)
|
425
|
+
|
426
|
+
@staticmethod
|
427
|
+
def _construct_feature_view(
|
428
|
+
name: str,
|
429
|
+
entities: List[Entity],
|
430
|
+
feature_df: DataFrame,
|
431
|
+
timestamp_col: Optional[str],
|
432
|
+
desc: str,
|
433
|
+
version: str,
|
434
|
+
status: FeatureViewStatus,
|
435
|
+
feature_descs: Dict[str, str],
|
436
|
+
refresh_freq: Optional[str],
|
437
|
+
database: Optional[str],
|
438
|
+
schema: Optional[str],
|
439
|
+
warehouse: Optional[str],
|
440
|
+
refresh_mode: Optional[str],
|
441
|
+
refresh_mode_reason: Optional[str],
|
442
|
+
) -> FeatureView:
|
443
|
+
fv = FeatureView(
|
444
|
+
name=name,
|
445
|
+
entities=entities,
|
446
|
+
feature_df=feature_df,
|
447
|
+
timestamp_col=timestamp_col,
|
448
|
+
desc=desc,
|
449
|
+
)
|
450
|
+
fv._version = FeatureViewVersion(version) if version is not None else None
|
451
|
+
fv._status = status
|
452
|
+
fv._refresh_freq = refresh_freq
|
453
|
+
fv._database = SqlIdentifier(database) if database is not None else None
|
454
|
+
fv._schema = SqlIdentifier(schema) if schema is not None else None
|
455
|
+
fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None
|
456
|
+
fv._refresh_mode = refresh_mode
|
457
|
+
fv._refresh_mode_reason = refresh_mode_reason
|
458
|
+
fv.attach_feature_desc(feature_descs)
|
459
|
+
return fv
|
@@ -4,9 +4,8 @@ import tempfile
|
|
4
4
|
from typing import Any, Dict, List, Optional, Union, cast
|
5
5
|
|
6
6
|
import yaml
|
7
|
-
from packaging import version
|
8
7
|
|
9
|
-
from snowflake.ml._internal.utils import identifier,
|
8
|
+
from snowflake.ml._internal.utils import identifier, sql_identifier
|
10
9
|
from snowflake.ml.model import model_signature, type_hints
|
11
10
|
from snowflake.ml.model._client.ops import metadata_ops
|
12
11
|
from snowflake.ml.model._client.sql import (
|
@@ -25,8 +24,6 @@ from snowflake.ml.model._signatures import snowpark_handler
|
|
25
24
|
from snowflake.snowpark import dataframe, row, session
|
26
25
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
27
26
|
|
28
|
-
_TAG_ON_MODEL_AVAILABLE_VERSION = version.parse("8.2.0")
|
29
|
-
|
30
27
|
|
31
28
|
class ModelOperator:
|
32
29
|
def __init__(
|
@@ -296,21 +293,14 @@ class ModelOperator:
|
|
296
293
|
tag_value: str,
|
297
294
|
statement_params: Optional[Dict[str, Any]] = None,
|
298
295
|
) -> None:
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
statement_params=statement_params,
|
308
|
-
)
|
309
|
-
else:
|
310
|
-
raise NotImplementedError(
|
311
|
-
f"`set_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION},"
|
312
|
-
f" currently is {sf_version}"
|
313
|
-
)
|
296
|
+
self._tag_client.set_tag_on_model(
|
297
|
+
model_name=model_name,
|
298
|
+
tag_database_name=tag_database_name,
|
299
|
+
tag_schema_name=tag_schema_name,
|
300
|
+
tag_name=tag_name,
|
301
|
+
tag_value=tag_value,
|
302
|
+
statement_params=statement_params,
|
303
|
+
)
|
314
304
|
|
315
305
|
def unset_tag(
|
316
306
|
self,
|
@@ -321,20 +311,13 @@ class ModelOperator:
|
|
321
311
|
tag_name: sql_identifier.SqlIdentifier,
|
322
312
|
statement_params: Optional[Dict[str, Any]] = None,
|
323
313
|
) -> None:
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
statement_params=statement_params,
|
332
|
-
)
|
333
|
-
else:
|
334
|
-
raise NotImplementedError(
|
335
|
-
f"`unset_tag` won't work before Snowflake version {_TAG_ON_MODEL_AVAILABLE_VERSION},"
|
336
|
-
f" currently is {sf_version}"
|
337
|
-
)
|
314
|
+
self._tag_client.unset_tag_on_model(
|
315
|
+
model_name=model_name,
|
316
|
+
tag_database_name=tag_database_name,
|
317
|
+
tag_schema_name=tag_schema_name,
|
318
|
+
tag_name=tag_name,
|
319
|
+
statement_params=statement_params,
|
320
|
+
)
|
338
321
|
|
339
322
|
def get_model_version_manifest(
|
340
323
|
self,
|
@@ -382,11 +365,6 @@ class ModelOperator:
|
|
382
365
|
version_name: sql_identifier.SqlIdentifier,
|
383
366
|
statement_params: Optional[Dict[str, Any]] = None,
|
384
367
|
) -> model_manifest_schema.SnowparkMLDataDict:
|
385
|
-
if (
|
386
|
-
snowflake_env.get_current_snowflake_version(self._session)
|
387
|
-
< model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
|
388
|
-
):
|
389
|
-
raise NotImplementedError("User_data has not been supported yet.")
|
390
368
|
raw_user_data_json_string = self._model_client.show_versions(
|
391
369
|
model_name=model_name,
|
392
370
|
version_name=version_name,
|
@@ -3,10 +3,8 @@ from typing import Any, Dict, List, Optional
|
|
3
3
|
from snowflake.ml._internal.utils import (
|
4
4
|
identifier,
|
5
5
|
query_result_checker,
|
6
|
-
snowflake_env,
|
7
6
|
sql_identifier,
|
8
7
|
)
|
9
|
-
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
10
8
|
from snowflake.snowpark import row, session
|
11
9
|
|
12
10
|
|
@@ -89,12 +87,8 @@ class ModelSQLClient:
|
|
89
87
|
.has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
90
88
|
.has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True)
|
91
89
|
.has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True)
|
90
|
+
.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
|
92
91
|
)
|
93
|
-
if (
|
94
|
-
snowflake_env.get_current_snowflake_version(self._session)
|
95
|
-
>= model_manifest_schema.MANIFEST_USER_DATA_ENABLE_VERSION
|
96
|
-
):
|
97
|
-
res = res.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
|
98
92
|
if validate_result and version_name:
|
99
93
|
res = res.has_dimensions(expected_rows=1)
|
100
94
|
|
@@ -146,24 +146,29 @@ class ModelVersionSQLClient:
|
|
146
146
|
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
147
147
|
statement_params: Optional[Dict[str, Any]] = None,
|
148
148
|
) -> dataframe.DataFrame:
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
149
|
+
with_statements = []
|
150
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
151
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
152
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
153
|
+
else:
|
154
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
155
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
156
|
+
self._database_name.identifier(),
|
157
|
+
self._schema_name.identifier(),
|
158
|
+
tmp_table_name,
|
159
|
+
)
|
160
|
+
input_df.write.save_as_table( # type: ignore[call-overload]
|
161
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
162
|
+
mode="errorifexists",
|
163
|
+
table_type="temporary",
|
164
|
+
statement_params=statement_params,
|
165
|
+
)
|
161
166
|
|
162
167
|
INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
|
163
168
|
|
164
169
|
module_version_alias = "MODEL_VERSION_ALIAS"
|
165
|
-
|
166
|
-
f"
|
170
|
+
with_statements.append(
|
171
|
+
f"{module_version_alias} AS "
|
167
172
|
f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
|
168
173
|
)
|
169
174
|
|
@@ -174,7 +179,7 @@ class ModelVersionSQLClient:
|
|
174
179
|
args_sql = ", ".join(args_sql_list)
|
175
180
|
|
176
181
|
sql = textwrap.dedent(
|
177
|
-
f"""{
|
182
|
+
f"""WITH {','.join(with_statements)}
|
178
183
|
SELECT *,
|
179
184
|
{module_version_alias}!{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
180
185
|
FROM {INTERMEDIATE_TABLE_NAME}"""
|
@@ -2,6 +2,7 @@ import logging
|
|
2
2
|
import os
|
3
3
|
import posixpath
|
4
4
|
from string import Template
|
5
|
+
from typing import List
|
5
6
|
|
6
7
|
import importlib_resources
|
7
8
|
|
@@ -36,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
36
37
|
session: snowpark.Session,
|
37
38
|
artifact_stage_location: str,
|
38
39
|
compute_pool: str,
|
40
|
+
external_access_integrations: List[str],
|
39
41
|
) -> None:
|
40
42
|
"""Initialization
|
41
43
|
|
@@ -47,6 +49,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
47
49
|
artifact_stage_location: Spec file and future deployment related artifacts will be stored under
|
48
50
|
{stage}/models/{model_id}
|
49
51
|
compute_pool: The compute pool used to run docker image build workload.
|
52
|
+
external_access_integrations: EAIs for network connection.
|
50
53
|
"""
|
51
54
|
self.context_dir = context_dir
|
52
55
|
self.image_repo = image_repo
|
@@ -54,6 +57,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
54
57
|
self.session = session
|
55
58
|
self.artifact_stage_location = artifact_stage_location
|
56
59
|
self.compute_pool = compute_pool
|
60
|
+
self.external_access_integrations = external_access_integrations
|
57
61
|
self.client = snowservice_client.SnowServiceClient(session)
|
58
62
|
|
59
63
|
assert artifact_stage_location.startswith(
|
@@ -202,4 +206,8 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
202
206
|
|
203
207
|
def _launch_kaniko_job(self, spec_stage_location: str) -> None:
|
204
208
|
logger.debug("Submitting job for building docker image with kaniko")
|
205
|
-
self.client.create_job(
|
209
|
+
self.client.create_job(
|
210
|
+
compute_pool=self.compute_pool,
|
211
|
+
spec_stage_location=spec_stage_location,
|
212
|
+
external_access_integrations=self.external_access_integrations,
|
213
|
+
)
|
@@ -465,6 +465,7 @@ class SnowServiceDeployment:
|
|
465
465
|
session=self.session,
|
466
466
|
artifact_stage_location=self._model_artifact_stage_location,
|
467
467
|
compute_pool=self.options.compute_pool,
|
468
|
+
external_access_integrations=self.options.external_access_integrations,
|
468
469
|
)
|
469
470
|
else:
|
470
471
|
image_builder = client_image_builder.ClientImageBuilder(
|
@@ -587,6 +588,7 @@ class SnowServiceDeployment:
|
|
587
588
|
spec_stage_location=spec_stage_location,
|
588
589
|
min_instances=self.options.min_instances,
|
589
590
|
max_instances=self.options.max_instances,
|
591
|
+
external_access_integrations=self.options.external_access_integrations,
|
590
592
|
)
|
591
593
|
logger.info(f"Wait for service {self._service_name} to become ready...")
|
592
594
|
client.block_until_resource_is_ready(
|