snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -1,5 +1,5 @@
1
1
  import time
2
- from typing import Any, Dict, Generic, List, Optional, TypeVar, cast
2
+ from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
3
3
 
4
4
  import yaml
5
5
 
@@ -18,11 +18,11 @@ class MLJob(Generic[T]):
18
18
  def __init__(
19
19
  self,
20
20
  id: str,
21
- service_spec: Optional[Dict[str, Any]] = None,
21
+ service_spec: Optional[dict[str, Any]] = None,
22
22
  session: Optional[snowpark.Session] = None,
23
23
  ) -> None:
24
24
  self._id = id
25
- self._service_spec_cached: Optional[Dict[str, Any]] = service_spec
25
+ self._service_spec_cached: Optional[dict[str, Any]] = service_spec
26
26
  self._session = session or sp_context.get_active_session()
27
27
 
28
28
  self._status: types.JOB_STATUS = "PENDING"
@@ -42,18 +42,18 @@ class MLJob(Generic[T]):
42
42
  return self._status
43
43
 
44
44
  @property
45
- def _service_spec(self) -> Dict[str, Any]:
45
+ def _service_spec(self) -> dict[str, Any]:
46
46
  """Get the job's service spec."""
47
47
  if not self._service_spec_cached:
48
48
  self._service_spec_cached = _get_service_spec(self._session, self.id)
49
49
  return self._service_spec_cached
50
50
 
51
51
  @property
52
- def _container_spec(self) -> Dict[str, Any]:
52
+ def _container_spec(self) -> dict[str, Any]:
53
53
  """Get the job's main container spec."""
54
54
  containers = self._service_spec["spec"]["containers"]
55
55
  container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
56
- return cast(Dict[str, Any], container_spec)
56
+ return cast(dict[str, Any], container_spec)
57
57
 
58
58
  @property
59
59
  def _stage_path(self) -> str:
@@ -70,8 +70,17 @@ class MLJob(Generic[T]):
70
70
  raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
71
71
  return f"{self._stage_path}/{result_path}"
72
72
 
73
- @snowpark._internal.utils.private_preview(version="1.7.4")
74
- def get_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> str:
73
+ @overload
74
+ def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[True]) -> list[str]:
75
+ ...
76
+
77
+ @overload
78
+ def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[False] = False) -> str:
79
+ ...
80
+
81
+ def get_logs(
82
+ self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: bool = False
83
+ ) -> Union[str, list[str]]:
75
84
  """
76
85
  Return the job's execution logs.
77
86
 
@@ -79,15 +88,17 @@ class MLJob(Generic[T]):
79
88
  limit: The maximum number of lines to return. Negative values are treated as no limit.
80
89
  instance_id: Optional instance ID to get logs from a specific instance.
81
90
  If not provided, returns logs from the head node.
91
+ as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
82
92
 
83
93
  Returns:
84
94
  The job's execution logs.
85
95
  """
86
96
  logs = _get_logs(self._session, self.id, limit, instance_id)
87
97
  assert isinstance(logs, str) # mypy
98
+ if as_list:
99
+ return logs.splitlines()
88
100
  return logs
89
101
 
90
- @snowpark._internal.utils.private_preview(version="1.7.4")
91
102
  def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
92
103
  """
93
104
  Display the job's execution logs.
@@ -97,9 +108,8 @@ class MLJob(Generic[T]):
97
108
  instance_id: Optional instance ID to get logs from a specific instance.
98
109
  If not provided, displays logs from the head node.
99
110
  """
100
- print(self.get_logs(limit, instance_id)) # noqa: T201: we need to print here.
111
+ print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
101
112
 
102
- @snowpark._internal.utils.private_preview(version="1.7.4")
103
113
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
104
114
  def wait(self, timeout: float = -1) -> types.JOB_STATUS:
