snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +6 -5
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/spec_utils.py +6 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +179 -58
- snowflake/ml/jobs/manager.py +194 -145
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +4 -2
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +119 -42
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -240,7 +240,7 @@ def get_schema_level_object_identifier(
|
|
240
240
|
"""
|
241
241
|
|
242
242
|
for identifier in (db, schema, object_name):
|
243
|
-
if identifier is not None and SF_IDENTIFIER_RE.
|
243
|
+
if identifier is not None and SF_IDENTIFIER_RE.fullmatch(identifier) is None:
|
244
244
|
raise ValueError(f"Invalid identifier {identifier}")
|
245
245
|
|
246
246
|
if others is None:
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from typing import Any, Optional
|
2
|
+
|
3
|
+
from snowflake.ml._internal.utils import identifier
|
4
|
+
from snowflake.snowpark import session
|
5
|
+
|
6
|
+
|
7
|
+
class SerializableSessionMixin:
|
8
|
+
"""Mixin that provides pickling capabilities for objects with Snowpark sessions."""
|
9
|
+
|
10
|
+
def __getstate__(self) -> dict[str, Any]:
|
11
|
+
"""Customize pickling to exclude non-serializable session and related components."""
|
12
|
+
state = self.__dict__.copy()
|
13
|
+
|
14
|
+
# Save session metadata for validation during unpickling
|
15
|
+
if hasattr(self, "_session") and self._session is not None:
|
16
|
+
try:
|
17
|
+
state["__session-account__"] = self._session.get_current_account()
|
18
|
+
state["__session-role__"] = self._session.get_current_role()
|
19
|
+
state["__session-database__"] = self._session.get_current_database()
|
20
|
+
state["__session-schema__"] = self._session.get_current_schema()
|
21
|
+
except Exception:
|
22
|
+
pass
|
23
|
+
|
24
|
+
state["_session"] = None
|
25
|
+
return state
|
26
|
+
|
27
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
28
|
+
"""Restore session from context during unpickling."""
|
29
|
+
saved_account = state.pop("__session-account__", None)
|
30
|
+
saved_role = state.pop("__session-role__", None)
|
31
|
+
saved_database = state.pop("__session-database__", None)
|
32
|
+
saved_schema = state.pop("__session-schema__", None)
|
33
|
+
self.__dict__.update(state)
|
34
|
+
|
35
|
+
if saved_account is not None:
|
36
|
+
|
37
|
+
def identifiers_match(saved: Optional[str], current: Optional[str]) -> bool:
|
38
|
+
saved_resolved = identifier.resolve_identifier(saved) if saved is not None else saved
|
39
|
+
current_resolved = identifier.resolve_identifier(current) if current is not None else current
|
40
|
+
return saved_resolved == current_resolved
|
41
|
+
|
42
|
+
for active_session in session._get_active_sessions():
|
43
|
+
try:
|
44
|
+
current_account = active_session.get_current_account()
|
45
|
+
current_role = active_session.get_current_role()
|
46
|
+
current_database = active_session.get_current_database()
|
47
|
+
current_schema = active_session.get_current_schema()
|
48
|
+
|
49
|
+
if (
|
50
|
+
identifiers_match(saved_account, current_account)
|
51
|
+
and identifiers_match(saved_role, current_role)
|
52
|
+
and identifiers_match(saved_database, current_database)
|
53
|
+
and identifiers_match(saved_schema, current_schema)
|
54
|
+
):
|
55
|
+
self._session = active_session
|
56
|
+
return
|
57
|
+
except Exception:
|
58
|
+
continue
|
59
|
+
|
60
|
+
# No matching session found or no metadata available
|
61
|
+
raise RuntimeError("No active Snowpark session available. Please create a session.")
|
@@ -15,7 +15,7 @@ STAGE_VOLUME_MOUNT_PATH = "/mnt/app"
|
|
15
15
|
DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
|
16
16
|
DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
|
17
17
|
DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
|
18
|
-
DEFAULT_IMAGE_TAG = "1.
|
18
|
+
DEFAULT_IMAGE_TAG = "1.5.0"
|
19
19
|
DEFAULT_ENTRYPOINT_PATH = "func.py"
|
20
20
|
|
21
21
|
# Percent of container memory to allocate for /dev/shm volume
|
@@ -75,16 +75,75 @@ def fetch_result(session: snowpark.Session, result_path: str) -> ExecutionResult
|
|
75
75
|
|
76
76
|
Returns:
|
77
77
|
A dictionary containing the execution result if available, None otherwise.
|
78
|
+
|
79
|
+
Raises:
|
80
|
+
RuntimeError: If both pickle and JSON result retrieval fail.
|
78
81
|
"""
|
79
82
|
try:
|
80
83
|
# TODO: Check if file exists
|
81
84
|
with session.file.get_stream(result_path) as result_stream:
|
82
85
|
return ExecutionResult.from_dict(pickle.load(result_stream))
|
83
|
-
except (
|
86
|
+
except (
|
87
|
+
sp_exceptions.SnowparkSQLException,
|
88
|
+
pickle.UnpicklingError,
|
89
|
+
TypeError,
|
90
|
+
ImportError,
|
91
|
+
AttributeError,
|
92
|
+
MemoryError,
|
93
|
+
) as pickle_error:
|
84
94
|
# Fall back to JSON result if loading pickled result fails for any reason
|
85
|
-
|
86
|
-
|
87
|
-
|
95
|
+
try:
|
96
|
+
result_json_path = os.path.splitext(result_path)[0] + ".json"
|
97
|
+
with session.file.get_stream(result_json_path) as result_stream:
|
98
|
+
return ExecutionResult.from_dict(json.load(result_stream))
|
99
|
+
except Exception as json_error:
|
100
|
+
# Both pickle and JSON failed - provide helpful error message
|
101
|
+
raise RuntimeError(_fetch_result_error_message(pickle_error, result_path, json_error)) from pickle_error
|
102
|
+
|
103
|
+
|
104
|
+
def _fetch_result_error_message(error: Exception, result_path: str, json_error: Optional[Exception] = None) -> str:
|
105
|
+
"""Create helpful error messages for common result retrieval failures."""
|
106
|
+
|
107
|
+
# Package import issues
|
108
|
+
if isinstance(error, ImportError):
|
109
|
+
return f"Failed to retrieve job result: Package not installed in your local environment. Error: {str(error)}"
|
110
|
+
|
111
|
+
# Package versions differ between runtime and local environment
|
112
|
+
if isinstance(error, AttributeError):
|
113
|
+
return f"Failed to retrieve job result: Package version mismatch. Error: {str(error)}"
|
114
|
+
|
115
|
+
# Serialization issues
|
116
|
+
if isinstance(error, TypeError):
|
117
|
+
return f"Failed to retrieve job result: Non-serializable objects were returned. Error: {str(error)}"
|
118
|
+
|
119
|
+
# Python version pickling incompatibility
|
120
|
+
if isinstance(error, pickle.UnpicklingError) and "protocol" in str(error).lower():
|
121
|
+
# TODO: Update this once we support different Python versions
|
122
|
+
client_version = f"Python {sys.version_info.major}.{sys.version_info.minor}"
|
123
|
+
runtime_version = "Python 3.10"
|
124
|
+
return (
|
125
|
+
f"Failed to retrieve job result: Python version mismatch - job ran on {runtime_version}, "
|
126
|
+
f"local environment using Python {client_version}. Error: {str(error)}"
|
127
|
+
)
|
128
|
+
|
129
|
+
# File access issues
|
130
|
+
if isinstance(error, sp_exceptions.SnowparkSQLException):
|
131
|
+
if "not found" in str(error).lower() or "does not exist" in str(error).lower():
|
132
|
+
return (
|
133
|
+
f"Failed to retrieve job result: No result file found. Check job.get_logs() for execution "
|
134
|
+
f"errors. Error: {str(error)}"
|
135
|
+
)
|
136
|
+
else:
|
137
|
+
return f"Failed to retrieve job result: Cannot access result file. Error: {str(error)}"
|
138
|
+
|
139
|
+
if isinstance(error, MemoryError):
|
140
|
+
return f"Failed to retrieve job result: Result too large for memory. Error: {str(error)}"
|
141
|
+
|
142
|
+
# Generic fallback
|
143
|
+
base_message = f"Failed to retrieve job result: {str(error)}"
|
144
|
+
if json_error:
|
145
|
+
base_message += f" (JSON fallback also failed: {str(json_error)})"
|
146
|
+
return base_message
|
88
147
|
|
89
148
|
|
90
149
|
def load_exception(exc_type_name: str, exc_value: Union[Exception, str], exc_tb: str) -> Exception:
|
@@ -12,13 +12,13 @@ import cloudpickle as cp
|
|
12
12
|
from packaging import version
|
13
13
|
|
14
14
|
from snowflake import snowpark
|
15
|
+
from snowflake.connector import errors
|
15
16
|
from snowflake.ml.jobs._utils import (
|
16
17
|
constants,
|
17
18
|
function_payload_utils,
|
18
19
|
stage_utils,
|
19
20
|
types,
|
20
21
|
)
|
21
|
-
from snowflake.snowpark import exceptions as sp_exceptions
|
22
22
|
from snowflake.snowpark._internal import code_generation
|
23
23
|
|
24
24
|
cp.register_pickle_by_value(function_payload_utils)
|
@@ -312,14 +312,15 @@ class JobPayload:
|
|
312
312
|
stage_name = stage_path.parts[0].lstrip("@")
|
313
313
|
# Explicitly check if stage exists first since we may not have CREATE STAGE privilege
|
314
314
|
try:
|
315
|
-
session.
|
316
|
-
except
|
317
|
-
session.
|
315
|
+
session._conn.run_query("describe stage identifier(?)", params=[stage_name], _force_qmark_paramstyle=True)
|
316
|
+
except errors.ProgrammingError:
|
317
|
+
session._conn.run_query(
|
318
318
|
"create stage if not exists identifier(?)"
|
319
319
|
" encryption = ( type = 'SNOWFLAKE_SSE' )"
|
320
320
|
" comment = 'Created by snowflake.ml.jobs Python API'",
|
321
321
|
params=[stage_name],
|
322
|
-
|
322
|
+
_force_qmark_paramstyle=True,
|
323
|
+
)
|
323
324
|
|
324
325
|
# Upload payload to stage
|
325
326
|
if not isinstance(source, (Path, stage_utils.StagePath)):
|
@@ -0,0 +1,9 @@
|
|
1
|
+
from snowflake import snowpark
|
2
|
+
|
3
|
+
|
4
|
+
def get_attribute_map(session: snowpark.Session, requested_attributes: dict[str, int]) -> dict[str, int]:
|
5
|
+
metadata = session._conn._cursor.description
|
6
|
+
for index in range(len(metadata)):
|
7
|
+
if metadata[index].name in requested_attributes.keys():
|
8
|
+
requested_attributes[metadata[index].name] = index
|
9
|
+
return requested_attributes
|
@@ -6,16 +6,18 @@ from typing import Any, Optional, Union
|
|
6
6
|
|
7
7
|
from snowflake import snowpark
|
8
8
|
from snowflake.ml._internal.utils import snowflake_env
|
9
|
-
from snowflake.ml.jobs._utils import constants, types
|
9
|
+
from snowflake.ml.jobs._utils import constants, query_helper, types
|
10
10
|
|
11
11
|
|
12
12
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
13
13
|
"""Extract resource information for the specified compute pool"""
|
14
14
|
# Get the instance family
|
15
|
-
rows = session.
|
16
|
-
if not rows:
|
15
|
+
rows = session._conn.run_query("show compute pools like ?", params=[compute_pool], _force_qmark_paramstyle=True)
|
16
|
+
if not rows or not isinstance(rows, dict) or not rows.get("data"):
|
17
17
|
raise ValueError(f"Compute pool '{compute_pool}' not found")
|
18
|
-
|
18
|
+
requested_attributes = query_helper.get_attribute_map(session, {"instance_family": 4})
|
19
|
+
compute_pool_info = rows["data"]
|
20
|
+
instance_family: str = compute_pool_info[0][requested_attributes["instance_family"]]
|
19
21
|
cloud = snowflake_env.get_current_cloud(session, default=snowflake_env.SnowflakeCloudType.AWS)
|
20
22
|
|
21
23
|
return (
|
snowflake/ml/jobs/decorators.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import copy
|
2
2
|
import functools
|
3
|
-
from typing import Callable, Optional, TypeVar
|
3
|
+
from typing import Any, Callable, Optional, TypeVar
|
4
4
|
|
5
5
|
from typing_extensions import ParamSpec
|
6
6
|
|
@@ -20,16 +20,11 @@ def remote(
|
|
20
20
|
compute_pool: str,
|
21
21
|
*,
|
22
22
|
stage_name: str,
|
23
|
+
target_instances: int = 1,
|
23
24
|
pip_requirements: Optional[list[str]] = None,
|
24
25
|
external_access_integrations: Optional[list[str]] = None,
|
25
|
-
query_warehouse: Optional[str] = None,
|
26
|
-
env_vars: Optional[dict[str, str]] = None,
|
27
|
-
target_instances: int = 1,
|
28
|
-
min_instances: Optional[int] = None,
|
29
|
-
enable_metrics: bool = False,
|
30
|
-
database: Optional[str] = None,
|
31
|
-
schema: Optional[str] = None,
|
32
26
|
session: Optional[snowpark.Session] = None,
|
27
|
+
**kwargs: Any,
|
33
28
|
) -> Callable[[Callable[_Args, _ReturnValue]], Callable[_Args, jb.MLJob[_ReturnValue]]]:
|
34
29
|
"""
|
35
30
|
Submit a job to the compute pool.
|
@@ -37,17 +32,20 @@ def remote(
|
|
37
32
|
Args:
|
38
33
|
compute_pool: The compute pool to use for the job.
|
39
34
|
stage_name: The name of the stage where the job payload will be uploaded.
|
35
|
+
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
40
36
|
pip_requirements: A list of pip requirements for the job.
|
41
37
|
external_access_integrations: A list of external access integrations.
|
42
|
-
query_warehouse: The query warehouse to use. Defaults to session warehouse.
|
43
|
-
env_vars: Environment variables to set in container
|
44
|
-
target_instances: The number of nodes in the job. If none specified, create a single node job.
|
45
|
-
min_instances: The minimum number of nodes required to start the job. If none specified,
|
46
|
-
defaults to target_instances. If set, the job will not start until the minimum number of nodes is available.
|
47
|
-
enable_metrics: Whether to enable metrics publishing for the job.
|
48
|
-
database: The database to use for the job.
|
49
|
-
schema: The schema to use for the job.
|
50
38
|
session: The Snowpark session to use. If none specified, uses active session.
|
39
|
+
kwargs: Additional keyword arguments. Supported arguments:
|
40
|
+
database (str): The database to use for the job.
|
41
|
+
schema (str): The schema to use for the job.
|
42
|
+
min_instances (int): The minimum number of nodes required to start the job.
|
43
|
+
If none specified, defaults to target_instances. If set, the job
|
44
|
+
will not start until the minimum number of nodes is available.
|
45
|
+
env_vars (dict): Environment variables to set in container.
|
46
|
+
enable_metrics (bool): Whether to enable metrics publishing for the job.
|
47
|
+
query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
|
48
|
+
spec_overrides (dict): A dictionary of overrides for the service spec.
|
51
49
|
|
52
50
|
Returns:
|
53
51
|
Decorator that dispatches invocations of the decorated function as remote jobs.
|
@@ -61,22 +59,17 @@ def remote(
|
|
61
59
|
wrapped_func.__code__ = wrapped_func.__code__.replace(co_firstlineno=func.__code__.co_firstlineno + 1)
|
62
60
|
|
63
61
|
@functools.wraps(func)
|
64
|
-
def wrapper(*
|
65
|
-
payload = payload_utils.create_function_payload(func, *
|
62
|
+
def wrapper(*_args: _Args.args, **_kwargs: _Args.kwargs) -> jb.MLJob[_ReturnValue]:
|
63
|
+
payload = payload_utils.create_function_payload(func, *_args, **_kwargs)
|
66
64
|
job = jm._submit_job(
|
67
65
|
source=payload,
|
68
66
|
stage_name=stage_name,
|
69
67
|
compute_pool=compute_pool,
|
68
|
+
target_instances=target_instances,
|
70
69
|
pip_requirements=pip_requirements,
|
71
70
|
external_access_integrations=external_access_integrations,
|
72
|
-
query_warehouse=query_warehouse,
|
73
|
-
env_vars=env_vars,
|
74
|
-
target_instances=target_instances,
|
75
|
-
min_instances=min_instances,
|
76
|
-
enable_metrics=enable_metrics,
|
77
|
-
database=database,
|
78
|
-
schema=schema,
|
79
71
|
session=payload.session or session,
|
72
|
+
**kwargs,
|
80
73
|
)
|
81
74
|
assert isinstance(job, jb.MLJob), f"Unexpected job type: {type(job)}"
|
82
75
|
return job
|