snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +156 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +88 -9
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +100 -41
- snowflake/ml/feature_store/feature_view.py +149 -5
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +186 -20
- snowflake/ml/model/_client/ops/model_ops.py +144 -30
- snowflake/ml/model/_client/ops/service_ops.py +312 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +196 -0
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +5 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
- snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
- snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
- snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
- snowflake/ml/model/_packager/model_packager.py +4 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +10 -4
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +20 -2
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
|
|
1
1
|
from typing import Any, Dict, List, Optional, Tuple
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
|
-
from snowflake.ml._internal.utils import formatting, query_result_checker
|
4
|
+
from snowflake.ml._internal.utils import formatting, identifier, query_result_checker
|
5
|
+
from snowflake.snowpark import types
|
5
6
|
|
6
7
|
"""Table_manager is a set of utils that helps create tables.
|
7
8
|
|
@@ -104,3 +105,20 @@ def get_table_schema(session: snowpark.Session, table_name: str, qualified_schem
|
|
104
105
|
for row in result:
|
105
106
|
schema_dict[row["name"]] = row["type"]
|
106
107
|
return schema_dict
|
108
|
+
|
109
|
+
|
110
|
+
def get_table_schema_types(
|
111
|
+
session: snowpark.Session,
|
112
|
+
database: str,
|
113
|
+
schema: str,
|
114
|
+
table_name: str,
|
115
|
+
) -> Dict[str, types.DataType]:
|
116
|
+
fully_qualified_table_name = identifier.get_schema_level_object_identifier(
|
117
|
+
db=database, schema=schema, object_name=table_name
|
118
|
+
)
|
119
|
+
struct_fields: List[types.StructField] = session.table(fully_qualified_table_name).schema.fields
|
120
|
+
|
121
|
+
schema_dict: Dict[str, types.DataType] = {}
|
122
|
+
for field in struct_fields:
|
123
|
+
schema_dict[field.name] = field.datatype
|
124
|
+
return schema_dict
|
@@ -53,7 +53,7 @@ def get_uri_scheme(uri: str) -> str:
|
|
53
53
|
def get_uri_from_snowflake_stage_path(stage_path: str) -> str:
|
54
54
|
"""Generates a URI from Snowflake stage path."""
|
55
55
|
assert stage_path.startswith("@")
|
56
|
-
(db, schema, stage, path) = identifier.
|
56
|
+
(db, schema, stage, path) = identifier.parse_snowflake_stage_path(
|
57
57
|
posixpath.normpath(identifier.remove_prefix(stage_path, "@"))
|
58
58
|
)
|
59
59
|
return urlunparse(
|
@@ -70,7 +70,7 @@ def get_uri_from_snowflake_stage_path(stage_path: str) -> str:
|
|
70
70
|
|
71
71
|
def get_stage_and_path(stage_path: str) -> Tuple[str, str]:
|
72
72
|
assert stage_path.startswith("@"), f"stage path should start with @, actual: {stage_path}"
|
73
|
-
(db, schema, stage, path) = identifier.
|
73
|
+
(db, schema, stage, path) = identifier.parse_snowflake_stage_path(
|
74
74
|
posixpath.normpath(identifier.remove_prefix(stage_path, "@"))
|
75
75
|
)
|
76
76
|
full_qualified_stage = "@" + identifier.get_schema_level_object_identifier(db, schema, stage)
|
@@ -2,17 +2,17 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import time
|
5
|
-
from typing import Any, Deque, Dict, Iterator, List, Optional
|
5
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
9
9
|
import pandas as pd
|
10
10
|
import pyarrow as pa
|
11
|
-
import pyarrow.dataset as
|
11
|
+
import pyarrow.dataset as pds
|
12
12
|
|
13
13
|
from snowflake import snowpark
|
14
|
-
from snowflake.
|
15
|
-
from snowflake.ml.data
|
14
|
+
from snowflake.connector import result_batch
|
15
|
+
from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
|
16
16
|
|
17
17
|
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
18
18
|
|
@@ -67,6 +67,10 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
67
67
|
|
68
68
|
self._schema: Optional[pa.Schema] = None
|
69
69
|
|
70
|
+
@classmethod
|
71
|
+
def from_sources(cls, session: snowpark.Session, sources: List[data_source.DataSource]) -> "ArrowIngestor":
|
72
|
+
return cls(session, sources)
|
73
|
+
|
70
74
|
@property
|
71
75
|
def data_sources(self) -> List[data_source.DataSource]:
|
72
76
|
return self._data_sources
|
@@ -115,9 +119,9 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
115
119
|
table = ds.to_table() if limit is None else ds.head(num_rows=limit)
|
116
120
|
return table.to_pandas()
|
117
121
|
|
118
|
-
def _get_dataset(self, shuffle: bool) ->
|
122
|
+
def _get_dataset(self, shuffle: bool) -> pds.Dataset:
|
119
123
|
format = self._format
|
120
|
-
sources = []
|
124
|
+
sources: List[Any] = []
|
121
125
|
source_format = None
|
122
126
|
for source in self._data_sources:
|
123
127
|
if isinstance(source, str):
|
@@ -137,8 +141,16 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
137
141
|
# in-memory (first batch) and file URLs (subsequent batches) and creating a
|
138
142
|
# union dataset.
|
139
143
|
result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
|
140
|
-
sources.extend(
|
141
|
-
|
144
|
+
sources.extend(
|
145
|
+
b.to_arrow(self._session.connection)
|
146
|
+
if isinstance(b, result_batch.ArrowResultBatch)
|
147
|
+
else b.to_arrow()
|
148
|
+
for b in result_batches
|
149
|
+
)
|
150
|
+
# HACK: Mitigate typing inconsistencies in Snowpark results
|
151
|
+
if len(sources) > 0:
|
152
|
+
sources = [_cast_if_needed(s, sources[-1].schema) for s in sources]
|
153
|
+
source_format = None # Arrow Dataset expects "None" for in-memory datasets
|
142
154
|
else:
|
143
155
|
raise RuntimeError(f"Unsupported data source type: {type(source)}")
|
144
156
|
|
@@ -150,7 +162,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
150
162
|
# Re-shuffle input files on each iteration start
|
151
163
|
if shuffle:
|
152
164
|
np.random.shuffle(sources)
|
153
|
-
pa_dataset:
|
165
|
+
pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
|
154
166
|
return pa_dataset
|
155
167
|
|
156
168
|
def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
|
@@ -201,7 +213,7 @@ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
|
|
201
213
|
|
202
214
|
|
203
215
|
def _retryable_batches(
|
204
|
-
dataset:
|
216
|
+
dataset: pds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
|
205
217
|
) -> Iterator[pa.RecordBatch]:
|
206
218
|
"""Make the Dataset to_batches retryable."""
|
207
219
|
retries = 0
|
@@ -226,3 +238,47 @@ def _retryable_batches(
|
|
226
238
|
time.sleep(delay)
|
227
239
|
else:
|
228
240
|
raise e
|
241
|
+
|
242
|
+
|
243
|
+
def _cast_if_needed(
|
244
|
+
batch: Union[pa.Table, pa.RecordBatch], schema: Optional[pa.Schema] = None
|
245
|
+
) -> Union[pa.Table, pa.RecordBatch]:
|
246
|
+
"""
|
247
|
+
Cast the batch to be compatible with downstream frameworks. Returns original batch if cast is not necessary.
|
248
|
+
Besides casting types to match `schema` (if provided), this function also applies the following casting:
|
249
|
+
- Decimal (fixed-point) types: Convert to float or integer types based on scale and byte length
|
250
|
+
|
251
|
+
Args:
|
252
|
+
batch: The PyArrow batch to cast if needed
|
253
|
+
schema: Optional schema the batch should be casted to match. Note that compatibility type casting takes
|
254
|
+
precedence over the provided schema, e.g. if the schema has decimal types the result will be further
|
255
|
+
cast into integer/float types.
|
256
|
+
|
257
|
+
Returns:
|
258
|
+
The type-casted PyArrow batch, or the original batch if casting was not necessary
|
259
|
+
"""
|
260
|
+
schema = schema or batch.schema
|
261
|
+
assert len(batch.schema) == len(schema)
|
262
|
+
fields = []
|
263
|
+
cast_needed = False
|
264
|
+
for field, target in zip(batch.schema, schema):
|
265
|
+
# Need to convert decimal types to supported types. This behavior supersedes target schema data types
|
266
|
+
if pa.types.is_decimal(target.type):
|
267
|
+
byte_length = int(target.metadata.get(b"byteLength", 8))
|
268
|
+
if int(target.metadata.get(b"scale", 0)) > 0:
|
269
|
+
target = target.with_type(pa.float32() if byte_length == 4 else pa.float64())
|
270
|
+
else:
|
271
|
+
if byte_length == 2:
|
272
|
+
target = target.with_type(pa.int16())
|
273
|
+
elif byte_length == 4:
|
274
|
+
target = target.with_type(pa.int32())
|
275
|
+
else: # Cap out at 64-bit
|
276
|
+
target = target.with_type(pa.int64())
|
277
|
+
if not field.equals(target):
|
278
|
+
cast_needed = True
|
279
|
+
field = target
|
280
|
+
fields.append(field)
|
281
|
+
|
282
|
+
if cast_needed:
|
283
|
+
return batch.cast(pa.schema(fields))
|
284
|
+
return batch
|
@@ -1,11 +1,17 @@
|
|
1
|
+
import os
|
1
2
|
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
|
2
3
|
|
3
4
|
import numpy.typing as npt
|
5
|
+
from typing_extensions import deprecated
|
4
6
|
|
5
7
|
from snowflake import snowpark
|
6
8
|
from snowflake.ml._internal import telemetry
|
7
9
|
from snowflake.ml.data import data_ingestor, data_source
|
8
|
-
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
10
|
+
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
11
|
+
from snowflake.ml.modeling._internal.constants import (
|
12
|
+
IN_ML_RUNTIME_ENV_VAR,
|
13
|
+
USE_OPTIMIZED_DATA_INGESTOR,
|
14
|
+
)
|
9
15
|
|
10
16
|
if TYPE_CHECKING:
|
11
17
|
import pandas as pd
|
@@ -24,6 +30,8 @@ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
|
|
24
30
|
class DataConnector:
|
25
31
|
"""Snowflake data reader which provides application integration connectors"""
|
26
32
|
|
33
|
+
DEFAULT_INGESTOR_CLASS: Type[data_ingestor.DataIngestor] = ArrowIngestor
|
34
|
+
|
27
35
|
def __init__(
|
28
36
|
self,
|
29
37
|
ingestor: data_ingestor.DataIngestor,
|
@@ -31,22 +39,48 @@ class DataConnector:
|
|
31
39
|
self._ingestor = ingestor
|
32
40
|
|
33
41
|
@classmethod
|
34
|
-
|
42
|
+
@snowpark._internal.utils.private_preview(version="1.6.0")
|
43
|
+
def from_dataframe(
|
44
|
+
cls: Type[DataConnectorType],
|
45
|
+
df: snowpark.DataFrame,
|
46
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
47
|
+
**kwargs: Any
|
48
|
+
) -> DataConnectorType:
|
35
49
|
if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
|
36
50
|
raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
|
37
51
|
source = data_source.DataFrameInfo(df.queries["queries"][0])
|
38
52
|
assert df._session is not None
|
39
|
-
|
40
|
-
return cls(ingestor, **kwargs)
|
53
|
+
return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
|
41
54
|
|
42
55
|
@classmethod
|
43
|
-
def from_dataset(
|
56
|
+
def from_dataset(
|
57
|
+
cls: Type[DataConnectorType],
|
58
|
+
ds: "dataset.Dataset",
|
59
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
60
|
+
**kwargs: Any
|
61
|
+
) -> DataConnectorType:
|
44
62
|
dsv = ds.selected_version
|
45
63
|
assert dsv is not None
|
46
64
|
source = data_source.DatasetInfo(
|
47
65
|
ds.fully_qualified_name, dsv.name, dsv.url(), exclude_cols=(dsv.label_cols + dsv.exclude_cols)
|
48
66
|
)
|
49
|
-
|
67
|
+
return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
@telemetry.send_api_usage_telemetry(
|
71
|
+
project=_PROJECT,
|
72
|
+
subproject_extractor=lambda cls: cls.__name__,
|
73
|
+
func_params_to_log=["sources", "ingestor_class"],
|
74
|
+
)
|
75
|
+
def from_sources(
|
76
|
+
cls: Type[DataConnectorType],
|
77
|
+
session: snowpark.Session,
|
78
|
+
sources: List[data_source.DataSource],
|
79
|
+
ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
|
80
|
+
**kwargs: Any
|
81
|
+
) -> DataConnectorType:
|
82
|
+
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
83
|
+
ingestor = ingestor_class.from_sources(session, sources)
|
50
84
|
return cls(ingestor, **kwargs)
|
51
85
|
|
52
86
|
@property
|
@@ -87,6 +121,9 @@ class DataConnector:
|
|
87
121
|
|
88
122
|
return tf.data.Dataset.from_generator(generator, output_signature=tf_signature)
|
89
123
|
|
124
|
+
@deprecated(
|
125
|
+
"to_torch_datapipe() is deprecated and will be removed in a future release. Use to_torch_dataset() instead"
|
126
|
+
)
|
90
127
|
@telemetry.send_api_usage_telemetry(
|
91
128
|
project=_PROJECT,
|
92
129
|
subproject_extractor=lambda self: type(self).__name__,
|
@@ -110,10 +147,40 @@ class DataConnector:
|
|
110
147
|
Returns:
|
111
148
|
A Pytorch iterable datapipe that yield data.
|
112
149
|
"""
|
113
|
-
from
|
150
|
+
from snowflake.ml.data import torch_utils
|
114
151
|
|
115
|
-
return
|
116
|
-
self._ingestor
|
152
|
+
return torch_utils.TorchDataPipeWrapper(
|
153
|
+
self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
|
154
|
+
)
|
155
|
+
|
156
|
+
@telemetry.send_api_usage_telemetry(
|
157
|
+
project=_PROJECT,
|
158
|
+
subproject_extractor=lambda self: type(self).__name__,
|
159
|
+
func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
|
160
|
+
)
|
161
|
+
def to_torch_dataset(
|
162
|
+
self, *, batch_size: int = 1, shuffle: bool = False, drop_last_batch: bool = True
|
163
|
+
) -> "torch_data.IterableDataset": # type: ignore[type-arg]
|
164
|
+
"""Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader.
|
165
|
+
|
166
|
+
Return a PyTorch Dataset which iterates on rows of data.
|
167
|
+
|
168
|
+
Args:
|
169
|
+
batch_size: It specifies the size of each data batch which will be yielded in the result dataset.
|
170
|
+
Batching is pushed down to data ingestion level which may be more performant than DataLoader
|
171
|
+
batching.
|
172
|
+
shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
|
173
|
+
rows in each file will also be shuffled.
|
174
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
|
175
|
+
then the last batch will get dropped if its size is smaller than the given batch_size.
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
A PyTorch Iterable Dataset that yields data.
|
179
|
+
"""
|
180
|
+
from snowflake.ml.data import torch_utils
|
181
|
+
|
182
|
+
return torch_utils.TorchDatasetWrapper(
|
183
|
+
self._ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last_batch
|
117
184
|
)
|
118
185
|
|
119
186
|
@telemetry.send_api_usage_telemetry(
|
@@ -131,3 +198,15 @@ class DataConnector:
|
|
131
198
|
A Pandas DataFrame.
|
132
199
|
"""
|
133
200
|
return self._ingestor.to_pandas(limit)
|
201
|
+
|
202
|
+
|
203
|
+
# Switch to use Runtime's Data Ingester if running in ML runtime
|
204
|
+
# Fail silently if the data ingester is not found
|
205
|
+
if os.getenv(IN_ML_RUNTIME_ENV_VAR) and os.getenv(USE_OPTIMIZED_DATA_INGESTOR):
|
206
|
+
try:
|
207
|
+
from runtime_external_entities import get_ingester_class
|
208
|
+
|
209
|
+
DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class()
|
210
|
+
except ImportError:
|
211
|
+
"""Runtime Default Ingester not found, ignore"""
|
212
|
+
pass
|
@@ -1,7 +1,18 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import (
|
2
|
+
TYPE_CHECKING,
|
3
|
+
Any,
|
4
|
+
Dict,
|
5
|
+
Iterator,
|
6
|
+
List,
|
7
|
+
Optional,
|
8
|
+
Protocol,
|
9
|
+
Type,
|
10
|
+
TypeVar,
|
11
|
+
)
|
2
12
|
|
3
13
|
from numpy import typing as npt
|
4
14
|
|
15
|
+
from snowflake import snowpark
|
5
16
|
from snowflake.ml.data import data_source
|
6
17
|
|
7
18
|
if TYPE_CHECKING:
|
@@ -12,6 +23,12 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
12
23
|
|
13
24
|
|
14
25
|
class DataIngestor(Protocol):
|
26
|
+
@classmethod
|
27
|
+
def from_sources(
|
28
|
+
cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
|
29
|
+
) -> DataIngestorType:
|
30
|
+
raise NotImplementedError
|
31
|
+
|
15
32
|
@property
|
16
33
|
def data_sources(self) -> List[data_source.DataSource]:
|
17
34
|
raise NotImplementedError
|
@@ -13,6 +13,7 @@ _TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
|
|
13
13
|
def get_dataframe_result_batches(
|
14
14
|
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
15
15
|
) -> List[result_batch.ResultBatch]:
|
16
|
+
"""Retrieve the ResultBatches for a given query"""
|
16
17
|
cursor = session._conn._cursor
|
17
18
|
|
18
19
|
if df_info.query_id:
|
@@ -39,6 +40,7 @@ def get_dataframe_result_batches(
|
|
39
40
|
def get_dataset_filesystem(
|
40
41
|
session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
|
41
42
|
) -> fsspec.AbstractFileSystem:
|
43
|
+
"""Get the fsspec filesystem for a given Dataset"""
|
42
44
|
# We can't directly load the Dataset to avoid a circular dependency
|
43
45
|
# Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
|
44
46
|
# TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
|
@@ -52,7 +54,9 @@ def get_dataset_filesystem(
|
|
52
54
|
def get_dataset_files(
|
53
55
|
session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
|
54
56
|
) -> List[str]:
|
57
|
+
"""Get the list of files in a given Dataset"""
|
55
58
|
if filesystem is None:
|
56
59
|
filesystem = get_dataset_filesystem(session, ds_info)
|
57
60
|
assert bool(ds_info.url) # Not null or empty
|
58
|
-
|
61
|
+
files = sorted(filesystem.ls(ds_info.url))
|
62
|
+
return [filesystem.unstrip_protocol(f) for f in files]
|
@@ -0,0 +1,68 @@
|
|
1
|
+
from typing import Any, Dict, Iterator, List, Union
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import numpy.typing as npt
|
5
|
+
import torch.utils.data
|
6
|
+
|
7
|
+
from snowflake.ml.data import data_ingestor
|
8
|
+
|
9
|
+
|
10
|
+
class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
11
|
+
"""Wrap a DataIngestor into a PyTorch IterableDataset"""
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
ingestor: data_ingestor.DataIngestor,
|
16
|
+
*,
|
17
|
+
batch_size: int,
|
18
|
+
shuffle: bool = False,
|
19
|
+
drop_last: bool = False,
|
20
|
+
squeeze_outputs: bool = True
|
21
|
+
) -> None:
|
22
|
+
"""Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
|
23
|
+
self._ingestor = ingestor
|
24
|
+
self._batch_size = batch_size
|
25
|
+
self._shuffle = shuffle
|
26
|
+
self._drop_last = drop_last
|
27
|
+
self._squeeze_outputs = squeeze_outputs
|
28
|
+
|
29
|
+
def __iter__(self) -> Iterator[Dict[str, Union[npt.NDArray[Any], List[Any]]]]:
|
30
|
+
max_idx = 0
|
31
|
+
filter_idx = 0
|
32
|
+
worker_info = torch.utils.data.get_worker_info()
|
33
|
+
if worker_info is not None:
|
34
|
+
max_idx = worker_info.num_workers - 1
|
35
|
+
filter_idx = worker_info.id
|
36
|
+
|
37
|
+
if self._shuffle and worker_info is not None:
|
38
|
+
raise RuntimeError("Dataset shuffling not currently supported with multithreading")
|
39
|
+
|
40
|
+
counter = 0
|
41
|
+
for batch in self._ingestor.to_batches(
|
42
|
+
batch_size=self._batch_size, shuffle=self._shuffle, drop_last_batch=self._drop_last
|
43
|
+
):
|
44
|
+
# Skip indices during multi-process data loading to prevent data duplication
|
45
|
+
if counter == filter_idx:
|
46
|
+
# Basic preprocessing on batch values: squeeze away extra dimensions
|
47
|
+
# and convert object arrays (e.g. strings) to lists
|
48
|
+
if self._squeeze_outputs:
|
49
|
+
yield {
|
50
|
+
k: (v.squeeze().tolist() if v.dtype == np.object_ else v.squeeze()) for k, v in batch.items()
|
51
|
+
}
|
52
|
+
else:
|
53
|
+
yield batch # type: ignore[misc]
|
54
|
+
|
55
|
+
if counter < max_idx:
|
56
|
+
counter += 1
|
57
|
+
else:
|
58
|
+
counter = 0
|
59
|
+
|
60
|
+
|
61
|
+
class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Dict[str, Any]]):
|
62
|
+
"""Wrap a DataIngestor into a PyTorch IterDataPipe"""
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self, ingestor: data_ingestor.DataIngestor, *, batch_size: int, shuffle: bool = False, drop_last: bool = False
|
66
|
+
) -> None:
|
67
|
+
"""Not intended for direct usage. Use DataConnector.to_torch_datapipe() instead"""
|
68
|
+
super().__init__(ingestor, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, squeeze_outputs=False)
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -472,9 +472,7 @@ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
|
|
472
472
|
|
473
473
|
def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> Tuple[str, str, str]:
|
474
474
|
"""Resolve a dataset name into a validated schema-level location identifier"""
|
475
|
-
db, schema, object_name
|
476
|
-
if others:
|
477
|
-
raise ValueError(f"Invalid identifier: unexpected '{others}'")
|
475
|
+
db, schema, object_name = identifier.parse_schema_level_object_identifier(dataset_name)
|
478
476
|
db = db or session.get_current_database()
|
479
477
|
schema = schema or session.get_current_schema()
|
480
478
|
return str(db), str(schema), str(object_name)
|
@@ -15,11 +15,13 @@ class FeatureStoreMetadata:
|
|
15
15
|
Properties:
|
16
16
|
spine_query: The input query on source table which will be joined with features.
|
17
17
|
serialized_feature_views: A list of serialized feature objects in the feature store.
|
18
|
+
compact_feature_views: A compact representation of a FeatureView or FeatureViewSlice.
|
18
19
|
spine_timestamp_col: Timestamp column which was used for point-in-time correct feature lookup.
|
19
20
|
"""
|
20
21
|
|
21
22
|
spine_query: str
|
22
|
-
serialized_feature_views: List[str]
|
23
|
+
serialized_feature_views: Optional[List[str]] = None
|
24
|
+
compact_feature_views: Optional[List[str]] = None
|
23
25
|
spine_timestamp_col: Optional[str] = None
|
24
26
|
|
25
27
|
def to_json(self) -> str:
|
@@ -1,10 +1,9 @@
|
|
1
|
-
from typing import List, Optional
|
1
|
+
from typing import Any, List, Optional, Type
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal import telemetry
|
5
5
|
from snowflake.ml._internal.lineage import lineage_utils
|
6
|
-
from snowflake.ml.data import data_connector, data_ingestor, data_source
|
7
|
-
from snowflake.ml.data._internal import ingestor_utils
|
6
|
+
from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
|
8
7
|
from snowflake.ml.fileset import snowfs
|
9
8
|
|
10
9
|
_PROJECT = "Dataset"
|
@@ -27,6 +26,13 @@ class DatasetReader(data_connector.DataConnector):
|
|
27
26
|
self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
|
28
27
|
self._files: Optional[List[str]] = None
|
29
28
|
|
29
|
+
@classmethod
|
30
|
+
def from_dataframe(
|
31
|
+
cls, df: snowpark.DataFrame, ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None, **kwargs: Any
|
32
|
+
) -> "DatasetReader":
|
33
|
+
# Block superclass constructor from Snowpark DataFrames
|
34
|
+
raise RuntimeError("Creating DatasetReader from DataFrames not supported")
|
35
|
+
|
30
36
|
def _list_files(self) -> List[str]:
|
31
37
|
"""Private helper function that lists all files in this DatasetVersion and caches the results."""
|
32
38
|
if self._files:
|
@@ -0,0 +1,16 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from snowflake.ml.feature_store import Entity
|
4
|
+
|
5
|
+
zipcode_entity = Entity(
|
6
|
+
name="AIRPORT_ZIP_CODE",
|
7
|
+
join_keys=["AIRPORT_ZIP_CODE"],
|
8
|
+
desc="Zip code of the airport.",
|
9
|
+
)
|
10
|
+
|
11
|
+
plane_entity = Entity(name="PLANE_MODEL", join_keys=["PLANE_MODEL"], desc="The model of an airplane.")
|
12
|
+
|
13
|
+
|
14
|
+
# This will be invoked by example_helper.py. Do not change function name.
|
15
|
+
def get_all_entities() -> List[Entity]:
|
16
|
+
return [zipcode_entity, plane_entity]
|
@@ -0,0 +1,31 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from snowflake.ml.feature_store import FeatureView
|
4
|
+
from snowflake.ml.feature_store.examples.airline_features.entities import plane_entity
|
5
|
+
from snowflake.snowpark import DataFrame, Session
|
6
|
+
|
7
|
+
|
8
|
+
# This function will be invoked by example_helper.py. Do not change the name.
|
9
|
+
def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
|
10
|
+
"""Create a feature view about airplane model."""
|
11
|
+
query = session.sql(
|
12
|
+
"""
|
13
|
+
select
|
14
|
+
PLANE_MODEL,
|
15
|
+
SEATING_CAPACITY
|
16
|
+
from
|
17
|
+
PLANE_MODEL_ATTRIBUTES
|
18
|
+
"""
|
19
|
+
)
|
20
|
+
|
21
|
+
return FeatureView(
|
22
|
+
name="f_plane", # name of feature view
|
23
|
+
entities=[plane_entity], # entities
|
24
|
+
feature_df=query, # definition query
|
25
|
+
refresh_freq=None, # refresh frequency
|
26
|
+
desc="Plane features never refresh.",
|
27
|
+
).attach_feature_desc(
|
28
|
+
{
|
29
|
+
"SEATING_CAPACITY": "The seating capacity of a plane.",
|
30
|
+
}
|
31
|
+
)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from snowflake.ml.feature_store import FeatureView
|
4
|
+
from snowflake.ml.feature_store.examples.airline_features.entities import zipcode_entity
|
5
|
+
from snowflake.snowpark import DataFrame, Session
|
6
|
+
|
7
|
+
|
8
|
+
# This function will be invoked by example_helper.py. Do not change the name.
|
9
|
+
def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], source_tables: List[str]) -> FeatureView:
|
10
|
+
"""Create a feature view about airport weather."""
|
11
|
+
query = session.sql(
|
12
|
+
"""
|
13
|
+
select
|
14
|
+
DATETIME_UTC AS TS,
|
15
|
+
AIRPORT_ZIP_CODE,
|
16
|
+
sum(RAIN_MM_H) over (
|
17
|
+
partition by AIRPORT_ZIP_CODE
|
18
|
+
order by DATETIME_UTC
|
19
|
+
range between interval '30 minutes' preceding and current row
|
20
|
+
) RAIN_SUM_30M,
|
21
|
+
sum(RAIN_MM_H) over (
|
22
|
+
partition by AIRPORT_ZIP_CODE
|
23
|
+
order by DATETIME_UTC
|
24
|
+
range between interval '1 day' preceding and current row
|
25
|
+
) RAIN_SUM_60M
|
26
|
+
from AIRPORT_WEATHER_STATION
|
27
|
+
"""
|
28
|
+
)
|
29
|
+
|
30
|
+
return FeatureView(
|
31
|
+
name="f_weather", # name of feature view
|
32
|
+
entities=[zipcode_entity], # entities
|
33
|
+
feature_df=query, # definition query
|
34
|
+
timestamp_col="TS", # timestamp column
|
35
|
+
refresh_freq="1d", # refresh frequency
|
36
|
+
desc="Airport weather features refreshed every day.",
|
37
|
+
).attach_feature_desc(
|
38
|
+
{
|
39
|
+
"RAIN_SUM_30M": "The sum of rain fall over past 30 minutes for one zipcode.",
|
40
|
+
"RAIN_SUM_60M": "The sum of rain fall over past 1 day for one zipcode.",
|
41
|
+
}
|
42
|
+
)
|
@@ -14,18 +14,24 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
|
|
14
14
|
f"""
|
15
15
|
select
|
16
16
|
end_station_id,
|
17
|
-
count(end_station_id) as
|
18
|
-
avg(end_station_latitude) as
|
19
|
-
avg(end_station_longitude) as
|
17
|
+
count(end_station_id) as f_count,
|
18
|
+
avg(end_station_latitude) as f_avg_latitude,
|
19
|
+
avg(end_station_longitude) as f_avg_longtitude
|
20
20
|
from {source_tables[0]}
|
21
21
|
group by end_station_id
|
22
22
|
"""
|
23
23
|
)
|
24
24
|
|
25
25
|
return FeatureView(
|
26
|
-
name="
|
26
|
+
name="f_station", # name of feature view
|
27
27
|
entities=[end_station_id], # entities
|
28
28
|
feature_df=query, # definition query
|
29
29
|
refresh_freq="1d", # refresh frequency. '1d' means it refreshes everyday
|
30
30
|
desc="Station features refreshed every day.",
|
31
|
+
).attach_feature_desc(
|
32
|
+
{
|
33
|
+
"f_count": "How many times this station appears in 1 day.",
|
34
|
+
"f_avg_latitude": "Averaged latitude of a station.",
|
35
|
+
"f_avg_longtitude": "Averaged longtitude of a station.",
|
36
|
+
}
|
31
37
|
)
|
@@ -21,4 +21,10 @@ def create_draft_feature_view(session: Session, source_dfs: List[DataFrame], sou
|
|
21
21
|
feature_df=feature_df, # definition query
|
22
22
|
refresh_freq=None, # refresh frequency. None indicates it never refresh
|
23
23
|
desc="Static trip features",
|
24
|
+
).attach_feature_desc(
|
25
|
+
{
|
26
|
+
"f_birth_year": "The birth year of a trip passenger.",
|
27
|
+
"f_gender": "The gender of a trip passenger.",
|
28
|
+
"f_bikeid": "The bike id of a trip passenger.",
|
29
|
+
}
|
24
30
|
)
|