105
115
  """
@@ -167,10 +177,10 @@ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[in
167
177
 
168
178
 
169
179
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
170
- def _get_service_spec(session: snowpark.Session, job_id: str) -> Dict[str, Any]:
180
+ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
171
181
  """Retrieve job execution service spec."""
172
182
  (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
173
- return cast(Dict[str, Any], yaml.safe_load(row["spec"]))
183
+ return cast(dict[str, Any], yaml.safe_load(row["spec"]))
174
184
 
175
185
 
176
186
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
@@ -192,7 +202,7 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_
192
202
  instance_id = _get_head_instance_id(session, job_id)
193
203
 
194
204
  # Assemble params: [job_id, instance_id, container_name, (optional) limit]
195
- params: List[Any] = [
205
+ params: list[Any] = [
196
206
  job_id,
197
207
  0 if instance_id is None else instance_id,
198
208
  constants.DEFAULT_CONTAINER_NAME,
@@ -1,16 +1,7 @@
1
+ import logging
1
2
  import pathlib
2
3
  import textwrap
3
- from typing import (
4
- Any,
5
- Callable,
6
- Dict,
7
- List,
8
- Literal,
9
- Optional,
10
- TypeVar,
11
- Union,
12
- overload,
13
- )
4
+ from typing import Any, Callable, Literal, Optional, TypeVar, Union, overload
14
5
  from uuid import uuid4
15
6
 
16
7
  import yaml
@@ -23,13 +14,14 @@ from snowflake.ml.jobs._utils import payload_utils, spec_utils
23
14
  from snowflake.snowpark.context import get_active_session
24
15
  from snowflake.snowpark.exceptions import SnowparkSQLException
25
16
 
17
+ logger = logging.getLogger(__name__)
18
+
26
19
  _PROJECT = "MLJob"
27
20
  JOB_ID_PREFIX = "MLJOB_"
28
21
 
29
22
  T = TypeVar("T")
30
23
 
31
24
 
32
- @snowpark._internal.utils.private_preview(version="1.7.4")
33
25
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
34
26
  def list_jobs(
35
27
  limit: int = 10,
@@ -69,7 +61,6 @@ def list_jobs(
69
61
  return df
70
62
 
71
63
 
72
- @snowpark._internal.utils.private_preview(version="1.7.4")
73
64
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
74
65
  def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
75
66
  """Retrieve a job service from the backend."""
@@ -93,7 +84,6 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
93
84
  raise
94
85
 
95
86
 
96
- @snowpark._internal.utils.private_preview(version="1.7.4")
97
87
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
98
88
  def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
99
89
  """Delete a job service from the backend. Status and logs will be lost."""
@@ -106,19 +96,18 @@ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Sessio
106
96
  session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
107
97
 
108
98
 
109
- @snowpark._internal.utils.private_preview(version="1.7.4")
110
99
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
111
100
  def submit_file(
112
101
  file_path: str,
113
102
  compute_pool: str,
114
103
  *,
115
104
  stage_name: str,
116
- args: Optional[List[str]] = None,
117
- env_vars: Optional[Dict[str, str]] = None,
118
- pip_requirements: Optional[List[str]] = None,
119
- external_access_integrations: Optional[List[str]] = None,
105
+ args: Optional[list[str]] = None,
106
+ env_vars: Optional[dict[str, str]] = None,
107
+ pip_requirements: Optional[list[str]] = None,
108
+ external_access_integrations: Optional[list[str]] = None,
120
109
  query_warehouse: Optional[str] = None,
121
- spec_overrides: Optional[Dict[str, Any]] = None,
110
+ spec_overrides: Optional[dict[str, Any]] = None,
122
111
  num_instances: Optional[int] = None,
123
112
  enable_metrics: bool = False,
124
113
  session: Optional[snowpark.Session] = None,
@@ -159,7 +148,6 @@ def submit_file(
159
148
  )
160
149
 
161
150
 
162
- @snowpark._internal.utils.private_preview(version="1.7.4")
163
151
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
164
152
  def submit_directory(
165
153
  dir_path: str,
@@ -167,12 +155,12 @@ def submit_directory(
167
155
  *,
168
156
  entrypoint: str,
169
157
  stage_name: str,
170
- args: Optional[List[str]] = None,
171
- env_vars: Optional[Dict[str, str]] = None,
172
- pip_requirements: Optional[List[str]] = None,
173
- external_access_integrations: Optional[List[str]] = None,
158
+ args: Optional[list[str]] = None,
159
+ env_vars: Optional[dict[str, str]] = None,
160
+ pip_requirements: Optional[list[str]] = None,
161
+ external_access_integrations: Optional[list[str]] = None,
174
162
  query_warehouse: Optional[str] = None,
175
- spec_overrides: Optional[Dict[str, Any]] = None,
163
+ spec_overrides: Optional[dict[str, Any]] = None,
176
164
  num_instances: Optional[int] = None,
177
165
  enable_metrics: bool = False,
178
166
  session: Optional[snowpark.Session] = None,
@@ -222,12 +210,12 @@ def _submit_job(
222
210
  *,
223
211
  stage_name: str,
224
212
  entrypoint: Optional[str] = None,
225
- args: Optional[List[str]] = None,
226
- env_vars: Optional[Dict[str, str]] = None,
227
- pip_requirements: Optional[List[str]] = None,
228
- external_access_integrations: Optional[List[str]] = None,
213
+ args: Optional[list[str]] = None,
214
+ env_vars: Optional[dict[str, str]] = None,
215
+ pip_requirements: Optional[list[str]] = None,
216
+ external_access_integrations: Optional[list[str]] = None,
229
217
  query_warehouse: Optional[str] = None,
230
- spec_overrides: Optional[Dict[str, Any]] = None,
218
+ spec_overrides: Optional[dict[str, Any]] = None,
231
219
  num_instances: Optional[int] = None,
232
220
  enable_metrics: bool = False,
233
221
  session: Optional[snowpark.Session] = None,
@@ -242,12 +230,12 @@ def _submit_job(
242
230
  *,
243
231
  stage_name: str,
244
232
  entrypoint: Optional[str] = None,
245
- args: Optional[List[str]] = None,
246
- env_vars: Optional[Dict[str, str]] = None,
247
- pip_requirements: Optional[List[str]] = None,
248
- external_access_integrations: Optional[List[str]] = None,
233
+ args: Optional[list[str]] = None,
234
+ env_vars: Optional[dict[str, str]] = None,
235
+ pip_requirements: Optional[list[str]] = None,
236
+ external_access_integrations: Optional[list[str]] = None,
249
237
  query_warehouse: Optional[str] = None,
250
- spec_overrides: Optional[Dict[str, Any]] = None,
238
+ spec_overrides: Optional[dict[str, Any]] = None,
251
239
  num_instances: Optional[int] = None,
252
240
  enable_metrics: bool = False,
253
241
  session: Optional[snowpark.Session] = None,
@@ -263,6 +251,8 @@ def _submit_job(
263
251
  # TODO: Log lengths of args, env_vars, and spec_overrides values
264
252
  "pip_requirements",
265
253
  "external_access_integrations",
254
+ "num_instances",
255
+ "enable_metrics",
266
256
  ],
267
257
  )
268
258
  def _submit_job(
@@ -271,12 +261,12 @@ def _submit_job(
271
261
  *,
272
262
  stage_name: str,
273
263
  entrypoint: Optional[str] = None,
274
- args: Optional[List[str]] = None,
275
- env_vars: Optional[Dict[str, str]] = None,
276
- pip_requirements: Optional[List[str]] = None,
277
- external_access_integrations: Optional[List[str]] = None,
264
+ args: Optional[list[str]] = None,
265
+ env_vars: Optional[dict[str, str]] = None,
266
+ pip_requirements: Optional[list[str]] = None,
267
+ external_access_integrations: Optional[list[str]] = None,
278
268
  query_warehouse: Optional[str] = None,
279
- spec_overrides: Optional[Dict[str, Any]] = None,
269
+ spec_overrides: Optional[dict[str, Any]] = None,
280
270
  num_instances: Optional[int] = None,
281
271
  enable_metrics: bool = False,
282
272
  session: Optional[snowpark.Session] = None,
@@ -305,6 +295,12 @@ def _submit_job(
305
295
  Raises:
306
296
  RuntimeError: If required Snowflake features are not enabled.
307
297
  """
