snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -11,18 +11,7 @@ import sys
11
11
  import tarfile
12
12
  import tempfile
13
13
  import zipfile
14
- from typing import (
15
- Any,
16
- Callable,
17
- Dict,
18
- Generator,
19
- List,
20
- Literal,
21
- Optional,
22
- Set,
23
- Tuple,
24
- Union,
25
- )
14
+ from typing import Any, Callable, Generator, Literal, Optional, Union
26
15
  from urllib import parse
27
16
 
28
17
  import cloudpickle
@@ -37,7 +26,7 @@ GENERATED_PY_FILE_EXT = (".pyc", ".pyo", ".pyd", ".pyi")
37
26
  def copytree(
38
27
  src: "Union[str, os.PathLike[str]]",
39
28
  dst: "Union[str, os.PathLike[str]]",
40
- ignore: Optional[Callable[..., Set[str]]] = None,
29
+ ignore: Optional[Callable[..., set[str]]] = None,
41
30
  dirs_exist_ok: bool = False,
42
31
  ) -> "Union[str, os.PathLike[str]]":
43
32
  """This is a forked version of shutil.copytree that remove all copystat, to make sure it works in Sproc.
@@ -170,7 +159,7 @@ def zip_python_package(zipfile_path: str, package_name: str, ignore_generated_py
170
159
 
171
160
 
172
161
  def hash_directory(
173
- directory: Union[str, pathlib.Path], *, ignore_hidden: bool = False, excluded_files: Optional[List[str]] = None
162
+ directory: Union[str, pathlib.Path], *, ignore_hidden: bool = False, excluded_files: Optional[list[str]] = None
174
163
  ) -> str:
175
164
  """Hash the **content** of a folder recursively using SHA-1.
176
165
 
