snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +14 -0
  5. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  6. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  7. snowflake/ml/data/data_connector.py +59 -6
  8. snowflake/ml/data/data_ingestor.py +18 -1
  9. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  10. snowflake/ml/data/torch_dataset.py +33 -0
  11. snowflake/ml/dataset/dataset_metadata.py +3 -1
  12. snowflake/ml/dataset/dataset_reader.py +9 -3
  13. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  16. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  25. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  26. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  27. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  29. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  30. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  31. snowflake/ml/feature_store/feature_store.py +59 -24
  32. snowflake/ml/feature_store/feature_view.py +148 -4
  33. snowflake/ml/model/_client/model/model_impl.py +11 -2
  34. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  35. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  36. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  37. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  38. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  39. snowflake/ml/model/_client/sql/model_version.py +13 -4
  40. snowflake/ml/model/_client/sql/service.py +129 -0
  41. snowflake/ml/model/_model_composer/model_composer.py +3 -0
  42. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  44. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  45. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
  47. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  48. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
  50. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  51. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  52. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  53. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  54. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  55. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  57. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  58. snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
  59. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  60. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  61. snowflake/ml/model/_packager/model_packager.py +2 -1
  62. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  63. snowflake/ml/model/type_hints.py +1 -3
  64. snowflake/ml/modeling/framework/base.py +28 -19
  65. snowflake/ml/modeling/pipeline/pipeline.py +3 -0
  66. snowflake/ml/registry/_manager/model_manager.py +16 -2
  67. snowflake/ml/utils/sql_client.py +22 -0
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
  70. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
  71. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  72. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
  74. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -90,7 +90,6 @@ def _call_complete_rest(
90
90
  prompt: Union[str, List[ConversationMessage]],
91
91
  options: Optional[CompleteOptions] = None,
92
92
  session: Optional[snowpark.Session] = None,
93
- stream: bool = False,
94
93
  ) -> requests.Response:
95
94
  session = session or context.get_active_session()
96
95
  if session is None:
@@ -121,7 +120,7 @@ def _call_complete_rest(
121
120
 
122
121
  data = {
123
122
  "model": model,
124
- "stream": stream,
123
+ "stream": True,
125
124
  }
126
125
  if isinstance(prompt, List):
127
126
  data["messages"] = prompt
@@ -137,32 +136,15 @@ def _call_complete_rest(
137
136
  if "top_p" in options:
138
137
  data["top_p"] = options["top_p"]
139
138
 
140
- logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
139
+ logger.debug(f"making POST request to {url} (model={model})")
141
140
  return requests.post(
142
141
  url,
143
142
  json=data,
144
143
  headers=headers,
145
- stream=stream,
144
+ stream=True,
146
145
  )
147
146
 
148
147
 
149
- def _process_rest_response(
150
- response: requests.Response,
151
- stream: bool = False,
152
- deadline: Optional[float] = None,
153
- ) -> Union[str, Iterator[str]]:
154
- if stream:
155
- return _return_stream_response(response, deadline)
156
-
157
- try:
158
- content = response.json()["choices"][0]["message"]["content"]
159
- assert isinstance(content, str)
160
- return content
161
- except (KeyError, IndexError, AssertionError) as e:
162
- # Unlike the streaming case, errors are not ignored because a message must be returned.
163
- raise ResponseParseException("Failed to parse message from response.") from e
164
-
165
-
166
148
  def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
167
149
  client = SSEClient(response)
168
150
  for event in client.events():
@@ -243,7 +225,6 @@ def _complete_impl(
243
225
  prompt: Union[str, List[ConversationMessage], snowpark.Column],
244
226
  options: Optional[CompleteOptions] = None,
245
227
  session: Optional[snowpark.Session] = None,
246
- use_rest_api_experimental: bool = False,
247
228
  stream: bool = False,
248
229
  function: str = "snowflake.cortex.complete",
249
230
  timeout: Optional[float] = None,
@@ -253,16 +234,14 @@ def _complete_impl(
253
234
  raise ValueError('only one of "timeout" and "deadline" must be set')
254
235
  if timeout is not None:
255
236
  deadline = time.time() + timeout
256
- if use_rest_api_experimental:
237
+ if stream:
257
238
  if not isinstance(model, str):
258
239
  raise ValueError("in REST mode, 'model' must be a string")
259
240
  if not isinstance(prompt, str) and not isinstance(prompt, List):
260
241
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
261
- response = _call_complete_rest(model, prompt, options, session=session, stream=stream, deadline=deadline)
242
+ response = _call_complete_rest(model, prompt, options, session=session, deadline=deadline)
262
243
  assert response.status_code >= 200 and response.status_code < 300
263
- return _process_rest_response(response, stream=stream)
264
- if stream is True:
265
- raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
244
+ return _return_stream_response(response, deadline)
266
245
  return _complete_sql_impl(function, model, prompt, options, session)
267
246
 
268
247
 
@@ -275,7 +254,6 @@ def Complete(
275
254
  *,
276
255
  options: Optional[CompleteOptions] = None,
277
256
  session: Optional[snowpark.Session] = None,
278
- use_rest_api_experimental: bool = False,
279
257
  stream: bool = False,
280
258
  timeout: Optional[float] = None,
281
259
  deadline: Optional[float] = None,
@@ -287,16 +265,13 @@ def Complete(
287
265
  prompt: A Column of prompts to send to the LLM.
288
266
  options: A instance of snowflake.cortex.CompleteOptions
289
267
  session: The snowpark session to use. Will be inferred by context if not specified.
290
- use_rest_api_experimental (bool): Toggles between the use of SQL and REST implementation. This feature is
291
- experimental and can be removed at any time.
292
268
  stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
293
269
  output as it is received. Each update is a string containing the new text content since the previous update.
294
- The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
295
270
  timeout (float): Timeout in seconds to retry failed REST requests.
296
271
  deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
297
272
 
298
273
  Raises:
299
- ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
274
+ ValueError: incorrect argument.
300
275
 
301
276
  Returns:
302
277
  A column of string responses.
@@ -307,7 +282,6 @@ def Complete(
307
282
  prompt,
308
283
  options=options,
309
284
  session=session,
310
- use_rest_api_experimental=use_rest_api_experimental,
311
285
  stream=stream,
312
286
  timeout=timeout,
313
287
  deadline=deadline,
@@ -27,7 +27,6 @@ class CONDA_OS(Enum):
27
27
  NO_ARCH = "noarch"
28
28
 
29
29
 
30
- _SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
31
30
  _NODEFAULTS = "nodefaults"
32
31
  _SNOWFLAKE_INFO_SCHEMA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
33
32
  _SNOWFLAKE_CONDA_PACKAGE_CACHE: Dict[str, List[version.Version]] = {}
@@ -36,6 +35,7 @@ _SUPPORTED_PACKAGE_SPEC_OPS = ["==", ">=", "<=", ">", "<"]
36
35
  DEFAULT_CHANNEL_NAME = ""
37
36
  SNOWML_SPROC_ENV = "IN_SNOWML_SPROC"
38
37
  SNOWPARK_ML_PKG_NAME = "snowflake-ml-python"
38
+ SNOWFLAKE_CONDA_CHANNEL_URL = "https://repo.anaconda.com/pkgs/snowflake"
39
39
 
40
40
 
41
41
  def _validate_pip_requirement_string(req_str: str) -> requirements.Requirement:
@@ -370,7 +370,7 @@ def get_matched_package_versions_in_snowflake_conda_channel(
370
370
 
371
371
  assert not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
372
372
 
373
- url = f"{_SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
373
+ url = f"{SNOWFLAKE_CONDA_CHANNEL_URL}/{conda_os.value}/repodata.json"
374
374
 
375
375
  if req.name not in _SNOWFLAKE_CONDA_PACKAGE_CACHE:
376
376
  try:
@@ -477,6 +477,7 @@ def save_conda_env_file(
477
477
  path: pathlib.Path,
478
478
  conda_chan_deps: DefaultDict[str, List[requirements.Requirement]],
479
479
  python_version: str,
480
+ default_channel_override: str = SNOWFLAKE_CONDA_CHANNEL_URL,
480
481
  ) -> None:
481
482
  """Generate conda.yml file given a dict of dependencies after validation.
482
483
  The channels part of conda.yml file will contains Snowflake Anaconda Channel, nodefaults and all channel names
@@ -489,6 +490,7 @@ def save_conda_env_file(
489
490
  path: Path to the conda.yml file.
490
491
  conda_chan_deps: Dict of conda dependencies after validated.
491
492
  python_version: A string 'major.minor' showing python version relate to model.
493
+ default_channel_override: The default channel to be put in the first place of the channels section.
492
494
  """
493
495
  assert path.suffix in [".yml", ".yaml"], "Conda environment file should have extension of yml or yaml."
494
496
  path.parent.mkdir(parents=True, exist_ok=True)
@@ -499,7 +501,11 @@ def save_conda_env_file(
499
501
  channels = list(dict(sorted(conda_chan_deps.items(), key=lambda item: len(item[1]), reverse=True)).keys())
500
502
  if DEFAULT_CHANNEL_NAME in channels:
501
503
  channels.remove(DEFAULT_CHANNEL_NAME)
502
- env["channels"] = [_SNOWFLAKE_CONDA_CHANNEL_URL] + channels + [_NODEFAULTS]
504
+
505
+ if default_channel_override in channels:
506
+ channels.remove(default_channel_override)
507
+
508
+ env["channels"] = [default_channel_override] + channels + [_NODEFAULTS]
503
509
  env["dependencies"] = [f"python=={python_version}.*"]
504
510
  for chan, reqs in conda_chan_deps.items():
505
511
  env["dependencies"].extend(
@@ -567,8 +573,8 @@ def load_conda_env_file(
567
573
  python_version = None
568
574
 
569
575
  channels = env.get("channels", [])
570
- if _SNOWFLAKE_CONDA_CHANNEL_URL in channels:
571
- channels.remove(_SNOWFLAKE_CONDA_CHANNEL_URL)
576
+ if len(channels) >= 1:
577
+ channels = channels[1:] # Skip the first channel which is the default channel
572
578
  if _NODEFAULTS in channels:
573
579
  channels.remove(_NODEFAULTS)
574
580
 
@@ -4,7 +4,10 @@ ATTRIBUTE_NOT_SET = (
4
4
  "-differences."
5
5
  )
6
6
  SIZE_MISMATCH = "Size mismatch: {}={}, {}={}."
7
- INVALID_MODEL_PARAM = "Invalid parameter {} for model {}. Valid parameters: {}."
7
+ INVALID_MODEL_PARAM = (
8
+ "Invalid parameter {} for model {}. Valid parameters: {}."
9
+ "Note: Scikit learn params cannot be set until the model has been fit."
10
+ )
8
11
  UNSUPPORTED_MODEL_CONVERSION = "Object doesn't support {}. Please use {}."
9
12
  INCOMPATIBLE_NEW_SKLEARN_PARAM = "Incompatible scikit-learn version: {} requires scikit-learn>={}. Installed: {}."
10
13
  REMOVED_SKLEARN_PARAM = "Incompatible scikit-learn version: {} is removed in scikit-learn>={}. Installed: {}."
@@ -44,6 +44,20 @@ _Args = ParamSpec("_Args")
44
44
  _ReturnValue = TypeVar("_ReturnValue")
45
45
 
46
46
 
47
+ @enum.unique
48
+ class TelemetryProject(enum.Enum):
49
+ MLOPS = "MLOps"
50
+ MODELING = "ModelDevelopment"
51
+ # TODO: Update with remaining projects.
52
+
53
+
54
+ @enum.unique
55
+ class TelemetrySubProject(enum.Enum):
56
+ MONITORING = "Monitoring"
57
+ REGISTRY = "ModelManagement"
58
+ # TODO: Update with remaining subprojects.
59
+
60
+
47
61
  @enum.unique
48
62
  class TelemetryField(enum.Enum):
49
63
  # constants
@@ -26,30 +26,11 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
26
26
  pkg_versions: List[str], session: Session, subproject: Optional[str] = None
27
27
  ) -> List[str]:
28
28
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
29
- return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(pkg_versions, session, subproject)
29
+ return pkg_versions
30
30
  else:
31
31
  return _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(pkg_versions, session, subproject)
32
32
 
33
33
 
34
- def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_sync(
35
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
36
- ) -> List[str]:
37
- for pkg_version in pkg_versions:
38
- if pkg_version not in cache:
39
- pkg_version_list = _query_pkg_version_supported_in_snowflake_conda_channel(
40
- pkg_version=pkg_version, session=session, block=True, subproject=subproject
41
- )
42
- assert isinstance(pkg_version_list, list) # keep mypy happy
43
- try:
44
- cache[pkg_version] = pkg_version_list[0]["VERSION"]
45
- except IndexError:
46
- cache[pkg_version] = None
47
-
48
- pkg_version_conda_list = _get_conda_packages_and_emit_warnings(pkg_versions)
49
-
50
- return pkg_version_conda_list
51
-
52
-
53
34
  def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
54
35
  pkg_versions: List[str], session: Session, subproject: Optional[str] = None
55
36
  ) -> List[str]:
@@ -60,7 +41,11 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
60
41
  async_job = _query_pkg_version_supported_in_snowflake_conda_channel(
61
42
  pkg_version=pkg_version, session=session, block=False, subproject=subproject
62
43
  )
63
- assert isinstance(async_job, AsyncJob)
44
+ if isinstance(async_job, list):
45
+ raise RuntimeError(
46
+ "Async job was expected, executed query was returned. Please contact Snowflake support."
47
+ )
48
+
64
49
  pkg_version_async_job_list.append((pkg_version, async_job))
65
50
 
66
51
  # Populate the cache.
@@ -143,7 +128,8 @@ def _get_conda_packages_and_emit_warnings(pkg_versions: List[str]) -> List[str]:
143
128
  warnings.warn(
144
129
  f"Package {', '.join([pkg[0] for pkg in pkg_version_warning_list])} is not supported "
145
130
  f"in snowflake conda channel for python runtime "
146
- f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}."
131
+ f"{', '.join([pkg[1] for pkg in pkg_version_warning_list])}.",
132
+ stacklevel=1,
147
133
  )
148
134
 
149
135
  return pkg_version_conda_list
@@ -2,17 +2,17 @@ import collections
2
2
  import logging
3
3
  import os
4
4
  import time
5
- from typing import Any, Deque, Dict, Iterator, List, Optional
5
+ from typing import Any, Deque, Dict, Iterator, List, Optional, Union
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
9
9
  import pandas as pd
10
10
  import pyarrow as pa
11
- import pyarrow.dataset as ds
11
+ import pyarrow.dataset as pds
12
12
 
13
13
  from snowflake import snowpark
14
- from snowflake.ml.data import data_ingestor, data_source
15
- from snowflake.ml.data._internal import ingestor_utils
14
+ from snowflake.connector import result_batch
15
+ from snowflake.ml.data import data_ingestor, data_source, ingestor_utils
16
16
 
17
17
  _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
18
18
 
@@ -67,6 +67,10 @@ class ArrowIngestor(data_ingestor.DataIngestor):
67
67
 
68
68
  self._schema: Optional[pa.Schema] = None
69
69
 
70
+ @classmethod
71
+ def from_sources(cls, session: snowpark.Session, sources: List[data_source.DataSource]) -> "ArrowIngestor":
72
+ return cls(session, sources)
73
+
70
74
  @property
71
75
  def data_sources(self) -> List[data_source.DataSource]:
72
76
  return self._data_sources
@@ -115,9 +119,9 @@ class ArrowIngestor(data_ingestor.DataIngestor):
115
119
  table = ds.to_table() if limit is None else ds.head(num_rows=limit)
116
120
  return table.to_pandas()
117
121
 
118
- def _get_dataset(self, shuffle: bool) -> ds.Dataset:
122
+ def _get_dataset(self, shuffle: bool) -> pds.Dataset:
119
123
  format = self._format
120
- sources = []
124
+ sources: List[Any] = []
121
125
  source_format = None
122
126
  for source in self._data_sources:
123
127
  if isinstance(source, str):
@@ -137,8 +141,16 @@ class ArrowIngestor(data_ingestor.DataIngestor):
137
141
  # in-memory (first batch) and file URLs (subsequent batches) and creating a
138
142
  # union dataset.
139
143
  result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
140
- sources.extend(b.to_arrow() for b in result_batches)
141
- source_format = "arrow"
144
+ sources.extend(
145
+ b.to_arrow(self._session.connection)
146
+ if isinstance(b, result_batch.ArrowResultBatch)
147
+ else b.to_arrow()
148
+ for b in result_batches
149
+ )
150
+ # HACK: Mitigate typing inconsistencies in Snowpark results
151
+ if len(sources) > 0:
152
+ sources = [_cast_if_needed(s, sources[-1].schema) for s in sources]
153
+ source_format = None # Arrow Dataset expects "None" for in-memory datasets
142
154
  else:
143
155
  raise RuntimeError(f"Unsupported data source type: {type(source)}")
144
156
 
@@ -150,7 +162,7 @@ class ArrowIngestor(data_ingestor.DataIngestor):
150
162
  # Re-shuffle input files on each iteration start
151
163
  if shuffle:
152
164
  np.random.shuffle(sources)
153
- pa_dataset: ds.Dataset = ds.dataset(sources, format=format, **self._kwargs)
165
+ pa_dataset: pds.Dataset = pds.dataset(sources, format=format, **self._kwargs)
154
166
  return pa_dataset
155
167
 
156
168
  def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
@@ -201,7 +213,7 @@ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
201
213
 
202
214
 
203
215
  def _retryable_batches(
204
- dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
216
+ dataset: pds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
205
217
  ) -> Iterator[pa.RecordBatch]:
206
218
  """Make the Dataset to_batches retryable."""
207
219
  retries = 0
@@ -226,3 +238,47 @@ def _retryable_batches(
226
238
  time.sleep(delay)
227
239
  else:
228
240
  raise e
241
+
242
+
243
+ def _cast_if_needed(
244
+ batch: Union[pa.Table, pa.RecordBatch], schema: Optional[pa.Schema] = None
245
+ ) -> Union[pa.Table, pa.RecordBatch]:
246
+ """
247
+ Cast the batch to be compatible with downstream frameworks. Returns original batch if cast is not necessary.
248
+ Besides casting types to match `schema` (if provided), this function also applies the following casting:
249
+ - Decimal (fixed-point) types: Convert to float or integer types based on scale and byte length
250
+
251
+ Args:
252
+ batch: The PyArrow batch to cast if needed
253
+ schema: Optional schema the batch should be casted to match. Note that compatibility type casting takes
254
+ precedence over the provided schema, e.g. if the schema has decimal types the result will be further
255
+ cast into integer/float types.
256
+
257
+ Returns:
258
+ The type-casted PyArrow batch, or the original batch if casting was not necessary
259
+ """
260
+ schema = schema or batch.schema
261
+ assert len(batch.schema) == len(schema)
262
+ fields = []
263
+ cast_needed = False
264
+ for field, target in zip(batch.schema, schema):
265
+ # Need to convert decimal types to supported types. This behavior supersedes target schema data types
266
+ if pa.types.is_decimal(target.type):
267
+ byte_length = int(target.metadata.get(b"byteLength", 8))
268
+ if int(target.metadata.get(b"scale", 0)) > 0:
269
+ target = target.with_type(pa.float32() if byte_length == 4 else pa.float64())
270
+ else:
271
+ if byte_length == 2:
272
+ target = target.with_type(pa.int16())
273
+ elif byte_length == 4:
274
+ target = target.with_type(pa.int32())
275
+ else: # Cap out at 64-bit
276
+ target = target.with_type(pa.int64())
277
+ if not field.equals(target):
278
+ cast_needed = True
279
+ field = target
280
+ fields.append(field)
281
+
282
+ if cast_needed:
283
+ return batch.cast(pa.schema(fields))
284
+ return batch
@@ -1,11 +1,12 @@
1
1
  from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Type, TypeVar
2
2
 
3
3
  import numpy.typing as npt
4
+ from typing_extensions import deprecated
4
5
 
5
6
  from snowflake import snowpark
6
7
  from snowflake.ml._internal import telemetry
7
8
  from snowflake.ml.data import data_ingestor, data_source
8
- from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor as DefaultIngestor
9
+ from snowflake.ml.data._internal.arrow_ingestor import ArrowIngestor
9
10
 
10
11
  if TYPE_CHECKING:
11
12
  import pandas as pd
@@ -24,6 +25,8 @@ DataConnectorType = TypeVar("DataConnectorType", bound="DataConnector")
24
25
  class DataConnector:
25
26
  """Snowflake data reader which provides application integration connectors"""
26
27
 
28
+ DEFAULT_INGESTOR_CLASS: Type[data_ingestor.DataIngestor] = ArrowIngestor
29
+
27
30
  def __init__(
28
31
  self,
29
32
  ingestor: data_ingestor.DataIngestor,
@@ -31,22 +34,48 @@ class DataConnector:
31
34
  self._ingestor = ingestor
32
35
 
33
36
  @classmethod
34
- def from_dataframe(cls: Type[DataConnectorType], df: snowpark.DataFrame, **kwargs: Any) -> DataConnectorType:
37
+ @snowpark._internal.utils.private_preview(version="1.6.0")
38
+ def from_dataframe(
39
+ cls: Type[DataConnectorType],
40
+ df: snowpark.DataFrame,
41
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
42
+ **kwargs: Any
43
+ ) -> DataConnectorType:
35
44
  if len(df.queries["queries"]) != 1 or len(df.queries["post_actions"]) != 0:
36
45
  raise ValueError("DataFrames with multiple queries and/or post-actions not supported")
37
46
  source = data_source.DataFrameInfo(df.queries["queries"][0])
38
47
  assert df._session is not None
39
- ingestor = DefaultIngestor(df._session, [source])
40
- return cls(ingestor, **kwargs)
48
+ return cls.from_sources(df._session, [source], ingestor_class=ingestor_class, **kwargs)
41
49
 
42
50
  @classmethod
43
- def from_dataset(cls: Type[DataConnectorType], ds: "dataset.Dataset", **kwargs: Any) -> DataConnectorType:
51
+ def from_dataset(
52
+ cls: Type[DataConnectorType],
53
+ ds: "dataset.Dataset",
54
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
55
+ **kwargs: Any
56
+ ) -> DataConnectorType:
44
57
  dsv = ds.selected_version
45
58
  assert dsv is not None
46
59
  source = data_source.DatasetInfo(
47
60
  ds.fully_qualified_name, dsv.name, dsv.url(), exclude_cols=(dsv.label_cols + dsv.exclude_cols)
48
61
  )
49
- ingestor = DefaultIngestor(ds._session, [source])
62
+ return cls.from_sources(ds._session, [source], ingestor_class=ingestor_class, **kwargs)
63
+
64
+ @classmethod
65
+ @telemetry.send_api_usage_telemetry(
66
+ project=_PROJECT,
67
+ subproject_extractor=lambda cls: cls.__name__,
68
+ func_params_to_log=["sources", "ingestor_class"],
69
+ )
70
+ def from_sources(
71
+ cls: Type[DataConnectorType],
72
+ session: snowpark.Session,
73
+ sources: List[data_source.DataSource],
74
+ ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None,
75
+ **kwargs: Any
76
+ ) -> DataConnectorType:
77
+ ingestor_class = ingestor_class or cls.DEFAULT_INGESTOR_CLASS
78
+ ingestor = ingestor_class.from_sources(session, sources)
50
79
  return cls(ingestor, **kwargs)
51
80
 
52
81
  @property
@@ -87,6 +116,9 @@ class DataConnector:
87
116
 
88
117
  return tf.data.Dataset.from_generator(generator, output_signature=tf_signature)
89
118
 
119
+ @deprecated(
120
+ "to_torch_datapipe() is deprecated and will be removed in a future release. Use to_torch_dataset() instead"
121
+ )
90
122
  @telemetry.send_api_usage_telemetry(
91
123
  project=_PROJECT,
92
124
  subproject_extractor=lambda self: type(self).__name__,
@@ -116,6 +148,27 @@ class DataConnector:
116
148
  self._ingestor.to_batches(batch_size, shuffle, drop_last_batch)
117
149
  )
118
150
 
151
+ @telemetry.send_api_usage_telemetry(
152
+ project=_PROJECT,
153
+ subproject_extractor=lambda self: type(self).__name__,
154
+ func_params_to_log=["shuffle"],
155
+ )
156
+ def to_torch_dataset(self, *, shuffle: bool = False) -> "torch_data.IterableDataset": # type: ignore[type-arg]
157
+ """Transform the Snowflake data into a PyTorch Iterable Dataset to be used with a DataLoader.
158
+
159
+ Return a PyTorch Dataset which iterates on rows of data.
160
+
161
+ Args:
162
+ shuffle: It specifies whether the data will be shuffled. If True, files will be shuffled, and
163
+ rows in each file will also be shuffled.
164
+
165
+ Returns:
166
+ A PyTorch Iterable Dataset that yields data.
167
+ """
168
+ from snowflake.ml.data import torch_dataset
169
+
170
+ return torch_dataset.TorchDataset(self._ingestor, shuffle)
171
+
119
172
  @telemetry.send_api_usage_telemetry(
120
173
  project=_PROJECT,
121
174
  subproject_extractor=lambda self: type(self).__name__,
@@ -1,7 +1,18 @@
1
- from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Protocol, TypeVar
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Dict,
5
+ Iterator,
6
+ List,
7
+ Optional,
8
+ Protocol,
9
+ Type,
10
+ TypeVar,
11
+ )
2
12
 
3
13
  from numpy import typing as npt
4
14
 
15
+ from snowflake import snowpark
5
16
  from snowflake.ml.data import data_source
6
17
 
7
18
  if TYPE_CHECKING:
@@ -12,6 +23,12 @@ DataIngestorType = TypeVar("DataIngestorType", bound="DataIngestor")
12
23
 
13
24
 
14
25
  class DataIngestor(Protocol):
26
+ @classmethod
27
+ def from_sources(
28
+ cls: Type[DataIngestorType], session: snowpark.Session, sources: List[data_source.DataSource]
29
+ ) -> DataIngestorType:
30
+ raise NotImplementedError
31
+
15
32
  @property
16
33
  def data_sources(self) -> List[data_source.DataSource]:
17
34
  raise NotImplementedError
@@ -13,6 +13,7 @@ _TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
13
13
  def get_dataframe_result_batches(
14
14
  session: snowpark.Session, df_info: data_source.DataFrameInfo
15
15
  ) -> List[result_batch.ResultBatch]:
16
+ """Retrieve the ResultBatches for a given query"""
16
17
  cursor = session._conn._cursor
17
18
 
18
19
  if df_info.query_id:
@@ -39,6 +40,7 @@ def get_dataframe_result_batches(
39
40
  def get_dataset_filesystem(
40
41
  session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
41
42
  ) -> fsspec.AbstractFileSystem:
43
+ """Get the fsspec filesystem for a given Dataset"""
42
44
  # We can't directly load the Dataset to avoid a circular dependency
43
45
  # Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
44
46
  # TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
@@ -52,7 +54,9 @@ def get_dataset_filesystem(
52
54
  def get_dataset_files(
53
55
  session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
54
56
  ) -> List[str]:
57
+ """Get the list of files in a given Dataset"""
55
58
  if filesystem is None:
56
59
  filesystem = get_dataset_filesystem(session, ds_info)
57
60
  assert bool(ds_info.url) # Not null or empty
58
- return sorted(filesystem.ls(ds_info.url))
61
+ files = sorted(filesystem.ls(ds_info.url))
62
+ return [filesystem.unstrip_protocol(f) for f in files]
@@ -0,0 +1,33 @@
1
+ from typing import Any, Dict, Iterator
2
+
3
+ import torch.utils.data
4
+
5
+ from snowflake.ml.data import data_ingestor
6
+
7
+
8
+ class TorchDataset(torch.utils.data.IterableDataset[Dict[str, Any]]):
9
+ """Implementation of PyTorch IterableDataset"""
10
+
11
+ def __init__(self, ingestor: data_ingestor.DataIngestor, shuffle: bool = False) -> None:
12
+ """Not intended for direct usage. Use DataConnector.to_torch_dataset() instead"""
13
+ self._ingestor = ingestor
14
+ self._shuffle = shuffle
15
+
16
+ def __iter__(self) -> Iterator[Dict[str, Any]]:
17
+ max_idx = 0
18
+ filter_idx = 0
19
+ worker_info = torch.utils.data.get_worker_info()
20
+ if worker_info is not None:
21
+ max_idx = worker_info.num_workers - 1
22
+ filter_idx = worker_info.id
23
+
24
+ counter = 0
25
+ for batch in self._ingestor.to_batches(batch_size=1, shuffle=self._shuffle, drop_last_batch=False):
26
+ # Skip indices during multi-process data loading to prevent data duplication
27
+ if counter == filter_idx:
28
+ yield {k: v.item() for k, v in batch.items()}
29
+
30
+ if counter < max_idx:
31
+ counter += 1
32
+ else:
33
+ counter = 0
@@ -15,11 +15,13 @@ class FeatureStoreMetadata:
15
15
  Properties:
16
16
  spine_query: The input query on source table which will be joined with features.
17
17
  serialized_feature_views: A list of serialized feature objects in the feature store.
18
+ compact_feature_views: A compact representation of a FeatureView or FeatureViewSlice.
18
19
  spine_timestamp_col: Timestamp column which was used for point-in-time correct feature lookup.
19
20
  """
20
21
 
21
22
  spine_query: str
22
- serialized_feature_views: List[str]
23
+ serialized_feature_views: Optional[List[str]] = None
24
+ compact_feature_views: Optional[List[str]] = None
23
25
  spine_timestamp_col: Optional[str] = None
24
26
 
25
27
  def to_json(self) -> str:
@@ -1,10 +1,9 @@
1
- from typing import List, Optional
1
+ from typing import Any, List, Optional, Type
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal import telemetry
5
5
  from snowflake.ml._internal.lineage import lineage_utils
6
- from snowflake.ml.data import data_connector, data_ingestor, data_source
7
- from snowflake.ml.data._internal import ingestor_utils
6
+ from snowflake.ml.data import data_connector, data_ingestor, data_source, ingestor_utils
8
7
  from snowflake.ml.fileset import snowfs
9
8
 
10
9
  _PROJECT = "Dataset"
@@ -27,6 +26,13 @@ class DatasetReader(data_connector.DataConnector):
27
26
  self._fs: snowfs.SnowFileSystem = ingestor_utils.get_dataset_filesystem(self._session)
28
27
  self._files: Optional[List[str]] = None
29
28
 
29
+ @classmethod
30
+ def from_dataframe(
31
+ cls, df: snowpark.DataFrame, ingestor_class: Optional[Type[data_ingestor.DataIngestor]] = None, **kwargs: Any
32
+ ) -> "DatasetReader":
33
+ # Block superclass constructor from Snowpark DataFrames
34
+ raise RuntimeError("Creating DatasetReader from DataFrames not supported")
35
+
30
36
  def _list_files(self) -> List[str]:
31
37
  """Private helper function that lists all files in this DatasetVersion and caches the results."""
32
38
  if self._files:
@@ -0,0 +1,16 @@
1
+ from typing import List
2
+
3
+ from snowflake.ml.feature_store import Entity
4
+
5
+ zipcode_entity = Entity(
6
+ name="AIRPORT_ZIP_CODE",
7
+ join_keys=["AIRPORT_ZIP_CODE"],
8
+ desc="Zip code of the airport.",
9
+ )
10
+
11
+ plane_entity = Entity(name="PLANE_MODEL", join_keys=["PLANE_MODEL"], desc="The model of an airplane.")
12
+
13
+
14
+ # This will be invoked by example_helper.py. Do not change function name.
15
+ def get_all_entities() -> List[Entity]:
16
+ return [zipcode_entity, plane_entity]