298
+ # Display warning about PrPr parameters
299
+ if num_instances is not None:
300
+ logger.warning(
301
+ "_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
302
+ )
303
+
308
304
  session = session or get_active_session()
309
305
  job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
310
306
  stage_name = "@" + stage_name.lstrip("@").rstrip("/")
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, Type, Union
3
+ from typing import TYPE_CHECKING, Literal, Optional, Union
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal import telemetry
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from snowflake.ml.model._client.model import model_version_impl
13
13
 
14
14
  _PROJECT = "LINEAGE"
15
- DOMAIN_LINEAGE_REGISTRY: Dict[str, Type["LineageNode"]] = {}
15
+ DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
16
16
 
17
17
 
18
18
  class LineageNode:
@@ -87,8 +87,8 @@ class LineageNode:
87
87
  def lineage(
88
88
  self,
89
89
  direction: Literal["upstream", "downstream"] = "downstream",
90
- domain_filter: Optional[Set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
- ) -> List[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
90
+ domain_filter: Optional[set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
+ ) -> list[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
92
92
  """
93
93
  Retrieves the lineage nodes connected to this node.
94
94
 
@@ -109,7 +109,7 @@ class LineageNode:
109
109
  if domain_filter is not None:
110
110
  domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
111
111
 
112
- lineage_nodes: List["LineageNode"] = []
112
+ lineage_nodes: list["LineageNode"] = []
113
113
  for row in df.collect():
114
114
  lineage_object = (
115
115
  json.loads(row["TARGET_OBJECT"])
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -224,7 +224,7 @@ class Model:
224
224
  project=_TELEMETRY_PROJECT,
225
225
  subproject=_TELEMETRY_SUBPROJECT,
226
226
  )
227
- def versions(self) -> List[model_version_impl.ModelVersion]:
227
+ def versions(self) -> list[model_version_impl.ModelVersion]:
228
228
  """Get all versions in the model.
229
229
 
230
230
  Returns:
@@ -298,7 +298,7 @@ class Model:
298
298
  project=_TELEMETRY_PROJECT,
299
299
  subproject=_TELEMETRY_SUBPROJECT,
300
300
  )
301
- def show_tags(self) -> Dict[str, str]:
301
+ def show_tags(self) -> dict[str, str]:
302
302
  """Get a dictionary showing the tag and its value attached to the model.
303
303
 
304
304
  Returns:
@@ -2,10 +2,11 @@ import enum
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Callable, Dict, List, Optional, Union, overload
5
+ from typing import Any, Callable, Optional, Union, overload
6
6
 
7
7
  import pandas as pd
8
8
 
9
+ from snowflake import snowpark
9
10
  from snowflake.ml._internal import telemetry
10
11
  from snowflake.ml._internal.utils import sql_identifier
11
12
  from snowflake.ml.lineage import lineage_node
@@ -32,7 +33,7 @@ class ModelVersion(lineage_node.LineageNode):
32
33
  _service_ops: service_ops.ServiceOperator
33
34
  _model_name: sql_identifier.SqlIdentifier
34
35
  _version_name: sql_identifier.SqlIdentifier
35
- _functions: List[model_manifest_schema.ModelFunctionInfo]
36
+ _functions: list[model_manifest_schema.ModelFunctionInfo]
36
37
 
37
38
  def __init__(self) -> None:
38
39
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
@@ -152,7 +153,7 @@ class ModelVersion(lineage_node.LineageNode):
152
153
  project=_TELEMETRY_PROJECT,
153
154
  subproject=_TELEMETRY_SUBPROJECT,
154
155
  )
155
- def show_metrics(self) -> Dict[str, Any]:
156
+ def show_metrics(self) -> dict[str, Any]:
156
157
  """Show all metrics logged with the model version.
157
158
 
158
159
  Returns:
@@ -293,7 +294,7 @@ class ModelVersion(lineage_node.LineageNode):
293
294
  statement_params=statement_params,
294
295
  )
295
296
 
296
- def _get_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
297
+ def _get_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
297
298
  statement_params = telemetry.get_statement_params(
298
299
  project=_TELEMETRY_PROJECT,
299
300
  subproject=_TELEMETRY_SUBPROJECT,
@@ -327,7 +328,7 @@ class ModelVersion(lineage_node.LineageNode):
327
328
  project=_TELEMETRY_PROJECT,
328
329
  subproject=_TELEMETRY_SUBPROJECT,
329
330
  )
330
- def show_functions(self) -> List[model_manifest_schema.ModelFunctionInfo]:
331
+ def show_functions(self) -> list[model_manifest_schema.ModelFunctionInfo]:
331
332
  """Show all functions information in a model version that is callable.
332
333
 
333
334
  Returns:
@@ -405,11 +406,6 @@ class ModelVersion(lineage_node.LineageNode):
405
406
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
406
407
  type validation to make sure your input data won't overflow when providing to the model.
407
408
 
408
- Raises:
409
- ValueError: When no method with the corresponding name is available.
410
- ValueError: When there are more than 1 target methods available in the model but no function name specified.
411
- ValueError: When the partition column is not a valid Snowflake identifier.
412
-
413
409
  Returns:
414
410
  The prediction data. It would be the same type dataframe as your input.
415
411
  """
@@ -422,29 +418,7 @@ class ModelVersion(lineage_node.LineageNode):
422
418
  # Partition column must be a valid identifier
423
419
  partition_column = sql_identifier.SqlIdentifier(partition_column)
424
420
 
425
- functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
426
-
427
- if function_name:
428
- req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
429
- find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
430
- lambda method: method["name"] == req_method_name
431
- )
432
- target_function_info = next(
433
- filter(find_method, functions),
434
- None,
435
- )
436
- if target_function_info is None:
437
- raise ValueError(
438
- f"There is no method with name {function_name} available in the model"
439
- f" {self.fully_qualified_model_name} version {self.version_name}"
440
- )
441
- elif len(functions) != 1:
442
- raise ValueError(
443
- f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
444
- f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
445
- )
446
- else:
447
- target_function_info = functions[0]
421
+ target_function_info = self._get_function_info(function_name=function_name)
448
422
 
