snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/platform_capabilities.py +36 -0
- snowflake/ml/_internal/utils/url.py +42 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
- snowflake/ml/data/data_connector.py +103 -1
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
- snowflake/ml/experiment/callback/__init__.py +0 -0
- snowflake/ml/experiment/callback/keras.py +25 -2
- snowflake/ml/experiment/callback/lightgbm.py +27 -2
- snowflake/ml/experiment/callback/xgboost.py +25 -2
- snowflake/ml/experiment/experiment_tracking.py +93 -3
- snowflake/ml/experiment/utils.py +6 -0
- snowflake/ml/feature_store/feature_view.py +34 -24
- snowflake/ml/jobs/_interop/protocols.py +3 -0
- snowflake/ml/jobs/_utils/constants.py +1 -0
- snowflake/ml/jobs/_utils/payload_utils.py +354 -356
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
- snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
- snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
- snowflake/ml/jobs/_utils/spec_utils.py +1 -445
- snowflake/ml/jobs/_utils/stage_utils.py +22 -1
- snowflake/ml/jobs/_utils/types.py +14 -7
- snowflake/ml/jobs/job.py +2 -8
- snowflake/ml/jobs/manager.py +57 -135
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/__init__.py +6 -0
- snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
- snowflake/ml/model/_client/model/model_version_impl.py +130 -14
- snowflake/ml/model/_client/ops/deployment_step.py +36 -0
- snowflake/ml/model/_client/ops/model_ops.py +93 -8
- snowflake/ml/model/_client/ops/service_ops.py +32 -52
- snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
- snowflake/ml/model/_client/sql/model_version.py +30 -6
- snowflake/ml/model/_client/sql/service.py +94 -5
- snowflake/ml/model/_model_composer/model_composer.py +1 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
- snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
- snowflake/ml/model/_packager/model_handler.py +8 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
- snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
- snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
- snowflake/ml/model/_packager/model_packager.py +1 -1
- snowflake/ml/model/_signatures/core.py +390 -8
- snowflake/ml/model/_signatures/utils.py +13 -4
- snowflake/ml/model/code_path.py +104 -0
- snowflake/ml/model/compute_pool.py +2 -0
- snowflake/ml/model/custom_model.py +55 -13
- snowflake/ml/model/model_signature.py +13 -1
- snowflake/ml/model/models/huggingface.py +285 -0
- snowflake/ml/model/models/huggingface_pipeline.py +19 -208
- snowflake/ml/model/type_hints.py +7 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
- snowflake/ml/registry/_manager/model_manager.py +230 -15
- snowflake/ml/registry/registry.py +4 -4
- snowflake/ml/utils/html_utils.py +67 -1
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
- snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
|
@@ -1,17 +1,6 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
|
-
import re
|
|
4
|
-
import sys
|
|
5
|
-
from math import ceil
|
|
6
|
-
from pathlib import PurePath
|
|
7
|
-
from typing import Any, Literal, Optional, Union
|
|
8
|
-
|
|
9
1
|
from snowflake import snowpark
|
|
10
2
|
from snowflake.ml._internal.utils import snowflake_env
|
|
11
|
-
from snowflake.ml.jobs._utils import constants,
|
|
12
|
-
from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
|
|
13
|
-
|
|
14
|
-
_OCI_TAG_REGEX = re.compile("^[a-zA-Z0-9._-]{1,128}$")
|
|
3
|
+
from snowflake.ml.jobs._utils import constants, query_helper, types
|
|
15
4
|
|
|
16
5
|
|
|
17
6
|
def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
|
|
@@ -31,436 +20,3 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
|
|
|
31
20
|
constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
|
|
32
21
|
or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
|
|
33
22
|
)
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
|
|
37
|
-
rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
|
|
38
|
-
if not rows:
|
|
39
|
-
return None
|
|
40
|
-
try:
|
|
41
|
-
runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
|
|
42
|
-
spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
|
|
43
|
-
except Exception as e:
|
|
44
|
-
logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
|
|
45
|
-
return None
|
|
46
|
-
|
|
47
|
-
selected_runtime = next(
|
|
48
|
-
(
|
|
49
|
-
runtime
|
|
50
|
-
for runtime in spcs_container_runtimes
|
|
51
|
-
if (
|
|
52
|
-
runtime.hardware_type.lower() == target_hardware.lower()
|
|
53
|
-
and runtime.python_version.major == sys.version_info.major
|
|
54
|
-
and runtime.python_version.minor == sys.version_info.minor
|
|
55
|
-
)
|
|
56
|
-
),
|
|
57
|
-
None,
|
|
58
|
-
)
|
|
59
|
-
return selected_runtime.runtime_container_image if selected_runtime else None
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def _check_image_tag_valid(tag: Optional[str]) -> bool:
|
|
63
|
-
if tag is None:
|
|
64
|
-
return False
|
|
65
|
-
|
|
66
|
-
return _OCI_TAG_REGEX.fullmatch(tag) is not None
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def _get_image_spec(
|
|
70
|
-
session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
|
|
71
|
-
) -> types.ImageSpec:
|
|
72
|
-
"""
|
|
73
|
-
Resolve image specification (container image and resources) for the job.
|
|
74
|
-
|
|
75
|
-
Behavior:
|
|
76
|
-
- If `runtime_environment` is empty or the feature flag is disabled, use the
|
|
77
|
-
default image tag and image name.
|
|
78
|
-
- If `runtime_environment` is a valid image tag, use that tag with the default
|
|
79
|
-
repository/name.
|
|
80
|
-
- If `runtime_environment` is a full image URL, use it directly.
|
|
81
|
-
- If the feature flag is enabled and `runtime_environment` is not provided,
|
|
82
|
-
select an ML Runtime image matching the local Python major.minor
|
|
83
|
-
- When multiple inputs are provided, `runtime_environment` takes priority.
|
|
84
|
-
|
|
85
|
-
Args:
|
|
86
|
-
session: Snowflake session.
|
|
87
|
-
compute_pool: Compute pool used to infer CPU/GPU resources.
|
|
88
|
-
runtime_environment: Optional image tag or full image URL to override.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Image spec including container image and resource requests/limits.
|
|
92
|
-
"""
|
|
93
|
-
# Retrieve compute pool node resources
|
|
94
|
-
resources = _get_node_resources(session, compute_pool=compute_pool)
|
|
95
|
-
hardware = "GPU" if resources.gpu > 0 else "CPU"
|
|
96
|
-
image_tag = _get_runtime_image_tag()
|
|
97
|
-
image_repo = constants.DEFAULT_IMAGE_REPO
|
|
98
|
-
image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
|
|
99
|
-
|
|
100
|
-
# Use MLRuntime image
|
|
101
|
-
container_image = None
|
|
102
|
-
if runtime_environment:
|
|
103
|
-
if _check_image_tag_valid(runtime_environment):
|
|
104
|
-
image_tag = runtime_environment
|
|
105
|
-
else:
|
|
106
|
-
container_image = runtime_environment
|
|
107
|
-
elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
|
|
108
|
-
container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
|
|
109
|
-
|
|
110
|
-
container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
|
|
111
|
-
# TODO: Should each instance consume the entire pod?
|
|
112
|
-
return types.ImageSpec(
|
|
113
|
-
resource_requests=resources,
|
|
114
|
-
resource_limits=resources,
|
|
115
|
-
container_image=container_image,
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def generate_spec_overrides(
|
|
120
|
-
environment_vars: Optional[dict[str, str]] = None,
|
|
121
|
-
custom_overrides: Optional[dict[str, Any]] = None,
|
|
122
|
-
) -> dict[str, Any]:
|
|
123
|
-
"""
|
|
124
|
-
Generate a dictionary of service specification overrides.
|
|
125
|
-
|
|
126
|
-
Args:
|
|
127
|
-
environment_vars: Environment variables to set in primary container
|
|
128
|
-
custom_overrides: Custom service specification overrides
|
|
129
|
-
|
|
130
|
-
Returns:
|
|
131
|
-
Resulting service specifiation patch dict. Empty if no overrides were supplied.
|
|
132
|
-
"""
|
|
133
|
-
# Generate container level overrides
|
|
134
|
-
container_spec: dict[str, Any] = {
|
|
135
|
-
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
if environment_vars:
|
|
139
|
-
# TODO: Validate environment variables
|
|
140
|
-
container_spec["env"] = environment_vars
|
|
141
|
-
|
|
142
|
-
# Build container override spec only if any overrides were supplied
|
|
143
|
-
spec = {}
|
|
144
|
-
if len(container_spec) > 1:
|
|
145
|
-
spec = {
|
|
146
|
-
"spec": {
|
|
147
|
-
"containers": [container_spec],
|
|
148
|
-
}
|
|
149
|
-
}
|
|
150
|
-
|
|
151
|
-
# Apply custom overrides
|
|
152
|
-
if custom_overrides:
|
|
153
|
-
spec = merge_patch(spec, custom_overrides, display_name="custom_overrides")
|
|
154
|
-
|
|
155
|
-
return spec
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
def generate_service_spec(
|
|
159
|
-
session: snowpark.Session,
|
|
160
|
-
compute_pool: str,
|
|
161
|
-
payload: types.UploadedPayload,
|
|
162
|
-
args: Optional[list[str]] = None,
|
|
163
|
-
target_instances: int = 1,
|
|
164
|
-
min_instances: int = 1,
|
|
165
|
-
enable_metrics: bool = False,
|
|
166
|
-
runtime_environment: Optional[str] = None,
|
|
167
|
-
) -> dict[str, Any]:
|
|
168
|
-
"""
|
|
169
|
-
Generate a service specification for a job.
|
|
170
|
-
|
|
171
|
-
Args:
|
|
172
|
-
session: Snowflake session
|
|
173
|
-
compute_pool: Compute pool for job execution
|
|
174
|
-
payload: Uploaded job payload
|
|
175
|
-
args: Arguments to pass to entrypoint script
|
|
176
|
-
target_instances: Number of instances for multi-node job
|
|
177
|
-
enable_metrics: Enable platform metrics for the job
|
|
178
|
-
min_instances: Minimum number of instances required to start the job
|
|
179
|
-
runtime_environment: The runtime image to use. Only support image tag or full image URL.
|
|
180
|
-
|
|
181
|
-
Returns:
|
|
182
|
-
Job service specification
|
|
183
|
-
"""
|
|
184
|
-
image_spec = _get_image_spec(session, compute_pool, runtime_environment)
|
|
185
|
-
|
|
186
|
-
# Set resource requests/limits, including nvidia.com/gpu quantity if applicable
|
|
187
|
-
resource_requests: dict[str, Union[str, int]] = {
|
|
188
|
-
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
|
189
|
-
"memory": f"{image_spec.resource_limits.memory}Gi",
|
|
190
|
-
}
|
|
191
|
-
resource_limits: dict[str, Union[str, int]] = {
|
|
192
|
-
"cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
|
|
193
|
-
"memory": f"{image_spec.resource_limits.memory}Gi",
|
|
194
|
-
}
|
|
195
|
-
if image_spec.resource_limits.gpu > 0:
|
|
196
|
-
resource_requests["nvidia.com/gpu"] = image_spec.resource_requests.gpu
|
|
197
|
-
resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
|
|
198
|
-
|
|
199
|
-
# Add local volumes for ephemeral logs and artifacts
|
|
200
|
-
volumes: list[dict[str, Any]] = []
|
|
201
|
-
volume_mounts: list[dict[str, str]] = []
|
|
202
|
-
for volume_name, mount_path in [
|
|
203
|
-
("system-logs", "/var/log/managedservices/system/mlrs"),
|
|
204
|
-
("user-logs", "/var/log/managedservices/user/mlrs"),
|
|
205
|
-
]:
|
|
206
|
-
volume_mounts.append(
|
|
207
|
-
{
|
|
208
|
-
"name": volume_name,
|
|
209
|
-
"mountPath": mount_path,
|
|
210
|
-
}
|
|
211
|
-
)
|
|
212
|
-
volumes.append(
|
|
213
|
-
{
|
|
214
|
-
"name": volume_name,
|
|
215
|
-
"source": "local",
|
|
216
|
-
}
|
|
217
|
-
)
|
|
218
|
-
|
|
219
|
-
# Mount 30% of memory limit as a memory-backed volume
|
|
220
|
-
memory_volume_size = min(
|
|
221
|
-
ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
|
|
222
|
-
image_spec.resource_requests.memory,
|
|
223
|
-
)
|
|
224
|
-
volume_mounts.append(
|
|
225
|
-
{
|
|
226
|
-
"name": constants.MEMORY_VOLUME_NAME,
|
|
227
|
-
"mountPath": "/dev/shm",
|
|
228
|
-
}
|
|
229
|
-
)
|
|
230
|
-
volumes.append(
|
|
231
|
-
{
|
|
232
|
-
"name": constants.MEMORY_VOLUME_NAME,
|
|
233
|
-
"source": "memory",
|
|
234
|
-
"size": f"{memory_volume_size}Gi",
|
|
235
|
-
}
|
|
236
|
-
)
|
|
237
|
-
|
|
238
|
-
# Mount payload as volume
|
|
239
|
-
stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
|
|
240
|
-
volume_mounts.append(
|
|
241
|
-
{
|
|
242
|
-
"name": constants.STAGE_VOLUME_NAME,
|
|
243
|
-
"mountPath": stage_mount.as_posix(),
|
|
244
|
-
}
|
|
245
|
-
)
|
|
246
|
-
volumes.append(
|
|
247
|
-
{
|
|
248
|
-
"name": constants.STAGE_VOLUME_NAME,
|
|
249
|
-
"source": "stage",
|
|
250
|
-
"stageConfig": {
|
|
251
|
-
"name": payload.stage_path.as_posix(),
|
|
252
|
-
"resources": {
|
|
253
|
-
"requests": {
|
|
254
|
-
"memory": "0Gi",
|
|
255
|
-
"cpu": "0",
|
|
256
|
-
},
|
|
257
|
-
},
|
|
258
|
-
},
|
|
259
|
-
}
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
# TODO: Add hooks for endpoints for integration with TensorBoard etc
|
|
263
|
-
|
|
264
|
-
env_vars = payload.env_vars
|
|
265
|
-
endpoints: list[dict[str, Any]] = []
|
|
266
|
-
|
|
267
|
-
if target_instances > 1:
|
|
268
|
-
# Update environment variables for multi-node job
|
|
269
|
-
env_vars.update(constants.RAY_PORTS)
|
|
270
|
-
env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
|
|
271
|
-
env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
|
|
272
|
-
|
|
273
|
-
# Define Ray endpoints for intra-service instance communication
|
|
274
|
-
ray_endpoints: list[dict[str, Any]] = [
|
|
275
|
-
{"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
|
|
276
|
-
{"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
|
|
277
|
-
{"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
|
|
278
|
-
{"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
|
|
279
|
-
{"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
|
|
280
|
-
{"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
|
|
281
|
-
{"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
|
|
282
|
-
{"name": "ray-dashboard-agent-grpc-endpoint", "port": 12014, "protocol": "TCP"},
|
|
283
|
-
{"name": "ephemeral-port-range", "portRange": "32768-60999", "protocol": "TCP"},
|
|
284
|
-
{"name": "ray-worker-port-range", "portRange": "12031-13000", "protocol": "TCP"},
|
|
285
|
-
]
|
|
286
|
-
endpoints.extend(ray_endpoints)
|
|
287
|
-
|
|
288
|
-
metrics = []
|
|
289
|
-
if enable_metrics:
|
|
290
|
-
# https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#label-spcs-available-platform-metrics
|
|
291
|
-
metrics = [
|
|
292
|
-
"system",
|
|
293
|
-
"status",
|
|
294
|
-
"network",
|
|
295
|
-
"storage",
|
|
296
|
-
]
|
|
297
|
-
|
|
298
|
-
spec_dict: dict[str, Any] = {
|
|
299
|
-
"containers": [
|
|
300
|
-
{
|
|
301
|
-
"name": constants.DEFAULT_CONTAINER_NAME,
|
|
302
|
-
"image": image_spec.container_image,
|
|
303
|
-
"command": ["/usr/local/bin/_entrypoint.sh"],
|
|
304
|
-
"args": [
|
|
305
|
-
(stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
306
|
-
]
|
|
307
|
-
+ (args or []),
|
|
308
|
-
"env": env_vars,
|
|
309
|
-
"volumeMounts": volume_mounts,
|
|
310
|
-
"resources": {
|
|
311
|
-
"requests": resource_requests,
|
|
312
|
-
"limits": resource_limits,
|
|
313
|
-
},
|
|
314
|
-
},
|
|
315
|
-
],
|
|
316
|
-
"volumes": volumes,
|
|
317
|
-
}
|
|
318
|
-
|
|
319
|
-
if target_instances > 1:
|
|
320
|
-
spec_dict.update(
|
|
321
|
-
{
|
|
322
|
-
"resourceManagement": {
|
|
323
|
-
"controlPolicy": {
|
|
324
|
-
"startupOrder": {
|
|
325
|
-
"type": "FirstInstance",
|
|
326
|
-
},
|
|
327
|
-
},
|
|
328
|
-
},
|
|
329
|
-
}
|
|
330
|
-
)
|
|
331
|
-
if endpoints:
|
|
332
|
-
spec_dict["endpoints"] = endpoints
|
|
333
|
-
if metrics:
|
|
334
|
-
spec_dict.update(
|
|
335
|
-
{
|
|
336
|
-
"platformMonitor": {
|
|
337
|
-
"metricConfig": {
|
|
338
|
-
"groups": metrics,
|
|
339
|
-
},
|
|
340
|
-
},
|
|
341
|
-
}
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
# Assemble into service specification dict
|
|
345
|
-
spec = {"spec": spec_dict}
|
|
346
|
-
|
|
347
|
-
return spec
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
|
|
351
|
-
"""
|
|
352
|
-
Implements a modified RFC7386 JSON Merge Patch
|
|
353
|
-
https://datatracker.ietf.org/doc/html/rfc7386
|
|
354
|
-
|
|
355
|
-
Behavior differs from the RFC in the following ways:
|
|
356
|
-
1. Empty nested dictionaries resulting from the patch are treated as None and are pruned
|
|
357
|
-
2. Attempts to merge lists of dicts using a merge key (default "name").
|
|
358
|
-
See _merge_lists_of_dicts for details on list merge behavior.
|
|
359
|
-
|
|
360
|
-
Args:
|
|
361
|
-
base: The base object to patch.
|
|
362
|
-
patch: The patch object.
|
|
363
|
-
display_name: The name of the patch object for logging purposes.
|
|
364
|
-
|
|
365
|
-
Returns:
|
|
366
|
-
The patched object.
|
|
367
|
-
"""
|
|
368
|
-
if type(base) is not type(patch):
|
|
369
|
-
if base is not None:
|
|
370
|
-
logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
|
|
371
|
-
return patch
|
|
372
|
-
elif isinstance(patch, list) and all(isinstance(v, dict) for v in base + patch):
|
|
373
|
-
# TODO: Should we prune empty lists?
|
|
374
|
-
return _merge_lists_of_dicts(base, patch, display_name=display_name)
|
|
375
|
-
elif not isinstance(patch, dict) or len(patch) == 0:
|
|
376
|
-
return patch
|
|
377
|
-
|
|
378
|
-
result = dict(base) # Shallow copy
|
|
379
|
-
for key, value in patch.items():
|
|
380
|
-
if value is None:
|
|
381
|
-
result.pop(key, None)
|
|
382
|
-
else:
|
|
383
|
-
merge_result = merge_patch(result.get(key, None), value, display_name=f"{display_name}.{key}")
|
|
384
|
-
if isinstance(merge_result, dict) and len(merge_result) == 0:
|
|
385
|
-
result.pop(key, None)
|
|
386
|
-
else:
|
|
387
|
-
result[key] = merge_result
|
|
388
|
-
|
|
389
|
-
return result
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
def _merge_lists_of_dicts(
|
|
393
|
-
base: list[dict[str, Any]],
|
|
394
|
-
patch: list[dict[str, Any]],
|
|
395
|
-
merge_key: str = "name",
|
|
396
|
-
display_name: str = "",
|
|
397
|
-
) -> list[dict[str, Any]]:
|
|
398
|
-
"""
|
|
399
|
-
Attempts to merge lists of dicts by matching on a merge key (default "name").
|
|
400
|
-
- If the merge key is missing, the behavior falls back to overwriting the list.
|
|
401
|
-
- If the merge key is present, the behavior is to match the list elements based on the
|
|
402
|
-
merge key and preserving any unmatched elements from the base list.
|
|
403
|
-
- Matched entries may be dropped in the following way(s):
|
|
404
|
-
1. The matching patch entry has a None key entry, e.g. { "name": "foo", None: None }.
|
|
405
|
-
|
|
406
|
-
Args:
|
|
407
|
-
base: The base list of dicts.
|
|
408
|
-
patch: The patch list of dicts.
|
|
409
|
-
merge_key: The key to use for merging.
|
|
410
|
-
display_name: The name of the patch object for logging purposes.
|
|
411
|
-
|
|
412
|
-
Returns:
|
|
413
|
-
The merged list of dicts if merging successful, else returns the patch list.
|
|
414
|
-
"""
|
|
415
|
-
if any(merge_key not in d for d in base + patch):
|
|
416
|
-
logging.warning(f"Missing merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
|
|
417
|
-
return patch
|
|
418
|
-
|
|
419
|
-
# Build mapping of merge key values to list elements for the base list
|
|
420
|
-
result = {d[merge_key]: d for d in base}
|
|
421
|
-
if len(result) != len(base):
|
|
422
|
-
logging.warning(f"Duplicate merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
|
|
423
|
-
return patch
|
|
424
|
-
|
|
425
|
-
# Apply patches
|
|
426
|
-
for d in patch:
|
|
427
|
-
key = d[merge_key]
|
|
428
|
-
|
|
429
|
-
# Removal case 1: `None` key in patch entry
|
|
430
|
-
if None in d:
|
|
431
|
-
result.pop(key, None)
|
|
432
|
-
continue
|
|
433
|
-
|
|
434
|
-
# Apply patch
|
|
435
|
-
if key in result:
|
|
436
|
-
d = merge_patch(
|
|
437
|
-
result[key],
|
|
438
|
-
d,
|
|
439
|
-
display_name=f"{display_name}[{merge_key}={d[merge_key]}]",
|
|
440
|
-
)
|
|
441
|
-
# TODO: Should we drop the item if the patch result is empty save for the merge key?
|
|
442
|
-
# Can check `d.keys() <= {merge_key}`
|
|
443
|
-
result[key] = d
|
|
444
|
-
|
|
445
|
-
return list(result.values())
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
def _get_runtime_image_tag() -> str:
|
|
449
|
-
"""
|
|
450
|
-
Detect runtime image tag from container environment.
|
|
451
|
-
|
|
452
|
-
Checks in order:
|
|
453
|
-
1. Environment variable MLRS_CONTAINER_IMAGE_TAG
|
|
454
|
-
2. Falls back to hardcoded default
|
|
455
|
-
|
|
456
|
-
Returns:
|
|
457
|
-
str: The runtime image tag to use for job containers
|
|
458
|
-
"""
|
|
459
|
-
env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
|
|
460
|
-
if env_tag:
|
|
461
|
-
logging.debug(f"Using runtime image tag from environment: {env_tag}")
|
|
462
|
-
return env_tag
|
|
463
|
-
|
|
464
|
-
# Fall back to default
|
|
465
|
-
logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
|
|
466
|
-
return constants.DEFAULT_IMAGE_TAG
|
|
@@ -2,10 +2,13 @@ import os
|
|
|
2
2
|
import re
|
|
3
3
|
from os import PathLike
|
|
4
4
|
from pathlib import Path, PurePath
|
|
5
|
-
from typing import Union
|
|
5
|
+
from typing import TYPE_CHECKING, Union
|
|
6
6
|
|
|
7
7
|
from snowflake.ml._internal.utils import identifier
|
|
8
8
|
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from snowflake.ml.jobs._utils import types
|
|
11
|
+
|
|
9
12
|
PROTOCOL_NAME = "snow"
|
|
10
13
|
_SNOWURL_PATH_RE = re.compile(
|
|
11
14
|
rf"^(?:(?:{PROTOCOL_NAME}://)?"
|
|
@@ -150,3 +153,21 @@ class StagePath:
|
|
|
150
153
|
# the arg might be an absolute path, so we need to remove the leading '/'
|
|
151
154
|
path = StagePath(f"{path.root}/{path._path.joinpath(arg).as_posix().lstrip('/')}")
|
|
152
155
|
return path
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def resolve_path(path: Union[str, Path]) -> "types.PayloadPath":
|
|
159
|
+
"""
|
|
160
|
+
Resolve a path to either a StagePath or a local Path.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
path: A string or Path object representing a local or stage path.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
A StagePath if the input is a valid stage path, otherwise a local Path.
|
|
167
|
+
"""
|
|
168
|
+
path_str = path.as_posix() if isinstance(path, Path) else path
|
|
169
|
+
try:
|
|
170
|
+
stage_path = StagePath(path_str)
|
|
171
|
+
except ValueError:
|
|
172
|
+
return Path(path_str)
|
|
173
|
+
return stage_path
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
from pathlib import PurePath
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Literal, Optional, Protocol, Union, runtime_checkable
|
|
5
|
+
|
|
6
|
+
from typing_extensions import Self
|
|
5
7
|
|
|
6
8
|
JOB_STATUS = Literal[
|
|
7
9
|
"PENDING",
|
|
@@ -19,6 +21,10 @@ JOB_STATUS = Literal[
|
|
|
19
21
|
class PayloadPath(Protocol):
|
|
20
22
|
"""A protocol for path-like objects used in this module, covering methods from pathlib.Path and StagePath."""
|
|
21
23
|
|
|
24
|
+
@property
|
|
25
|
+
def parts(self) -> tuple[str, ...]:
|
|
26
|
+
...
|
|
27
|
+
|
|
22
28
|
@property
|
|
23
29
|
def name(self) -> str:
|
|
24
30
|
...
|
|
@@ -32,7 +38,7 @@ class PayloadPath(Protocol):
|
|
|
32
38
|
...
|
|
33
39
|
|
|
34
40
|
@property
|
|
35
|
-
def parent(self) ->
|
|
41
|
+
def parent(self) -> Self:
|
|
36
42
|
...
|
|
37
43
|
|
|
38
44
|
@property
|
|
@@ -45,13 +51,16 @@ class PayloadPath(Protocol):
|
|
|
45
51
|
def is_file(self) -> bool:
|
|
46
52
|
...
|
|
47
53
|
|
|
54
|
+
def is_dir(self) -> bool:
|
|
55
|
+
...
|
|
56
|
+
|
|
48
57
|
def is_absolute(self) -> bool:
|
|
49
58
|
...
|
|
50
59
|
|
|
51
|
-
def absolute(self) ->
|
|
60
|
+
def absolute(self) -> Self:
|
|
52
61
|
...
|
|
53
62
|
|
|
54
|
-
def joinpath(self, *other: Union[str, os.PathLike[str]]) ->
|
|
63
|
+
def joinpath(self, *other: Union[str, os.PathLike[str]]) -> Self:
|
|
55
64
|
...
|
|
56
65
|
|
|
57
66
|
def as_posix(self) -> str:
|
|
@@ -79,9 +88,7 @@ class PayloadSpec:
|
|
|
79
88
|
|
|
80
89
|
source_path: PayloadPath
|
|
81
90
|
remote_relative_path: Optional[PurePath] = None
|
|
82
|
-
|
|
83
|
-
def __iter__(self) -> Iterator[Union[PayloadPath, Optional[PurePath]]]:
|
|
84
|
-
return iter((self.source_path, self.remote_relative_path))
|
|
91
|
+
compress: bool = False
|
|
85
92
|
|
|
86
93
|
|
|
87
94
|
@dataclass(frozen=True)
|
snowflake/ml/jobs/job.py
CHANGED
|
@@ -13,13 +13,7 @@ from snowflake.ml._internal import telemetry
|
|
|
13
13
|
from snowflake.ml._internal.utils import identifier
|
|
14
14
|
from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
|
|
15
15
|
from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
|
|
16
|
-
from snowflake.ml.jobs._utils import
|
|
17
|
-
constants,
|
|
18
|
-
payload_utils,
|
|
19
|
-
query_helper,
|
|
20
|
-
stage_utils,
|
|
21
|
-
types,
|
|
22
|
-
)
|
|
16
|
+
from snowflake.ml.jobs._utils import constants, query_helper, stage_utils, types
|
|
23
17
|
from snowflake.snowpark import Row, context as sp_context
|
|
24
18
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
25
19
|
|
|
@@ -131,7 +125,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
|
|
|
131
125
|
|
|
132
126
|
def _transform_path(self, path_str: str) -> str:
|
|
133
127
|
"""Transform a local path within the container to a stage path."""
|
|
134
|
-
path =
|
|
128
|
+
path = stage_utils.resolve_path(path_str)
|
|
135
129
|
if isinstance(path, stage_utils.StagePath):
|
|
136
130
|
# Stage paths need no transformation
|
|
137
131
|
return path.as_posix()
|