snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -3,14 +3,14 @@ import importlib
3
3
  import inspect
4
4
  import io
5
5
  import itertools
6
- import keyword
7
6
  import logging
8
7
  import pickle
9
8
  import sys
10
9
  import textwrap
11
10
  from importlib.abc import Traversable
12
11
  from pathlib import Path, PurePath
13
- from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
12
+ from types import ModuleType
13
+ from typing import IO, Any, Callable, Optional, Union, cast, get_args, get_origin
14
14
 
15
15
  import cloudpickle as cp
16
16
  from packaging import version
@@ -25,279 +25,109 @@ from snowflake.ml.jobs._utils import (
25
25
  )
26
26
  from snowflake.snowpark import exceptions as sp_exceptions
27
27
  from snowflake.snowpark._internal import code_generation
28
+ from snowflake.snowpark._internal.utils import zip_file_or_directory_to_stream
28
29
 
29
30
  logger = logging.getLogger(__name__)
30
31
 
31
32
  cp.register_pickle_by_value(function_payload_utils)
32
-
33
+ ImportType = Union[str, Path, ModuleType]
33
34
 
34
35
  _SUPPORTED_ARG_TYPES = {str, int, float}
35
36
  _SUPPORTED_ENTRYPOINT_EXTENSIONS = {".py"}
36
37
  _ENTRYPOINT_FUNC_NAME = "func"
37
38
  _STARTUP_SCRIPT_PATH = PurePath("startup.sh")