@@ -186,7 +175,7 @@ def hash_directory(
186
175
  excluded_files = []
187
176
 
188
177
  def _update_hash_from_dir(
189
- directory: Union[str, pathlib.Path], hash: "hashlib._Hash", *, ignore_hidden: bool, excluded_files: List[str]
178
+ directory: Union[str, pathlib.Path], hash: "hashlib._Hash", *, ignore_hidden: bool, excluded_files: list[str]
190
179
  ) -> "hashlib._Hash":
191
180
  assert pathlib.Path(directory).is_dir(), "Provided path is not a directory."
192
181
  for path in sorted(pathlib.Path(directory).iterdir(), key=lambda p: str(p).lower()):
@@ -208,7 +197,7 @@ def hash_directory(
208
197
  ).hexdigest()
209
198
 
210
199
 
211
- def get_all_modules(dirname: str, prefix: str = "") -> List[str]:
200
+ def get_all_modules(dirname: str, prefix: str = "") -> list[str]:
212
201
  modules = [mod.name for mod in pkgutil.iter_modules([dirname], prefix=prefix)]
213
202
  subdirs = [f.path for f in os.scandir(dirname) if f.is_dir()]
214
203
  for sub_dirname in subdirs:
@@ -248,7 +237,7 @@ def _create_tar_gz_stream(source_dir: str, arcname: Optional[str] = None) -> Gen
248
237
  yield output_stream
249
238
 
250
239
 
251
- def get_package_path(package_name: str, strategy: Literal["first", "last"] = "first") -> Tuple[str, str]:
240
+ def get_package_path(package_name: str, strategy: Literal["first", "last"] = "first") -> tuple[str, str]:
252
241
  """[Obsolete]Return the path to where a package is defined and its start location.
253
242
  Example 1: snowflake.ml -> path/to/site-packages/snowflake/ml, path/to/site-packages
254
243
  Example 2: zip_imported_module -> path/to/some/zipfile.zip/zip_imported_module, path/to/some/zipfile.zip
@@ -267,7 +256,7 @@ def get_package_path(package_name: str, strategy: Literal["first", "last"] = "fi
267
256
  return pkg_path, pkg_start_path
268
257
 
269
258
 
270
- def stage_object(session: snowpark.Session, object: object, stage_location: str) -> List[snowpark.PutResult]:
259
+ def stage_object(session: snowpark.Session, object: object, stage_location: str) -> list[snowpark.PutResult]:
271
260
  temp_file = tempfile.NamedTemporaryFile(delete=False)
272
261
  temp_file_path = temp_file.name
273
262
  temp_file.close()
@@ -279,7 +268,7 @@ def stage_object(session: snowpark.Session, object: object, stage_location: str)
279
268
 
280
269
 
281
270
  def stage_file_exists(
282
- session: snowpark.Session, stage_location: str, file_name: str, statement_params: Dict[str, Any]
271
+ session: snowpark.Session, stage_location: str, file_name: str, statement_params: dict[str, Any]
283
272
  ) -> bool:
284
273
  try:
285
274
  res = session.sql(f"list {stage_location}/{file_name}").collect(statement_params=statement_params)
@@ -297,7 +286,7 @@ def upload_directory_to_stage(
297
286
  local_path: pathlib.Path,
298
287
  stage_path: Union[pathlib.PurePosixPath, parse.ParseResult],
299
288
  *,
300
- statement_params: Optional[Dict[str, Any]] = None,
289
+ statement_params: Optional[dict[str, Any]] = None,
301
290
  ) -> None:
302
291
  """Upload a local folder recursively to a stage and keep the structure.
303
292
 
@@ -350,7 +339,7 @@ def download_directory_from_stage(
350
339
  stage_path: pathlib.PurePosixPath,
351
340
  local_path: pathlib.Path,
352
341
  *,
353
- statement_params: Optional[Dict[str, Any]] = None,
342
+ statement_params: Optional[dict[str, Any]] = None,
354
343
  ) -> None:
355
344
  """Upload a folder in stage recursively to a folder in local and keep the structure.
356
345
 
@@ -15,7 +15,6 @@ In this module you will find:
15
15
 
16
16
  import math
17
17
  from abc import ABC, abstractmethod
18
- from typing import Dict, List, Tuple
19
18
 
20
19
 
21
20
  class HRIDBase(ABC):
@@ -28,12 +27,11 @@ class HRIDBase(ABC):
28
27
  @abstractmethod
29
28
  def __id_generator__(self) -> int:
30
29
  """The generator to use to generate new IDs. The implementer needs to provide this."""
31
- pass
32
30
 
33
- __hrid_structure__: Tuple[str, ...]
31
+ __hrid_structure__: tuple[str, ...]
34
32
  """The HRID structure to be generated. The implementer needs to provide this."""
35
33
 
36
- __hrid_words__: Dict[str, Tuple[str, ...]]
34
+ __hrid_words__: dict[str, tuple[str, ...]]
37
35
  """The mapping between the HRID parts and the words to use. The implementer needs to provide this."""
38
36
 
39
37
  __separator__ = "_"
@@ -82,7 +80,7 @@ class HRIDBase(ABC):
82
80
  hrid.append(str(values[idxs[i]]))
83
81
  return self.__separator__.join(hrid)
84
82
 
85
- def generate(self) -> Tuple[int, str]:
83
+ def generate(self) -> tuple[int, str]:
86
84
  """Generate an ID and the corresponding HRID.
87
85
 
88
86
  Returns:
@@ -92,7 +90,7 @@ class HRIDBase(ABC):
92
90
  hrid = self.id_to_hrid(id)
93
91
  return (id, hrid)
94
92
 
95
- def _id_to_idxs(self, id: int) -> List[int]:
93
+ def _id_to_idxs(self, id: int) -> list[int]:
96
94
  """Take the ID and convert it to indices into the HRID words.
97
95
 
98
96
  Args:
@@ -109,7 +107,7 @@ class HRIDBase(ABC):
109
107
  idxs.append((id & mask) >> shift)
110
108
  return idxs
111
109
 
112
- def _hrid_to_idxs(self, hrid: str) -> List[int]:
110
+ def _hrid_to_idxs(self, hrid: str) -> list[int]:
113
111
  """Take the HRID and convert it to indices into the HRID words.
114
112
 
115
113
  Args:
@@ -2,10 +2,9 @@ import importlib
2
2
  import inspect
3
3
  import pkgutil
4
4
  from types import FunctionType
5
- from typing import Dict
6
5
 
7
6
 
8
- def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[str, type]:
7
+ def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> dict[str, type]:
9
8
  """Finds classes defined all the python modules in the given package directory.
10
9
 
11
10
  Args:
@@ -36,7 +35,7 @@ def fetch_classes_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[s
36
35
  return exportable_classes
37
36
 
38
37
 
39
- def fetch_functions_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> Dict[str, FunctionType]:
38
+ def fetch_functions_from_modules_in_pkg_dir(pkg_dir: str, pkg_name: str) -> dict[str, FunctionType]:
40
39
  """Finds functions defined all the python modules in the given package directory.
41
40
 
42
41
  Args:
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Any, Callable, List, Optional, get_args
3
+ from typing import Any, Callable, Optional, get_args
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml.data import data_source
@@ -9,7 +9,7 @@ _DATA_SOURCES_ATTR = "_data_sources"
9
9
 
10
10
 
11
11
  def _wrap_func(
12
- fn: Callable[..., snowpark.DataFrame], data_sources: List[data_source.DataSource]
12
+ fn: Callable[..., snowpark.DataFrame], data_sources: list[data_source.DataSource]
13
13
  ) -> Callable[..., snowpark.DataFrame]:
14
14
  """Wrap a DataFrame transform function to propagate data_sources to derived DataFrames."""
15
15
 
@@ -34,9 +34,9 @@ def _wrap_class_func(fn: Callable[..., snowpark.DataFrame]) -> Callable[..., sno
34
34
  return wrapped
35
35
 
36
36
 
37
- def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
37
+ def get_data_sources(*args: Any) -> Optional[list[data_source.DataSource]]:
38
38
  """Helper method for extracting data sources attribute from DataFrames in an argument list"""
39
- result: Optional[List[data_source.DataSource]] = None
39
+ result: Optional[list[data_source.DataSource]] = None
40
40
  for arg in args:
41
41
  srcs = getattr(arg, _DATA_SOURCES_ATTR, None)
42
42
  if isinstance(srcs, list) and all(isinstance(s, get_args(data_source.DataSource)) for s in srcs):
@@ -46,7 +46,7 @@ def get_data_sources(*args: Any) -> Optional[List[data_source.DataSource]]:
46
46
  return result
47
47
 
48
48
 
49
- def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSource]]) -> None:
49
+ def set_data_sources(obj: Any, data_sources: Optional[list[data_source.DataSource]]) -> None:
50
50
  """Helper method for attaching data sources to an object"""
