snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -0
- snowflake/cortex/_classify_text.py +2 -2
- snowflake/cortex/_embed_text_1024.py +37 -0
- snowflake/cortex/_embed_text_768.py +37 -0
- snowflake/cortex/_extract_answer.py +2 -2
- snowflake/cortex/_sentiment.py +2 -2
- snowflake/cortex/_summarize.py +2 -2
- snowflake/cortex/_translate.py +2 -2
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +5 -5
- snowflake/ml/_internal/exceptions/error_codes.py +2 -0
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/db_utils.py +50 -0
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/service_logger.py +63 -0
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +26 -2
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/ingestor_utils.py +20 -10
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/access_manager.py +3 -3
- snowflake/ml/feature_store/feature_store.py +60 -19
- snowflake/ml/feature_store/feature_view.py +84 -30
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/fileset/stage_fs.py +2 -1
- snowflake/ml/lineage/lineage_node.py +7 -2
- snowflake/ml/model/__init__.py +1 -2
- snowflake/ml/model/_client/model/model_version_impl.py +96 -12
- snowflake/ml/model/_client/ops/model_ops.py +124 -6
- snowflake/ml/model/_client/ops/service_ops.py +309 -9
- snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
- snowflake/ml/model/_client/sql/_base.py +5 -0
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +9 -5
- snowflake/ml/model/_client/sql/service.py +121 -20
- snowflake/ml/model/_model_composer/model_composer.py +11 -39
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
- snowflake/ml/model/_packager/model_env/model_env.py +4 -38
- snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
- snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
- snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
- snowflake/ml/model/_packager/model_packager.py +14 -8
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/type_hints.py +12 -145
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
- snowflake/ml/modeling/cluster/birch.py +61 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
- snowflake/ml/modeling/cluster/dbscan.py +61 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
- snowflake/ml/modeling/cluster/k_means.py +61 -21
- snowflake/ml/modeling/cluster/mean_shift.py +61 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
- snowflake/ml/modeling/cluster/optics.py +61 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
- snowflake/ml/modeling/compose/column_transformer.py +61 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
- snowflake/ml/modeling/covariance/oas.py +61 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/pca.py +61 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
- snowflake/ml/modeling/impute/knn_imputer.py +61 -21
- snowflake/ml/modeling/impute/missing_indicator.py +61 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/lars.py +61 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/perceptron.py +61 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/ridge.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
- snowflake/ml/modeling/manifold/isomap.py +61 -21
- snowflake/ml/modeling/manifold/mds.py +61 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
- snowflake/ml/modeling/manifold/tsne.py +61 -21
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
- snowflake/ml/modeling/metrics/ranking.py +0 -3
- snowflake/ml/modeling/metrics/regression.py +0 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -13
- snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
- snowflake/ml/modeling/svm/linear_svc.py +61 -21
- snowflake/ml/modeling/svm/linear_svr.py +61 -21
- snowflake/ml/modeling/svm/nu_svc.py +61 -21
- snowflake/ml/modeling/svm/nu_svr.py +61 -21
- snowflake/ml/modeling/svm/svc.py +61 -21
- snowflake/ml/modeling/svm/svr.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
- snowflake/ml/monitoring/_client/model_monitor.py +126 -0
- snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
- snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
- snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
- snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
- snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
- snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
- snowflake/ml/monitoring/entities/output_score_type.py +90 -0
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/registry.py +166 -8
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
- snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
- snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
- snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
- snowflake/ml/_internal/utils/session_token_manager.py +0 -46
- snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
- snowflake/ml/_internal/utils/uri.py +0 -77
- snowflake/ml/data/torch_dataset.py +0 -33
- snowflake/ml/model/_api.py +0 -568
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
- snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
- snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
- snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
- snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
- snowflake/ml/model/deploy_platforms.py +0 -6
- snowflake/ml/model/models/llm.py +0 -104
- snowflake/ml/monitoring/monitor.py +0 -203
- snowflake/ml/registry/_initial_schema.py +0 -142
- snowflake/ml/registry/_schema.py +0 -82
- snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
- snowflake/ml/registry/_schema_version_manager.py +0 -163
- snowflake/ml/registry/model_registry.py +0 -2048
- snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,63 @@
|
|
1
|
+
import enum
|
2
|
+
import logging
|
3
|
+
import sys
|
4
|
+
|
5
|
+
|
6
|
+
class LogColor(enum.Enum):
|
7
|
+
GREY = "\x1b[38;20m"
|
8
|
+
RED = "\x1b[31;20m"
|
9
|
+
BOLD_RED = "\x1b[31;1m"
|
10
|
+
YELLOW = "\x1b[33;20m"
|
11
|
+
BLUE = "\x1b[34;20m"
|
12
|
+
GREEN = "\x1b[32;20m"
|
13
|
+
|
14
|
+
|
15
|
+
class CustomFormatter(logging.Formatter):
|
16
|
+
|
17
|
+
reset = "\x1b[0m"
|
18
|
+
log_format = "%(name)s [%(asctime)s] [%(levelname)s] %(message)s"
|
19
|
+
|
20
|
+
def __init__(self, info_color: LogColor) -> None:
|
21
|
+
super().__init__()
|
22
|
+
self.level_colors = {
|
23
|
+
logging.DEBUG: LogColor.GREY.value,
|
24
|
+
logging.INFO: info_color.value,
|
25
|
+
logging.WARNING: LogColor.YELLOW.value,
|
26
|
+
logging.ERROR: LogColor.RED.value,
|
27
|
+
logging.CRITICAL: LogColor.BOLD_RED.value,
|
28
|
+
}
|
29
|
+
|
30
|
+
def format(self, record: logging.LogRecord) -> str:
|
31
|
+
# default to DEBUG color
|
32
|
+
fmt = self.level_colors.get(record.levelno, self.level_colors[logging.DEBUG]) + self.log_format + self.reset
|
33
|
+
formatter = logging.Formatter(fmt)
|
34
|
+
|
35
|
+
# split the log message by lines and format each line individually
|
36
|
+
original_message = record.getMessage()
|
37
|
+
message_lines = original_message.splitlines()
|
38
|
+
formatted_lines = [
|
39
|
+
formatter.format(
|
40
|
+
logging.LogRecord(
|
41
|
+
name=record.name,
|
42
|
+
level=record.levelno,
|
43
|
+
pathname=record.pathname,
|
44
|
+
lineno=record.lineno,
|
45
|
+
msg=line,
|
46
|
+
args=None,
|
47
|
+
exc_info=None,
|
48
|
+
)
|
49
|
+
)
|
50
|
+
for line in message_lines
|
51
|
+
]
|
52
|
+
|
53
|
+
return "\n".join(formatted_lines)
|
54
|
+
|
55
|
+
|
56
|
+
def get_logger(logger_name: str, info_color: LogColor) -> logging.Logger:
|
57
|
+
logger = logging.getLogger(logger_name)
|
58
|
+
logger.setLevel(logging.INFO)
|
59
|
+
handler = logging.StreamHandler(sys.stdout)
|
60
|
+
handler.setLevel(logging.INFO)
|
61
|
+
handler.setFormatter(CustomFormatter(info_color))
|
62
|
+
logger.addHandler(handler)
|
63
|
+
return logger
|
@@ -2,7 +2,7 @@ import enum
|
|
2
2
|
from typing import Any, Dict, Optional, TypedDict, cast
|
3
3
|
|
4
4
|
from packaging import version
|
5
|
-
from typing_extensions import Required
|
5
|
+
from typing_extensions import NotRequired, Required
|
6
6
|
|
7
7
|
from snowflake.ml._internal.utils import query_result_checker
|
8
8
|
from snowflake.snowpark import session
|
@@ -52,7 +52,7 @@ class SnowflakeCloudType(enum.Enum):
|
|
52
52
|
|
53
53
|
|
54
54
|
class SnowflakeRegion(TypedDict):
|
55
|
-
region_group:
|
55
|
+
region_group: NotRequired[str]
|
56
56
|
snowflake_region: Required[str]
|
57
57
|
cloud: Required[SnowflakeCloudType]
|
58
58
|
region: Required[str]
|
@@ -64,23 +64,33 @@ def get_regions(
|
|
64
64
|
) -> Dict[str, SnowflakeRegion]:
|
65
65
|
res = (
|
66
66
|
query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
|
67
|
-
.has_column("region_group")
|
68
67
|
.has_column("snowflake_region")
|
69
68
|
.has_column("cloud")
|
70
69
|
.has_column("region")
|
71
70
|
.has_column("display_name")
|
72
71
|
.validate()
|
73
72
|
)
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
73
|
+
res_dict = {}
|
74
|
+
for r in res:
|
75
|
+
if hasattr(r, "region_group") and r.region_group:
|
76
|
+
key = f"{r.region_group}.{r.snowflake_region}"
|
77
|
+
res_dict[key] = SnowflakeRegion(
|
78
|
+
region_group=r.region_group,
|
79
|
+
snowflake_region=r.snowflake_region,
|
80
|
+
cloud=SnowflakeCloudType.from_value(r.cloud),
|
81
|
+
region=r.region,
|
82
|
+
display_name=r.display_name,
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
key = r.snowflake_region
|
86
|
+
res_dict[key] = SnowflakeRegion(
|
87
|
+
snowflake_region=r.snowflake_region,
|
88
|
+
cloud=SnowflakeCloudType.from_value(r.cloud),
|
89
|
+
region=r.region,
|
90
|
+
display_name=r.display_name,
|
91
|
+
)
|
92
|
+
|
93
|
+
return res_dict
|
84
94
|
|
85
95
|
|
86
96
|
def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import List, Optional, Tuple
|
1
|
+
from typing import List, Optional, Tuple, Union
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import identifier
|
4
4
|
|
@@ -84,7 +84,7 @@ def to_sql_identifiers(list_of_str: List[str], *, case_sensitive: bool = False)
|
|
84
84
|
def parse_fully_qualified_name(
|
85
85
|
name: str,
|
86
86
|
) -> Tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
|
87
|
-
db, schema, object
|
87
|
+
db, schema, object = identifier.parse_schema_level_object_identifier(name)
|
88
88
|
|
89
89
|
assert name is not None, f"Unable parse the input name `{name}` as fully qualified."
|
90
90
|
return (
|
@@ -92,3 +92,27 @@ def parse_fully_qualified_name(
|
|
92
92
|
SqlIdentifier(schema) if schema else None,
|
93
93
|
SqlIdentifier(object),
|
94
94
|
)
|
95
|
+
|
96
|
+
|
97
|
+
def get_fully_qualified_name(
|
98
|
+
db: Union[SqlIdentifier, str, None],
|
99
|
+
schema: Union[SqlIdentifier, str, None],
|
100
|
+
object: Union[SqlIdentifier, str],
|
101
|
+
session_db: Optional[str] = None,
|
102
|
+
session_schema: Optional[str] = None,
|
103
|
+
) -> str:
|
104
|
+
db_name: Optional[SqlIdentifier] = None
|
105
|
+
schema_name: Optional[SqlIdentifier] = None
|
106
|
+
if not db and session_db:
|
107
|
+
db_name = SqlIdentifier(session_db)
|
108
|
+
elif isinstance(db, str):
|
109
|
+
db_name = SqlIdentifier(db)
|
110
|
+
if not schema and session_schema:
|
111
|
+
schema_name = SqlIdentifier(session_schema)
|
112
|
+
elif isinstance(schema, str):
|
113
|
+
schema_name = SqlIdentifier(schema)
|
114
|
+
return identifier.get_schema_level_object_identifier(
|
115
|
+
db=db_name.identifier() if db_name else None,
|
116
|
+
schema=schema_name.identifier() if schema_name else None,
|
117
|
+
object_name=object.identifier() if isinstance(object, SqlIdentifier) else SqlIdentifier(object).identifier(),
|
118
|
+
)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional, Tuple
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
|
-
from snowflake.ml._internal.utils import formatting, query_result_checker
|
4
|
+
from snowflake.ml._internal.utils import formatting, identifier, query_result_checker
|
5
|
+
from snowflake.snowpark import types
|
5
6
|
|
6
7
|
"""Table_manager is a set of utils that helps create tables.
|
7
8
|
|
@@ -104,3 +105,20 @@ def get_table_schema(session: snowpark.Session, table_name: str, qualified_schem
|
|
104
105
|
for row in result:
|
105
106
|
schema_dict[row["name"]] = row["type"]
|
106
107
|
return schema_dict
|
108
|
+
|
109
|
+
|
110
|
+
def get_table_schema_types(
|
111
|
+
session: snowpark.Session,
|
112
|
+
database: str,
|
113
|
+
schema: str,
|
114
|
+
table_name: str,
|
115
|
+
) -> Dict[str, types.DataType]:
|
116
|
+
fully_qualified_table_name = identifier.get_schema_level_object_identifier(
|
117
|
+
db=database, schema=schema, object_name=table_name
|
118
|
+
)
|
119
|
+
struct_fields: List[types.StructField] = session.table(fully_qualified_table_name).schema.fields
|
120
|
+
|
121
|
+
schema_dict: Dict[str, types.DataType] = {}
|
122
|
+
for field in struct_fields:
|
123
|
+
schema_dict[field.name] = field.datatype
|
124
|
+
return schema_dict
|
@@ -11,7 +11,6 @@ import pyarrow as pa
|
|
11
11
|
import pyarrow.dataset as pds
|
12
12
|
|
13
13
|
from snowflake import snowpark
|
14
|
-
from snowflake.connector import result_batch
|
15
14
|
from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
|
16
15
|
|
17
16
|
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
@@ -140,16 +139,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
140
139
|
# We may be able to optimize this by splitting the result batches into
|
141
140
|
# in-memory (first batch) and file URLs (subsequent batches) and creating a
|
142
141
|
# union dataset.
|
143
|
-
|
144
|
-
sources.extend(
|
145
|
-
b.to_arrow(self._session.connection)
|
146
|
-
if isinstance(b, result_batch.ArrowResultBatch)
|
147
|
-
else b.to_arrow()
|
148
|
-
for b in result_batches
|
149
|
-
)
|
150
|
-
# HACK: Mitigate typing inconsistencies in Snowpark results
|
151
|
-
if len(sources) > 0:
|
152
|
-
sources = [_cast_if_needed(s, sources[-1].schema) for s in sources]
|
142
|
+
sources.append(_cast_if_needed(ingestor_utils.get_dataframe_arrow_table(self._session, source)))
|
153
143
|
source_format = None # Arrow Dataset expects "None" for in-memory datasets
|
154
144
|
else:
|
155
145
|
raise RuntimeError(f"Unsupported data source type: {type(source)}")
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import os
|
1
2
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
3
|
|
3
4
|
import numpy.typing as npt
|
@@ -7,6 +8,10 @@ from snowflake import snowpark
|
|
7
8
|
from snowflake.ml._internal import telemetry
|
8
9
|
from snowflake.ml.data import data_ingestor, data_source
|
9
10
|
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
11
|
+
from snowflake.ml.modeling._internal.constants import (
|
12
|
+
IN_ML_RUNTIME_ENV_VAR,
|
13
|
+
USE_OPTIMIZED_DATA_INGESTOR,
|
14
|
+
)
|
10
15
|
|
11
16
|
if TYPE_CHECKING:
|
12
17
|
import pandas as pd
|
@@ -142,32 +147,41 @@ class DataConnector:
|
|
142
147
|
Returns:
|
143
148
|
A Pytorch iterable datapipe that yield data.
|
144
149
|
"""
|
145
|
-
from
|
150
|
+
from snowflake.ml.data import torch_utils
|
146
151
|
|
147
|
-
return
|
148
|
-
self._ingestor
|
152
|
+
return torch_utils.TorchDataPipeWrapper(
|
153
|
+
self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
|
149
154
|
)
|
150
155
|
|
151
156
|
@telemetry.send_api_usage_telemetry(
|
152
157
|
project=_PROJECT,
|
153
158
|
subproject_extractor=lambda self: type(self).__name__,
|
154
|
-
func_params_to_log=["shuffle"],
|
159
|
+
func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
|
155
160
|
)
|
156
|
-
def to_torch_dataset(
|
161
|
+
def to_torch_dataset(
|
162
|
+
self, *, batch_size: int = 1, shuffle: bool = False, drop_last_batch: bool = True
|
163
|
+
) -> "torch_data.IterableDataset": # type: ignore[type-arg]
|
157
164
|
"""Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader.
|
158
165
|
|
159
166
|
Return a PyTorch Dataset which iterates on rows of data.
|
160
167
|
|
161
168
|
Args:
|
169
|
+
batch_size: It specifies the size of each data batch which will be yielded in the result dataset.
|
170
|
+
Batching is pushed down to data ingestion level which may be more performant than DataLoader
|
171
|
+
batching.
|
162
172
|
shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
|
163
173
|
rows in each file will also be shuffled.
|
174
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
|
175
|
+
then the last batch will get dropped if its size is smaller than the given batch_size.
|
164
176
|
|
165
177
|
Returns:
|
166
178
|
A PyTorch Iterable Dataset that yields data.
|
167
179
|
"""
|
168
|
-
from snowflake.ml.data import
|
180
|
+
from snowflake.ml.data import torch_utils
|
169
181
|
|
170
|
-
return
|
182
|
+
return torch_utils.TorchDatasetWrapper(
|
183
|
+
self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
|
184
|
+
)
|
171
185
|
|
172
186
|
@telemetry.send_api_usage_telemetry(
|
173
187
|
project=_PROJECT,
|
@@ -184,3 +198,15 @@ class DataConnector:
|
|
184
198
|
A Pandas DataFrame.
|
185
199
|
"""
|
186
200
|
return self._ingestor.to_pandas(limit)
|
201
|
+
|
202
|
+
|
203
|
+
# Switch to use Runtime's Data Ingester if running in ML runtime
|
204
|
+
# Fail silently if the data ingester is not found
|
205
|
+
if os.getenv(IN_ML_RUNTIME_ENV_VAR) and os.getenv(USE_OPTIMIZED_DATA_INGESTOR):
|
206
|
+
try:
|
207
|
+
from runtime_external_entities import get_ingester_class
|
208
|
+
|
209
|
+
DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class()
|
210
|
+
except ImportError:
|
211
|
+
"""Runtime Default Ingester not found, ignore"""
|
212
|
+
pass
|
@@ -1,19 +1,17 @@
|
|
1
1
|
from typing import List, Optional
|
2
2
|
|
3
3
|
import fsspec
|
4
|
+
import pyarrow as pa
|
4
5
|
|
5
6
|
from snowflake import snowpark
|
6
|
-
from snowflake.connector import result_batch
|
7
|
+
from snowflake.connector import cursor as sf_cursor, result_batch
|
7
8
|
from snowflake.ml.data import data_source
|
8
9
|
from snowflake.ml.fileset import snowfs
|
9
10
|
|
10
11
|
_TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
|
11
12
|
|
12
13
|
|
13
|
-
def
|
14
|
-
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
15
|
-
) -> List[result_batch.ResultBatch]:
|
16
|
-
"""Retrieve the ResultBatches for a given query"""
|
14
|
+
def _get_dataframe_cursor(session: snowpark.Session, df_info: data_source.DataFrameInfo) -> sf_cursor.SnowflakeCursor:
|
17
15
|
cursor = session._conn._cursor
|
18
16
|
|
19
17
|
if df_info.query_id:
|
@@ -29,12 +27,24 @@ def get_dataframe_result_batches(
|
|
29
27
|
if cursor._prefetch_hook is None:
|
30
28
|
raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
|
31
29
|
cursor._prefetch_hook()
|
30
|
+
|
31
|
+
return cursor
|
32
|
+
|
33
|
+
|
34
|
+
def get_dataframe_result_batches(
|
35
|
+
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
36
|
+
) -> List[result_batch.ResultBatch]:
|
37
|
+
"""Retrieve the ResultBatches for a given query"""
|
38
|
+
cursor = _get_dataframe_cursor(session, df_info)
|
32
39
|
batches = cursor.get_result_batches()
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
40
|
+
return batches or []
|
41
|
+
|
42
|
+
|
43
|
+
def get_dataframe_arrow_table(session: snowpark.Session, df_info: data_source.DataFrameInfo) -> pa.Table:
|
44
|
+
"""Retrieve the full in-memory result for a given query"""
|
45
|
+
cursor = _get_dataframe_cursor(session, df_info)
|
46
|
+
table = cursor.fetch_arrow_all() # type: ignore[call-overload]
|
47
|
+
return table
|
38
48
|
|
39
49
|
|
40
50
|
def get_dataset_filesystem(
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from typing import Any, Dict, Iterator, List, Union
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import numpy.typing as npt
|
5
|
+
import torch.utils.data
|
6
|
+
|
7
|
+
from snowflake.ml.data import data_ingestor
|
8
|
+
|
9
|
+
|
10
|
+
class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
11
|
+
"""Wrap a DataIngestor into a PyTorch IterableDataset"""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
ingestor: data_ingestor.DataIngestor,
|
16
|
+
*,
|
17
|
+
batch_size: int,
|
18
|
+
shuffle: bool = False,
|
19
|
+
drop_last: bool = False,
|
20
|
+
squeeze_outputs: bool = True
|
21
|
+
) -> None:
|
22
|
+
"""Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
|
23
|
+
self._ingestor = ingestor
|
24
|
+
self._batch_size = batch_size
|
25
|
+
self._shuffle = shuffle
|
26
|
+
self._drop_last = drop_last
|
27
|
+
self._squeeze_outputs = squeeze_outputs
|
28
|
+
|
29
|
+
def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
|
30
|
+
max_idx = 0
|
31
|
+
filter_idx = 0
|
32
|
+
worker_info = torch.utils.data.get_worker_info()
|
33
|
+
if worker_info is not None:
|
34
|
+
max_idx = worker_info.num_workers - 1
|
35
|
+
filter_idx = worker_info.id
|
36
|
+
|
37
|
+
if self._shuffle and worker_info is not None:
|
38
|
+
raise RuntimeError("Dataset shuffling not currently supported with multithreading")
|
39
|
+
|
40
|
+
counter = 0
|
41
|
+
for batch in self._ingestor.to_batches(
|
42
|
+
batch_size=self._batch_size, shuffle=self._shuffle, drop_last_batch=self._drop_last
|
43
|
+
):
|
44
|
+
# Skip indices during multi-process data loading to prevent data duplication
|
45
|
+
if counter == filter_idx:
|
46
|
+
# Basic preprocessing on batch values: squeeze away extra dimensions
|
47
|
+
# and convert object arrays (e.g. strings) to lists
|
48
|
+
if self._squeeze_outputs:
|
49
|
+
yield {
|
50
|
+
k: (v.squeeze().tolist() if v.dtype == np.object_ else v.squeeze()) for k, v in batch.items()
|
51
|
+
}
|
52
|
+
else:
|
53
|
+
yield batch # type: ignore[misc]
|
54
|
+
|
55
|
+
if counter < max_idx:
|
56
|
+
counter += 1
|
57
|
+
else:
|
58
|
+
counter = 0
|
59
|
+
|
60
|
+
|
61
|
+
class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Dict[str, Any]]):
|
62
|
+
"""Wrap a DataIngestor into a PyTorch IterDataPipe"""
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False
|
66
|
+
) -> None:
|
67
|
+
"""Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
|
68
|
+
super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, squeeze_outputs=False)
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -472,9 +472,7 @@ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
|
|
472
472
|
|
473
473
|
def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> Tuple[str, str, str]:
|
474
474
|
"""Resolve a dataset name into a validated schema-level location identifier"""
|
475
|
-
db, schema, object_name
|
476
|
-
if others:
|
477
|
-
raise ValueError(f"Invalid identifier: unexpected '{others}'")
|
475
|
+
db, schema, object_name = identifier.parse_schema_level_object_identifier(dataset_name)
|
478
476
|
db = db or session.get_current_database()
|
479
477
|
schema = schema or session.get_current_schema()
|
480
478
|
return str(db), str(schema), str(object_name)
|
@@ -30,6 +30,7 @@ class _Privilege:
|
|
30
30
|
object_name: str
|
31
31
|
privileges: List[str]
|
32
32
|
scope: Optional[str] = None
|
33
|
+
optional: bool = False
|
33
34
|
|
34
35
|
|
35
36
|
@dataclass(frozen=True)
|
@@ -72,8 +73,7 @@ _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
|
|
72
73
|
_Privilege("VIEW", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"),
|
73
74
|
_Privilege("TABLE", _ALL_OBJECTS, ["SELECT", "REFERENCES"], "SCHEMA {database}.{schema}"),
|
74
75
|
_Privilege("DATASET", _ALL_OBJECTS, ["USAGE"], "SCHEMA {database}.{schema}"),
|
75
|
-
|
76
|
-
# _Privilege("WAREHOUSE", "{warehouse}", ["USAGE"]),
|
76
|
+
_Privilege("WAREHOUSE", "{warehouse}", ["USAGE"], optional=True),
|
77
77
|
],
|
78
78
|
_FeatureStoreRole.NONE: [],
|
79
79
|
}
|
@@ -109,7 +109,7 @@ def _grant_privileges(
|
|
109
109
|
query += f" TO ROLE {role_name}"
|
110
110
|
session.sql(query).collect()
|
111
111
|
except exceptions.SnowparkSQLException as e:
|
112
|
-
if any(
|
112
|
+
if p.optional or any(
|
113
113
|
s in e.message
|
114
114
|
for s in (
|
115
115
|
"Ask your account admin",
|
@@ -122,6 +122,14 @@ _DT_OR_VIEW_QUERY_PATTERN = re.compile(
|
|
122
122
|
flags=re.DOTALL | re.IGNORECASE | re.X,
|
123
123
|
)
|
124
124
|
|
125
|
+
_DT_INITIALIZE_PATTERN = re.compile(
|
126
|
+
r"""CREATE\ DYNAMIC\ TABLE\ .*
|
127
|
+
initialize\ =\ '(?P<initialize>.*)'\ .*?
|
128
|
+
AS\ .*
|
129
|
+
""",
|
130
|
+
flags=re.DOTALL | re.IGNORECASE | re.X,
|
131
|
+
)
|
132
|
+
|
125
133
|
_LIST_FEATURE_VIEW_SCHEMA = StructType(
|
126
134
|
[
|
127
135
|
StructField("name", StringType()),
|
@@ -565,11 +573,15 @@ class FeatureStore:
|
|
565
573
|
tagging_clause_str = ",\n".join(tagging_clause)
|
566
574
|
|
567
575
|
def create_col_desc(col: StructField) -> str:
|
568
|
-
desc = feature_view.feature_descs.get(SqlIdentifier(col.name), None)
|
576
|
+
desc = feature_view.feature_descs.get(SqlIdentifier(col.name), None) # type: ignore[union-attr]
|
569
577
|
desc = "" if desc is None else f"COMMENT '{desc}'"
|
570
578
|
return f"{col.name} {desc}"
|
571
579
|
|
572
|
-
column_descs =
|
580
|
+
column_descs = (
|
581
|
+
", ".join([f"{create_col_desc(col)}" for col in feature_view.output_schema.fields])
|
582
|
+
if feature_view.feature_descs is not None
|
583
|
+
else ""
|
584
|
+
)
|
573
585
|
|
574
586
|
if refresh_freq is not None:
|
575
587
|
schedule_task = refresh_freq != "DOWNSTREAM" and timeparse(refresh_freq) is None
|
@@ -604,7 +616,7 @@ class FeatureStore:
|
|
604
616
|
logger.info(f"Registered FeatureView {feature_view.name}/{version} successfully.")
|
605
617
|
return self.get_feature_view(feature_view.name, str(version))
|
606
618
|
|
607
|
-
@
|
619
|
+
@overload
|
608
620
|
def update_feature_view(
|
609
621
|
self,
|
610
622
|
name: str,
|
@@ -613,13 +625,37 @@ class FeatureStore:
|
|
613
625
|
refresh_freq: Optional[str] = None,
|
614
626
|
warehouse: Optional[str] = None,
|
615
627
|
desc: Optional[str] = None,
|
628
|
+
) -> FeatureView:
|
629
|
+
...
|
630
|
+
|
631
|
+
@overload
|
632
|
+
def update_feature_view(
|
633
|
+
self,
|
634
|
+
name: FeatureView,
|
635
|
+
version: Optional[str] = None,
|
636
|
+
*,
|
637
|
+
refresh_freq: Optional[str] = None,
|
638
|
+
warehouse: Optional[str] = None,
|
639
|
+
desc: Optional[str] = None,
|
640
|
+
) -> FeatureView:
|
641
|
+
...
|
642
|
+
|
643
|
+
@dispatch_decorator() # type: ignore[misc]
|
644
|
+
def update_feature_view(
|
645
|
+
self,
|
646
|
+
name: Union[FeatureView, str],
|
647
|
+
version: Optional[str] = None,
|
648
|
+
*,
|
649
|
+
refresh_freq: Optional[str] = None,
|
650
|
+
warehouse: Optional[str] = None,
|
651
|
+
desc: Optional[str] = None,
|
616
652
|
) -> FeatureView:
|
617
653
|
"""Update a registered feature view.
|
618
654
|
Check feature_view.py for which fields are allowed to be updated after registration.
|
619
655
|
|
620
656
|
Args:
|
621
|
-
name:
|
622
|
-
version: version of
|
657
|
+
name: FeatureView object or name to suspend.
|
658
|
+
version: Optional version of feature view. Must set when argument feature_view is a str.
|
623
659
|
refresh_freq: updated refresh frequency.
|
624
660
|
warehouse: updated warehouse.
|
625
661
|
desc: description of feature view.
|
@@ -661,7 +697,7 @@ class FeatureStore:
|
|
661
697
|
SnowflakeMLException: [RuntimeError] If FeatureView is not managed and refresh_freq is defined.
|
662
698
|
SnowflakeMLException: [RuntimeError] Failed to update feature view.
|
663
699
|
"""
|
664
|
-
feature_view = self.
|
700
|
+
feature_view = self._validate_feature_view_name_and_version_input(name, version)
|
665
701
|
new_desc = desc if desc is not None else feature_view.desc
|
666
702
|
|
667
703
|
if feature_view.status == FeatureViewStatus.STATIC:
|
@@ -696,7 +732,7 @@ class FeatureStore:
|
|
696
732
|
f"Update feature view {feature_view.name}/{feature_view.version} failed: {e}"
|
697
733
|
),
|
698
734
|
) from e
|
699
|
-
return self.get_feature_view(name=name, version=version)
|
735
|
+
return self.get_feature_view(name=feature_view.name, version=str(feature_view.version))
|
700
736
|
|
701
737
|
@overload
|
702
738
|
def read_feature_view(self, feature_view: str, version: str) -> DataFrame:
|
@@ -1795,6 +1831,7 @@ class FeatureStore:
|
|
1795
1831
|
)
|
1796
1832
|
WAREHOUSE = {warehouse}
|
1797
1833
|
REFRESH_MODE = {feature_view.refresh_mode}
|
1834
|
+
INITIALIZE = {feature_view.initialize}
|
1798
1835
|
AS {feature_view.query}
|
1799
1836
|
"""
|
1800
1837
|
self._session.sql(query).collect(block=block, statement_params=self._telemetry_stmp)
|
@@ -2121,7 +2158,7 @@ class FeatureStore:
|
|
2121
2158
|
if "." not in name:
|
2122
2159
|
return f"{self._config.full_schema_path}.{name}"
|
2123
2160
|
|
2124
|
-
db_name, schema_name, object_name
|
2161
|
+
db_name, schema_name, object_name = identifier.parse_schema_level_object_identifier(name)
|
2125
2162
|
return "{}.{}.{}".format(
|
2126
2163
|
db_name or self._config.database,
|
2127
2164
|
schema_name or self._config.schema,
|
@@ -2186,11 +2223,7 @@ class FeatureStore:
|
|
2186
2223
|
if len(fv_maps.keys()) == 0:
|
2187
2224
|
return self._session.create_dataframe([], schema=_LIST_FEATURE_VIEW_SCHEMA)
|
2188
2225
|
|
2189
|
-
filters = (
|
2190
|
-
[lambda d: d["entityName"].startswith(feature_view_name.resolved())] # type: ignore[union-attr]
|
2191
|
-
if feature_view_name
|
2192
|
-
else None
|
2193
|
-
)
|
2226
|
+
filters = [lambda d: d["entityName"].startswith(feature_view_name.resolved())] if feature_view_name else None
|
2194
2227
|
res = self._lookup_tagged_objects(self._get_entity_name(entity_name), filters)
|
2195
2228
|
|
2196
2229
|
output_values: List[List[Any]] = []
|
@@ -2273,6 +2306,8 @@ class FeatureStore:
|
|
2273
2306
|
entities = [find_and_compose_entity(n) for n in fv_metadata.entities]
|
2274
2307
|
ts_col = fv_metadata.timestamp_col
|
2275
2308
|
timestamp_col = ts_col if ts_col not in _LEGACY_TIMESTAMP_COL_PLACEHOLDER_VALS else None
|
2309
|
+
re_initialize = re.match(_DT_INITIALIZE_PATTERN, row["text"])
|
2310
|
+
initialize = re_initialize.group("initialize") if re_initialize is not None else "ON_CREATE"
|
2276
2311
|
|
2277
2312
|
fv = FeatureView._construct_feature_view(
|
2278
2313
|
name=name,
|
@@ -2281,18 +2316,23 @@ class FeatureStore:
|
|
2281
2316
|
timestamp_col=timestamp_col,
|
2282
2317
|
desc=desc,
|
2283
2318
|
version=version,
|
2284
|
-
status=
|
2285
|
-
|
2286
|
-
|
2319
|
+
status=(
|
2320
|
+
FeatureViewStatus(row["scheduling_state"])
|
2321
|
+
if len(row["scheduling_state"]) > 0
|
2322
|
+
else FeatureViewStatus.MASKED
|
2323
|
+
),
|
2287
2324
|
feature_descs=self._fetch_column_descs("DYNAMIC TABLE", fv_name),
|
2288
2325
|
refresh_freq=row["target_lag"],
|
2289
2326
|
database=self._config.database.identifier(),
|
2290
2327
|
schema=self._config.schema.identifier(),
|
2291
|
-
warehouse=
|
2292
|
-
|
2293
|
-
|
2328
|
+
warehouse=(
|
2329
|
+
SqlIdentifier(row["warehouse"], case_sensitive=True).identifier()
|
2330
|
+
if len(row["warehouse"]) > 0
|
2331
|
+
else None
|
2332
|
+
),
|
2294
2333
|
refresh_mode=row["refresh_mode"],
|
2295
2334
|
refresh_mode_reason=row["refresh_mode_reason"],
|
2335
|
+
initialize=initialize,
|
2296
2336
|
owner=row["owner"],
|
2297
2337
|
infer_schema_df=infer_schema_df,
|
2298
2338
|
session=self._session,
|
@@ -2319,6 +2359,7 @@ class FeatureStore:
|
|
2319
2359
|
warehouse=None,
|
2320
2360
|
refresh_mode=None,
|
2321
2361
|
refresh_mode_reason=None,
|
2362
|
+
initialize="ON_CREATE",
|
2322
2363
|
owner=row["owner"],
|
2323
2364
|
infer_schema_df=infer_schema_df,
|
2324
2365
|
session=self._session,
|