snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. snowflake/ml/_internal/utils/mixins.py +6 -4
  2. snowflake/ml/_internal/utils/service_logger.py +118 -4
  3. snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
  4. snowflake/ml/data/data_connector.py +4 -34
  5. snowflake/ml/dataset/dataset.py +1 -1
  6. snowflake/ml/dataset/dataset_reader.py +2 -8
  7. snowflake/ml/experiment/__init__.py +3 -0
  8. snowflake/ml/experiment/callback/lightgbm.py +55 -0
  9. snowflake/ml/experiment/callback/xgboost.py +63 -0
  10. snowflake/ml/experiment/utils.py +14 -0
  11. snowflake/ml/jobs/_utils/constants.py +15 -4
  12. snowflake/ml/jobs/_utils/payload_utils.py +159 -52
  13. snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
  14. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
  15. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  16. snowflake/ml/jobs/_utils/stage_utils.py +30 -14
  17. snowflake/ml/jobs/_utils/types.py +64 -4
  18. snowflake/ml/jobs/job.py +22 -6
  19. snowflake/ml/jobs/manager.py +5 -3
  20. snowflake/ml/model/_client/model/model_version_impl.py +56 -48
  21. snowflake/ml/model/_client/ops/service_ops.py +194 -14
  22. snowflake/ml/model/_client/sql/service.py +1 -38
  23. snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
  24. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
  25. snowflake/ml/model/_signatures/pandas_handler.py +3 -0
  26. snowflake/ml/model/_signatures/utils.py +4 -0
  27. snowflake/ml/model/event_handler.py +87 -18
  28. snowflake/ml/model/model_signature.py +2 -0
  29. snowflake/ml/model/models/huggingface_pipeline.py +71 -49
  30. snowflake/ml/model/type_hints.py +26 -1
  31. snowflake/ml/registry/_manager/model_manager.py +30 -35
  32. snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
  33. snowflake/ml/registry/registry.py +0 -19
  34. snowflake/ml/version.py +1 -1
  35. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
  36. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
  37. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
  38. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
  39. {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,10 @@ _SNOWURL_PATH_RE = re.compile(
14
14
  r"(?P<path>versions(?:/(?P<version>[^/]+)(?:/(?P<relpath>.*))?)?)$"
15
15
  )
16
16
 
17
- _STAGEF_PATH_RE = re.compile(r"^@(?P<stage>~|%?\w+)(?:/(?P<relpath>[\w\-./]*))?$")
17
+ # Break long regex into two main parts
18
+ _STAGE_PATTERN = rf"~|%?(?:(?:{identifier._SF_IDENTIFIER}\.?){{,2}}{identifier._SF_IDENTIFIER})"
19
+ _RELPATH_PATTERN = r"[\w\-./]*"
20
+ _STAGEF_PATH_RE = re.compile(rf"^@(?P<stage>{_STAGE_PATTERN})(?:/(?P<relpath>{_RELPATH_PATTERN}))?$")
18
21
 
19
22
 
20
23
  class StagePath:
@@ -29,6 +32,14 @@ class StagePath:
29
32
  self._root = self._raw_path[0:start].rstrip("/") if relpath else self._raw_path.rstrip("/")
30
33
  self._path = Path(relpath or "")
31
34
 
35
+ @property
36
+ def parts(self) -> tuple[str, ...]:
37
+ return self._path.parts
38
+
39
+ @property
40
+ def name(self) -> str:
41
+ return self._path.name
42
+
32
43
  @property
33
44
  def parent(self) -> "StagePath":
34
45
  if self._path.parent == Path(""):
@@ -51,18 +62,28 @@ class StagePath:
51
62
  else:
52
63
  return f"{self.root}/{path}"
53
64
 
54
- def is_relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> bool:
65
+ def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
66
+ if not other:
67
+ raise TypeError("is_relative_to() requires at least one argument")
68
+ # For now, we only support a single argument, like pathlib.Path in Python < 3.12
69
+ path = other[0]
55
70
  stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
56
71
  if stage_path.root == self.root:
57
72
  return self._path.is_relative_to(stage_path._path)
58
73
  else:
59
74
  return False
60
75
 
61
- def relative_to(self, path: Union[str, PathLike[str], "StagePath"]) -> PurePath:
76
+ def relative_to(self, *other: Union[str, os.PathLike[str]]) -> PurePath:
77
+ if not other:
78
+ raise TypeError("relative_to() requires at least one argument")
79
+ if not self.is_relative_to(*other):
80
+ raise ValueError(f"{other} does not start with {self._raw_path}")
81
+ path = other[0]
62
82
  stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
63
83
  if self.root == stage_path.root:
64
84
  return self._path.relative_to(stage_path._path)
65
- raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
85
+ else:
86
+ raise ValueError(f"{self._raw_path} does not start with {stage_path._raw_path}")
66
87
 
67
88
  def absolute(self) -> "StagePath":
68
89
  return self
@@ -88,6 +109,9 @@ class StagePath:
88
109
  def __str__(self) -> str:
89
110
  return self.as_posix()
90
111
 
112
+ def __repr__(self) -> str:
113
+ return f"StagePath('{self.as_posix()}')"
114
+
91
115
  def __eq__(self, other: object) -> bool:
92
116
  if not isinstance(other, StagePath):
93
117
  raise NotImplementedError
@@ -96,24 +120,16 @@ class StagePath:
96
120
  def __fspath__(self) -> str:
97
121
  return self._compose_path(self._path)
98
122
 
99
- def joinpath(self, *args: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
123
+ def joinpath(self, *args: Union[str, PathLike[str]]) -> "StagePath":
100
124
  path = self
101
125
  for arg in args:
102
126
  path = path._make_child(arg)
103
127
  return path
104
128
 
105
- def _make_child(self, path: Union[str, PathLike[str], "StagePath"]) -> "StagePath":
129
+ def _make_child(self, path: Union[str, PathLike[str]]) -> "StagePath":
106
130
  stage_path = path if isinstance(path, StagePath) else StagePath(os.fspath(path))
107
131
  if self.root == stage_path.root:
108
132
  child_path = self._path.joinpath(stage_path._path)
109
133
  return StagePath(self._compose_path(child_path))
110
134
  else:
111
135
  return stage_path
112
-
113
-
114
- def identify_stage_path(path: str) -> Union[StagePath, Path]:
115
- try:
116
- stage_path = StagePath(path)
117
- except ValueError:
118
- return Path(path)
119
- return stage_path
@@ -1,8 +1,7 @@
1
+ import os
1
2
  from dataclasses import dataclass
2
3
  from pathlib import PurePath
3
- from typing import Literal, Optional, Union
4
-
5
- from snowflake.ml.jobs._utils import stage_utils
4
+ from typing import Iterator, Literal, Optional, Protocol, Union, runtime_checkable
6
5
 
7
6
  JOB_STATUS = Literal[
8
7
  "PENDING",
@@ -15,9 +14,70 @@ JOB_STATUS = Literal[
15
14
  ]
16
15
 
17
16
 
17
+ @runtime_checkable
18
+ class PayloadPath(Protocol):
19
+ """A protocol for path-like objects used in this module, covering methods from pathlib.Path and StagePath."""
20
+
21
+ @property
22
+ def name(self) -> str:
23
+ ...
24
+
25
+ @property
26
+ def suffix(self) -> str:
27
+ ...
28
+
29
+ @property
30
+ def parent(self) -> "PayloadPath":
31
+ ...
32
+
33
+ def exists(self) -> bool:
34
+ ...
35
+
36
+ def is_file(self) -> bool:
37
+ ...
38
+
39
+ def is_absolute(self) -> bool:
40
+ ...
41
+
42
+ def absolute(self) -> "PayloadPath":
43
+ ...
44
+
45
+ def joinpath(self, *other: Union[str, os.PathLike[str]]) -> "PayloadPath":
46
+ ...
47
+
48
+ def as_posix(self) -> str:
49
+ ...
50
+
51
+ def is_relative_to(self, *other: Union[str, os.PathLike[str]]) -> bool:
52
+ ...
53
+
54
+ def relative_to(self, *other: Union[str, os.PathLike[str]]) -> PurePath:
55
+ ...
56
+
57
+ def __fspath__(self) -> str:
58
+ ...
59
+
60
+ def __str__(self) -> str:
61
+ ...
62
+
63
+ def __repr__(self) -> str:
64
+ ...
65
+
66
+
67
+ @dataclass
68
+ class PayloadSpec:
69
+ """Represents a payload item to be uploaded."""
70
+
71
+ source_path: PayloadPath
72
+ remote_relative_path: Optional[PurePath] = None
73
+
74
+ def __iter__(self) -> Iterator[Union[PayloadPath, Optional[PurePath]]]:
75
+ return iter((self.source_path, self.remote_relative_path))
76
+
77
+
18
78
  @dataclass(frozen=True)
19
79
  class PayloadEntrypoint:
20
- file_path: Union[PurePath, stage_utils.StagePath]
80
+ file_path: PayloadPath
21
81
  main_func: Optional[str]
22
82
 
23
83
 
snowflake/ml/jobs/job.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
3
  import os
4
4
  import time
5
5
  from functools import cached_property
6
+ from pathlib import Path
6
7
  from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
7
8
 
8
9
  import yaml
@@ -95,10 +96,24 @@ class MLJob(Generic[T], SerializableSessionMixin):
95
96
  @property
96
97
  def _result_path(self) -> str:
97
98
  """Get the job's result file location."""
98
- result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
99
- if result_path is None:
99
+ result_path_str = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
100
+ if result_path_str is None:
100
101
  raise RuntimeError(f"Job {self.name} doesn't have a result path configured")
101
- return f"{self._stage_path}/{result_path}"
102
+ volume_mounts = self._container_spec["volumeMounts"]
103
+ stage_mount_str = next(v for v in volume_mounts if v.get("name") == constants.STAGE_VOLUME_NAME)["mountPath"]
104
+
105
+ result_path = Path(result_path_str)
106
+ stage_mount = Path(stage_mount_str)
107
+ try:
108
+ relative_path = result_path.relative_to(stage_mount)
109
+ except ValueError:
110
+ if result_path.is_absolute():
111
+ raise ValueError(
112
+ f"Result path {result_path} is absolute, but should be relative to stage mount {stage_mount}"
113
+ )
114
+ relative_path = result_path
115
+
116
+ return f"{self._stage_path}/{relative_path.as_posix()}"
102
117
 
103
118
  @overload
104
119
  def get_logs(
@@ -181,7 +196,10 @@ class MLJob(Generic[T], SerializableSessionMixin):
181
196
  start_time = time.monotonic()
182
197
  warning_shown = False
183
198
  while (status := self.status) not in TERMINAL_JOB_STATUSES:
184
- if status == "PENDING" and not warning_shown:
199
+ elapsed = time.monotonic() - start_time
200
+ if elapsed >= timeout >= 0:
201
+ raise TimeoutError(f"Job {self.name} did not complete within {timeout} seconds")
202
+ elif status == "PENDING" and not warning_shown and elapsed >= 2: # Only show warning after 2s
185
203
  pool_info = _get_compute_pool_info(self._session, self._compute_pool)
186
204
  if (pool_info.max_nodes - pool_info.active_nodes) < self.min_instances:
187
205
  logger.warning(
@@ -189,8 +207,6 @@ class MLJob(Generic[T], SerializableSessionMixin):
189
207
  f"{self.min_instances} nodes required). Job execution may be delayed."
190
208
  )
191
209
  warning_shown = True
192
- if timeout >= 0 and (elapsed := time.monotonic() - start_time) >= timeout:
193
- raise TimeoutError(f"Job {self.name} did not complete within {elapsed} seconds")
194
210
  time.sleep(delay)
195
211
  delay = min(delay * 1.2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
196
212
  return self.status
@@ -447,6 +447,10 @@ def _submit_job(
447
447
  spec_overrides = kwargs.pop("spec_overrides", None)
448
448
  enable_metrics = kwargs.pop("enable_metrics", True)
449
449
  query_warehouse = kwargs.pop("query_warehouse", None)
450
+ additional_payloads = kwargs.pop("additional_payloads", None)
451
+
452
+ if additional_payloads:
453
+ logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.")
450
454
 
451
455
  # Warn if there are unknown kwargs
452
456
  if kwargs:
@@ -477,9 +481,7 @@ def _submit_job(
477
481
 
478
482
  # Upload payload
479
483
  uploaded_payload = payload_utils.JobPayload(
480
- source,
481
- entrypoint=entrypoint,
482
- pip_requirements=pip_requirements,
484
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
483
485
  ).upload(session, stage_path)
484
486
 
485
487
  # Generate service spec
@@ -1,5 +1,4 @@
1
1
  import enum
2
- import logging
3
2
  import pathlib
4
3
  import tempfile
5
4
  import warnings
@@ -881,6 +880,7 @@ class ModelVersion(lineage_node.LineageNode):
881
880
 
882
881
  Raises:
883
882
  ValueError: Illegal external access integration arguments.
883
+ exceptions.SnowparkSQLException: if service already exists.
884
884
 
885
885
  Returns:
886
886
  If `block=True`, return result information about service creation from server.
@@ -891,16 +891,6 @@ class ModelVersion(lineage_node.LineageNode):
891
891
  subproject=_TELEMETRY_SUBPROJECT,
892
892
  )
893
893
 
894
- # Check root logger level and emit warning if needed
895
- root_logger = logging.getLogger()
896
- if root_logger.level in (logging.WARNING, logging.ERROR):
897
- warnings.warn(
898
- "Suppressing service logs. Set the log level to INFO if you would like "
899
- "verbose service logs (e.g., logging.getLogger().setLevel(logging.INFO)).",
900
- UserWarning,
901
- stacklevel=2,
902
- )
903
-
904
894
  if build_external_access_integration is not None:
905
895
  msg = (
906
896
  "`build_external_access_integration` is deprecated. "
@@ -917,39 +907,60 @@ class ModelVersion(lineage_node.LineageNode):
917
907
 
918
908
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
919
909
  image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
920
- return self._service_ops.create_service(
921
- database_name=None,
922
- schema_name=None,
923
- model_name=self._model_name,
924
- version_name=self._version_name,
925
- service_database_name=service_db_id,
926
- service_schema_name=service_schema_id,
927
- service_name=service_id,
928
- image_build_compute_pool_name=(
929
- sql_identifier.SqlIdentifier(image_build_compute_pool)
930
- if image_build_compute_pool
931
- else sql_identifier.SqlIdentifier(service_compute_pool)
932
- ),
933
- service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
934
- image_repo_database_name=image_repo_db_id,
935
- image_repo_schema_name=image_repo_schema_id,
936
- image_repo_name=image_repo_id,
937
- ingress_enabled=ingress_enabled,
938
- max_instances=max_instances,
939
- cpu_requests=cpu_requests,
940
- memory_requests=memory_requests,
941
- gpu_requests=gpu_requests,
942
- num_workers=num_workers,
943
- max_batch_rows=max_batch_rows,
944
- force_rebuild=force_rebuild,
945
- build_external_access_integrations=(
946
- None
947
- if build_external_access_integrations is None
948
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
949
- ),
950
- block=block,
951
- statement_params=statement_params,
952
- )
910
+
911
+ from snowflake.ml.model import event_handler
912
+ from snowflake.snowpark import exceptions
913
+
914
+ model_event_handler = event_handler.ModelEventHandler()
915
+
916
+ with model_event_handler.status("Creating model inference service", total=6, block=block) as status:
917
+ try:
918
+ result = self._service_ops.create_service(
919
+ database_name=None,
920
+ schema_name=None,
921
+ model_name=self._model_name,
922
+ version_name=self._version_name,
923
+ service_database_name=service_db_id,
924
+ service_schema_name=service_schema_id,
925
+ service_name=service_id,
926
+ image_build_compute_pool_name=(
927
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
928
+ if image_build_compute_pool
929
+ else sql_identifier.SqlIdentifier(service_compute_pool)
930
+ ),
931
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
932
+ image_repo=image_repo,
933
+ ingress_enabled=ingress_enabled,
934
+ max_instances=max_instances,
935
+ cpu_requests=cpu_requests,
936
+ memory_requests=memory_requests,
937
+ gpu_requests=gpu_requests,
938
+ num_workers=num_workers,
939
+ max_batch_rows=max_batch_rows,
940
+ force_rebuild=force_rebuild,
941
+ build_external_access_integrations=(
942
+ None
943
+ if build_external_access_integrations is None
944
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
945
+ ),
946
+ block=block,
947
+ statement_params=statement_params,
948
+ progress_status=status,
949
+ )
950
+ status.update(label="Model service created successfully", state="complete", expanded=False)
951
+ return result
952
+ except exceptions.SnowparkSQLException as e:
953
+ # Check if the error is because the service already exists
954
+ if "already exists" in str(e).lower() or "100132" in str(
955
+ e
956
+ ): # 100132 is Snowflake error code for object already exists
957
+ status.update("service already exists")
958
+ status.complete()
959
+ status.update(label="Service already exists", state="error", expanded=False)
960
+ raise
961
+ else:
962
+ status.update(label="Service creation failed", state="error", expanded=False)
963
+ raise
953
964
 
954
965
  @telemetry.send_api_usage_telemetry(
955
966
  project=_TELEMETRY_PROJECT,
@@ -1045,7 +1056,6 @@ class ModelVersion(lineage_node.LineageNode):
1045
1056
  )
1046
1057
  target_function_info = self._get_function_info(function_name=function_name)
1047
1058
  job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
1048
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
1049
1059
  output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
1050
1060
  output_table_name
1051
1061
  )
@@ -1064,9 +1074,7 @@ class ModelVersion(lineage_node.LineageNode):
1064
1074
  job_name=job_id,
1065
1075
  compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
1066
1076
  warehouse_name=sql_identifier.SqlIdentifier(warehouse),
1067
- image_repo_database_name=image_repo_db_id,
1068
- image_repo_schema_name=image_repo_schema_id,
1069
- image_repo_name=image_repo_id,
1077
+ image_repo=image_repo,
1070
1078
  output_table_database_name=output_table_db_id,
1071
1079
  output_table_schema_name=output_table_schema_id,
1072
1080
  output_table_name=output_table_id,