38
- _STARTUP_SCRIPT_CODE = textwrap.dedent(
39
- f"""
40
- #!/bin/bash
41
-
42
- ##### Get system scripts directory #####
43
- SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
44
-
45
- ##### Perform common set up steps #####
46
- set -e # exit if a command fails
47
-
48
- echo "Creating log directories..."
49
- mkdir -p /var/log/managedservices/user/mlrs
50
- mkdir -p /var/log/managedservices/system/mlrs
51
- mkdir -p /var/log/managedservices/system/ray
52
-
53
- echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
54
- echo "" >> /etc/cron.d/ray_copy_cron
55
- chmod 744 /etc/cron.d/ray_copy_cron
56
-
57
- service cron start
58
-
59
- mkdir -p /tmp/prometheus-multi-dir
60
-
61
- # Change directory to user payload directory
62
- if [ -n "${constants.PAYLOAD_DIR_ENV_VAR}" ]; then
63
- cd ${constants.STAGE_MOUNT_PATH_ENV_VAR}/${constants.PAYLOAD_DIR_ENV_VAR}
64
- fi
65
-
66
- ##### Set up Python environment #####
67
- export PYTHONPATH=/opt/env/site-packages/
68
- MLRS_SYSTEM_REQUIREMENTS_FILE=${{MLRS_SYSTEM_REQUIREMENTS_FILE:-"${{SYSTEM_DIR}}/requirements.txt"}}
69
-
70
- if [ -f "${{MLRS_SYSTEM_REQUIREMENTS_FILE}}" ]; then
71
- echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
72
- if ! pip install --no-index -r $MLRS_SYSTEM_REQUIREMENTS_FILE; then
73
- echo "Offline install failed, falling back to regular pip install"
74
- pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
75
- fi
76
- fi
77
-
78
- MLRS_REQUIREMENTS_FILE=${{MLRS_REQUIREMENTS_FILE:-"requirements.txt"}}
79
- if [ -f "${{MLRS_REQUIREMENTS_FILE}}" ]; then
80
- # TODO: Prevent collisions with MLRS packages using virtualenvs
81
- echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
82
- pip install -r $MLRS_REQUIREMENTS_FILE
83
- fi
84
-
85
- MLRS_CONDA_ENV_FILE=${{MLRS_CONDA_ENV_FILE:-"environment.yml"}}
86
- if [ -f "${{MLRS_CONDA_ENV_FILE}}" ]; then
87
- # TODO: Handle conda environment
88
- echo "Custom conda environments not currently supported"
89
- exit 1
90
- fi
91
- ##### End Python environment setup #####
92
-
93
- ##### Ray configuration #####
94
- shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
95
-
96
- # Check if the local get_instance_ip.py script exists
97
- HELPER_EXISTS=$(
98
- [ -f "${{SYSTEM_DIR}}/get_instance_ip.py" ] && echo "true" || echo "false"
99
- )
100
-
101
39
 
102
- # Configure IP address and logging directory
103
- if [ "$HELPER_EXISTS" = "true" ]; then
104
- eth0Ip=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" \
105
- "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
106
- else
107
- eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
108
- fi
109
- log_dir="/tmp/ray"
110
-
111
- # Check if eth0Ip is a valid IP address and fall back to default if necessary
112
- if [[ ! $eth0Ip =~ ^[0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+$ ]]; then
113
- eth0Ip="127.0.0.1"
114
- fi
115
-
116
- # Get the environment values of SNOWFLAKE_JOBS_COUNT and SNOWFLAKE_JOB_INDEX for batch jobs
117
- # These variables don't exist for non-batch jobs, so set defaults
118
- if [ -z "$SNOWFLAKE_JOBS_COUNT" ]; then
119
- SNOWFLAKE_JOBS_COUNT=1
120
- fi
121
-
122
- if [ -z "$SNOWFLAKE_JOB_INDEX" ]; then
123
- SNOWFLAKE_JOB_INDEX=0
124
- fi
125
-
126
- # Determine if it should be a worker or a head node for batch jobs
127
- if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 && "$HELPER_EXISTS" = "true" ]]; then
128
- head_info=$(python3 "${{SYSTEM_DIR}}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
129
- if [ $? -eq 0 ]; then
130
- # Parse the output using read
131
- read head_index head_ip head_status<<< "$head_info"
132
-
133
- if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
134
- NODE_TYPE="worker"
135
- echo "{constants.LOG_START_MSG}"
136
- fi
137
-
138
- # Use the parsed variables
139
- echo "Head Instance Index: $head_index"
140
- echo "Head Instance IP: $head_ip"
141
- echo "Head Instance Status: $head_status"
142
-
143
- # If the head status is not "READY" or "PENDING", exit early
144
- if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
145
- echo "Head instance status is not READY or PENDING. Exiting."
146
- exit 0
147
- fi
148
-
149
- else
150
- echo "Error: Failed to get head instance information."
151
- echo "$head_info" # Print the error message
152
- exit 1
153
- fi
154
-
155
-
156
- fi
157
-
158
- # Common parameters for both head and worker nodes
159
- common_params=(
160
- "--node-ip-address=$eth0Ip"
161
- "--object-manager-port=${{RAY_OBJECT_MANAGER_PORT:-12011}}"
162
- "--node-manager-port=${{RAY_NODE_MANAGER_PORT:-12012}}"
163
- "--runtime-env-agent-port=${{RAY_RUNTIME_ENV_AGENT_PORT:-12013}}"
164
- "--dashboard-agent-grpc-port=${{RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}}"
165
- "--dashboard-agent-listen-port=${{RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}}"
166
- "--min-worker-port=${{RAY_MIN_WORKER_PORT:-12031}}"
167
- "--max-worker-port=${{RAY_MAX_WORKER_PORT:-13000}}"
168
- "--metrics-export-port=11502"
169
- "--temp-dir=$log_dir"
170
- "--disable-usage-stats"
171
- )
172
-
173
- if [ "$NODE_TYPE" = "worker" ]; then
174
- # Use head_ip as head address if it exists
175
- if [ ! -z "$head_ip" ]; then
176
- RAY_HEAD_ADDRESS="$head_ip"
177
- fi
178
-
179
- # If RAY_HEAD_ADDRESS is still empty, exit with an error
180
- if [ -z "$RAY_HEAD_ADDRESS" ]; then
181
- echo "Error: Failed to determine head node address using default instance-index=0"
182
- exit 1
183
- fi
184
-
185
- if [ -z "$SERVICE_NAME" ]; then
186
- SERVICE_NAME="$SNOWFLAKE_SERVICE_NAME"
187
- fi
188
-
189
- if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
190
- echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
191
- exit 1
192
- fi
193
-
194
- # Additional worker-specific parameters
195
- worker_params=(
196
- "--address=${{RAY_HEAD_ADDRESS}}:12001" # Connect to head node
197
- "--resources={{\\"${{SERVICE_NAME}}\\":1, \\"node_tag:worker\\":1}}" # Tag for node identification
198
- "--object-store-memory=${{shm_size}}"
199
- )
200
40
 
201
- # Start Ray on a worker node - run in background
202
- ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
203
-
204
- echo "Worker node started on address $eth0Ip. See more logs in the head node."
205
-
206
- echo "{constants.LOG_END_MSG}"
207
-
208
- # Start the worker shutdown listener in the background
209
- echo "Starting worker shutdown listener..."
210
- python "${{SYSTEM_DIR}}/worker_shutdown_listener.py"
211
- WORKER_EXIT_CODE=$?
212
-
213
- echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
214
- exit $WORKER_EXIT_CODE
215
- else
216
- # Additional head-specific parameters
217
- head_params=(
218
- "--head"
219
- "--port=${{RAY_HEAD_GCS_PORT:-12001}}" # Port of Ray (GCS server)
220
- "--ray-client-server-port=${{RAY_HEAD_CLIENT_SERVER_PORT:-10001}}" # Rort for Ray Client Server
221
- "--dashboard-host=${{NODE_IP_ADDRESS}}" # Host to bind the dashboard server
222
- "--dashboard-grpc-port=${{RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}}" # Dashboard head to listen for grpc
223
- "--dashboard-port=${{DASHBOARD_PORT}}" # Port to bind the dashboard server for debugging
224
- "--resources={{\\"node_tag:head\\":1}}" # Resource tag for selecting head as coordinator
41
+ def _compress_and_upload_file(
42
+ session: snowpark.Session, source_path: Path, stage_path: PurePath, import_path: Optional[str] = None
43
+ ) -> None:
44
+ absolute_source_path = source_path.absolute()
45
+ leading_path = absolute_source_path.as_posix()[: -len(import_path)] if import_path else None
46
+ filename = f"{source_path.name}.zip" if source_path.is_dir() or source_path.suffix == ".py" else source_path.name
47
+ with zip_file_or_directory_to_stream(source_path.absolute().as_posix(), leading_path) as stream:
48
+ session.file.put_stream(
49
+ cast(IO[bytes], stream),
50
+ stage_path.joinpath(filename).as_posix(),
51
+ auto_compress=False,
52
+ overwrite=True,
225
53
  )
226
54
 
227
- # Start Ray on the head node
228
- ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
229
-
230
- ##### End Ray configuration #####
231
-
232
- # TODO: Monitor MLRS and handle process crashes
233
- python -m web.ml_runtime_grpc_server &
234
-
235
- # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
236
- echo Running command: python "$@"
237
-
238
- # Run user's Python entrypoint
239
- python "$@"
240
-
241
- # After the user's job completes, signal workers to shut down
242
- echo "User job completed. Signaling workers to shut down..."
243
- python "${{SYSTEM_DIR}}/signal_workers.py" --wait-time 15
244
- echo "Head node job completed. Exiting."
245
- fi
246
- """
247
- ).strip()
248
55
 
