snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -13
  2. snowflake/ml/data/data_connector.py +1 -1
  3. snowflake/ml/jobs/_utils/constants.py +9 -0
  4. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  5. snowflake/ml/jobs/_utils/payload_utils.py +12 -4
  6. snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
  7. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +85 -2
  8. snowflake/ml/jobs/_utils/spec_utils.py +7 -5
  9. snowflake/ml/jobs/decorators.py +7 -3
  10. snowflake/ml/jobs/job.py +158 -25
  11. snowflake/ml/jobs/manager.py +29 -19
  12. snowflake/ml/model/_client/ops/service_ops.py +5 -3
  13. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  14. snowflake/ml/model/_client/sql/model_version.py +1 -1
  15. snowflake/ml/model/_client/sql/service.py +16 -19
  16. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  17. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  18. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  19. snowflake/ml/monitoring/explain_visualize.py +160 -22
  20. snowflake/ml/utils/connection_params.py +8 -2
  21. snowflake/ml/version.py +1 -1
  22. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +27 -9
  23. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +26 -26
  24. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
  25. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  26. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ import enum
4
4
  import functools
5
5
  import inspect
6
6
  import operator
7
+ import os
7
8
  import sys
8
9
  import time
9
10
  import traceback
@@ -13,7 +14,7 @@ from typing import Any, Callable, Iterable, Mapping, Optional, TypeVar, Union, c
13
14
  from typing_extensions import ParamSpec
14
15
 
15
16
  from snowflake import connector
16
- from snowflake.connector import telemetry as connector_telemetry, time_util
17
+ from snowflake.connector import connect, telemetry as connector_telemetry, time_util
17
18
  from snowflake.ml import version as snowml_version
18
19
  from snowflake.ml._internal import env
