snowflake-ml-python 1.20.0__py3-none-any.whl → 1.22.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +36 -0
  2. snowflake/ml/_internal/utils/url.py +42 -0
  3. snowflake/ml/data/_internal/arrow_ingestor.py +67 -2
  4. snowflake/ml/data/data_connector.py +103 -1
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +8 -2
  6. snowflake/ml/experiment/callback/__init__.py +0 -0
  7. snowflake/ml/experiment/callback/keras.py +25 -2
  8. snowflake/ml/experiment/callback/lightgbm.py +27 -2
  9. snowflake/ml/experiment/callback/xgboost.py +25 -2
  10. snowflake/ml/experiment/experiment_tracking.py +93 -3
  11. snowflake/ml/experiment/utils.py +6 -0
  12. snowflake/ml/feature_store/feature_view.py +34 -24
  13. snowflake/ml/jobs/_interop/protocols.py +3 -0
  14. snowflake/ml/jobs/_utils/constants.py +1 -0
  15. snowflake/ml/jobs/_utils/payload_utils.py +354 -356
  16. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +95 -8
  17. snowflake/ml/jobs/_utils/scripts/start_mlruntime.sh +92 -0
  18. snowflake/ml/jobs/_utils/scripts/startup.sh +112 -0
  19. snowflake/ml/jobs/_utils/spec_utils.py +1 -445
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -1
  21. snowflake/ml/jobs/_utils/types.py +14 -7
  22. snowflake/ml/jobs/job.py +2 -8
  23. snowflake/ml/jobs/manager.py +57 -135
  24. snowflake/ml/lineage/lineage_node.py +1 -1
  25. snowflake/ml/model/__init__.py +6 -0
  26. snowflake/ml/model/_client/model/batch_inference_specs.py +16 -1
  27. snowflake/ml/model/_client/model/model_version_impl.py +130 -14
  28. snowflake/ml/model/_client/ops/deployment_step.py +36 -0
  29. snowflake/ml/model/_client/ops/model_ops.py +93 -8
  30. snowflake/ml/model/_client/ops/service_ops.py +32 -52
  31. snowflake/ml/model/_client/service/import_model_spec_schema.py +23 -0
  32. snowflake/ml/model/_client/service/model_deployment_spec.py +12 -4
  33. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -0
  34. snowflake/ml/model/_client/sql/model_version.py +30 -6
  35. snowflake/ml/model/_client/sql/service.py +94 -5
  36. snowflake/ml/model/_model_composer/model_composer.py +1 -1
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -0
  38. snowflake/ml/model/_model_composer/model_method/model_method.py +61 -2
  39. snowflake/ml/model/_packager/model_handler.py +8 -2
  40. snowflake/ml/model/_packager/model_handlers/custom.py +52 -0
  41. snowflake/ml/model/_packager/model_handlers/{huggingface_pipeline.py → huggingface.py} +203 -76
  42. snowflake/ml/model/_packager/model_handlers/mlflow.py +6 -1
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +26 -1
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +40 -7
  45. snowflake/ml/model/_packager/model_packager.py +1 -1
  46. snowflake/ml/model/_signatures/core.py +390 -8
  47. snowflake/ml/model/_signatures/utils.py +13 -4
  48. snowflake/ml/model/code_path.py +104 -0
  49. snowflake/ml/model/compute_pool.py +2 -0
  50. snowflake/ml/model/custom_model.py +55 -13
  51. snowflake/ml/model/model_signature.py +13 -1
  52. snowflake/ml/model/models/huggingface.py +285 -0
  53. snowflake/ml/model/models/huggingface_pipeline.py +19 -208
  54. snowflake/ml/model/type_hints.py +7 -1
  55. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  56. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +12 -0
  57. snowflake/ml/monitoring/_manager/model_monitor_manager.py +12 -0
  58. snowflake/ml/monitoring/entities/model_monitor_config.py +5 -0
  59. snowflake/ml/registry/_manager/model_manager.py +230 -15
  60. snowflake/ml/registry/registry.py +4 -4
  61. snowflake/ml/utils/html_utils.py +67 -1
  62. snowflake/ml/version.py +1 -1
  63. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/METADATA +81 -7
  64. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/RECORD +67 -59
  65. snowflake/ml/jobs/_utils/runtime_env_utils.py +0 -63
  66. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/WHEEL +0 -0
  67. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/licenses/LICENSE.txt +0 -0
  68. {snowflake_ml_python-1.20.0.dist-info → snowflake_ml_python-1.22.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,6 @@
1
- import logging
2
- import os
3
- import re
4
- import sys
5
- from math import ceil
6
- from pathlib import PurePath
7
- from typing import Any, Literal, Optional, Union
8
-
9
1
  from snowflake import snowpark
10
2
  from snowflake.ml._internal.utils import snowflake_env
11
- from snowflake.ml.jobs._utils import constants, feature_flags, query_helper, types
12
- from snowflake.ml.jobs._utils.runtime_env_utils import RuntimeEnvironmentsDict
13
-
14
- _OCI_TAG_REGEX = re.compile("^[a-zA-Z0-9._-]{1,128}$")
3
+ from snowflake.ml.jobs._utils import constants, query_helper, types
15
4
 
16
5
 
17
6
  def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.ComputeResources:
@@ -31,436 +20,3 @@ def _get_node_resources(session: snowpark.Session, compute_pool: str) -> types.C
31
20
  constants.COMMON_INSTANCE_FAMILIES.get(instance_family)
32
21
  or constants.CLOUD_INSTANCE_FAMILIES[cloud][instance_family]
33
22
  )
