snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +66 -35
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/env_utils.py +11 -5
  6. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  7. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  8. snowflake/ml/_internal/telemetry.py +26 -2
  9. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  10. snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
  11. snowflake/ml/data/data_connector.py +186 -0
  12. snowflake/ml/data/data_ingestor.py +45 -0
  13. snowflake/ml/data/data_source.py +23 -0
  14. snowflake/ml/data/ingestor_utils.py +62 -0
  15. snowflake/ml/data/torch_dataset.py +33 -0
  16. snowflake/ml/dataset/dataset.py +1 -13
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +23 -117
  19. snowflake/ml/feature_store/access_manager.py +7 -1
  20. snowflake/ml/feature_store/entity.py +19 -2
  21. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  22. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  23. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  24. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  26. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
  27. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
  28. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
  29. snowflake/ml/feature_store/examples/example_helper.py +278 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  31. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
  32. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
  34. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  35. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  36. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  37. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  38. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  39. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  40. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
  43. snowflake/ml/feature_store/feature_store.py +637 -76
  44. snowflake/ml/feature_store/feature_view.py +316 -9
  45. snowflake/ml/fileset/stage_fs.py +18 -10
  46. snowflake/ml/lineage/lineage_node.py +1 -1
  47. snowflake/ml/model/_client/model/model_impl.py +11 -2
  48. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  49. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  50. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  51. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  52. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  53. snowflake/ml/model/_client/sql/model_version.py +13 -4
  54. snowflake/ml/model/_client/sql/service.py +129 -0
  55. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  56. snowflake/ml/model/_model_composer/model_composer.py +14 -14
  57. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
  58. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
  59. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  60. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  61. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  62. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  63. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  64. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  65. snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
  66. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  67. snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
  68. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  69. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  70. snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
  71. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  72. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  73. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  74. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  75. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  76. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  77. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  78. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  79. snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
  80. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  81. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  82. snowflake/ml/model/_packager/model_packager.py +2 -1
  83. snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
  84. snowflake/ml/model/model_signature.py +4 -4
  85. snowflake/ml/model/type_hints.py +2 -0
  86. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  87. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  88. snowflake/ml/modeling/framework/base.py +28 -19
  89. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  90. snowflake/ml/modeling/pipeline/pipeline.py +7 -4
  91. snowflake/ml/registry/_manager/model_manager.py +16 -2
  92. snowflake/ml/registry/registry.py +100 -13
  93. snowflake/ml/utils/sql_client.py +22 -0
  94. snowflake/ml/version.py +1 -1
  95. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
  96. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
  97. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
  98. snowflake/ml/_internal/lineage/data_source.py +0 -10
  99. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  100. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,284 @@