449
423
  if service_name:
450
424
  database_name_id, schema_name_id, service_name_id = sql_identifier.parse_fully_qualified_name(service_name)
@@ -475,6 +449,33 @@ class ModelVersion(lineage_node.LineageNode):
475
449
  is_partitioned=target_function_info["is_partitioned"],
476
450
  )
477
451
 
452
+ def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
453
+ functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
454
+
455
+ if function_name:
456
+ req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
457
+ find_method: Callable[[model_manifest_schema.ModelFunctionInfo], bool] = (
458
+ lambda method: method["name"] == req_method_name
459
+ )
460
+ target_function_info = next(
461
+ filter(find_method, functions),
462
+ None,
463
+ )
464
+ if target_function_info is None:
465
+ raise ValueError(
466
+ f"There is no method with name {function_name} available in the model"
467
+ f" {self.fully_qualified_model_name} version {self.version_name}"
468
+ )
469
+ elif len(functions) != 1:
470
+ raise ValueError(
471
+ f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
472
+ f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
473
+ )
474
+ else:
475
+ target_function_info = functions[0]
476
+
477
+ return target_function_info
478
+
478
479
  @telemetry.send_api_usage_telemetry(
479
480
  project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
480
481
  )
@@ -684,7 +685,7 @@ class ModelVersion(lineage_node.LineageNode):
684
685
  num_workers: Optional[int] = None,
