snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -6,9 +6,13 @@ import logging
6
6
  import math
7
7
  import os
8
8
  import runpy
9
+ import shutil
10
+ import subprocess
9
11
  import sys
10
12
  import time
11
13
  import traceback
14
+ import zipfile
15
+ from pathlib import Path
12
16
  from typing import Any, Optional
13
17
 
14
18
  # Ensure payload directory is in sys.path for module imports before importing other modules
@@ -18,11 +22,17 @@ from typing import Any, Optional
18
22
  STAGE_MOUNT_PATH = os.environ.get("MLRS_STAGE_MOUNT_PATH", "/mnt/job_stage")
19
23
  JOB_RESULT_PATH = os.environ.get("MLRS_RESULT_PATH", "output/mljob_result.pkl")
20
24
  PAYLOAD_PATH = os.environ.get("MLRS_PAYLOAD_DIR")
25
+
21
26
  if PAYLOAD_PATH and not os.path.isabs(PAYLOAD_PATH):
22
27
  PAYLOAD_PATH = os.path.join(STAGE_MOUNT_PATH, PAYLOAD_PATH)
23
- if PAYLOAD_PATH and PAYLOAD_PATH not in sys.path:
24
- sys.path.insert(0, PAYLOAD_PATH)
25
28
 
29
+ if PAYLOAD_PATH:
30
+ if PAYLOAD_PATH not in sys.path:
31
+ sys.path.insert(0, PAYLOAD_PATH)
32
+ for zip_file in Path(PAYLOAD_PATH).rglob("*.zip"):
33
+ fpath = str(zip_file)
34
+ if fpath not in sys.path and zipfile.is_zipfile(fpath):
35
+ sys.path.insert(0, fpath)
26
36
  # Imports below must come after sys.path modification to support module overrides
27
37
  import snowflake.ml.jobs._utils.constants # noqa: E402
28
38
  import snowflake.snowpark # noqa: E402
@@ -81,6 +91,76 @@ TIMEOUT = float(os.getenv(INSTANCES_TIMEOUT_ENV_VAR) or 720) # seconds
81
91
  CHECK_INTERVAL = float(os.getenv(INSTANCES_CHECK_INTERVAL_ENV_VAR) or 10) # seconds
82
92
 
83
93
 
94
+ def is_python_script(file_path: str) -> bool:
95
+ """Check if a file is a Python script by examining its shebang.
96
+
97
+ Args:
98
+ file_path: Path to the file to check.
99
+
100
+ Returns:
101
+ True if the file has a shebang line containing 'python', False otherwise.
102
+ """
103
+ try:
104
+ with open(file_path, "rb") as f:
105
+ first_line = f.readline()
106
+ if first_line.startswith(b"#!"):
107
+ shebang = first_line.decode("utf-8", errors="ignore").lower()
108
+ return "python" in shebang
109
+ except OSError:
110
+ pass
111
+ return False
112
+
113
+
114
+ def resolve_entrypoint(entrypoint: str) -> tuple[str, bool]:
115
+ """Resolve the entrypoint to determine how to execute it.
116
+
117
+ Args:
118
+ entrypoint: The entrypoint string (file path or command name).
119
+
120
+ Returns:
121
+ A tuple of (resolved_path, is_python):
122
+ - resolved_path: The path to the executable/script.
123
+ - is_python: True if this should be run as a Python script.
124
+ """
125
+ # Check if entrypoint is an existing file
126
+ if os.path.isfile(entrypoint):
127
+ # Always run as Python script for backward compatibility
128
+ return entrypoint, True
129
+
130
+ # Try to resolve as a command using shutil.which
131
+ resolved_path = shutil.which(entrypoint)
132
+ if resolved_path:
133
+ if is_python_script(resolved_path):
134
+ return resolved_path, True
135
+ else:
136
+ # Assume it's meant to be used as a command and not a Python script
137
+ return entrypoint, False
138
+
139
+ # If we can't resolve it, assume it's meant to be a Python script path
140
+ # (this preserves backwards compatibility and will fail with a clear error)
141
+ return entrypoint, True
142
+
143
+
144
+ def run_command(command: str, *args: Any) -> None:
145
+ """Execute a command as a subprocess, streaming output and raising an exception if it fails.
146
+
147
+ Args:
148
+ command: Path to the executable.
149
+ args: Arguments to pass to the command.
150
+
151
+ Raises:
152
+ CalledProcessError: If the subprocess exits with a non-zero return code.
153
+ """
154
+ cmd = [command, *[str(arg) for arg in args]]
155
+ logger.debug(f"Running subprocess: {' '.join(cmd)}")
156
+
157
+ # Run subprocess without capturing output - let stdout/stderr flow directly to console
158
+ result = subprocess.run(cmd)
159
+
160
+ if result.returncode != 0:
161
+ raise subprocess.CalledProcessError(result.returncode, cmd)
162
+
163
+
84
164
  def save_mljob_result_v2(value: Any, is_error: bool, path: str) -> None:
85
165
  from snowflake.ml.jobs._interop import (
86
166
  results as interop_result,
@@ -313,11 +393,11 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
313
393
  sys.argv = original_argv
314
394
 
315
395
 
316
- def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
396
+ def main(entrypoint: str, *script_args: Any, script_main_func: Optional[str] = None) -> Any:
317
397
  """Executes a Python script and serializes the result to JOB_RESULT_PATH.
318
398
 
319
399
  Args:
320
- script_path (str): Path to the Python script to execute.
400
+ entrypoint (str): The job payload entrypoint to execute.
321
401
  script_args (Any): Arguments to pass to the script.
322
402
  script_main_func (str, optional): The name of the function to call in the script (if any).
323
403
 
@@ -361,8 +441,15 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
361
441
  # Log start marker before starting user script execution
362
442
  print(LOG_START_MSG) # noqa: T201
363
443
 
364
- # Run the user script
365
- execution_result_value = run_script(script_path, *script_args, main_func=script_main_func)
444
+ # Resolve entrypoint to determine execution method
445
+ resolved_entrypoint, is_python = resolve_entrypoint(entrypoint)
446
+
447
+ if is_python:
448
+ # Run as Python script
449
+ execution_result_value = run_script(resolved_entrypoint, *script_args, main_func=script_main_func)
450
+ else:
451
+ # Run as subprocess
452
+ run_command(resolved_entrypoint, *script_args)
366
453
 
367
454
  # Log end marker for user script execution
368
455
  print(LOG_END_MSG) # noqa: T201
@@ -395,7 +482,7 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
395
482
 
396
483
  if __name__ == "__main__":
397
484
  parser = argparse.ArgumentParser(description="Launch a Python script and save the result")
398
- parser.add_argument("script_path", help="Path to the Python script to execute")
485
+ parser.add_argument("entrypoint", help="The job payload entrypoint to execute")
399
486
  parser.add_argument("script_args", nargs="*", help="Arguments to pass to the script")
400
487
  parser.add_argument(
401
488
  "--script_main_func", required=False, help="The name of the main function to call in the script"
@@ -403,7 +490,7 @@ if __name__ == "__main__":
403
490
  args, unknown_args = parser.parse_known_args()
404
491
 
405
492
  main(
406
- args.script_path,
493
+ args.entrypoint,
407
494
  *args.script_args,
408
495
  *unknown_args,
409
496
  script_main_func=args.script_main_func,
@@ -0,0 +1,92 @@
1
+ #!/bin/bash
2
+
3
+ set -e # exit if a command fails
4
+
5
+ echo "Creating log directories..."
6
+ mkdir -p /var/log/managedservices/user/mlrs
7
+ mkdir -p /var/log/managedservices/system/mlrs
8
+ mkdir -p /var/log/managedservices/system/ray
9
+
10
+ echo "*/1 * * * * root /etc/ray_copy_cron.sh" >> /etc/cron.d/ray_copy_cron
11
+ echo "" >> /etc/cron.d/ray_copy_cron
12
+ chmod 744 /etc/cron.d/ray_copy_cron
13
+
14
+ service cron start
15
+
16
+ mkdir -p /tmp/prometheus-multi-dir
17
+
18
+ # Configure IP address and logging directory
19
+ eth0Ip=$(ifconfig eth0 | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
20
+ log_dir="/tmp/ray"
21
+
22
+ # Check if eth0Ip is empty and set default if necessary
23
+ if [ -z "$eth0Ip" ]; then
24
+ # This should never happen, but just in case ethOIp is not set, we should default to localhost
25
+ eth0Ip="127.0.0.1"
26
+ fi
27
+
28
+ shm_size=$(df --output=size --block-size=1 /dev/shm | tail -n 1)
29
+ total_memory_size=$(awk '/MemTotal/ {print int($2/1024)}' /proc/meminfo)
30
+
31
+ # Determine if dashboard should be enabled based on total memory size
32
+ # Enable dashboard only if total memory size >= 8GB (i.e. not on XS compute pool)
33
+ # TODO (SNOW-2860029): use a environment variable to determine the node type
34
+ total_memory_threshold=8192
35
+ if [ "$total_memory_size" -ge "$total_memory_threshold" ]; then
36
+ enable_dashboard="true"
37
+ else
38
+ enable_dashboard="false"
39
+ fi
40
+
41
+ echo "Shared memory size: $shm_size bytes"
42
+ echo "Dashboard enabled: $enable_dashboard"
43
+
44
+ # Common parameters for both head and worker nodes
45
+ common_params=(
46
+ "--node-ip-address=$eth0Ip"
47
+ "--object-manager-port=${RAY_OBJECT_MANAGER_PORT:-12011}"
48
+ "--node-manager-port=${RAY_NODE_MANAGER_PORT:-12012}"
49
+ "--runtime-env-agent-port=${RAY_RUNTIME_ENV_AGENT_PORT:-12013}"
50
+ "--dashboard-agent-grpc-port=${RAY_DASHBOARD_AGENT_GRPC_PORT:-12014}"
51
+ "--dashboard-agent-listen-port=${RAY_DASHBOARD_AGENT_LISTEN_PORT:-12015}"
52
+ "--min-worker-port=${RAY_MIN_WORKER_PORT:-12031}"
53
+ "--max-worker-port=${RAY_MAX_WORKER_PORT:-13000}"
54
+ "--metrics-export-port=11502"
55
+ "--temp-dir=$log_dir"
56
+ "--disable-usage-stats"
57
+ )
58
+
59
+ # Specific parameters for head and worker nodes
60
+ if [ "$NODE_TYPE" = "worker" ]; then
61
+ # Check mandatory environment variables for worker
62
+ if [ -z "$RAY_HEAD_ADDRESS" ] || [ -z "$SERVICE_NAME" ]; then
63
+ echo "Error: RAY_HEAD_ADDRESS and SERVICE_NAME must be set."
64
+ exit 1
65
+ fi
66
+
67
+ # Additional worker-specific parameters
68
+ worker_params=(
69
+ "--address=${RAY_HEAD_ADDRESS}:${RAY_HEAD_GCS_PORT:-12001}" # Connect to head node
70
+ "--resources={\"${SERVICE_NAME}\":1, \"node_tag:worker\":1}" # Custom resource for node identification
71
+ "--object-store-memory=${shm_size}"
72
+ )
73
+
74
+ # Start Ray on a worker node
75
+ ray start "${common_params[@]}" "${worker_params[@]}" "$@" -v
76
+ else
77
+ # Additional head-specific parameters
78
+ head_params=(
79
+ "--head"
80
+ "--include-dashboard=$enable_dashboard"
81
+ "--disable-usage-stats"
82
+ "--port=${RAY_HEAD_GCS_PORT:-12001}" # Port of Ray (GCS server)
83
+ "--ray-client-server-port=${RAY_HEAD_CLIENT_SERVER_PORT:-10001}" # Listening port for Ray Client Server
84
+ "--dashboard-host=${NODE_IP_ADDRESS}" # Host to bind the dashboard server
85
+ "--dashboard-grpc-port=${RAY_HEAD_DASHBOARD_GRPC_PORT:-12002}" # Dashboard head to listen for grpc on
86
+ "--dashboard-port=${DASHBOARD_PORT}" # Port to bind the dashboard server for local debugging
87
+ "--resources={\"node_tag:head\":1}" # Resource tag for selecting head as coordinator
88
+ )
89
+
90
+ # Start Ray
91
+ ray start "${common_params[@]}" "${head_params[@]}" "$@"
92
+ fi
@@ -0,0 +1,112 @@
1
+ #!/bin/bash
2
+
3
+ set -e # exit if a command fails
4
+
5
+ # Get and change to system scripts directory
6
+ SYSTEM_DIR=$(cd "$(dirname "$0")" && pwd)
7
+
8
+ # Change directory to user payload directory
9
+ if [ -n "${MLRS_PAYLOAD_DIR}" ]; then
10
+ cd ${MLRS_STAGE_MOUNT_PATH}/${MLRS_PAYLOAD_DIR}
11
+ fi
12
+
13
+ ##### Set up Python environment #####
14
+ export PYTHONPATH=/opt/env/site-packages/
15
+ MLRS_SYSTEM_REQUIREMENTS_FILE=${MLRS_SYSTEM_REQUIREMENTS_FILE:-"${SYSTEM_DIR}/requirements.txt"}
16
+ if [ -f "${MLRS_SYSTEM_REQUIREMENTS_FILE}" ]; then
17
+ echo "Installing packages from $MLRS_SYSTEM_REQUIREMENTS_FILE"
18
+ if ! pip install --no-index -r $MLRS_SYSTEM_REQUIREMENTS_FILE; then
19
+ echo "Offline install failed, falling back to regular pip install"
20
+ pip install -r $MLRS_SYSTEM_REQUIREMENTS_FILE
21
+ fi
22
+ fi
23
+
24
+ MLRS_REQUIREMENTS_FILE=${MLRS_REQUIREMENTS_FILE:-"requirements.txt"}
25
+ if [ -f "${MLRS_REQUIREMENTS_FILE}" ]; then
26
+ # TODO: Prevent collisions with MLRS packages using virtualenvs
27
+ echo "Installing packages from $MLRS_REQUIREMENTS_FILE"
28
+ pip install -r $MLRS_REQUIREMENTS_FILE
29
+ fi
30
+
31
+ MLRS_CONDA_ENV_FILE=${MLRS_CONDA_ENV_FILE:-"environment.yml"}
32
+ if [ -f "${MLRS_CONDA_ENV_FILE}" ]; then
33
+ # TODO: Handle conda environment
34
+ echo "Custom conda environments not currently supported"
35
+ exit 1
36
+ fi
37
+ ##### End Python environment setup #####
38
+
39
+ ##### Set up multi-node configuration #####
40
+ # Configure IP address
41
+ if [ -f "${SYSTEM_DIR}/get_instance_ip.py" ]; then
42
+ eth0Ip=$(python3 "${SYSTEM_DIR}/get_instance_ip.py" \
43
+ "$SNOWFLAKE_SERVICE_NAME" --instance-index=-1)
44
+ else
45
+ eth0Ip=$(ifconfig eth0 2>/dev/null | sed -En -e 's/.*inet ([0-9.]+).*/\1/p')
46
+ fi
47
+
48
+ # Check if eth0Ip is a valid IP address and fall back to default if necessary
49
+ if [[ ! $eth0Ip =~ ^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
50
+ eth0Ip="127.0.0.1"
51
+ fi
52
+
53
+ # Set default values for job environment variables if they don't exist
54
+ # (e.g. some are only populate by SPCS for batch jobs, others just may not be set at all)
55
+ export SNOWFLAKE_JOBS_COUNT=${SNOWFLAKE_JOBS_COUNT:-1}
56
+ export SNOWFLAKE_JOB_INDEX=${SNOWFLAKE_JOB_INDEX:-0}
57
+ export SERVICE_NAME="${SERVICE_NAME:-$SNOWFLAKE_SERVICE_NAME}"
58
+
59
+ ##### Ray configuration #####
60
+
61
+ # Determine if it should be a worker or a head node for batch jobs
62
+ if [[ "$SNOWFLAKE_JOBS_COUNT" -gt 1 ]]; then
63
+ head_info=$(python3 "${SYSTEM_DIR}/get_instance_ip.py" "$SNOWFLAKE_SERVICE_NAME" --head)
64
+ if [ $? -eq 0 ]; then
65
+ # Parse the output using read
66
+ read head_index head_ip head_status<<< "$head_info"
67
+
68
+ if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
69
+ NODE_TYPE="worker"
70
+ fi
71
+
72
+ # Use the parsed variables
73
+ echo "Head Instance Index: $head_index"
74
+ echo "Head Instance IP: $head_ip"
75
+ echo "Head Instance Status: $head_status"
76
+
77
+ # If the head status is not "READY" or "PENDING", exit early
78
+ if [ "$head_status" != "READY" ] && [ "$head_status" != "PENDING" ]; then
79
+ echo "Head instance status is not READY or PENDING. Exiting."
80
+ exit 0
81
+ fi
82
+
83
+ else
84
+ echo "Error: Failed to get head instance information."
85
+ echo "$head_info" # Print the error message
86
+ exit 1
87
+ fi
88
+ fi
89
+
90
+ # Start ML Runtime (non-blocking call)
91
+ NODE_TYPE=$NODE_TYPE RAY_HEAD_ADDRESS="$head_ip" bash ${SYSTEM_DIR}/start_mlruntime.sh
92
+
93
+ if [ "$NODE_TYPE" = "worker" ]; then
94
+ echo "Worker node started on address $eth0Ip. See more logs in the head node."
95
+
96
+ # Start the worker shutdown listener in the background
97
+ echo "Starting worker shutdown listener..."
98
+ python "${SYSTEM_DIR}/worker_shutdown_listener.py"
99
+ WORKER_EXIT_CODE=$?
100
+
101
+ echo "Worker shutdown listener exited with code $WORKER_EXIT_CODE"
102
+ exit $WORKER_EXIT_CODE
103
+ else
104
+ # Run user's Python entrypoint via mljob_launcher
105
+ echo Running command: python "${SYSTEM_DIR}/mljob_launcher.py" "$@"
106
+ python "${SYSTEM_DIR}/mljob_launcher.py" "$@"
107
+
108
+ # After the user's job completes, signal workers to shut down
109
+ echo "User job completed. Signaling workers to shut down..."
110
+ python "${SYSTEM_DIR}/signal_workers.py" --wait-time 15
111
+ echo "Head node job completed. Exiting."
112
+ fi