56
+ def _upload_directory(session: snowpark.Session, source_path: Path, payload_stage_path: PurePath) -> None:
57
+ # Manually traverse the directory and upload each file, since Snowflake PUT
58
+ # can't handle directories. Reduce the number of PUT operations by using
59
+ # wildcard patterns to batch upload files with the same extension.
60
+ upload_path_patterns = set()
61
+ for p in source_path.rglob("*"):
62
+ if p.is_dir():
63
+ continue
64
+ # Skip python cache files
65
+ if "__pycache__" in p.parts or p.suffix == ".pyc":
66
+ continue
67
+ if p.name.startswith("."):
68
+ # Hidden files: use .* pattern for batch upload
69
+ if p.suffix:
70
+ upload_path_patterns.add(p.parent.joinpath(f".*{p.suffix}"))
71
+ else:
72
+ upload_path_patterns.add(p.parent.joinpath(".*"))
73
+ else:
74
+ # Regular files: use * pattern for batch upload
75
+ if p.suffix:
76
+ upload_path_patterns.add(p.parent.joinpath(f"*{p.suffix}"))
77
+ else:
78
+ upload_path_patterns.add(p)
249
79
 
250
- def resolve_path(path: str) -> types.PayloadPath:
251
- try:
252
- stage_path = stage_utils.StagePath(path)
253
- except ValueError:
254
- return Path(path)
255
- return stage_path
80
+ for path in upload_path_patterns:
81
+ session.file.put(
82
+ str(path),
83
+ payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
84
+ overwrite=True,
85
+ auto_compress=False,
86
+ )
256
87
 
257
88
 
258
89
  def upload_payloads(session: snowpark.Session, stage_path: PurePath, *payload_specs: types.PayloadSpec) -> None:
259
- for source_path, remote_relative_path in payload_specs:
90
+ for spec in payload_specs:
91
+ source_path = spec.source_path
92
+ remote_relative_path = spec.remote_relative_path
93
+ compress = spec.compress
260
94
  payload_stage_path = stage_path.joinpath(remote_relative_path) if remote_relative_path else stage_path
261
95
  if isinstance(source_path, stage_utils.StagePath):
262
96
  # only copy files into one stage directory from another stage directory, not from stage file
263
97
  # due to incomplete of StagePath functionality