685
686
  max_batch_rows: Optional[int] = None,
686
687
  force_rebuild: bool = False,
687
- build_external_access_integrations: Optional[List[str]] = None,
688
+ build_external_access_integrations: Optional[list[str]] = None,
688
689
  block: bool = True,
689
690
  ) -> Union[str, async_job.AsyncJob]:
690
691
  """Create an inference service with the given spec.
@@ -751,7 +752,7 @@ class ModelVersion(lineage_node.LineageNode):
751
752
  max_batch_rows: Optional[int] = None,
752
753
  force_rebuild: bool = False,
753
754
  build_external_access_integration: Optional[str] = None,
754
- build_external_access_integrations: Optional[List[str]] = None,
755
+ build_external_access_integrations: Optional[list[str]] = None,
755
756
  block: bool = True,
756
757
  ) -> Union[str, async_job.AsyncJob]:
757
758
  """Create an inference service with the given spec.
@@ -914,5 +915,72 @@ class ModelVersion(lineage_node.LineageNode):
914
915
  statement_params=statement_params,
915
916
  )
916
917
 
918
+ @snowpark._internal.utils.private_preview(version="1.8.3")
919
+ @telemetry.send_api_usage_telemetry(
920
+ project=_TELEMETRY_PROJECT,
921
+ subproject=_TELEMETRY_SUBPROJECT,
922
+ )
923
+ def run_job(
924
+ self,
925
+ X: Union[pd.DataFrame, "dataframe.DataFrame"],
926
+ *,
927
+ job_name: str,
928
+ compute_pool: str,
929
+ image_repo: str,
930
+ output_table_name: str,
931
+ function_name: Optional[str] = None,
932
+ cpu_requests: Optional[str] = None,
933
+ memory_requests: Optional[str] = None,
934
+ gpu_requests: Optional[Union[str, int]] = None,
935
+ num_workers: Optional[int] = None,
936
+ max_batch_rows: Optional[int] = None,
937
+ force_rebuild: bool = False,
938
+ build_external_access_integrations: Optional[list[str]] = None,
939
+ ) -> Union[pd.DataFrame, dataframe.DataFrame]:
940
+ statement_params = telemetry.get_statement_params(
941
+ project=_TELEMETRY_PROJECT,
942
+ subproject=_TELEMETRY_SUBPROJECT,
943
+ )
944
+ target_function_info = self._get_function_info(function_name=function_name)
945
+ job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
946
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
947
+ output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
948
+ output_table_name
949
+ )
950
+ warehouse = self._service_ops._session.get_current_warehouse()
951
+ assert warehouse, "No active warehouse selected in the current session."
952
+ return self._service_ops.invoke_job_method(
953
+ target_method=target_function_info["target_method"],
954
+ signature=target_function_info["signature"],
955
+ X=X,
956
+ database_name=None,
957
+ schema_name=None,
958
+ model_name=self._model_name,
959
+ version_name=self._version_name,
960
+ job_database_name=job_db_id,
961
+ job_schema_name=job_schema_id,
962
+ job_name=job_id,
963
+ compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
964
+ warehouse_name=sql_identifier.SqlIdentifier(warehouse),
965
+ image_repo_database_name=image_repo_db_id,
966
+ image_repo_schema_name=image_repo_schema_id,
967
+ image_repo_name=image_repo_id,
968
+ output_table_database_name=output_table_db_id,
969
+ output_table_schema_name=output_table_schema_id,
970
+ output_table_name=output_table_id,
971
+ cpu_requests=cpu_requests,
972
+ memory_requests=memory_requests,
973
+ gpu_requests=gpu_requests,
974
+ num_workers=num_workers,
975
+ max_batch_rows=max_batch_rows,
976
+ force_rebuild=force_rebuild,
977
+ build_external_access_integrations=(
978
+ None
979
+ if build_external_access_integrations is None
980
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
981
+ ),
982
+ statement_params=statement_params,
983
+ )
984
+
917
985
 
