snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +67 -10
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  6. snowflake/ml/_internal/telemetry.py +12 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
  8. snowflake/ml/data/_internal/ingestor_utils.py +58 -0
  9. snowflake/ml/data/data_connector.py +133 -0
  10. snowflake/ml/data/data_ingestor.py +28 -0
  11. snowflake/ml/data/data_source.py +23 -0
  12. snowflake/ml/dataset/dataset.py +1 -13
  13. snowflake/ml/dataset/dataset_reader.py +18 -118
  14. snowflake/ml/feature_store/access_manager.py +7 -1
  15. snowflake/ml/feature_store/entity.py +19 -2
  16. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +240 -0
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
  25. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  26. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  27. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  28. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  29. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  30. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
  31. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
  32. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
  33. snowflake/ml/feature_store/feature_store.py +579 -53
  34. snowflake/ml/feature_store/feature_view.py +168 -5
  35. snowflake/ml/fileset/stage_fs.py +18 -10
  36. snowflake/ml/lineage/lineage_node.py +1 -1
  37. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  38. snowflake/ml/model/_model_composer/model_composer.py +11 -14
  39. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
  40. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
  41. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  42. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  43. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  44. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  45. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  46. snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
  47. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  48. snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
  50. snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
  51. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
  52. snowflake/ml/model/model_signature.py +4 -4
  53. snowflake/ml/model/type_hints.py +4 -0
  54. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  56. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  57. snowflake/ml/modeling/pipeline/pipeline.py +4 -4
  58. snowflake/ml/registry/registry.py +100 -13
  59. snowflake/ml/version.py +1 -1
  60. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
  61. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
  62. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
  63. snowflake/ml/_internal/lineage/data_source.py +0 -10
  64. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,4 @@
1
+ from snowflake.cortex._classify_text import ClassifyText
1
2
  from snowflake.cortex._complete import Complete, CompleteOptions
2
3
  from snowflake.cortex._extract_answer import ExtractAnswer
3
4
  from snowflake.cortex._sentiment import Sentiment
@@ -5,6 +6,7 @@ from snowflake.cortex._summarize import Summarize
5
6
  from snowflake.cortex._translate import Translate
6
7
 
