snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -8,25 +8,13 @@ import sys
|
|
8
8
|
import time
|
9
9
|
import traceback
|
10
10
|
import types
|
11
|
-
from typing import
|
12
|
-
Any,
|
13
|
-
Callable,
|
14
|
-
Dict,
|
15
|
-
Iterable,
|
16
|
-
List,
|
17
|
-
Mapping,
|
18
|
-
Optional,
|
19
|
-
Set,
|
20
|
-
Tuple,
|
21
|
-
TypeVar,
|
22
|
-
Union,
|
23
|
-
cast,
|
24
|
-
)
|
11
|
+
from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, cast
|
25
12
|
|
26
13
|
from typing_extensions import ParamSpec
|
27
14
|
|
28
15
|
from snowflake import connector
|
29
16
|
from snowflake.connector import telemetry as connector_telemetry, time_util
|
17
|
+
from snowflake.ml import version as snowml_version
|
30
18
|
from snowflake.ml._internal import env
|
31
19
|
from snowflake.ml._internal.exceptions import (
|
32
20
|
error_codes,
|
@@ -99,13 +87,13 @@ class _TelemetrySourceType(enum.Enum):
|
|
99
87
|
AUGMENT_TELEMETRY = "SNOWML_AUGMENT_TELEMETRY"
|
100
88
|
|
101
89
|
|
102
|
-
_statement_params_context_var: contextvars.ContextVar[
|
90
|
+
_statement_params_context_var: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("statement_params")
|
103
91
|
|
104
92
|
|
105
93
|
class _StatementParamsPatchManager:
|
106
94
|
def __init__(self) -> None:
|
107
|
-
self._patch_cache:
|
108
|
-
self._context_var: contextvars.ContextVar[
|
95
|
+
self._patch_cache: set[server_connection.ServerConnection] = set()
|
96
|
+
self._context_var: contextvars.ContextVar[dict[str, str]] = _statement_params_context_var
|
109
97
|
|
110
98
|
def apply_patches(self) -> None:
|
111
99
|
try:
|
@@ -117,7 +105,7 @@ class _StatementParamsPatchManager:
|
|
117
105
|
except snowpark_exceptions.SnowparkSessionException:
|
118
106
|
pass
|
119
107
|
|
120
|
-
def set_statement_params(self, statement_params:
|
108
|
+
def set_statement_params(self, statement_params: dict[str, str]) -> None:
|
121
109
|
# Only set value if not already set in context
|
122
110
|
if not self._context_var.get({}):
|
123
111
|
self._context_var.set(statement_params)
|
@@ -152,7 +140,6 @@ class _StatementParamsPatchManager:
|
|
152
140
|
if throw_on_patch_fail: # primarily used for testing
|
153
141
|
raise
|
154
142
|
# TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection
|
155
|
-
pass
|
156
143
|
|
157
144
|
def _patch_with_statement_params(
|
158
145
|
self, target: object, function_name: str, param_name: str = "statement_params"
|
@@ -197,10 +184,10 @@ class _StatementParamsPatchManager:
|
|
197
184
|
|
198
185
|
setattr(target, function_name, wrapper)
|
199
186
|
|
200
|
-
def __getstate__(self) ->
|
187
|
+
def __getstate__(self) -> dict[str, Any]:
|
201
188
|
return {}
|
202
189
|
|
203
|
-
def __setstate__(self, state:
|
190
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
204
191
|
# unpickling does not call __init__ by default, do it manually here
|
205
192
|
self.__init__() # type: ignore[misc]
|
206
193
|
|
@@ -210,7 +197,7 @@ _patch_manager = _StatementParamsPatchManager()
|
|
210
197
|
|
211
198
|
def get_statement_params(
|
212
199
|
project: str, subproject: Optional[str] = None, class_name: Optional[str] = None
|
213
|
-
) ->
|
200
|
+
) -> dict[str, Any]:
|
214
201
|
"""
|
215
202
|
Get telemetry statement parameters.
|
216
203
|
|
@@ -231,8 +218,8 @@ def get_statement_params(
|
|
231
218
|
|
232
219
|
|
233
220
|
def add_statement_params_custom_tags(
|
234
|
-
statement_params: Optional[
|
235
|
-
) ->
|
221
|
+
statement_params: Optional[dict[str, Any]], custom_tags: Mapping[str, Any]
|
222
|
+
) -> dict[str, Any]:
|
236
223
|
"""
|
237
224
|
Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
|
238
225
|
If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
|
@@ -246,7 +233,7 @@ def add_statement_params_custom_tags(
|
|
246
233
|
"""
|
247
234
|
if not statement_params:
|
248
235
|
return {}
|
249
|
-
existing_custom_tags:
|
236
|
+
existing_custom_tags: dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
|
250
237
|
existing_custom_tags.update(custom_tags)
|
251
238
|
# NOTE: This can be done with | operator after upgrade from py3.8
|
252
239
|
return {
|
@@ -289,17 +276,17 @@ def get_function_usage_statement_params(
|
|
289
276
|
*,
|
290
277
|
function_category: str = TelemetryField.FUNC_CAT_USAGE.value,
|
291
278
|
function_name: Optional[str] = None,
|
292
|
-
function_parameters: Optional[
|
279
|
+
function_parameters: Optional[dict[str, Any]] = None,
|
293
280
|
api_calls: Optional[
|
294
|
-
|
281
|
+
list[
|
295
282
|
Union[
|
296
|
-
|
283
|
+
dict[str, Union[Callable[..., Any], str]],
|
297
284
|
Union[Callable[..., Any], str],
|
298
285
|
]
|
299
286
|
]
|
300
287
|
] = None,
|
301
|
-
custom_tags: Optional[
|
302
|
-
) ->
|
288
|
+
custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
|
289
|
+
) -> dict[str, Any]:
|
303
290
|
"""
|
304
291
|
Get function usage statement parameters.
|
305
292
|
|
@@ -321,12 +308,12 @@ def get_function_usage_statement_params(
|
|
321
308
|
>>> df.collect(statement_params=statement_params)
|
322
309
|
"""
|
323
310
|
telemetry_type = f"{env.SOURCE.lower()}_{TelemetryField.TYPE_FUNCTION_USAGE.value}"
|
324
|
-
statement_params:
|
311
|
+
statement_params: dict[str, Any] = {
|
325
312
|
connector_telemetry.TelemetryField.KEY_SOURCE.value: env.SOURCE,
|
326
313
|
TelemetryField.KEY_PROJECT.value: project,
|
327
314
|
TelemetryField.KEY_SUBPROJECT.value: subproject,
|
328
315
|
TelemetryField.KEY_OS.value: env.OS,
|
329
|
-
TelemetryField.KEY_VERSION.value:
|
316
|
+
TelemetryField.KEY_VERSION.value: snowml_version.VERSION,
|
330
317
|
TelemetryField.KEY_PYTHON_VERSION.value: env.PYTHON_VERSION,
|
331
318
|
connector_telemetry.TelemetryField.KEY_TYPE.value: telemetry_type,
|
332
319
|
TelemetryField.KEY_CATEGORY.value: function_category,
|
@@ -339,7 +326,7 @@ def get_function_usage_statement_params(
|
|
339
326
|
if api_calls:
|
340
327
|
statement_params[TelemetryField.KEY_API_CALLS.value] = []
|
341
328
|
for api_call in api_calls:
|
342
|
-
if isinstance(api_call,
|
329
|
+
if isinstance(api_call, dict):
|
343
330
|
telemetry_api_call = api_call.copy()
|
344
331
|
# convert Callable to str
|
345
332
|
for field, api in api_call.items():
|
@@ -388,7 +375,7 @@ def send_custom_usage(
|
|
388
375
|
*,
|
389
376
|
telemetry_type: str,
|
390
377
|
subproject: Optional[str] = None,
|
391
|
-
data: Optional[
|
378
|
+
data: Optional[dict[str, Any]] = None,
|
392
379
|
**kwargs: Any,
|
393
380
|
) -> None:
|
394
381
|
active_session = next(iter(session._get_active_sessions()))
|
@@ -409,17 +396,17 @@ def send_api_usage_telemetry(
|
|
409
396
|
api_calls_extractor: Optional[
|
410
397
|
Callable[
|
411
398
|
...,
|
412
|
-
|
399
|
+
list[
|
413
400
|
Union[
|
414
|
-
|
401
|
+
dict[str, Union[Callable[..., Any], str]],
|
415
402
|
Union[Callable[..., Any], str],
|
416
403
|
]
|
417
404
|
],
|
418
405
|
]
|
419
406
|
] = None,
|
420
|
-
sfqids_extractor: Optional[Callable[...,
|
407
|
+
sfqids_extractor: Optional[Callable[..., list[str]]] = None,
|
421
408
|
subproject_extractor: Optional[Callable[[Any], str]] = None,
|
422
|
-
custom_tags: Optional[
|
409
|
+
custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
|
423
410
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
|
424
411
|
"""
|
425
412
|
Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
|
@@ -454,7 +441,7 @@ def send_api_usage_telemetry(
|
|
454
441
|
def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
|
455
442
|
params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
|
456
443
|
|
457
|
-
api_calls:
|
444
|
+
api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
|
458
445
|
if api_calls_extractor:
|
459
446
|
extracted_api_calls = api_calls_extractor(args[0])
|
460
447
|
for api_call in extracted_api_calls:
|
@@ -484,7 +471,7 @@ def send_api_usage_telemetry(
|
|
484
471
|
custom_tags=custom_tags,
|
485
472
|
)
|
486
473
|
|
487
|
-
def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params:
|
474
|
+
def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
|
488
475
|
"""
|
489
476
|
Update SnowML function usage statement parameters to the object if it is a Snowpark DataFrame.
|
490
477
|
Used to track APIs returning a Snowpark DataFrame.
|
@@ -614,7 +601,7 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
|
|
614
601
|
|
615
602
|
def _get_func_params(
|
616
603
|
func: Callable[..., Any], func_params_to_log: Optional[Iterable[str]], args: Any, kwargs: Any
|
617
|
-
) ->
|
604
|
+
) -> dict[str, Any]:
|
618
605
|
"""
|
619
606
|
Get function parameters.
|
620
607
|
|
@@ -639,7 +626,7 @@ def _get_func_params(
|
|
639
626
|
return params
|
640
627
|
|
641
628
|
|
642
|
-
def _extract_arg_value(field: str, func_spec: inspect.FullArgSpec, args: Any, kwargs: Any) ->
|
629
|
+
def _extract_arg_value(field: str, func_spec: inspect.FullArgSpec, args: Any, kwargs: Any) -> tuple[bool, Any]:
|
643
630
|
"""
|
644
631
|
Function to extract a specified argument value.
|
645
632
|
|
@@ -702,11 +689,11 @@ class _SourceTelemetryClient:
|
|
702
689
|
self.source: str = env.SOURCE
|
703
690
|
self.project: Optional[str] = project
|
704
691
|
self.subproject: Optional[str] = subproject
|
705
|
-
self.version =
|
692
|
+
self.version = snowml_version.VERSION
|
706
693
|
self.python_version: str = env.PYTHON_VERSION
|
707
694
|
self.os: str = env.OS
|
708
695
|
|
709
|
-
def _send(self, msg:
|
696
|
+
def _send(self, msg: dict[str, Any], timestamp: Optional[int] = None) -> None:
|
710
697
|
"""
|
711
698
|
Add telemetry data to a batch in connector client.
|
712
699
|
|
@@ -720,7 +707,7 @@ class _SourceTelemetryClient:
|
|
720
707
|
telemetry_data = connector_telemetry.TelemetryData(message=msg, timestamp=timestamp)
|
721
708
|
self._telemetry.try_add_log_to_batch(telemetry_data)
|
722
709
|
|
723
|
-
def _create_basic_telemetry_data(self, telemetry_type: str) ->
|
710
|
+
def _create_basic_telemetry_data(self, telemetry_type: str) -> dict[str, Any]:
|
724
711
|
message = {
|
725
712
|
connector_telemetry.TelemetryField.KEY_SOURCE.value: self.source,
|
726
713
|
TelemetryField.KEY_PROJECT.value: self.project,
|
@@ -738,10 +725,10 @@ class _SourceTelemetryClient:
|
|
738
725
|
func_name: str,
|
739
726
|
function_category: str,
|
740
727
|
duration: float,
|
741
|
-
func_params: Optional[
|
742
|
-
api_calls: Optional[
|
743
|
-
sfqids: Optional[
|
744
|
-
custom_tags: Optional[
|
728
|
+
func_params: Optional[dict[str, Any]] = None,
|
729
|
+
api_calls: Optional[list[dict[str, Any]]] = None,
|
730
|
+
sfqids: Optional[list[Any]] = None,
|
731
|
+
custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
|
745
732
|
error: Optional[str] = None,
|
746
733
|
error_code: Optional[str] = None,
|
747
734
|
stack_trace: Optional[str] = None,
|
@@ -761,7 +748,7 @@ class _SourceTelemetryClient:
|
|
761
748
|
error_code: Error code.
|
762
749
|
stack_trace: Error stack trace.
|
763
750
|
"""
|
764
|
-
data:
|
751
|
+
data: dict[str, Any] = {
|
765
752
|
TelemetryField.KEY_FUNC_NAME.value: func_name,
|
766
753
|
TelemetryField.KEY_CATEGORY.value: function_category,
|
767
754
|
}
|
@@ -775,7 +762,7 @@ class _SourceTelemetryClient:
|
|
775
762
|
data[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
|
776
763
|
|
777
764
|
telemetry_type = f"{self.source.lower()}_{TelemetryField.TYPE_FUNCTION_USAGE.value}"
|
778
|
-
message:
|
765
|
+
message: dict[str, Any] = {
|
779
766
|
**self._create_basic_telemetry_data(telemetry_type),
|
780
767
|
TelemetryField.KEY_DATA.value: data,
|
781
768
|
TelemetryField.KEY_DURATION.value: duration,
|
@@ -795,7 +782,7 @@ class _SourceTelemetryClient:
|
|
795
782
|
self._telemetry.send_batch()
|
796
783
|
|
797
784
|
|
798
|
-
def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params:
|
785
|
+
def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: dict[str, Any]) -> dict[str, Any]:
|
799
786
|
"""
|
800
787
|
Get statement_params keyword argument for sproc call.
|
801
788
|
|
@@ -11,7 +11,7 @@ T = TypeVar("T")
|
|
11
11
|
class LazyType(Generic[T]):
|
12
12
|
"""Utility type to help defer need of importing."""
|
13
13
|
|
14
|
-
def __init__(self, klass: Union[str,
|
14
|
+
def __init__(self, klass: Union[str, type[T]]) -> None:
|
15
15
|
self.qualname = ""
|
16
16
|
if isinstance(klass, str):
|
17
17
|
parts = klass.rsplit(".", 1)
|
@@ -30,7 +30,7 @@ class LazyType(Generic[T]):
|
|
30
30
|
return self.isinstance(obj)
|
31
31
|
|
32
32
|
@classmethod
|
33
|
-
def from_type(cls, typ_: Union["LazyType[T]",
|
33
|
+
def from_type(cls, typ_: Union["LazyType[T]", type[T]]) -> "LazyType[T]":
|
34
34
|
if isinstance(typ_, LazyType):
|
35
35
|
return typ_
|
36
36
|
return cls(typ_)
|
@@ -48,7 +48,7 @@ class LazyType(Generic[T]):
|
|
48
48
|
def __repr__(self) -> str:
|
49
49
|
return f'LazyType("{self.module}", "{self.qualname}")'
|
50
50
|
|
51
|
-
def get_class(self) ->
|
51
|
+
def get_class(self) -> type[T]:
|
52
52
|
if self._runtime_class is None:
|
53
53
|
try:
|
54
54
|
m = importlib.import_module(self.module)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from enum import Enum
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional
|
3
3
|
|
4
4
|
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
5
5
|
from snowflake.snowpark import session
|
@@ -19,7 +19,7 @@ def db_object_exists(
|
|
19
19
|
*,
|
20
20
|
database_name: Optional[sql_identifier.SqlIdentifier] = None,
|
21
21
|
schema_name: Optional[sql_identifier.SqlIdentifier] = None,
|
22
|
-
statement_params: Optional[
|
22
|
+
statement_params: Optional[dict[str, Any]] = None,
|
23
23
|
) -> bool:
|
24
24
|
"""Check if object exists in database.
|
25
25
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import re
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, Union, overload
|
3
3
|
|
4
4
|
from snowflake.snowpark._internal.analyzer import analyzer_utils
|
5
5
|
|
@@ -112,7 +112,7 @@ def get_inferred_name(name: str) -> str:
|
|
112
112
|
return escaped_id
|
113
113
|
|
114
114
|
|
115
|
-
def concat_names(names:
|
115
|
+
def concat_names(names: list[str]) -> str:
|
116
116
|
"""Concatenates `names` to form one valid id.
|
117
117
|
|
118
118
|
|
@@ -142,7 +142,7 @@ def rename_to_valid_snowflake_identifier(name: str) -> str:
|
|
142
142
|
|
143
143
|
def parse_schema_level_object_identifier(
|
144
144
|
object_name: str,
|
145
|
-
) ->
|
145
|
+
) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
|
146
146
|
"""Parse a string which starts with schema level object.
|
147
147
|
|
148
148
|
Args:
|
@@ -172,7 +172,7 @@ def parse_schema_level_object_identifier(
|
|
172
172
|
|
173
173
|
def parse_snowflake_stage_path(
|
174
174
|
path: str,
|
175
|
-
) ->
|
175
|
+
) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
|
176
176
|
"""Parse a string which represents a snowflake stage path.
|
177
177
|
|
178
178
|
Args:
|
@@ -260,11 +260,11 @@ def get_unescaped_names(ids: str) -> str:
|
|
260
260
|
|
261
261
|
|
262
262
|
@overload
|
263
|
-
def get_unescaped_names(ids:
|
263
|
+
def get_unescaped_names(ids: list[str]) -> list[str]:
|
264
264
|
...
|
265
265
|
|
266
266
|
|
267
|
-
def get_unescaped_names(ids: Optional[Union[str,
|
267
|
+
def get_unescaped_names(ids: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
|
268
268
|
"""Given a user provided identifier(s), this method will compute the equivalent column name identifier(s) in the
|
269
269
|
response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
|
270
270
|
https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
|
@@ -308,11 +308,11 @@ def get_inferred_names(names: str) -> str:
|
|
308
308
|
|
309
309
|
|
310
310
|
@overload
|
311
|
-
def get_inferred_names(names:
|
311
|
+
def get_inferred_names(names: list[str]) -> list[str]:
|
312
312
|
...
|
313
313
|
|
314
314
|
|
315
|
-
def get_inferred_names(names: Optional[Union[str,
|
315
|
+
def get_inferred_names(names: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
|
316
316
|
"""Given a user provided *string(s)*, this method will compute the equivalent column name identifier(s)
|
317
317
|
in case of column name contains special characters, and maintains case-sensitivity
|
318
318
|
https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import importlib
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any
|
3
3
|
|
4
4
|
|
5
5
|
class MissingOptionalDependency:
|
@@ -46,7 +46,7 @@ def import_with_fallbacks(*targets: str) -> Any:
|
|
46
46
|
raise ImportError(f"None of the requested targets could be imported. Requested: {', '.join(targets)}")
|
47
47
|
|
48
48
|
|
49
|
-
def import_or_get_dummy(target: str) ->
|
49
|
+
def import_or_get_dummy(target: str) -> tuple[Any, bool]:
|
50
50
|
"""Try to import the the given target or return a dummy object.
|
51
51
|
|
52
52
|
If the import target (package/module/symbol) is available, the target will be returned. If it is not available,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import math
|
2
2
|
from contextlib import contextmanager
|
3
3
|
from timeit import default_timer
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Generator, Iterable, Optional
|
5
5
|
|
6
6
|
import snowflake.snowpark.functions as F
|
7
7
|
from snowflake import snowpark
|
@@ -17,17 +17,17 @@ def timer() -> Generator[Callable[[], float], None, None]:
|
|
17
17
|
yield lambda: elapser()
|
18
18
|
|
19
19
|
|
20
|
-
def _flatten(L: Iterable[
|
20
|
+
def _flatten(L: Iterable[list[Any]]) -> list[Any]:
|
21
21
|
return [val for sublist in L for val in sublist]
|
22
22
|
|
23
23
|
|
24
24
|
def map_dataframe_by_column(
|
25
25
|
df: snowpark.DataFrame,
|
26
|
-
cols:
|
27
|
-
map_func: Callable[[snowpark.DataFrame,
|
26
|
+
cols: list[str],
|
27
|
+
map_func: Callable[[snowpark.DataFrame, list[str]], snowpark.DataFrame],
|
28
28
|
partition_size: int,
|
29
|
-
statement_params: Optional[
|
30
|
-
) ->
|
29
|
+
statement_params: Optional[dict[str, Any]] = None,
|
30
|
+
) -> list[list[Any]]:
|
31
31
|
"""Applies the `map_func` to the input DataFrame by parallelizing it over subsets of the column.
|
32
32
|
|
33
33
|
Because the return results are materialized as Python lists *in memory*, this method should
|
@@ -84,7 +84,7 @@ def map_dataframe_by_column(
|
|
84
84
|
unioned_df = mapped_df if unioned_df is None else unioned_df.union_all(mapped_df)
|
85
85
|
|
86
86
|
# Store results in a list of size |n_partitions| x |n_rows| x |n_output_cols|
|
87
|
-
all_results:
|
87
|
+
all_results: list[list[list[Any]]] = [[] for _ in range(n_partitions - 1)]
|
88
88
|
|
89
89
|
# Collect the results of the first n-1 partitions, removing the partition_id column
|
90
90
|
unioned_result = unioned_df.collect(statement_params=statement_params) if unioned_df is not None else []
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import sys
|
2
2
|
import warnings
|
3
|
-
from typing import
|
3
|
+
from typing import Optional, Union
|
4
4
|
|
5
5
|
from packaging.version import Version
|
6
6
|
|
@@ -8,7 +8,7 @@ from snowflake.ml._internal import telemetry
|
|
8
8
|
from snowflake.snowpark import AsyncJob, Row, Session
|
9
9
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
10
10
|
|
11
|
-
cache:
|
11
|
+
cache: dict[str, Optional[str]] = {}
|
12
12
|
|
13
13
|
_PROJECT = "ModelDevelopment"
|
14
14
|
_SUBPROJECT = "utils"
|
@@ -23,8 +23,8 @@ def is_relaxed() -> bool:
|
|
23
23
|
|
24
24
|
|
25
25
|
def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
26
|
-
pkg_versions:
|
27
|
-
) ->
|
26
|
+
pkg_versions: list[str], session: Session, subproject: Optional[str] = None
|
27
|
+
) -> list[str]:
|
28
28
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
29
29
|
return pkg_versions
|
30
30
|
else:
|
@@ -32,9 +32,9 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
|
|
32
32
|
|
33
33
|
|
34
34
|
def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
|
35
|
-
pkg_versions:
|
36
|
-
) ->
|
37
|
-
pkg_version_async_job_list:
|
35
|
+
pkg_versions: list[str], session: Session, subproject: Optional[str] = None
|
36
|
+
) -> list[str]:
|
37
|
+
pkg_version_async_job_list: list[tuple[str, AsyncJob]] = []
|
38
38
|
for pkg_version in pkg_versions:
|
39
39
|
if pkg_version not in cache:
|
40
40
|
# Execute pkg version queries asynchronously.
|
@@ -64,7 +64,7 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
|
|
64
64
|
|
65
65
|
def _query_pkg_version_supported_in_snowflake_conda_channel(
|
66
66
|
pkg_version: str, session: Session, block: bool, subproject: Optional[str] = None
|
67
|
-
) -> Union[AsyncJob,
|
67
|
+
) -> Union[AsyncJob, list[Row]]:
|
68
68
|
tokens = pkg_version.split("==")
|
69
69
|
if len(tokens) != 2:
|
70
70
|
raise RuntimeError(
|
@@ -102,9 +102,9 @@ def _query_pkg_version_supported_in_snowflake_conda_channel(
|
|
102
102
|
return pkg_version_list_or_async_job
|
103
103
|
|
104
104
|
|
105
|
-
def _get_conda_packages_and_emit_warnings(pkg_versions:
|
106
|
-
pkg_version_conda_list:
|
107
|
-
pkg_version_warning_list:
|
105
|
+
def _get_conda_packages_and_emit_warnings(pkg_versions: list[str]) -> list[str]:
|
106
|
+
pkg_version_conda_list: list[str] = []
|
107
|
+
pkg_version_warning_list: list[list[str]] = []
|
108
108
|
for pkg_version in pkg_versions:
|
109
109
|
try:
|
110
110
|
conda_pkg_version = cache[pkg_version]
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations # for return self methods
|
2
2
|
|
3
3
|
from functools import partial
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional
|
5
5
|
|
6
6
|
from snowflake import connector, snowpark
|
7
7
|
from snowflake.ml._internal.utils import formatting
|
@@ -123,7 +123,7 @@ def cell_value_by_column_matcher(
|
|
123
123
|
return True
|
124
124
|
|
125
125
|
|
126
|
-
_DEFAULT_MATCHERS:
|
126
|
+
_DEFAULT_MATCHERS: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = [
|
127
127
|
partial(result_dimension_matcher, 1, 1),
|
128
128
|
partial(column_name_matcher, "status"),
|
129
129
|
]
|
@@ -252,12 +252,12 @@ class SqlResultValidator(ResultValidator):
|
|
252
252
|
"""
|
253
253
|
|
254
254
|
def __init__(
|
255
|
-
self, session: snowpark.Session, query: str, statement_params: Optional[
|
255
|
+
self, session: snowpark.Session, query: str, statement_params: Optional[dict[str, Any]] = None
|
256
256
|
) -> None:
|
257
257
|
self._session: snowpark.Session = session
|
258
258
|
self._query: str = query
|
259
259
|
self._success_matchers: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = []
|
260
|
-
self._statement_params: Optional[
|
260
|
+
self._statement_params: Optional[dict[str, Any]] = statement_params
|
261
261
|
|
262
262
|
def _get_result(self) -> list[snowpark.Row]:
|
263
263
|
"""Collect the result of the given SQL query."""
|
@@ -1,15 +1,15 @@
|
|
1
1
|
import enum
|
2
|
-
from typing import Any,
|
2
|
+
from typing import Any, Optional, TypedDict, cast
|
3
3
|
|
4
4
|
from packaging import version
|
5
5
|
from typing_extensions import NotRequired, Required
|
6
6
|
|
7
7
|
from snowflake.ml._internal.utils import query_result_checker
|
8
|
-
from snowflake.snowpark import session
|
8
|
+
from snowflake.snowpark import exceptions as sp_exceptions, session
|
9
9
|
|
10
10
|
|
11
11
|
def get_current_snowflake_version(
|
12
|
-
sess: session.Session, *, statement_params: Optional[
|
12
|
+
sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
|
13
13
|
) -> version.Version:
|
14
14
|
"""Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say:
|
15
15
|
"7.44.2 b202312132139364eb71238" to <Version('7.44.2+b202312132139364eb71238')>
|
@@ -60,8 +60,8 @@ class SnowflakeRegion(TypedDict):
|
|
60
60
|
|
61
61
|
|
62
62
|
def get_regions(
|
63
|
-
sess: session.Session, *, statement_params: Optional[
|
64
|
-
) ->
|
63
|
+
sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
|
64
|
+
) -> dict[str, SnowflakeRegion]:
|
65
65
|
res = (
|
66
66
|
query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
|
67
67
|
.has_column("snowflake_region")
|
@@ -93,7 +93,7 @@ def get_regions(
|
|
93
93
|
return res_dict
|
94
94
|
|
95
95
|
|
96
|
-
def get_current_region_id(sess: session.Session, *, statement_params: Optional[
|
96
|
+
def get_current_region_id(sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None) -> str:
|
97
97
|
res = (
|
98
98
|
query_result_checker.SqlResultValidator(
|
99
99
|
sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params
|
@@ -103,3 +103,25 @@ def get_current_region_id(sess: session.Session, *, statement_params: Optional[D
|
|
103
103
|
)
|
104
104
|
|
105
105
|
return cast(str, res.CURRENT_REGION)
|
106
|
+
|
107
|
+
|
108
|
+
def get_current_cloud(
|
109
|
+
sess: session.Session,
|
110
|
+
default: Optional[SnowflakeCloudType] = None,
|
111
|
+
*,
|
112
|
+
statement_params: Optional[dict[str, Any]] = None,
|
113
|
+
) -> SnowflakeCloudType:
|
114
|
+
region_id = get_current_region_id(sess, statement_params=statement_params)
|
115
|
+
try:
|
116
|
+
region = get_regions(sess, statement_params=statement_params)[region_id]
|
117
|
+
return region["cloud"]
|
118
|
+
except sp_exceptions.SnowparkSQLException:
|
119
|
+
# SHOW REGIONS not available, try to infer cloud from region name
|
120
|
+
region_name = region_id.split(".", 1)[-1] # Drop region group if any, e.g. PUBLIC
|
121
|
+
cloud_name_maybe = region_name.split("_", 1)[0] # Extract cloud name, e.g. AWS_US_WEST -> AWS
|
122
|
+
try:
|
123
|
+
return SnowflakeCloudType.from_value(cloud_name_maybe)
|
124
|
+
except ValueError:
|
125
|
+
if default:
|
126
|
+
return default
|
127
|
+
raise
|
@@ -1,13 +1,13 @@
|
|
1
1
|
import logging
|
2
2
|
import warnings
|
3
|
-
from typing import
|
3
|
+
from typing import Optional
|
4
4
|
|
5
5
|
from snowflake import snowpark
|
6
6
|
from snowflake.ml._internal.utils import sql_identifier
|
7
7
|
from snowflake.snowpark import functions, types
|
8
8
|
|
9
9
|
|
10
|
-
def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[
|
10
|
+
def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[list[str]] = None) -> snowpark.DataFrame:
|
11
11
|
"""Cast columns in the dataframe to types that are compatible with tensor.
|
12
12
|
|
13
13
|
It assists FileSet.make() in performing implicit data casting.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
from snowflake.ml._internal.utils import identifier
|
4
4
|
|
@@ -77,13 +77,13 @@ class SqlIdentifier(str):
|
|
77
77
|
return super().__hash__()
|
78
78
|
|
79
79
|
|
80
|
-
def to_sql_identifiers(list_of_str:
|
80
|
+
def to_sql_identifiers(list_of_str: list[str], *, case_sensitive: bool = False) -> list[SqlIdentifier]:
|
81
81
|
return [SqlIdentifier(val, case_sensitive=case_sensitive) for val in list_of_str]
|
82
82
|
|
83
83
|
|
84
84
|
def parse_fully_qualified_name(
|
85
85
|
name: str,
|
86
|
-
) ->
|
86
|
+
) -> tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
|
87
87
|
db, schema, object = identifier.parse_schema_level_object_identifier(name)
|
88
88
|
|
89
89
|
assert name is not None, f"Unable parse the input name `{name}` as fully qualified."
|