918
986
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, Optional, TypedDict
2
+ from typing import Any, Optional, TypedDict
3
3
 
4
4
  from typing_extensions import NotRequired
5
5
 
@@ -14,7 +14,7 @@ MODEL_VERSION_METADATA_SCHEMA_VERSION = "2024-01-01"
14
14
 
15
15
 
16
16
  class ModelVersionMetadataSchema(TypedDict):
17
- metrics: NotRequired[Dict[str, Any]]
17
+ metrics: NotRequired[dict[str, Any]]
18
18
 
19
19
 
20
20
  class MetadataOperator:
@@ -44,7 +44,7 @@ class MetadataOperator:
44
44
  )
45
45
 
46
46
  @staticmethod
47
- def _parse(metadata_dict: Dict[str, Any]) -> ModelVersionMetadataSchema:
47
+ def _parse(metadata_dict: dict[str, Any]) -> ModelVersionMetadataSchema:
48
48
  loaded_metadata_schema_version = metadata_dict.get("snowpark_ml_schema_version", None)
49
49
  if loaded_metadata_schema_version is None:
50
50
  return ModelVersionMetadataSchema(metrics={})
@@ -65,8 +65,8 @@ class MetadataOperator:
65
65
  schema_name: Optional[sql_identifier.SqlIdentifier],
66
66
  model_name: sql_identifier.SqlIdentifier,
67
67
  version_name: sql_identifier.SqlIdentifier,
68
- statement_params: Optional[Dict[str, Any]] = None,
69
- ) -> Dict[str, Any]:
68
+ statement_params: Optional[dict[str, Any]] = None,
69
+ ) -> dict[str, Any]:
70
70
  version_info_list = self._model_client.show_versions(
71
71
  database_name=database_name,
72
72
  schema_name=schema_name,
@@ -89,7 +89,7 @@ class MetadataOperator:
89
89
  schema_name: Optional[sql_identifier.SqlIdentifier],
90
90
  model_name: sql_identifier.SqlIdentifier,
91
91
  version_name: sql_identifier.SqlIdentifier,
92
- statement_params: Optional[Dict[str, Any]] = None,
92
+ statement_params: Optional[dict[str, Any]] = None,
93
93
  ) -> ModelVersionMetadataSchema:
94
94
  metadata_dict = self._get_current_metadata_dict(
95
95
  database_name=database_name,
@@ -108,7 +108,7 @@ class MetadataOperator:
108
108
  schema_name: Optional[sql_identifier.SqlIdentifier],
109
109
  model_name: sql_identifier.SqlIdentifier,
110
110
  version_name: sql_identifier.SqlIdentifier,
111
- statement_params: Optional[Dict[str, Any]] = None,
111
+ statement_params: Optional[dict[str, Any]] = None,
112
112
  ) -> None:
113
113
  metadata_dict = self._get_current_metadata_dict(
114
114
  database_name=database_name,