snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import collections
2
2
  import logging
3
3
  from functools import partial
4
- from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
4
+ from typing import Any, Callable, Optional, Union, cast
5
5
 
6
6
  import fsspec
7
7
 
@@ -100,7 +100,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
100
100
  raise ValueError("Either sf_connection or snowpark_session has to be non-empty!")
101
101
  self._conn = self._session._conn._conn # Telemetry wrappers expect connection under `conn_attr_name="_conn"``
102
102
  self._kwargs = kwargs
103
- self._stage_fs_set: Dict[Tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
103
+ self._stage_fs_set: dict[tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
104
104
 
105
105
  super().__init__(**kwargs)
106
106
 
@@ -133,7 +133,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
133
133
  assert isinstance(session, snowpark.Session)
134
134
  return session
135
135
 
136
- def __reduce__(self) -> Tuple[Callable[[], Type["SFFileSystem"]], Tuple[()], Dict[str, Any]]:
136
+ def __reduce__(self) -> tuple[Callable[[], type["SFFileSystem"]], tuple[()], dict[str, Any]]:
137
137
  """Returns a state dictionary for use in serialization.
138
138
 
139
139
  Returns:
@@ -145,7 +145,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
145
145
 
146
146
  return partial(self.__class__, **{_RECREATE_FROM_SERIALIZED: True}), (), state_dictionary
147
147
 
148
- def __setstate__(self, state_dict: Dict[str, Any]) -> None:
148
+ def __setstate__(self, state_dict: dict[str, Any]) -> None:
149
149
  """Sets the dictionary state at deserialization time, and rebuilds a snowflake connection.
150
150
 
151
151
  Args:
@@ -191,7 +191,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
191
191
  func_params_to_log=["detail"],
192
192
  conn_attr_name="_conn",
193
193
  )
194
- def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[List[str], List[Dict[str, Any]]]:
194
+ def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[list[str], list[dict[str, Any]]]:
195
195
  """Override fsspec `ls` method. List single "directory" with or without details.
196
196
 
197
197
  Args:
@@ -214,14 +214,14 @@ class SFFileSystem(fsspec.AbstractFileSystem):
214
214
  file_path = self._parse_file_path(path)
215
215
  stage_fs = self._get_stage_fs(file_path)
216
216
  stage_path_list = stage_fs.ls(file_path.filepath, detail=True, **kwargs)
217
- stage_path_list = cast(List[Dict[str, Any]], stage_path_list)
217
+ stage_path_list = cast(list[dict[str, Any]], stage_path_list)
218
218
  return self._decorate_ls_res(stage_fs, stage_path_list, detail)
219
219
 
220
220
  @telemetry.send_api_usage_telemetry(
221
221
  project=_PROJECT,
222
222
  conn_attr_name="_conn",
223
223
  )
224
- def optimize_read(self, files: Optional[List[str]] = None) -> None:
224
+ def optimize_read(self, files: Optional[list[str]] = None) -> None:
225
225
  """Prefetch and cache the presigned urls for all the given files to speed up the file opening.
226
226
 
227
227
  All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
@@ -232,8 +232,8 @@ class SFFileSystem(fsspec.AbstractFileSystem):
232
232
  """
233
233
  if not files:
234
234
  return
235
- stage_fs_dict: Dict[str, stage_fs.SFStageFileSystem] = {}
236
- stage_file_paths: Dict[str, List[str]] = collections.defaultdict(list)
235
+ stage_fs_dict: dict[str, stage_fs.SFStageFileSystem] = {}
236
+ stage_file_paths: dict[str, list[str]] = collections.defaultdict(list)
237
237
  for file in files:
238
238
  path_info = self._parse_file_path(file)
239
239
  fs = self._get_stage_fs(path_info)
@@ -271,11 +271,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
271
271
  project=_PROJECT,
272
272
  conn_attr_name="_conn",
273
273
  )
274
- def info(self, path: str, **kwargs: Any) -> Dict[str, Any]:
274
+ def info(self, path: str, **kwargs: Any) -> dict[str, Any]:
275
275
  """Override fsspec `info` method. Give details of entry at path."""
276
276
  file_path = self._parse_file_path(path)
277
277
  stage_fs = self._get_stage_fs(file_path)
278
- res: Dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
278
+ res: dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
279
279
  if res:
280
280
  res["name"] = self._stage_path_to_absolute_path(stage_fs, res["name"])
281
281
  return res
@@ -283,9 +283,9 @@ class SFFileSystem(fsspec.AbstractFileSystem):
283
283
  def _decorate_ls_res(
284
284
  self,
285
285
  stage_fs: stage_fs.SFStageFileSystem,
286
- stage_path_list: List[Dict[str, Any]],
286
+ stage_path_list: list[dict[str, Any]],
287
287
  detail: bool,
288
- ) -> Union[List[str], List[Dict[str, Any]]]:
288
+ ) -> Union[list[str], list[dict[str, Any]]]:
289
289
  """Add the stage location as the prefix of file names returned by ls() of stagefs"""
290
290
  for path in stage_path_list:
291
291
  path["name"] = self._stage_path_to_absolute_path(stage_fs, path["name"])
@@ -2,7 +2,7 @@ import inspect
2
2
  import logging
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
5
+ from typing import Any, Optional, Union, cast
6
6
 
7
7
  import fsspec
8
8
  from fsspec.implementations import http as httpfs
@@ -44,7 +44,7 @@ class _PresignedUrl:
44
44
  return not self.expire_at or time.time() > self.expire_at - headroom_sec
45
45
 
46
46
 
47
- def _get_httpfs_kwargs(**kwargs: Any) -> Dict[str, Any]:
47
+ def _get_httpfs_kwargs(**kwargs: Any) -> dict[str, Any]:
48
48
  """Extract kwargs that are meaningful to HTTPFileSystem."""
49
49
  httpfs_related_keys = [
50
50
  "block_size",
@@ -124,7 +124,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
124
124
  self._db = db
125
125
  self._schema = schema
126
126
  self._stage = stage
127
- self._url_cache: Dict[str, _PresignedUrl] = {}
127
+ self._url_cache: dict[str, _PresignedUrl] = {}
128
128
 
129
129
  httpfs_kwargs = _get_httpfs_kwargs(**kwargs)
130
130
  self._fs = httpfs.HTTPFileSystem(**httpfs_kwargs)
@@ -145,7 +145,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
145
145
  project=_PROJECT,
146
146
  func_params_to_log=["detail"],
147
147
  )
148
- def ls(self, path: str, detail: bool = False) -> Union[List[str], List[Dict[str, Any]]]:
148
+ def ls(self, path: str, detail: bool = False) -> Union[list[str], list[dict[str, Any]]]:
149
149
  """Override fsspec `ls` method. List single "directory" with or without details.
150
150
 
151
151
  Args:
@@ -169,7 +169,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
169
169
  loc = self.stage_name
170
170
  path = path.lstrip("/")
171
171
  async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
172
- objects: List[snowpark.Row] = _resolve_async_job(async_job)
172
+ objects: list[snowpark.Row] = _resolve_async_job(async_job)
173
173
  except snowpark_exceptions.SnowparkSQLException as e:
174
174
  if e.sql_error_code == fileset_errors.ERRNO_DOMAIN_NOT_EXIST:
175
175
  raise snowml_exceptions.SnowflakeMLException(
@@ -192,7 +192,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
192
192
  @telemetry.send_api_usage_telemetry(
193
193
  project=_PROJECT,
194
194
  )
195
- def optimize_read(self, files: Optional[List[str]] = None) -> None:
195
+ def optimize_read(self, files: Optional[list[str]] = None) -> None:
196
196
  """Prefetch and cache the presigned urls for all the given files to speed up the read performance.
197
197
 
198
198
  All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
@@ -271,7 +271,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
271
271
  original_exception=fileset_errors.StageFileNotFoundError(f"Stage file {path} doesn't exist."),
272
272
  )
273
273
 
274
- def _open_with_snowpark(self, path: str, **kwargs: Dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
274
+ def _open_with_snowpark(self, path: str, **kwargs: dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
275
275
  """Open the a file for reading using snowflake.snowpark.file_operation
276
276
 
277
277
  Args:
@@ -299,7 +299,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
299
299
  original_exception=e,
300
300
  )
301
301
 
302
- def _parse_list_result(self, list_result: List[snowpark.Row], search_path: str) -> List[Dict[str, Any]]:
302
+ def _parse_list_result(self, list_result: list[snowpark.Row], search_path: str) -> list[dict[str, Any]]:
303
303
  """Convert the result from LIST query to the expected format of fsspec ls() method.
304
304
 
305
305
  Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
@@ -318,7 +318,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
318
318
  Returns:
319
319
  A list of dict, where each dict contains key-value pairs as the properties of a file.
320
320
  """
321
- files: Dict[str, Dict[str, Any]] = {}
321
+ files: dict[str, dict[str, Any]] = {}
322
322
  search_path = search_path.strip("/")
323
323
  for row in list_result:
324
324
  name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
@@ -360,7 +360,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
360
360
 
361
361
  def _add_file_info_helper(
362
362
  self,
363
- files: Dict[str, Dict[str, Any]],
363
+ files: dict[str, dict[str, Any]],
364
364
  object_path: str,
365
365
  file_size: int,
366
366
  file_type: str,
@@ -379,12 +379,12 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
379
379
  )
380
380
 
381
381
  def _fetch_presigned_urls(
382
- self, files: List[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
383
- ) -> List[Tuple[str, str]]:
382
+ self, files: list[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
383
+ ) -> list[tuple[str, str]]:
384
384
  """Fetch presigned urls for the given files."""
385
385
  file_df = self._session.create_dataframe(files).to_df("name")
386
386
  try:
387
- presigned_urls: List[Tuple[str, str]] = file_df.select_expr(
387
+ presigned_urls: list[tuple[str, str]] = file_df.select_expr(
388
388
  f"name, get_presigned_url('{self.stage_name}', name, {url_lifetime}) as url"
389
389
  ).collect(
390
390
  statement_params=telemetry.get_function_usage_statement_params(
@@ -418,10 +418,10 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
418
418
 
419
419
 
420
420
  @snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
421
- def _resolve_async_job(async_job: snowpark.AsyncJob) -> List[snowpark.Row]:
421
+ def _resolve_async_job(async_job: snowpark.AsyncJob) -> list[snowpark.Row]:
422
422
  # Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
423
423
  try:
424
- query_result = cast(List[snowpark.Row], async_job.result("row"))
424
+ query_result = cast(list[snowpark.Row], async_job.result("row"))
425
425
  return query_result
426
426
  except snowpark_errors.DatabaseError as e:
427
427
  # HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
@@ -4,6 +4,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
4
4
  # SPCS specification constants
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
+ RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
7
8
  MEMORY_VOLUME_NAME = "dshm"
8
9
  STAGE_VOLUME_NAME = "stage-volume"
9
10
  STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
@@ -18,10 +19,6 @@ DEFAULT_ENTRYPOINT_PATH = "func.py"
18
19
  # Percent of container memory to allocate for /dev/shm volume
19
20
  MEMORY_VOLUME_SIZE = 0.3
20
21
 
21
- # Multi Node Headless prototype constants
22
- # TODO: Replace this placeholder with the actual container runtime image tag.
23
- MULTINODE_HEADLESS_IMAGE_TAG = "latest"
24
-
25
22
  # Ray port configuration
26
23
  RAY_PORTS = {
27
24
  "HEAD_CLIENT_SERVER_PORT": "10001",
@@ -48,6 +45,7 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
48
45
 
49
46
  # Magic attributes
50
47
  IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
48
+ RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
51
49
 
52
50
  # Compute pool resource information
53
51
  # TODO: Query Snowflake for resource information instead of relying on this hardcoded
@@ -0,0 +1,442 @@
1
+ import builtins
2
+ import functools
3
+ import importlib
4
+ import json
5
+ import os
6
+ import pickle
7
+ import re
8
+ import sys
9
+ import traceback
10
+ from collections import namedtuple
11
+ from dataclasses import dataclass
12
+ from types import TracebackType
13
+ from typing import Any, Callable, Optional, Union, cast
14
+
15
+ from snowflake import snowpark
16
+ from snowflake.snowpark import exceptions as sp_exceptions
17
+
18
+ _TRACEBACK_ENTRY_PATTERN = re.compile(
19
+ r'File "(?P<filename>[^"]+)", line (?P<lineno>\d+), in (?P<name>[^\n]+)(?:\n(?!^\s*File)^\s*(?P<line>[^\n]+))?\n',
20
+ flags=re.MULTILINE,
21
+ )
22
+ _REMOTE_ERROR_ATTR_NAME = "_remote_error"
23
+
24
+ RemoteError = namedtuple("RemoteError", ["exc_type", "exc_msg", "exc_tb"])
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class ExecutionResult:
29
+ result: Any = None
30
+ exception: Optional[BaseException] = None
31
+
32
+ @property
33
+ def success(self) -> bool:
34
+ return self.exception is None
35
+
36
+ def to_dict(self) -> dict[str, Any]:
37
+ """Return the serializable dictionary."""
38
+ if isinstance(self.exception, BaseException):
39
+ exc_type = type(self.exception)
40
+ return {
41
+ "success": False,
42
+ "exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
43
+ "exc_value": self.exception,
44
+ "exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
45
+ }
46
+ return {
47
+ "success": True,
48
+ "result_type": type(self.result).__qualname__,
49
+ "result": self.result,
50
+ }
51
+
52
+ @classmethod
53
+ def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
54
+ if not isinstance(result_dict.get("success"), bool):
55
+ raise ValueError("Invalid result dictionary")
56
+
57
+ if result_dict["success"]:
58
+ # Load successful result
59
+ return cls(result=result_dict.get("result"))
60
+
61
+ # Load exception
62
+ exc_type = result_dict.get("exc_type", "RuntimeError")
63
+ exc_value = result_dict.get("exc_value", "Unknown error")
64
+ exc_tb = result_dict.get("exc_tb", "")
65
+ return cls(exception=load_exception(exc_type, exc_value, exc_tb))
66
+
67
+
68
+ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult:
69
+ """
70
+ Fetch the serialized result from the specified path.
71
+
72
+ Args:
73
+ session: Snowpark Session to use for file operations.
74
+ result_path: The path to the serialized result file.
75
+
76
+ Returns:
77
+ A dictionary containing the execution result if available, None otherwise.
78
+ """
79
+ try:
80
+ # TODO: Check if file exists
81
+ with session.file.get_stream(result_path) as result_stream:
82
+ return ExecutionResult.from_dict(pickle.load(result_stream))
83
+ except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
84
+ # Fall back to JSON result if loading pickled result fails for any reason
85
+ result_json_path = os.path.splitext(result_path)[0] + ".json"
86
+ with session.file.get_stream(result_json_path) as result_stream:
87
+ return ExecutionResult.from_dict(json.load(result_stream))
88
+
89
+
90
+ def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
91
+ """
92
+ Create an exception with a string-formatted traceback.
93
+
94
+ When this exception is raised and not caught, it will display the original traceback.
95
+ When caught, it behaves like a regular exception without showing the traceback.
96
+
97
+ Args:
98
+ exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
99
+ exc_value: The deserialized exception value or exception string (i.e. message)
100
+ exc_tb: String representation of the traceback
101
+
102
+ Returns:
103
+ An exception object with the original traceback information
104
+
105
+ # noqa: DAR401
106
+ """
107
+ if isinstance(exc_value, Exception):
108
+ exception = exc_value
109
+ else:
110
+ # Try to load the original exception type if possible
111
+ try:
112
+ # First check built-in exceptions
113
+ exc_type = getattr(builtins, exc_type_name, None)
114
+ if exc_type is None and "." in exc_type_name:
115
+ # Try to import from module path if it's a qualified name
116
+ module_path, class_name = exc_type_name.rsplit(".", 1)
117
+ module = importlib.import_module(module_path)
118
+ exc_type = getattr(module, class_name)
119
+ if exc_type is None or not issubclass(exc_type, Exception):
120
+ raise TypeError(f"{exc_type_name} is not a known exception type")
121
+ # Create the exception instance
122
+ exception = exc_type(exc_value)
123
+ except (ImportError, AttributeError, TypeError):
124
+ # Fall back to a generic exception
125
+ exception = RuntimeError(
126
+ f"Exception deserialization failed, original exception: {exc_type_name}: {exc_value}"
127
+ )
128
+
129
+ # Attach the traceback information to the exception
130
+ return _attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb)
131
+
132
+
133
+ def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceback_str: str) -> Exception:
134
+ """
135
+ Attach a string-formatted traceback to an exception.
136
+
137
+ When the exception is raised and not caught, it will display the original traceback.
138
+ When caught, it behaves like a regular exception without showing the traceback.
139
+
140
+ Args:
141
+ ex: The exception object to modify
142
+ exc_type: The original exception type name
143
+ exc_msg: The original exception message
144
+ traceback_str: String representation of the traceback
145
+
146
+ Returns:
147
+ An exception object with the original traceback information
148
+ """
149
+ # Store the traceback information
150
+ exc_type = exc_type.rsplit(".", 1)[-1] # Remove module path
151
+ setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteError(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str))
152
+ return ex
153
+
154
+
155
+ def _retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteError]:
156
+ """
157
+ Retrieve the string-formatted traceback from an exception if it exists.
158
+
159
+ Args:
160
+ ex: The exception to retrieve the traceback from
161
+
162
+ Returns:
163
+ The remote error tuple if it exists, None otherwise
164
+ """
165
+ if not ex:
166
+ return None
167
+ return getattr(ex, _REMOTE_ERROR_ATTR_NAME, None)
168
+
169
+
170
+ # ###############################################################################
171
+ # ------------------------------- !!! NOTE !!! -------------------------------- #
172
+ # ###############################################################################
173
+ # Job execution results (including uncaught exceptions) are serialized to file(s)
174
+ # in mljob_launcher.py. When the job is executed remotely, the serialized results
175
+ # are fetched and deserialized in the local environment. If the result contains
176
+ # an exception the original traceback is reconstructed and displayed to the user.
177
+ #
178
+ # It's currently impossible to recreate the original traceback object, so the
179
+ # following overrides are necessary to attach and display the deserialized
180
+ # traceback during exception handling.
181
+ #
182
+ # The following code implements the necessary overrides including sys.excepthook
183
+ # modifications and IPython traceback formatting. The hooks are applied on init
184
+ # and will be active for the duration of the process. The hooks are designed to
185
+ # self-uninstall in the event of an error in case of future compatibility issues.
186
+ # ###############################################################################
187
+
188
+
189
+ def _revert_func_wrapper(
190
+ patched_func: Callable[..., Any],
191
+ original_func: Callable[..., Any],
192
+ uninstall_func: Callable[[], None],
193
+ ) -> Callable[..., Any]:
194
+ """
195
+ Create a wrapper function that uninstalls the original function if an error occurs during execution.
196
+
197
+ This wrapper provides a fallback mechanism where if the patched function fails, it will:
198
+ 1. Uninstall the patched function using the provided uninstall_func, reverting back to using the original function
199
+ 2. Re-execute the current call using the original (unpatched) function with the same arguments
200
+
201
+ Args:
202
+ patched_func: The patched function to call.
203
+ original_func: The original function to call if patched_func fails.
204
+ uninstall_func: The function to call to uninstall the patched function.
205
+
206
+ Returns:
207
+ A wrapped function that calls patched_func and uninstalls on failure.
208
+ """
209
+
210
+ @functools.wraps(patched_func)
211
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
212
+ try:
213
+ return patched_func(*args, **kwargs)
214
+ except Exception:
215
+ # Uninstall and revert to original on failure
216
+ uninstall_func()
217
+ return original_func(*args, **kwargs)
218
+
219
+ return wrapped
220
+
221
+
222
+ def _install_sys_excepthook() -> None:
223
+ """
224
+ Install a custom sys.excepthook to handle remote exception tracebacks.
225
+
226
+ sys.excepthook is the global hook that Python calls when an unhandled exception occurs.
227
+ By default it prints the exception type, message and traceback to stderr.
228
+
229
+ We override sys.excepthook to intercept exceptions that contain our special RemoteError
230
+ attribute. These exceptions come from deserialized remote execution results and contain
231
+ the original traceback information from where they occurred.
232
+
233
+ When such an exception is detected, we format and display the original remote traceback
234
+ instead of the local one, which provides better debugging context by showing where the
235
+ error actually happened during remote execution.
236
+
237
+ The custom hook maintains proper exception chaining for both __cause__ (from raise from)
238
+ and __context__ (from implicit exception chaining).
239
+ """
240
+ # Attach the custom excepthook for standard Python scripts if not already attached
241
+ if not hasattr(sys, "_original_excepthook"):
242
+ original_excepthook = sys.excepthook
243
+
244
+ def custom_excepthook(
245
+ exc_type: type[BaseException],
246
+ exc_value: BaseException,
247
+ exc_tb: Optional[TracebackType],
248
+ *,
249
+ seen_exc_ids: Optional[set[int]] = None,
250
+ ) -> None:
251
+ if seen_exc_ids is None:
252
+ seen_exc_ids = set()
253
+ seen_exc_ids.add(id(exc_value))
254
+
255
+ cause = getattr(exc_value, "__cause__", None)
256
+ context = getattr(exc_value, "__context__", None)
257
+ if cause:
258
+ # Handle cause-chained exceptions
259
+ custom_excepthook(type(cause), cause, cause.__traceback__, seen_exc_ids=seen_exc_ids)
260
+ print( # noqa: T201
261
+ "\nThe above exception was the direct cause of the following exception:\n", file=sys.stderr
262
+ )
263
+ elif context and not getattr(exc_value, "__suppress_context__", False):
264
+ # Handle context-chained exceptions
265
+ # Only process context if it's different from cause to avoid double printing
266
+ custom_excepthook(type(context), context, context.__traceback__, seen_exc_ids=seen_exc_ids)
267
+ print( # noqa: T201
268
+ "\nDuring handling of the above exception, another exception occurred:\n", file=sys.stderr
269
+ )
270
+
271
+ if (remote_err := _retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteError):
272
+ # Display stored traceback for deserialized exceptions
273
+ print("Traceback (from remote execution):", file=sys.stderr) # noqa: T201
274
+ print(remote_err.exc_tb, end="", file=sys.stderr) # noqa: T201
275
+ print(f"{remote_err.exc_type}: {remote_err.exc_msg}", file=sys.stderr) # noqa: T201
276
+ else:
277
+ # Fall back to the original excepthook
278
+ traceback.print_exception(exc_type, exc_value, exc_tb, file=sys.stderr, chain=False)
279
+
280
+ sys._original_excepthook = original_excepthook # type: ignore[attr-defined]
281
+ sys.excepthook = _revert_func_wrapper(custom_excepthook, original_excepthook, _uninstall_sys_excepthook)
282
+
283
+
284
+ def _uninstall_sys_excepthook() -> None:
285
+ """
286
+ Restore the original excepthook for the current process.
287
+
288
+ This is useful when we want to revert to the default behavior after installing a custom excepthook.
289
+ """
290
+ if hasattr(sys, "_original_excepthook"):
291
+ sys.excepthook = sys._original_excepthook
292
+ del sys._original_excepthook
293
+
294
+
295
+ def _install_ipython_hook() -> bool:
296
+ """Install IPython-specific exception handling hook to improve remote error reporting.
297
+
298
+ This function enhances IPython's error formatting capabilities by intercepting and customizing
299
+ how remote execution errors are displayed. It modifies two key IPython traceback formatters:
300
+
301
+ 1. VerboseTB.format_exception_as_a_whole: Customizes the full traceback formatting for remote
302
+ errors by:
303
+ - Adding a "(from remote execution)" header instead of "(most recent call last)"
304
+ - Properly formatting the remote traceback entries
305
+ - Maintaining original behavior for non-remote errors
306
+
307
+ 2. ListTB.structured_traceback: Modifies the structured traceback output by:
308
+ - Parsing and formatting remote tracebacks appropriately
309
+ - Adding remote execution context to the output
310
+ - Preserving original functionality for local errors
311
+
312
+ The modifications are needed because IPython's default error handling doesn't properly display
313
+ remote execution errors that occur in Snowpark/Snowflake operations. The custom formatters
314
+ ensure that error messages from remote executions are properly captured, formatted and displayed
315
+ with the correct context and traceback information.
316
+
317
+ Returns:
318
+ bool: True if IPython hooks were successfully installed, False if IPython is not available
319
+ or not in an IPython environment.
320
+
321
+ Note:
322
+ This function maintains the ability to revert changes through _uninstall_ipython_hook by
323
+ storing original implementations before applying modifications.
324
+ """
325
+ try:
326
+ from IPython.core.getipython import get_ipython
327
+ from IPython.core.ultratb import ListTB, VerboseTB
328
+
329
+ if get_ipython() is None:
330
+ return False
331
+ except ImportError:
332
+ return False
333
+
334
+ def parse_traceback_str(traceback_str: str) -> list[tuple[str, int, str, str]]:
335
+ return [
336
+ (m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
337
+ for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
338
+ ]
339
+
340
+ if not hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
341
+ original_format_exception_as_a_whole = VerboseTB.format_exception_as_a_whole
342
+
343
+ def custom_format_exception_as_a_whole(
344
+ self: VerboseTB,
345
+ etype: type[BaseException],
346
+ evalue: Optional[BaseException],
347
+ etb: Optional[TracebackType],
348
+ number_of_lines_of_context: int,
349
+ tb_offset: Optional[int],
350
+ **kwargs: Any,
351
+ ) -> list[list[str]]:
352
+ if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
353
+ # Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
354
+ head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
355
+ "(most recent call last)",
356
+ "(from remote execution)",
357
+ )
358
+
359
+ frames = ListTB._format_list(
360
+ self,
361
+ parse_traceback_str(remote_err.exc_tb),
362
+ )
363
+ formatted_exception = self.format_exception(remote_err.exc_type, remote_err.exc_msg)
364
+
365
+ return [[head] + frames + formatted_exception]
366
+ return original_format_exception_as_a_whole( # type: ignore[no-any-return]
367
+ self,
368
+ etype=etype,
369
+ evalue=evalue,
370
+ etb=etb,
371
+ number_of_lines_of_context=number_of_lines_of_context,
372
+ tb_offset=tb_offset,
373
+ **kwargs,
374
+ )
375
+
376
+ VerboseTB._original_format_exception_as_a_whole = original_format_exception_as_a_whole
377
+ VerboseTB.format_exception_as_a_whole = _revert_func_wrapper(
378
+ custom_format_exception_as_a_whole, original_format_exception_as_a_whole, _uninstall_ipython_hook
379
+ )
380
+
381
+ if not hasattr(ListTB, "_original_structured_traceback"):
382
+ original_structured_traceback = ListTB.structured_traceback
383
+
384
+ def structured_traceback(
385
+ self: ListTB,
386
+ etype: type,
387
+ evalue: Optional[BaseException],
388
+ etb: Optional[TracebackType],
389
+ tb_offset: Optional[int] = None,
390
+ **kwargs: Any,
391
+ ) -> list[str]:
392
+ if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
393
+ tb_list = [
394
+ (m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
395
+ for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, remote_err.exc_tb or "")
396
+ ]
397
+ out_list = original_structured_traceback(self, etype, evalue, tb_list, tb_offset, **kwargs)
398
+ if out_list:
399
+ out_list[0] = out_list[0].replace(
400
+ "(most recent call last)",
401
+ "(from remote execution)",
402
+ )
403
+ return cast(list[str], out_list)
404
+ return original_structured_traceback( # type: ignore[no-any-return]
405
+ self, etype, evalue, etb, tb_offset, **kwargs
406
+ )
407
+
408
+ ListTB._original_structured_traceback = original_structured_traceback
409
+ ListTB.structured_traceback = _revert_func_wrapper(
410
+ structured_traceback, original_structured_traceback, _uninstall_ipython_hook
411
+ )
412
+
413
+ return True
414
+
415
+
416
+ def _uninstall_ipython_hook() -> None:
417
+ """
418
+ Restore the original IPython traceback formatting if it was modified.
419
+
420
+ This is useful when we want to revert to the default behavior after installing a custom hook.
421
+ """
422
+ try:
423
+ from IPython.core.ultratb import ListTB, VerboseTB
424
+
425
+ if hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
426
+ VerboseTB.format_exception_as_a_whole = VerboseTB._original_format_exception_as_a_whole
427
+ del VerboseTB._original_format_exception_as_a_whole
428
+
429
+ if hasattr(ListTB, "_original_structured_traceback"):
430
+ ListTB.structured_traceback = ListTB._original_structured_traceback
431
+ del ListTB._original_structured_traceback
432
+ except ImportError:
433
+ pass
434
+
435
+
436
+ def install_exception_display_hooks() -> None:
437
+ if not _install_ipython_hook():
438
+ _install_sys_excepthook()
439
+
440
+
441
+ # ------ Install the custom traceback hooks by default ------ #
442
+ install_exception_display_hooks()