1
+ import collections
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import pandas as pd
10
+ import pyarrow as pa
11
+ import pyarrow.dataset as pds
12
+
13
+ from snowflake import snowpark
14
+ from snowflake.connector import result_batch
15
+ from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
16
+
17
+ _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
18
+
19
+ # The row count for batches read from PyArrow Dataset. This number should be large enough so that
20
+ # dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
21
+ _DEFAULT_DATASET_BATCH_SIZE = 1000000
22
+
23
+
24
+ class _RecordBatchesBuffer:
25
+ """A queue that stores record batches and tracks the total num of rows in it."""
26
+
27
+ def __init__(self) -> None:
28
+ self.buffer: Deque[pa.RecordBatch] = collections.deque()
29
+ self.num_rows = 0
30
+
31
+ def append(self, rb: pa.RecordBatch) -> None:
32
+ self.buffer.append(rb)
33
+ self.num_rows += rb.num_rows
34
+
35
+ def appendleft(self, rb: pa.RecordBatch) -> None:
36
+ self.buffer.appendleft(rb)
37
+ self.num_rows += rb.num_rows
38
+
39
+ def popleft(self) -> pa.RecordBatch:
40
+ popped = self.buffer.popleft()
41
+ self.num_rows -= popped.num_rows
42
+ return popped
43
+
44
+
45
+ class ArrowIngestor(data_ingestor.DataIngestor):
46
+ """Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
47
+
48
+ def __init__(
49
+ self,
50
+ session: snowpark.Session,
51
+ data_sources: List[data_source.DataSource],
52
+ format: Optional[str] = None,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ """
56
+ Args:
57
+ session: The Snowpark Session to use.
58
+ data_sources: List of data sources to ingest.
59
+ format: Currently “parquet”, “ipc”/”arrow”/”feather”, “csv”, “json”, and “orc” are supported.
60
+ Will be inferred if not specified.
61
+ kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
62
+ """
63
+ self._session = session
64
+ self._data_sources = data_sources
65
+ self._format = format
66
+ self._kwargs = kwargs
67
+
68
+ self._schema: Optional[pa.Schema] = None
69
+
70
+ @classmethod
71
+ def from_sources(cls, session: snowpark.Session, sources: List[data_source.DataSource]) -> "ArrowIngestor":
72
+ return cls(session, sources)
73
+
74
+ @property
75
+ def data_sources(self) -> List[data_source.DataSource]:
76
+ return self._data_sources
77
+
78
+ def to_batches(
79
+ self,
80
+ batch_size: int,
81
+ shuffle: bool = True,
82
+ drop_last_batch: bool = True,
83
+ ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
84
+ """Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
85
+
86
+ As we are generating batches with the exactly same length, the last few rows in each file might get left as they
87
+ are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
88
+ few rows of the next file to generate a new batch.
89
+
90
+ Args:
91
+ batch_size: Specifies the size of each batch that will be yield
92
+ shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
93
+ the order of files, and then shuflle the order of rows in each file.
94
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last
95
+ batch will get dropped if its size is smaller than the given batch_size.
96
+
97
+ Yields:
98
+ A dict mapping column names to the corresponding data fetch from that column.
99
+ """
100
+ self._rb_buffer = _RecordBatchesBuffer()
101
+
102
+ # Extract schema if not already known
103
+ dataset = self._get_dataset(shuffle)
104
+ if self._schema is None:
105
+ self._schema = dataset.schema
106
+
107
+ for rb in _retryable_batches(dataset, batch_size=max(_DEFAULT_DATASET_BATCH_SIZE, batch_size)):
108
+ if shuffle:
109
+ rb = rb.take(np.random.permutation(rb.num_rows))
110
+ self._rb_buffer.append(rb)
111
+ while self._rb_buffer.num_rows >= batch_size:
112
+ yield self._get_batches_from_buffer(batch_size)
113
+
114
+ if self._rb_buffer.num_rows and not drop_last_batch:
115
+ yield self._get_batches_from_buffer(batch_size)
116
+
117
+ def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
118
+ ds = self._get_dataset(shuffle=False)
119
+ table = ds.to_table() if limit is None else ds.head(num_rows=limit)
120
+ return table.to_pandas()
121
+
122
+ def _get_dataset(self, shuffle: bool) -> pds.Dataset:
123
+ format = self._format
124
+ sources: List[Any] = []
125
+ source_format = None
126
+ for source in self._data_sources:
127
+ if isinstance(source, str):
128
+ sources.append(source)
129
+ source_format = format or os.path.splitext(source)[-1]
130
+ elif isinstance(source, data_source.DatasetInfo):
131
+ if not self._kwargs.get("filesystem"):
132
+ self._kwargs["filesystem"] = ingestor_utils.get_dataset_filesystem(self._session, source)
133
+ sources.extend(
134
+ ingestor_utils.get_dataset_files(self._session, source, filesystem=self._kwargs["filesystem"])
135
+ )
136
+ source_format = "parquet"
137
+ elif isinstance(source, data_source.DataFrameInfo):
138
+ # FIXME: This currently loads all result batches into memory so that it
139
+ # can be passed into pyarrow.dataset as a list/tuple of pa.RecordBatches
140
+ # We may be able to optimize this by splitting the result batches into
141
+ # in-memory (first batch) and file URLs (subsequent batches) and creating a
142
+ # union dataset.
143
+ result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
144
+ sources.extend(
145
+ b.to_arrow(self._session.connection)
146
+ if isinstance(b, result_batch.ArrowResultBatch)
147
+ else b.to_arrow()
148
+ for b in result_batches
149
+ )
150
+ # HACK: Mitigate typing inconsistencies in Snowpark results
151
+ if len(sources) > 0:
152
+ sources = [_cast_if_needed(s, sources[-1].schema) for s in sources]
153
+ source_format = None # Arrow Dataset expects "None" for in-memory datasets
154
+ else:
155
+ raise RuntimeError(f"Unsupported data source type: {type(source)}")
156
+
157
+ # Make sure source types not mixed
158
+ if format and format != source_format:
159
+ raise RuntimeError(f"Unexpected data source format (expected {format}, found {source_format})")
160
+ format = source_format
161
+
162
+ # Re-shuffle input files on each iteration start
163
+ if shuffle:
164
+ np.random.shuffle(sources)
165
+ pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
166
+ return pa_dataset
167
+
168
+ def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
169
+ """Generate new batches from the existing record batch buffer."""
170
+ cnt_rbs_num_rows = 0
171
+ candidates = []
172
+
173
+ # Keep popping record batches in buffer until there are enough rows for a batch.
174
+ while self._rb_buffer.num_rows and cnt_rbs_num_rows < batch_size:
175
+ candidate = self._rb_buffer.popleft()
176
+ cnt_rbs_num_rows += candidate.num_rows
177
+ candidates.append(candidate)
178
+
179
+ # When there are more rows than needed, slice the last popped batch to fit batch_size.
180
+ if cnt_rbs_num_rows > batch_size:
181
+ row_diff = cnt_rbs_num_rows - batch_size
182
+ slice_target = candidates[-1]
183
+ cut_off = slice_target.num_rows - row_diff
184
+ to_merge = slice_target.slice(length=cut_off)
185
+ left_over = slice_target.slice(offset=cut_off)
186
+ candidates[-1] = to_merge
187
+ self._rb_buffer.appendleft(left_over)
188
+
189
+ res = _merge_record_batches(candidates)
190
+ return _record_batch_to_arrays(res)
191
+
192
+
193
+ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
194
+ """Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
195
+ if not record_batches:
196
+ return _EMPTY_RECORD_BATCH
197
+ if len(record_batches) == 1:
198
+ return record_batches[0]
199
+ record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
200
+ one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
201
+ batches = one_chunk_table.to_batches(max_chunksize=None)
202
+ return batches[0]
203
+
204
+
205
+ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
206
+ """Transform the record batch to a (string, numpy array) dict."""
207
+ batch_dict = {}
208
+ for column, column_schema in zip(rb, rb.schema):
209
+ # zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
210
+ array = column.to_numpy(zero_copy_only=False)
211
+ batch_dict[column_schema.name] = array
212
+ return batch_dict
213
+
214
+
215
+ def _retryable_batches(
216
+ dataset: pds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
217
+ ) -> Iterator[pa.RecordBatch]:
218
+ """Make the Dataset to_batches retryable."""
219
+ retries = 0
220
+ current_batch_index = 0
221
+
222
+ while True:
223
+ try:
224
+ for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
225
+ if batch_index < current_batch_index:
226
+ # Skip batches that have already been processed
227
+ continue
228
+
229
+ yield batch
230
+ current_batch_index = batch_index + 1
231
+ # Exit the loop once all batches are processed
232
+ break
233
+
234
+ except Exception as e:
235
+ if retries < max_retries:
236
+ retries += 1
237
+ logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
238
+ time.sleep(delay)
239
+ else:
240
+ raise e
241
+
242
+
243
+ def _cast_if_needed(
244
+ batch: Union[pa.Table, pa.RecordBatch], schema: Optional[pa.Schema] = None
245
+ ) -> Union[pa.Table, pa.RecordBatch]:
246
+ """
247
+ Cast the batch to be compatible with downstream frameworks. Returns original batch if cast is not necessary.
248
+ Besides casting types to match `schema` (if provided), this function also applies the following casting:
249
+ - Decimal (fixed-point) types: Convert to float or integer types based on scale and byte length
250
+
251
+ Args:
252
+ batch: The PyArrow batch to cast if needed
253
+ schema: Optional schema the batch should be casted to match. Note that compatibility type casting takes
254
+ precedence over the provided schema, e.g. if the schema has decimal types the result will be further
255
+ cast into integer/float types.
256
+
257
+ Returns:
258
+ The type-casted PyArrow batch, or the original batch if casting was not necessary
259
+ """
260
+ schema = schema or batch.schema
261
+ assert len(batch.schema) == len(schema)
262
+ fields = []
263
+ cast_needed = False
264
+ for field, target in zip(batch.schema, schema):
265
+ # Need to convert decimal types to supported types. This behavior supersedes target schema data types
266
+ if pa.types.is_decimal(target.type):
267
+ byte_length = int(target.metadata.get(b"byteLength", 8))
268
+ if int(target.metadata.get(b"scale", 0)) > 0:
269
+ target = target.with_type(pa.float32() if byte_length == 4 else pa.float64())
270
+ else:
271
+ if byte_length == 2:
272
+ target = target.with_type(pa.int16())
273
+ elif byte_length == 4:
274
+ target = target.with_type(pa.int32())
275
+ else: # Cap out at 64-bit
276
+ target = target.with_type(pa.int64())
277
+ if not field.equals(target):
278
+ cast_needed = True
279
+ field = target
280
+ fields.append(field)
281
+
282
+ if cast_needed:
283
+ return batch.cast(pa.schema(fields))
284
+ return batch
@@ -0,0 +1,186 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
+
3
+ import numpy.typing as npt
4
+ from typing_extensions import deprecated
5
+
6
+ from snowflake import snowpark
7
+ from snowflake.ml._internal import telemetry
8
+ from snowflake.ml.data import data_ingestor, data_source
9
+ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
10
+
11
+ if TYPE_CHECKING:
12
+ import pandas as pd
13
+ import tensorflow as tf
14
+ from torch.utils import data as torch_data
15
+
16
+ # This module can't actually depend on dataset to avoid a circular dependency
17
+ # Dataset -> DatasetReader -> DataConnector -!-> Dataset
18
+ from snowflake.ml import dataset
19
+
20
+ _PROJECT = "DataConnector"
21
+
22
+ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
23
+
24
+
25
+ class DataConnector:
26
+ """Snowflake data reader which provides application integration connectors"""
27
+
28
+ DEFAULT_INGESTOR_CLASS: Type[data_ingestor.DataIngestor] = ArrowIngestor
29
+
30
+ def __init__(
31
+ self,
32
+ ingestor: data_ingestor.DataIngestor,
33
+ ) -> None:
34
+ self._ingestor = ingestor
35
+
36
+ @classmethod
37
+ @snowpark._internal.utils.private_preview(version="1.6.0")
38
+ def from_dataframe(
39
+ cls: Type[DataConnectorType],
40
+ df: snowpark.DataFrame,
41
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
42
+ **kwargs: Any
43
+ ) -> DataConnectorType:
44
+ if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
45
+ raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
46
+ source = data_source.DataFrameInfo(df.queries["queries"][0])
47
+ assert df._session is not None
48
+ return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
49
+
50
+ @classmethod
51
+ def from_dataset(
52
+ cls: Type[DataConnectorType],
53
+ ds: "dataset.Dataset",
54
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
55
+ **kwargs: Any
56
+ ) -> DataConnectorType:
57
+ dsv = ds.selected_version
58
+ assert dsv is not None
59
+ source = data_source.DatasetInfo(
60
+ ds.fully_qualified_name, dsv.name, dsv.url(), exclude_cols=(dsv.label_cols + dsv.exclude_cols)
61
+ )
62
+ return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
63
+
64
+ @classmethod
65
+ @telemetry.send_api_usage_telemetry(
66
+ project=_PROJECT,
67
+ subproject_extractor=lambda cls: cls.__name__,
68
+ func_params_to_log=["sources", "ingestor_class"],
69
+ )
70
+ def from_sources(
71
+ cls: Type[DataConnectorType],
72
+ session: snowpark.Session,
73
+ sources: List[data_source.DataSource],
74
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
75
+ **kwargs: Any
76
+ ) -> DataConnectorType:
77
+ ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
78
+ ingestor = ingestor_class.from_sources(session, sources)
79
+ return cls(ingestor, **kwargs)
80
+
81
+ @property
82
+ def data_sources(self) -> List[data_source.DataSource]:
83
+ return self._ingestor.data_sources
84
+
85
+ @telemetry.send_api_usage_telemetry(
86
+ project=_PROJECT,
87
+ subproject_extractor=lambda self: type(self).__name__,
88
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
89
+ )
90
+ def to_tf_dataset(
91
+ self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
92
+ ) -> "tf.data.Dataset":
93
+ """Transform the Snowflake data into a ready-to-use TensorFlow tf.data.Dataset.
94
+
95
+ Args:
96
+ batch_size: It specifies the size of each data batch which will be
97
+ yield in the result datapipe
98
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
99
+ rows in each file will also be shuffled.
100
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
101
+ then the last batch will get dropped if its size is smaller than the given batch_size.
102
+
103
+ Returns:
104
+ A tf.data.Dataset that yields batched tf.Tensors.
105
+ """
106
+ import tensorflow as tf
107
+
108
+ def generator() -> Generator[Dict[str, npt.NDArray[Any]], None, None]:
109
+ yield from self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
110
+
111
+ # Derive TensorFlow signature
112
+ first_batch = next(self._ingestor.to_batches(1, shuffle=False, drop_last_batch=False))
113
+ tf_signature = {
114
+ k: tf.TensorSpec(shape=(None,), dtype=tf.dtypes.as_dtype(v.dtype), name=k) for k, v in first_batch.items()
115
+ }
116
+
117
+ return tf.data.Dataset.from_generator(generator, output_signature=tf_signature)
118
+
119
+ @deprecated(
120
+ "to_torch_datapipe() is deprecated and will be removed in a future release. Use to_torch_dataset() instead"
121
+ )
122
+ @telemetry.send_api_usage_telemetry(
123
+ project=_PROJECT,
124
+ subproject_extractor=lambda self: type(self).__name__,
125
+ func_params_to_log=["batch_size", "shuffle", "drop_last_batch"],
126
+ )
127
+ def to_torch_datapipe(
128
+ self, *, batch_size: int, shuffle: bool = False, drop_last_batch: bool = True
129
+ ) -> "torch_data.IterDataPipe": # type: ignore[type-arg]
130
+ """Transform the Snowflake data into a ready-to-use Pytorch datapipe.
131
+
132
+ Return a Pytorch datapipe which iterates on rows of data.
133
+
134
+ Args:
135
+ batch_size: It specifies the size of each data batch which will be
136
+ yield in the result datapipe
137
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
138
+ rows in each file will also be shuffled.
139
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true,
140
+ then the last batch will get dropped if its size is smaller than the given batch_size.
141
+
142
+ Returns:
143
+ A Pytorch iterable datapipe that yield data.
144
+ """
145
+ from torch.utils.data.datapipes import iter as torch_iter
146
+
147
+ return torch_iter.IterableWrapper( # type: ignore[no-untyped-call]
148
+ self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
149
+ )
150
+
151
+ @telemetry.send_api_usage_telemetry(
152
+ project=_PROJECT,
153
+ subproject_extractor=lambda self: type(self).__name__,
154
+ func_params_to_log=["shuffle"],
155
+ )
156
+ def to_torch_dataset(self, *, shuffle: bool = False) -> "torch_data.IterableDataset": # type: ignore[type-arg]
157
+ """Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader.
158
+
159
+ Return a PyTorch Dataset which iterates on rows of data.
160
+
161
+ Args:
162
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
163
+ rows in each file will also be shuffled.
164
+
165
+ Returns:
166
+ A PyTorch Iterable Dataset that yields data.
167
+ """
168
+ from snowflake.ml.data import torch_dataset
169
+
170
+ return torch_dataset.TorchDataset(self._ingestor, shuffle)
171
+
172
+ @telemetry.send_api_usage_telemetry(
173
+ project=_PROJECT,
174
+ subproject_extractor=lambda self: type(self).__name__,
175
+ func_params_to_log=["limit"],
176
+ )
177
+ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
178
+ """Retrieve the Snowflake data as a Pandas DataFrame.
179
+
180
+ Args:
181
+ limit: If specified, the maximum number of rows to load into the DataFrame.
182
+
183
+ Returns:
184
+ A Pandas DataFrame.
185
+ """
186
+ return self._ingestor.to_pandas(limit)
@@ -0,0 +1,45 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Dict,
5
+ Iterator,
6
+ List,
7
+ Optional,
8
+ Protocol,
9
+ Type,
10
+ TypeVar,
11
+ )
12
+
13
+ from numpy import typing as npt
14
+
15
+ from snowflake import snowpark
16
+ from snowflake.ml.data import data_source
17
+
18
+ if TYPE_CHECKING:
19
+ import pandas as pd
20
+
21
+
22
+ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
23
+
24
+
25
+ class DataIngestor(Protocol):
26
+ @classmethod
27
+ def from_sources(
28
+ cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
29
+ ) -> DataIngestorType:
30
+ raise NotImplementedError
31
+
32
+ @property
33
+ def data_sources(self) -> List[data_source.DataSource]:
34
+ raise NotImplementedError
35
+
36
+ def to_batches(
37
+ self,
38
+ batch_size: int,
39
+ shuffle: bool = True,
40
+ drop_last_batch: bool = True,
41
+ ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
42
+ raise NotImplementedError
43
+
44
+ def to_pandas(self, limit: Optional[int] = None) -> "pd.DataFrame":
45
+ raise NotImplementedError
@@ -0,0 +1,23 @@
1
+ import dataclasses
2
+ from typing import List, Optional, Union
3
+
4
+
5
+ @dataclasses.dataclass(frozen=True)
6
+ class DataFrameInfo:
7
+ """Serializable information from Snowpark DataFrames"""
8
+
9
+ sql: str
10
+ query_id: Optional[str] = None
11
+
12
+
13
+ @dataclasses.dataclass(frozen=True)
14
+ class DatasetInfo:
15
+ """Serializable information from SnowML Datasets"""
16
+
17
+ fully_qualified_name: str
18
+ version: str
19
+ url: Optional[str] = None
20
+ exclude_cols: Optional[List[str]] = None
21
+
22
+
23
+ DataSource = Union[DataFrameInfo, DatasetInfo, str]
@@ -0,0 +1,62 @@
1
+ from typing import List, Optional
2
+
3
+ import fsspec
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.connector import result_batch
7
+ from snowflake.ml.data import data_source
8
+ from snowflake.ml.fileset import snowfs
9
+
10
+ _TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
11
+
12
+
13
+ def get_dataframe_result_batches(
14
+ session: snowpark.Session, df_info: data_source.DataFrameInfo
15
+ ) -> List[result_batch.ResultBatch]:
16
+ """Retrieve the ResultBatches for a given query"""
17
+ cursor = session._conn._cursor
18
+
19
+ if df_info.query_id:
20
+ query_id = df_info.query_id
21
+ else:
22
+ query_id = session.sql(df_info.sql).collect_nowait().query_id
23
+
24
+ # TODO: Check if query result cache is still live
25
+ cursor.get_results_from_sfqid(sfqid=query_id)
26
+
27
+ # Prefetch hook should be set by `get_results_from_sfqid`
28
+ # This call blocks until the query results are ready
29
+ if cursor._prefetch_hook is None:
30
+ raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
31
+ cursor._prefetch_hook()
32
+ batches = cursor.get_result_batches()
33
+ if batches is None:
34
+ raise ValueError(
35
+ "Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}"
36
+ )
37
+ return batches
38
+
39
+
40
+ def get_dataset_filesystem(
41
+ session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
42
+ ) -> fsspec.AbstractFileSystem:
43
+ """Get the fsspec filesystem for a given Dataset"""
44
+ # We can't directly load the Dataset to avoid a circular dependency
45
+ # Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
46
+ # TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
47
+ return snowfs.SnowFileSystem(
48
+ snowpark_session=session,
49
+ cache_type="bytes",
50
+ block_size=2 * _TARGET_FILE_SIZE,
51
+ )
52
+
53
+
54
+ def get_dataset_files(
55
+ session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
56
+ ) -> List[str]:
57
+ """Get the list of files in a given Dataset"""
58
+ if filesystem is None:
59
+ filesystem = get_dataset_filesystem(session, ds_info)
60
+ assert bool(ds_info.url) # Not null or empty
61
+ files = sorted(filesystem.ls(ds_info.url))
62
+ return [filesystem.unstrip_protocol(f) for f in files]
@@ -0,0 +1,33 @@
1
+ from typing import Any, Dict, Iterator
2
+
3
+ import torch.utils.data
4
+
5
+ from snowflake.ml.data import data_ingestor
6
+
7
+
8
+ class TorchDataset(torch.utils.data.IterableDataset[Dict[str, Any]]):
9
+ """Implementation of PyTorch IterableDataset"""
10
+
11
+ def __init__(self, ingestor: data_ingestor.DataIngestor, shuffle: bool = False) -> None:
12
+ """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
13
+ self._ingestor = ingestor
14
+ self._shuffle = shuffle
15
+
16
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
17
+ max_idx = 0
18
+ filter_idx = 0
19
+ worker_info = torch.utils.data.get_worker_info()
20
+ if worker_info is not None:
21
+ max_idx = worker_info.num_workers - 1
22
+ filter_idx = worker_info.id
23
+
24
+ counter = 0
25
+ for batch in self._ingestor.to_batches(batch_size=1, shuffle=self._shuffle, drop_last_batch=False):
26
+ # Skip indices during multi-process data loading to prevent data duplication
27
+ if counter == filter_idx:
28
+ yield {k: v.item() for k, v in batch.items()}
29
+
30
+ if counter < max_idx:
31
+ counter += 1
32
+ else:
33
+ counter = 0
@@ -11,7 +11,6 @@ from snowflake.ml._internal.exceptions import (
11
11
  error_codes,
12
12
  exceptions as snowml_exceptions,
13
13
  )
14
- from snowflake.ml._internal.lineage import data_source
15
14
  from snowflake.ml._internal.utils import (
16
15
  formatting,
17
16
  identifier,
@@ -177,18 +176,7 @@ class Dataset(lineage_node.LineageNode):
177
176
  original_exception=RuntimeError("No Dataset version selected."),
178
177
  )
179
178
  if self._reader is None:
180
- v = self.selected_version
181
- self._reader = dataset_reader.DatasetReader(
182
- self._session,
183
- [
184
- data_source.DataSource(
185
- fully_qualified_name=self._lineage_node_name,
186
- version=v.name,
187
- url=v.url(),
188
- exclude_cols=(v.label_cols + v.exclude_cols),
189
- )
190
- ],
191
- )
179
+ self._reader = dataset_reader.DatasetReader.from_dataset(self, snowpark_session=self._session)
192
180
  return self._reader
193
181
 
194
182
  @staticmethod
@@ -15,11 +15,13 @@ class FeatureStoreMetadata:
15
15
  Properties:
16
16
  spine_query: The input query on source table which will be joined with features.
17
17
  serialized_feature_views: A list of serialized feature objects in the feature store.
18
+ compact_feature_views: A compact representation of a FeatureView or FeatureViewSlice.
18
19
  spine_timestamp_col: Timestamp column which was used for point-in-time correct feature lookup.
19
20
  """
20
21
 
21
22
  spine_query: str
22
- serialized_feature_views: List[str]
23
+ serialized_feature_views: Optional[List[str]] = None
24
+ compact_feature_views: Optional[List[str]] = None
23
25
  spine_timestamp_col: Optional[str] = None
24
26
 
25
27
  def to_json(self) -> str: