snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/sql_identifier.py +25 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +19 -2
- snowflake/ml/feature_store/feature_view.py +82 -28
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +78 -9
- snowflake/ml/model/_client/ops/model_ops.py +89 -7
- snowflake/ml/model/_client/ops/service_ops.py +200 -91
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +47 -13
- snowflake/ml/model/_model_composer/model_composer.py +11 -41
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
- snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
- snowflake/ml/model/_packager/model_packager.py +14 -10
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/type_hints.py +11 -152
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
- snowflake/ml/modeling/cluster/birch.py +1 -0
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
- snowflake/ml/modeling/cluster/dbscan.py +1 -0
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
- snowflake/ml/modeling/cluster/k_means.py +1 -0
- snowflake/ml/modeling/cluster/mean_shift.py +1 -0
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
- snowflake/ml/modeling/cluster/optics.py +1 -0
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
- snowflake/ml/modeling/compose/column_transformer.py +1 -0
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
- snowflake/ml/modeling/covariance/oas.py +1 -0
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/pca.py +1 -0
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
- snowflake/ml/modeling/impute/knn_imputer.py +1 -0
- snowflake/ml/modeling/impute/missing_indicator.py +1 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/lars.py +1 -0
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/perceptron.py +1 -0
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/ridge.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
- snowflake/ml/modeling/manifold/isomap.py +1 -0
- snowflake/ml/modeling/manifold/mds.py +1 -0
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
- snowflake/ml/modeling/manifold/tsne.py +1 -0
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
- snowflake/ml/modeling/pipeline/pipeline.py +0 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
- snowflake/ml/modeling/svm/linear_svc.py +1 -0
- snowflake/ml/modeling/svm/linear_svr.py +1 -0
- snowflake/ml/modeling/svm/nu_svc.py +1 -0
- snowflake/ml/modeling/svm/nu_svr.py +1 -0
- snowflake/ml/modeling/svm/svc.py +1 -0
- snowflake/ml/modeling/svm/svr.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -4
- snowflake/ml/registry/registry.py +165 -6
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +24 -9
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/RECORD +225 -249
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -106
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,7 @@ from snowflake.ml.feature_store import feature_store
|
|
22
22
|
from snowflake.ml.feature_store.entity import Entity
|
23
23
|
from snowflake.ml.lineage import lineage_node
|
24
24
|
from snowflake.snowpark import DataFrame, Session
|
25
|
+
from snowflake.snowpark.exceptions import SnowparkSQLException
|
25
26
|
from snowflake.snowpark.types import (
|
26
27
|
DateType,
|
27
28
|
StructType,
|
@@ -167,6 +168,7 @@ class FeatureView(lineage_node.LineageNode):
|
|
167
168
|
refresh_freq: Optional[str] = None,
|
168
169
|
desc: str = "",
|
169
170
|
warehouse: Optional[str] = None,
|
171
|
+
initialize: str = "ON_CREATE",
|
170
172
|
**_kwargs: Any,
|
171
173
|
) -> None:
|
172
174
|
"""
|
@@ -190,6 +192,10 @@ class FeatureView(lineage_node.LineageNode):
|
|
190
192
|
warehouse: warehouse to refresh feature view. Not needed for static feature view (refresh_freq is None).
|
191
193
|
For managed feature view, this warehouse will overwrite the default warehouse of Feature Store if it is
|
192
194
|
specified, otherwise the default warehouse will be used.
|
195
|
+
initialize: Specifies the behavior of the initial refresh of feature view. This property cannot be altered
|
196
|
+
after you register the feature view. It supports ON_CREATE (default) or ON_SCHEDULE. ON_CREATE refreshes
|
197
|
+
the feature view synchronously at creation. ON_SCHEDULE refreshes the feature view at the next scheduled
|
198
|
+
refresh. It is only effective when refresh_freq is not None.
|
193
199
|
_kwargs: reserved kwargs for system generated args. NOTE: DO NOT USE.
|
194
200
|
|
195
201
|
Example::
|
@@ -227,10 +233,14 @@ class FeatureView(lineage_node.LineageNode):
|
|
227
233
|
self._query: str = self._get_query()
|
228
234
|
self._version: Optional[FeatureViewVersion] = None
|
229
235
|
self._status: FeatureViewStatus = FeatureViewStatus.DRAFT
|
230
|
-
|
236
|
+
feature_names = self._get_feature_names()
|
237
|
+
self._feature_desc: Optional[OrderedDict[SqlIdentifier, str]] = (
|
238
|
+
OrderedDict((f, "") for f in feature_names) if feature_names is not None else None
|
239
|
+
)
|
231
240
|
self._refresh_freq: Optional[str] = refresh_freq
|
232
241
|
self._database: Optional[SqlIdentifier] = None
|
233
242
|
self._schema: Optional[SqlIdentifier] = None
|
243
|
+
self._initialize: str = initialize
|
234
244
|
self._warehouse: Optional[SqlIdentifier] = SqlIdentifier(warehouse) if warehouse is not None else None
|
235
245
|
self._refresh_mode: Optional[str] = _kwargs.get("refresh_mode", "AUTO")
|
236
246
|
self._refresh_mode_reason: Optional[str] = None
|
@@ -345,6 +355,15 @@ class FeatureView(lineage_node.LineageNode):
|
|
345
355
|
('START_STATION_LATITUDE', 'Latitude of the start station.')])
|
346
356
|
|
347
357
|
"""
|
358
|
+
if self._feature_desc is None:
|
359
|
+
warnings.warn(
|
360
|
+
"Failed to read feature view schema. Probably feature view is not refreshed yet. "
|
361
|
+
"Schema will be available after initial refresh.",
|
362
|
+
stacklevel=2,
|
363
|
+
category=UserWarning,
|
364
|
+
)
|
365
|
+
return self
|
366
|
+
|
348
367
|
for f, d in descs.items():
|
349
368
|
f = SqlIdentifier(f)
|
350
369
|
if f not in self._feature_desc:
|
@@ -424,10 +443,10 @@ class FeatureView(lineage_node.LineageNode):
|
|
424
443
|
|
425
444
|
@property
|
426
445
|
def feature_names(self) -> List[SqlIdentifier]:
|
427
|
-
return list(self._feature_desc.keys())
|
446
|
+
return list(self._feature_desc.keys()) if self._feature_desc is not None else []
|
428
447
|
|
429
448
|
@property
|
430
|
-
def feature_descs(self) -> Dict[SqlIdentifier, str]:
|
449
|
+
def feature_descs(self) -> Optional[Dict[SqlIdentifier, str]]:
|
431
450
|
return self._feature_desc
|
432
451
|
|
433
452
|
def list_columns(self) -> DataFrame:
|
@@ -463,7 +482,17 @@ class FeatureView(lineage_node.LineageNode):
|
|
463
482
|
|
464
483
|
"""
|
465
484
|
session = self._feature_df.session
|
466
|
-
rows = []
|
485
|
+
rows = [] # type: ignore[var-annotated]
|
486
|
+
|
487
|
+
if self.feature_descs is None:
|
488
|
+
warnings.warn(
|
489
|
+
"Failed to read feature view schema. Probably feature view is not refreshed yet. "
|
490
|
+
"Schema will be available after initial refresh.",
|
491
|
+
stacklevel=2,
|
492
|
+
category=UserWarning,
|
493
|
+
)
|
494
|
+
return session.create_dataframe(rows, schema=["name", "category", "dtype", "desc"])
|
495
|
+
|
467
496
|
for name, type in self._feature_df.dtypes:
|
468
497
|
if SqlIdentifier(name) in self.feature_descs:
|
469
498
|
desc = self.feature_descs[SqlIdentifier(name)]
|
@@ -565,6 +594,10 @@ class FeatureView(lineage_node.LineageNode):
|
|
565
594
|
)
|
566
595
|
self._warehouse = SqlIdentifier(new_value)
|
567
596
|
|
597
|
+
@property
|
598
|
+
def initialize(self) -> str:
|
599
|
+
return self._initialize
|
600
|
+
|
568
601
|
@property
|
569
602
|
def output_schema(self) -> StructType:
|
570
603
|
return self._infer_schema_df.schema
|
@@ -601,33 +634,49 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
601
634
|
f"FeatureView name `{self._name}` contains invalid character `{_FEATURE_VIEW_NAME_DELIMITER}`."
|
602
635
|
)
|
603
636
|
|
604
|
-
|
605
|
-
|
606
|
-
for
|
607
|
-
|
608
|
-
|
609
|
-
f"join_key {k} in Entity {e.name} is not found in input dataframe: {
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
637
|
+
df_cols = self._get_column_names()
|
638
|
+
if df_cols is not None:
|
639
|
+
for e in self._entities:
|
640
|
+
for k in e.join_keys:
|
641
|
+
if k not in df_cols:
|
642
|
+
raise ValueError(f"join_key {k} in Entity {e.name} is not found in input dataframe: {df_cols}")
|
643
|
+
|
644
|
+
if self._timestamp_col is not None:
|
645
|
+
ts_col = self._timestamp_col
|
646
|
+
if ts_col == SqlIdentifier(_TIMESTAMP_COL_PLACEHOLDER):
|
647
|
+
raise ValueError(f"Invalid timestamp_col name, cannot be {_TIMESTAMP_COL_PLACEHOLDER}.")
|
648
|
+
if ts_col not in df_cols:
|
649
|
+
raise ValueError(f"timestamp_col {ts_col} is not found in input dataframe.")
|
650
|
+
|
651
|
+
col_type = self._infer_schema_df.schema[ts_col].datatype
|
652
|
+
if not isinstance(col_type, (DateType, TimeType, TimestampType, _NumericType)):
|
653
|
+
raise ValueError(f"Invalid data type for timestamp_col {ts_col}: {col_type}.")
|
622
654
|
|
623
655
|
if re.match(_RESULT_SCAN_QUERY_PATTERN, self._query) is not None:
|
624
656
|
raise ValueError(f"feature_df should not be reading from RESULT_SCAN. Invalid query: {self._query}")
|
625
657
|
|
626
|
-
|
658
|
+
if self._initialize not in ["ON_CREATE", "ON_SCHEDULE"]:
|
659
|
+
raise ValueError("'initialize' only supports ON_CREATE or ON_SCHEDULE.")
|
660
|
+
|
661
|
+
def _get_column_names(self) -> Optional[List[SqlIdentifier]]:
|
662
|
+
try:
|
663
|
+
return to_sql_identifiers(self._infer_schema_df.columns)
|
664
|
+
except SnowparkSQLException as e:
|
665
|
+
warnings.warn(
|
666
|
+
"Failed to read feature view schema. Probably feature view is not refreshed yet. "
|
667
|
+
f"Schema will be available after initial refresh. Original exception: {e}",
|
668
|
+
stacklevel=2,
|
669
|
+
category=UserWarning,
|
670
|
+
)
|
671
|
+
return None
|
672
|
+
|
673
|
+
def _get_feature_names(self) -> Optional[List[SqlIdentifier]]:
|
627
674
|
join_keys = [k for e in self._entities for k in e.join_keys]
|
628
675
|
ts_col = [self._timestamp_col] if self._timestamp_col is not None else []
|
629
|
-
feature_names =
|
630
|
-
|
676
|
+
feature_names = self._get_column_names()
|
677
|
+
if feature_names is not None:
|
678
|
+
return [c for c in feature_names if c not in join_keys + ts_col]
|
679
|
+
return None
|
631
680
|
|
632
681
|
def __repr__(self) -> str:
|
633
682
|
states = (f"{k}={v}" for k, v in vars(self).items())
|
@@ -670,11 +719,13 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
670
719
|
fv_dict["_schema"] = str(self._schema) if self._schema is not None else None
|
671
720
|
fv_dict["_warehouse"] = str(self._warehouse) if self._warehouse is not None else None
|
672
721
|
fv_dict["_timestamp_col"] = str(self._timestamp_col) if self._timestamp_col is not None else None
|
722
|
+
fv_dict["_initialize"] = str(self._initialize)
|
673
723
|
|
674
724
|
feature_desc_dict = {}
|
675
|
-
|
676
|
-
|
677
|
-
|
725
|
+
if self._feature_desc is not None:
|
726
|
+
for k, v in self._feature_desc.items():
|
727
|
+
feature_desc_dict[k.identifier()] = v
|
728
|
+
fv_dict["_feature_desc"] = feature_desc_dict
|
678
729
|
|
679
730
|
lineage_node_keys = [key for key in fv_dict if key.startswith("_node") or key == "_session"]
|
680
731
|
|
@@ -760,6 +811,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
760
811
|
warehouse=json_dict["_warehouse"],
|
761
812
|
refresh_mode=json_dict["_refresh_mode"],
|
762
813
|
refresh_mode_reason=json_dict["_refresh_mode_reason"],
|
814
|
+
initialize=json_dict["_initialize"],
|
763
815
|
owner=json_dict["_owner"],
|
764
816
|
infer_schema_df=session.sql(json_dict.get("_infer_schema_query", None)),
|
765
817
|
session=session,
|
@@ -830,6 +882,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
830
882
|
warehouse: Optional[str],
|
831
883
|
refresh_mode: Optional[str],
|
832
884
|
refresh_mode_reason: Optional[str],
|
885
|
+
initialize: str,
|
833
886
|
owner: Optional[str],
|
834
887
|
infer_schema_df: Optional[DataFrame],
|
835
888
|
session: Session,
|
@@ -850,6 +903,7 @@ Got {len(self._feature_df.queries['queries'])}: {self._feature_df.queries['queri
|
|
850
903
|
fv._warehouse = SqlIdentifier(warehouse) if warehouse is not None else None
|
851
904
|
fv._refresh_mode = refresh_mode
|
852
905
|
fv._refresh_mode_reason = refresh_mode_reason
|
906
|
+
fv._initialize = initialize
|
853
907
|
fv._owner = owner
|
854
908
|
fv.attach_feature_desc(feature_descs)
|
855
909
|
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -431,4 +431,5 @@ def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]:
|
|
431
431
|
error_code=error_codes.SNOWML_NOT_FOUND,
|
432
432
|
original_exception=fileset_errors.StageNotFoundError("Query failed."),
|
433
433
|
) from e
|
434
|
-
|
434
|
+
assert e.msg is not None
|
435
|
+
raise snowpark_exceptions.SnowparkSQLException(e.msg, conn_error=e) from e
|
@@ -118,16 +118,21 @@ class LineageNode:
|
|
118
118
|
)
|
119
119
|
domain = lineage_object["domain"].lower()
|
120
120
|
if domain_filter is None or domain in domain_filter:
|
121
|
+
obj_name = ".".join(
|
122
|
+
identifier.rename_to_valid_snowflake_identifier(s)
|
123
|
+
for s in identifier.parse_schema_level_object_identifier(lineage_object["name"])
|
124
|
+
if s is not None
|
125
|
+
)
|
121
126
|
if domain in DOMAIN_LINEAGE_REGISTRY and lineage_object["status"] == "ACTIVE":
|
122
127
|
lineage_nodes.append(
|
123
128
|
DOMAIN_LINEAGE_REGISTRY[domain]._load_from_lineage_node(
|
124
|
-
self._session,
|
129
|
+
self._session, obj_name, lineage_object.get("version")
|
125
130
|
)
|
126
131
|
)
|
127
132
|
else:
|
128
133
|
lineage_nodes.append(
|
129
134
|
LineageNode(
|
130
|
-
name=
|
135
|
+
name=obj_name,
|
131
136
|
version=lineage_object.get("version"),
|
132
137
|
domain=domain,
|
133
138
|
status=lineage_object["status"],
|
snowflake/ml/model/__init__.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
from snowflake.ml.model._client.model.model_impl import Model
|
2
2
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
3
3
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
4
|
-
from snowflake.ml.model.models.llm import LLM, LLMOptions
|
5
4
|
|
6
|
-
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"
|
5
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
|
@@ -310,12 +310,12 @@ class ModelVersion(lineage_node.LineageNode):
|
|
310
310
|
project=_TELEMETRY_PROJECT,
|
311
311
|
subproject=_TELEMETRY_SUBPROJECT,
|
312
312
|
)
|
313
|
-
def
|
313
|
+
def get_model_task(self) -> model_types.Task:
|
314
314
|
statement_params = telemetry.get_statement_params(
|
315
315
|
project=_TELEMETRY_PROJECT,
|
316
316
|
subproject=_TELEMETRY_SUBPROJECT,
|
317
317
|
)
|
318
|
-
return self._model_ops.
|
318
|
+
return self._model_ops.get_model_task(
|
319
319
|
database_name=None,
|
320
320
|
schema_name=None,
|
321
321
|
model_name=self._model_name,
|
@@ -423,6 +423,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
423
423
|
partition_column = sql_identifier.SqlIdentifier(partition_column)
|
424
424
|
|
425
425
|
functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
|
426
|
+
|
426
427
|
if function_name:
|
427
428
|
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
428
429
|
find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
|
@@ -625,6 +626,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
625
626
|
"image_repo",
|
626
627
|
"gpu_requests",
|
627
628
|
"num_workers",
|
629
|
+
"max_batch_rows",
|
628
630
|
],
|
629
631
|
)
|
630
632
|
def create_service(
|
@@ -638,6 +640,7 @@ class ModelVersion(lineage_node.LineageNode):
|
|
638
640
|
max_instances: int = 1,
|
639
641
|
gpu_requests: Optional[str] = None,
|
640
642
|
num_workers: Optional[int] = None,
|
643
|
+
max_batch_rows: Optional[int] = None,
|
641
644
|
force_rebuild: bool = False,
|
642
645
|
build_external_access_integration: str,
|
643
646
|
) -> str:
|
@@ -646,22 +649,27 @@ class ModelVersion(lineage_node.LineageNode):
|
|
646
649
|
Args:
|
647
650
|
service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
|
648
651
|
schema of the model will be used.
|
649
|
-
image_build_compute_pool: The name of the compute pool used to build the model inference image.
|
652
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
650
653
|
the service compute pool if None.
|
651
654
|
service_compute_pool: The name of the compute pool used to run the inference service.
|
652
655
|
image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
|
653
656
|
or schema of the model will be used.
|
654
|
-
ingress_enabled:
|
655
|
-
|
657
|
+
ingress_enabled: If true, creates an service endpoint associated with the service. User must have
|
658
|
+
BIND SERVICE ENDPOINT privilege on the account.
|
659
|
+
max_instances: The maximum number of inference service instances to run. The same value it set to
|
660
|
+
MIN_INSTANCES property of the service.
|
656
661
|
gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
|
657
662
|
if None.
|
658
|
-
num_workers: The number of workers
|
659
|
-
|
663
|
+
num_workers: The number of workers to run the inference service for handling requests in parallel within an
|
664
|
+
instance of the service. By default, it is set to 2*vCPU+1 of the node for CPU based inference and 1 for
|
665
|
+
GPU based inference. For GPU based inference, please see best practices before playing with this value.
|
666
|
+
max_batch_rows: The maximum number of rows to batch for inference. Auto determined if None. Minimum 32.
|
660
667
|
force_rebuild: Whether to force a model inference image rebuild.
|
661
|
-
build_external_access_integration: The external access integration for image build.
|
668
|
+
build_external_access_integration: The external access integration for image build. This is usually
|
669
|
+
permitting access to conda & PyPI repositories.
|
662
670
|
|
663
671
|
Returns:
|
664
|
-
|
672
|
+
Result information about service creation from server.
|
665
673
|
"""
|
666
674
|
statement_params = telemetry.get_statement_params(
|
667
675
|
project=_TELEMETRY_PROJECT,
|
@@ -690,10 +698,71 @@ class ModelVersion(lineage_node.LineageNode):
|
|
690
698
|
max_instances=max_instances,
|
691
699
|
gpu_requests=gpu_requests,
|
692
700
|
num_workers=num_workers,
|
701
|
+
max_batch_rows=max_batch_rows,
|
693
702
|
force_rebuild=force_rebuild,
|
694
703
|
build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
|
695
704
|
statement_params=statement_params,
|
696
705
|
)
|
697
706
|
|
707
|
+
@telemetry.send_api_usage_telemetry(
|
708
|
+
project=_TELEMETRY_PROJECT,
|
709
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
710
|
+
)
|
711
|
+
def list_services(
|
712
|
+
self,
|
713
|
+
) -> List[str]:
|
714
|
+
"""List all the service names using this model version.
|
715
|
+
|
716
|
+
Returns:
|
717
|
+
List of service_names: The name of the service, can be fully qualified. If not fully qualified, the database
|
718
|
+
or schema of the model will be used.
|
719
|
+
"""
|
720
|
+
statement_params = telemetry.get_statement_params(
|
721
|
+
project=_TELEMETRY_PROJECT,
|
722
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
723
|
+
)
|
724
|
+
|
725
|
+
return self._model_ops.list_inference_services(
|
726
|
+
database_name=None,
|
727
|
+
schema_name=None,
|
728
|
+
model_name=self._model_name,
|
729
|
+
version_name=self._version_name,
|
730
|
+
statement_params=statement_params,
|
731
|
+
)
|
732
|
+
|
733
|
+
@telemetry.send_api_usage_telemetry(
|
734
|
+
project=_TELEMETRY_PROJECT,
|
735
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
736
|
+
)
|
737
|
+
def delete_service(
|
738
|
+
self,
|
739
|
+
service_name: str,
|
740
|
+
) -> None:
|
741
|
+
"""Drops the given service.
|
742
|
+
|
743
|
+
Args:
|
744
|
+
service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
|
745
|
+
schema of the model will be used.
|
746
|
+
|
747
|
+
Raises:
|
748
|
+
ValueError: If the service does not exist or operation is not permitted by user or service does not belong
|
749
|
+
to this model.
|
750
|
+
"""
|
751
|
+
if not service_name:
|
752
|
+
raise ValueError("service_name cannot be empty.")
|
753
|
+
|
754
|
+
statement_params = telemetry.get_statement_params(
|
755
|
+
project=_TELEMETRY_PROJECT,
|
756
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
757
|
+
)
|
758
|
+
self._model_ops.delete_service(
|
759
|
+
database_name=None,
|
760
|
+
schema_name=None,
|
761
|
+
model_name=self._model_name,
|
762
|
+
version_name=self._version_name,
|
763
|
+
service_name=service_name,
|
764
|
+
statement_params=statement_params,
|
765
|
+
)
|
766
|
+
|
698
767
|
|
699
768
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import json
|
1
2
|
import os
|
2
3
|
import pathlib
|
3
4
|
import tempfile
|
@@ -6,6 +7,7 @@ from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
|
|
6
7
|
|
7
8
|
import yaml
|
8
9
|
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
11
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
10
12
|
from snowflake.ml.model import model_signature, type_hints
|
11
13
|
from snowflake.ml.model._client.ops import metadata_ops
|
@@ -512,6 +514,71 @@ class ModelOperator:
|
|
512
514
|
statement_params=statement_params,
|
513
515
|
)
|
514
516
|
|
517
|
+
def list_inference_services(
|
518
|
+
self,
|
519
|
+
*,
|
520
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
521
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
522
|
+
model_name: sql_identifier.SqlIdentifier,
|
523
|
+
version_name: sql_identifier.SqlIdentifier,
|
524
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
525
|
+
) -> List[str]:
|
526
|
+
res = self._model_client.show_versions(
|
527
|
+
database_name=database_name,
|
528
|
+
schema_name=schema_name,
|
529
|
+
model_name=model_name,
|
530
|
+
version_name=version_name,
|
531
|
+
statement_params=statement_params,
|
532
|
+
)
|
533
|
+
col_name = self._model_client.MODEL_VERSION_INFERENCE_SERVICES_COL_NAME
|
534
|
+
if col_name not in res[0]:
|
535
|
+
# User need to opt into BCR 2024_08
|
536
|
+
raise exceptions.SnowflakeMLException(
|
537
|
+
error_code=error_codes.OPT_IN_REQUIRED,
|
538
|
+
original_exception=RuntimeError(
|
539
|
+
"Please opt in to BCR Bundle 2024_08 ("
|
540
|
+
"https://docs.snowflake.com/en/release-notes/bcr-bundles/2024_08_bundle)."
|
541
|
+
),
|
542
|
+
)
|
543
|
+
json_array = json.loads(res[0][col_name])
|
544
|
+
# TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
|
545
|
+
return [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
546
|
+
|
547
|
+
def delete_service(
|
548
|
+
self,
|
549
|
+
*,
|
550
|
+
database_name: Optional[sql_identifier.SqlIdentifier],
|
551
|
+
schema_name: Optional[sql_identifier.SqlIdentifier],
|
552
|
+
model_name: sql_identifier.SqlIdentifier,
|
553
|
+
version_name: sql_identifier.SqlIdentifier,
|
554
|
+
service_name: str,
|
555
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
556
|
+
) -> None:
|
557
|
+
services = self.list_inference_services(
|
558
|
+
database_name=database_name,
|
559
|
+
schema_name=schema_name,
|
560
|
+
model_name=model_name,
|
561
|
+
version_name=version_name,
|
562
|
+
statement_params=statement_params,
|
563
|
+
)
|
564
|
+
db, schema, service_name = sql_identifier.parse_fully_qualified_name(service_name)
|
565
|
+
fully_qualified_service_name = sql_identifier.get_fully_qualified_name(
|
566
|
+
db, schema, service_name, self._session.get_current_database(), self._session.get_current_schema()
|
567
|
+
)
|
568
|
+
|
569
|
+
for service in services:
|
570
|
+
if service == fully_qualified_service_name:
|
571
|
+
self._service_client.drop_service(
|
572
|
+
database_name=db,
|
573
|
+
schema_name=schema,
|
574
|
+
service_name=service_name,
|
575
|
+
statement_params=statement_params,
|
576
|
+
)
|
577
|
+
return
|
578
|
+
raise ValueError(
|
579
|
+
f"Service '{service_name}' does not exist or unauthorized or not associated with this model version."
|
580
|
+
)
|
581
|
+
|
515
582
|
def get_model_version_manifest(
|
516
583
|
self,
|
517
584
|
*,
|
@@ -538,7 +605,8 @@ class ModelOperator:
|
|
538
605
|
def _match_model_spec_with_sql_functions(
|
539
606
|
sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
|
540
607
|
) -> Dict[sql_identifier.SqlIdentifier, str]:
|
541
|
-
res = {}
|
608
|
+
res: Dict[sql_identifier.SqlIdentifier, str] = {}
|
609
|
+
|
542
610
|
for target_method in target_methods:
|
543
611
|
# Here we need to find the SQL function corresponding to the Python function.
|
544
612
|
# If the python function name is `abc`, then SQL function name can be `ABC` or `"abc"`.
|
@@ -574,7 +642,7 @@ class ModelOperator:
|
|
574
642
|
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
575
643
|
return model_spec
|
576
644
|
|
577
|
-
def
|
645
|
+
def get_model_task(
|
578
646
|
self,
|
579
647
|
*,
|
580
648
|
database_name: Optional[sql_identifier.SqlIdentifier],
|
@@ -582,7 +650,7 @@ class ModelOperator:
|
|
582
650
|
model_name: sql_identifier.SqlIdentifier,
|
583
651
|
version_name: sql_identifier.SqlIdentifier,
|
584
652
|
statement_params: Optional[Dict[str, Any]] = None,
|
585
|
-
) -> type_hints.
|
653
|
+
) -> type_hints.Task:
|
586
654
|
model_spec = self._fetch_model_spec(
|
587
655
|
database_name=database_name,
|
588
656
|
schema_name=schema_name,
|
@@ -590,8 +658,8 @@ class ModelOperator:
|
|
590
658
|
version_name=version_name,
|
591
659
|
statement_params=statement_params,
|
592
660
|
)
|
593
|
-
|
594
|
-
return type_hints.
|
661
|
+
task_val = model_spec.get("task", type_hints.Task.UNKNOWN.value)
|
662
|
+
return type_hints.Task(task_val)
|
595
663
|
|
596
664
|
def get_functions(
|
597
665
|
self,
|
@@ -633,6 +701,20 @@ class ModelOperator:
|
|
633
701
|
|
634
702
|
function_names_and_types.append((function_name, function_type))
|
635
703
|
|
704
|
+
if not function_names_and_types:
|
705
|
+
# If function_names_and_types is not populated, there are currently
|
706
|
+
# no warehouse functions for the model version. In order to do inference
|
707
|
+
# we must populate the functions so the mapping can be constructed.
|
708
|
+
model_manifest = self.get_model_version_manifest(
|
709
|
+
database_name=database_name,
|
710
|
+
schema_name=schema_name,
|
711
|
+
model_name=model_name,
|
712
|
+
version_name=version_name,
|
713
|
+
statement_params=statement_params,
|
714
|
+
)
|
715
|
+
for method in model_manifest["methods"]:
|
716
|
+
function_names_and_types.append((sql_identifier.SqlIdentifier(method["name"]), method["type"]))
|
717
|
+
|
636
718
|
signatures = model_spec["signatures"]
|
637
719
|
function_names = [name for name, _ in function_names_and_types]
|
638
720
|
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
@@ -799,7 +881,7 @@ class ModelOperator:
|
|
799
881
|
|
800
882
|
if keep_order:
|
801
883
|
# if it's a partitioned table function, _ID will be null and we won't be able to sort.
|
802
|
-
if df_res.select(
|
884
|
+
if df_res.select(snowpark_handler._KEEP_ORDER_COL_NAME).limit(1).collect()[0][0] is None:
|
803
885
|
warnings.warn(
|
804
886
|
formatting.unwrap(
|
805
887
|
"""
|
@@ -812,7 +894,7 @@ class ModelOperator:
|
|
812
894
|
)
|
813
895
|
else:
|
814
896
|
df_res = df_res.sort(
|
815
|
-
|
897
|
+
snowpark_handler._KEEP_ORDER_COL_NAME,
|
816
898
|
ascending=True,
|
817
899
|
)
|
818
900
|
|