7
8
  __all__ = [
9
+ "ClassifyText",
8
10
  "Complete",
9
11
  "CompleteOptions",
10
12
  "ExtractAnswer",
@@ -0,0 +1,36 @@
1
+ from typing import List, Optional, Union
2
+
3
+ from snowflake import snowpark
4
+ from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
5
+ from snowflake.ml._internal import telemetry
6
+
7
+
8
+ @telemetry.send_api_usage_telemetry(
9
+ project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
10
+ )
11
+ def ClassifyText(
12
+ str_input: Union[str, snowpark.Column],
13
+ categories: Union[List[str], snowpark.Column],
14
+ session: Optional[snowpark.Session] = None,
15
+ ) -> Union[str, snowpark.Column]:
16
+ """Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
17
+
18
+ Args:
19
+ str_input: A Column of strings to classify.
20
+ categories: A list of candidate categories to classify the INPUT text into.
21
+ session: The snowpark session to use. Will be inferred by context if not specified.
22
+
23
+ Returns:
24
+ A column of classification responses.
25
+ """
26
+
27
+ return _classify_text_impl("snowflake.cortex.classify_text", str_input, categories, session=session)
28
+
29
+
30
+ def _classify_text_impl(
31
+ function: str,
32
+ str_input: Union[str, snowpark.Column],
33
+ categories: Union[List[str], snowpark.Column],
34
+ session: Optional[snowpark.Session] = None,
35
+ ) -> Union[str, snowpark.Column]:
36
+ return call_sql_function(function, session, str_input, categories)
@@ -1,6 +1,7 @@
1
1
  import json
2
2
  import logging
3
- from typing import Iterator, List, Optional, TypedDict, Union, cast
3
+ import time
4
+ from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast
4
5
  from urllib.parse import urlunparse
5
6
 
6
7
  import requests
@@ -52,6 +53,38 @@ class ResponseParseException(Exception):
52
53
  pass
53
54
 
54
55
 
56
+ _MAX_RETRY_SECONDS = 30
57
+
58
+
59
+ def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Response]:
60
+ def inner(*args: Any, **kwargs: Any) -> requests.Response:
61
+ deadline = cast(Optional[float], kwargs["deadline"])
62
+ kwargs = {key: value for key, value in kwargs.items() if key != "deadline"}
63
+ expRetrySeconds = 0.5
64
+ while True:
65
+ if deadline is not None and time.time() > deadline:
66
+ raise TimeoutError()
67
+ response = func(*args, **kwargs)
68
+ if response.status_code >= 200 and response.status_code < 300:
69
+ return response
70
+ retry_status_codes = [429, 503, 504]
71
+ if response.status_code not in retry_status_codes:
72
+ response.raise_for_status()
73
+ logger.debug(f"request failed with status code {response.status_code}, retrying")
74
+
75
+ # Formula: delay(i) = max(RetryAfterHeader, min(2^i, _MAX_RETRY_SECONDS)).
76
+ expRetrySeconds = min(2 * expRetrySeconds, _MAX_RETRY_SECONDS)
77
+ retrySeconds = expRetrySeconds
78
+ retryAfterHeader = response.headers.get("retry-after")
79
+ if retryAfterHeader is not None:
80
+ retrySeconds = max(retrySeconds, int(retryAfterHeader))
81
+ logger.debug(f"sleeping for {retrySeconds}s before retrying")
82
+ time.sleep(retrySeconds)
83
+
84
+ return inner
85
+
86
+
87
+ @retry
55
88
  def _call_complete_rest(
56
89
  model: str,
57
90
  prompt: Union[str, List[ConversationMessage]],
@@ -78,7 +111,7 @@ def _call_complete_rest(
78
111
  scheme = "https"
79
112
  if hasattr(session.connection, "scheme"):
80
113
  scheme = session.connection.scheme
81
- url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference/complete", "", "", ""))
114
+ url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", ""))
82
115
 
83
116
  headers = {
84
117
  "Content-Type": "application/json",
@@ -105,19 +138,21 @@ def _call_complete_rest(
105
138
  data["top_p"] = options["top_p"]
106
139
 
107
140
  logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
108
- response = requests.post(
141
+ return requests.post(
109
142
  url,
110
143
  json=data,
111
144
  headers=headers,
112
145
  stream=stream,
113
146
  )
114
- response.raise_for_status()
115
- return response
116
147
 
117
148
 
118
- def _process_rest_response(response: requests.Response, stream: bool = False) -> Union[str, Iterator[str]]:
149
+ def _process_rest_response(
150
+ response: requests.Response,
151
+ stream: bool = False,
152
+ deadline: Optional[float] = None,
153
+ ) -> Union[str, Iterator[str]]:
119
154
  if stream:
120
- return _return_stream_response(response)
155
+ return _return_stream_response(response, deadline)
121
156
 
122
157
  try:
123
158
  content = response.json()["choices"][0]["message"]["content"]
@@ -128,9 +163,11 @@ def _process_rest_response(response: requests.Response, stream: bool = False) ->
128
163
  raise ResponseParseException("Failed to parse message from response.") from e
129
164
 
130
165
 
131
- def _return_stream_response(response: requests.Response) -> Iterator[str]:
166
+ def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
132
167
  client = SSEClient(response)
133
168
  for event in client.events():
169
+ if deadline is not None and time.time() > deadline:
170
+ raise TimeoutError()
134
171
  try:
135
172
  yield json.loads(event.data)["choices"][0]["delta"]["content"]
136
173
  except (json.JSONDecodeError, KeyError, IndexError):
@@ -209,13 +246,20 @@ def _complete_impl(
209
246
  use_rest_api_experimental: bool = False,
210
247
  stream: bool = False,
211
248
  function: str = "snowflake.cortex.complete",
249
+ timeout: Optional[float] = None,
250
+ deadline: Optional[float] = None,
212
251
  ) -> Union[str, Iterator[str], snowpark.Column]:
252
+ if timeout is not None and deadline is not None:
253
+ raise ValueError('only one of "timeout" and "deadline" must be set')
254
+ if timeout is not None:
255
+ deadline = time.time() + timeout
213
256
  if use_rest_api_experimental:
214
257
  if not isinstance(model, str):
215
258
  raise ValueError("in REST mode, 'model' must be a string")
216
259
  if not isinstance(prompt, str) and not isinstance(prompt, List):
217
260
  raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
218
- response = _call_complete_rest(model, prompt, options, session=session, stream=stream)
261
+ response = _call_complete_rest(model, prompt, options, session=session, stream=stream, deadline=deadline)
262
+ assert response.status_code >= 200 and response.status_code < 300
219
263
  return _process_rest_response(response, stream=stream)
220
264
  if stream is True:
221
265
  raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
@@ -233,6 +277,8 @@ def Complete(
233
277
  session: Optional[snowpark.Session] = None,
234
278
  use_rest_api_experimental: bool = False,
235
279
  stream: bool = False,
280
+ timeout: Optional[float] = None,
281
+ deadline: Optional[float] = None,
236
282
  ) -> Union[str, Iterator[str], snowpark.Column]:
237
283
  """Complete calls into the LLM inference service to perform completion.
238
284
 
@@ -246,6 +292,8 @@ def Complete(
246
292
  stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
247
293
  output as it is received. Each update is a string containing the new text content since the previous update.
248
294
  The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
295
+ timeout (float): Timeout in seconds to retry failed REST requests.
296
+ deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
249
297
 
250
298
  Raises:
251
299
  ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
@@ -254,6 +302,15 @@ def Complete(
254
302
  A column of string responses.
255
303
  """
256
304
  try:
257
- return _complete_impl(model, prompt, options, session, use_rest_api_experimental, stream)
305
+ return _complete_impl(
306
+ model,
307
+ prompt,
308
+ options=options,
309
+ session=session,
310
+ use_rest_api_experimental=use_rest_api_experimental,
311
+ stream=stream,
312
+ timeout=timeout,
313
+ deadline=deadline,
314
+ )
258
315
  except ValueError as err:
259
316
  raise err
snowflake/cortex/_util.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Dict, Optional, Union, cast
1
+ from typing import Dict, List, Optional, Union, cast
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.snowpark import context, functions
@@ -23,7 +23,7 @@ class SnowflakeConfigurationException(Exception):
23
23
  def call_sql_function(
24
24
  function: str,
25
25
  session: Optional[snowpark.Session],
26
- *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
26
+ *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
27
27
  ) -> Union[str, snowpark.Column]:
28
28
  handle_as_column = False
29
29
 
@@ -40,7 +40,7 @@ def call_sql_function(
40
40
 
41
41
 
42
42
  def _call_sql_function_column(
43
- function: str, *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]]
43
+ function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
44
44
  ) -> snowpark.Column:
45
45
  return cast(snowpark.Column, functions.builtin(function)(*args))
46
46
 
@@ -48,7 +48,7 @@ def _call_sql_function_column(
48
48
  def _call_sql_function_immediate(
49
49
  function: str,
50
50
  session: Optional[snowpark.Session],
51
- *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
51
+ *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
52
52
  ) -> str:
53
53
  session = session or context.get_active_session()
54
54
  if session is None:
@@ -1,9 +1,9 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List, Optional
3
+ from typing import Any, Callable, List, Optional, get_args
4
4
 
5
5
  from snowflake import snowpark
6
- from snowflake.ml._internal.lineage import data_source
6
+ from snowflake.ml.data import data_source
7
7
 
8
8
  _DATA_SOURCES_ATTR = "_data_sources"
9
9
 
@@ -39,7 +39,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
39
39
  result: Optional[List[data_source.DataSource]] = None
40
40
  for arg in args:
41
41
  srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
- if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
42
+ if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
43
43
  if result is None:
44
44
  result = []
45
45
  result += srcs
@@ -49,7 +49,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
49
49
  def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
50
50
  """Helper method for attaching data sources to an object"""
51
51
  if data_sources:
52
- assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
52
+ assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
53
53
  setattr(obj, _DATA_SOURCES_ATTR, data_sources)
54
54
 
55
55
 
@@ -277,6 +277,7 @@ def send_api_usage_telemetry(
277
277
  ]
278
278
  ] = None,
279
279
  sfqids_extractor: Optional[Callable[..., List[str]]] = None,
280
+ subproject_extractor: Optional[Callable[[Any], str]] = None,
280
281
  custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
281
282
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
282
283
  """
@@ -290,6 +291,7 @@ def send_api_usage_telemetry(
290
291
  conn_attr_name: Name of the SnowflakeConnection attribute in `self`.
291
292
  api_calls_extractor: Extract API calls from `self`.
292
293
  sfqids_extractor: Extract sfqids from `self`.
294
+ subproject_extractor: Extract subproject at runtime from `self`.
293
295
  custom_tags: Custom tags.
294
296
 
295
297
  Returns:
@@ -297,10 +299,14 @@ def send_api_usage_telemetry(
297
299
 
298
300
  Raises:
299
301
  TypeError: If `conn_attr_name` is provided but the conn attribute is not of type SnowflakeConnection.
302
+ ValueError: If both `subproject` and `subproject_extractor` are provided
300
303
 
301
304
  # noqa: DAR402
302
305
  """
303
306
 
307
+ if subproject is not None and subproject_extractor is not None:
308
+ raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
309
+
304
310
  def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
305
311
  @functools.wraps(func)
306
312
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
@@ -322,9 +328,13 @@ def send_api_usage_telemetry(
322
328
  if sfqids_extractor:
323
329
  sfqids = sfqids_extractor(args[0])
324
330
 
331
+ subproject_name = subproject
332
+ if subproject_extractor is not None:
333
+ subproject_name = subproject_extractor(args[0])
334
+
325
335
  statement_params = get_function_usage_statement_params(
326
336
  project=project,
327
- subproject=subproject,
337
+ subproject=subproject_name,
328
338
  function_category=TelemetryField.FUNC_CAT_USAGE.value,
329
339
  function_name=_get_full_func_name(func),
330
340
  function_parameters=params,
@@ -381,7 +391,7 @@ def send_api_usage_telemetry(
381
391
  raise e.original_exception from e
382
392
 
383
393
  # TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
384
- telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
394
+ telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
385
395
  telemetry_args = dict(
386
396
  func_name=_get_full_func_name(func),
387
397
  function_category=TelemetryField.FUNC_CAT_USAGE.value,
@@ -0,0 +1,228 @@
1
+ import collections
2
+ import logging
3
+ import os
4
+ import time
5
+ from typing import Any, Deque, Dict, Iterator, List, Optional
6
+
7
+ import numpy as np
8
+ import numpy.typing as npt
9
+ import pandas as pd
10
+ import pyarrow as pa
11
+ import pyarrow.dataset as ds
12
+
13
+ from snowflake import snowpark
14
+ from snowflake.ml.data import data_ingestor, data_source
15
+ from snowflake.ml.data._internal import ingestor_utils
16
+
17
+ _EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
18
+
19
+ # The row count for batches read from PyArrow Dataset. This number should be large enough so that
20
+ # dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
21
+ _DEFAULT_DATASET_BATCH_SIZE = 1000000
22
+
23
+
24
+ class _RecordBatchesBuffer:
25
+ """A queue that stores record batches and tracks the total num of rows in it."""
26
+
27
+ def __init__(self) -> None:
28
+ self.buffer: Deque[pa.RecordBatch] = collections.deque()
29
+ self.num_rows = 0
30
+
31
+ def append(self, rb: pa.RecordBatch) -> None:
32
+ self.buffer.append(rb)
33
+ self.num_rows += rb.num_rows
34
+
35
+ def appendleft(self, rb: pa.RecordBatch) -> None:
36
+ self.buffer.appendleft(rb)
37
+ self.num_rows += rb.num_rows
38
+
39
+ def popleft(self) -> pa.RecordBatch:
40
+ popped = self.buffer.popleft()
41
+ self.num_rows -= popped.num_rows
42
+ return popped
43
+
44
+
45
+ class ArrowIngestor(data_ingestor.DataIngestor):
46
+ """Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
47
+
48
+ def __init__(
49
+ self,
50
+ session: snowpark.Session,
51
+ data_sources: List[data_source.DataSource],
52
+ format: Optional[str] = None,
53
+ **kwargs: Any,
54
+ ) -> None:
55
+ """
56
+ Args:
57
+ session: The Snowpark Session to use.
58
+ data_sources: List of data sources to ingest.
59
+ format: Currently “parquet”, “ipc”/”arrow”/”feather”, “csv”, “json”, and “orc” are supported.
60
+ Will be inferred if not specified.
61
+ kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
62
+ """
63
+ self._session = session
64
+ self._data_sources = data_sources
65
+ self._format = format
66
+ self._kwargs = kwargs
67
+
68
+ self._schema: Optional[pa.Schema] = None
69
+
70
+ @property
71
+ def data_sources(self) -> List[data_source.DataSource]:
72
+ return self._data_sources
73
+
74
+ def to_batches(
75
+ self,
76
+ batch_size: int,
77
+ shuffle: bool = True,
78
+ drop_last_batch: bool = True,
79
+ ) -> Iterator[Dict[str, npt.NDArray[Any]]]:
80
+ """Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
81
+
82
+ As we are generating batches with the exactly same length, the last few rows in each file might get left as they
83
+ are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
84
+ few rows of the next file to generate a new batch.
85
+
86
+ Args:
87
+ batch_size: Specifies the size of each batch that will be yield
88
+ shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
89
+ the order of files, and then shuflle the order of rows in each file.
90
+ drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last
91
+ batch will get dropped if its size is smaller than the given batch_size.
92
+
93
+ Yields:
94
+ A dict mapping column names to the corresponding data fetch from that column.
95
+ """
96
+ self._rb_buffer = _RecordBatchesBuffer()
97
+
98
+ # Extract schema if not already known
99
+ dataset = self._get_dataset(shuffle)
100
+ if self._schema is None:
101
+ self._schema = dataset.schema
102
+
103
+ for rb in _retryable_batches(dataset, batch_size=max(_DEFAULT_DATASET_BATCH_SIZE, batch_size)):
104
+ if shuffle:
105
+ rb = rb.take(np.random.permutation(rb.num_rows))
106
+ self._rb_buffer.append(rb)
107
+ while self._rb_buffer.num_rows >= batch_size:
108
+ yield self._get_batches_from_buffer(batch_size)
109
+
110
+ if self._rb_buffer.num_rows and not drop_last_batch:
111
+ yield self._get_batches_from_buffer(batch_size)
112
+
113
+ def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
114
+ ds = self._get_dataset(shuffle=False)
115
+ table = ds.to_table() if limit is None else ds.head(num_rows=limit)
116
+ return table.to_pandas()
117
+
118
+ def _get_dataset(self, shuffle: bool) -> ds.Dataset:
119
+ format = self._format
120
+ sources = []
121
+ source_format = None
122
+ for source in self._data_sources:
123
+ if isinstance(source, str):
124
+ sources.append(source)
125
+ source_format = format or os.path.splitext(source)[-1]
126
+ elif isinstance(source, data_source.DatasetInfo):
127
+ if not self._kwargs.get("filesystem"):
128
+ self._kwargs["filesystem"] = ingestor_utils.get_dataset_filesystem(self._session, source)
129
+ sources.extend(
130
+ ingestor_utils.get_dataset_files(self._session, source, filesystem=self._kwargs["filesystem"])
131
+ )
132
+ source_format = "parquet"
133
+ elif isinstance(source, data_source.DataFrameInfo):
134
+ # FIXME: This currently loads all result batches into memory so that it
135
+ # can be passed into pyarrow.dataset as a list/tuple of pa.RecordBatches
136
+ # We may be able to optimize this by splitting the result batches into
137
+ # in-memory (first batch) and file URLs (subsequent batches) and creating a
138
+ # union dataset.
139
+ result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
140
+ sources.extend(b.to_arrow() for b in result_batches)
141
+ source_format = "arrow"
142
+ else:
143
+ raise RuntimeError(f"Unsupported data source type: {type(source)}")
144
+
145
+ # Make sure source types not mixed
146
+ if format and format != source_format:
147
+ raise RuntimeError(f"Unexpected data source format (expected {format}, found {source_format})")
148
+ format = source_format
149
+
150
+ # Re-shuffle input files on each iteration start
151
+ if shuffle:
152
+ np.random.shuffle(sources)
153
+ pa_dataset: ds.Dataset = ds.dataset(sources, format=format, **self._kwargs)
154
+ return pa_dataset
155
+
156
+ def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
157
+ """Generate new batches from the existing record batch buffer."""
158
+ cnt_rbs_num_rows = 0
159
+ candidates = []
160
+
161
+ # Keep popping record batches in buffer until there are enough rows for a batch.
162
+ while self._rb_buffer.num_rows and cnt_rbs_num_rows < batch_size:
163
+ candidate = self._rb_buffer.popleft()
164
+ cnt_rbs_num_rows += candidate.num_rows
165
+ candidates.append(candidate)
166
+
167
+ # When there are more rows than needed, slice the last popped batch to fit batch_size.
168
+ if cnt_rbs_num_rows > batch_size:
169
+ row_diff = cnt_rbs_num_rows - batch_size
170
+ slice_target = candidates[-1]
171
+ cut_off = slice_target.num_rows - row_diff
172
+ to_merge = slice_target.slice(length=cut_off)
173
+ left_over = slice_target.slice(offset=cut_off)
174
+ candidates[-1] = to_merge
175
+ self._rb_buffer.appendleft(left_over)
176
+
177
+ res = _merge_record_batches(candidates)
178
+ return _record_batch_to_arrays(res)
179
+
180
+
181
+ def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
182
+ """Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
183
+ if not record_batches:
184
+ return _EMPTY_RECORD_BATCH
185
+ if len(record_batches) == 1:
186
+ return record_batches[0]
187
+ record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
188
+ one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
189
+ batches = one_chunk_table.to_batches(max_chunksize=None)
190
+ return batches[0]
191
+
192
+
193
+ def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
194
+ """Transform the record batch to a (string, numpy array) dict."""
195
+ batch_dict = {}
196
+ for column, column_schema in zip(rb, rb.schema):
197
+ # zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
198
+ array = column.to_numpy(zero_copy_only=False)
199
+ batch_dict[column_schema.name] = array
200
+ return batch_dict
201
+
202
+
203
+ def _retryable_batches(
204
+ dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
205
+ ) -> Iterator[pa.RecordBatch]:
206
+ """Make the Dataset to_batches retryable."""
207
+ retries = 0
208
+ current_batch_index = 0
209
+
210
+ while True:
211
+ try:
212
+ for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
213
+ if batch_index < current_batch_index:
214
+ # Skip batches that have already been processed
215
+ continue
216
+
217
+ yield batch
218
+ current_batch_index = batch_index + 1
219
+ # Exit the loop once all batches are processed
220
+ break
221
+
222
+ except Exception as e:
223
+ if retries < max_retries:
224
+ retries += 1
225
+ logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
226
+ time.sleep(delay)
227
+ else:
228
+ raise e
@@ -0,0 +1,58 @@
1
+ from typing import List, Optional
2
+
3
+ import fsspec
4
+
5
+ from snowflake import snowpark
6
+ from snowflake.connector import result_batch
7
+ from snowflake.ml.data import data_source
8
+ from snowflake.ml.fileset import snowfs
9
+
10
+ _TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
11
+
12
+
13
+ def get_dataframe_result_batches(
14
+ session: snowpark.Session, df_info: data_source.DataFrameInfo
15
+ ) -> List[result_batch.ResultBatch]:
16
+ cursor = session._conn._cursor
17
+
18
+ if df_info.query_id:
19
+ query_id = df_info.query_id
20
+ else:
21
+ query_id = session.sql(df_info.sql).collect_nowait().query_id
22
+
23
+ # TODO: Check if query result cache is still live
24
+ cursor.get_results_from_sfqid(sfqid=query_id)
25
+
26
+ # Prefetch hook should be set by `get_results_from_sfqid`
27
+ # This call blocks until the query results are ready
28
+ if cursor._prefetch_hook is None:
29
+ raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
30
+ cursor._prefetch_hook()
31
+ batches = cursor.get_result_batches()
32
+ if batches is None:
33
+ raise ValueError(
34
+ "Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}"
35
+ )
36
+ return batches
37
+
38
+
39
+ def get_dataset_filesystem(
40
+ session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
41
+ ) -> fsspec.AbstractFileSystem:
42
+ # We can't directly load the Dataset to avoid a circular dependency
43
+ # Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
44
+ # TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
45
+ return snowfs.SnowFileSystem(
46
+ snowpark_session=session,
47
+ cache_type="bytes",
48
+ block_size=2 * _TARGET_FILE_SIZE,
49
+ )
50
+
51
+
52
+ def get_dataset_files(
53
+ session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
54
+ ) -> List[str]:
55
+ if filesystem is None:
56
+ filesystem = get_dataset_filesystem(session, ds_info)
57
+ assert bool(ds_info.url) # Not null or empty
58
+ return sorted(filesystem.ls(ds_info.url))