264
- session.sql(f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/").collect()
98
+ if source_path.as_posix().endswith(".py"):
99
+ session.sql(f"copy files into {stage_path.as_posix()}/ from {source_path.as_posix()}").collect()
100
+ else:
101
+ session.sql(
102
+ f"copy files into {payload_stage_path.as_posix()}/ from {source_path.as_posix()}/"
103
+ ).collect()
265
104
  elif isinstance(source_path, Path):
266
105
  if source_path.is_dir():
267
- # Manually traverse the directory and upload each file, since Snowflake PUT
268
- # can't handle directories. Reduce the number of PUT operations by using
269
- # wildcard patterns to batch upload files with the same extension.
270
- upload_path_patterns = set()
271
- for p in source_path.rglob("*"):
272
- if p.is_dir():
273
- continue
274
- if p.name.startswith("."):
275
- # Hidden files: use .* pattern for batch upload
276
- if p.suffix:
277
- upload_path_patterns.add(p.parent.joinpath(f".*{p.suffix}"))
278
- else:
279
- upload_path_patterns.add(p.parent.joinpath(".*"))
280
- else:
281
- # Regular files: use * pattern for batch upload
282
- if p.suffix:
283
- upload_path_patterns.add(p.parent.joinpath(f"*{p.suffix}"))
284
- else:
285
- upload_path_patterns.add(p)
286
-
287
- for path in upload_path_patterns:
106
+ if compress:
107
+ _compress_and_upload_file(
108
+ session,
109
+ source_path,
110
+ stage_path,
111
+ remote_relative_path.as_posix() if remote_relative_path else None,
112
+ )
113
+ else:
114
+ _upload_directory(session, source_path, payload_stage_path)
115
+
116
+ elif source_path.is_file():
117
+ if compress and source_path.suffix == ".py":
118
+ _compress_and_upload_file(
119
+ session,
120
+ source_path,
121
+ stage_path,
122
+ remote_relative_path.as_posix() if remote_relative_path else None,
123
+ )
124
+ else:
288
125
  session.file.put(
289
- str(path),
290
- payload_stage_path.joinpath(path.parent.relative_to(source_path)).as_posix(),
126
+ str(source_path.resolve()),
127
+ payload_stage_path.as_posix(),
291
128
  overwrite=True,
292
129
  auto_compress=False,
293
130
  )
294
- else:
295
- session.file.put(
296
- str(source_path.resolve()),
297
- payload_stage_path.as_posix(),
298
- overwrite=True,
299
- auto_compress=False,
300
- )
301
131
 
302
132
 
303
133
  def upload_system_resources(session: snowpark.Session, stage_path: PurePath) -> None:
@@ -336,8 +166,30 @@ def resolve_source(
336
166
 
337
167
  def resolve_entrypoint(
338
168
  source: Union[types.PayloadPath, Callable[..., Any]],
339
- entrypoint: Optional[types.PayloadPath],
340
- ) -> types.PayloadEntrypoint:
169
+ entrypoint: Optional[Union[types.PayloadPath, list[str]]],
170
+ ) -> Union[types.PayloadEntrypoint, list[str]]:
171
+ """Resolve and validate the entrypoint for a job payload.
172
+
173
+ Args:
174
+ source: The source path or callable for the job payload.
175
+ entrypoint: The entrypoint specification. Can be:
176
+ - A path (str or Path) to a Python script file
177
+ - A list of strings representing a custom command (passed through as-is)
178
+ - None (inferred from source if source is a file)
179
+
180
+ Returns:
181
+ Either a PayloadEntrypoint object for file-based entrypoints, or the list
182
+ of strings passed through unchanged for custom command entrypoints.
183
+
184
+ Raises:
185
+ ValueError: If the entrypoint is invalid or cannot be resolved.
186
+ FileNotFoundError: If the entrypoint file does not exist.
187
+ """
188
+ # If entrypoint is a list, pass it through without resolution/validation
189
+ # This allows users to specify custom entrypoints (e.g., installed CLI tools)
190
+ if isinstance(entrypoint, (list, tuple)):
191
+ return entrypoint
192
+
341
193
  if callable(source):
342
194
  # Entrypoint is generated for callable payloads
343
195
  return types.PayloadEntrypoint(
@@ -385,87 +237,223 @@ def resolve_entrypoint(
385
237
  )
386
238
 
387
239
 
388
- def resolve_additional_payloads(
389
- additional_payloads: Optional[list[Union[str, tuple[str, str]]]]
390
- ) -> list[types.PayloadSpec]:
240
+ def get_zip_file_from_path(path: types.PayloadPath) -> types.PayloadPath:
241
+ """Finds the path of the outermost zip archive from a given file path.
242
+
243
+ Examples:
244
+ >>> get_zip_file_from_path("/path/to/archive.zip/nested_file.py")
245
+ "/path/to/archive.zip"
246
+ >>> get_zip_file_from_path("/path/to/archive.zip")
247
+ "/path/to/archive.zip"
248
+ >>> get_zip_file_from_path("/path/to/regular_file.py")
249
+ "/path/to/regular_file.py"
250
+
251
+ Args:
252
+ path: The file path to inspect.
253
+
254
+ Returns:
255
+ str: The path to the outermost zip file, or the original path if
256
+ none is found.
391
257
  """
392
- Determine how to stage local packages so that imports continue to work.
258
+
259
+ path_str = path.as_posix()
260
+
261
+ index = path_str.rfind(".zip/")
262
+ if index != -1:
263
+ return stage_utils.resolve_path(path_str[: index + 4])
264
+ return path
265
+
266
+
267
+ def _finalize_payload_pair(
268
+ p: types.PayloadPath, base_import_path: Optional[str]
269
+ ) -> tuple[types.PayloadPath, Optional[str]]:
270
+ """Finalize the `(payload_path, import_path)` pair based on source type.
271
+
272
+ - Zip file: ignore import path (returns `(p, None)`).
273
+ - Python file: if `base_import_path` is provided, append ".py"; otherwise None.
274
+ - Directory: preserve `base_import_path` as-is.
275
+ - Stage file: use `base_import_path` as-is since we do not compress stage files.
276
+ - Other files: ignore import path (None).
393
277
 
394
278
  Args:
395
- additional_payloads: A list of directory paths, each optionally paired with a dot-separated
396
- import path
397
- e.g. [("proj/src/utils", "src.utils"), "proj/src/helper"]
398
- if there is no import path, the last part of path will be considered as import path
399
- e.g. the import path of "proj/src/helper" is "helper"
279
+ p (types.PayloadPath): The resolved source path
280
+ base_import_path (Optional[str]): Slash-separated import path
400
281
 
401
282
  Returns:
402
- A list of payloadSpec for additional payloads.
283
+ tuple[types.PayloadPath, Optional[str]]: `(p, final_import_path)` where:
284
+ - `final_import_path` is None for zip archives and non-Python files.
285
+ - `final_import_path` is `base_import_path + ".py"` for Python files when
286
+ `base_import_path` is provided; otherwise None.
287
+ - `final_import_path` is `base_import_path` for directories.
288
+
289
+ """
290
+ if p.suffix == ".zip":
291
+ final_import_path = None
292
+ elif isinstance(p, stage_utils.StagePath):
293
+ final_import_path = base_import_path
294
+ elif p.is_file():
295
+ if p.suffix == ".py":
296
+ final_import_path = (base_import_path + ".py") if base_import_path else None
297
+ else:
298
+ final_import_path = None
299
+ else:
300
+ final_import_path = base_import_path
301
+
302
+ validate_import_path(p, final_import_path)
303
+ return (p, None) if p.suffix == ".zip" else (p, final_import_path)
403
304
 
404
- Raises:
405
- FileNotFoundError: If any specified package path does not exist.
406
- ValueError: If the format of local_packages is invalid.
407
305
 
306
+ def resolve_import_path(
307
+ path: Union[types.PayloadPath, ModuleType],
308
+ import_path: Optional[str] = None,
309
+ ) -> list[tuple[types.PayloadPath, Optional[str]]]:
408
310
  """
409
- if not additional_payloads:
410
- return []
311
+ Resolve and normalize the import path for modules, Python files, or zip payloads.
411
312
 
412
- logger.warning(
413
- "When providing a stage path as an additional payload, "
414
- "please ensure it points to a directory. "
415
- "Files are not currently supported."
416
- )
313
+ Args:
314
+ path (Union[types.PayloadPath, ModuleType]): The source path or module to resolve.
315
+ - If a directory is provided, it is compressed as a zip archive preserving its structure.
316
+ - If a single Python file is provided, the file itself is zipped.
317
+ - If a module is provided, it is treated as a directory or Python file.
318
+ - If a zip file is provided, it is uploaded as it is.
319
+ - If a stage file is provided, we only support stage file when the import path is provided
320
+ import_path (Optional[str], optional): Explicit import path to use. If None,
321
+ the function infers it from `path`.
417
322
 
418
- additional_payloads_paths = []
419
- for pkg in additional_payloads:
420
- if isinstance(pkg, str):
421
- source_path = resolve_path(pkg).absolute()
422
- module_path = source_path.name
423
- elif isinstance(pkg, tuple):
424
- try:
425
- source_path_str, module_path = pkg
426
- except ValueError:
427
- raise ValueError(
428
- f"Invalid format in `additional_payloads`. "
429
- f"Expected a tuple of (source_path, module_path). Got {pkg}"
430
- )
431
- source_path = resolve_path(source_path_str).absolute()
323
+ Returns:
324
+ list[tuple[types.PayloadPath, Optional[str]]]: A list of tuples where each tuple
325
+ contains the resolved payload path and its corresponding import path (if any).
326
+
327
+ Raises:
328
+ FileNotFoundError: If the provided `path` does not exist.
329
+ NotImplementedError: If the stage file is provided without an import path.
330
+ ValueError: If the import path cannot be resolved or is invalid.
331
+ """
332
+ if import_path is None:
333
+ import_path = path.stem if isinstance(path, types.PayloadPath) else path.__name__
334
+ import_path = import_path.strip().replace(".", "/") if import_path else None
335
+ if isinstance(path, Path):
336
+ if not path.exists():
337
+ raise FileNotFoundError(f"{path} is not found")
338
+ return [_finalize_payload_pair(path.absolute(), import_path)]
339
+ elif isinstance(path, stage_utils.StagePath):
340
+ if import_path:
341
+ return [_finalize_payload_pair(path.absolute(), import_path)]
342
+ raise NotImplementedError("We only support stage file when the import path is provided")
343
+ elif isinstance(path, ModuleType):
344
+ if hasattr(path, "__path__"):
345
+ paths = [get_zip_file_from_path(stage_utils.resolve_path(p).absolute()) for p in path.__path__]
346
+ return [_finalize_payload_pair(p, import_path) for p in paths]
347
+ elif hasattr(path, "__file__") and path.__file__:
348
+ p = get_zip_file_from_path(Path(path.__file__).absolute())
349
+ return [_finalize_payload_pair(p, import_path)]
432
350
  else:
433
- raise ValueError("the format of additional payload is not correct")
351
+ raise ValueError(f"Module {path} is not a valid module")
352
+ else:
353
+ raise ValueError(f"Module {path} is not a valid imports")
434
354
 
435
- if not source_path.exists():
436
- raise FileNotFoundError(f"{source_path} does not exist")
437
355
 
438
- if isinstance(source_path, Path):
439
- if source_path.is_file():
440
- raise ValueError(f"file is not supported for additional payloads: {source_path}")
356
+ def validate_import_path(source: Union[str, types.PayloadPath], import_path: Optional[str]) -> None:
357
+ """Validate the import path for local python file or directory."""
358
+ if import_path is None:
359
+ return
441
360
 
442
- module_parts = module_path.split(".")
443
- for part in module_parts:
444
- if not part.isidentifier() or keyword.iskeyword(part):
445
- raise ValueError(
446
- f"Invalid module import path '{module_path}'. "
447
- f"'{part}' is not a valid Python identifier or is a keyword."
448
- )
361
+ source_path = stage_utils.resolve_path(source) if isinstance(source, str) else source
362
+ if isinstance(source_path, stage_utils.StagePath):
363
+ if not source_path.as_posix().endswith(import_path + ".py"):
364
+ raise ValueError(f"Import path {import_path} must end with the source name {source_path}")
365
+ elif (source_path.is_file() and source_path.suffix == ".py") or source_path.is_dir():
366
+ if not source_path.as_posix().endswith(import_path):
367
+ raise ValueError(f"Import path {import_path} must end with the source name {source_path}")
449
368
 
450
- dest_path = PurePath(*module_parts)
451
- additional_payloads_paths.append(types.PayloadSpec(source_path, dest_path))
452
- return additional_payloads_paths
369
+
370
+ def upload_imports(
371
+ imports: Optional[list[Union[str, Path, ModuleType, tuple[Union[str, Path, ModuleType], Optional[str]]]]],
372
+ session: snowpark.Session,
373
+ stage_path: PurePath,
374
+ ) -> None:
375
+ """Resolve paths and upload imports for ML Jobs.
376
+
377
+ Args:
378
+ imports: Optional list of paths/modules, or tuples of
379
+ ``(path_or_module, import_path)``. The path can be a local
380
+ directory, a local ``.py`` file, a local ``.zip`` file, or a stage
381
+ path (for example, ``@stage/path``). If a tuple is provided and the
382
+ first element is a local directory or ``.py`` file, the second
383
+ element denotes the Python import path (dot or slash separated) to
384
+ which the content should be mounted. If not provided for local
385
+ sources, it defaults to the stem of the path/module. For stage
386
+ paths or non-Python local files, the import path is ignored.
387
+ session: Active Snowpark session used to upload files.
388
+ stage_path: Destination stage subpath where payloads will be uploaded.
389
+
390
+ Raises:
391
+ ValueError: If a import has an invalid format or the
392
+ provided import path is incompatible with the source.
393
+
394
+ """
395
+ if not imports:
396
+ return
397
+ for additional_payload in imports:
398
+ if isinstance(additional_payload, tuple):
399
+ source, import_path = additional_payload
400
+ elif isinstance(additional_payload, str) or isinstance(additional_payload, ModuleType):
401
+ source = additional_payload
402
+ import_path = None
403
+ else:
404
+ raise ValueError(f"Invalid import format: {additional_payload}")
405
+ resolved_imports = resolve_import_path(
406
+ stage_utils.resolve_path(source) if not isinstance(source, ModuleType) else source, import_path
407
+ )
408
+ for source_path, import_path in resolved_imports:
409
+ # TODO(SNOW-2467038): support import path for stage files or directories
410
+ if isinstance(source_path, stage_utils.StagePath):
411
+ remote = None
412
+ compress = False
413
+ elif source_path.as_posix().endswith(".zip"):
414
+ remote = None
415
+ compress = False
416
+ elif source_path.is_dir() or source_path.suffix == ".py":
417
+ remote = PurePath(import_path) if import_path else None
418
+ compress = True
419
+ else:
420
+ # if the file is not a python file, ignore the import path
421
+ remote = None
422
+ compress = False
423
+
424
+ upload_payloads(session, stage_path, types.PayloadSpec(source_path, remote, compress=compress))
453
425
 
454
426
 
455
427
  class JobPayload:
456
428
  def __init__(
457
429
  self,
458
430
  source: Union[str, Path, Callable[..., Any]],
459
- entrypoint: Optional[Union[str, Path]] = None,
431
+ entrypoint: Optional[Union[str, Path, list[str]]] = None,
460
432
  *,
461
433
  pip_requirements: Optional[list[str]] = None,
462
- additional_payloads: Optional[list[Union[str, tuple[str, str]]]] = None,
434
+ imports: Optional[list[Union[ImportType, tuple[ImportType, Optional[str]]]]] = None,
463
435
  ) -> None:
436
+ """Initialize a job payload.
437
+
438
+ Args:
439
+ source: The source for the job payload. Can be a file path, directory path,
440
+ stage path, or a callable.
441
+ entrypoint: The entrypoint for job execution. Can be:
442
+ - A path (str or Path) to a Python script file
443
+ - A list of strings representing a custom command (e.g., ["arctic_training"])
444
+ which is passed through as-is without resolution or validation
445
+ - None (inferred from source if source is a file)
446
+ pip_requirements: A list of pip requirements for the job.
447
+ imports: A list of additional imports/payloads for the job.
448
+ """
464
449
  # for stage path like snow://domain....., Path(path) will remove duplicate /, it will become snow:/ domain...
465
- self.source = resolve_path(source) if isinstance(source, str) else source
466
- self.entrypoint = resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
450
+ self.source = stage_utils.resolve_path(source) if isinstance(source, str) else source
451
+ if isinstance(entrypoint, list):
452
+ self.entrypoint: Optional[Union[types.PayloadPath, list[str]]] = entrypoint
453
+ else:
454
+ self.entrypoint = stage_utils.resolve_path(entrypoint) if isinstance(entrypoint, str) else entrypoint
467
455
  self.pip_requirements = pip_requirements
468
- self.additional_payloads = additional_payloads
456
+ self.imports = imports
469
457
 
470
458
  def upload(self, session: snowpark.Session, stage_path: Union[str, PurePath]) -> types.UploadedPayload:
471
459
  # Prepare local variables
@@ -473,7 +461,6 @@ class JobPayload:
473
461
  source = resolve_source(self.source)
474
462
  entrypoint = resolve_entrypoint(source, self.entrypoint)
475
463
  pip_requirements = self.pip_requirements or []
476
- additional_payload_specs = resolve_additional_payloads(self.additional_payloads)
477
464
 
478
465
  # Create stage if necessary
479
466
  stage_name = stage_path.parts[0].lstrip("@")
@@ -488,70 +475,81 @@ class JobPayload:
488
475
  " comment = 'Created by snowflake.ml.jobs Python API'",
489
476
  params=[stage_name],
490
477
  )
491
- payload_name = None
478
+
492
479
  # Upload payload to stage - organize into app/ subdirectory
493
480
  app_stage_path = stage_path.joinpath(constants.APP_STAGE_SUBPATH)
494
- if not isinstance(source, types.PayloadPath):
495
- if isinstance(source, function_payload_utils.FunctionPayload):
496
- payload_name = source.function.__name__
497
-
498
- source_code = generate_python_code(source, source_code_display=True)
499
- _ = session.file.put_stream(
500
- io.BytesIO(source_code.encode()),
501
- stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
481
+ upload_imports(self.imports, session, app_stage_path)
482
+
483
+ # Handle list entrypoints (custom commands like ["arctic_training"])
484
+ if isinstance(entrypoint, (list, tuple)):
485
+ payload_name = entrypoint[0] if entrypoint else None
486
+ # For list entrypoints, still upload source if it's a path
487
+ if isinstance(source, Path):
488
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
489
+ elif isinstance(source, stage_utils.StagePath):
490
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
491
+ python_entrypoint: list[Union[str, PurePath]] = list(entrypoint)
492
+ else:
493
+ # Standard file-based entrypoint handling
494
+ payload_name = None
495
+ if not isinstance(source, types.PayloadPath):
496
+ if isinstance(source, function_payload_utils.FunctionPayload):
497
+ payload_name = source.function.__name__
498
+
499
+ source_code = generate_python_code(source, source_code_display=True)
500
+ _ = session.file.put_stream(
501
+ io.BytesIO(source_code.encode()),
502
+ stage_location=app_stage_path.joinpath(entrypoint.file_path).as_posix(),
503
+ auto_compress=False,
504
+ overwrite=True,
505
+ )
506
+ source = Path(entrypoint.file_path.parent)
507
+
508
+ elif isinstance(source, stage_utils.StagePath):
509
+ payload_name = entrypoint.file_path.stem
510
+ # copy payload to stage
511
+ if source == entrypoint.file_path:
512
+ source = source.parent
513
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
514
+
515
+ elif isinstance(source, Path):
516
+ payload_name = entrypoint.file_path.stem
517
+ upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
518
+ if source.is_file():
519
+ source = source.parent
520
+
521
+ python_entrypoint = [
522
+ PurePath(
523
+ constants.STAGE_VOLUME_MOUNT_PATH,
524
+ constants.APP_STAGE_SUBPATH,
525
+ entrypoint.file_path.relative_to(source).as_posix(),
526
+ ),
527
+ ]
528
+ if entrypoint.main_func:
529
+ python_entrypoint += ["--script_main_func", entrypoint.main_func]
530
+
531
+ if pip_requirements:
532
+ session.file.put_stream(
533
+ io.BytesIO("\n".join(pip_requirements).encode()),
534
+ stage_location=app_stage_path.joinpath("requirements.txt").as_posix(),
502
535
  auto_compress=False,
503
536
  overwrite=True,
504
537
  )
505
- source = Path(entrypoint.file_path.parent)
506
-
507
- elif isinstance(source, stage_utils.StagePath):
508
- payload_name = entrypoint.file_path.stem
509
- # copy payload to stage
510
- if source == entrypoint.file_path:
511
- source = source.parent
512
- upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
513
538
 
514
- elif isinstance(source, Path):
515
- payload_name = entrypoint.file_path.stem
516
- upload_payloads(session, app_stage_path, types.PayloadSpec(source, None))
517
- if source.is_file():
518
- source = source.parent
519
-
520
- upload_payloads(session, app_stage_path, *additional_payload_specs)
521
-
522
- if not any(r.startswith("cloudpickle") for r in pip_requirements):
523
- pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
524
-
525
- # Upload system scripts and requirements.txt generated by pip_requirements to system/ directory
539
+ # Upload system scripts and other assets to system/ directory
526
540
  system_stage_path = stage_path.joinpath(constants.SYSTEM_STAGE_SUBPATH)
527
- if pip_requirements:
528
- # Upload requirements.txt to stage
541
+ system_pip_requirements = []
542
+ if not any(r.startswith("cloudpickle") for r in pip_requirements):
543
+ system_pip_requirements.append(f"cloudpickle~={version.parse(cp.__version__).major}.0")
544
+ if system_pip_requirements:
545
+ # Upload requirements.txt to system path in stage
529
546
  session.file.put_stream(
530
- io.BytesIO("\n".join(pip_requirements).encode()),
547
+ io.BytesIO("\n".join(system_pip_requirements).encode()),
531
548
  stage_location=system_stage_path.joinpath("requirements.txt").as_posix(),
532
549
  auto_compress=False,
533
550
  overwrite=True,
534
551
  )
535
-
536
- # TODO: Make sure payload does not include file with same name
537
- session.file.put_stream(
538
- io.BytesIO(_STARTUP_SCRIPT_CODE.encode()),
539
- stage_location=system_stage_path.joinpath(_STARTUP_SCRIPT_PATH).as_posix(),
540
- auto_compress=False,
541
- overwrite=False, # FIXME
542
- )
543
-
544
552
  upload_system_resources(session, system_stage_path)
545
- python_entrypoint: list[Union[str, PurePath]] = [
546
- PurePath(constants.STAGE_VOLUME_MOUNT_PATH, constants.SYSTEM_STAGE_SUBPATH, "mljob_launcher.py"),
547
- PurePath(
548
- constants.STAGE_VOLUME_MOUNT_PATH,
549
- constants.APP_STAGE_SUBPATH,
550
- entrypoint.file_path.relative_to(source).as_posix(),
551
- ),
552
- ]
553
- if entrypoint.main_func:
554
- python_entrypoint += ["--script_main_func", entrypoint.main_func]
555
553
 
556
554
  env_vars = {
557
555
  constants.STAGE_MOUNT_PATH_ENV_VAR: constants.STAGE_VOLUME_MOUNT_PATH,