19
20
  from snowflake.ml._internal.exceptions import (
@@ -37,6 +38,37 @@ _Args = ParamSpec("_Args")
37
38
  _ReturnValue = TypeVar("_ReturnValue")
38
39
 
39
40
 
41
+ def _get_login_token() -> Union[str, bytes]:
42
+ with open("/snowflake/session/token") as f:
43
+ return f.read()
44
+
45
+
46
+ def _get_snowflake_connection() -> Optional[connector.SnowflakeConnection]:
47
+ conn = None
48
+ if os.getenv("SNOWFLAKE_HOST") is not None and os.getenv("SNOWFLAKE_ACCOUNT") is not None:
49
+ try:
50
+ conn = connect(
51
+ host=os.getenv("SNOWFLAKE_HOST"),
52
+ account=os.getenv("SNOWFLAKE_ACCOUNT"),
53
+ token=_get_login_token(),
54
+ authenticator="oauth",
55
+ )
56
+ except Exception:
57
+ # Failed to get a new SnowflakeConnection in SPCS. Fall back to using the active session.
58
+ # This will work in some cases once SPCS enables multiple authentication modes, and users select any auth.
59
+ pass
60
+
61
+ if conn is None:
62
+ try:
63
+ active_session = next(iter(session._get_active_sessions()))
64
+ conn = active_session._conn._conn if active_session.telemetry_enabled else None
65
+ except snowpark_exceptions.SnowparkSessionException:
66
+ # Failed to get an active session. No connection available.
67
+ pass
68
+
69
+ return conn
70
+
71
+
40
72
  @enum.unique
41
73
  class TelemetryProject(enum.Enum):
42
74
  MLOPS = "MLOps"
@@ -378,10 +410,14 @@ def send_custom_usage(
378
410
  data: Optional[dict[str, Any]] = None,
379
411
  **kwargs: Any,
380
412
  ) -> None:
381
- active_session = next(iter(session._get_active_sessions()))
382
- assert active_session, "Missing active session object"
413
+ conn = _get_snowflake_connection()
414
+ if conn is None:
415
+ raise ValueError(
416
+ """Snowflake connection is required to send custom telemetry. This means there
417
+ must be at least one active session, or that telemetry is being sent from within an SPCS service."""
418
+ )
383
419
 
384
- client = _SourceTelemetryClient(conn=active_session._conn._conn, project=project, subproject=subproject)
420
+ client = _SourceTelemetryClient(conn=conn, project=project, subproject=subproject)
385
421
  common_metrics = client._create_basic_telemetry_data(telemetry_type=telemetry_type)
386
422
  data = {**common_metrics, TelemetryField.KEY_DATA.value: data, **kwargs}
387
423
  client._send(msg=data)
@@ -501,7 +537,6 @@ def send_api_usage_telemetry(
501
537
  return update_stmt_params_if_snowpark_df(result, statement_params)
502
538
 
503
539
  # prioritize `conn_attr_name` over the active session
504
- telemetry_enabled = True
505
540
  if conn_attr_name:
506
541
  # raise AttributeError if conn attribute does not exist in `self`
507
542
  conn = operator.attrgetter(conn_attr_name)(args[0])
@@ -509,16 +544,10 @@ def send_api_usage_telemetry(
509
544
  raise TypeError(
510
545
  f"Expected a conn object of type {' or '.join(_CONNECTION_TYPES.keys())} but got {type(conn)}"
511
546
  )
512
- # get an active session
513
547
  else:
514
- try:
515
- active_session = next(iter(session._get_active_sessions()))
516
- conn = active_session._conn._conn
517
- telemetry_enabled = active_session.telemetry_enabled
518
- except snowpark_exceptions.SnowparkSessionException:
519
- conn = None
548
+ conn = _get_snowflake_connection()
520
549
 
521
- if conn is None or not telemetry_enabled:
550
+ if conn is None:
522
551
  # Telemetry not enabled, just execute without our additional telemetry logic
523
552
  try:
524
553
  return ctx.run(execute_func_with_statement_params)
@@ -249,7 +249,7 @@ class DataConnector:
249
249
 
250
250
  # Switch to use Runtime's Data Ingester if running in ML runtime
251
251
  # Fail silently if the data ingester is not found
252
- if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR):
252
+ if env.IN_ML_RUNTIME and os.getenv(env.USE_OPTIMIZED_DATA_INGESTOR, "").lower() in ("true", "1"):
253
253
  try:
254
254
  from runtime_external_entities import get_ingester_class
255
255
 
@@ -5,6 +5,7 @@ from snowflake.ml.jobs._utils.types import ComputeResources
5
5
  DEFAULT_CONTAINER_NAME = "main"
6
6
  PAYLOAD_DIR_ENV_VAR = "MLRS_PAYLOAD_DIR"
7
7
  RESULT_PATH_ENV_VAR = "MLRS_RESULT_PATH"
8
+ MIN_INSTANCES_ENV_VAR = "MLRS_MIN_INSTANCES"
8
9
  MEMORY_VOLUME_NAME = "dshm"
9
10
  STAGE_VOLUME_NAME = "stage-volume"
10
11
  STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
@@ -37,6 +38,7 @@ RAY_PORTS = {
37
38
  # Node health check configuration
38
39
  # TODO(SNOW-1937020): Revisit the health check configuration
39
40
  ML_RUNTIME_HEALTH_CHECK_PORT = "5001"
41
+ ENABLE_HEALTH_CHECKS_ENV_VAR = "ENABLE_HEALTH_CHECKS"
40
42
  ENABLE_HEALTH_CHECKS = "false"
41
43
 
42
44
  # Job status polling constants
@@ -47,6 +49,13 @@ JOB_POLL_MAX_DELAY_SECONDS = 1
47
49
  IS_MLJOB_REMOTE_ATTR = "_is_mljob_remote_callable"
48
50
  RESULT_PATH_DEFAULT_VALUE = "mljob_result.pkl"
49
51
 
52
+ # Log start and end messages
53
+ LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
54
+ LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
55
+
56
+ # Default setting for verbose logging in get_log function
57
+ DEFAULT_VERBOSE_LOG = False
58
+
50
59
  # Compute pool resource information
51
60
  # TODO: Query Snowflake for resource information instead of relying on this hardcoded
52
61
  # table from https://docs.snowflake.com/en/sql-reference/sql/create-compute-pool
@@ -80,7 +80,7 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
80
80
  # TODO: Check if file exists
81
81
  with session.file.get_stream(result_path) as result_stream:
82
82
  return ExecutionResult.from_dict(pickle.load(result_stream))
83
- except (sp_exceptions.SnowparkSQLException, TypeError, pickle.UnpicklingError):
83
+ except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
84
84
  # Fall back to JSON result if loading pickled result fails for any reason
85
85
  result_json_path = os.path.splitext(result_path)[0] + ".json"
86
86
  with session.file.get_stream(result_json_path) as result_stream:
@@ -100,6 +100,11 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
100
100
  # Parse the output using read
101
101
  read head_index head_ip head_status<<< "$head_info"
102
102
 
103
+ if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
104
+ NODE_TYPE="worker"
105
+ echo "{constants.LOG_START_MSG}"
106
+ fi
107
+
103
108
  # Use the parsed variables
104
109
  echo "Head Instance Index: $head_index"
105
110
  echo "Head Instance IP: $head_ip"
@@ -117,9 +122,7 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
117
122
  exit 1
118
123
  fi
119
124
 
120
- if [ "$SNOWFLAKE_JOB_INDEX" -ne "$head_index" ]; then
121
- NODE_TYPE="worker"
122
- fi
125
+
123
126
  fi
124
127
 
125
128
  # Common parameters for both head and worker nodes
@@ -168,6 +171,10 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
168
171
  # Start Ray on a worker node - run in background
169
172
  ray start "${{common_params[@]}}" "${{worker_params[@]}}" -v --block &
170
173
 
174
+ echo "Worker node started on address $eth0Ip. See more logs in the head node."
175
+
176
+ echo "{constants.LOG_END_MSG}"
177
+
171
178
  # Start the worker shutdown listener in the background
172
179
  echo "Starting worker shutdown listener..."
173
180
  python worker_shutdown_listener.py
@@ -189,15 +196,16 @@ _STARTUP_SCRIPT_CODE = textwrap.dedent(
189
196
 
190
197
  # Start Ray on the head node
191
198
  ray start "${{common_params[@]}}" "${{head_params[@]}}" -v
199
+
192
200
  ##### End Ray configuration #####
193
201
 
194
202
  # TODO: Monitor MLRS and handle process crashes
195
203
  python -m web.ml_runtime_grpc_server &
196
204
 
197
205
  # TODO: Launch worker service(s) using SQL if Ray and MLRS successfully started
206
+ echo Running command: python "$@"
198
207
 
199
208
  # Run user's Python entrypoint
200
- echo Running command: python "$@"
201
209
  python "$@"
202
210
 
203
211
  # After the user's job completes, signal workers to shut down
@@ -2,3 +2,9 @@
2
2
  SHUTDOWN_ACTOR_NAME = "ShutdownSignal"
3
3
  SHUTDOWN_ACTOR_NAMESPACE = "default"
4
4
  SHUTDOWN_RPC_TIMEOUT_SECONDS = 5.0
5
+
6
+
7
+ # Log start and end messages
8
+ # Inherited from snowflake.ml.jobs._utils.constants
9
+ LOG_START_MSG = "--------------------------------\nML job started\n--------------------------------"
10
+ LOG_END_MSG = "--------------------------------\nML job finished\n--------------------------------"
@@ -2,25 +2,35 @@ import argparse
2
2
  import copy
3
3
  import importlib.util
4
4
  import json
5
+ import logging
5
6
  import os
6
7
  import runpy
7
8
  import sys
9
+ import time
8
10
  import traceback
9
11
  import warnings
10
12
  from pathlib import Path
11
13
  from typing import Any, Optional
12
14
 
13
15
  import cloudpickle
16
+ from constants import LOG_END_MSG, LOG_START_MSG
14
17
 
15
18
  from snowflake.ml.jobs._utils import constants
16
19
  from snowflake.ml.utils.connection_params import SnowflakeLoginOptions
17
20
  from snowflake.snowpark import Session
18
21
 
22
+ # Configure logging
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
24
+ logger = logging.getLogger(__name__)
25
+
19
26
  # Fallbacks in case of SnowML version mismatch
20
27
  RESULT_PATH_ENV_VAR = getattr(constants, "RESULT_PATH_ENV_VAR", "MLRS_RESULT_PATH")
21
-
22
28
  JOB_RESULT_PATH = os.environ.get(RESULT_PATH_ENV_VAR, "mljob_result.pkl")
23
29
 
30
+ # Constants for the wait_for_min_instances function
31
+ CHECK_INTERVAL = 10 # seconds
32
+ TIMEOUT = 720 # seconds
33
+
24
34
 
25
35
  try:
26
36
  from snowflake.ml.jobs._utils.interop_utils import ExecutionResult
@@ -62,6 +72,66 @@ class SimpleJSONEncoder(json.JSONEncoder):
62
72
  return f"Unserializable object: {repr(obj)}"
63
73
 
64
74
 
75
+ def get_active_node_count() -> int:
76
+ """
77
+ Count the number of active nodes in the Ray cluster.
78
+
79
+ Returns:
80
+ int: Total count of active nodes
81
+ """
82
+ import ray
83
+
84
+ if not ray.is_initialized():
85
+ ray.init(address="auto", ignore_reinit_error=True, log_to_driver=False)
86
+ try:
87
+ nodes = [node for node in ray.nodes() if node.get("Alive")]
88
+ total_active = len(nodes)
89
+
90
+ logger.info(f"Active nodes: {total_active}")
91
+ return total_active
92
+ except Exception as e:
93
+ logger.warning(f"Error getting active node count: {e}")
94
+ return 0
95
+
96
+
97
+ def wait_for_min_instances(min_instances: int) -> None:
98
+ """
99
+ Wait until the specified minimum number of instances are available in the Ray cluster.
100
+
101
+ Args:
102
+ min_instances: Minimum number of instances required
103
+
104
+ Raises:
105
+ TimeoutError: If failed to connect to Ray or if minimum instances are not available within timeout
106
+ """
107
+ if min_instances <= 1:
108
+ logger.debug("Minimum instances is 1 or less, no need to wait for additional instances")
109
+ return
110
+
111
+ start_time = time.time()
112
+ timeout = os.getenv("JOB_MIN_INSTANCES_TIMEOUT", TIMEOUT)
113
+ check_interval = os.getenv("JOB_MIN_INSTANCES_CHECK_INTERVAL", CHECK_INTERVAL)
114
+ logger.debug(f"Waiting for at least {min_instances} instances to be ready (timeout: {timeout}s)")
115
+
116
+ while time.time() - start_time < timeout:
117
+ total_nodes = get_active_node_count()
118
+
119
+ if total_nodes >= min_instances:
120
+ elapsed = time.time() - start_time
121
+ logger.info(f"Minimum instance requirement met: {total_nodes} instances available after {elapsed:.1f}s")
122
+ return
123
+
124
+ logger.debug(
125
+ f"Waiting for instances: {total_nodes}/{min_instances} available "
126
+ f"(elapsed: {time.time() - start_time:.1f}s)"
127
+ )
128
+ time.sleep(check_interval)
129
+
130
+ raise TimeoutError(
131
+ f"Timed out after {timeout}s waiting for {min_instances} instances, only {get_active_node_count()} available"
132
+ )
133
+
134
+
65
135
  def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = None) -> Any:
66
136
  """
67
137
  Execute a Python script and return its result.
@@ -86,6 +156,7 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
86
156
  session = Session.builder.configs(SnowflakeLoginOptions()).create() # noqa: F841
87
157
 
88
158
  try:
159
+
89
160
  if main_func:
90
161
  # Use importlib for scripts with a main function defined
91
162
  module_name = Path(script_path).stem
@@ -126,9 +197,21 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
126
197
  Raises:
127
198
  Exception: Re-raises any exception caught during script execution.
128
199
  """
129
- # Run the script with the specified arguments
130
200
  try:
201
+ # Wait for minimum required instances if specified
202
+ min_instances_str = os.environ.get("JOB_MIN_INSTANCES", 1)
203
+ if min_instances_str and int(min_instances_str) > 1:
204
+ wait_for_min_instances(int(min_instances_str))
205
+
206
+ # Log start marker for user script execution
207
+ print(LOG_START_MSG) # noqa: T201
208
+
209
+ # Run the script with the specified arguments
131
210
  result = run_script(script_path, *script_args, main_func=script_main_func)
211
+
212
+ # Log end marker for user script execution
213
+ print(LOG_END_MSG) # noqa: T201
214
+
132
215
  result_obj = ExecutionResult(result=result)
133
216
  return result_obj
134
217
  except Exception as e:
@@ -85,7 +85,8 @@ def generate_service_spec(
85
85
  compute_pool: str,
86
86
  payload: types.UploadedPayload,
87
87
  args: Optional[list[str]] = None,
88
- num_instances: Optional[int] = None,
88
+ target_instances: int = 1,
89
+ min_instances: int = 1,
89
90
  enable_metrics: bool = False,
90
91
  ) -> dict[str, Any]:
91
92
  """
@@ -96,13 +97,13 @@ def generate_service_spec(
96
97
  compute_pool: Compute pool for job execution
97
98
  payload: Uploaded job payload
98
99
  args: Arguments to pass to entrypoint script
99
- num_instances: Number of instances for multi-node job
100
+ target_instances: Number of instances for multi-node job
100
101
  enable_metrics: Enable platform metrics for the job
102
+ min_instances: Minimum number of instances required to start the job
101
103
 
102
104
  Returns:
103
105
  Job service specification
104
106
  """
105
- is_multi_node = num_instances is not None and num_instances > 1
106
107
  image_spec = _get_image_spec(session, compute_pool)
107
108
 
108
109
  # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
@@ -180,10 +181,11 @@ def generate_service_spec(
180
181
  }
181
182
  endpoints = []
182
183
 
183
- if is_multi_node:
184
+ if target_instances > 1:
184
185
  # Update environment variables for multi-node job
185
186
  env_vars.update(constants.RAY_PORTS)
186
- env_vars["ENABLE_HEALTH_CHECKS"] = constants.ENABLE_HEALTH_CHECKS
187
+ env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
188
+ env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
187
189
 
188
190
  # Define Ray endpoints for intra-service instance communication
189
191
  ray_endpoints = [
@@ -24,7 +24,8 @@ def remote(
24
24
  external_access_integrations: Optional[list[str]] = None,
25
25
  query_warehouse: Optional[str] = None,
26
26
  env_vars: Optional[dict[str, str]] = None,
27
- num_instances: Optional[int] = None,
27
+ target_instances: int = 1,
28
+ min_instances: int = 1,
28
29
  enable_metrics: bool = False,
29
30
  database: Optional[str] = None,
30
31
  schema: Optional[str] = None,
@@ -40,7 +41,9 @@ def remote(
40
41
  external_access_integrations: A list of external access integrations.
41
42
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
42
43
  env_vars: Environment variables to set in container
43
- num_instances: The number of nodes in the job. If none specified, create a single node job.
44
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
45
+ min_instances: The minimum number of nodes required to start the job. If none specified, defaults to 1.
46
+ If set, the job will not start until the minimum number of nodes is available.
44
47
  enable_metrics: Whether to enable metrics publishing for the job.
45
48
  database: The database to use for the job.
46
49
  schema: The schema to use for the job.
@@ -69,7 +72,8 @@ def remote(
69
72
  external_access_integrations=external_access_integrations,
70
73
  query_warehouse=query_warehouse,
71
74
  env_vars=env_vars,
72
- num_instances=num_instances,
75
+ target_instances=target_instances,
76
+ min_instances=min_instances,
73
77
  enable_metrics=enable_metrics,
74
78
  database=database,
75
79
  schema=schema,