snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +67 -10
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +12 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +579 -53
- snowflake/ml/feature_store/feature_view.py +168 -5
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +11 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +24 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +11 -1
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +42 -0
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +68 -0
- snowflake/ml/model/_packager/model_handlers/xgboost.py +59 -0
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +4 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +4 -4
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +48 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +64 -42
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
snowflake/cortex/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
from snowflake.cortex._classify_text import ClassifyText
|
1
2
|
from snowflake.cortex._complete import Complete, CompleteOptions
|
2
3
|
from snowflake.cortex._extract_answer import ExtractAnswer
|
3
4
|
from snowflake.cortex._sentiment import Sentiment
|
@@ -5,6 +6,7 @@ from snowflake.cortex._summarize import Summarize
|
|
5
6
|
from snowflake.cortex._translate import Translate
|
6
7
|
|
7
8
|
__all__ = [
|
9
|
+
"ClassifyText",
|
8
10
|
"Complete",
|
9
11
|
"CompleteOptions",
|
10
12
|
"ExtractAnswer",
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from typing import List, Optional, Union
|
2
|
+
|
3
|
+
from snowflake import snowpark
|
4
|
+
from snowflake.cortex._util import CORTEX_FUNCTIONS_TELEMETRY_PROJECT, call_sql_function
|
5
|
+
from snowflake.ml._internal import telemetry
|
6
|
+
|
7
|
+
|
8
|
+
@telemetry.send_api_usage_telemetry(
|
9
|
+
project=CORTEX_FUNCTIONS_TELEMETRY_PROJECT,
|
10
|
+
)
|
11
|
+
def ClassifyText(
|
12
|
+
str_input: Union[str, snowpark.Column],
|
13
|
+
categories: Union[List[str], snowpark.Column],
|
14
|
+
session: Optional[snowpark.Session] = None,
|
15
|
+
) -> Union[str, snowpark.Column]:
|
16
|
+
"""Use the LLM inference service to classify the INPUT text into one of the target CATEGORIES.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
str_input: A Column of strings to classify.
|
20
|
+
categories: A list of candidate categories to classify the INPUT text into.
|
21
|
+
session: The snowpark session to use. Will be inferred by context if not specified.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
A column of classification responses.
|
25
|
+
"""
|
26
|
+
|
27
|
+
return _classify_text_impl("snowflake.cortex.classify_text", str_input, categories, session=session)
|
28
|
+
|
29
|
+
|
30
|
+
def _classify_text_impl(
|
31
|
+
function: str,
|
32
|
+
str_input: Union[str, snowpark.Column],
|
33
|
+
categories: Union[List[str], snowpark.Column],
|
34
|
+
session: Optional[snowpark.Session] = None,
|
35
|
+
) -> Union[str, snowpark.Column]:
|
36
|
+
return call_sql_function(function, session, str_input, categories)
|
snowflake/cortex/_complete.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
|
-
|
3
|
+
import time
|
4
|
+
from typing import Any, Callable, Iterator, List, Optional, TypedDict, Union, cast
|
4
5
|
from urllib.parse import urlunparse
|
5
6
|
|
6
7
|
import requests
|
@@ -52,6 +53,38 @@ class ResponseParseException(Exception):
|
|
52
53
|
pass
|
53
54
|
|
54
55
|
|
56
|
+
_MAX_RETRY_SECONDS = 30
|
57
|
+
|
58
|
+
|
59
|
+
def retry(func: Callable[..., requests.Response]) -> Callable[..., requests.Response]:
|
60
|
+
def inner(*args: Any, **kwargs: Any) -> requests.Response:
|
61
|
+
deadline = cast(Optional[float], kwargs["deadline"])
|
62
|
+
kwargs = {key: value for key, value in kwargs.items() if key != "deadline"}
|
63
|
+
expRetrySeconds = 0.5
|
64
|
+
while True:
|
65
|
+
if deadline is not None and time.time() > deadline:
|
66
|
+
raise TimeoutError()
|
67
|
+
response = func(*args, **kwargs)
|
68
|
+
if response.status_code >= 200 and response.status_code < 300:
|
69
|
+
return response
|
70
|
+
retry_status_codes = [429, 503, 504]
|
71
|
+
if response.status_code not in retry_status_codes:
|
72
|
+
response.raise_for_status()
|
73
|
+
logger.debug(f"request failed with status code {response.status_code}, retrying")
|
74
|
+
|
75
|
+
# Formula: delay(i) = max(RetryAfterHeader, min(2^i, _MAX_RETRY_SECONDS)).
|
76
|
+
expRetrySeconds = min(2 * expRetrySeconds, _MAX_RETRY_SECONDS)
|
77
|
+
retrySeconds = expRetrySeconds
|
78
|
+
retryAfterHeader = response.headers.get("retry-after")
|
79
|
+
if retryAfterHeader is not None:
|
80
|
+
retrySeconds = max(retrySeconds, int(retryAfterHeader))
|
81
|
+
logger.debug(f"sleeping for {retrySeconds}s before retrying")
|
82
|
+
time.sleep(retrySeconds)
|
83
|
+
|
84
|
+
return inner
|
85
|
+
|
86
|
+
|
87
|
+
@retry
|
55
88
|
def _call_complete_rest(
|
56
89
|
model: str,
|
57
90
|
prompt: Union[str, List[ConversationMessage]],
|
@@ -78,7 +111,7 @@ def _call_complete_rest(
|
|
78
111
|
scheme = "https"
|
79
112
|
if hasattr(session.connection, "scheme"):
|
80
113
|
scheme = session.connection.scheme
|
81
|
-
url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference
|
114
|
+
url = urlunparse((scheme, session.connection.host, "api/v2/cortex/inference:complete", "", "", ""))
|
82
115
|
|
83
116
|
headers = {
|
84
117
|
"Content-Type": "application/json",
|
@@ -105,19 +138,21 @@ def _call_complete_rest(
|
|
105
138
|
data["top_p"] = options["top_p"]
|
106
139
|
|
107
140
|
logger.debug(f"making POST request to {url} (model={model}, stream={stream})")
|
108
|
-
|
141
|
+
return requests.post(
|
109
142
|
url,
|
110
143
|
json=data,
|
111
144
|
headers=headers,
|
112
145
|
stream=stream,
|
113
146
|
)
|
114
|
-
response.raise_for_status()
|
115
|
-
return response
|
116
147
|
|
117
148
|
|
118
|
-
def _process_rest_response(
|
149
|
+
def _process_rest_response(
|
150
|
+
response: requests.Response,
|
151
|
+
stream: bool = False,
|
152
|
+
deadline: Optional[float] = None,
|
153
|
+
) -> Union[str, Iterator[str]]:
|
119
154
|
if stream:
|
120
|
-
return _return_stream_response(response)
|
155
|
+
return _return_stream_response(response, deadline)
|
121
156
|
|
122
157
|
try:
|
123
158
|
content = response.json()["choices"][0]["message"]["content"]
|
@@ -128,9 +163,11 @@ def _process_rest_response(response: requests.Response, stream: bool = False) ->
|
|
128
163
|
raise ResponseParseException("Failed to parse message from response.") from e
|
129
164
|
|
130
165
|
|
131
|
-
def _return_stream_response(response: requests.Response) -> Iterator[str]:
|
166
|
+
def _return_stream_response(response: requests.Response, deadline: Optional[float]) -> Iterator[str]:
|
132
167
|
client = SSEClient(response)
|
133
168
|
for event in client.events():
|
169
|
+
if deadline is not None and time.time() > deadline:
|
170
|
+
raise TimeoutError()
|
134
171
|
try:
|
135
172
|
yield json.loads(event.data)["choices"][0]["delta"]["content"]
|
136
173
|
except (json.JSONDecodeError, KeyError, IndexError):
|
@@ -209,13 +246,20 @@ def _complete_impl(
|
|
209
246
|
use_rest_api_experimental: bool = False,
|
210
247
|
stream: bool = False,
|
211
248
|
function: str = "snowflake.cortex.complete",
|
249
|
+
timeout: Optional[float] = None,
|
250
|
+
deadline: Optional[float] = None,
|
212
251
|
) -> Union[str, Iterator[str], snowpark.Column]:
|
252
|
+
if timeout is not None and deadline is not None:
|
253
|
+
raise ValueError('only one of "timeout" and "deadline" must be set')
|
254
|
+
if timeout is not None:
|
255
|
+
deadline = time.time() + timeout
|
213
256
|
if use_rest_api_experimental:
|
214
257
|
if not isinstance(model, str):
|
215
258
|
raise ValueError("in REST mode, 'model' must be a string")
|
216
259
|
if not isinstance(prompt, str) and not isinstance(prompt, List):
|
217
260
|
raise ValueError("in REST mode, 'prompt' must be a string or a list of ConversationMessage")
|
218
|
-
response = _call_complete_rest(model, prompt, options, session=session, stream=stream)
|
261
|
+
response = _call_complete_rest(model, prompt, options, session=session, stream=stream, deadline=deadline)
|
262
|
+
assert response.status_code >= 200 and response.status_code < 300
|
219
263
|
return _process_rest_response(response, stream=stream)
|
220
264
|
if stream is True:
|
221
265
|
raise ValueError("streaming can only be enabled in REST mode, set use_rest_api_experimental=True")
|
@@ -233,6 +277,8 @@ def Complete(
|
|
233
277
|
session: Optional[snowpark.Session] = None,
|
234
278
|
use_rest_api_experimental: bool = False,
|
235
279
|
stream: bool = False,
|
280
|
+
timeout: Optional[float] = None,
|
281
|
+
deadline: Optional[float] = None,
|
236
282
|
) -> Union[str, Iterator[str], snowpark.Column]:
|
237
283
|
"""Complete calls into the LLM inference service to perform completion.
|
238
284
|
|
@@ -246,6 +292,8 @@ def Complete(
|
|
246
292
|
stream (bool): Enables streaming. When enabled, a generator function is returned that provides the streaming
|
247
293
|
output as it is received. Each update is a string containing the new text content since the previous update.
|
248
294
|
The use of streaming requires the experimental use_rest_api_experimental flag to be enabled.
|
295
|
+
timeout (float): Timeout in seconds to retry failed REST requests.
|
296
|
+
deadline (float): Time in seconds since the epoch (as returned by time.time()) to retry failed REST requests.
|
249
297
|
|
250
298
|
Raises:
|
251
299
|
ValueError: If `stream` is set to True and `use_rest_api_experimental` is set to False.
|
@@ -254,6 +302,15 @@ def Complete(
|
|
254
302
|
A column of string responses.
|
255
303
|
"""
|
256
304
|
try:
|
257
|
-
return _complete_impl(
|
305
|
+
return _complete_impl(
|
306
|
+
model,
|
307
|
+
prompt,
|
308
|
+
options=options,
|
309
|
+
session=session,
|
310
|
+
use_rest_api_experimental=use_rest_api_experimental,
|
311
|
+
stream=stream,
|
312
|
+
timeout=timeout,
|
313
|
+
deadline=deadline,
|
314
|
+
)
|
258
315
|
except ValueError as err:
|
259
316
|
raise err
|
snowflake/cortex/_util.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Optional, Union, cast
|
1
|
+
from typing import Dict, List, Optional, Union, cast
|
2
2
|
|
3
3
|
from snowflake import snowpark
|
4
4
|
from snowflake.snowpark import context, functions
|
@@ -23,7 +23,7 @@ class SnowflakeConfigurationException(Exception):
|
|
23
23
|
def call_sql_function(
|
24
24
|
function: str,
|
25
25
|
session: Optional[snowpark.Session],
|
26
|
-
*args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
|
26
|
+
*args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
|
27
27
|
) -> Union[str, snowpark.Column]:
|
28
28
|
handle_as_column = False
|
29
29
|
|
@@ -40,7 +40,7 @@ def call_sql_function(
|
|
40
40
|
|
41
41
|
|
42
42
|
def _call_sql_function_column(
|
43
|
-
function: str, *args: Union[str, snowpark.Column, Dict[str, Union[int, float]]]
|
43
|
+
function: str, *args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]]
|
44
44
|
) -> snowpark.Column:
|
45
45
|
return cast(snowpark.Column, functions.builtin(function)(*args))
|
46
46
|
|
@@ -48,7 +48,7 @@ def _call_sql_function_column(
|
|
48
48
|
def _call_sql_function_immediate(
|
49
49
|
function: str,
|
50
50
|
session: Optional[snowpark.Session],
|
51
|
-
*args: Union[str, snowpark.Column, Dict[str, Union[int, float]]],
|
51
|
+
*args: Union[str, List[str], snowpark.Column, Dict[str, Union[int, float]]],
|
52
52
|
) -> str:
|
53
53
|
session = session or context.get_active_session()
|
54
54
|
if session is None:
|
@@ -1,9 +1,9 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Any, Callable, List, Optional
|
3
|
+
from typing import Any, Callable, List, Optional, get_args
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
|
-
from snowflake.ml.
|
6
|
+
from snowflake.ml.data import data_source
|
7
7
|
|
8
8
|
_DATA_SOURCES_ATTR = "_data_sources"
|
9
9
|
|
@@ -39,7 +39,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
39
39
|
result: Optional[List[data_source.DataSource]] = None
|
40
40
|
for arg in args:
|
41
41
|
srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
|
42
|
-
if isinstance(srcs, list) and all(isinstance(s, data_source.DataSource) for s in srcs):
|
42
|
+
if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
|
43
43
|
if result is None:
|
44
44
|
result = []
|
45
45
|
result += srcs
|
@@ -49,7 +49,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
|
|
49
49
|
def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
|
50
50
|
"""Helper method for attaching data sources to an object"""
|
51
51
|
if data_sources:
|
52
|
-
assert all(isinstance(ds, data_source.DataSource) for ds in data_sources)
|
52
|
+
assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
|
53
53
|
setattr(obj, _DATA_SOURCES_ATTR, data_sources)
|
54
54
|
|
55
55
|
|
@@ -277,6 +277,7 @@ def send_api_usage_telemetry(
|
|
277
277
|
]
|
278
278
|
] = None,
|
279
279
|
sfqids_extractor: Optional[Callable[..., List[str]]] = None,
|
280
|
+
subproject_extractor: Optional[Callable[[Any], str]] = None,
|
280
281
|
custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
|
281
282
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
|
282
283
|
"""
|
@@ -290,6 +291,7 @@ def send_api_usage_telemetry(
|
|
290
291
|
conn_attr_name: Name of the SnowflakeConnection attribute in `self`.
|
291
292
|
api_calls_extractor: Extract API calls from `self`.
|
292
293
|
sfqids_extractor: Extract sfqids from `self`.
|
294
|
+
subproject_extractor: Extract subproject at runtime from `self`.
|
293
295
|
custom_tags: Custom tags.
|
294
296
|
|
295
297
|
Returns:
|
@@ -297,10 +299,14 @@ def send_api_usage_telemetry(
|
|
297
299
|
|
298
300
|
Raises:
|
299
301
|
TypeError: If `conn_attr_name` is provided but the conn attribute is not of type SnowflakeConnection.
|
302
|
+
ValueError: If both `subproject` and `subproject_extractor` are provided
|
300
303
|
|
301
304
|
# noqa: DAR402
|
302
305
|
"""
|
303
306
|
|
307
|
+
if subproject is not None and subproject_extractor is not None:
|
308
|
+
raise ValueError("Specifying both subproject and subproject_extractor is not allowed")
|
309
|
+
|
304
310
|
def decorator(func: Callable[_Args, _ReturnValue]) -> Callable[_Args, _ReturnValue]:
|
305
311
|
@functools.wraps(func)
|
306
312
|
def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
|
@@ -322,9 +328,13 @@ def send_api_usage_telemetry(
|
|
322
328
|
if sfqids_extractor:
|
323
329
|
sfqids = sfqids_extractor(args[0])
|
324
330
|
|
331
|
+
subproject_name = subproject
|
332
|
+
if subproject_extractor is not None:
|
333
|
+
subproject_name = subproject_extractor(args[0])
|
334
|
+
|
325
335
|
statement_params = get_function_usage_statement_params(
|
326
336
|
project=project,
|
327
|
-
subproject=
|
337
|
+
subproject=subproject_name,
|
328
338
|
function_category=TelemetryField.FUNC_CAT_USAGE.value,
|
329
339
|
function_name=_get_full_func_name(func),
|
330
340
|
function_parameters=params,
|
@@ -381,7 +391,7 @@ def send_api_usage_telemetry(
|
|
381
391
|
raise e.original_exception from e
|
382
392
|
|
383
393
|
# TODO(hayu): [SNOW-750287] Optimize telemetry client to a singleton.
|
384
|
-
telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=
|
394
|
+
telemetry = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject_name)
|
385
395
|
telemetry_args = dict(
|
386
396
|
func_name=_get_full_func_name(func),
|
387
397
|
function_category=TelemetryField.FUNC_CAT_USAGE.value,
|
@@ -0,0 +1,228 @@
|
|
1
|
+
import collections
|
2
|
+
import logging
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from typing import Any, Deque, Dict, Iterator, List, Optional
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import numpy.typing as npt
|
9
|
+
import pandas as pd
|
10
|
+
import pyarrow as pa
|
11
|
+
import pyarrow.dataset as ds
|
12
|
+
|
13
|
+
from snowflake import snowpark
|
14
|
+
from snowflake.ml.data import data_ingestor, data_source
|
15
|
+
from snowflake.ml.data._internal import ingestor_utils
|
16
|
+
|
17
|
+
_EMPTY_RECORD_BATCH = pa.RecordBatch.from_arrays([], [])
|
18
|
+
|
19
|
+
# The row count for batches read from PyArrow Dataset. This number should be large enough so that
|
20
|
+
# dataset.to_batches() would read in a very large portion of, if not entirely, a parquet file.
|
21
|
+
_DEFAULT_DATASET_BATCH_SIZE = 1000000
|
22
|
+
|
23
|
+
|
24
|
+
class _RecordBatchesBuffer:
|
25
|
+
"""A queue that stores record batches and tracks the total num of rows in it."""
|
26
|
+
|
27
|
+
def __init__(self) -> None:
|
28
|
+
self.buffer: Deque[pa.RecordBatch] = collections.deque()
|
29
|
+
self.num_rows = 0
|
30
|
+
|
31
|
+
def append(self, rb: pa.RecordBatch) -> None:
|
32
|
+
self.buffer.append(rb)
|
33
|
+
self.num_rows += rb.num_rows
|
34
|
+
|
35
|
+
def appendleft(self, rb: pa.RecordBatch) -> None:
|
36
|
+
self.buffer.appendleft(rb)
|
37
|
+
self.num_rows += rb.num_rows
|
38
|
+
|
39
|
+
def popleft(self) -> pa.RecordBatch:
|
40
|
+
popped = self.buffer.popleft()
|
41
|
+
self.num_rows -= popped.num_rows
|
42
|
+
return popped
|
43
|
+
|
44
|
+
|
45
|
+
class ArrowIngestor(data_ingestor.DataIngestor):
|
46
|
+
"""Read and parse the data sources into an Arrow Dataset and yield batched numpy array in dict."""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
session: snowpark.Session,
|
51
|
+
data_sources: List[data_source.DataSource],
|
52
|
+
format: Optional[str] = None,
|
53
|
+
**kwargs: Any,
|
54
|
+
) -> None:
|
55
|
+
"""
|
56
|
+
Args:
|
57
|
+
session: The Snowpark Session to use.
|
58
|
+
data_sources: List of data sources to ingest.
|
59
|
+
format: Currently “parquet”, “ipc”/”arrow”/”feather”, “csv”, “json”, and “orc” are supported.
|
60
|
+
Will be inferred if not specified.
|
61
|
+
kwargs: Miscellaneous arguments passed to underlying PyArrow Dataset initializer.
|
62
|
+
"""
|
63
|
+
self._session = session
|
64
|
+
self._data_sources = data_sources
|
65
|
+
self._format = format
|
66
|
+
self._kwargs = kwargs
|
67
|
+
|
68
|
+
self._schema: Optional[pa.Schema] = None
|
69
|
+
|
70
|
+
@property
|
71
|
+
def data_sources(self) -> List[data_source.DataSource]:
|
72
|
+
return self._data_sources
|
73
|
+
|
74
|
+
def to_batches(
|
75
|
+
self,
|
76
|
+
batch_size: int,
|
77
|
+
shuffle: bool = True,
|
78
|
+
drop_last_batch: bool = True,
|
79
|
+
) -> Iterator[Dict[str, npt.NDArray[Any]]]:
|
80
|
+
"""Iterate through PyArrow Dataset to generate batches whose length equals to expected batch size.
|
81
|
+
|
82
|
+
As we are generating batches with the exactly same length, the last few rows in each file might get left as they
|
83
|
+
are not long enough to form a batch. These rows will be put into a temporary buffer and combine with the first
|
84
|
+
few rows of the next file to generate a new batch.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
batch_size: Specifies the size of each batch that will be yield
|
88
|
+
shuffle: Whether the data in the file will be shuffled. If set to be true, it will first randomly shuffle
|
89
|
+
the order of files, and then shuflle the order of rows in each file.
|
90
|
+
drop_last_batch: Whether the last batch of data should be dropped. If set to be true, then the last
|
91
|
+
batch will get dropped if its size is smaller than the given batch_size.
|
92
|
+
|
93
|
+
Yields:
|
94
|
+
A dict mapping column names to the corresponding data fetch from that column.
|
95
|
+
"""
|
96
|
+
self._rb_buffer = _RecordBatchesBuffer()
|
97
|
+
|
98
|
+
# Extract schema if not already known
|
99
|
+
dataset = self._get_dataset(shuffle)
|
100
|
+
if self._schema is None:
|
101
|
+
self._schema = dataset.schema
|
102
|
+
|
103
|
+
for rb in _retryable_batches(dataset, batch_size=max(_DEFAULT_DATASET_BATCH_SIZE, batch_size)):
|
104
|
+
if shuffle:
|
105
|
+
rb = rb.take(np.random.permutation(rb.num_rows))
|
106
|
+
self._rb_buffer.append(rb)
|
107
|
+
while self._rb_buffer.num_rows >= batch_size:
|
108
|
+
yield self._get_batches_from_buffer(batch_size)
|
109
|
+
|
110
|
+
if self._rb_buffer.num_rows and not drop_last_batch:
|
111
|
+
yield self._get_batches_from_buffer(batch_size)
|
112
|
+
|
113
|
+
def to_pandas(self, limit: Optional[int] = None) -> pd.DataFrame:
|
114
|
+
ds = self._get_dataset(shuffle=False)
|
115
|
+
table = ds.to_table() if limit is None else ds.head(num_rows=limit)
|
116
|
+
return table.to_pandas()
|
117
|
+
|
118
|
+
def _get_dataset(self, shuffle: bool) -> ds.Dataset:
|
119
|
+
format = self._format
|
120
|
+
sources = []
|
121
|
+
source_format = None
|
122
|
+
for source in self._data_sources:
|
123
|
+
if isinstance(source, str):
|
124
|
+
sources.append(source)
|
125
|
+
source_format = format or os.path.splitext(source)[-1]
|
126
|
+
elif isinstance(source, data_source.DatasetInfo):
|
127
|
+
if not self._kwargs.get("filesystem"):
|
128
|
+
self._kwargs["filesystem"] = ingestor_utils.get_dataset_filesystem(self._session, source)
|
129
|
+
sources.extend(
|
130
|
+
ingestor_utils.get_dataset_files(self._session, source, filesystem=self._kwargs["filesystem"])
|
131
|
+
)
|
132
|
+
source_format = "parquet"
|
133
|
+
elif isinstance(source, data_source.DataFrameInfo):
|
134
|
+
# FIXME: This currently loads all result batches into memory so that it
|
135
|
+
# can be passed into pyarrow.dataset as a list/tuple of pa.RecordBatches
|
136
|
+
# We may be able to optimize this by splitting the result batches into
|
137
|
+
# in-memory (first batch) and file URLs (subsequent batches) and creating a
|
138
|
+
# union dataset.
|
139
|
+
result_batches = ingestor_utils.get_dataframe_result_batches(self._session, source)
|
140
|
+
sources.extend(b.to_arrow() for b in result_batches)
|
141
|
+
source_format = "arrow"
|
142
|
+
else:
|
143
|
+
raise RuntimeError(f"Unsupported data source type: {type(source)}")
|
144
|
+
|
145
|
+
# Make sure source types not mixed
|
146
|
+
if format and format != source_format:
|
147
|
+
raise RuntimeError(f"Unexpected data source format (expected {format}, found {source_format})")
|
148
|
+
format = source_format
|
149
|
+
|
150
|
+
# Re-shuffle input files on each iteration start
|
151
|
+
if shuffle:
|
152
|
+
np.random.shuffle(sources)
|
153
|
+
pa_dataset: ds.Dataset = ds.dataset(sources, format=format, **self._kwargs)
|
154
|
+
return pa_dataset
|
155
|
+
|
156
|
+
def _get_batches_from_buffer(self, batch_size: int) -> Dict[str, npt.NDArray[Any]]:
|
157
|
+
"""Generate new batches from the existing record batch buffer."""
|
158
|
+
cnt_rbs_num_rows = 0
|
159
|
+
candidates = []
|
160
|
+
|
161
|
+
# Keep popping record batches in buffer until there are enough rows for a batch.
|
162
|
+
while self._rb_buffer.num_rows and cnt_rbs_num_rows < batch_size:
|
163
|
+
candidate = self._rb_buffer.popleft()
|
164
|
+
cnt_rbs_num_rows += candidate.num_rows
|
165
|
+
candidates.append(candidate)
|
166
|
+
|
167
|
+
# When there are more rows than needed, slice the last popped batch to fit batch_size.
|
168
|
+
if cnt_rbs_num_rows > batch_size:
|
169
|
+
row_diff = cnt_rbs_num_rows - batch_size
|
170
|
+
slice_target = candidates[-1]
|
171
|
+
cut_off = slice_target.num_rows - row_diff
|
172
|
+
to_merge = slice_target.slice(length=cut_off)
|
173
|
+
left_over = slice_target.slice(offset=cut_off)
|
174
|
+
candidates[-1] = to_merge
|
175
|
+
self._rb_buffer.appendleft(left_over)
|
176
|
+
|
177
|
+
res = _merge_record_batches(candidates)
|
178
|
+
return _record_batch_to_arrays(res)
|
179
|
+
|
180
|
+
|
181
|
+
def _merge_record_batches(record_batches: List[pa.RecordBatch]) -> pa.RecordBatch:
|
182
|
+
"""Merge a list of arrow RecordBatches into one. Similar to MergeTables."""
|
183
|
+
if not record_batches:
|
184
|
+
return _EMPTY_RECORD_BATCH
|
185
|
+
if len(record_batches) == 1:
|
186
|
+
return record_batches[0]
|
187
|
+
record_batches = list(filter(lambda rb: rb.num_rows > 0, record_batches))
|
188
|
+
one_chunk_table = pa.Table.from_batches(record_batches).combine_chunks()
|
189
|
+
batches = one_chunk_table.to_batches(max_chunksize=None)
|
190
|
+
return batches[0]
|
191
|
+
|
192
|
+
|
193
|
+
def _record_batch_to_arrays(rb: pa.RecordBatch) -> Dict[str, npt.NDArray[Any]]:
|
194
|
+
"""Transform the record batch to a (string, numpy array) dict."""
|
195
|
+
batch_dict = {}
|
196
|
+
for column, column_schema in zip(rb, rb.schema):
|
197
|
+
# zero_copy_only=False because of nans. Ideally nans should have been imputed in feature engineering.
|
198
|
+
array = column.to_numpy(zero_copy_only=False)
|
199
|
+
batch_dict[column_schema.name] = array
|
200
|
+
return batch_dict
|
201
|
+
|
202
|
+
|
203
|
+
def _retryable_batches(
|
204
|
+
dataset: ds.Dataset, batch_size: int, max_retries: int = 3, delay: int = 0
|
205
|
+
) -> Iterator[pa.RecordBatch]:
|
206
|
+
"""Make the Dataset to_batches retryable."""
|
207
|
+
retries = 0
|
208
|
+
current_batch_index = 0
|
209
|
+
|
210
|
+
while True:
|
211
|
+
try:
|
212
|
+
for batch_index, batch in enumerate(dataset.to_batches(batch_size=batch_size)):
|
213
|
+
if batch_index < current_batch_index:
|
214
|
+
# Skip batches that have already been processed
|
215
|
+
continue
|
216
|
+
|
217
|
+
yield batch
|
218
|
+
current_batch_index = batch_index + 1
|
219
|
+
# Exit the loop once all batches are processed
|
220
|
+
break
|
221
|
+
|
222
|
+
except Exception as e:
|
223
|
+
if retries < max_retries:
|
224
|
+
retries += 1
|
225
|
+
logging.info(f"Error encountered: {e}. Retrying {retries}/{max_retries}...")
|
226
|
+
time.sleep(delay)
|
227
|
+
else:
|
228
|
+
raise e
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from typing import List, Optional
|
2
|
+
|
3
|
+
import fsspec
|
4
|
+
|
5
|
+
from snowflake import snowpark
|
6
|
+
from snowflake.connector import result_batch
|
7
|
+
from snowflake.ml.data import data_source
|
8
|
+
from snowflake.ml.fileset import snowfs
|
9
|
+
|
10
|
+
_TARGET_FILE_SIZE = 32 * 2**20 # The max file size for data loading.
|
11
|
+
|
12
|
+
|
13
|
+
def get_dataframe_result_batches(
|
14
|
+
session: snowpark.Session, df_info: data_source.DataFrameInfo
|
15
|
+
) -> List[result_batch.ResultBatch]:
|
16
|
+
cursor = session._conn._cursor
|
17
|
+
|
18
|
+
if df_info.query_id:
|
19
|
+
query_id = df_info.query_id
|
20
|
+
else:
|
21
|
+
query_id = session.sql(df_info.sql).collect_nowait().query_id
|
22
|
+
|
23
|
+
# TODO: Check if query result cache is still live
|
24
|
+
cursor.get_results_from_sfqid(sfqid=query_id)
|
25
|
+
|
26
|
+
# Prefetch hook should be set by `get_results_from_sfqid`
|
27
|
+
# This call blocks until the query results are ready
|
28
|
+
if cursor._prefetch_hook is None:
|
29
|
+
raise RuntimeError("Loading data from result query failed unexpectedly. Please contact Snowflake support.")
|
30
|
+
cursor._prefetch_hook()
|
31
|
+
batches = cursor.get_result_batches()
|
32
|
+
if batches is None:
|
33
|
+
raise ValueError(
|
34
|
+
"Failed to retrieve training data. Query status:" f" {session._conn._conn.get_query_status(query_id)}"
|
35
|
+
)
|
36
|
+
return batches
|
37
|
+
|
38
|
+
|
39
|
+
def get_dataset_filesystem(
|
40
|
+
session: snowpark.Session, ds_info: Optional[data_source.DatasetInfo] = None
|
41
|
+
) -> fsspec.AbstractFileSystem:
|
42
|
+
# We can't directly load the Dataset to avoid a circular dependency
|
43
|
+
# Dataset -> DatasetReader -> DataConnector -> DataIngestor -> (?) ingestor_utils -> Dataset
|
44
|
+
# TODO: Automatically pick appropriate fsspec implementation based on protocol in URL
|
45
|
+
return snowfs.SnowFileSystem(
|
46
|
+
snowpark_session=session,
|
47
|
+
cache_type="bytes",
|
48
|
+
block_size=2 * _TARGET_FILE_SIZE,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
def get_dataset_files(
|
53
|
+
session: snowpark.Session, ds_info: data_source.DatasetInfo, filesystem: Optional[fsspec.AbstractFileSystem] = None
|
54
|
+
) -> List[str]:
|
55
|
+
if filesystem is None:
|
56
|
+
filesystem = get_dataset_filesystem(session, ds_info)
|
57
|
+
assert bool(ds_info.url) # Not null or empty
|
58
|
+
return sorted(filesystem.ls(ds_info.url))
|