snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/utils/identifier.py +1 -1
  2. snowflake/ml/_internal/utils/mixins.py +61 -0
  3. snowflake/ml/jobs/_utils/constants.py +1 -1
  4. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  5. snowflake/ml/jobs/_utils/payload_utils.py +6 -5
  6. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  7. snowflake/ml/jobs/_utils/spec_utils.py +6 -4
  8. snowflake/ml/jobs/decorators.py +18 -25
  9. snowflake/ml/jobs/job.py +179 -58
  10. snowflake/ml/jobs/manager.py +194 -145
  11. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  12. snowflake/ml/model/_client/ops/service_ops.py +4 -2
  13. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  14. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  15. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  16. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  17. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  18. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  19. snowflake/ml/model/target_platform.py +11 -0
  20. snowflake/ml/model/task.py +9 -0
  21. snowflake/ml/model/type_hints.py +5 -13
  22. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  23. snowflake/ml/registry/_manager/model_manager.py +30 -15
  24. snowflake/ml/registry/registry.py +119 -42
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
  27. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
  28. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -240,7 +240,7 @@ def get_schema_level_object_identifier(
240
240
  """
241
241
 
242
242
  for identifier in (db, schema, object_name):
243
- if identifier is not None and SF_IDENTIFIER_RE.match(identifier) is None:
243
+ if identifier is not None and SF_IDENTIFIER_RE.fullmatch(identifier) is None:
244
244
  raise ValueError(f"Invalid identifier {identifier}")
245
245
 
246
246
  if others is None:
@@ -0,0 +1,61 @@
1
+ from typing import Any, Optional
2
+
3
+ from snowflake.ml._internal.utils import identifier
4
+ from snowflake.snowpark import session
5
+
6
+
7
+ class SerializableSessionMixin:
8
+ """Mixin that provides pickling capabilities for objects with Snowpark sessions."""
9
+
10
+ def __getstate__(self) -> dict[str, Any]:
11
+ """Customize pickling to exclude non-serializable session and related components."""
12
+ state = self.__dict__.copy()
13
+
14
+ # Save session metadata for validation during unpickling
15
+ if hasattr(self, "_session") and self._session is not None:
16
+ try:
17
+ state["__session-account__"] = self._session.get_current_account()
18
+ state["__session-role__"] = self._session.get_current_role()
19
+ state["__session-database__"] = self._session.get_current_database()
20
+ state["__session-schema__"] = self._session.get_current_schema()
21
+ except Exception:
22
+ pass
23
+
24
+ state["_session"] = None
25
+ return state
26
+
27
+ def __setstate__(self, state: dict[str, Any]) -> None:
28
+ """Restore session from context during unpickling."""
29
+ saved_account = state.pop("__session-account__", None)
30
+ saved_role = state.pop("__session-role__", None)
31
+ saved_database = state.pop("__session-database__", None)
32
+ saved_schema = state.pop("__session-schema__", None)
33
+ self.__dict__.update(state)
34
+
35
+ if saved_account is not None:
36
+
37
+ def identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
38
+ saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
39
+ current_resolved = identifier.resolve_identifier(current) if current is not None else current
40
+ return saved_resolved == current_resolved
41
+
42
+ for active_session in session._get_active_sessions():
43
+ try:
44
+ current_account = active_session.get_current_account()
45
+ current_role = active_session.get_current_role()
46
+ current_database = active_session.get_current_database()
47
+ current_schema = active_session.get_current_schema()
48
+
49
+ if (
50
+ identifiers_match(saved_account, current_account)
51
+ and identifiers_match(saved_role, current_role)
52
+ and identifiers_match(saved_database, current_database)
53
+ and identifiers_match(saved_schema, current_schema)
54
+ ):
55
+ self._session = active_session
56
+ return
57
+ except Exception:
58
+ continue
59
+
60
+ # No matching session found or no metadata available
61
+ raise RuntimeError("No active Snowpark session available. Please create a session.")
@@ -15,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
15
15
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
16
16
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
17
17
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
18
- DEFAULT_IMAGE_TAG = "1.4.2"
18
+ DEFAULT_IMAGE_TAG = "1.5.0"
19
19
  DEFAULT_ENTRYPOINT_PATH = "func.py"
20
20
 
21
21
  # Percent of container memory to allocate for /dev/shm volume
@@ -75,16 +75,75 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
75
75
 
76
76
  Returns:
77
77
  A dictionary containing the execution result if available, None otherwise.
78
+
79
+ Raises:
80
+ RuntimeError: If both pickle and JSON result retrieval fail.
78
81
  """
79
82
  try:
80
83
  # TODO: Check if file exists
81
84
  with session.file.get_stream(result_path) as result_stream:
82
85
  return ExecutionResult.from_dict(pickle.load(result_stream))
83
- except (sp_exceptions.SnowparkSQLException, pickle.UnpicklingError, TypeError, ImportError):
86
+ except (
87
+ sp_exceptions.SnowparkSQLException,
88
+ pickle.UnpicklingError,
89
+ TypeError,
90
+ ImportError,
91
+ AttributeError,
92
+ MemoryError,
93
+ ) as pickle_error:
84
94
  # Fall back to JSON result if loading pickled result fails for any reason
85
- result_json_path = os.path.splitext(result_path)[0] + ".json"
86
- with session.file.get_stream(result_json_path) as result_stream:
87
- return ExecutionResult.from_dict(json.load(result_stream))
95
+ try:
96
+ result_json_path = os.path.splitext(result_path)[0] + ".json"
97
+ with session.file.get_stream(result_json_path) as result_stream:
98
+ return ExecutionResult.from_dict(json.load(result_stream))
99
+ except Exception as json_error:
100
+ # Both pickle and JSON failed - provide helpful error message
101
+ raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
102
+
103
+
104
+ def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
105
+ """Create helpful error messages for common result retrieval failures."""
106
+
107
+ # Package import issues
108
+ if isinstance(error, ImportError):
109
+ return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
110
+
111
+ # Package versions differ between runtime and local environment
112
+ if isinstance(error, AttributeError):
113
+ return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
114
+
115
+ # Serialization issues
116
+ if isinstance(error, TypeError):
117
+ return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
118
+
119
+ # Python version pickling incompatibility
120
+ if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
121
+ # TODO: Update this once we support different Python versions
122
+ client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
123
+ runtime_version = "Python 3.10"
124
+ return (
125
+ f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
126
+ f"local environment using Python {client_version}. Error: {str(error)}"
127
+ )
128
+
129
+ # File access issues
130
+ if isinstance(error, sp_exceptions.SnowparkSQLException):
131
+ if "not found" in str(error).lower() or "does not exist" in str(error).lower():
132
+ return (
133
+ f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
134
+ f"errors. Error: {str(error)}"
135
+ )
136
+ else:
137
+ return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
138
+
139
+ if isinstance(error, MemoryError):
140
+ return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
141
+
142
+ # Generic fallback
143
+ base_message = f"Failed to retrieve job result: {str(error)}"
144
+ if json_error:
145
+ base_message += f" (JSON fallback also failed: {str(json_error)})"
146
+ return base_message
88
147
 
89
148
 
90
149
  def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
@@ -12,13 +12,13 @@ import cloudpickle as cp
12
12
  from packaging import version
13
13
 
14
14
  from snowflake import snowpark
15
+ from snowflake.connector import errors
15
16
  from snowflake.ml.jobs._utils import (
16
17
  constants,
17
18
  function_payload_utils,
18
19
  stage_utils,
19
20
  types,
20
21
  )
21
- from snowflake.snowpark import exceptions as sp_exceptions
22
22
  from snowflake.snowpark._internal import code_generation
23
23
 
24
24
  cp.register_pickle_by_value(function_payload_utils)
@@ -312,14 +312,15 @@ class JobPayload:
312
312
  stage_name = stage_path.parts[0].lstrip("@")
313
313
  # Explicitly check if stage exists first since we may not have CREATE STAGE privilege
314
314
  try:
315
- session.sql("describe stage identifier(?)", params=[stage_name]).collect()
316
- except sp_exceptions.SnowparkSQLException:
317
- session.sql(
315
+ session._conn.run_query("describe stage identifier(?)", params=[stage_name], _force_qmark_paramstyle=True)
316
+ except errors.ProgrammingError:
317
+ session._conn.run_query(
318
318
  "create stage if not exists identifier(?)"
319
319
  " encryption = ( type = 'SNOWFLAKE_SSE' )"
320
320
  " comment = 'Created by snowflake.ml.jobs Python API'",
321
321
  params=[stage_name],
322
- ).collect()
322
+ _force_qmark_paramstyle=True,
323
+ )
323
324
 
324
325
  # Upload payload to stage
325
326
  if not isinstance(source, (Path, stage_utils.StagePath)):
@@ -0,0 +1,9 @@
1
+ from snowflake import snowpark
2
+
3
+
4
+ def get_attribute_map(session: snowpark.Session, requested_attributes: dict[str, int]) -> dict[str, int]:
5
+ metadata = session._conn._cursor.description
6
+ for index in range(len(metadata)):
7
+ if metadata[index].name in requested_attributes.keys():
8
+ requested_attributes[metadata[index].name] = index
9
+ return requested_attributes
@@ -6,16 +6,18 @@ from typing import Any, Optional, Union
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal.utils import snowflake_env
9
- from snowflake.ml.jobs._utils import constants, types
9
+ from snowflake.ml.jobs._utils import constants, query_helper, types
10
10
 
11
11
 
12
12
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
13
13
  """Extract resource information for the specified compute pool"""
14
14
  # Get the instance family
15
- rows = session.sql("show compute pools like ?", params=[compute_pool]).collect()
16
- if not rows:
15
+ rows = session._conn.run_query("show compute pools like ?", params=[compute_pool], _force_qmark_paramstyle=True)
16
+ if not rows or not isinstance(rows, dict) or not rows.get("data"):
17
17
  raise ValueError(f"Compute pool '{compute_pool}' not found")
18
- instance_family: str = rows[0]["instance_family"]
18
+ requested_attributes = query_helper.get_attribute_map(session, {"instance_family": 4})
19
+ compute_pool_info = rows["data"]
20
+ instance_family: str = compute_pool_info[0][requested_attributes["instance_family"]]
19
21
  cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
20
22
 
21
23
  return (
@@ -1,6 +1,6 @@
1
1
  import copy
2
2
  import functools
3
- from typing import Callable, Optional, TypeVar
3
+ from typing import Any, Callable, Optional, TypeVar
4
4
 
5
5
  from typing_extensions import ParamSpec
6
6
 
@@ -20,16 +20,11 @@ def remote(
20
20
  compute_pool: str,
21
21
  *,
22
22
  stage_name: str,
23
+ target_instances: int = 1,
23
24
  pip_requirements: Optional[list[str]] = None,
24
25
  external_access_integrations: Optional[list[str]] = None,
25
- query_warehouse: Optional[str] = None,
26
- env_vars: Optional[dict[str, str]] = None,
27
- target_instances: int = 1,
28
- min_instances: Optional[int] = None,
29
- enable_metrics: bool = False,
30
- database: Optional[str] = None,
31
- schema: Optional[str] = None,
32
26
  session: Optional[snowpark.Session] = None,
27
+ **kwargs: Any,
33
28
  ) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
34
29
  """
35
30
  Submit a job to the compute pool.
@@ -37,17 +32,20 @@ def remote(
37
32
  Args:
38
33
  compute_pool: The compute pool to use for the job.
39
34
  stage_name: The name of the stage where the job payload will be uploaded.
35
+ target_instances: The number of nodes in the job. If none specified, create a single node job.
40
36
  pip_requirements: A list of pip requirements for the job.
41
37
  external_access_integrations: A list of external access integrations.
42
- query_warehouse: The query warehouse to use. Defaults to session warehouse.
43
- env_vars: Environment variables to set in container
44
- target_instances: The number of nodes in the job. If none specified, create a single node job.
45
- min_instances: The minimum number of nodes required to start the job. If none specified,
46
- defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
47
- enable_metrics: Whether to enable metrics publishing for the job.
48
- database: The database to use for the job.
49
- schema: The schema to use for the job.
50
38
  session: The Snowpark session to use. If none specified, uses active session.
39
+ kwargs: Additional keyword arguments. Supported arguments:
40
+ database (str): The database to use for the job.
41
+ schema (str): The schema to use for the job.
42
+ min_instances (int): The minimum number of nodes required to start the job.
43
+ If none specified, defaults to target_instances. If set, the job
44
+ will not start until the minimum number of nodes is available.
45
+ env_vars (dict): Environment variables to set in container.
46
+ enable_metrics (bool): Whether to enable metrics publishing for the job.
47
+ query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
48
+ spec_overrides (dict): A dictionary of overrides for the service spec.
51
49
 
52
50
  Returns:
53
51
  Decorator that dispatches invocations of the decorated function as remote jobs.
@@ -61,22 +59,17 @@ def remote(
61
59
  wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
62
60
 
63
61
  @functools.wraps(func)
64
- def wrapper(*args: _Args.args, **kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
65
- payload = payload_utils.create_function_payload(func, *args, **kwargs)
62
+ def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
63
+ payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
66
64
  job = jm._submit_job(
67
65
  source=payload,
68
66
  stage_name=stage_name,
69
67
  compute_pool=compute_pool,
68
+ target_instances=target_instances,
70
69
  pip_requirements=pip_requirements,
71
70
  external_access_integrations=external_access_integrations,
72
- query_warehouse=query_warehouse,
73
- env_vars=env_vars,
74
- target_instances=target_instances,
75
- min_instances=min_instances,
76
- enable_metrics=enable_metrics,
77
- database=database,
78
- schema=schema,
79
71
  session=payload.session or session,
72
+ **kwargs,
80
73
  )
81
74
  assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
82
75
  return job