snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import collections
2
2
  import logging
3
3
  from functools import partial
4
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
4
+ from typing import Any, Callable, Optional, Union, cast
5
5
 
6
6
  import fsspec
7
7
 
@@ -100,7 +100,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
100
100
  raise ValueError("Either sf_connection or snowpark_session has to be non-empty!")
101
101
  self._conn = self._session._conn._conn # Telemetry wrappers expect connection under `conn_attr_name="_conn"``
102
102
  self._kwargs = kwargs
103
- self._stage_fs_set: Dict[Tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
103
+ self._stage_fs_set: dict[tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
104
104
 
105
105
  super().__init__(**kwargs)
106
106
 
@@ -133,7 +133,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
133
133
  assert isinstance(session, snowpark.Session)
134
134
  return session
135
135
 
136
- def __reduce__(self) -> Tuple[Callable[[], Type["SFFileSystem"]], Tuple[()], Dict[str, Any]]:
136
+ def __reduce__(self) -> tuple[Callable[[], type["SFFileSystem"]], tuple[()], dict[str, Any]]:
137
137
  """Returns a state dictionary for use in serialization.
138
138
 
139
139
  Returns:
@@ -145,7 +145,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
145
145
 
146
146
  return partial(self.__class__, **{_RECREATE_FROM_SERIALIZED: True}), (), state_dictionary
147
147
 
148
- def __setstate__(self, state_dict: Dict[str, Any]) -> None:
148
+ def __setstate__(self, state_dict: dict[str, Any]) -> None:
149
149
  """Sets the dictionary state at deserialization time, and rebuilds a snowflake connection.
150
150
 
151
151
  Args:
@@ -191,7 +191,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
191
191
  func_params_to_log=["detail"],
192
192
  conn_attr_name="_conn",
193
193
  )
194
- def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[List[str], List[Dict[str, Any]]]:
194
+ def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[list[str], list[dict[str, Any]]]:
195
195
  """Override fsspec `ls` method. List single "directory" with or without details.
196
196
 
197
197
  Args:
@@ -214,14 +214,14 @@ class SFFileSystem(fsspec.AbstractFileSystem):
214
214
  file_path = self._parse_file_path(path)
215
215
  stage_fs = self._get_stage_fs(file_path)
216
216
  stage_path_list = stage_fs.ls(file_path.filepath, detail=True, **kwargs)
217
- stage_path_list = cast(List[Dict[str, Any]], stage_path_list)
217
+ stage_path_list = cast(list[dict[str, Any]], stage_path_list)
218
218
  return self._decorate_ls_res(stage_fs, stage_path_list, detail)
219
219
 
220
220
  @telemetry.send_api_usage_telemetry(
221
221
  project=_PROJECT,
222
222
  conn_attr_name="_conn",
223
223
  )
224
- def optimize_read(self, files: Optional[List[str]] = None) -> None:
224
+ def optimize_read(self, files: Optional[list[str]] = None) -> None:
225
225
  """Prefetch and cache the presigned urls for all the given files to speed up the file opening.
226
226
 
227
227
  All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
@@ -232,8 +232,8 @@ class SFFileSystem(fsspec.AbstractFileSystem):
232
232
  """
233
233
  if not files:
234
234
  return
235
- stage_fs_dict: Dict[str, stage_fs.SFStageFileSystem] = {}
236
- stage_file_paths: Dict[str, List[str]] = collections.defaultdict(list)
235
+ stage_fs_dict: dict[str, stage_fs.SFStageFileSystem] = {}
236
+ stage_file_paths: dict[str, list[str]] = collections.defaultdict(list)
237
237
  for file in files:
238
238
  path_info = self._parse_file_path(file)
239
239
  fs = self._get_stage_fs(path_info)
@@ -271,11 +271,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
271
271
  project=_PROJECT,
272
272
  conn_attr_name="_conn",
273
273
  )
274
- def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
274
+ def info(self, path: str, **kwargs: Any) -> dict[str, Any]:
275
275
  """Override fsspec `info` method. Give details of entry at path."""
276
276
  file_path = self._parse_file_path(path)
277
277
  stage_fs = self._get_stage_fs(file_path)
278
- res: Dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
278
+ res: dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
279
279
  if res:
280
280
  res["name"] = self._stage_path_to_absolute_path(stage_fs, res["name"])
281
281
  return res
@@ -283,9 +283,9 @@ class SFFileSystem(fsspec.AbstractFileSystem):
283
283
  def _decorate_ls_res(
284
284
  self,
285
285
  stage_fs: stage_fs.SFStageFileSystem,
286
- stage_path_list: List[Dict[str, Any]],
286
+ stage_path_list: list[dict[str, Any]],
287
287
  detail: bool,
288
- ) -> Union[List[str], List[Dict[str, Any]]]:
288
+ ) -> Union[list[str], list[dict[str, Any]]]:
289
289
  """Add the stage location as the prefix of file names returned by ls() of stagefs"""
290
290
  for path in stage_path_list:
291
291
  path["name"] = self._stage_path_to_absolute_path(stage_fs, path["name"])
@@ -2,7 +2,7 @@ import inspect
2
2
  import logging
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
5
+ from typing import Any, Optional, Union, cast
6
6
 
7
7
  import fsspec
8
8
  from fsspec.implementations import http as httpfs
@@ -44,7 +44,7 @@ class _PresignedUrl:
44
44
  return not self.expire_at or time.time() > self.expire_at - headroom_sec
45
45
 
46
46
 
47
- def _get_httpfs_kwargs(**kwargs: Any) -> Dict[str, Any]:
47
+ def _get_httpfs_kwargs(**kwargs: Any) -> dict[str, Any]:
48
48
  """Extract kwargs that are meaningful to HTTPFileSystem."""
49
49
  httpfs_related_keys = [
50
50
  "block_size",
@@ -124,7 +124,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
124
124
  self._db = db
125
125
  self._schema = schema
126
126
  self._stage = stage
127
- self._url_cache: Dict[str, _PresignedUrl] = {}
127
+ self._url_cache: dict[str, _PresignedUrl] = {}
128
128
 
129
129
  httpfs_kwargs = _get_httpfs_kwargs(**kwargs)
130
130
  self._fs = httpfs.HTTPFileSystem(**httpfs_kwargs)
@@ -145,7 +145,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
145
145
  project=_PROJECT,
146
146
  func_params_to_log=["detail"],
147
147
  )
148
- def ls(self, path: str, detail: bool = False) -> Union[List[str], List[Dict[str, Any]]]:
148
+ def ls(self, path: str, detail: bool = False) -> Union[list[str], list[dict[str, Any]]]:
149
149
  """Override fsspec `ls` method. List single "directory" with or without details.
150
150
 
151
151
  Args:
@@ -169,7 +169,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
169
169
  loc = self.stage_name
170
170
  path = path.lstrip("/")
171
171
  async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
172
- objects: List[snowpark.Row] = _resolve_async_job(async_job)
172
+ objects: list[snowpark.Row] = _resolve_async_job(async_job)
173
173
  except snowpark_exceptions.SnowparkSQLException as e:
174
174
  if e.sql_error_code == fileset_errors.ERRNO_DOMAIN_NOT_EXIST:
175
175
  raise snowml_exceptions.SnowflakeMLException(
@@ -192,7 +192,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
192
192
  @telemetry.send_api_usage_telemetry(
193
193
  project=_PROJECT,
194
194
  )
195
- def optimize_read(self, files: Optional[List[str]] = None) -> None:
195
+ def optimize_read(self, files: Optional[list[str]] = None) -> None:
196
196
  """Prefetch and cache the presigned urls for all the given files to speed up the read performance.
197
197
 
198
198
  All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
@@ -271,7 +271,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
271
271
  original_exception=fileset_errors.StageFileNotFoundError(f"Stage file {path} doesn't exist."),
272
272
  )
273
273
 
274
- def _open_with_snowpark(self, path: str, **kwargs: Dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
274
+ def _open_with_snowpark(self, path: str, **kwargs: dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
275
275
  """Open the a file for reading using snowflake.snowpark.file_operation
276
276
 
277
277
  Args:
@@ -299,7 +299,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
299
299
  original_exception=e,
300
300
  )
301
301
 
302
- def _parse_list_result(self, list_result: List[snowpark.Row], search_path: str) -> List[Dict[str, Any]]:
302
+ def _parse_list_result(self, list_result: list[snowpark.Row], search_path: str) -> list[dict[str, Any]]:
303
303
  """Convert the result from LIST query to the expected format of fsspec ls() method.
304
304
 
305
305
  Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
@@ -318,7 +318,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
318
318
  Returns:
319
319
  A list of dict, where each dict contains key-value pairs as the properties of a file.
320
320
  """
321
- files: Dict[str, Dict[str, Any]] = {}
321
+ files: dict[str, dict[str, Any]] = {}
322
322
  search_path = search_path.strip("/")
323
323
  for row in list_result:
324
324
  name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
@@ -360,7 +360,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
360
360
 
361
361
  def _add_file_info_helper(
362
362
  self,
363
- files: Dict[str, Dict[str, Any]],
363
+ files: dict[str, dict[str, Any]],
364
364
  object_path: str,
365
365
  file_size: int,
366
366
  file_type: str,
@@ -379,12 +379,12 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
379
379
  )
380
380
 
381
381
  def _fetch_presigned_urls(
382
- self, files: List[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
383
- ) -> List[Tuple[str, str]]:
382
+ self, files: list[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
383
+ ) -> list[tuple[str, str]]:
384
384
  """Fetch presigned urls for the given files."""
385
385
  file_df = self._session.create_dataframe(files).to_df("name")
386
386
  try:
387
- presigned_urls: List[Tuple[str, str]] = file_df.select_expr(
387
+ presigned_urls: list[tuple[str, str]] = file_df.select_expr(
388
388
  f"name, get_presigned_url('{self.stage_name}', name, {url_lifetime}) as url"
389
389
  ).collect(
390
390
  statement_params=telemetry.get_function_usage_statement_params(
@@ -418,10 +418,10 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
418
418
 
419
419
 
420
420
  @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
421
- def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]:
421
+ def _resolve_async_job(async_job: snowpark.AsyncJob) -> list[snowpark.Row]:
422
422
  # Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
423
423
  try:
424
- query_result = cast(List[snowpark.Row], async_job.result("row"))
424
+ query_result = cast(list[snowpark.Row], async_job.result("row"))
425
425
  return query_result
426
426
  except snowpark_errors.DatabaseError as e:
427
427
  # HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
@@ -10,7 +10,7 @@ import traceback
10
10
  from collections import namedtuple
11
11
  from dataclasses import dataclass
12
12
  from types import TracebackType
13
- from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, cast
13
+ from typing import Any, Callable, Optional, Union, cast
14
14
 
15
15
  from snowflake import snowpark
16
16
  from snowflake.snowpark import exceptions as sp_exceptions
@@ -33,7 +33,7 @@ class ExecutionResult:
33
33
  def success(self) -> bool:
34
34
  return self.exception is None
35
35
 
36
- def to_dict(self) -> Dict[str, Any]:
36
+ def to_dict(self) -> dict[str, Any]:
37
37
  """Return the serializable dictionary."""
38
38
  if isinstance(self.exception, BaseException):
39
39
  exc_type = type(self.exception)
@@ -50,7 +50,7 @@ class ExecutionResult:
50
50
  }
51
51
 
52
52
  @classmethod
53
- def from_dict(cls, result_dict: Dict[str, Any]) -> "ExecutionResult":
53
+ def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
54
54
  if not isinstance(result_dict.get("success"), bool):
55
55
  raise ValueError("Invalid result dictionary")
56
56
 
@@ -242,11 +242,11 @@ def _install_sys_excepthook() -> None:
242
242
  original_excepthook = sys.excepthook
243
243
 
244
244
  def custom_excepthook(
245
- exc_type: Type[BaseException],
245
+ exc_type: type[BaseException],
246
246
  exc_value: BaseException,
247
247
  exc_tb: Optional[TracebackType],
248
248
  *,
249
- seen_exc_ids: Optional[Set[int]] = None,
249
+ seen_exc_ids: Optional[set[int]] = None,
250
250
  ) -> None:
251
251
  if seen_exc_ids is None:
252
252
  seen_exc_ids = set()
@@ -331,7 +331,7 @@ def _install_ipython_hook() -> bool:
331
331
  except ImportError:
332
332
  return False
333
333
 
334
- def parse_traceback_str(traceback_str: str) -> List[Tuple[str, int, str, str]]:
334
+ def parse_traceback_str(traceback_str: str) -> list[tuple[str, int, str, str]]:
335
335
  return [
336
336
  (m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
337
337
  for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
@@ -342,13 +342,13 @@ def _install_ipython_hook() -> bool:
342
342
 
343
343
  def custom_format_exception_as_a_whole(
344
344
  self: VerboseTB,
345
- etype: Type[BaseException],
345
+ etype: type[BaseException],
346
346
  evalue: Optional[BaseException],
347
347
  etb: Optional[TracebackType],
348
348
  number_of_lines_of_context: int,
349
349
  tb_offset: Optional[int],
350
350
  **kwargs: Any,
351
- ) -> List[List[str]]:
351
+ ) -> list[list[str]]:
352
352
  if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
353
353
  # Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
354
354
  head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
@@ -388,7 +388,7 @@ def _install_ipython_hook() -> bool:
388
388
  etb: Optional[TracebackType],
389
389
  tb_offset: Optional[int] = None,
390
390
  **kwargs: Any,
391
- ) -> List[str]:
391
+ ) -> list[str]:
392
392
  if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
393
393
  tb_list = [
394
394
  (m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
@@ -400,7 +400,7 @@ def _install_ipython_hook() -> bool:
400
400
  "(most recent call last)",
401
401
  "(from remote execution)",
402
402
  )
403
- return cast(List[str], out_list)
403
+ return cast(list[str], out_list)
404
404
  return original_structured_traceback( # type: ignore[no-any-return]
405
405
  self, etype, evalue, etb, tb_offset, **kwargs
406
406
  )
@@ -6,17 +6,7 @@ import pickle
6
6
  import sys
7
7
  import textwrap
8
8
  from pathlib import Path, PurePath
9
- from typing import (
10
- Any,
11
- Callable,
12
- List,
13
- Optional,
14
- Type,
15
- Union,
16
- cast,
17
- get_args,
18
- get_origin,
19
- )
9
+ from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
20
10
 
21
11
  import cloudpickle as cp
22
12
 
@@ -277,7 +267,7 @@ class JobPayload:
277
267
  source: Union[str, Path, Callable[..., Any]],
278
268
  entrypoint: Optional[Union[str, Path]] = None,
279
269
  *,
280
- pip_requirements: Optional[List[str]] = None,
270
+ pip_requirements: Optional[list[str]] = None,
281
271
  ) -> None:
282
272
  self.source = Path(source) if isinstance(source, str) else source
283
273
  self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
@@ -364,7 +354,7 @@ class JobPayload:
364
354
  auto_compress=False,
365
355
  )
366
356
 
367
- python_entrypoint: List[Union[str, PurePath]] = [
357
+ python_entrypoint: list[Union[str, PurePath]] = [
368
358
  PurePath("mljob_launcher.py"),
369
359
  entrypoint.file_path.relative_to(source),
370
360
  ]
@@ -381,7 +371,7 @@ class JobPayload:
381
371
  )
382
372
 
383
373
 
384
- def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
374
+ def _get_parameter_type(param: inspect.Parameter) -> Optional[type[object]]:
385
375
  # Unwrap Optional type annotations
386
376
  param_type = param.annotation
387
377
  if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
@@ -390,10 +380,10 @@ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
390
380
  # Return None for empty type annotations
391
381
  if param_type == inspect.Parameter.empty:
392
382
  return None
393
- return cast(Type[object], param_type)
383
+ return cast(type[object], param_type)
394
384
 
395
385
 
396
- def _validate_parameter_type(param_type: Type[object], param_name: str) -> None:
386
+ def _validate_parameter_type(param_type: type[object], param_name: str) -> None:
397
387
  # Validate param_type is a supported type
398
388
  if param_type not in _SUPPORTED_ARG_TYPES:
399
389
  raise ValueError(
@@ -1,4 +1,5 @@
1
1
  import argparse
2
+ import copy
2
3
  import importlib.util
3
4
  import json
4
5
  import os
@@ -7,7 +8,7 @@ import sys
7
8
  import traceback
8
9
  import warnings
9
10
  from pathlib import Path
10
- from typing import Any, Dict, Optional
11
+ from typing import Any, Optional
11
12
 
12
13
  import cloudpickle
13
14
 
@@ -27,7 +28,7 @@ except ImportError:
27
28
  from dataclasses import dataclass
28
29
 
29
30
  @dataclass(frozen=True)
30
- class ExecutionResult:
31
+ class ExecutionResult: # type: ignore[no-redef]
31
32
  result: Optional[Any] = None
32
33
  exception: Optional[BaseException] = None
33
34
 
@@ -35,7 +36,7 @@ except ImportError:
35
36
  def success(self) -> bool:
36
37
  return self.exception is None
37
38
 
38
- def to_dict(self) -> Dict[str, Any]:
39
+ def to_dict(self) -> dict[str, Any]:
39
40
  """Return the serializable dictionary."""
40
41
  if isinstance(self.exception, BaseException):
41
42
  exc_type = type(self.exception)
@@ -136,7 +137,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
136
137
  while tb and tb.tb_frame.f_code.co_filename in skip_files:
137
138
  # Skip any frames preceding user script execution
138
139
  tb = tb.tb_next
139
- result_obj = ExecutionResult(exception=e.with_traceback(tb))
140
+ cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
141
+ cleaned_ex = cleaned_ex.with_traceback(tb)
142
+ result_obj = ExecutionResult(exception=cleaned_ex)
140
143
  raise
141
144
  finally:
142
145
  result_dict = result_obj.to_dict()
@@ -9,7 +9,7 @@ import logging
9
9
  import socket
10
10
  import sys
11
11
  import time
12
- from typing import Any, Dict, List, Set
12
+ from typing import Any
13
13
 
14
14
  import ray
15
15
  from constants import (
@@ -33,34 +33,34 @@ class ShutdownSignal:
33
33
  self.acknowledged_workers = set()
34
34
  logging.info(f"ShutdownSignal actor created on {self.hostname}")
35
35
 
36
- def request_shutdown(self) -> Dict[str, Any]:
36
+ def request_shutdown(self) -> dict[str, Any]:
37
37
  """Signal workers to shut down"""
38
38
  self.shutdown_requested = True
39
39
  self.timestamp = time.time()
40
40
  logging.info(f"Shutdown requested by head node at {self.timestamp}")
41
41
  return {"status": "shutdown_requested", "timestamp": self.timestamp, "host": self.hostname}
42
42
 
43
- def should_shutdown(self) -> Dict[str, Any]:
43
+ def should_shutdown(self) -> dict[str, Any]:
44
44
  """Check if shutdown has been requested"""
45
45
  return {"shutdown": self.shutdown_requested, "timestamp": self.timestamp, "host": self.hostname}
46
46
 
47
- def ping(self) -> Dict[str, Any]:
47
+ def ping(self) -> dict[str, Any]:
48
48
  """Simple method to test connectivity"""
49
49
  return {"status": "alive", "host": self.hostname}
50
50
 
51
- def acknowledge_shutdown(self, worker_id: str) -> Dict[str, Any]:
51
+ def acknowledge_shutdown(self, worker_id: str) -> dict[str, Any]:
52
52
  """Worker acknowledges it has received the shutdown signal and is terminating"""
53
53
  self.acknowledged_workers.add(worker_id)
54
54
  logging.info(f"Worker {worker_id} acknowledged shutdown. Total acknowledged: {len(self.acknowledged_workers)}")
55
55
 
56
56
  return {"status": "acknowledged", "worker_id": worker_id, "acknowledged_count": len(self.acknowledged_workers)}
57
57
 
58
- def get_acknowledgment_workers(self) -> Set[str]:
58
+ def get_acknowledgment_workers(self) -> set[str]:
59
59
  """Get the set of workers who have acknowledged shutdown"""
60
60
  return self.acknowledged_workers
61
61
 
62
62
 
63
- def get_worker_node_ids() -> List[str]:
63
+ def get_worker_node_ids() -> list[str]:
64
64
  """Get the IDs of all active worker nodes.
65
65
 
66
66
  Returns:
@@ -127,7 +127,7 @@ def verify_shutdown(shutdown_signal: ActorHandle) -> None:
127
127
  logging.debug(f"Shutdown status check: {check}")
128
128
 
129
129
 
130
- def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: List[str], wait_time: int) -> None:
130
+ def wait_for_acknowledgments(shutdown_signal: ActorHandle, worker_node_ids: list[str], wait_time: int) -> None:
131
131
  """Wait for workers to acknowledge shutdown.
132
132
 
133
133
  Args:
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from math import ceil
3
3
  from pathlib import PurePath
4
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal.utils import snowflake_env
@@ -15,10 +15,7 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
15
15
  if not rows:
16
16
  raise ValueError(f"Compute pool '{compute_pool}' not found")
17
17
  instance_family: str = rows[0]["instance_family"]
18
-
19
- # Get the cloud we're using (AWS, Azure, etc)
20
- region = snowflake_env.get_regions(session)[snowflake_env.get_current_region_id(session)]
21
- cloud = region["cloud"]
18
+ cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
22
19
 
23
20
  return (
24
21
  constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
@@ -26,22 +23,14 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
26
23
  )
27
24
 
28
25
 
29
- def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Optional[str] = None) -> types.ImageSpec:
26
+ def _get_image_spec(session: snowpark.Session, compute_pool: str) -> types.ImageSpec:
30
27
  # Retrieve compute pool node resources
31
28
  resources = _get_node_resources(session, compute_pool=compute_pool)
32
29
 
33
30
  # Use MLRuntime image
34
31
  image_repo = constants.DEFAULT_IMAGE_REPO
35
32
  image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
36
-
37
- # Try to pull latest image tag from server side if possible
38
- if not image_tag:
39
- query_result = session.sql("SHOW PARAMETERS LIKE 'constants.RUNTIME_BASE_IMAGE_TAG' IN ACCOUNT").collect()
40
- if query_result:
41
- image_tag = query_result[0]["value"]
42
-
43
- if image_tag is None:
44
- image_tag = constants.DEFAULT_IMAGE_TAG
33
+ image_tag = constants.DEFAULT_IMAGE_TAG
45
34
 
46
35
  # TODO: Should each instance consume the entire pod?
47
36
  return types.ImageSpec(
@@ -54,9 +43,9 @@ def _get_image_spec(session: snowpark.Session, compute_pool: str, image_tag: Opt
54
43
 
55
44
 
56
45
  def generate_spec_overrides(
57
- environment_vars: Optional[Dict[str, str]] = None,
58
- custom_overrides: Optional[Dict[str, Any]] = None,
59
- ) -> Dict[str, Any]:
46
+ environment_vars: Optional[dict[str, str]] = None,
47
+ custom_overrides: Optional[dict[str, Any]] = None,
48
+ ) -> dict[str, Any]:
60
49
  """
61
50
  Generate a dictionary of service specification overrides.
62
51
 
@@ -68,7 +57,7 @@ def generate_spec_overrides(
68
57
  Resulting service specifiation patch dict. Empty if no overrides were supplied.
69
58
  """
70
59
  # Generate container level overrides
71
- container_spec: Dict[str, Any] = {
60
+ container_spec: dict[str, Any] = {
72
61
  "name": constants.DEFAULT_CONTAINER_NAME,
73
62
  }
74
63
  if environment_vars:
@@ -95,10 +84,10 @@ def generate_service_spec(
95
84
  session: snowpark.Session,
96
85
  compute_pool: str,
97
86
  payload: types.UploadedPayload,
98
- args: Optional[List[str]] = None,
87
+ args: Optional[list[str]] = None,
99
88
  num_instances: Optional[int] = None,
100
89
  enable_metrics: bool = False,
101
- ) -> Dict[str, Any]:
90
+ ) -> dict[str, Any]:
102
91
  """
103
92
  Generate a service specification for a job.
104
93
 
@@ -117,11 +106,11 @@ def generate_service_spec(
117
106
  image_spec = _get_image_spec(session, compute_pool)
118
107
 
119
108
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
120
- resource_requests: Dict[str, Union[str, int]] = {
109
+ resource_requests: dict[str, Union[str, int]] = {
121
110
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
122
111
  "memory": f"{image_spec.resource_limits.memory}Gi",
123
112
  }
124
- resource_limits: Dict[str, Union[str, int]] = {
113
+ resource_limits: dict[str, Union[str, int]] = {
125
114
  "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
126
115
  "memory": f"{image_spec.resource_limits.memory}Gi",
127
116
  }
@@ -130,8 +119,8 @@ def generate_service_spec(
130
119
  resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
131
120
 
132
121
  # Add local volumes for ephemeral logs and artifacts
133
- volumes: List[Dict[str, str]] = []
134
- volume_mounts: List[Dict[str, str]] = []
122
+ volumes: list[dict[str, str]] = []
123
+ volume_mounts: list[dict[str, str]] = []
135
124
  for volume_name, mount_path in [
136
125
  ("system-logs", "/var/log/managedservices/system/mlrs"),
137
126
  ("user-logs", "/var/log/managedservices/user/mlrs"),
@@ -302,11 +291,11 @@ def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
302
291
 
303
292
 
304
293
  def _merge_lists_of_dicts(
305
- base: List[Dict[str, Any]],
306
- patch: List[Dict[str, Any]],
294
+ base: list[dict[str, Any]],
295
+ patch: list[dict[str, Any]],
307
296
  merge_key: str = "name",
308
297
  display_name: str = "",
309
- ) -> List[Dict[str, Any]]:
298
+ ) -> list[dict[str, Any]]:
310
299
  """
311
300
  Attempts to merge lists of dicts by matching on a merge key (default "name").
312
301
  - If the merge key is missing, the behavior falls back to overwriting the list.
@@ -1,6 +1,6 @@
1
1
  from dataclasses import dataclass
2
2
  from pathlib import PurePath
3
- from typing import List, Literal, Optional, Union
3
+ from typing import Literal, Optional, Union
4
4
 
5
5
  JOB_STATUS = Literal[
6
6
  "PENDING",
@@ -21,7 +21,7 @@ class PayloadEntrypoint:
21
21
  class UploadedPayload:
22
22
  # TODO: Include manifest of payload files for validation
23
23
  stage_path: PurePath
24
- entrypoint: List[Union[str, PurePath]]
24
+ entrypoint: list[Union[str, PurePath]]
25
25
 
26
26
 
27
27
  @dataclass(frozen=True)
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Callable, Dict, List, Optional, TypeVar
3
+ from typing import Callable, Optional, TypeVar
4
4
 
5
5
  from typing_extensions import ParamSpec
6
6
 
@@ -15,16 +15,15 @@ _Args = ParamSpec("_Args")
15
15
  _ReturnValue = TypeVar("_ReturnValue")
16
16
 
17
17
 
18
- @snowpark._internal.utils.private_preview(version="1.7.4")
19
18
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
20
19
  def remote(
21
20
  compute_pool: str,
22
21
  *,
23
22
  stage_name: str,
24
- pip_requirements: Optional[List[str]] = None,
25
- external_access_integrations: Optional[List[str]] = None,
23
+ pip_requirements: Optional[list[str]] = None,
24
+ external_access_integrations: Optional[list[str]] = None,
26
25
  query_warehouse: Optional[str] = None,
27
- env_vars: Optional[Dict[str, str]] = None,
26
+ env_vars: Optional[dict[str, str]] = None,
28
27
  num_instances: Optional[int] = None,
29
28
  enable_metrics: bool = False,
30
29
  session: Optional[snowpark.Session] = None,