snowflake-ml-python 1.4.0__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +77 -32
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/_internal/utils/identifier.py +3 -1
- snowflake/ml/_internal/utils/sql_identifier.py +2 -6
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +531 -332
- snowflake/ml/feature_store/feature_view.py +40 -23
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +56 -54
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +49 -17
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +137 -50
- snowflake/ml/model/_client/ops/model_ops.py +159 -40
- snowflake/ml/model/_client/sql/model.py +25 -2
- snowflake/ml/model/_client/sql/model_version.py +131 -2
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -51
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +19 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +6 -10
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_handlers/catboost.py +206 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +218 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -0
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -11
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +20 -1
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -1
- snowflake/ml/model/_packager/model_packager.py +2 -5
- snowflake/ml/model/{_model_composer/model_runtime/_runtime_requirements.py → _packager/model_runtime/_snowml_inference_alternative_requirements.py} +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +137 -0
- snowflake/ml/model/type_hints.py +21 -2
- snowflake/ml/modeling/_internal/estimator_utils.py +16 -11
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +13 -14
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +29 -7
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +246 -175
- snowflake/ml/modeling/cluster/affinity_propagation.py +246 -175
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +246 -175
- snowflake/ml/modeling/cluster/birch.py +248 -175
- snowflake/ml/modeling/cluster/bisecting_k_means.py +248 -175
- snowflake/ml/modeling/cluster/dbscan.py +246 -175
- snowflake/ml/modeling/cluster/feature_agglomeration.py +248 -175
- snowflake/ml/modeling/cluster/k_means.py +248 -175
- snowflake/ml/modeling/cluster/mean_shift.py +246 -175
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +248 -175
- snowflake/ml/modeling/cluster/optics.py +246 -175
- snowflake/ml/modeling/cluster/spectral_biclustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_clustering.py +246 -175
- snowflake/ml/modeling/cluster/spectral_coclustering.py +246 -175
- snowflake/ml/modeling/compose/column_transformer.py +248 -175
- snowflake/ml/modeling/compose/transformed_target_regressor.py +246 -175
- snowflake/ml/modeling/covariance/elliptic_envelope.py +246 -175
- snowflake/ml/modeling/covariance/empirical_covariance.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso.py +246 -175
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +246 -175
- snowflake/ml/modeling/covariance/ledoit_wolf.py +246 -175
- snowflake/ml/modeling/covariance/min_cov_det.py +246 -175
- snowflake/ml/modeling/covariance/oas.py +246 -175
- snowflake/ml/modeling/covariance/shrunk_covariance.py +246 -175
- snowflake/ml/modeling/decomposition/dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/factor_analysis.py +248 -175
- snowflake/ml/modeling/decomposition/fast_ica.py +248 -175
- snowflake/ml/modeling/decomposition/incremental_pca.py +248 -175
- snowflake/ml/modeling/decomposition/kernel_pca.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +248 -175
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/pca.py +248 -175
- snowflake/ml/modeling/decomposition/sparse_pca.py +248 -175
- snowflake/ml/modeling/decomposition/truncated_svd.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +248 -175
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/bagging_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/isolation_forest.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +246 -175
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +246 -175
- snowflake/ml/modeling/ensemble/stacking_regressor.py +248 -175
- snowflake/ml/modeling/ensemble/voting_classifier.py +248 -175
- snowflake/ml/modeling/ensemble/voting_regressor.py +248 -175
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fdr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fpr.py +248 -175
- snowflake/ml/modeling/feature_selection/select_fwe.py +248 -175
- snowflake/ml/modeling/feature_selection/select_k_best.py +248 -175
- snowflake/ml/modeling/feature_selection/select_percentile.py +248 -175
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +248 -175
- snowflake/ml/modeling/feature_selection/variance_threshold.py +248 -175
- snowflake/ml/modeling/framework/_utils.py +8 -1
- snowflake/ml/modeling/framework/base.py +72 -37
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +246 -175
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +246 -175
- snowflake/ml/modeling/impute/iterative_imputer.py +248 -175
- snowflake/ml/modeling/impute/knn_imputer.py +248 -175
- snowflake/ml/modeling/impute/missing_indicator.py +248 -175
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/nystroem.py +248 -175
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +248 -175
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +248 -175
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +248 -175
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +246 -175
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ard_regression.py +246 -175
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/gamma_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/huber_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/lars.py +246 -175
- snowflake/ml/modeling/linear_model/lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +246 -175
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +246 -175
- snowflake/ml/modeling/linear_model/linear_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression.py +246 -175
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +246 -175
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +246 -175
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/perceptron.py +246 -175
- snowflake/ml/modeling/linear_model/poisson_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ransac_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/ridge.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +246 -175
- snowflake/ml/modeling/linear_model/ridge_cv.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_classifier.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +246 -175
- snowflake/ml/modeling/linear_model/sgd_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +246 -175
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +246 -175
- snowflake/ml/modeling/manifold/isomap.py +248 -175
- snowflake/ml/modeling/manifold/mds.py +248 -175
- snowflake/ml/modeling/manifold/spectral_embedding.py +248 -175
- snowflake/ml/modeling/manifold/tsne.py +248 -175
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +246 -175
- snowflake/ml/modeling/mixture/gaussian_mixture.py +246 -175
- snowflake/ml/modeling/model_selection/grid_search_cv.py +63 -41
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +80 -38
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +246 -175
- snowflake/ml/modeling/multiclass/output_code_classifier.py +246 -175
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/complement_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +246 -175
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neighbors/kernel_density.py +246 -175
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_centroid.py +246 -175
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +246 -175
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +248 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +246 -175
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +246 -175
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +248 -175
- snowflake/ml/modeling/neural_network/mlp_classifier.py +246 -175
- snowflake/ml/modeling/neural_network/mlp_regressor.py +246 -175
- snowflake/ml/modeling/pipeline/pipeline.py +517 -35
- snowflake/ml/modeling/preprocessing/binarizer.py +1 -5
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -5
- snowflake/ml/modeling/preprocessing/label_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +10 -12
- snowflake/ml/modeling/preprocessing/normalizer.py +1 -5
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +13 -5
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +1 -5
- snowflake/ml/modeling/preprocessing/polynomial_features.py +248 -175
- snowflake/ml/modeling/preprocessing/robust_scaler.py +1 -5
- snowflake/ml/modeling/preprocessing/standard_scaler.py +11 -11
- snowflake/ml/modeling/semi_supervised/label_propagation.py +246 -175
- snowflake/ml/modeling/semi_supervised/label_spreading.py +246 -175
- snowflake/ml/modeling/svm/linear_svc.py +246 -175
- snowflake/ml/modeling/svm/linear_svr.py +246 -175
- snowflake/ml/modeling/svm/nu_svc.py +246 -175
- snowflake/ml/modeling/svm/nu_svr.py +246 -175
- snowflake/ml/modeling/svm/svc.py +246 -175
- snowflake/ml/modeling/svm/svr.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/decision_tree_regressor.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_classifier.py +246 -175
- snowflake/ml/modeling/tree/extra_tree_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgb_regressor.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +246 -175
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +246 -175
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/registry/registry.py +1 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +129 -57
- snowflake_ml_python-1.5.0.dist-info/RECORD +380 -0
- snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +0 -97
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- snowflake_ml_python-1.4.0.dist-info/RECORD +0 -370
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.0.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
|
-
import
|
1
|
+
import os
|
2
2
|
import pathlib
|
3
3
|
import tempfile
|
4
|
-
from typing import Any, Dict, List, Optional, Union, cast
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
@@ -19,7 +19,9 @@ from snowflake.ml.model._model_composer.model_manifest import (
|
|
19
19
|
model_manifest,
|
20
20
|
model_manifest_schema,
|
21
21
|
)
|
22
|
-
from snowflake.ml.model._packager.
|
22
|
+
from snowflake.ml.model._packager.model_env import model_env
|
23
|
+
from snowflake.ml.model._packager.model_meta import model_meta
|
24
|
+
from snowflake.ml.model._packager.model_runtime import model_runtime
|
23
25
|
from snowflake.ml.model._signatures import snowpark_handler
|
24
26
|
from snowflake.snowpark import dataframe, row, session
|
25
27
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
@@ -337,52 +339,90 @@ class ModelOperator:
|
|
337
339
|
mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
|
338
340
|
return mm.load()
|
339
341
|
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
)
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
return model_meta.ModelMetadata._validate_model_metadata(raw_model_meta)
|
342
|
+
@staticmethod
|
343
|
+
def _match_model_spec_with_sql_functions(
|
344
|
+
sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
|
345
|
+
) -> Dict[sql_identifier.SqlIdentifier, str]:
|
346
|
+
res = {}
|
347
|
+
for target_method in target_methods:
|
348
|
+
# Here we need to find the SQL function corresponding to the Python function.
|
349
|
+
# If the python function name is `abc`, then SQL function name can be `ABC` or `"abc"`.
|
350
|
+
# We will try to match`"abc"` first, then `ABC`.
|
351
|
+
# The reason why is because, if we have two python methods whose names are `abc` and `aBc`.
|
352
|
+
# At most 1 of them can be `ABC`, so if we check `"abc"` or `"aBc"` first we could resolve them correctly.
|
353
|
+
function_name = sql_identifier.SqlIdentifier(target_method, case_sensitive=True)
|
354
|
+
if function_name not in sql_functions_names:
|
355
|
+
function_name = sql_identifier.SqlIdentifier(target_method)
|
356
|
+
assert (
|
357
|
+
function_name in sql_functions_names
|
358
|
+
), f"Unable to match {target_method} in {sql_functions_names}."
|
359
|
+
res[function_name] = target_method
|
360
|
+
return res
|
360
361
|
|
361
|
-
def
|
362
|
+
def get_functions(
|
362
363
|
self,
|
363
364
|
*,
|
364
365
|
model_name: sql_identifier.SqlIdentifier,
|
365
366
|
version_name: sql_identifier.SqlIdentifier,
|
366
367
|
statement_params: Optional[Dict[str, Any]] = None,
|
367
|
-
) -> model_manifest_schema.
|
368
|
-
|
368
|
+
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
369
|
+
raw_model_spec_res = self._model_client.show_versions(
|
370
|
+
model_name=model_name,
|
371
|
+
version_name=version_name,
|
372
|
+
check_model_details=True,
|
373
|
+
statement_params={**(statement_params or {}), "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True},
|
374
|
+
)[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
|
375
|
+
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
376
|
+
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
377
|
+
show_functions_res = self._model_version_client.show_functions(
|
369
378
|
model_name=model_name,
|
370
379
|
version_name=version_name,
|
371
380
|
statement_params=statement_params,
|
372
|
-
)
|
373
|
-
|
374
|
-
|
375
|
-
|
381
|
+
)
|
382
|
+
function_names_and_types = []
|
383
|
+
for r in show_functions_res:
|
384
|
+
function_name = sql_identifier.SqlIdentifier(
|
385
|
+
r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
|
386
|
+
)
|
387
|
+
|
388
|
+
function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
389
|
+
try:
|
390
|
+
return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
|
391
|
+
except KeyError:
|
392
|
+
pass
|
393
|
+
else:
|
394
|
+
if "TABLE" in return_type:
|
395
|
+
function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
396
|
+
|
397
|
+
function_names_and_types.append((function_name, function_type))
|
398
|
+
|
399
|
+
signatures = model_spec["signatures"]
|
400
|
+
function_names = [name for name, _ in function_names_and_types]
|
401
|
+
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
402
|
+
function_names, list(signatures.keys())
|
403
|
+
)
|
404
|
+
|
405
|
+
return [
|
406
|
+
model_manifest_schema.ModelFunctionInfo(
|
407
|
+
name=function_name.identifier(),
|
408
|
+
target_method=function_name_mapping[function_name],
|
409
|
+
target_method_function_type=function_type,
|
410
|
+
signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
|
411
|
+
)
|
412
|
+
for function_name, function_type in function_names_and_types
|
413
|
+
]
|
376
414
|
|
377
415
|
def invoke_method(
|
378
416
|
self,
|
379
417
|
*,
|
380
418
|
method_name: sql_identifier.SqlIdentifier,
|
419
|
+
method_function_type: str,
|
381
420
|
signature: model_signature.ModelSignature,
|
382
421
|
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
383
422
|
model_name: sql_identifier.SqlIdentifier,
|
384
423
|
version_name: sql_identifier.SqlIdentifier,
|
385
424
|
strict_input_validation: bool = False,
|
425
|
+
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
386
426
|
statement_params: Optional[Dict[str, str]] = None,
|
387
427
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
388
428
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
@@ -420,15 +460,27 @@ class ModelOperator:
|
|
420
460
|
if output_name in original_cols:
|
421
461
|
original_cols.remove(output_name)
|
422
462
|
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
463
|
+
if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
|
464
|
+
df_res = self._model_version_client.invoke_function_method(
|
465
|
+
method_name=method_name,
|
466
|
+
input_df=s_df,
|
467
|
+
input_args=input_args,
|
468
|
+
returns=returns,
|
469
|
+
model_name=model_name,
|
470
|
+
version_name=version_name,
|
471
|
+
statement_params=statement_params,
|
472
|
+
)
|
473
|
+
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
474
|
+
df_res = self._model_version_client.invoke_table_function_method(
|
475
|
+
method_name=method_name,
|
476
|
+
input_df=s_df,
|
477
|
+
input_args=input_args,
|
478
|
+
partition_column=partition_column,
|
479
|
+
returns=returns,
|
480
|
+
model_name=model_name,
|
481
|
+
version_name=version_name,
|
482
|
+
statement_params=statement_params,
|
483
|
+
)
|
432
484
|
|
433
485
|
if keep_order:
|
434
486
|
df_res = df_res.sort(
|
@@ -437,7 +489,11 @@ class ModelOperator:
|
|
437
489
|
)
|
438
490
|
|
439
491
|
if not output_with_input_features:
|
440
|
-
|
492
|
+
cols_to_drop = original_cols
|
493
|
+
if partition_column is not None:
|
494
|
+
# don't drop partition column
|
495
|
+
cols_to_drop.remove(partition_column.identifier())
|
496
|
+
df_res = df_res.drop(*cols_to_drop)
|
441
497
|
|
442
498
|
# Get final result
|
443
499
|
if not isinstance(X, dataframe.DataFrame):
|
@@ -463,3 +519,66 @@ class ModelOperator:
|
|
463
519
|
model_name=model_name,
|
464
520
|
statement_params=statement_params,
|
465
521
|
)
|
522
|
+
|
523
|
+
def rename(
|
524
|
+
self,
|
525
|
+
*,
|
526
|
+
model_name: sql_identifier.SqlIdentifier,
|
527
|
+
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
528
|
+
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
529
|
+
new_model_name: sql_identifier.SqlIdentifier,
|
530
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
531
|
+
) -> None:
|
532
|
+
self._model_client.rename(
|
533
|
+
model_name=model_name,
|
534
|
+
new_model_db=new_model_db,
|
535
|
+
new_model_schema=new_model_schema,
|
536
|
+
new_model_name=new_model_name,
|
537
|
+
statement_params=statement_params,
|
538
|
+
)
|
539
|
+
|
540
|
+
# Map indicating in different modes, the path to list and download.
|
541
|
+
# The boolean value indicates if it is a directory,
|
542
|
+
MODEL_FILE_DOWNLOAD_PATTERN = {
|
543
|
+
"minimal": {
|
544
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
|
545
|
+
/ model_meta.MODEL_METADATA_FILE: False,
|
546
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) / model_env._DEFAULT_ENV_DIR: True,
|
547
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
|
548
|
+
/ model_runtime.ModelRuntime.RUNTIME_DIR_REL_PATH: True,
|
549
|
+
},
|
550
|
+
"model": {pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH): True},
|
551
|
+
"full": {pathlib.PurePosixPath(os.curdir): True},
|
552
|
+
}
|
553
|
+
|
554
|
+
def download_files(
|
555
|
+
self,
|
556
|
+
*,
|
557
|
+
model_name: sql_identifier.SqlIdentifier,
|
558
|
+
version_name: sql_identifier.SqlIdentifier,
|
559
|
+
target_path: pathlib.Path,
|
560
|
+
mode: Literal["full", "model", "minimal"] = "model",
|
561
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
562
|
+
) -> None:
|
563
|
+
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
564
|
+
list_file_res = self._model_version_client.list_file(
|
565
|
+
model_name=model_name,
|
566
|
+
version_name=version_name,
|
567
|
+
file_path=remote_rel_path,
|
568
|
+
is_dir=is_dir,
|
569
|
+
statement_params=statement_params,
|
570
|
+
)
|
571
|
+
file_list = [
|
572
|
+
pathlib.PurePosixPath(*pathlib.PurePosixPath(row.name).parts[2:]) # versions/<version_name>/...
|
573
|
+
for row in list_file_res
|
574
|
+
]
|
575
|
+
for stage_file_path in file_list:
|
576
|
+
local_file_dir = target_path / stage_file_path.parent
|
577
|
+
local_file_dir.mkdir(parents=True, exist_ok=True)
|
578
|
+
self._model_version_client.get_file(
|
579
|
+
model_name=model_name,
|
580
|
+
version_name=version_name,
|
581
|
+
file_path=stage_file_path,
|
582
|
+
target_path=local_file_dir,
|
583
|
+
statement_params=statement_params,
|
584
|
+
)
|
@@ -16,7 +16,7 @@ class ModelSQLClient:
|
|
16
16
|
MODEL_VERSION_NAME_COL_NAME = "name"
|
17
17
|
MODEL_VERSION_COMMENT_COL_NAME = "comment"
|
18
18
|
MODEL_VERSION_METADATA_COL_NAME = "metadata"
|
19
|
-
|
19
|
+
MODEL_VERSION_MODEL_SPEC_COL_NAME = "model_spec"
|
20
20
|
|
21
21
|
def __init__(
|
22
22
|
self,
|
@@ -72,6 +72,7 @@ class ModelSQLClient:
|
|
72
72
|
model_name: sql_identifier.SqlIdentifier,
|
73
73
|
version_name: Optional[sql_identifier.SqlIdentifier] = None,
|
74
74
|
validate_result: bool = True,
|
75
|
+
check_model_details: bool = False,
|
75
76
|
statement_params: Optional[Dict[str, Any]] = None,
|
76
77
|
) -> List[row.Row]:
|
77
78
|
like_sql = ""
|
@@ -87,10 +88,11 @@ class ModelSQLClient:
|
|
87
88
|
.has_column(ModelSQLClient.MODEL_VERSION_NAME_COL_NAME, allow_empty=True)
|
88
89
|
.has_column(ModelSQLClient.MODEL_VERSION_COMMENT_COL_NAME, allow_empty=True)
|
89
90
|
.has_column(ModelSQLClient.MODEL_VERSION_METADATA_COL_NAME, allow_empty=True)
|
90
|
-
.has_column(ModelSQLClient.MODEL_VERSION_USER_DATA_COL_NAME, allow_empty=True)
|
91
91
|
)
|
92
92
|
if validate_result and version_name:
|
93
93
|
res = res.has_dimensions(expected_rows=1)
|
94
|
+
if check_model_details:
|
95
|
+
res = res.has_column(ModelSQLClient.MODEL_VERSION_MODEL_SPEC_COL_NAME, allow_empty=True)
|
94
96
|
|
95
97
|
return res.validate()
|
96
98
|
|
@@ -118,3 +120,24 @@ class ModelSQLClient:
|
|
118
120
|
f"DROP MODEL {self.fully_qualified_model_name(model_name)}",
|
119
121
|
statement_params=statement_params,
|
120
122
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
123
|
+
|
124
|
+
def rename(
|
125
|
+
self,
|
126
|
+
*,
|
127
|
+
model_name: sql_identifier.SqlIdentifier,
|
128
|
+
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
129
|
+
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
130
|
+
new_model_name: sql_identifier.SqlIdentifier,
|
131
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
132
|
+
) -> None:
|
133
|
+
# Use registry's database and schema if a non fully qualified new model name is provided.
|
134
|
+
new_fully_qualified_name = identifier.get_schema_level_object_identifier(
|
135
|
+
new_model_db.identifier() if new_model_db else self._database_name.identifier(),
|
136
|
+
new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
|
137
|
+
new_model_name.identifier(),
|
138
|
+
)
|
139
|
+
query_result_checker.SqlResultValidator(
|
140
|
+
self._session,
|
141
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} RENAME TO {new_fully_qualified_name}",
|
142
|
+
statement_params=statement_params,
|
143
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -9,7 +9,7 @@ from snowflake.ml._internal.utils import (
|
|
9
9
|
query_result_checker,
|
10
10
|
sql_identifier,
|
11
11
|
)
|
12
|
-
from snowflake.snowpark import dataframe, functions as F, session, types as spt
|
12
|
+
from snowflake.snowpark import dataframe, functions as F, row, session, types as spt
|
13
13
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
14
14
|
|
15
15
|
|
@@ -21,6 +21,9 @@ def _normalize_url_for_sql(url: str) -> str:
|
|
21
21
|
|
22
22
|
|
23
23
|
class ModelVersionSQLClient:
|
24
|
+
FUNCTION_NAME_COL_NAME = "name"
|
25
|
+
FUNCTION_RETURN_TYPE_COL_NAME = "return_type"
|
26
|
+
|
24
27
|
def __init__(
|
25
28
|
self,
|
26
29
|
session: session.Session,
|
@@ -93,6 +96,38 @@ class ModelVersionSQLClient:
|
|
93
96
|
statement_params=statement_params,
|
94
97
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
95
98
|
|
99
|
+
def list_file(
|
100
|
+
self,
|
101
|
+
*,
|
102
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
|
+
version_name: sql_identifier.SqlIdentifier,
|
104
|
+
file_path: pathlib.PurePosixPath,
|
105
|
+
is_dir: bool = False,
|
106
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
107
|
+
) -> List[row.Row]:
|
108
|
+
# Workaround for snowURL bug.
|
109
|
+
trailing_slash = "/" if is_dir else ""
|
110
|
+
|
111
|
+
stage_location = (
|
112
|
+
pathlib.PurePosixPath(
|
113
|
+
self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
|
114
|
+
).as_posix()
|
115
|
+
+ trailing_slash
|
116
|
+
)
|
117
|
+
stage_location_url = ParseResult(
|
118
|
+
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
119
|
+
).geturl()
|
120
|
+
|
121
|
+
return (
|
122
|
+
query_result_checker.SqlResultValidator(
|
123
|
+
self._session,
|
124
|
+
f"List {_normalize_url_for_sql(stage_location_url)}",
|
125
|
+
statement_params=statement_params,
|
126
|
+
)
|
127
|
+
.has_column("name")
|
128
|
+
.validate()
|
129
|
+
)
|
130
|
+
|
96
131
|
def get_file(
|
97
132
|
self,
|
98
133
|
*,
|
@@ -124,6 +159,24 @@ class ModelVersionSQLClient:
|
|
124
159
|
).has_dimensions(expected_rows=1).validate()
|
125
160
|
return target_path / file_path.name
|
126
161
|
|
162
|
+
def show_functions(
|
163
|
+
self,
|
164
|
+
*,
|
165
|
+
model_name: sql_identifier.SqlIdentifier,
|
166
|
+
version_name: sql_identifier.SqlIdentifier,
|
167
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
168
|
+
) -> List[row.Row]:
|
169
|
+
res = query_result_checker.SqlResultValidator(
|
170
|
+
self._session,
|
171
|
+
(
|
172
|
+
f"SHOW FUNCTIONS IN MODEL {self.fully_qualified_model_name(model_name)}"
|
173
|
+
f" VERSION {version_name.identifier()}"
|
174
|
+
),
|
175
|
+
statement_params=statement_params,
|
176
|
+
).has_column(ModelVersionSQLClient.FUNCTION_NAME_COL_NAME, allow_empty=True)
|
177
|
+
|
178
|
+
return res.validate()
|
179
|
+
|
127
180
|
def set_comment(
|
128
181
|
self,
|
129
182
|
*,
|
@@ -141,7 +194,7 @@ class ModelVersionSQLClient:
|
|
141
194
|
statement_params=statement_params,
|
142
195
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
143
196
|
|
144
|
-
def
|
197
|
+
def invoke_function_method(
|
145
198
|
self,
|
146
199
|
*,
|
147
200
|
model_name: sql_identifier.SqlIdentifier,
|
@@ -211,6 +264,82 @@ class ModelVersionSQLClient:
|
|
211
264
|
|
212
265
|
return output_df
|
213
266
|
|
267
|
+
def invoke_table_function_method(
|
268
|
+
self,
|
269
|
+
*,
|
270
|
+
model_name: sql_identifier.SqlIdentifier,
|
271
|
+
version_name: sql_identifier.SqlIdentifier,
|
272
|
+
method_name: sql_identifier.SqlIdentifier,
|
273
|
+
input_df: dataframe.DataFrame,
|
274
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
275
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
276
|
+
partition_column: Optional[sql_identifier.SqlIdentifier],
|
277
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
278
|
+
) -> dataframe.DataFrame:
|
279
|
+
with_statements = []
|
280
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
281
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
282
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
283
|
+
else:
|
284
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
285
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
286
|
+
self._database_name.identifier(),
|
287
|
+
self._schema_name.identifier(),
|
288
|
+
tmp_table_name,
|
289
|
+
)
|
290
|
+
input_df.write.save_as_table( # type: ignore[call-overload]
|
291
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
292
|
+
mode="errorifexists",
|
293
|
+
table_type="temporary",
|
294
|
+
statement_params=statement_params,
|
295
|
+
)
|
296
|
+
|
297
|
+
module_version_alias = "MODEL_VERSION_ALIAS"
|
298
|
+
with_statements.append(
|
299
|
+
f"{module_version_alias} AS "
|
300
|
+
f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
|
301
|
+
)
|
302
|
+
|
303
|
+
partition_by = partition_column.identifier() if partition_column is not None else "1"
|
304
|
+
|
305
|
+
args_sql_list = []
|
306
|
+
for input_arg_value in input_args:
|
307
|
+
args_sql_list.append(input_arg_value)
|
308
|
+
|
309
|
+
args_sql = ", ".join(args_sql_list)
|
310
|
+
|
311
|
+
sql = textwrap.dedent(
|
312
|
+
f"""WITH {','.join(with_statements)}
|
313
|
+
SELECT *,
|
314
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
315
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
|
316
|
+
OVER (PARTITION BY {partition_by}))"""
|
317
|
+
)
|
318
|
+
|
319
|
+
output_df = self._session.sql(sql)
|
320
|
+
|
321
|
+
# Prepare the output
|
322
|
+
output_cols = []
|
323
|
+
output_names = []
|
324
|
+
|
325
|
+
for output_name, output_type, output_col_name in returns:
|
326
|
+
output_cols.append(F.col(output_name).astype(output_type))
|
327
|
+
output_names.append(output_col_name)
|
328
|
+
|
329
|
+
if partition_column is not None:
|
330
|
+
output_cols.append(F.col(partition_column.identifier()))
|
331
|
+
output_names.append(partition_column)
|
332
|
+
|
333
|
+
output_df = output_df.with_columns(
|
334
|
+
col_names=output_names,
|
335
|
+
values=output_cols,
|
336
|
+
)
|
337
|
+
|
338
|
+
if statement_params:
|
339
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
340
|
+
|
341
|
+
return output_df
|
342
|
+
|
214
343
|
def set_metadata(
|
215
344
|
self,
|
216
345
|
metadata_dict: Dict[str, Any],
|
@@ -37,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
37
37
|
session: snowpark.Session,
|
38
38
|
artifact_stage_location: str,
|
39
39
|
compute_pool: str,
|
40
|
+
job_name: str,
|
40
41
|
external_access_integrations: List[str],
|
41
42
|
) -> None:
|
42
43
|
"""Initialization
|
@@ -49,6 +50,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
49
50
|
artifact_stage_location: Spec file and future deployment related artifacts will be stored under
|
50
51
|
{stage}/models/{model_id}
|
51
52
|
compute_pool: The compute pool used to run docker image build workload.
|
53
|
+
job_name: job_name to use.
|
52
54
|
external_access_integrations: EAIs for network connection.
|
53
55
|
"""
|
54
56
|
self.context_dir = context_dir
|
@@ -58,6 +60,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
58
60
|
self.artifact_stage_location = artifact_stage_location
|
59
61
|
self.compute_pool = compute_pool
|
60
62
|
self.external_access_integrations = external_access_integrations
|
63
|
+
self.job_name = job_name
|
61
64
|
self.client = snowservice_client.SnowServiceClient(session)
|
62
65
|
|
63
66
|
assert artifact_stage_location.startswith(
|
@@ -203,8 +206,9 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
203
206
|
)
|
204
207
|
|
205
208
|
def _launch_kaniko_job(self, spec_stage_location: str) -> None:
|
206
|
-
logger.debug("Submitting job for building docker image with kaniko")
|
209
|
+
logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko")
|
207
210
|
self.client.create_job(
|
211
|
+
job_name=self.job_name,
|
208
212
|
compute_pool=self.compute_pool,
|
209
213
|
spec_stage_location=spec_stage_location,
|
210
214
|
external_access_integrations=self.external_access_integrations,
|
@@ -30,6 +30,7 @@ USER mambauser
|
|
30
30
|
|
31
31
|
# Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time.
|
32
32
|
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
33
|
+
ARG MAMBA_NO_LOW_SPEED_LIMIT=1
|
33
34
|
|
34
35
|
# Bitsandbytes uses this ENVVAR to determine CUDA library location
|
35
36
|
ENV CONDA_PREFIX=/opt/conda
|
@@ -346,6 +346,7 @@ class SnowServiceDeployment:
|
|
346
346
|
(db, schema, _, _) = identifier.parse_schema_level_object_identifier(service_func_name)
|
347
347
|
|
348
348
|
self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}")
|
349
|
+
self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}")
|
349
350
|
# Spec file and future deployment related artifacts will be stored under {stage}/models/{model_id}
|
350
351
|
self._model_artifact_stage_location = posixpath.join(deployment_stage_path, "models", self.id)
|
351
352
|
self.debug_dir: Optional[str] = None
|
@@ -468,6 +469,7 @@ class SnowServiceDeployment:
|
|
468
469
|
session=self.session,
|
469
470
|
artifact_stage_location=self._model_artifact_stage_location,
|
470
471
|
compute_pool=self.options.compute_pool,
|
472
|
+
job_name=self._job_name,
|
471
473
|
external_access_integrations=self.options.external_access_integrations,
|
472
474
|
)
|
473
475
|
else:
|
@@ -17,11 +17,6 @@ class ResourceStatus(Enum):
|
|
17
17
|
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
18
18
|
|
19
19
|
|
20
|
-
RESOURCE_TO_STATUS_FUNCTION_MAPPING = {
|
21
|
-
ResourceType.SERVICE: "SYSTEM$GET_SERVICE_STATUS",
|
22
|
-
ResourceType.JOB: "SYSTEM$GET_JOB_STATUS",
|
23
|
-
}
|
24
|
-
|
25
20
|
PREDICT = "predict"
|
26
21
|
STAGE = "stage"
|
27
22
|
COMPUTE_POOL = "compute_pool"
|