snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -8,25 +8,13 @@ import sys
8
8
  import time
9
9
  import traceback
10
10
  import types
11
- from typing import (
12
- Any,
13
- Callable,
14
- Dict,
15
- Iterable,
16
- List,
17
- Mapping,
18
- Optional,
19
- Set,
20
- Tuple,
21
- TypeVar,
22
- Union,
23
- cast,
24
- )
11
+ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, cast
25
12
 
26
13
  from typing_extensions import ParamSpec
27
14
 
28
15
  from snowflake import connector
29
16
  from snowflake.connector import telemetry as connector_telemetry, time_util
17
+ from snowflake.ml import version as snowml_version
30
18
  from snowflake.ml._internal import env
31
19
  from snowflake.ml._internal.exceptions import (
32
20
  error_codes,
@@ -99,13 +87,13 @@ class _TelemetrySourceType(enum.Enum):
99
87
  AUGMENT_TELEMETRY = "SNOWML_AUGMENT_TELEMETRY"
100
88
 
101
89
 
102
- _statement_params_context_var: contextvars.ContextVar[Dict[str, str]] = contextvars.ContextVar("statement_params")
90
+ _statement_params_context_var: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("statement_params")
103
91
 
104
92
 
105
93
  class _StatementParamsPatchManager:
106
94
  def __init__(self) -> None:
107
- self._patch_cache: Set[server_connection.ServerConnection] = set()
108
- self._context_var: contextvars.ContextVar[Dict[str, str]] = _statement_params_context_var
95
+ self._patch_cache: set[server_connection.ServerConnection] = set()
96
+ self._context_var: contextvars.ContextVar[dict[str, str]] = _statement_params_context_var
109
97
 
110
98
  def apply_patches(self) -> None:
111
99
  try:
@@ -117,7 +105,7 @@ class _StatementParamsPatchManager:
117
105
  except snowpark_exceptions.SnowparkSessionException:
118
106
  pass
119
107
 
120
- def set_statement_params(self, statement_params: Dict[str, str]) -> None:
108
+ def set_statement_params(self, statement_params: dict[str, str]) -> None:
121
109
  # Only set value if not already set in context
122
110
  if not self._context_var.get({}):
123
111
  self._context_var.set(statement_params)
@@ -152,7 +140,6 @@ class _StatementParamsPatchManager:
152
140
  if throw_on_patch_fail: # primarily used for testing
153
141
  raise
154
142
  # TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection
155
- pass
156
143
 
157
144
  def _patch_with_statement_params(
158
145
  self, target: object, function_name: str, param_name: str = "statement_params"
@@ -197,10 +184,10 @@ class _StatementParamsPatchManager:
197
184
 
198
185
  setattr(target, function_name, wrapper)
199
186
 
200
- def __getstate__(self) -> Dict[str, Any]:
187
+ def __getstate__(self) -> dict[str, Any]:
201
188
  return {}
202
189
 
203
- def __setstate__(self, state: Dict[str, Any]) -> None:
190
+ def __setstate__(self, state: dict[str, Any]) -> None:
204
191
  # unpickling does not call __init__ by default, do it manually here
205
192
  self.__init__() # type: ignore[misc]
206
193
 
@@ -210,7 +197,7 @@ _patch_manager = _StatementParamsPatchManager()
210
197
 
211
198
  def get_statement_params(
212
199
  project: str, subproject: Optional[str] = None, class_name: Optional[str] = None
213
- ) -> Dict[str, Any]:
200
+ ) -> dict[str, Any]:
214
201
  """
215
202
  Get telemetry statement parameters.
216
203
 
@@ -231,8 +218,8 @@ def get_statement_params(
231
218
 
232
219
 
233
220
  def add_statement_params_custom_tags(
234
- statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
235
- ) -> Dict[str, Any]:
221
+ statement_params: Optional[dict[str, Any]], custom_tags: Mapping[str, Any]
222
+ ) -> dict[str, Any]:
236
223
  """
237
224
  Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
238
225
  If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
@@ -246,7 +233,7 @@ def add_statement_params_custom_tags(
246
233
  """
247
234
  if not statement_params:
248
235
  return {}
249
- existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
236
+ existing_custom_tags: dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
250
237
  existing_custom_tags.update(custom_tags)
251
238
  # NOTE: This can be done with | operator after upgrade from py3.8
252
239
  return {
@@ -289,17 +276,17 @@ def get_function_usage_statement_params(
289
276
  *,
290
277
  function_category: str = TelemetryField.FUNC_CAT_USAGE.value,
291
278
  function_name: Optional[str] = None,
292
- function_parameters: Optional[Dict[str, Any]] = None,
279
+ function_parameters: Optional[dict[str, Any]] = None,
293
280
  api_calls: Optional[
294
- List[
281
+ list[
295
282
  Union[
296
- Dict[str, Union[Callable[..., Any], str]],
283
+ dict[str, Union[Callable[..., Any], str]],
297
284
  Union[Callable[..., Any], str],
298
285
  ]
299
286
  ]
300
287
  ] = None,
301
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
302
- ) -> Dict[str, Any]:
288
+ custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
289
+ ) -> dict[str, Any]:
303
290
  """
304
291
  Get function usage statement parameters.
305
292
 
@@ -321,12 +308,12 @@ def get_function_usage_statement_params(
321
308
  >>> df.collect(statement_params=statement_params)
322
309
  """
323
310
  telemetry_type = f"{env.SOURCE.lower()}_{TelemetryField.TYPE_FUNCTION_USAGE.value}"
324
- statement_params: Dict[str, Any] = {
311
+ statement_params: dict[str, Any] = {
325
312
  connector_telemetry.TelemetryField.KEY_SOURCE.value: env.SOURCE,
326
313
  TelemetryField.KEY_PROJECT.value: project,
327
314
  TelemetryField.KEY_SUBPROJECT.value: subproject,
328
315
  TelemetryField.KEY_OS.value: env.OS,
329
- TelemetryField.KEY_VERSION.value: env.VERSION,
316
+ TelemetryField.KEY_VERSION.value: snowml_version.VERSION,
330
317
  TelemetryField.KEY_PYTHON_VERSION.value: env.PYTHON_VERSION,
331
318
  connector_telemetry.TelemetryField.KEY_TYPE.value: telemetry_type,
332
319
  TelemetryField.KEY_CATEGORY.value: function_category,
@@ -339,7 +326,7 @@ def get_function_usage_statement_params(
339
326
  if api_calls:
340
327
  statement_params[TelemetryField.KEY_API_CALLS.value] = []
341
328
  for api_call in api_calls:
342
- if isinstance(api_call, Dict):
329
+ if isinstance(api_call, dict):
343
330
  telemetry_api_call = api_call.copy()
344
331
  # convert Callable to str
345
332
  for field, api in api_call.items():
@@ -388,7 +375,7 @@ def send_custom_usage(
388
375
  *,
389
376
  telemetry_type: str,
390
377
  subproject: Optional[str] = None,
391
- data: Optional[Dict[str, Any]] = None,
378
+ data: Optional[dict[str, Any]] = None,
392
379
  **kwargs: Any,
393
380
  ) -> None:
394
381
  active_session = next(iter(session._get_active_sessions()))
@@ -409,17 +396,17 @@ def send_api_usage_telemetry(
409
396
  api_calls_extractor: Optional[
410
397
  Callable[
411
398
  ...,
412
- List[
399
+ list[
413
400
  Union[
414
- Dict[str, Union[Callable[..., Any], str]],
401
+ dict[str, Union[Callable[..., Any], str]],
415
402
  Union[Callable[..., Any], str],
416
403
  ]
417
404
  ],
418
405
  ]
419
406
  ] = None,
420
- sfqids_extractor: Optional[Callable[..., List[str]]] = None,
407
+ sfqids_extractor: Optional[Callable[..., list[str]]] = None,
421
408
  subproject_extractor: Optional[Callable[[Any], str]] = None,
422
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
409
+ custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
423
410
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
424
411
  """
425
412
  Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
@@ -454,7 +441,7 @@ def send_api_usage_telemetry(
454
441
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
455
442
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
456
443
 
457
- api_calls: List[Union[Dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
444
+ api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
458
445
  if api_calls_extractor:
459
446
  extracted_api_calls = api_calls_extractor(args[0])
460
447
  for api_call in extracted_api_calls:
@@ -484,7 +471,7 @@ def send_api_usage_telemetry(
484
471
  custom_tags=custom_tags,
485
472
  )
486
473
 
487
- def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[str, Any]) -> _ReturnValue:
474
+ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
488
475
  """
489
476
  Update SnowML function usage statement parameters to the object if it is a Snowpark DataFrame.
490
477
  Used to track APIs returning a Snowpark DataFrame.
@@ -614,7 +601,7 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
614
601
 
615
602
  def _get_func_params(
616
603
  func: Callable[..., Any], func_params_to_log: Optional[Iterable[str]], args: Any, kwargs: Any
617
- ) -> Dict[str, Any]:
604
+ ) -> dict[str, Any]:
618
605
  """
619
606
  Get function parameters.
620
607
 
@@ -639,7 +626,7 @@ def _get_func_params(
639
626
  return params
640
627
 
641
628
 
642
- def _extract_arg_value(field: str, func_spec: inspect.FullArgSpec, args: Any, kwargs: Any) -> Tuple[bool, Any]:
629
+ def _extract_arg_value(field: str, func_spec: inspect.FullArgSpec, args: Any, kwargs: Any) -> tuple[bool, Any]:
643
630
  """
644
631
  Function to extract a specified argument value.
645
632
 
@@ -702,11 +689,11 @@ class _SourceTelemetryClient:
702
689
  self.source: str = env.SOURCE
703
690
  self.project: Optional[str] = project
704
691
  self.subproject: Optional[str] = subproject
705
- self.version = env.VERSION
692
+ self.version = snowml_version.VERSION
706
693
  self.python_version: str = env.PYTHON_VERSION
707
694
  self.os: str = env.OS
708
695
 
709
- def _send(self, msg: Dict[str, Any], timestamp: Optional[int] = None) -> None:
696
+ def _send(self, msg: dict[str, Any], timestamp: Optional[int] = None) -> None:
710
697
  """
711
698
  Add telemetry data to a batch in connector client.
712
699
 
@@ -720,7 +707,7 @@ class _SourceTelemetryClient:
720
707
  telemetry_data = connector_telemetry.TelemetryData(message=msg, timestamp=timestamp)
721
708
  self._telemetry.try_add_log_to_batch(telemetry_data)
722
709
 
723
- def _create_basic_telemetry_data(self, telemetry_type: str) -> Dict[str, Any]:
710
+ def _create_basic_telemetry_data(self, telemetry_type: str) -> dict[str, Any]:
724
711
  message = {
725
712
  connector_telemetry.TelemetryField.KEY_SOURCE.value: self.source,
726
713
  TelemetryField.KEY_PROJECT.value: self.project,
@@ -738,10 +725,10 @@ class _SourceTelemetryClient:
738
725
  func_name: str,
739
726
  function_category: str,
740
727
  duration: float,
741
- func_params: Optional[Dict[str, Any]] = None,
742
- api_calls: Optional[List[Dict[str, Any]]] = None,
743
- sfqids: Optional[List[Any]] = None,
744
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
728
+ func_params: Optional[dict[str, Any]] = None,
729
+ api_calls: Optional[list[dict[str, Any]]] = None,
730
+ sfqids: Optional[list[Any]] = None,
731
+ custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
745
732
  error: Optional[str] = None,
746
733
  error_code: Optional[str] = None,
747
734
  stack_trace: Optional[str] = None,
@@ -761,7 +748,7 @@ class _SourceTelemetryClient:
761
748
  error_code: Error code.
762
749
  stack_trace: Error stack trace.
763
750
  """
764
- data: Dict[str, Any] = {
751
+ data: dict[str, Any] = {
765
752
  TelemetryField.KEY_FUNC_NAME.value: func_name,
766
753
  TelemetryField.KEY_CATEGORY.value: function_category,
767
754
  }
@@ -775,7 +762,7 @@ class _SourceTelemetryClient:
775
762
  data[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
776
763
 
777
764
  telemetry_type = f"{self.source.lower()}_{TelemetryField.TYPE_FUNCTION_USAGE.value}"
778
- message: Dict[str, Any] = {
765
+ message: dict[str, Any] = {
779
766
  **self._create_basic_telemetry_data(telemetry_type),
780
767
  TelemetryField.KEY_DATA.value: data,
781
768
  TelemetryField.KEY_DURATION.value: duration,
@@ -795,7 +782,7 @@ class _SourceTelemetryClient:
795
782
  self._telemetry.send_batch()
796
783
 
797
784
 
798
- def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: Dict[str, Any]) -> Dict[str, Any]:
785
+ def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: dict[str, Any]) -> dict[str, Any]:
799
786
  """
800
787
  Get statement_params keyword argument for sproc call.
801
788
 
@@ -11,7 +11,7 @@ T = TypeVar("T")
11
11
  class LazyType(Generic[T]):
12
12
  """Utility type to help defer need of importing."""
13
13
 
14
- def __init__(self, klass: Union[str, Type[T]]) -> None:
14
+ def __init__(self, klass: Union[str, type[T]]) -> None:
15
15
  self.qualname = ""
16
16
  if isinstance(klass, str):
17
17
  parts = klass.rsplit(".", 1)
@@ -30,7 +30,7 @@ class LazyType(Generic[T]):
30
30
  return self.isinstance(obj)
31
31
 
32
32
  @classmethod
33
- def from_type(cls, typ_: Union["LazyType[T]", Type[T]]) -> "LazyType[T]":
33
+ def from_type(cls, typ_: Union["LazyType[T]", type[T]]) -> "LazyType[T]":
34
34
  if isinstance(typ_, LazyType):
35
35
  return typ_
36
36
  return cls(typ_)
@@ -48,7 +48,7 @@ class LazyType(Generic[T]):
48
48
  def __repr__(self) -> str:
49
49
  return f'LazyType("{self.module}", "{self.qualname}")'
50
50
 
51
- def get_class(self) -> Type[T]:
51
+ def get_class(self) -> type[T]:
52
52
  if self._runtime_class is None:
53
53
  try:
54
54
  m = importlib.import_module(self.module)
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import Any, Dict, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
5
5
  from snowflake.snowpark import session
@@ -19,7 +19,7 @@ def db_object_exists(
19
19
  *,
20
20
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
21
21
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
22
- statement_params: Optional[Dict[str, Any]] = None,
22
+ statement_params: Optional[dict[str, Any]] = None,
23
23
  ) -> bool:
24
24
  """Check if object exists in database.
25
25
 
@@ -1,5 +1,5 @@
1
1
  import re
2
- from typing import Any, List, Optional, Tuple, Union, overload
2
+ from typing import Any, Optional, Union, overload
3
3
 
4
4
  from snowflake.snowpark._internal.analyzer import analyzer_utils
5
5
 
@@ -112,7 +112,7 @@ def get_inferred_name(name: str) -> str:
112
112
  return escaped_id
113
113
 
114
114
 
115
- def concat_names(names: List[str]) -> str:
115
+ def concat_names(names: list[str]) -> str:
116
116
  """Concatenates `names` to form one valid id.
117
117
 
118
118
 
@@ -142,7 +142,7 @@ def rename_to_valid_snowflake_identifier(name: str) -> str:
142
142
 
143
143
  def parse_schema_level_object_identifier(
144
144
  object_name: str,
145
- ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
145
+ ) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
146
146
  """Parse a string which starts with schema level object.
147
147
 
148
148
  Args:
@@ -172,7 +172,7 @@ def parse_schema_level_object_identifier(
172
172
 
173
173
  def parse_snowflake_stage_path(
174
174
  path: str,
175
- ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
175
+ ) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
176
176
  """Parse a string which represents a snowflake stage path.
177
177
 
178
178
  Args:
@@ -260,11 +260,11 @@ def get_unescaped_names(ids: str) -> str:
260
260
 
261
261
 
262
262
  @overload
263
- def get_unescaped_names(ids: List[str]) -> List[str]:
263
+ def get_unescaped_names(ids: list[str]) -> list[str]:
264
264
  ...
265
265
 
266
266
 
267
- def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
267
+ def get_unescaped_names(ids: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
268
268
  """Given a user provided identifier(s), this method will compute the equivalent column name identifier(s) in the
269
269
  response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
270
270
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
@@ -308,11 +308,11 @@ def get_inferred_names(names: str) -> str:
308
308
 
309
309
 
310
310
  @overload
311
- def get_inferred_names(names: List[str]) -> List[str]:
311
+ def get_inferred_names(names: list[str]) -> list[str]:
312
312
  ...
313
313
 
314
314
 
315
- def get_inferred_names(names: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
315
+ def get_inferred_names(names: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
316
316
  """Given a user provided *string(s)*, this method will compute the equivalent column name identifier(s)
317
317
  in case of column name contains special characters, and maintains case-sensitivity
318
318
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
@@ -1,5 +1,5 @@
1
1
  import importlib
2
- from typing import Any, Tuple
2
+ from typing import Any
3
3
 
4
4
 
5
5
  class MissingOptionalDependency:
@@ -46,7 +46,7 @@ def import_with_fallbacks(*targets: str) -> Any:
46
46
  raise ImportError(f"None of the requested targets could be imported. Requested: {', '.join(targets)}")
47
47
 
48
48
 
49
- def import_or_get_dummy(target: str) -> Tuple[Any, bool]:
49
+ def import_or_get_dummy(target: str) -> tuple[Any, bool]:
50
50
  """Try to import the the given target or return a dummy object.
51
51
 
52
52
  If the import target (package/module/symbol) is available, the target will be returned. If it is not available,
@@ -1,7 +1,7 @@
1
1
  import math
2
2
  from contextlib import contextmanager
3
3
  from timeit import default_timer
4
- from typing import Any, Callable, Dict, Generator, Iterable, List, Optional
4
+ from typing import Any, Callable, Generator, Iterable, Optional
5
5
 
6
6
  import snowflake.snowpark.functions as F
7
7
  from snowflake import snowpark
@@ -17,17 +17,17 @@ def timer() -> Generator[Callable[[], float], None, None]:
17
17
  yield lambda: elapser()
18
18
 
19
19
 
20
- def _flatten(L: Iterable[List[Any]]) -> List[Any]:
20
+ def _flatten(L: Iterable[list[Any]]) -> list[Any]:
21
21
  return [val for sublist in L for val in sublist]
22
22
 
23
23
 
24
24
  def map_dataframe_by_column(
25
25
  df: snowpark.DataFrame,
26
- cols: List[str],
27
- map_func: Callable[[snowpark.DataFrame, List[str]], snowpark.DataFrame],
26
+ cols: list[str],
27
+ map_func: Callable[[snowpark.DataFrame, list[str]], snowpark.DataFrame],
28
28
  partition_size: int,
29
- statement_params: Optional[Dict[str, Any]] = None,
30
- ) -> List[List[Any]]:
29
+ statement_params: Optional[dict[str, Any]] = None,
30
+ ) -> list[list[Any]]:
31
31
  """Applies the `map_func` to the input DataFrame by parallelizing it over subsets of the column.
32
32
 
33
33
  Because the return results are materialized as Python lists *in memory*, this method should
@@ -84,7 +84,7 @@ def map_dataframe_by_column(
84
84
  unioned_df = mapped_df if unioned_df is None else unioned_df.union_all(mapped_df)
85
85
 
86
86
  # Store results in a list of size |n_partitions| x |n_rows| x |n_output_cols|
87
- all_results: List[List[List[Any]]] = [[] for _ in range(n_partitions - 1)]
87
+ all_results: list[list[list[Any]]] = [[] for _ in range(n_partitions - 1)]
88
88
 
89
89
  # Collect the results of the first n-1 partitions, removing the partition_id column
90
90
  unioned_result = unioned_df.collect(statement_params=statement_params) if unioned_df is not None else []
@@ -1,6 +1,6 @@
1
1
  import sys
2
2
  import warnings
3
- from typing import Dict, List, Optional, Tuple, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  from packaging.version import Version
6
6
 
@@ -8,7 +8,7 @@ from snowflake.ml._internal import telemetry
8
8
  from snowflake.snowpark import AsyncJob, Row, Session
9
9
  from snowflake.snowpark._internal import utils as snowpark_utils
10
10
 
11
- cache: Dict[str, Optional[str]] = {}
11
+ cache: dict[str, Optional[str]] = {}
12
12
 
13
13
  _PROJECT = "ModelDevelopment"
14
14
  _SUBPROJECT = "utils"
@@ -23,8 +23,8 @@ def is_relaxed() -> bool:
23
23
 
24
24
 
25
25
  def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
26
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
27
- ) -> List[str]:
26
+ pkg_versions: list[str], session: Session, subproject: Optional[str] = None
27
+ ) -> list[str]:
28
28
  if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
29
29
  return pkg_versions
30
30
  else:
@@ -32,9 +32,9 @@ def get_valid_pkg_versions_supported_in_snowflake_conda_channel(
32
32
 
33
33
 
34
34
  def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
35
- pkg_versions: List[str], session: Session, subproject: Optional[str] = None
36
- ) -> List[str]:
37
- pkg_version_async_job_list: List[Tuple[str, AsyncJob]] = []
35
+ pkg_versions: list[str], session: Session, subproject: Optional[str] = None
36
+ ) -> list[str]:
37
+ pkg_version_async_job_list: list[tuple[str, AsyncJob]] = []
38
38
  for pkg_version in pkg_versions:
39
39
  if pkg_version not in cache:
40
40
  # Execute pkg version queries asynchronously.
@@ -64,7 +64,7 @@ def _get_valid_pkg_versions_supported_in_snowflake_conda_channel_async(
64
64
 
65
65
  def _query_pkg_version_supported_in_snowflake_conda_channel(
66
66
  pkg_version: str, session: Session, block: bool, subproject: Optional[str] = None
67
- ) -> Union[AsyncJob, List[Row]]:
67
+ ) -> Union[AsyncJob, list[Row]]:
68
68
  tokens = pkg_version.split("==")
69
69
  if len(tokens) != 2:
70
70
  raise RuntimeError(
@@ -102,9 +102,9 @@ def _query_pkg_version_supported_in_snowflake_conda_channel(
102
102
  return pkg_version_list_or_async_job
103
103
 
104
104
 
105
- def _get_conda_packages_and_emit_warnings(pkg_versions: List[str]) -> List[str]:
106
- pkg_version_conda_list: List[str] = []
107
- pkg_version_warning_list: List[List[str]] = []
105
+ def _get_conda_packages_and_emit_warnings(pkg_versions: list[str]) -> list[str]:
106
+ pkg_version_conda_list: list[str] = []
107
+ pkg_version_warning_list: list[list[str]] = []
108
108
  for pkg_version in pkg_versions:
109
109
  try:
110
110
  conda_pkg_version = cache[pkg_version]
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations # for return self methods
2
2
 
3
3
  from functools import partial
4
- from typing import Any, Callable, Dict, List, Optional
4
+ from typing import Any, Callable, Optional
5
5
 
6
6
  from snowflake import connector, snowpark
7
7
  from snowflake.ml._internal.utils import formatting
@@ -123,7 +123,7 @@ def cell_value_by_column_matcher(
123
123
  return True
124
124
 
125
125
 
126
- _DEFAULT_MATCHERS: List[Callable[[List[snowpark.Row], Optional[str]], bool]] = [
126
+ _DEFAULT_MATCHERS: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = [
127
127
  partial(result_dimension_matcher, 1, 1),
128
128
  partial(column_name_matcher, "status"),
129
129
  ]
@@ -252,12 +252,12 @@ class SqlResultValidator(ResultValidator):
252
252
  """
253
253
 
254
254
  def __init__(
255
- self, session: snowpark.Session, query: str, statement_params: Optional[Dict[str, Any]] = None
255
+ self, session: snowpark.Session, query: str, statement_params: Optional[dict[str, Any]] = None
256
256
  ) -> None:
257
257
  self._session: snowpark.Session = session
258
258
  self._query: str = query
259
259
  self._success_matchers: list[Callable[[list[snowpark.Row], Optional[str]], bool]] = []
260
- self._statement_params: Optional[Dict[str, Any]] = statement_params
260
+ self._statement_params: Optional[dict[str, Any]] = statement_params
261
261
 
262
262
  def _get_result(self) -> list[snowpark.Row]:
263
263
  """Collect the result of the given SQL query."""
@@ -1,15 +1,15 @@
1
1
  import enum
2
- from typing import Any, Dict, Optional, TypedDict, cast
2
+ from typing import Any, Optional, TypedDict, cast
3
3
 
4
4
  from packaging import version
5
5
  from typing_extensions import NotRequired, Required
6
6
 
7
7
  from snowflake.ml._internal.utils import query_result_checker
8
- from snowflake.snowpark import session
8
+ from snowflake.snowpark import exceptions as sp_exceptions, session
9
9
 
10
10
 
11
11
  def get_current_snowflake_version(
12
- sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
12
+ sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
13
13
  ) -> version.Version:
14
14
  """Get Snowflake Version as a version.Version object follow PEP way of versioning, that is to say:
15
15
  "7.44.2 b202312132139364eb71238" to <Version('7.44.2+b202312132139364eb71238')>
@@ -60,8 +60,8 @@ class SnowflakeRegion(TypedDict):
60
60
 
61
61
 
62
62
  def get_regions(
63
- sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None
64
- ) -> Dict[str, SnowflakeRegion]:
63
+ sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None
64
+ ) -> dict[str, SnowflakeRegion]:
65
65
  res = (
66
66
  query_result_checker.SqlResultValidator(sess, "SHOW REGIONS", statement_params=statement_params)
67
67
  .has_column("snowflake_region")
@@ -93,7 +93,7 @@ def get_regions(
93
93
  return res_dict
94
94
 
95
95
 
96
- def get_current_region_id(sess: session.Session, *, statement_params: Optional[Dict[str, Any]] = None) -> str:
96
+ def get_current_region_id(sess: session.Session, *, statement_params: Optional[dict[str, Any]] = None) -> str:
97
97
  res = (
98
98
  query_result_checker.SqlResultValidator(
99
99
  sess, "SELECT CURRENT_REGION() AS CURRENT_REGION", statement_params=statement_params
@@ -103,3 +103,25 @@ def get_current_region_id(sess: session.Session, *, statement_params: Optional[D
103
103
  )
104
104
 
105
105
  return cast(str, res.CURRENT_REGION)
106
+
107
+
108
+ def get_current_cloud(
109
+ sess: session.Session,
110
+ default: Optional[SnowflakeCloudType] = None,
111
+ *,
112
+ statement_params: Optional[dict[str, Any]] = None,
113
+ ) -> SnowflakeCloudType:
114
+ region_id = get_current_region_id(sess, statement_params=statement_params)
115
+ try:
116
+ region = get_regions(sess, statement_params=statement_params)[region_id]
117
+ return region["cloud"]
118
+ except sp_exceptions.SnowparkSQLException:
119
+ # SHOW REGIONS not available, try to infer cloud from region name
120
+ region_name = region_id.split(".", 1)[-1] # Drop region group if any, e.g. PUBLIC
121
+ cloud_name_maybe = region_name.split("_", 1)[0] # Extract cloud name, e.g. AWS_US_WEST -> AWS
122
+ try:
123
+ return SnowflakeCloudType.from_value(cloud_name_maybe)
124
+ except ValueError:
125
+ if default:
126
+ return default
127
+ raise
@@ -1,13 +1,13 @@
1
1
  import logging
2
2
  import warnings
3
- from typing import List, Optional
3
+ from typing import Optional
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal.utils import sql_identifier
7
7
  from snowflake.snowpark import functions, types
8
8
 
9
9
 
10
- def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[List[str]] = None) -> snowpark.DataFrame:
10
+ def cast_snowpark_dataframe(df: snowpark.DataFrame, ignore_columns: Optional[list[str]] = None) -> snowpark.DataFrame:
11
11
  """Cast columns in the dataframe to types that are compatible with tensor.
12
12
 
13
13
  It assists FileSet.make() in performing implicit data casting.
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from snowflake.ml._internal.utils import identifier
4
4
 
@@ -77,13 +77,13 @@ class SqlIdentifier(str):
77
77
  return super().__hash__()
78
78
 
79
79
 
80
- def to_sql_identifiers(list_of_str: List[str], *, case_sensitive: bool = False) -> List[SqlIdentifier]:
80
+ def to_sql_identifiers(list_of_str: list[str], *, case_sensitive: bool = False) -> list[SqlIdentifier]:
81
81
  return [SqlIdentifier(val, case_sensitive=case_sensitive) for val in list_of_str]
82
82
 
83
83
 
84
84
  def parse_fully_qualified_name(
85
85
  name: str,
86
- ) -> Tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
86
+ ) -> tuple[Optional[SqlIdentifier], Optional[SqlIdentifier], SqlIdentifier]:
87
87
  db, schema, object = identifier.parse_schema_level_object_identifier(name)
88
88
 
89
89
  assert name is not None, f"Unable parse the input name `{name}` as fully qualified."