34
-
35
-
36
- def _get_runtime_image(session: snowpark.Session, target_hardware: Literal["CPU", "GPU"]) -> Optional[str]:
37
- rows = query_helper.run_query(session, "CALL SYSTEM$NOTEBOOKS_FIND_LABELED_RUNTIMES()")
38
- if not rows:
39
- return None
40
- try:
41
- runtime_envs = RuntimeEnvironmentsDict.model_validate_json(rows[0][0])
42
- spcs_container_runtimes = runtime_envs.get_spcs_container_runtimes()
43
- except Exception as e:
44
- logging.warning(f"Failed to parse runtime image name from {rows[0][0]}, error: {e}")
45
- return None
46
-
47
- selected_runtime = next(
48
- (
49
- runtime
50
- for runtime in spcs_container_runtimes
51
- if (
52
- runtime.hardware_type.lower() == target_hardware.lower()
53
- and runtime.python_version.major == sys.version_info.major
54
- and runtime.python_version.minor == sys.version_info.minor
55
- )
56
- ),
57
- None,
58
- )
59
- return selected_runtime.runtime_container_image if selected_runtime else None
60
-
61
-
62
- def _check_image_tag_valid(tag: Optional[str]) -> bool:
63
- if tag is None:
64
- return False
65
-
66
- return _OCI_TAG_REGEX.fullmatch(tag) is not None
67
-
68
-
69
- def _get_image_spec(
70
- session: snowpark.Session, compute_pool: str, runtime_environment: Optional[str] = None
71
- ) -> types.ImageSpec:
72
- """
73
- Resolve image specification (container image and resources) for the job.
74
-
75
- Behavior:
76
- - If `runtime_environment` is empty or the feature flag is disabled, use the
77
- default image tag and image name.
78
- - If `runtime_environment` is a valid image tag, use that tag with the default
79
- repository/name.
80
- - If `runtime_environment` is a full image URL, use it directly.
81
- - If the feature flag is enabled and `runtime_environment` is not provided,
82
- select an ML Runtime image matching the local Python major.minor
83
- - When multiple inputs are provided, `runtime_environment` takes priority.
84
-
85
- Args:
86
- session: Snowflake session.
87
- compute_pool: Compute pool used to infer CPU/GPU resources.
88
- runtime_environment: Optional image tag or full image URL to override.
89
-
90
- Returns:
91
- Image spec including container image and resource requests/limits.
92
- """
93
- # Retrieve compute pool node resources
94
- resources = _get_node_resources(session, compute_pool=compute_pool)
95
- hardware = "GPU" if resources.gpu > 0 else "CPU"
96
- image_tag = _get_runtime_image_tag()
97
- image_repo = constants.DEFAULT_IMAGE_REPO
98
- image_name = constants.DEFAULT_IMAGE_GPU if resources.gpu > 0 else constants.DEFAULT_IMAGE_CPU
99
-
100
- # Use MLRuntime image
101
- container_image = None
102
- if runtime_environment:
103
- if _check_image_tag_valid(runtime_environment):
104
- image_tag = runtime_environment
105
- else:
106
- container_image = runtime_environment
107
- elif feature_flags.FeatureFlags.ENABLE_RUNTIME_VERSIONS.is_enabled():
108
- container_image = _get_runtime_image(session, hardware) # type: ignore[arg-type]
109
-
110
- container_image = container_image or f"{image_repo}/{image_name}:{image_tag}"
111
- # TODO: Should each instance consume the entire pod?
112
- return types.ImageSpec(
113
- resource_requests=resources,
114
- resource_limits=resources,
115
- container_image=container_image,
116
- )
117
-
118
-
119
- def generate_spec_overrides(
120
- environment_vars: Optional[dict[str, str]] = None,
121
- custom_overrides: Optional[dict[str, Any]] = None,
122
- ) -> dict[str, Any]:
123
- """
124
- Generate a dictionary of service specification overrides.
125
-
126
- Args:
127
- environment_vars: Environment variables to set in primary container
128
- custom_overrides: Custom service specification overrides
129
-
130
- Returns:
131
- Resulting service specifiation patch dict. Empty if no overrides were supplied.
132
- """
133
- # Generate container level overrides
134
- container_spec: dict[str, Any] = {
135
- "name": constants.DEFAULT_CONTAINER_NAME,
136
- }
137
-
138
- if environment_vars:
139
- # TODO: Validate environment variables
140
- container_spec["env"] = environment_vars
141
-
142
- # Build container override spec only if any overrides were supplied
143
- spec = {}
144
- if len(container_spec) > 1:
145
- spec = {
146
- "spec": {
147
- "containers": [container_spec],
148
- }
149
- }
150
-
151
- # Apply custom overrides
152
- if custom_overrides:
153
- spec = merge_patch(spec, custom_overrides, display_name="custom_overrides")
154
-
155
- return spec
156
-
157
-
158
- def generate_service_spec(
159
- session: snowpark.Session,
160
- compute_pool: str,
161
- payload: types.UploadedPayload,
162
- args: Optional[list[str]] = None,
163
- target_instances: int = 1,
164
- min_instances: int = 1,
165
- enable_metrics: bool = False,
166
- runtime_environment: Optional[str] = None,
167
- ) -> dict[str, Any]:
168
- """
169
- Generate a service specification for a job.
170
-
171
- Args:
172
- session: Snowflake session
173
- compute_pool: Compute pool for job execution
174
- payload: Uploaded job payload
175
- args: Arguments to pass to entrypoint script
176
- target_instances: Number of instances for multi-node job
177
- enable_metrics: Enable platform metrics for the job
178
- min_instances: Minimum number of instances required to start the job
179
- runtime_environment: The runtime image to use. Only support image tag or full image URL.
180
-
181
- Returns:
182
- Job service specification
183
- """
184
- image_spec = _get_image_spec(session, compute_pool, runtime_environment)
185
-
186
- # Set resource requests/limits, including nvidia.com/gpu quantity if applicable
187
- resource_requests: dict[str, Union[str, int]] = {
188
- "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
189
- "memory": f"{image_spec.resource_limits.memory}Gi",
190
- }
191
- resource_limits: dict[str, Union[str, int]] = {
192
- "cpu": f"{int(image_spec.resource_requests.cpu * 1000)}m",
193
- "memory": f"{image_spec.resource_limits.memory}Gi",
194
- }
195
- if image_spec.resource_limits.gpu > 0:
196
- resource_requests["nvidia.com/gpu"] = image_spec.resource_requests.gpu
197
- resource_limits["nvidia.com/gpu"] = image_spec.resource_limits.gpu
198
-
199
- # Add local volumes for ephemeral logs and artifacts
200
- volumes: list[dict[str, Any]] = []
201
- volume_mounts: list[dict[str, str]] = []
202
- for volume_name, mount_path in [
203
- ("system-logs", "/var/log/managedservices/system/mlrs"),
204
- ("user-logs", "/var/log/managedservices/user/mlrs"),
205
- ]:
206
- volume_mounts.append(
207
- {
208
- "name": volume_name,
209
- "mountPath": mount_path,
210
- }
211
- )
212
- volumes.append(
213
- {
214
- "name": volume_name,
215
- "source": "local",
216
- }
217
- )
218
-
219
- # Mount 30% of memory limit as a memory-backed volume
220
- memory_volume_size = min(
221
- ceil(image_spec.resource_limits.memory * constants.MEMORY_VOLUME_SIZE),
222
- image_spec.resource_requests.memory,
223
- )
224
- volume_mounts.append(
225
- {
226
- "name": constants.MEMORY_VOLUME_NAME,
227
- "mountPath": "/dev/shm",
228
- }
229
- )
230
- volumes.append(
231
- {
232
- "name": constants.MEMORY_VOLUME_NAME,
233
- "source": "memory",
234
- "size": f"{memory_volume_size}Gi",
235
- }
236
- )
237
-
238
- # Mount payload as volume
239
- stage_mount = PurePath(constants.STAGE_VOLUME_MOUNT_PATH)
240
- volume_mounts.append(
241
- {
242
- "name": constants.STAGE_VOLUME_NAME,
243
- "mountPath": stage_mount.as_posix(),
244
- }
245
- )
246
- volumes.append(
247
- {
248
- "name": constants.STAGE_VOLUME_NAME,
249
- "source": "stage",
250
- "stageConfig": {
251
- "name": payload.stage_path.as_posix(),
252
- "resources": {
253
- "requests": {
254
- "memory": "0Gi",
255
- "cpu": "0",
256
- },
257
- },
258
- },
259
- }
260
- )
261
-
262
- # TODO: Add hooks for endpoints for integration with TensorBoard etc
263
-
264
- env_vars = payload.env_vars
265
- endpoints: list[dict[str, Any]] = []
266
-
267
- if target_instances > 1:
268
- # Update environment variables for multi-node job
269
- env_vars.update(constants.RAY_PORTS)
270
- env_vars[constants.ENABLE_HEALTH_CHECKS_ENV_VAR] = constants.ENABLE_HEALTH_CHECKS
271
- env_vars[constants.MIN_INSTANCES_ENV_VAR] = str(min_instances)
272
-
273
- # Define Ray endpoints for intra-service instance communication
274
- ray_endpoints: list[dict[str, Any]] = [
275
- {"name": "ray-client-server-endpoint", "port": 10001, "protocol": "TCP"},
276
- {"name": "ray-gcs-endpoint", "port": 12001, "protocol": "TCP"},
277
- {"name": "ray-dashboard-grpc-endpoint", "port": 12002, "protocol": "TCP"},
278
- {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
279
- {"name": "ray-object-manager-endpoint", "port": 12011, "protocol": "TCP"},
280
- {"name": "ray-node-manager-endpoint", "port": 12012, "protocol": "TCP"},
281
- {"name": "ray-runtime-agent-endpoint", "port": 12013, "protocol": "TCP"},
282
- {"name": "ray-dashboard-agent-grpc-endpoint", "port": 12014, "protocol": "TCP"},
283
- {"name": "ephemeral-port-range", "portRange": "32768-60999", "protocol": "TCP"},
284
- {"name": "ray-worker-port-range", "portRange": "12031-13000", "protocol": "TCP"},
285
- ]
286
- endpoints.extend(ray_endpoints)
287
-
288
- metrics = []
289
- if enable_metrics:
290
- # https://docs.snowflake.com/en/developer-guide/snowpark-container-services/monitoring-services#label-spcs-available-platform-metrics
291
- metrics = [
292
- "system",
293
- "status",
294
- "network",
295
- "storage",
296
- ]
297
-
298
- spec_dict: dict[str, Any] = {
299
- "containers": [
300
- {
301
- "name": constants.DEFAULT_CONTAINER_NAME,
302
- "image": image_spec.container_image,
303
- "command": ["/usr/local/bin/_entrypoint.sh"],
304
- "args": [
305
- (stage_mount.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
306
- ]
307
- + (args or []),
308
- "env": env_vars,
309
- "volumeMounts": volume_mounts,
310
- "resources": {
311
- "requests": resource_requests,
312
- "limits": resource_limits,
313
- },
314
- },
315
- ],
316
- "volumes": volumes,
317
- }
318
-
319
- if target_instances > 1:
320
- spec_dict.update(
321
- {
322
- "resourceManagement": {
323
- "controlPolicy": {
324
- "startupOrder": {
325
- "type": "FirstInstance",
326
- },
327
- },
328
- },
329
- }
330
- )
331
- if endpoints:
332
- spec_dict["endpoints"] = endpoints
333
- if metrics:
334
- spec_dict.update(
335
- {
336
- "platformMonitor": {
337
- "metricConfig": {
338
- "groups": metrics,
339
- },
340
- },
341
- }
342
- )
343
-
344
- # Assemble into service specification dict
345
- spec = {"spec": spec_dict}
346
-
347
- return spec
348
-
349
-
350
- def merge_patch(base: Any, patch: Any, display_name: str = "") -> Any:
351
- """
352
- Implements a modified RFC7386 JSON Merge Patch
353
- https://datatracker.ietf.org/doc/html/rfc7386
354
-
355
- Behavior differs from the RFC in the following ways:
356
- 1. Empty nested dictionaries resulting from the patch are treated as None and are pruned
357
- 2. Attempts to merge lists of dicts using a merge key (default "name").
358
- See _merge_lists_of_dicts for details on list merge behavior.
359
-
360
- Args:
361
- base: The base object to patch.
362
- patch: The patch object.
363
- display_name: The name of the patch object for logging purposes.
364
-
365
- Returns:
366
- The patched object.
367
- """
368
- if type(base) is not type(patch):
369
- if base is not None:
370
- logging.warning(f"Type mismatch while merging {display_name} (base={type(base)}, patch={type(patch)})")
371
- return patch
372
- elif isinstance(patch, list) and all(isinstance(v, dict) for v in base + patch):
373
- # TODO: Should we prune empty lists?
374
- return _merge_lists_of_dicts(base, patch, display_name=display_name)
375
- elif not isinstance(patch, dict) or len(patch) == 0:
376
- return patch
377
-
378
- result = dict(base) # Shallow copy
379
- for key, value in patch.items():
380
- if value is None:
381
- result.pop(key, None)
382
- else:
383
- merge_result = merge_patch(result.get(key, None), value, display_name=f"{display_name}.{key}")
384
- if isinstance(merge_result, dict) and len(merge_result) == 0:
385
- result.pop(key, None)
386
- else:
387
- result[key] = merge_result
388
-
389
- return result
390
-
391
-
392
- def _merge_lists_of_dicts(
393
- base: list[dict[str, Any]],
394
- patch: list[dict[str, Any]],
395
- merge_key: str = "name",
396
- display_name: str = "",
397
- ) -> list[dict[str, Any]]:
398
- """
399
- Attempts to merge lists of dicts by matching on a merge key (default "name").
400
- - If the merge key is missing, the behavior falls back to overwriting the list.
401
- - If the merge key is present, the behavior is to match the list elements based on the
402
- merge key and preserving any unmatched elements from the base list.
403
- - Matched entries may be dropped in the following way(s):
404
- 1. The matching patch entry has a None key entry, e.g. { "name": "foo", None: None }.
405
-
406
- Args:
407
- base: The base list of dicts.
408
- patch: The patch list of dicts.
409
- merge_key: The key to use for merging.
410
- display_name: The name of the patch object for logging purposes.
411
-
412
- Returns:
413
- The merged list of dicts if merging successful, else returns the patch list.
414
- """
415
- if any(merge_key not in d for d in base + patch):
416
- logging.warning(f"Missing merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
417
- return patch
418
-
419
- # Build mapping of merge key values to list elements for the base list
420
- result = {d[merge_key]: d for d in base}
421
- if len(result) != len(base):
422
- logging.warning(f"Duplicate merge key {merge_key} in {display_name}. Falling back to overwrite behavior.")
423
- return patch
424
-
425
- # Apply patches
426
- for d in patch:
427
- key = d[merge_key]
428
-
429
- # Removal case 1: `None` key in patch entry
430
- if None in d:
431
- result.pop(key, None)
432
- continue
433
-
434
- # Apply patch
435
- if key in result:
436
- d = merge_patch(
437
- result[key],
438
- d,
439
- display_name=f"{display_name}[{merge_key}={d[merge_key]}]",
440
- )
441
- # TODO: Should we drop the item if the patch result is empty save for the merge key?
442
- # Can check `d.keys() <= {merge_key}`
443
- result[key] = d
444
-
445
- return list(result.values())
446
-
447
-
448
- def _get_runtime_image_tag() -> str:
449
- """
450
- Detect runtime image tag from container environment.
451
-
452
- Checks in order:
453
- 1. Environment variable MLRS_CONTAINER_IMAGE_TAG
454
- 2. Falls back to hardcoded default
455
-
456
- Returns:
457
- str: The runtime image tag to use for job containers
458
- """
459
- env_tag = os.environ.get(constants.RUNTIME_IMAGE_TAG_ENV_VAR)
460
- if env_tag:
461
- logging.debug(f"Using runtime image tag from environment: {env_tag}")
462
- return env_tag
463
-
464
- # Fall back to default
465
- logging.debug(f"Using default runtime image tag: {constants.DEFAULT_IMAGE_TAG}")
466
- return constants.DEFAULT_IMAGE_TAG
@@ -2,10 +2,13 @@ import os
2
2
  import re
3
3
  from os import PathLike
4
4
  from pathlib import Path, PurePath
5
- from typing import Union
5
+ from typing import TYPE_CHECKING, Union
6
6
 
7
7
  from snowflake.ml._internal.utils import identifier
8
8
 
9
+ if TYPE_CHECKING:
10
+ from snowflake.ml.jobs._utils import types
11
+
9
12
  PROTOCOL_NAME = "snow"
10
13
  _SNOWURL_PATH_RE = re.compile(
11
14
  rf"^(?:(?:{PROTOCOL_NAME}://)?"
@@ -150,3 +153,21 @@ class StagePath:
150
153
  # the arg might be an absolute path, so we need to remove the leading '/'
151
154
  path = StagePath(f"{path.root}/{path._path.joinpath(arg).as_posix().lstrip('/')}")
152
155
  return path
156
+
157
+
158
+ def resolve_path(path: Union[str, Path]) -> "types.PayloadPath":
159
+ """
160
+ Resolve a path to either a StagePath or a local Path.
161
+
162
+ Args:
163
+ path: A string or Path object representing a local or stage path.
164
+
165
+ Returns:
166
+ A StagePath if the input is a valid stage path, otherwise a local Path.
167
+ """
168
+ path_str = path.as_posix() if isinstance(path, Path) else path
169
+ try:
170
+ stage_path = StagePath(path_str)
171
+ except ValueError:
172
+ return Path(path_str)
173
+ return stage_path
@@ -1,7 +1,9 @@
1
1
  import os
2
2
  from dataclasses import dataclass, field
3
3
  from pathlib import PurePath
4
- from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
4
+ from typing import Literal, Optional, Protocol, Union, runtime_checkable
5
+
6
+ from typing_extensions import Self
5
7
 
6
8
  JOB_STATUS = Literal[
7
9
  "PENDING",
@@ -19,6 +21,10 @@ JOB_STATUS = Literal[
19
21
  class PayloadPath(Protocol):
20
22
  """A protocol for path-like objects used in this module, covering methods from pathlib.Path and StagePath."""
21
23
 
24
+ @property
25
+ def parts(self) -> tuple[str, ...]:
26
+ ...
27
+
22
28
  @property
23
29
  def name(self) -> str:
24
30
  ...
@@ -32,7 +38,7 @@ class PayloadPath(Protocol):
32
38
  ...
33
39
 
34
40
  @property
35
- def parent(self) -> "PayloadPath":
41
+ def parent(self) -> Self:
36
42
  ...
37
43
 
38
44
  @property
@@ -45,13 +51,16 @@ class PayloadPath(Protocol):
45
51
  def is_file(self) -> bool:
46
52
  ...
47
53
 
54
+ def is_dir(self) -> bool:
55
+ ...
56
+
48
57
  def is_absolute(self) -> bool:
49
58
  ...
50
59
 
51
- def absolute(self) -> "PayloadPath":
60
+ def absolute(self) -> Self:
52
61
  ...
53
62
 
54
- def joinpath(self, *other: Union[str, os.PathLike[str]]) -> "PayloadPath":
63
+ def joinpath(self, *other: Union[str, os.PathLike[str]]) -> Self:
55
64
  ...
56
65
 
57
66
  def as_posix(self) -> str:
@@ -79,9 +88,7 @@ class PayloadSpec:
79
88
 
80
89
  source_path: PayloadPath
81
90
  remote_relative_path: Optional[PurePath] = None
82
-
83
- def __iter__(self) -> Iterator[Union[PayloadPath, Optional[PurePath]]]:
84
- return iter((self.source_path, self.remote_relative_path))
91
+ compress: bool = False
85
92
 
86
93
 
87
94
  @dataclass(frozen=True)
snowflake/ml/jobs/job.py CHANGED
@@ -13,13 +13,7 @@ from snowflake.ml._internal import telemetry
13
13
  from snowflake.ml._internal.utils import identifier
14
14
  from snowflake.ml._internal.utils.mixins import SerializableSessionMixin
15
15
  from snowflake.ml.jobs._interop import results as interop_result, utils as interop_utils
16
- from snowflake.ml.jobs._utils import (
17
- constants,
18
- payload_utils,
19
- query_helper,
20
- stage_utils,
21
- types,
22
- )
16
+ from snowflake.ml.jobs._utils import constants, query_helper, stage_utils, types
23
17
  from snowflake.snowpark import Row, context as sp_context
24
18
  from snowflake.snowpark.exceptions import SnowparkSQLException
25
19
 
@@ -131,7 +125,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
131
125
 
132
126
  def _transform_path(self, path_str: str) -> str:
133
127
  """Transform a local path within the container to a stage path."""
134
- path = payload_utils.resolve_path(path_str)
128
+ path = stage_utils.resolve_path(path_str)
135
129
  if isinstance(path, stage_utils.StagePath):
136
130
  # Stage paths need no transformation
137
131
  return path.as_posix()