51
51
  if data_sources:
52
52
  assert all(isinstance(ds, get_args(data_source.DataSource)) for ds in data_sources)
@@ -54,7 +54,7 @@ def set_data_sources(obj: Any, data_sources: Optional[List[data_source.DataSourc
54
54
 
55
55
 
56
56
  def patch_dataframe(
57
- df: snowpark.DataFrame, data_sources: List[data_source.DataSource], inplace: bool = False
57
+ df: snowpark.DataFrame, data_sources: list[data_source.DataSource], inplace: bool = False
58
58
  ) -> snowpark.DataFrame:
59
59
  """
60
60
  Monkey patch a DataFrame to add attach the provided data_sources as an attribute of the DataFrame.
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from contextlib import contextmanager
3
- from typing import Any, Dict, Optional
3
+ from typing import Any, Optional
4
4
 
5
5
  from absl import logging
6
6
 
@@ -29,7 +29,7 @@ class PlatformCapabilities:
29
29
 
30
30
  _instance: Optional["PlatformCapabilities"] = None
31
31
  # Used for unittesting only. This is to avoid the need to mock the session object or reaching out to Snowflake
32
- _mock_features: Optional[Dict[str, Any]] = None
32
+ _mock_features: Optional[dict[str, Any]] = None
33
33
 
34
34
  @classmethod
35
35
  def get_instance(cls, session: Optional[snowpark_session.Session] = None) -> "PlatformCapabilities":
@@ -41,7 +41,7 @@ class PlatformCapabilities:
41
41
  return cls._instance
42
42
 
43
43
  @classmethod
44
- def set_mock_features(cls, features: Optional[Dict[str, Any]] = None) -> None:
44
+ def set_mock_features(cls, features: Optional[dict[str, Any]] = None) -> None:
45
45
  cls._mock_features = features
46
46
 
47
47
  @classmethod
@@ -52,7 +52,7 @@ class PlatformCapabilities:
52
52
  # Python 3.11. So, we are ignoring the type for this method.
53
53
  @classmethod # type: ignore[arg-type]
54
54
  @contextmanager
55
- def mock_features(cls, features: Dict[str, Any]) -> None: # type: ignore[misc]
55
+ def mock_features(cls, features: dict[str, Any]) -> None: # type: ignore[misc]
56
56
  logging.debug(f"Setting mock features: {features}")
57
57
  cls.set_mock_features(features)
58
58
  try:
@@ -71,7 +71,7 @@ class PlatformCapabilities:
71
71
  return self._get_bool_feature("ENABLE_BUNDLE_MODULE_CHECKOUT", False)
72
72
 
73
73
  @staticmethod
74
- def _get_features(session: snowpark_session.Session) -> Dict[str, Any]:
74
+ def _get_features(session: snowpark_session.Session) -> dict[str, Any]:
75
75
  try:
76
76
  result = (
77
77
  query_result_checker.SqlResultValidator(
@@ -99,7 +99,7 @@ class PlatformCapabilities:
99
99
  return {}
100
100
 
101
101
  def __init__(
102
- self, *, session: Optional[snowpark_session.Session] = None, features: Optional[Dict[str, Any]] = None
102
+ self, *, session: Optional[snowpark_session.Session] = None, features: Optional[dict[str, Any]] = None
103
103
  ) -> None:
104
104
  # This is for testing purposes only.
105
105
  if features:
@@ -8,25 +8,13 @@ import sys
8
8
  import time
9
9
  import traceback
10
10
  import types
11
- from typing import (
12
- Any,
13
- Callable,
14
- Dict,
15
- Iterable,
16
- List,
17
- Mapping,
18
- Optional,
19
- Set,
20
- Tuple,
21
- TypeVar,
22
- Union,
23
- cast,
24
- )
11
+ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, cast
25
12
 
26
13
  from typing_extensions import ParamSpec
27
14
 
28
15
  from snowflake import connector
29
16
  from snowflake.connector import telemetry as connector_telemetry, time_util
17
+ from snowflake.ml import version as snowml_version
30
18
  from snowflake.ml._internal import env
31
19
  from snowflake.ml._internal.exceptions import (
32
20
  error_codes,
@@ -99,13 +87,13 @@ class _TelemetrySourceType(enum.Enum):
99
87
  AUGMENT_TELEMETRY = "SNOWML_AUGMENT_TELEMETRY"
100
88
 
101
89
 
102
- _statement_params_context_var: contextvars.ContextVar[Dict[str, str]] = contextvars.ContextVar("statement_params")
90
+ _statement_params_context_var: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar("statement_params")
103
91
 
104
92
 
105
93
  class _StatementParamsPatchManager:
106
94
  def __init__(self) -> None:
107
- self._patch_cache: Set[server_connection.ServerConnection] = set()
108
- self._context_var: contextvars.ContextVar[Dict[str, str]] = _statement_params_context_var
95
+ self._patch_cache: set[server_connection.ServerConnection] = set()
96
+ self._context_var: contextvars.ContextVar[dict[str, str]] = _statement_params_context_var
109
97
 
110
98
  def apply_patches(self) -> None:
111
99
  try:
@@ -117,7 +105,7 @@ class _StatementParamsPatchManager:
117
105
  except snowpark_exceptions.SnowparkSessionException:
118
106
  pass
119
107
 
120
- def set_statement_params(self, statement_params: Dict[str, str]) -> None:
108
+ def set_statement_params(self, statement_params: dict[str, str]) -> None:
121
109
  # Only set value if not already set in context
122
110
  if not self._context_var.get({}):
123
111
  self._context_var.set(statement_params)
@@ -152,7 +140,6 @@ class _StatementParamsPatchManager:
152
140
  if throw_on_patch_fail: # primarily used for testing
153
141
  raise
154
142
  # TODO: Log a warning, this probably means there was a breaking change in Snowpark/SnowflakeConnection
155
- pass
156
143
 
157
144
  def _patch_with_statement_params(
158
145
  self, target: object, function_name: str, param_name: str = "statement_params"
@@ -197,10 +184,10 @@ class _StatementParamsPatchManager:
197
184
 
198
185
  setattr(target, function_name, wrapper)
199
186
 
200
- def __getstate__(self) -> Dict[str, Any]:
187
+ def __getstate__(self) -> dict[str, Any]:
201
188
  return {}
202
189
 
203
- def __setstate__(self, state: Dict[str, Any]) -> None:
190
+ def __setstate__(self, state: dict[str, Any]) -> None:
204
191
  # unpickling does not call __init__ by default, do it manually here
205
192
  self.__init__() # type: ignore[misc]
206
193
 
@@ -210,7 +197,7 @@ _patch_manager = _StatementParamsPatchManager()
210
197
 
211
198
  def get_statement_params(
212
199
  project: str, subproject: Optional[str] = None, class_name: Optional[str] = None
213
- ) -> Dict[str, Any]:
200
+ ) -> dict[str, Any]:
214
201
  """
215
202
  Get telemetry statement parameters.
216
203
 
@@ -231,8 +218,8 @@ def get_statement_params(
231
218
 
232
219
 
233
220
  def add_statement_params_custom_tags(
234
- statement_params: Optional[Dict[str, Any]], custom_tags: Mapping[str, Any]
235
- ) -> Dict[str, Any]:
221
+ statement_params: Optional[dict[str, Any]], custom_tags: Mapping[str, Any]
222
+ ) -> dict[str, Any]:
236
223
  """
237
224
  Add custom_tags to existing statement_params. Overwrite keys in custom_tags dict that already exist.
238
225
  If existing statement_params are not provided, do nothing as the information cannot be effectively tracked.
@@ -246,7 +233,7 @@ def add_statement_params_custom_tags(
246
233
  """
247
234
  if not statement_params:
248
235
  return {}
249
- existing_custom_tags: Dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
236
+ existing_custom_tags: dict[str, Any] = statement_params.pop(TelemetryField.KEY_CUSTOM_TAGS.value, {})
250
237
  existing_custom_tags.update(custom_tags)
251
238
  # NOTE: This can be done with | operator after upgrade from py3.8
252
239
  return {
@@ -289,17 +276,17 @@ def get_function_usage_statement_params(
289
276
  *,
290
277
  function_category: str = TelemetryField.FUNC_CAT_USAGE.value,
291
278
  function_name: Optional[str] = None,
292
- function_parameters: Optional[Dict[str, Any]] = None,
279
+ function_parameters: Optional[dict[str, Any]] = None,
293
280
  api_calls: Optional[
294
- List[
281
+ list[
295
282
  Union[
296
- Dict[str, Union[Callable[..., Any], str]],
283
+ dict[str, Union[Callable[..., Any], str]],
297
284
  Union[Callable[..., Any], str],
298
285
  ]
299
286
  ]
300
287
  ] = None,
301
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
302
- ) -> Dict[str, Any]:
288
+ custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
289
+ ) -> dict[str, Any]:
303
290
  """
304
291
  Get function usage statement parameters.
305
292
 
@@ -321,12 +308,12 @@ def get_function_usage_statement_params(
321
308
  >>> df.collect(statement_params=statement_params)
322
309
  """
323
310
  telemetry_type = f"{env.SOURCE.lower()}_{TelemetryField.TYPE_FUNCTION_USAGE.value}"
324
- statement_params: Dict[str, Any] = {
311
+ statement_params: dict[str, Any] = {
325
312
  connector_telemetry.TelemetryField.KEY_SOURCE.value: env.SOURCE,
326
313
  TelemetryField.KEY_PROJECT.value: project,
327
314
  TelemetryField.KEY_SUBPROJECT.value: subproject,
328
315
  TelemetryField.KEY_OS.value: env.OS,
329
- TelemetryField.KEY_VERSION.value: env.VERSION,
316
+ TelemetryField.KEY_VERSION.value: snowml_version.VERSION,
330
317
  TelemetryField.KEY_PYTHON_VERSION.value: env.PYTHON_VERSION,
331
318
  connector_telemetry.TelemetryField.KEY_TYPE.value: telemetry_type,
332
319
  TelemetryField.KEY_CATEGORY.value: function_category,
@@ -339,7 +326,7 @@ def get_function_usage_statement_params(
339
326
  if api_calls:
340
327
  statement_params[TelemetryField.KEY_API_CALLS.value] = []
341
328
  for api_call in api_calls:
342
- if isinstance(api_call, Dict):
329
+ if isinstance(api_call, dict):
343
330
  telemetry_api_call = api_call.copy()
344
331
  # convert Callable to str
345
332
  for field, api in api_call.items():
@@ -388,7 +375,7 @@ def send_custom_usage(
388
375
  *,
389
376
  telemetry_type: str,
390
377
  subproject: Optional[str] = None,
391
- data: Optional[Dict[str, Any]] = None,
378
+ data: Optional[dict[str, Any]] = None,
392
379
  **kwargs: Any,
393
380
  ) -> None:
394
381
  active_session = next(iter(session._get_active_sessions()))
@@ -409,17 +396,17 @@ def send_api_usage_telemetry(
409
396
  api_calls_extractor: Optional[
410
397
  Callable[
411
398
  ...,
412
- List[
399
+ list[
413
400
  Union[
414
- Dict[str, Union[Callable[..., Any], str]],
401
+ dict[str, Union[Callable[..., Any], str]],
415
402
  Union[Callable[..., Any], str],
416
403
  ]
417
404
  ],
418
405
  ]
419
406
  ] = None,
420
- sfqids_extractor: Optional[Callable[..., List[str]]] = None,
407
+ sfqids_extractor: Optional[Callable[..., list[str]]] = None,
421
408
  subproject_extractor: Optional[Callable[[Any], str]] = None,
422
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
409
+ custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
423
410
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, _ReturnValue]]:
424
411
  """
425
412
  Decorator that sends API usage telemetry and adds function usage statement parameters to the dataframe returned by
@@ -454,7 +441,7 @@ def send_api_usage_telemetry(
454
441
  def wrap(*args: Any, **kwargs: Any) -> _ReturnValue:
455
442
  params = _get_func_params(func, func_params_to_log, args, kwargs) if func_params_to_log else None
456
443
 
457
- api_calls: List[Union[Dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
444
+ api_calls: list[Union[dict[str, Union[Callable[..., Any], str]], Callable[..., Any], str]] = []
458
445
  if api_calls_extractor:
459
446
  extracted_api_calls = api_calls_extractor(args[0])
460
447
  for api_call in extracted_api_calls:
@@ -484,7 +471,7 @@ def send_api_usage_telemetry(
484
471
  custom_tags=custom_tags,
485
472
  )
486
473
 
487
- def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: Dict[str, Any]) -> _ReturnValue:
474
+ def update_stmt_params_if_snowpark_df(obj: _ReturnValue, statement_params: dict[str, Any]) -> _ReturnValue:
488
475
  """
489
476
  Update SnowML function usage statement parameters to the object if it is a Snowpark DataFrame.
490
477
  Used to track APIs returning a Snowpark DataFrame.
@@ -614,7 +601,7 @@ def _get_full_func_name(func: Callable[..., Any]) -> str:
614
601
 
615
602
  def _get_func_params(
616
603
  func: Callable[..., Any], func_params_to_log: Optional[Iterable[str]], args: Any, kwargs: Any
617
- ) -> Dict[str, Any]:
604
+ ) -> dict[str, Any]:
618
605
  """
619
606
  Get function parameters.
620
607
 
@@ -639,7 +626,7 @@ def _get_func_params(
639
626
  return params
640
627
 
641
628
 
642
- def _extract_arg_value(field: str, func_spec: inspect.FullArgSpec, args: Any, kwargs: Any) -> Tuple[bool, Any]:
629
+ def _extract_arg_value(field: str, func_spec: inspect.FullArgSpec, args: Any, kwargs: Any) -> tuple[bool, Any]:
643
630
  """
644
631
  Function to extract a specified argument value.
645
632
 
@@ -702,11 +689,11 @@ class _SourceTelemetryClient:
702
689
  self.source: str = env.SOURCE
703
690
  self.project: Optional[str] = project
704
691
  self.subproject: Optional[str] = subproject
705
- self.version = env.VERSION
692
+ self.version = snowml_version.VERSION
706
693
  self.python_version: str = env.PYTHON_VERSION
707
694
  self.os: str = env.OS
708
695
 
709
- def _send(self, msg: Dict[str, Any], timestamp: Optional[int] = None) -> None:
696
+ def _send(self, msg: dict[str, Any], timestamp: Optional[int] = None) -> None:
710
697
  """
711
698
  Add telemetry data to a batch in connector client.
712
699
 
@@ -720,7 +707,7 @@ class _SourceTelemetryClient:
720
707
  telemetry_data = connector_telemetry.TelemetryData(message=msg, timestamp=timestamp)
721
708
  self._telemetry.try_add_log_to_batch(telemetry_data)
722
709
 
723
- def _create_basic_telemetry_data(self, telemetry_type: str) -> Dict[str, Any]:
710
+ def _create_basic_telemetry_data(self, telemetry_type: str) -> dict[str, Any]:
724
711
  message = {
725
712
  connector_telemetry.TelemetryField.KEY_SOURCE.value: self.source,
726
713
  TelemetryField.KEY_PROJECT.value: self.project,
@@ -738,10 +725,10 @@ class _SourceTelemetryClient:
738
725
  func_name: str,
739
726
  function_category: str,
740
727
  duration: float,
741
- func_params: Optional[Dict[str, Any]] = None,
742
- api_calls: Optional[List[Dict[str, Any]]] = None,
743
- sfqids: Optional[List[Any]] = None,
744
- custom_tags: Optional[Dict[str, Union[bool, int, str, float]]] = None,
728
+ func_params: Optional[dict[str, Any]] = None,
729
+ api_calls: Optional[list[dict[str, Any]]] = None,
730
+ sfqids: Optional[list[Any]] = None,
731
+ custom_tags: Optional[dict[str, Union[bool, int, str, float]]] = None,
745
732
  error: Optional[str] = None,
746
733
  error_code: Optional[str] = None,
747
734
  stack_trace: Optional[str] = None,
@@ -761,7 +748,7 @@ class _SourceTelemetryClient:
761
748
  error_code: Error code.
762
749
  stack_trace: Error stack trace.
763
750
  """
764
- data: Dict[str, Any] = {
751
+ data: dict[str, Any] = {
765
752
  TelemetryField.KEY_FUNC_NAME.value: func_name,
766
753
  TelemetryField.KEY_CATEGORY.value: function_category,
767
754
  }
@@ -775,7 +762,7 @@ class _SourceTelemetryClient:
775
762
  data[TelemetryField.KEY_CUSTOM_TAGS.value] = custom_tags
776
763
 
777
764
  telemetry_type = f"{self.source.lower()}_{TelemetryField.TYPE_FUNCTION_USAGE.value}"
778
- message: Dict[str, Any] = {
765
+ message: dict[str, Any] = {
779
766
  **self._create_basic_telemetry_data(telemetry_type),
780
767
  TelemetryField.KEY_DATA.value: data,
781
768
  TelemetryField.KEY_DURATION.value: duration,
@@ -795,7 +782,7 @@ class _SourceTelemetryClient:
795
782
  self._telemetry.send_batch()
796
783
 
797
784
 
798
- def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: Dict[str, Any]) -> Dict[str, Any]:
785
+ def get_sproc_statement_params_kwargs(sproc: Callable[..., Any], statement_params: dict[str, Any]) -> dict[str, Any]:
799
786
  """
800
787
  Get statement_params keyword argument for sproc call.
801
788
 
@@ -11,7 +11,7 @@ T = TypeVar("T")
11
11
  class LazyType(Generic[T]):
12
12
  """Utility type to help defer need of importing."""
13
13
 
14
- def __init__(self, klass: Union[str, Type[T]]) -> None:
14
+ def __init__(self, klass: Union[str, type[T]]) -> None:
15
15
  self.qualname = ""
16
16
  if isinstance(klass, str):
17
17
  parts = klass.rsplit(".", 1)
@@ -30,7 +30,7 @@ class LazyType(Generic[T]):
30
30
  return self.isinstance(obj)
31
31
 
32
32
  @classmethod
33
- def from_type(cls, typ_: Union["LazyType[T]", Type[T]]) -> "LazyType[T]":
33
+ def from_type(cls, typ_: Union["LazyType[T]", type[T]]) -> "LazyType[T]":
34
34
  if isinstance(typ_, LazyType):
35
35
  return typ_
36
36
  return cls(typ_)
@@ -48,7 +48,7 @@ class LazyType(Generic[T]):
48
48
  def __repr__(self) -> str:
49
49
  return f'LazyType("{self.module}", "{self.qualname}")'
50
50
 
51
- def get_class(self) -> Type[T]:
51
+ def get_class(self) -> type[T]:
52
52
  if self._runtime_class is None:
53
53
  try:
54
54
  m = importlib.import_module(self.module)
@@ -1,5 +1,5 @@
1
1
  from enum import Enum
2
- from typing import Any, Dict, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
5
5
  from snowflake.snowpark import session
@@ -19,7 +19,7 @@ def db_object_exists(
19
19
  *,
20
20
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
21
21
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
22
- statement_params: Optional[Dict[str, Any]] = None,
22
+ statement_params: Optional[dict[str, Any]] = None,
23
23
  ) -> bool:
24
24
  """Check if object exists in database.
25
25
 
@@ -1,5 +1,5 @@
1
1
  import re
2
- from typing import Any, List, Optional, Tuple, Union, overload
2
+ from typing import Any, Optional, Union, overload
3
3
 
4
4
  from snowflake.snowpark._internal.analyzer import analyzer_utils
5
5
 
@@ -112,7 +112,7 @@ def get_inferred_name(name: str) -> str:
112
112
  return escaped_id
113
113
 
114
114
 
115
- def concat_names(names: List[str]) -> str:
115
+ def concat_names(names: list[str]) -> str:
116
116
  """Concatenates `names` to form one valid id.
117
117
 
118
118
 
@@ -142,7 +142,7 @@ def rename_to_valid_snowflake_identifier(name: str) -> str:
142
142
 
143
143
  def parse_schema_level_object_identifier(
144
144
  object_name: str,
145
- ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
145
+ ) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any]]:
146
146
  """Parse a string which starts with schema level object.
147
147
 
148
148
  Args:
@@ -172,7 +172,7 @@ def parse_schema_level_object_identifier(
172
172
 
173
173
  def parse_snowflake_stage_path(
174
174
  path: str,
175
- ) -> Tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
175
+ ) -> tuple[Union[str, Any], Union[str, Any], Union[str, Any], Union[str, Any]]:
176
176
  """Parse a string which represents a snowflake stage path.
177
177
 
178
178
  Args:
@@ -260,11 +260,11 @@ def get_unescaped_names(ids: str) -> str:
260
260
 
261
261
 
262
262
  @overload
263
- def get_unescaped_names(ids: List[str]) -> List[str]:
263
+ def get_unescaped_names(ids: list[str]) -> list[str]:
264
264
  ...
265
265
 
266
266
 
267
- def get_unescaped_names(ids: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
267
+ def get_unescaped_names(ids: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
268
268
  """Given a user provided identifier(s), this method will compute the equivalent column name identifier(s) in the
269
269
  response pandas dataframe(i.e., in the response of snowpark_df.to_pandas()) using the rules defined here
270
270
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
@@ -308,11 +308,11 @@ def get_inferred_names(names: str) -> str:
308
308
 
309
309
 
310
310
  @overload
311
- def get_inferred_names(names: List[str]) -> List[str]:
311
+ def get_inferred_names(names: list[str]) -> list[str]:
312
312
  ...
313
313
 
314
314
 
315
- def get_inferred_names(names: Optional[Union[str, List[str]]]) -> Optional[Union[str, List[str]]]:
315
+ def get_inferred_names(names: Optional[Union[str, list[str]]]) -> Optional[Union[str, list[str]]]:
316
316
  """Given a user provided *string(s)*, this method will compute the equivalent column name identifier(s)
317
317
  in case of column name contains special characters, and maintains case-sensitivity
318
318
  https://docs.snowflake.com/en/sql-reference/identifiers-syntax.
@@ -1,5 +1,5 @@
1
1
  import importlib
2
- from typing import Any, Tuple
2
+ from typing import Any
3
3
 
4
4
 
5
5
  class MissingOptionalDependency:
@@ -46,7 +46,7 @@ def import_with_fallbacks(*targets: str) -> Any:
46
46
  raise ImportError(f"None of the requested targets could be imported. Requested: {', '.join(targets)}")
47
47
 
48
48
 
49
- def import_or_get_dummy(target: str) -> Tuple[Any, bool]:
49
+ def import_or_get_dummy(target: str) -> tuple[Any, bool]:
50
50
  """Try to import the the given target or return a dummy object.
51
51
 
52
52
  If the import target (package/module/symbol) is available, the target will be returned. If it is not available,