snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal.utils import formatting, identifier, query_result_checker
|
@@ -24,8 +24,8 @@ def create_single_table(
|
|
24
24
|
database_name: str,
|
25
25
|
schema_name: str,
|
26
26
|
table_name: str,
|
27
|
-
table_schema:
|
28
|
-
statement_params: Optional[
|
27
|
+
table_schema: list[tuple[str, str]],
|
28
|
+
statement_params: Optional[dict[str, Any]] = None,
|
29
29
|
) -> str:
|
30
30
|
"""Creates a single table for registry and returns the fully qualified name of the table.
|
31
31
|
|
@@ -55,7 +55,7 @@ def create_single_table(
|
|
55
55
|
return fully_qualified_table_name
|
56
56
|
|
57
57
|
|
58
|
-
def insert_table_entry(session: snowpark.Session, table: str, columns:
|
58
|
+
def insert_table_entry(session: snowpark.Session, table: str, columns: dict[str, Any]) -> list[snowpark.Row]:
|
59
59
|
"""Insert an entry into an internal Model Registry table.
|
60
60
|
|
61
61
|
Args:
|
@@ -99,9 +99,9 @@ def validate_table_exist(session: snowpark.Session, table: str, qualified_schema
|
|
99
99
|
return len(tables) == 1
|
100
100
|
|
101
101
|
|
102
|
-
def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) ->
|
102
|
+
def get_table_schema(session: snowpark.Session, table_name: str, qualified_schema_name: str) -> dict[str, str]:
|
103
103
|
result = session.sql(f"DESC TABLE {qualified_schema_name}.{table_name}").collect()
|
104
|
-
schema_dict:
|
104
|
+
schema_dict: dict[str, str] = {}
|
105
105
|
for row in result:
|
106
106
|
schema_dict[row["name"]] = row["type"]
|
107
107
|
return schema_dict
|
@@ -112,13 +112,13 @@ def get_table_schema_types(
|
|
112
112
|
database: str,
|
113
113
|
schema: str,
|
114
114
|
table_name: str,
|
115
|
-
) ->
|
115
|
+
) -> dict[str, types.DataType]:
|
116
116
|
fully_qualified_table_name = identifier.get_schema_level_object_identifier(
|
117
117
|
db=database, schema=schema, object_name=table_name
|
118
118
|
)
|
119
|
-
struct_fields:
|
119
|
+
struct_fields: list[types.StructField] = session.table(fully_qualified_table_name).schema.fields
|
120
120
|
|
121
|
-
schema_dict:
|
121
|
+
schema_dict: dict[str, types.DataType] = {}
|
122
122
|
for field in struct_fields:
|
123
123
|
schema_dict[field.name] = field.datatype
|
124
124
|
return schema_dict
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import os
|
4
4
|
import time
|
5
|
-
from typing import Any, Deque,
|
5
|
+
from typing import Any, Deque, Iterator, Optional, Sequence, Union
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import numpy.typing as npt
|
@@ -71,7 +71,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
71
71
|
return cls(session, sources)
|
72
72
|
|
73
73
|
@property
|
74
|
-
def data_sources(self) ->
|
74
|
+
def data_sources(self) -> list[data_source.DataSource]:
|
75
75
|
return self._data_sources
|
76
76
|
|
77
77
|
def to_batches(
|
@@ -79,7 +79,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
79
79
|
batch_size: int,
|
80
80
|
shuffle: bool = True,
|
81
81
|
drop_last_batch: bool = True,
|
82
|
-
) -> Iterator[
|
82
|
+
) -> Iterator[dict[str, npt.NDArray[Any]]]:
|
83
83
|
"""Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
|
84
84
|
|
85
85
|
As we are generating batches with the exactly same length, the last few rows in each file might get left as they
|
@@ -120,7 +120,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
120
120
|
|
121
121
|
def _get_dataset(self, shuffle: bool) -> pds.Dataset:
|
122
122
|
format = self._format
|
123
|
-
sources:
|
123
|
+
sources: list[Any] = []
|
124
124
|
source_format = None
|
125
125
|
for source in self._data_sources:
|
126
126
|
if isinstance(source, str):
|
@@ -155,7 +155,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
155
155
|
pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
|
156
156
|
return pa_dataset
|
157
157
|
|
158
|
-
def _get_batches_from_buffer(self, batch_size: int) ->
|
158
|
+
def _get_batches_from_buffer(self, batch_size: int) -> dict[str, npt.NDArray[Any]]:
|
159
159
|
"""Generate new batches from the existing record batch buffer."""
|
160
160
|
cnt_rbs_num_rows = 0
|
161
161
|
candidates = []
|
@@ -180,7 +180,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
|
|
180
180
|
return _record_batch_to_arrays(res)
|
181
181
|
|
182
182
|
|
183
|
-
def _merge_record_batches(record_batches:
|
183
|
+
def _merge_record_batches(record_batches: list[pa.RecordBatch]) -> pa.RecordBatch:
|
184
184
|
"""Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
|
185
185
|
if not record_batches:
|
186
186
|
return _EMPTY_RECORD_BATCH
|
@@ -192,7 +192,7 @@ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatc
|
|
192
192
|
return batches[0]
|
193
193
|
|
194
194
|
|
195
|
-
def _record_batch_to_arrays(rb: pa.RecordBatch) ->
|
195
|
+
def _record_batch_to_arrays(rb: pa.RecordBatch) -> dict[str, npt.NDArray[Any]]:
|
196
196
|
"""Transform the record batch to a (string, numpy array) dict."""
|
197
197
|
batch_dict = {}
|
198
198
|
for column, column_schema in zip(rb, rb.schema):
|
@@ -1,32 +1,18 @@
|
|
1
1
|
import os
|
2
|
-
from typing import
|
3
|
-
TYPE_CHECKING,
|
4
|
-
Any,
|
5
|
-
Dict,
|
6
|
-
Generator,
|
7
|
-
List,
|
8
|
-
Optional,
|
9
|
-
Sequence,
|
10
|
-
Type,
|
11
|
-
TypeVar,
|
12
|
-
cast,
|
13
|
-
)
|
2
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence, TypeVar
|
14
3
|
|
15
4
|
import numpy.typing as npt
|
16
5
|
from typing_extensions import deprecated
|
17
6
|
|
18
7
|
from snowflake import snowpark
|
19
|
-
from snowflake.ml._internal import telemetry
|
8
|
+
from snowflake.ml._internal import env, telemetry
|
20
9
|
from snowflake.ml.data import data_ingestor, data_source
|
21
10
|
from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
|
22
|
-
from snowflake.ml.modeling._internal.constants import (
|
23
|
-
IN_ML_RUNTIME_ENV_VAR,
|
24
|
-
USE_OPTIMIZED_DATA_INGESTOR,
|
25
|
-
)
|
26
11
|
from snowflake.snowpark import context as sf_context
|
27
12
|
|
28
13
|
if TYPE_CHECKING:
|
29
14
|
import pandas as pd
|
15
|
+
import ray
|
30
16
|
import tensorflow as tf
|
31
17
|
from torch.utils import data as torch_data
|
32
18
|
|
@@ -42,7 +28,7 @@ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
|
|
42
28
|
class DataConnector:
|
43
29
|
"""Snowflake data reader which provides application integration connectors"""
|
44
30
|
|
45
|
-
DEFAULT_INGESTOR_CLASS:
|
31
|
+
DEFAULT_INGESTOR_CLASS: type[data_ingestor.DataIngestor] = ArrowIngestor
|
46
32
|
|
47
33
|
def __init__(
|
48
34
|
self,
|
@@ -53,27 +39,22 @@ class DataConnector:
|
|
53
39
|
self._kwargs = kwargs
|
54
40
|
|
55
41
|
@classmethod
|
56
|
-
@snowpark._internal.utils.private_preview(version="1.6.0")
|
57
42
|
def from_dataframe(
|
58
|
-
cls:
|
43
|
+
cls: type[DataConnectorType],
|
59
44
|
df: snowpark.DataFrame,
|
60
|
-
ingestor_class: Optional[
|
45
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
61
46
|
**kwargs: Any,
|
62
47
|
) -> DataConnectorType:
|
63
48
|
if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
|
64
49
|
raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
|
65
|
-
return
|
66
|
-
DataConnectorType,
|
67
|
-
cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs),
|
68
|
-
)
|
50
|
+
return cls.from_sql(df.queries["queries"][0], session=df._session, ingestor_class=ingestor_class, **kwargs)
|
69
51
|
|
70
52
|
@classmethod
|
71
|
-
@snowpark._internal.utils.private_preview(version="1.7.3")
|
72
53
|
def from_sql(
|
73
|
-
cls:
|
54
|
+
cls: type[DataConnectorType],
|
74
55
|
query: str,
|
75
56
|
session: Optional[snowpark.Session] = None,
|
76
|
-
ingestor_class: Optional[
|
57
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
77
58
|
**kwargs: Any,
|
78
59
|
) -> DataConnectorType:
|
79
60
|
session = session or sf_context.get_active_session()
|
@@ -82,9 +63,9 @@ class DataConnector:
|
|
82
63
|
|
83
64
|
@classmethod
|
84
65
|
def from_dataset(
|
85
|
-
cls:
|
66
|
+
cls: type[DataConnectorType],
|
86
67
|
ds: "dataset.Dataset",
|
87
|
-
ingestor_class: Optional[
|
68
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
88
69
|
**kwargs: Any,
|
89
70
|
) -> DataConnectorType:
|
90
71
|
dsv = ds.selected_version
|
@@ -101,10 +82,10 @@ class DataConnector:
|
|
101
82
|
func_params_to_log=["sources", "ingestor_class"],
|
102
83
|
)
|
103
84
|
def from_sources(
|
104
|
-
cls:
|
85
|
+
cls: type[DataConnectorType],
|
105
86
|
session: snowpark.Session,
|
106
87
|
sources: Sequence[data_source.DataSource],
|
107
|
-
ingestor_class: Optional[
|
88
|
+
ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None,
|
108
89
|
**kwargs: Any,
|
109
90
|
) -> DataConnectorType:
|
110
91
|
ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
|
@@ -112,7 +93,7 @@ class DataConnector:
|
|
112
93
|
return cls(ingestor, **kwargs)
|
113
94
|
|
114
95
|
@property
|
115
|
-
def data_sources(self) ->
|
96
|
+
def data_sources(self) -> list[data_source.DataSource]:
|
116
97
|
return self._ingestor.data_sources
|
117
98
|
|
118
99
|
@telemetry.send_api_usage_telemetry(
|
@@ -138,7 +119,7 @@ class DataConnector:
|
|
138
119
|
"""
|
139
120
|
import tensorflow as tf
|
140
121
|
|
141
|
-
def generator() -> Generator[
|
122
|
+
def generator() -> Generator[dict[str, npt.NDArray[Any]], None, None]:
|
142
123
|
yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
|
143
124
|
|
144
125
|
# Derive TensorFlow signature
|
@@ -241,14 +222,37 @@ class DataConnector:
|
|
241
222
|
"""
|
242
223
|
return self._ingestor.to_pandas(limit)
|
243
224
|
|
225
|
+
@telemetry.send_api_usage_telemetry(
|
226
|
+
project=_PROJECT,
|
227
|
+
subproject_extractor=lambda self: type(self).__name__,
|
228
|
+
func_params_to_log=["limit"],
|
229
|
+
)
|
230
|
+
def to_ray_dataset(self) -> "ray.data.Dataset":
|
231
|
+
"""Retrieve the Snowflake data as a Ray Dataset.
|
232
|
+
|
233
|
+
Returns:
|
234
|
+
A Ray Dataset.
|
235
|
+
|
236
|
+
Raises:
|
237
|
+
ImportError: If Ray is not installed in the local environment.
|
238
|
+
"""
|
239
|
+
if hasattr(self._ingestor, "to_ray_dataset"):
|
240
|
+
return self._ingestor.to_ray_dataset()
|
241
|
+
|
242
|
+
try:
|
243
|
+
import ray
|
244
|
+
|
245
|
+
return ray.data.from_pandas(self._ingestor.to_pandas())
|
246
|
+
except ImportError as e:
|
247
|
+
raise ImportError("Ray is not installed, please install ray in your local environment.") from e
|
248
|
+
|
244
249
|
|
245
250
|
# Switch to use Runtime's Data Ingester if running in ML runtime
|
246
251
|
# Fail silently if the data ingester is not found
|
247
|
-
if
|
252
|
+
if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
|
248
253
|
try:
|
249
254
|
from runtime_external_entities import get_ingester_class
|
250
255
|
|
251
256
|
DataConnector.DEFAULT_INGESTOR_CLASS = get_ingester_class()
|
252
257
|
except ImportError:
|
253
258
|
"""Runtime Default Ingester not found, ignore"""
|
254
|
-
pass
|
@@ -1,15 +1,4 @@
|
|
1
|
-
from typing import
|
2
|
-
TYPE_CHECKING,
|
3
|
-
Any,
|
4
|
-
Dict,
|
5
|
-
Iterator,
|
6
|
-
List,
|
7
|
-
Optional,
|
8
|
-
Protocol,
|
9
|
-
Sequence,
|
10
|
-
Type,
|
11
|
-
TypeVar,
|
12
|
-
)
|
1
|
+
from typing import TYPE_CHECKING, Any, Iterator, Optional, Protocol, Sequence, TypeVar
|
13
2
|
|
14
3
|
from numpy import typing as npt
|
15
4
|
|
@@ -26,12 +15,12 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
|
|
26
15
|
class DataIngestor(Protocol):
|
27
16
|
@classmethod
|
28
17
|
def from_sources(
|
29
|
-
cls:
|
18
|
+
cls: type[DataIngestorType], session: snowpark.Session, sources: Sequence[data_source.DataSource]
|
30
19
|
) -> DataIngestorType:
|
31
20
|
raise NotImplementedError
|
32
21
|
|
33
22
|
@property
|
34
|
-
def data_sources(self) ->
|
23
|
+
def data_sources(self) -> list[data_source.DataSource]:
|
35
24
|
raise NotImplementedError
|
36
25
|
|
37
26
|
def to_batches(
|
@@ -39,7 +28,7 @@ class DataIngestor(Protocol):
|
|
39
28
|
batch_size: int,
|
40
29
|
shuffle: bool = True,
|
41
30
|
drop_last_batch: bool = True,
|
42
|
-
) -> Iterator[
|
31
|
+
) -> Iterator[dict[str, npt.NDArray[Any]]]:
|
43
32
|
raise NotImplementedError
|
44
33
|
|
45
34
|
def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
|
snowflake/ml/data/data_source.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import dataclasses
|
2
|
-
from typing import
|
2
|
+
from typing import Optional, Union
|
3
3
|
|
4
4
|
|
5
5
|
@dataclasses.dataclass(frozen=True)
|
@@ -17,7 +17,7 @@ class DatasetInfo:
|
|
17
17
|
fully_qualified_name: str
|
18
18
|
version: str
|
19
19
|
url: Optional[str] = None
|
20
|
-
exclude_cols: Optional[
|
20
|
+
exclude_cols: Optional[list[str]] = None
|
21
21
|
|
22
22
|
|
23
23
|
DataSource = Union[DataFrameInfo, DatasetInfo, str]
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional
|
2
2
|
|
3
3
|
import fsspec
|
4
4
|
import pyarrow as pa
|
@@ -33,7 +33,7 @@ def _get_dataframe_cursor(session: snowpark.Session, df_info: data_source.DataFr
|
|
33
33
|
|
34
34
|
def get_dataframe_result_batches(
|
35
35
|
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
36
|
-
) ->
|
36
|
+
) -> list[result_batch.ResultBatch]:
|
37
37
|
"""Retrieve the ResultBatches for a given query"""
|
38
38
|
cursor = _get_dataframe_cursor(session, df_info)
|
39
39
|
batches = cursor.get_result_batches()
|
@@ -63,7 +63,7 @@ def get_dataset_filesystem(
|
|
63
63
|
|
64
64
|
def get_dataset_files(
|
65
65
|
session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
|
66
|
-
) ->
|
66
|
+
) -> list[str]:
|
67
67
|
"""Get the list of files in a given Dataset"""
|
68
68
|
if filesystem is None:
|
69
69
|
filesystem = get_dataset_filesystem(session, ds_info)
|
snowflake/ml/data/torch_utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Iterator, Optional, Union
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import numpy.typing as npt
|
@@ -7,7 +7,7 @@ import torch.utils.data
|
|
7
7
|
from snowflake.ml.data import data_ingestor
|
8
8
|
|
9
9
|
|
10
|
-
class TorchDatasetWrapper(torch.utils.data.IterableDataset[
|
10
|
+
class TorchDatasetWrapper(torch.utils.data.IterableDataset[dict[str, Any]]):
|
11
11
|
"""Wrap a DataIngestor into a PyTorch IterableDataset"""
|
12
12
|
|
13
13
|
def __init__(
|
@@ -32,7 +32,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
32
32
|
self._squeeze_outputs = squeeze
|
33
33
|
self._expand_dims = expand_dims
|
34
34
|
|
35
|
-
def __iter__(self) -> Iterator[
|
35
|
+
def __iter__(self) -> Iterator[dict[str, Union[npt.NDArray[Any], list[Any]]]]:
|
36
36
|
max_idx = 0
|
37
37
|
filter_idx = 0
|
38
38
|
worker_info = torch.utils.data.get_worker_info()
|
@@ -59,7 +59,7 @@ class TorchDatasetWrapper(torch.utils.data.IterableDataset[Dict[str, Any]]):
|
|
59
59
|
counter = 0
|
60
60
|
|
61
61
|
|
62
|
-
class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[
|
62
|
+
class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[dict[str, Any]]):
|
63
63
|
"""Wrap a DataIngestor into a PyTorch IterDataPipe"""
|
64
64
|
|
65
65
|
def __init__(
|
@@ -77,7 +77,7 @@ class TorchDataPipeWrapper(TorchDatasetWrapper, torch.utils.data.IterDataPipe[Di
|
|
77
77
|
|
78
78
|
def _preprocess_array(
|
79
79
|
arr: npt.NDArray[Any], squeeze: bool = False, expand_dims: bool = True
|
80
|
-
) -> Union[npt.NDArray[Any],
|
80
|
+
) -> Union[npt.NDArray[Any], list[np.object_]]:
|
81
81
|
"""Preprocesses batch column values."""
|
82
82
|
single_dimensional = arr.ndim < 2 and not arr.dtype == np.object_
|
83
83
|
|
snowflake/ml/dataset/dataset.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import warnings
|
3
3
|
from datetime import datetime
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, Union
|
5
5
|
|
6
6
|
from snowflake import snowpark
|
7
7
|
from snowflake.ml._internal import telemetry
|
@@ -46,8 +46,8 @@ class DatasetVersion:
|
|
46
46
|
self._version = version
|
47
47
|
self._session: snowpark.Session = self._parent._session
|
48
48
|
|
49
|
-
self._properties: Optional[
|
50
|
-
self._raw_metadata: Optional[
|
49
|
+
self._properties: Optional[dict[str, Any]] = None
|
50
|
+
self._raw_metadata: Optional[dict[str, Any]] = None
|
51
51
|
self._metadata: Optional[dataset_metadata.DatasetMetadata] = None
|
52
52
|
|
53
53
|
@property
|
@@ -66,14 +66,14 @@ class DatasetVersion:
|
|
66
66
|
return comment
|
67
67
|
|
68
68
|
@property
|
69
|
-
def label_cols(self) ->
|
69
|
+
def label_cols(self) -> list[str]:
|
70
70
|
metadata = self._get_metadata()
|
71
71
|
if metadata is None or metadata.label_cols is None:
|
72
72
|
return []
|
73
73
|
return metadata.label_cols
|
74
74
|
|
75
75
|
@property
|
76
|
-
def exclude_cols(self) ->
|
76
|
+
def exclude_cols(self) -> list[str]:
|
77
77
|
metadata = self._get_metadata()
|
78
78
|
if metadata is None or metadata.exclude_cols is None:
|
79
79
|
return []
|
@@ -115,7 +115,7 @@ class DatasetVersion:
|
|
115
115
|
return path
|
116
116
|
|
117
117
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
118
|
-
def list_files(self, subdir: Optional[str] = None) ->
|
118
|
+
def list_files(self, subdir: Optional[str] = None) -> list[snowpark.Row]:
|
119
119
|
"""Get the list of remote file paths for the current DatasetVersion."""
|
120
120
|
return self._session.sql(f"LIST {self.url()}{subdir or ''}").collect(
|
121
121
|
statement_params=_TELEMETRY_STATEMENT_PARAMS
|
@@ -244,7 +244,7 @@ class Dataset(lineage_node.LineageNode):
|
|
244
244
|
raise
|
245
245
|
|
246
246
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
247
|
-
def list_versions(self, detailed: bool = False) -> Union[
|
247
|
+
def list_versions(self, detailed: bool = False) -> Union[list[str], list[snowpark.Row]]:
|
248
248
|
"""Return list of versions"""
|
249
249
|
versions = self._list_versions()
|
250
250
|
versions.sort(key=lambda r: r[_DATASET_VERSION_NAME_COL])
|
@@ -271,8 +271,8 @@ class Dataset(lineage_node.LineageNode):
|
|
271
271
|
version: str,
|
272
272
|
input_dataframe: snowpark.DataFrame,
|
273
273
|
shuffle: bool = False,
|
274
|
-
exclude_cols: Optional[
|
275
|
-
label_cols: Optional[
|
274
|
+
exclude_cols: Optional[list[str]] = None,
|
275
|
+
label_cols: Optional[list[str]] = None,
|
276
276
|
properties: Optional[dataset_metadata.DatasetPropertiesType] = None,
|
277
277
|
partition_by: Optional[str] = None,
|
278
278
|
comment: Optional[str] = None,
|
@@ -423,7 +423,7 @@ class Dataset(lineage_node.LineageNode):
|
|
423
423
|
statement_params=_TELEMETRY_STATEMENT_PARAMS
|
424
424
|
)
|
425
425
|
|
426
|
-
def _list_versions(self, pattern: Optional[str] = None) ->
|
426
|
+
def _list_versions(self, pattern: Optional[str] = None) -> list[snowpark.Row]:
|
427
427
|
"""Return list of versions"""
|
428
428
|
try:
|
429
429
|
pattern_clause = f" LIKE '{pattern}'" if pattern else ""
|
@@ -469,7 +469,7 @@ lineage_node.DOMAIN_LINEAGE_REGISTRY["dataset"] = Dataset
|
|
469
469
|
# Utility methods
|
470
470
|
|
471
471
|
|
472
|
-
def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) ->
|
472
|
+
def _get_schema_level_identifier(session: snowpark.Session, dataset_name: str) -> tuple[str, str, str]:
|
473
473
|
"""Resolve a dataset name into a validated schema-level location identifier"""
|
474
474
|
db, schema, object_name = identifier.parse_schema_level_object_identifier(dataset_name)
|
475
475
|
db = db or session.get_current_database()
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import json
|
3
3
|
import typing
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Optional, Union
|
5
5
|
|
6
6
|
_PROPERTY_TYPE_KEY = "$proptype$"
|
7
7
|
DATASET_SCHEMA_VERSION = "1"
|
@@ -20,15 +20,15 @@ class FeatureStoreMetadata:
|
|
20
20
|
"""
|
21
21
|
|
22
22
|
spine_query: str
|
23
|
-
serialized_feature_views: Optional[
|
24
|
-
compact_feature_views: Optional[
|
23
|
+
serialized_feature_views: Optional[list[str]] = None
|
24
|
+
compact_feature_views: Optional[list[str]] = None
|
25
25
|
spine_timestamp_col: Optional[str] = None
|
26
26
|
|
27
27
|
def to_json(self) -> str:
|
28
28
|
return json.dumps(dataclasses.asdict(self))
|
29
29
|
|
30
30
|
@classmethod
|
31
|
-
def from_json(cls, input_json: Union[
|
31
|
+
def from_json(cls, input_json: Union[dict[str, Any], str, bytes]) -> "FeatureStoreMetadata":
|
32
32
|
if isinstance(input_json, dict):
|
33
33
|
return cls(**input_json)
|
34
34
|
return cls(**json.loads(input_json))
|
@@ -61,8 +61,8 @@ class DatasetMetadata:
|
|
61
61
|
|
62
62
|
source_query: str
|
63
63
|
owner: str
|
64
|
-
exclude_cols: Optional[
|
65
|
-
label_cols: Optional[
|
64
|
+
exclude_cols: Optional[list[str]] = None
|
65
|
+
label_cols: Optional[list[str]] = None
|
66
66
|
properties: Optional[DatasetPropertiesType] = None
|
67
67
|
schema_version: str = dataclasses.field(default=DATASET_SCHEMA_VERSION, init=False)
|
68
68
|
|
@@ -78,11 +78,11 @@ class DatasetMetadata:
|
|
78
78
|
return json.dumps(state_dict)
|
79
79
|
|
80
80
|
@classmethod
|
81
|
-
def from_json(cls, input_json: Union[
|
81
|
+
def from_json(cls, input_json: Union[dict[str, Any], str, bytes]) -> "DatasetMetadata":
|
82
82
|
if not input_json:
|
83
83
|
raise ValueError("json_str was empty or None")
|
84
84
|
try:
|
85
|
-
state_dict:
|
85
|
+
state_dict: dict[str, Any] = (
|
86
86
|
input_json if isinstance(input_json, dict) else json.loads(input_json, strict=False)
|
87
87
|
)
|
88
88
|
|
@@ -1,10 +1,11 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Optional
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.ml._internal import telemetry
|
5
5
|
from snowflake.ml._internal.lineage import lineage_utils
|
6
6
|
from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
|
7
7
|
from snowflake.ml.fileset import snowfs
|
8
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
8
9
|
|
9
10
|
_PROJECT = "Dataset"
|
10
11
|
_SUBPROJECT = "DatasetReader"
|
@@ -24,21 +25,21 @@ class DatasetReader(data_connector.DataConnector):
|
|
24
25
|
|
25
26
|
self._session: snowpark.Session = snowpark_session
|
26
27
|
self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
|
27
|
-
self._files: Optional[
|
28
|
+
self._files: Optional[list[str]] = None
|
28
29
|
|
29
30
|
@classmethod
|
30
31
|
def from_dataframe(
|
31
|
-
cls, df: snowpark.DataFrame, ingestor_class: Optional[
|
32
|
+
cls, df: snowpark.DataFrame, ingestor_class: Optional[type[data_ingestor.DataIngestor]] = None, **kwargs: Any
|
32
33
|
) -> "DatasetReader":
|
33
34
|
# Block superclass constructor from Snowpark DataFrames
|
34
35
|
raise RuntimeError("Creating DatasetReader from DataFrames not supported")
|
35
36
|
|
36
|
-
def _list_files(self) ->
|
37
|
+
def _list_files(self) -> list[str]:
|
37
38
|
"""Private helper function that lists all files in this DatasetVersion and caches the results."""
|
38
39
|
if self._files:
|
39
40
|
return self._files
|
40
41
|
|
41
|
-
files:
|
42
|
+
files: list[str] = []
|
42
43
|
for source in self.data_sources:
|
43
44
|
assert isinstance(source, data_source.DatasetInfo)
|
44
45
|
files.extend(ingestor_utils.get_dataset_files(self._session, source, filesystem=self._fs))
|
@@ -48,7 +49,7 @@ class DatasetReader(data_connector.DataConnector):
|
|
48
49
|
return self._files
|
49
50
|
|
50
51
|
@telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
|
51
|
-
def files(self) ->
|
52
|
+
def files(self) -> list[str]:
|
52
53
|
"""Get the list of remote file paths for the current DatasetVersion.
|
53
54
|
|
54
55
|
The file paths follows the snow protocol.
|
@@ -91,10 +92,13 @@ class DatasetReader(data_connector.DataConnector):
|
|
91
92
|
For example, an OBJECT column may be scanned back as a STRING column.
|
92
93
|
"""
|
93
94
|
file_path_pattern = ".*data_.*[.]parquet"
|
94
|
-
dfs:
|
95
|
+
dfs: list[snowpark.DataFrame] = []
|
95
96
|
for source in self.data_sources:
|
96
97
|
assert isinstance(source, data_source.DatasetInfo) and source.url is not None
|
97
|
-
|
98
|
+
stage_reader = self._session.read.option("pattern", file_path_pattern)
|
99
|
+
if "INFER_SCHEMA_OPTIONS" in snowpark_utils.NON_FORMAT_TYPE_OPTIONS:
|
100
|
+
stage_reader = stage_reader.option("INFER_SCHEMA_OPTIONS", {"MAX_FILE_COUNT": 1})
|
101
|
+
df = stage_reader.parquet(source.url)
|
98
102
|
if only_feature_cols and source.exclude_cols:
|
99
103
|
df = df.drop(source.exclude_cols)
|
100
104
|
dfs.append(df)
|
@@ -4,7 +4,7 @@ from snowflake.ml._internal import init_utils
|
|
4
4
|
|
5
5
|
from .access_manager import setup_feature_store
|
6
6
|
|
7
|
-
pkg_dir = os.path.dirname(
|
7
|
+
pkg_dir = os.path.dirname(__file__)
|
8
8
|
pkg_name = __name__
|
9
9
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
10
10
|
for k, v in exportable_classes.items():
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from dataclasses import asdict, dataclass
|
2
2
|
from enum import Enum
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
from warnings import warn
|
5
5
|
|
6
6
|
from snowflake.ml._internal import telemetry
|
@@ -28,7 +28,7 @@ class _FeatureStoreRole(Enum):
|
|
28
28
|
class _Privilege:
|
29
29
|
object_type: str
|
30
30
|
object_name: str
|
31
|
-
privileges:
|
31
|
+
privileges: list[str]
|
32
32
|
scope: Optional[str] = None
|
33
33
|
optional: bool = False
|
34
34
|
|
@@ -41,7 +41,7 @@ class _SessionInfo:
|
|
41
41
|
|
42
42
|
|
43
43
|
# Lists of permissions as tuples of (OBJECT_TYPE, [PRIVILEGES, ...])
|
44
|
-
_PRE_INIT_PRIVILEGES:
|
44
|
+
_PRE_INIT_PRIVILEGES: dict[_FeatureStoreRole, list[_Privilege]] = {
|
45
45
|
_FeatureStoreRole.PRODUCER: [
|
46
46
|
_Privilege("DATABASE", "{database}", ["USAGE"]),
|
47
47
|
_Privilege("SCHEMA", "{database}.{schema}", ["USAGE"]),
|
@@ -78,7 +78,7 @@ _PRE_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
|
|
78
78
|
_FeatureStoreRole.NONE: [],
|
79
79
|
}
|
80
80
|
|
81
|
-
_POST_INIT_PRIVILEGES:
|
81
|
+
_POST_INIT_PRIVILEGES: dict[_FeatureStoreRole, list[_Privilege]] = {
|
82
82
|
_FeatureStoreRole.PRODUCER: [
|
83
83
|
_Privilege("TAG", f"{{database}}.{{schema}}.{_FEATURE_VIEW_METADATA_TAG}", ["APPLY"]),
|
84
84
|
_Privilege("TAG", f"{{database}}.{{schema}}.{_FEATURE_STORE_OBJECT_TAG}", ["APPLY"]),
|
@@ -89,7 +89,7 @@ _POST_INIT_PRIVILEGES: Dict[_FeatureStoreRole, List[_Privilege]] = {
|
|
89
89
|
|
90
90
|
|
91
91
|
def _grant_privileges(
|
92
|
-
session: Session, role_name: str, privileges:
|
92
|
+
session: Session, role_name: str, privileges: list[_Privilege], session_info: _SessionInfo
|
93
93
|
) -> None:
|
94
94
|
session_info_dict = asdict(session_info)
|
95
95
|
for p in privileges:
|
@@ -129,7 +129,7 @@ def _grant_privileges(
|
|
129
129
|
def _configure_pre_init_privileges(
|
130
130
|
session: Session,
|
131
131
|
session_info: _SessionInfo,
|
132
|
-
roles_to_create:
|
132
|
+
roles_to_create: dict[_FeatureStoreRole, str],
|
133
133
|
) -> None:
|
134
134
|
"""
|
135
135
|
Configure Feature Store role privileges. Must be run with ACCOUNTADMIN
|
@@ -172,7 +172,7 @@ def _configure_pre_init_privileges(
|
|
172
172
|
def _configure_post_init_privileges(
|
173
173
|
session: Session,
|
174
174
|
session_info: _SessionInfo,
|
175
|
-
roles_to_create:
|
175
|
+
roles_to_create: dict[_FeatureStoreRole, str],
|
176
176
|
) -> None:
|
177
177
|
for role_type, role in roles_to_create.items():
|
178
178
|
_grant_privileges(session, role, _POST_INIT_PRIVILEGES[role_type], session_info)
|