snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +64 -31
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +41 -5
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +40 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +12 -8
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/constants.py +2 -4
- snowflake/ml/jobs/_utils/interop_utils.py +442 -0
- snowflake/ml/jobs/_utils/payload_utils.py +86 -62
- snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
- snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
- snowflake/ml/jobs/_utils/spec_utils.py +22 -36
- snowflake/ml/jobs/_utils/types.py +8 -2
- snowflake/ml/jobs/decorators.py +7 -8
- snowflake/ml/jobs/job.py +158 -26
- snowflake/ml/jobs/manager.py +78 -30
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +230 -50
- snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +22 -18
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +46 -25
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +35 -26
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +12 -8
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +5 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +50 -29
- snowflake/ml/registry/registry.py +34 -23
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -6,17 +6,7 @@ import pickle
|
|
6
6
|
import sys
|
7
7
|
import textwrap
|
8
8
|
from pathlib import Path, PurePath
|
9
|
-
from typing import
|
10
|
-
Any,
|
11
|
-
Callable,
|
12
|
-
List,
|
13
|
-
Optional,
|
14
|
-
Type,
|
15
|
-
Union,
|
16
|
-
cast,
|
17
|
-
get_args,
|
18
|
-
get_origin,
|
19
|
-
)
|
9
|
+
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
|
20
10
|
|
21
11
|
import cloudpickle as cp
|
22
12
|
|
@@ -27,6 +17,7 @@ from snowflake.snowpark._internal import code_generation
|
|
27
17
|
|
28
18
|
_SUPPORTED_ARG_TYPES = {str, int, float}
|
29
19
|
_SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
|
20
|
+
_ENTRYPOINT_FUNC_NAME = "func"
|
30
21
|
_STARTUP_SCRIPT_PATH = PurePath("startup.sh")
|
31
22
|
_STARTUP_SCRIPT_CODE = textwrap.dedent(
|
32
23
|
f"""
|
@@ -73,14 +64,14 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
73
64
|
##### Ray configuration #####
|
74
65
|
shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
|
75
66
|
|
76
|
-
# Check if the
|
67
|
+
# Check if the local get_instance_ip.py script exists
|
77
68
|
HELPER_EXISTS=$(
|
78
|
-
|
69
|
+
[ -f "get_instance_ip.py" ] && echo "true" || echo "false"
|
79
70
|
)
|
80
71
|
|
81
72
|
# Configure IP address and logging directory
|
82
73
|
if [ "$HELPER_EXISTS" = "true" ]; then
|
83
|
-
eth0Ip=$(python3
|
74
|
+
eth0Ip=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
|
84
75
|
else
|
85
76
|
eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
|
86
77
|
fi
|
@@ -103,7 +94,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
103
94
|
|
104
95
|
# Determine if it should be a worker or a head node for batch jobs
|
105
96
|
if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
|
106
|
-
head_info=$(python3
|
97
|
+
head_info=$(python3 get_instance_ip.py "$SNOWFLAKE_SERVICE_NAME" --head)
|
107
98
|
if [ $? -eq 0 ]; then
|
108
99
|
# Parse the output using read
|
109
100
|
read head_index head_ip <<< "$head_info"
|
@@ -166,10 +157,17 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
166
157
|
"--object-store-memory=${{shm_size}}"
|
167
158
|
)
|
168
159
|
|
169
|
-
# Start Ray on a worker node
|
170
|
-
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block
|
171
|
-
|
160
|
+
# Start Ray on a worker node - run in background
|
161
|
+
ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
|
162
|
+
|
163
|
+
# Start the worker shutdown listener in the background
|
164
|
+
echo "Starting worker shutdown listener..."
|
165
|
+
python worker_shutdown_listener.py
|
166
|
+
WORKER_EXIT_CODE=$?
|
172
167
|
|
168
|
+
echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
|
169
|
+
exit $WORKER_EXIT_CODE
|
170
|
+
else
|
173
171
|
# Additional head-specific parameters
|
174
172
|
head_params=(
|
175
173
|
"--head"
|
@@ -193,13 +191,39 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
|
|
193
191
|
# Run user's Python entrypoint
|
194
192
|
echo Running command: python "$@"
|
195
193
|
python "$@"
|
194
|
+
|
195
|
+
# After the user's job completes, signal workers to shut down
|
196
|
+
echo "User job completed. Signaling workers to shut down..."
|
197
|
+
python signal_workers.py --wait-time 15
|
198
|
+
echo "Head node job completed. Exiting."
|
196
199
|
fi
|
197
200
|
"""
|
198
201
|
).strip()
|
199
202
|
|
200
203
|
|
201
|
-
def
|
202
|
-
|
204
|
+
def resolve_source(source: Union[Path, Callable[..., Any]]) -> Union[Path, Callable[..., Any]]:
|
205
|
+
if callable(source):
|
206
|
+
return source
|
207
|
+
elif isinstance(source, Path):
|
208
|
+
# Validate source
|
209
|
+
source = source
|
210
|
+
if not source.exists():
|
211
|
+
raise FileNotFoundError(f"{source} does not exist")
|
212
|
+
return source.absolute()
|
213
|
+
else:
|
214
|
+
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
215
|
+
|
216
|
+
|
217
|
+
def resolve_entrypoint(source: Union[Path, Callable[..., Any]], entrypoint: Optional[Path]) -> types.PayloadEntrypoint:
|
218
|
+
if callable(source):
|
219
|
+
# Entrypoint is generated for callable payloads
|
220
|
+
return types.PayloadEntrypoint(
|
221
|
+
file_path=entrypoint or Path(constants.DEFAULT_ENTRYPOINT_PATH),
|
222
|
+
main_func=_ENTRYPOINT_FUNC_NAME,
|
223
|
+
)
|
224
|
+
|
225
|
+
# Resolve entrypoint path for file-based payloads
|
226
|
+
parent = source.absolute()
|
203
227
|
if entrypoint is None:
|
204
228
|
if parent.is_file():
|
205
229
|
# Infer entrypoint from source
|
@@ -218,12 +242,23 @@ def _resolve_entrypoint(parent: Path, entrypoint: Optional[Path]) -> Path:
|
|
218
242
|
else:
|
219
243
|
# Relative to source dir
|
220
244
|
entrypoint = parent.joinpath(entrypoint)
|
245
|
+
|
246
|
+
# Validate resolved entrypoint file
|
221
247
|
if not entrypoint.is_file():
|
222
248
|
raise FileNotFoundError(
|
223
249
|
"Entrypoint not found. Ensure the entrypoint is a valid file and is under"
|
224
250
|
f" the source directory (source={parent}, entrypoint={entrypoint})"
|
225
251
|
)
|
226
|
-
|
252
|
+
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
253
|
+
raise ValueError(
|
254
|
+
"Unsupported entrypoint type:"
|
255
|
+
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
256
|
+
)
|
257
|
+
|
258
|
+
return types.PayloadEntrypoint(
|
259
|
+
file_path=entrypoint, # entrypoint is an absolute path at this point
|
260
|
+
main_func=None,
|
261
|
+
)
|
227
262
|
|
228
263
|
|
229
264
|
class JobPayload:
|
@@ -232,46 +267,17 @@ class JobPayload:
|
|
232
267
|
source: Union[str, Path, Callable[..., Any]],
|
233
268
|
entrypoint: Optional[Union[str, Path]] = None,
|
234
269
|
*,
|
235
|
-
pip_requirements: Optional[
|
270
|
+
pip_requirements: Optional[list[str]] = None,
|
236
271
|
) -> None:
|
237
272
|
self.source = Path(source) if isinstance(source, str) else source
|
238
273
|
self.entrypoint = Path(entrypoint) if isinstance(entrypoint, str) else entrypoint
|
239
274
|
self.pip_requirements = pip_requirements
|
240
275
|
|
241
|
-
def validate(self) -> None:
|
242
|
-
if callable(self.source):
|
243
|
-
# Any entrypoint value is OK for callable payloads (including None aka default)
|
244
|
-
# since we will generate the file from the serialized callable
|
245
|
-
pass
|
246
|
-
elif isinstance(self.source, Path):
|
247
|
-
# Validate source
|
248
|
-
source = self.source
|
249
|
-
if not source.exists():
|
250
|
-
raise FileNotFoundError(f"{source} does not exist")
|
251
|
-
source = source.absolute()
|
252
|
-
|
253
|
-
# Validate entrypoint
|
254
|
-
entrypoint = _resolve_entrypoint(source, self.entrypoint)
|
255
|
-
if entrypoint.suffix not in _SUPPORTED_ENTRYPOINT_EXTENSIONS:
|
256
|
-
raise ValueError(
|
257
|
-
"Unsupported entrypoint type:"
|
258
|
-
f" supported={','.join(_SUPPORTED_ENTRYPOINT_EXTENSIONS)} got={entrypoint.suffix}"
|
259
|
-
)
|
260
|
-
|
261
|
-
# Update fields with normalized values
|
262
|
-
self.source = source
|
263
|
-
self.entrypoint = entrypoint
|
264
|
-
else:
|
265
|
-
raise ValueError("Unsupported source type. Source must be a file, directory, or callable.")
|
266
|
-
|
267
276
|
def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
|
268
|
-
# Validate payload
|
269
|
-
self.validate()
|
270
|
-
|
271
277
|
# Prepare local variables
|
272
278
|
stage_path = PurePath(stage_path) if isinstance(stage_path, str) else stage_path
|
273
|
-
source = self.source
|
274
|
-
entrypoint = self.entrypoint
|
279
|
+
source = resolve_source(self.source)
|
280
|
+
entrypoint = resolve_entrypoint(source, self.entrypoint)
|
275
281
|
|
276
282
|
# Create stage if necessary
|
277
283
|
stage_name = stage_path.parts[0].lstrip("@")
|
@@ -290,11 +296,11 @@ class JobPayload:
|
|
290
296
|
source_code = generate_python_code(source, source_code_display=True)
|
291
297
|
_ = session.file.put_stream(
|
292
298
|
io.BytesIO(source_code.encode()),
|
293
|
-
stage_location=stage_path.joinpath(entrypoint).as_posix(),
|
299
|
+
stage_location=stage_path.joinpath(entrypoint.file_path).as_posix(),
|
294
300
|
auto_compress=False,
|
295
301
|
overwrite=True,
|
296
302
|
)
|
297
|
-
source = entrypoint.parent
|
303
|
+
source = Path(entrypoint.file_path.parent)
|
298
304
|
elif source.is_dir():
|
299
305
|
# Manually traverse the directory and upload each file, since Snowflake PUT
|
300
306
|
# can't handle directories. Reduce the number of PUT operations by using
|
@@ -337,17 +343,35 @@ class JobPayload:
|
|
337
343
|
overwrite=False, # FIXME
|
338
344
|
)
|
339
345
|
|
346
|
+
# Upload system scripts
|
347
|
+
scripts_dir = Path(__file__).parent.joinpath("scripts")
|
348
|
+
for script_file in scripts_dir.glob("*"):
|
349
|
+
if script_file.is_file():
|
350
|
+
session.file.put(
|
351
|
+
script_file.as_posix(),
|
352
|
+
stage_path.as_posix(),
|
353
|
+
overwrite=True,
|
354
|
+
auto_compress=False,
|
355
|
+
)
|
356
|
+
|
357
|
+
python_entrypoint: list[Union[str, PurePath]] = [
|
358
|
+
PurePath("mljob_launcher.py"),
|
359
|
+
entrypoint.file_path.relative_to(source),
|
360
|
+
]
|
361
|
+
if entrypoint.main_func:
|
362
|
+
python_entrypoint += ["--script_main_func", entrypoint.main_func]
|
363
|
+
|
340
364
|
return types.UploadedPayload(
|
341
365
|
stage_path=stage_path,
|
342
366
|
entrypoint=[
|
343
367
|
"bash",
|
344
368
|
_STARTUP_SCRIPT_PATH,
|
345
|
-
|
369
|
+
*python_entrypoint,
|
346
370
|
],
|
347
371
|
)
|
348
372
|
|
349
373
|
|
350
|
-
def _get_parameter_type(param: inspect.Parameter) -> Optional[
|
374
|
+
def _get_parameter_type(param: inspect.Parameter) -> Optional[type[object]]:
|
351
375
|
# Unwrap Optional type annotations
|
352
376
|
param_type = param.annotation
|
353
377
|
if get_origin(param_type) is Union and len(get_args(param_type)) == 2 and type(None) in get_args(param_type):
|
@@ -356,10 +380,10 @@ def _get_parameter_type(param: inspect.Parameter) -> Optional[Type[object]]:
|
|
356
380
|
# Return None for empty type annotations
|
357
381
|
if param_type == inspect.Parameter.empty:
|
358
382
|
return None
|
359
|
-
return cast(
|
383
|
+
return cast(type[object], param_type)
|
360
384
|
|
361
385
|
|
362
|
-
def _validate_parameter_type(param_type:
|
386
|
+
def _validate_parameter_type(param_type: type[object], param_name: str) -> None:
|
363
387
|
# Validate param_type is a supported type
|
364
388
|
if param_type not in _SUPPORTED_ARG_TYPES:
|
365
389
|
raise ValueError(
|
@@ -471,12 +495,11 @@ def generate_python_code(func: Callable[..., Any], source_code_display: bool = F
|
|
471
495
|
# https://github.com/snowflakedb/snowpark-python/blob/main/src/snowflake/snowpark/_internal/udf_utils.py
|
472
496
|
source_code_comment = _generate_source_code_comment(func) if source_code_display else ""
|
473
497
|
|
474
|
-
func_name = "func"
|
475
498
|
func_code = f"""
|
476
499
|
{source_code_comment}
|
477
500
|
|
478
501
|
import pickle
|
479
|
-
{
|
502
|
+
{_ENTRYPOINT_FUNC_NAME} = pickle.loads(bytes.fromhex('{_serialize_callable(func).hex()}'))
|
480
503
|
"""
|
481
504
|
|
482
505
|
arg_dict_name = "kwargs"
|
@@ -487,6 +510,7 @@ import pickle
|
|
487
510
|
|
488
511
|
return f"""
|
489
512
|
### Version guard to check compatibility across Python versions ###
|
513
|
+
import os
|
490
514
|
import sys
|
491
515
|
import warnings
|
492
516
|
|
@@ -508,5 +532,5 @@ if sys.version_info.major != {sys.version_info.major} or sys.version_info.minor
|
|
508
532
|
if __name__ == '__main__':
|
509
533
|
{textwrap.indent(param_code, ' ')}
|
510
534
|
|
511
|
-
{
|
535
|
+
__return__ = {_ENTRYPOINT_FUNC_NAME}(**{arg_dict_name})
|
512
536
|
"""
|
@@ -0,0 +1,136 @@
|
|
1
|
+
#!/usr/bin/env python3
|
2
|
+
# This file is modified from mlruntime/service/snowflake/runtime/utils
|
3
|
+
import argparse
|
4
|
+
import logging
|
5
|
+
import socket
|
6
|
+
import sys
|
7
|
+
import time
|
8
|
+
from typing import Optional
|
9
|
+
|
10
|
+
# Configure logging
|
11
|
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
12
|
+
logger = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
def get_self_ip() -> Optional[str]:
|
16
|
+
"""Get the IP address of the current service instance.
|
17
|
+
References:
|
18
|
+
- https://docs.snowflake.com/en/developer-guide/snowpark-container-services/working-with-services#general-guidelines-related-to-service-to-service-communications # noqa: E501
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
Optional[str]: The IP address of the current service instance, or None if unable to retrieve.
|
22
|
+
"""
|
23
|
+
try:
|
24
|
+
hostname = socket.gethostname()
|
25
|
+
instance_ip = socket.gethostbyname(hostname)
|
26
|
+
return instance_ip
|
27
|
+
except OSError as e:
|
28
|
+
logger.error(f"Error: Unable to get IP address via socket. {e}")
|
29
|
+
return None
|
30
|
+
|
31
|
+
|
32
|
+
def get_first_instance(service_name: str) -> Optional[tuple[str, str]]:
|
33
|
+
"""Get the first instance of a batch job based on start time and instance ID.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
service_name (str): The name of the service to query.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
tuple[str, str]: A tuple containing (instance_id, ip_address) of the head instance.
|
40
|
+
"""
|
41
|
+
from snowflake.runtime.utils import session_utils
|
42
|
+
|
43
|
+
session = session_utils.get_session()
|
44
|
+
df = session.sql(f"show service instances in service {service_name}")
|
45
|
+
result = df.select('"instance_id"', '"ip_address"', '"start_time"').collect()
|
46
|
+
|
47
|
+
if not result:
|
48
|
+
return None
|
49
|
+
|
50
|
+
# Sort by start_time first, then by instance_id
|
51
|
+
sorted_instances = sorted(result, key=lambda x: (x["start_time"], int(x["instance_id"])))
|
52
|
+
head_instance = sorted_instances[0]
|
53
|
+
if not head_instance["instance_id"] or not head_instance["ip_address"]:
|
54
|
+
return None
|
55
|
+
|
56
|
+
# Validate head instance IP
|
57
|
+
ip_address = head_instance["ip_address"]
|
58
|
+
try:
|
59
|
+
socket.inet_aton(ip_address) # Validate IPv4 address
|
60
|
+
return (head_instance["instance_id"], ip_address)
|
61
|
+
except OSError:
|
62
|
+
logger.error(f"Error: Invalid IP address format: {ip_address}")
|
63
|
+
return None
|
64
|
+
|
65
|
+
|
66
|
+
def main():
|
67
|
+
"""Retrieves the IP address of a specified service instance or the current service.
|
68
|
+
Args:
|
69
|
+
service_name (str,required) Name of the service to query
|
70
|
+
--instance-index (int, optional) Index of the service instance to query. Default: -1
|
71
|
+
Currently only supports -1 to get the IP address of the current service instance.
|
72
|
+
--head (bool, optional) Get the head instance information using show services.
|
73
|
+
If set, instance-index will be ignored, and the script will return the index and IP address of
|
74
|
+
the head instance, split by a space. Default: False.
|
75
|
+
--timeout (int, optional) Maximum time to wait for IP address retrieval in seconds. Default: 720 seconds
|
76
|
+
--retry-interval (int, optional) Time to wait between retry attempts in seconds. Default: 10 seconds
|
77
|
+
Usage Examples:
|
78
|
+
python get_instance_ip.py myservice --instance-index=1 --retry-interval=5
|
79
|
+
Returns:
|
80
|
+
Prints the IP address to stdout if successful. Exits with status code 0 on success, 1 on failure
|
81
|
+
"""
|
82
|
+
|
83
|
+
parser = argparse.ArgumentParser(description="Get IP address of a service instance")
|
84
|
+
group = parser.add_mutually_exclusive_group()
|
85
|
+
parser.add_argument("service_name", help="Name of the service")
|
86
|
+
group.add_argument(
|
87
|
+
"--instance-index",
|
88
|
+
type=int,
|
89
|
+
default=-1,
|
90
|
+
help="Index of service instance (default: -1 for self instance)",
|
91
|
+
)
|
92
|
+
group.add_argument(
|
93
|
+
"--head",
|
94
|
+
action="store_true",
|
95
|
+
help="Get head instance information using show services",
|
96
|
+
)
|
97
|
+
parser.add_argument("--timeout", type=int, default=720, help="Timeout in seconds (default: 720)")
|
98
|
+
parser.add_argument(
|
99
|
+
"--retry-interval",
|
100
|
+
type=int,
|
101
|
+
default=10,
|
102
|
+
help="Retry interval in seconds (default: 10)",
|
103
|
+
)
|
104
|
+
|
105
|
+
args = parser.parse_args()
|
106
|
+
start_time = time.time()
|
107
|
+
|
108
|
+
if args.head:
|
109
|
+
while time.time() - start_time < args.timeout:
|
110
|
+
head_info = get_first_instance(args.service_name)
|
111
|
+
if head_info:
|
112
|
+
# Print to stdout to allow capture but don't use logger
|
113
|
+
sys.stdout.write(f"{head_info[0]} {head_info[1]}\n")
|
114
|
+
sys.exit(0)
|
115
|
+
time.sleep(args.retry_interval)
|
116
|
+
# If we get here, we've timed out
|
117
|
+
logger.error("Error: Unable to retrieve head IP address")
|
118
|
+
sys.exit(1)
|
119
|
+
|
120
|
+
# If the index is -1, use get_self_ip to get the IP address of the current service
|
121
|
+
if args.instance_index == -1:
|
122
|
+
ip_address = get_self_ip()
|
123
|
+
if ip_address:
|
124
|
+
sys.stdout.write(f"{ip_address}\n")
|
125
|
+
sys.exit(0)
|
126
|
+
else:
|
127
|
+
logger.error("Error: Unable to retrieve self IP address")
|
128
|
+
sys.exit(1)
|
129
|
+
else:
|
130
|
+
# We don't support querying a specific instance index other than -1
|
131
|
+
logger.error("Error: Invalid arguments. Only --instance-index=-1 is supported for now.")
|
132
|
+
sys.exit(1)
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
main()
|
@@ -0,0 +1,181 @@
|
|
1
|
+
import argparse
|
2
|
+
import copy
|
3
|
+
import importlib.util
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import runpy
|
7
|
+
import sys
|
8
|
+
import traceback
|
9
|
+
import warnings
|
10
|
+
from pathlib import Path
|
11
|
+
from typing import Any, Optional
|
12
|
+
|
13
|
+
import cloudpickle
|
14
|
+
|
15
|
+
from snowflake.ml.jobs._utils import constants
|
16
|
+
from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
|
17
|
+
from snowflake.snowpark import Session
|
18
|
+
|
19
|
+
# Fallbacks in case of SnowML version mismatch
|
20
|
+
RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
|
21
|
+
|
22
|
+
JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
|
23
|
+
|
24
|
+
|
25
|
+
try:
|
26
|
+
from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
|
27
|
+
except ImportError:
|
28
|
+
from dataclasses import dataclass
|
29
|
+
|
30
|
+
@dataclass(frozen=True)
|
31
|
+
class ExecutionResult: # type: ignore[no-redef]
|
32
|
+
result: Optional[Any] = None
|
33
|
+
exception: Optional[BaseException] = None
|
34
|
+
|
35
|
+
@property
|
36
|
+
def success(self) -> bool:
|
37
|
+
return self.exception is None
|
38
|
+
|
39
|
+
def to_dict(self) -> dict[str, Any]:
|
40
|
+
"""Return the serializable dictionary."""
|
41
|
+
if isinstance(self.exception, BaseException):
|
42
|
+
exc_type = type(self.exception)
|
43
|
+
return {
|
44
|
+
"success": False,
|
45
|
+
"exc_type": f"{exc_type.__module__}.{exc_type.__name__}",
|
46
|
+
"exc_value": self.exception,
|
47
|
+
"exc_tb": "".join(traceback.format_tb(self.exception.__traceback__)),
|
48
|
+
}
|
49
|
+
return {
|
50
|
+
"success": True,
|
51
|
+
"result_type": type(self.result).__qualname__,
|
52
|
+
"result": self.result,
|
53
|
+
}
|
54
|
+
|
55
|
+
|
56
|
+
# Create a custom JSON encoder that converts non-serializable types to strings
|
57
|
+
class SimpleJSONEncoder(json.JSONEncoder):
|
58
|
+
def default(self, obj: Any) -> Any:
|
59
|
+
try:
|
60
|
+
return super().default(obj)
|
61
|
+
except TypeError:
|
62
|
+
return str(obj)
|
63
|
+
|
64
|
+
|
65
|
+
def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
|
66
|
+
"""
|
67
|
+
Execute a Python script and return its result.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
script_path: Path to the Python script
|
71
|
+
script_args: Arguments to pass to the script
|
72
|
+
main_func: The name of the function to call in the script (if any)
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Result from script execution, either from the main function or the script's __return__ value
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
RuntimeError: If the specified main_func is not found or not callable
|
79
|
+
"""
|
80
|
+
# Save original sys.argv and modify it for the script (applies to runpy execution only)
|
81
|
+
original_argv = sys.argv
|
82
|
+
sys.argv = [script_path, *script_args]
|
83
|
+
|
84
|
+
# Create a Snowpark session before running the script
|
85
|
+
# Session can be retrieved from using snowflake.snowpark.context.get_active_session()
|
86
|
+
session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
|
87
|
+
|
88
|
+
try:
|
89
|
+
if main_func:
|
90
|
+
# Use importlib for scripts with a main function defined
|
91
|
+
module_name = Path(script_path).stem
|
92
|
+
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
93
|
+
assert spec is not None
|
94
|
+
assert spec.loader is not None
|
95
|
+
module = importlib.util.module_from_spec(spec)
|
96
|
+
spec.loader.exec_module(module)
|
97
|
+
|
98
|
+
# Validate main function
|
99
|
+
if not (func := getattr(module, main_func, None)) or not callable(func):
|
100
|
+
raise RuntimeError(f"Function '{main_func}' not a valid entrypoint for {script_path}")
|
101
|
+
|
102
|
+
# Call main function
|
103
|
+
result = func(*script_args)
|
104
|
+
return result
|
105
|
+
else:
|
106
|
+
# Use runpy for other scripts
|
107
|
+
globals_dict = runpy.run_path(script_path, run_name="__main__")
|
108
|
+
result = globals_dict.get("__return__", None)
|
109
|
+
return result
|
110
|
+
finally:
|
111
|
+
# Restore original sys.argv
|
112
|
+
sys.argv = original_argv
|
113
|
+
|
114
|
+
|
115
|
+
def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
|
116
|
+
"""Executes a Python script and serializes the result to JOB_RESULT_PATH.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
script_path (str): Path to the Python script to execute.
|
120
|
+
script_args (Any): Arguments to pass to the script.
|
121
|
+
script_main_func (str, optional): The name of the function to call in the script (if any).
|
122
|
+
|
123
|
+
Returns:
|
124
|
+
ExecutionResult: Object containing execution results.
|
125
|
+
|
126
|
+
Raises:
|
127
|
+
Exception: Re-raises any exception caught during script execution.
|
128
|
+
"""
|
129
|
+
# Run the script with the specified arguments
|
130
|
+
try:
|
131
|
+
result = run_script(script_path, *script_args, main_func=script_main_func)
|
132
|
+
result_obj = ExecutionResult(result=result)
|
133
|
+
return result_obj
|
134
|
+
except Exception as e:
|
135
|
+
tb = e.__traceback__
|
136
|
+
skip_files = {__file__, runpy.__file__}
|
137
|
+
while tb and tb.tb_frame.f_code.co_filename in skip_files:
|
138
|
+
# Skip any frames preceding user script execution
|
139
|
+
tb = tb.tb_next
|
140
|
+
cleaned_ex = copy.copy(e) # Need to create a mutable copy of exception to set __traceback__
|
141
|
+
cleaned_ex = cleaned_ex.with_traceback(tb)
|
142
|
+
result_obj = ExecutionResult(exception=cleaned_ex)
|
143
|
+
raise
|
144
|
+
finally:
|
145
|
+
result_dict = result_obj.to_dict()
|
146
|
+
try:
|
147
|
+
# Serialize result using cloudpickle
|
148
|
+
result_pickle_path = JOB_RESULT_PATH
|
149
|
+
with open(result_pickle_path, "wb") as f:
|
150
|
+
cloudpickle.dump(result_dict, f) # Pickle dictionary form for compatibility
|
151
|
+
except Exception as pkl_exc:
|
152
|
+
warnings.warn(f"Failed to pickle result to {result_pickle_path}: {pkl_exc}", RuntimeWarning, stacklevel=1)
|
153
|
+
|
154
|
+
try:
|
155
|
+
# Serialize result to JSON as fallback path in case of cross version incompatibility
|
156
|
+
# TODO: Manually convert non-serializable types to strings
|
157
|
+
result_json_path = os.path.splitext(JOB_RESULT_PATH)[0] + ".json"
|
158
|
+
with open(result_json_path, "w") as f:
|
159
|
+
json.dump(result_dict, f, indent=2, cls=SimpleJSONEncoder)
|
160
|
+
except Exception as json_exc:
|
161
|
+
warnings.warn(
|
162
|
+
f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
|
163
|
+
)
|
164
|
+
|
165
|
+
|
166
|
+
if __name__ == "__main__":
|
167
|
+
# Parse command line arguments
|
168
|
+
parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
|
169
|
+
parser.add_argument("script_path", help="Path to the Python script to execute")
|
170
|
+
parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
|
171
|
+
parser.add_argument(
|
172
|
+
"--script_main_func", required=False, help="The name of the main function to call in the script"
|
173
|
+
)
|
174
|
+
args, unknown_args = parser.parse_known_args()
|
175
|
+
|
176
|
+
main(
|
177
|
+
args.script_path,
|
178
|
+
*args.script_args,
|
179
|
+
*unknown_args,
|
180
|
+
script_main_func=args.script_main_func,
|
181
|
+
)
|