snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
snowflake/ml/fileset/sfcfs.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import collections
|
2
2
|
import logging
|
3
3
|
from functools import partial
|
4
|
-
from typing import Any, Callable,
|
4
|
+
from typing import Any, Callable, Optional, Union, cast
|
5
5
|
|
6
6
|
import fsspec
|
7
7
|
|
@@ -100,7 +100,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
100
100
|
raise ValueError("Either sf_connection or snowpark_session has to be non-empty!")
|
101
101
|
self._conn = self._session._conn._conn # Telemetry wrappers expect connection under `conn_attr_name="_conn"``
|
102
102
|
self._kwargs = kwargs
|
103
|
-
self._stage_fs_set:
|
103
|
+
self._stage_fs_set: dict[tuple[str, str, str], stage_fs.SFStageFileSystem] = {}
|
104
104
|
|
105
105
|
super().__init__(**kwargs)
|
106
106
|
|
@@ -133,7 +133,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
133
133
|
assert isinstance(session, snowpark.Session)
|
134
134
|
return session
|
135
135
|
|
136
|
-
def __reduce__(self) ->
|
136
|
+
def __reduce__(self) -> tuple[Callable[[], type["SFFileSystem"]], tuple[()], dict[str, Any]]:
|
137
137
|
"""Returns a state dictionary for use in serialization.
|
138
138
|
|
139
139
|
Returns:
|
@@ -145,7 +145,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
145
145
|
|
146
146
|
return partial(self.__class__, **{_RECREATE_FROM_SERIALIZED: True}), (), state_dictionary
|
147
147
|
|
148
|
-
def __setstate__(self, state_dict:
|
148
|
+
def __setstate__(self, state_dict: dict[str, Any]) -> None:
|
149
149
|
"""Sets the dictionary state at deserialization time, and rebuilds a snowflake connection.
|
150
150
|
|
151
151
|
Args:
|
@@ -191,7 +191,7 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
191
191
|
func_params_to_log=["detail"],
|
192
192
|
conn_attr_name="_conn",
|
193
193
|
)
|
194
|
-
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[
|
194
|
+
def ls(self, path: str, detail: bool = False, **kwargs: Any) -> Union[list[str], list[dict[str, Any]]]:
|
195
195
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
196
196
|
|
197
197
|
Args:
|
@@ -214,14 +214,14 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
214
214
|
file_path = self._parse_file_path(path)
|
215
215
|
stage_fs = self._get_stage_fs(file_path)
|
216
216
|
stage_path_list = stage_fs.ls(file_path.filepath, detail=True, **kwargs)
|
217
|
-
stage_path_list = cast(
|
217
|
+
stage_path_list = cast(list[dict[str, Any]], stage_path_list)
|
218
218
|
return self._decorate_ls_res(stage_fs, stage_path_list, detail)
|
219
219
|
|
220
220
|
@telemetry.send_api_usage_telemetry(
|
221
221
|
project=_PROJECT,
|
222
222
|
conn_attr_name="_conn",
|
223
223
|
)
|
224
|
-
def optimize_read(self, files: Optional[
|
224
|
+
def optimize_read(self, files: Optional[list[str]] = None) -> None:
|
225
225
|
"""Prefetch and cache the presigned urls for all the given files to speed up the file opening.
|
226
226
|
|
227
227
|
All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
|
@@ -232,8 +232,8 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
232
232
|
"""
|
233
233
|
if not files:
|
234
234
|
return
|
235
|
-
stage_fs_dict:
|
236
|
-
stage_file_paths:
|
235
|
+
stage_fs_dict: dict[str, stage_fs.SFStageFileSystem] = {}
|
236
|
+
stage_file_paths: dict[str, list[str]] = collections.defaultdict(list)
|
237
237
|
for file in files:
|
238
238
|
path_info = self._parse_file_path(file)
|
239
239
|
fs = self._get_stage_fs(path_info)
|
@@ -271,11 +271,11 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
271
271
|
project=_PROJECT,
|
272
272
|
conn_attr_name="_conn",
|
273
273
|
)
|
274
|
-
def info(self, path: str, **kwargs: Any) ->
|
274
|
+
def info(self, path: str, **kwargs: Any) -> dict[str, Any]:
|
275
275
|
"""Override fsspec `info` method. Give details of entry at path."""
|
276
276
|
file_path = self._parse_file_path(path)
|
277
277
|
stage_fs = self._get_stage_fs(file_path)
|
278
|
-
res:
|
278
|
+
res: dict[str, Any] = stage_fs.info(file_path.filepath, **kwargs)
|
279
279
|
if res:
|
280
280
|
res["name"] = self._stage_path_to_absolute_path(stage_fs, res["name"])
|
281
281
|
return res
|
@@ -283,9 +283,9 @@ class SFFileSystem(fsspec.AbstractFileSystem):
|
|
283
283
|
def _decorate_ls_res(
|
284
284
|
self,
|
285
285
|
stage_fs: stage_fs.SFStageFileSystem,
|
286
|
-
stage_path_list:
|
286
|
+
stage_path_list: list[dict[str, Any]],
|
287
287
|
detail: bool,
|
288
|
-
) -> Union[
|
288
|
+
) -> Union[list[str], list[dict[str, Any]]]:
|
289
289
|
"""Add the stage location as the prefix of file names returned by ls() of stagefs"""
|
290
290
|
for path in stage_path_list:
|
291
291
|
path["name"] = self._stage_path_to_absolute_path(stage_fs, path["name"])
|
snowflake/ml/fileset/stage_fs.py
CHANGED
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import logging
|
3
3
|
import time
|
4
4
|
from dataclasses import dataclass
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Optional, Union, cast
|
6
6
|
|
7
7
|
import fsspec
|
8
8
|
from fsspec.implementations import http as httpfs
|
@@ -44,7 +44,7 @@ class _PresignedUrl:
|
|
44
44
|
return not self.expire_at or time.time() > self.expire_at - headroom_sec
|
45
45
|
|
46
46
|
|
47
|
-
def _get_httpfs_kwargs(**kwargs: Any) ->
|
47
|
+
def _get_httpfs_kwargs(**kwargs: Any) -> dict[str, Any]:
|
48
48
|
"""Extract kwargs that are meaningful to HTTPFileSystem."""
|
49
49
|
httpfs_related_keys = [
|
50
50
|
"block_size",
|
@@ -124,7 +124,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
124
124
|
self._db = db
|
125
125
|
self._schema = schema
|
126
126
|
self._stage = stage
|
127
|
-
self._url_cache:
|
127
|
+
self._url_cache: dict[str, _PresignedUrl] = {}
|
128
128
|
|
129
129
|
httpfs_kwargs = _get_httpfs_kwargs(**kwargs)
|
130
130
|
self._fs = httpfs.HTTPFileSystem(**httpfs_kwargs)
|
@@ -145,7 +145,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
145
145
|
project=_PROJECT,
|
146
146
|
func_params_to_log=["detail"],
|
147
147
|
)
|
148
|
-
def ls(self, path: str, detail: bool = False) -> Union[
|
148
|
+
def ls(self, path: str, detail: bool = False) -> Union[list[str], list[dict[str, Any]]]:
|
149
149
|
"""Override fsspec `ls` method. List single "directory" with or without details.
|
150
150
|
|
151
151
|
Args:
|
@@ -169,7 +169,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
169
169
|
loc = self.stage_name
|
170
170
|
path = path.lstrip("/")
|
171
171
|
async_job: snowpark.AsyncJob = self._session.sql(f"LIST '{loc}/{path}'").collect(block=False)
|
172
|
-
objects:
|
172
|
+
objects: list[snowpark.Row] = _resolve_async_job(async_job)
|
173
173
|
except snowpark_exceptions.SnowparkSQLException as e:
|
174
174
|
if e.sql_error_code == fileset_errors.ERRNO_DOMAIN_NOT_EXIST:
|
175
175
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -192,7 +192,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
192
192
|
@telemetry.send_api_usage_telemetry(
|
193
193
|
project=_PROJECT,
|
194
194
|
)
|
195
|
-
def optimize_read(self, files: Optional[
|
195
|
+
def optimize_read(self, files: Optional[list[str]] = None) -> None:
|
196
196
|
"""Prefetch and cache the presigned urls for all the given files to speed up the read performance.
|
197
197
|
|
198
198
|
All the files introduced here will have their urls cached. Further open() on any of cached urls will lead to a
|
@@ -271,7 +271,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
271
271
|
original_exception=fileset_errors.StageFileNotFoundError(f"Stage file {path} doesn't exist."),
|
272
272
|
)
|
273
273
|
|
274
|
-
def _open_with_snowpark(self, path: str, **kwargs:
|
274
|
+
def _open_with_snowpark(self, path: str, **kwargs: dict[str, Any]) -> fsspec.spec.AbstractBufferedFile:
|
275
275
|
"""Open the a file for reading using snowflake.snowpark.file_operation
|
276
276
|
|
277
277
|
Args:
|
@@ -299,7 +299,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
299
299
|
original_exception=e,
|
300
300
|
)
|
301
301
|
|
302
|
-
def _parse_list_result(self, list_result:
|
302
|
+
def _parse_list_result(self, list_result: list[snowpark.Row], search_path: str) -> list[dict[str, Any]]:
|
303
303
|
"""Convert the result from LIST query to the expected format of fsspec ls() method.
|
304
304
|
|
305
305
|
Note that Snowflake LIST query has different behavior with ls(). LIST query will return all the stage files
|
@@ -318,7 +318,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
318
318
|
Returns:
|
319
319
|
A list of dict, where each dict contains key-value pairs as the properties of a file.
|
320
320
|
"""
|
321
|
-
files:
|
321
|
+
files: dict[str, dict[str, Any]] = {}
|
322
322
|
search_path = search_path.strip("/")
|
323
323
|
for row in list_result:
|
324
324
|
name, size, md5, last_modified = row["name"], row["size"], row["md5"], row["last_modified"]
|
@@ -360,7 +360,7 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
360
360
|
|
361
361
|
def _add_file_info_helper(
|
362
362
|
self,
|
363
|
-
files:
|
363
|
+
files: dict[str, dict[str, Any]],
|
364
364
|
object_path: str,
|
365
365
|
file_size: int,
|
366
366
|
file_type: str,
|
@@ -379,12 +379,12 @@ class SFStageFileSystem(fsspec.AbstractFileSystem):
|
|
379
379
|
)
|
380
380
|
|
381
381
|
def _fetch_presigned_urls(
|
382
|
-
self, files:
|
383
|
-
) ->
|
382
|
+
self, files: list[str], url_lifetime: float = _PRESIGNED_URL_LIFETIME_SEC
|
383
|
+
) -> list[tuple[str, str]]:
|
384
384
|
"""Fetch presigned urls for the given files."""
|
385
385
|
file_df = self._session.create_dataframe(files).to_df("name")
|
386
386
|
try:
|
387
|
-
presigned_urls:
|
387
|
+
presigned_urls: list[tuple[str, str]] = file_df.select_expr(
|
388
388
|
f"name, get_presigned_url('{self.stage_name}', name, {url_lifetime}) as url"
|
389
389
|
).collect(
|
390
390
|
statement_params=telemetry.get_function_usage_statement_params(
|
@@ -418,10 +418,10 @@ def _match_error_code(ex: snowpark_exceptions.SnowparkSQLException, error_code:
|
|
418
418
|
|
419
419
|
|
420
420
|
@snowflake_plan.SnowflakePlan.Decorator.wrap_exception # type: ignore[misc]
|
421
|
-
def _resolve_async_job(async_job: snowpark.AsyncJob) ->
|
421
|
+
def _resolve_async_job(async_job: snowpark.AsyncJob) -> list[snowpark.Row]:
|
422
422
|
# Make sure Snowpark exceptions are properly caught and converted by wrap_exception wrapper
|
423
423
|
try:
|
424
|
-
query_result = cast(
|
424
|
+
query_result = cast(list[snowpark.Row], async_job.result("row"))
|
425
425
|
return query_result
|
426
426
|
except snowpark_errors.DatabaseError as e:
|
427
427
|
# HACK: Snowpark surfaces a generic exception if query doesn't complete immediately
|
@@ -4,6 +4,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
|
|
4
4
|
# SPCS specification constants
|
5
5
|
DEFAULT_CONTAINER_NAME = "main"
|
6
6
|
PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
|
7
|
+
RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
|
7
8
|
MEMORY_VOLUME_NAME = "dshm"
|
8
9
|
STAGE_VOLUME_NAME = "stage-volume"
|
9
10
|
STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
@@ -18,10 +19,6 @@ DEFAULT_ENTRYPOINT_PATH = "func.py"
|
|
18
19
|
# Percent of container memory to allocate for /dev/shm volume
|
19
20
|
MEMORY_VOLUME_SIZE = 0.3
|
20
21
|
|
21
|
-
# Multi Node Headless prototype constants
|
22
|
-
# TODO: Replace this placeholder with the actual container runtime image tag.
|
23
|
-
MULTINODE_HEADLESS_IMAGE_TAG = "latest"
|
24
|
-
|
25
22
|
# Ray port configuration
|
26
23
|
RAY_PORTS = {
|
27
24
|
"HEAD_CLIENT_SERVER_PORT": "10001",
|
@@ -48,6 +45,7 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
|
|
48
45
|
|
49
46
|
# Magic attributes
|
50
47
|
IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
|
48
|
+
RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
|
51
49
|
|
52
50
|
# Compute pool resource information
|
53
51
|
# TODO: Query Snowflake for resource information instead of relying on this hardcoded
|
@@ -0,0 +1,442 @@
|
|
1
|
+
import builtins
|
2
|
+
import functools
|
3
|
+
import importlib
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
import re
|
8
|
+
import sys
|
9
|
+
import traceback
|
10
|
+
from collections import namedtuple
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from types import TracebackType
|
13
|
+
from typing import Any, Callable, Optional, Union, cast
|
14
|
+
|
15
|
+
from snowflake import snowpark
|
16
|
+
from snowflake.snowpark import exceptions as sp_exceptions
|
17
|
+
|
18
|
+
_TRACEBACK_ENTRY_PATTERN = re.compile(
|
19
|
+
r'File "(?P<filename>[^"]+)", line (?P<lineno>\d+), in (?P<name>[^\n]+)(?:\n(?!^\s*File)^\s*(?P<line>[^\n]+))?\n',
|
20
|
+
flags=re.MULTILINE,
|
21
|
+
)
|
22
|
+
_REMOTE_ERROR_ATTR_NAME = "_remote_error"
|
23
|
+
|
24
|
+
RemoteError = namedtuple("RemoteError", ["exc_type", "exc_msg", "exc_tb"])
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass(frozen=True)
|
28
|
+
class ExecutionResult:
|
29
|
+
result: Any = None
|
30
|
+
exception: Optional[BaseException] = None
|
31
|
+
|
32
|
+
@property
|
33
|
+
def success(self) -> bool:
|
34
|
+
return self.exception is None
|
35
|
+
|
36
|
+
def to_dict(self) -> dict[str, Any]:
|
37
|
+
"""Return the serializable dictionary."""
|
38
|
+
if isinstance(self.exception, BaseException):
|
39
|
+
exc_type = type(self.exception)
|
40
|
+
return {
|
41
|
+
"success": False,
|
42
|
+
"exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
|
43
|
+
"exc_value": self.exception,
|
44
|
+
"exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
|
45
|
+
}
|
46
|
+
return {
|
47
|
+
"success": True,
|
48
|
+
"result_type": type(self.result).__qualname__,
|
49
|
+
"result": self.result,
|
50
|
+
}
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def from_dict(cls, result_dict: dict[str, Any]) -> "ExecutionResult":
|
54
|
+
if not isinstance(result_dict.get("success"), bool):
|
55
|
+
raise ValueError("Invalid result dictionary")
|
56
|
+
|
57
|
+
if result_dict["success"]:
|
58
|
+
# Load successful result
|
59
|
+
return cls(result=result_dict.get("result"))
|
60
|
+
|
61
|
+
# Load exception
|
62
|
+
exc_type = result_dict.get("exc_type", "RuntimeError")
|
63
|
+
exc_value = result_dict.get("exc_value", "Unknown error")
|
64
|
+
exc_tb = result_dict.get("exc_tb", "")
|
65
|
+
return cls(exception=load_exception(exc_type, exc_value, exc_tb))
|
66
|
+
|
67
|
+
|
68
|
+
def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult:
|
69
|
+
"""
|
70
|
+
Fetch the serialized result from the specified path.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
session: Snowpark Session to use for file operations.
|
74
|
+
result_path: The path to the serialized result file.
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
A dictionary containing the execution result if available, None otherwise.
|
78
|
+
"""
|
79
|
+
try:
|
80
|
+
# TODO: Check if file exists
|
81
|
+
with session.file.get_stream(result_path) as result_stream:
|
82
|
+
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
+
except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
|
84
|
+
# Fall back to JSON result if loading pickled result fails for any reason
|
85
|
+
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
86
|
+
with session.file.get_stream(result_json_path) as result_stream:
|
87
|
+
return ExecutionResult.from_dict(json.load(result_stream))
|
88
|
+
|
89
|
+
|
90
|
+
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
|
91
|
+
"""
|
92
|
+
Create an exception with a string-formatted traceback.
|
93
|
+
|
94
|
+
When this exception is raised and not caught, it will display the original traceback.
|
95
|
+
When caught, it behaves like a regular exception without showing the traceback.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
exc_type_name: Name of the exception type (e.g., 'ValueError', 'RuntimeError')
|
99
|
+
exc_value: The deserialized exception value or exception string (i.e. message)
|
100
|
+
exc_tb: String representation of the traceback
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
An exception object with the original traceback information
|
104
|
+
|
105
|
+
# noqa: DAR401
|
106
|
+
"""
|
107
|
+
if isinstance(exc_value, Exception):
|
108
|
+
exception = exc_value
|
109
|
+
else:
|
110
|
+
# Try to load the original exception type if possible
|
111
|
+
try:
|
112
|
+
# First check built-in exceptions
|
113
|
+
exc_type = getattr(builtins, exc_type_name, None)
|
114
|
+
if exc_type is None and "." in exc_type_name:
|
115
|
+
# Try to import from module path if it's a qualified name
|
116
|
+
module_path, class_name = exc_type_name.rsplit(".", 1)
|
117
|
+
module = importlib.import_module(module_path)
|
118
|
+
exc_type = getattr(module, class_name)
|
119
|
+
if exc_type is None or not issubclass(exc_type, Exception):
|
120
|
+
raise TypeError(f"{exc_type_name} is not a known exception type")
|
121
|
+
# Create the exception instance
|
122
|
+
exception = exc_type(exc_value)
|
123
|
+
except (ImportError, AttributeError, TypeError):
|
124
|
+
# Fall back to a generic exception
|
125
|
+
exception = RuntimeError(
|
126
|
+
f"Exception deserialization failed, original exception: {exc_type_name}: {exc_value}"
|
127
|
+
)
|
128
|
+
|
129
|
+
# Attach the traceback information to the exception
|
130
|
+
return _attach_remote_error_info(exception, exc_type_name, str(exc_value), exc_tb)
|
131
|
+
|
132
|
+
|
133
|
+
def _attach_remote_error_info(ex: Exception, exc_type: str, exc_msg: str, traceback_str: str) -> Exception:
|
134
|
+
"""
|
135
|
+
Attach a string-formatted traceback to an exception.
|
136
|
+
|
137
|
+
When the exception is raised and not caught, it will display the original traceback.
|
138
|
+
When caught, it behaves like a regular exception without showing the traceback.
|
139
|
+
|
140
|
+
Args:
|
141
|
+
ex: The exception object to modify
|
142
|
+
exc_type: The original exception type name
|
143
|
+
exc_msg: The original exception message
|
144
|
+
traceback_str: String representation of the traceback
|
145
|
+
|
146
|
+
Returns:
|
147
|
+
An exception object with the original traceback information
|
148
|
+
"""
|
149
|
+
# Store the traceback information
|
150
|
+
exc_type = exc_type.rsplit(".", 1)[-1] # Remove module path
|
151
|
+
setattr(ex, _REMOTE_ERROR_ATTR_NAME, RemoteError(exc_type=exc_type, exc_msg=exc_msg, exc_tb=traceback_str))
|
152
|
+
return ex
|
153
|
+
|
154
|
+
|
155
|
+
def _retrieve_remote_error_info(ex: Optional[BaseException]) -> Optional[RemoteError]:
|
156
|
+
"""
|
157
|
+
Retrieve the string-formatted traceback from an exception if it exists.
|
158
|
+
|
159
|
+
Args:
|
160
|
+
ex: The exception to retrieve the traceback from
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
The remote error tuple if it exists, None otherwise
|
164
|
+
"""
|
165
|
+
if not ex:
|
166
|
+
return None
|
167
|
+
return getattr(ex, _REMOTE_ERROR_ATTR_NAME, None)
|
168
|
+
|
169
|
+
|
170
|
+
# ###############################################################################
|
171
|
+
# ------------------------------- !!! NOTE !!! -------------------------------- #
|
172
|
+
# ###############################################################################
|
173
|
+
# Job execution results (including uncaught exceptions) are serialized to file(s)
|
174
|
+
# in mljob_launcher.py. When the job is executed remotely, the serialized results
|
175
|
+
# are fetched and deserialized in the local environment. If the result contains
|
176
|
+
# an exception the original traceback is reconstructed and displayed to the user.
|
177
|
+
#
|
178
|
+
# It's currently impossible to recreate the original traceback object, so the
|
179
|
+
# following overrides are necessary to attach and display the deserialized
|
180
|
+
# traceback during exception handling.
|
181
|
+
#
|
182
|
+
# The following code implements the necessary overrides including sys.excepthook
|
183
|
+
# modifications and IPython traceback formatting. The hooks are applied on init
|
184
|
+
# and will be active for the duration of the process. The hooks are designed to
|
185
|
+
# self-uninstall in the event of an error in case of future compatibility issues.
|
186
|
+
# ###############################################################################
|
187
|
+
|
188
|
+
|
189
|
+
def _revert_func_wrapper(
|
190
|
+
patched_func: Callable[..., Any],
|
191
|
+
original_func: Callable[..., Any],
|
192
|
+
uninstall_func: Callable[[], None],
|
193
|
+
) -> Callable[..., Any]:
|
194
|
+
"""
|
195
|
+
Create a wrapper function that uninstalls the original function if an error occurs during execution.
|
196
|
+
|
197
|
+
This wrapper provides a fallback mechanism where if the patched function fails, it will:
|
198
|
+
1. Uninstall the patched function using the provided uninstall_func, reverting back to using the original function
|
199
|
+
2. Re-execute the current call using the original (unpatched) function with the same arguments
|
200
|
+
|
201
|
+
Args:
|
202
|
+
patched_func: The patched function to call.
|
203
|
+
original_func: The original function to call if patched_func fails.
|
204
|
+
uninstall_func: The function to call to uninstall the patched function.
|
205
|
+
|
206
|
+
Returns:
|
207
|
+
A wrapped function that calls patched_func and uninstalls on failure.
|
208
|
+
"""
|
209
|
+
|
210
|
+
@functools.wraps(patched_func)
|
211
|
+
def wrapped(*args: Any, **kwargs: Any) -> Any:
|
212
|
+
try:
|
213
|
+
return patched_func(*args, **kwargs)
|
214
|
+
except Exception:
|
215
|
+
# Uninstall and revert to original on failure
|
216
|
+
uninstall_func()
|
217
|
+
return original_func(*args, **kwargs)
|
218
|
+
|
219
|
+
return wrapped
|
220
|
+
|
221
|
+
|
222
|
+
def _install_sys_excepthook() -> None:
|
223
|
+
"""
|
224
|
+
Install a custom sys.excepthook to handle remote exception tracebacks.
|
225
|
+
|
226
|
+
sys.excepthook is the global hook that Python calls when an unhandled exception occurs.
|
227
|
+
By default it prints the exception type, message and traceback to stderr.
|
228
|
+
|
229
|
+
We override sys.excepthook to intercept exceptions that contain our special RemoteError
|
230
|
+
attribute. These exceptions come from deserialized remote execution results and contain
|
231
|
+
the original traceback information from where they occurred.
|
232
|
+
|
233
|
+
When such an exception is detected, we format and display the original remote traceback
|
234
|
+
instead of the local one, which provides better debugging context by showing where the
|
235
|
+
error actually happened during remote execution.
|
236
|
+
|
237
|
+
The custom hook maintains proper exception chaining for both __cause__ (from raise from)
|
238
|
+
and __context__ (from implicit exception chaining).
|
239
|
+
"""
|
240
|
+
# Attach the custom excepthook for standard Python scripts if not already attached
|
241
|
+
if not hasattr(sys, "_original_excepthook"):
|
242
|
+
original_excepthook = sys.excepthook
|
243
|
+
|
244
|
+
def custom_excepthook(
|
245
|
+
exc_type: type[BaseException],
|
246
|
+
exc_value: BaseException,
|
247
|
+
exc_tb: Optional[TracebackType],
|
248
|
+
*,
|
249
|
+
seen_exc_ids: Optional[set[int]] = None,
|
250
|
+
) -> None:
|
251
|
+
if seen_exc_ids is None:
|
252
|
+
seen_exc_ids = set()
|
253
|
+
seen_exc_ids.add(id(exc_value))
|
254
|
+
|
255
|
+
cause = getattr(exc_value, "__cause__", None)
|
256
|
+
context = getattr(exc_value, "__context__", None)
|
257
|
+
if cause:
|
258
|
+
# Handle cause-chained exceptions
|
259
|
+
custom_excepthook(type(cause), cause, cause.__traceback__, seen_exc_ids=seen_exc_ids)
|
260
|
+
print( # noqa: T201
|
261
|
+
"\nThe above exception was the direct cause of the following exception:\n", file=sys.stderr
|
262
|
+
)
|
263
|
+
elif context and not getattr(exc_value, "__suppress_context__", False):
|
264
|
+
# Handle context-chained exceptions
|
265
|
+
# Only process context if it's different from cause to avoid double printing
|
266
|
+
custom_excepthook(type(context), context, context.__traceback__, seen_exc_ids=seen_exc_ids)
|
267
|
+
print( # noqa: T201
|
268
|
+
"\nDuring handling of the above exception, another exception occurred:\n", file=sys.stderr
|
269
|
+
)
|
270
|
+
|
271
|
+
if (remote_err := _retrieve_remote_error_info(exc_value)) and isinstance(remote_err, RemoteError):
|
272
|
+
# Display stored traceback for deserialized exceptions
|
273
|
+
print("Traceback (from remote execution):", file=sys.stderr) # noqa: T201
|
274
|
+
print(remote_err.exc_tb, end="", file=sys.stderr) # noqa: T201
|
275
|
+
print(f"{remote_err.exc_type}: {remote_err.exc_msg}", file=sys.stderr) # noqa: T201
|
276
|
+
else:
|
277
|
+
# Fall back to the original excepthook
|
278
|
+
traceback.print_exception(exc_type, exc_value, exc_tb, file=sys.stderr, chain=False)
|
279
|
+
|
280
|
+
sys._original_excepthook = original_excepthook # type: ignore[attr-defined]
|
281
|
+
sys.excepthook = _revert_func_wrapper(custom_excepthook, original_excepthook, _uninstall_sys_excepthook)
|
282
|
+
|
283
|
+
|
284
|
+
def _uninstall_sys_excepthook() -> None:
|
285
|
+
"""
|
286
|
+
Restore the original excepthook for the current process.
|
287
|
+
|
288
|
+
This is useful when we want to revert to the default behavior after installing a custom excepthook.
|
289
|
+
"""
|
290
|
+
if hasattr(sys, "_original_excepthook"):
|
291
|
+
sys.excepthook = sys._original_excepthook
|
292
|
+
del sys._original_excepthook
|
293
|
+
|
294
|
+
|
295
|
+
def _install_ipython_hook() -> bool:
|
296
|
+
"""Install IPython-specific exception handling hook to improve remote error reporting.
|
297
|
+
|
298
|
+
This function enhances IPython's error formatting capabilities by intercepting and customizing
|
299
|
+
how remote execution errors are displayed. It modifies two key IPython traceback formatters:
|
300
|
+
|
301
|
+
1. VerboseTB.format_exception_as_a_whole: Customizes the full traceback formatting for remote
|
302
|
+
errors by:
|
303
|
+
- Adding a "(from remote execution)" header instead of "(most recent call last)"
|
304
|
+
- Properly formatting the remote traceback entries
|
305
|
+
- Maintaining original behavior for non-remote errors
|
306
|
+
|
307
|
+
2. ListTB.structured_traceback: Modifies the structured traceback output by:
|
308
|
+
- Parsing and formatting remote tracebacks appropriately
|
309
|
+
- Adding remote execution context to the output
|
310
|
+
- Preserving original functionality for local errors
|
311
|
+
|
312
|
+
The modifications are needed because IPython's default error handling doesn't properly display
|
313
|
+
remote execution errors that occur in Snowpark/Snowflake operations. The custom formatters
|
314
|
+
ensure that error messages from remote executions are properly captured, formatted and displayed
|
315
|
+
with the correct context and traceback information.
|
316
|
+
|
317
|
+
Returns:
|
318
|
+
bool: True if IPython hooks were successfully installed, False if IPython is not available
|
319
|
+
or not in an IPython environment.
|
320
|
+
|
321
|
+
Note:
|
322
|
+
This function maintains the ability to revert changes through _uninstall_ipython_hook by
|
323
|
+
storing original implementations before applying modifications.
|
324
|
+
"""
|
325
|
+
try:
|
326
|
+
from IPython.core.getipython import get_ipython
|
327
|
+
from IPython.core.ultratb import ListTB, VerboseTB
|
328
|
+
|
329
|
+
if get_ipython() is None:
|
330
|
+
return False
|
331
|
+
except ImportError:
|
332
|
+
return False
|
333
|
+
|
334
|
+
def parse_traceback_str(traceback_str: str) -> list[tuple[str, int, str, str]]:
|
335
|
+
return [
|
336
|
+
(m.group("filename"), int(m.group("lineno")), m.group("name"), m.group("line"))
|
337
|
+
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, traceback_str)
|
338
|
+
]
|
339
|
+
|
340
|
+
if not hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
|
341
|
+
original_format_exception_as_a_whole = VerboseTB.format_exception_as_a_whole
|
342
|
+
|
343
|
+
def custom_format_exception_as_a_whole(
|
344
|
+
self: VerboseTB,
|
345
|
+
etype: type[BaseException],
|
346
|
+
evalue: Optional[BaseException],
|
347
|
+
etb: Optional[TracebackType],
|
348
|
+
number_of_lines_of_context: int,
|
349
|
+
tb_offset: Optional[int],
|
350
|
+
**kwargs: Any,
|
351
|
+
) -> list[list[str]]:
|
352
|
+
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
353
|
+
# Implementation forked from IPython.core.ultratb.VerboseTB.format_exception_as_a_whole
|
354
|
+
head = self.prepare_header(remote_err.exc_type, long_version=False).replace(
|
355
|
+
"(most recent call last)",
|
356
|
+
"(from remote execution)",
|
357
|
+
)
|
358
|
+
|
359
|
+
frames = ListTB._format_list(
|
360
|
+
self,
|
361
|
+
parse_traceback_str(remote_err.exc_tb),
|
362
|
+
)
|
363
|
+
formatted_exception = self.format_exception(remote_err.exc_type, remote_err.exc_msg)
|
364
|
+
|
365
|
+
return [[head] + frames + formatted_exception]
|
366
|
+
return original_format_exception_as_a_whole( # type: ignore[no-any-return]
|
367
|
+
self,
|
368
|
+
etype=etype,
|
369
|
+
evalue=evalue,
|
370
|
+
etb=etb,
|
371
|
+
number_of_lines_of_context=number_of_lines_of_context,
|
372
|
+
tb_offset=tb_offset,
|
373
|
+
**kwargs,
|
374
|
+
)
|
375
|
+
|
376
|
+
VerboseTB._original_format_exception_as_a_whole = original_format_exception_as_a_whole
|
377
|
+
VerboseTB.format_exception_as_a_whole = _revert_func_wrapper(
|
378
|
+
custom_format_exception_as_a_whole, original_format_exception_as_a_whole, _uninstall_ipython_hook
|
379
|
+
)
|
380
|
+
|
381
|
+
if not hasattr(ListTB, "_original_structured_traceback"):
|
382
|
+
original_structured_traceback = ListTB.structured_traceback
|
383
|
+
|
384
|
+
def structured_traceback(
|
385
|
+
self: ListTB,
|
386
|
+
etype: type,
|
387
|
+
evalue: Optional[BaseException],
|
388
|
+
etb: Optional[TracebackType],
|
389
|
+
tb_offset: Optional[int] = None,
|
390
|
+
**kwargs: Any,
|
391
|
+
) -> list[str]:
|
392
|
+
if (remote_err := _retrieve_remote_error_info(evalue)) and isinstance(remote_err, RemoteError):
|
393
|
+
tb_list = [
|
394
|
+
(m.group("filename"), m.group("lineno"), m.group("name"), m.group("line"))
|
395
|
+
for m in re.finditer(_TRACEBACK_ENTRY_PATTERN, remote_err.exc_tb or "")
|
396
|
+
]
|
397
|
+
out_list = original_structured_traceback(self, etype, evalue, tb_list, tb_offset, **kwargs)
|
398
|
+
if out_list:
|
399
|
+
out_list[0] = out_list[0].replace(
|
400
|
+
"(most recent call last)",
|
401
|
+
"(from remote execution)",
|
402
|
+
)
|
403
|
+
return cast(list[str], out_list)
|
404
|
+
return original_structured_traceback( # type: ignore[no-any-return]
|
405
|
+
self, etype, evalue, etb, tb_offset, **kwargs
|
406
|
+
)
|
407
|
+
|
408
|
+
ListTB._original_structured_traceback = original_structured_traceback
|
409
|
+
ListTB.structured_traceback = _revert_func_wrapper(
|
410
|
+
structured_traceback, original_structured_traceback, _uninstall_ipython_hook
|
411
|
+
)
|
412
|
+
|
413
|
+
return True
|
414
|
+
|
415
|
+
|
416
|
+
def _uninstall_ipython_hook() -> None:
|
417
|
+
"""
|
418
|
+
Restore the original IPython traceback formatting if it was modified.
|
419
|
+
|
420
|
+
This is useful when we want to revert to the default behavior after installing a custom hook.
|
421
|
+
"""
|
422
|
+
try:
|
423
|
+
from IPython.core.ultratb import ListTB, VerboseTB
|
424
|
+
|
425
|
+
if hasattr(VerboseTB, "_original_format_exception_as_a_whole"):
|
426
|
+
VerboseTB.format_exception_as_a_whole = VerboseTB._original_format_exception_as_a_whole
|
427
|
+
del VerboseTB._original_format_exception_as_a_whole
|
428
|
+
|
429
|
+
if hasattr(ListTB, "_original_structured_traceback"):
|
430
|
+
ListTB.structured_traceback = ListTB._original_structured_traceback
|
431
|
+
del ListTB._original_structured_traceback
|
432
|
+
except ImportError:
|
433
|
+
pass
|
434
|
+
|
435
|
+
|
436
|
+
def install_exception_display_hooks() -> None:
|
437
|
+
if not _install_ipython_hook():
|
438
|
+
_install_sys_excepthook()
|
439
|
+
|
440
|
+
|
441
|
+
# ------ Install the custom traceback hooks by default ------ #
|
442
|
+
install_exception_display_hooks()
|