snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/job.py CHANGED
@@ -1,20 +1,32 @@
1
1
  import time
2
- from typing import Any, List, Optional, cast
2
+ from typing import Any, Generic, Literal, Optional, TypeVar, Union, cast, overload
3
+
4
+ import yaml
3
5
 
4
6
  from snowflake import snowpark
5
7
  from snowflake.ml._internal import telemetry
6
- from snowflake.ml.jobs._utils import constants, types
8
+ from snowflake.ml.jobs._utils import constants, interop_utils, types
7
9
  from snowflake.snowpark import context as sp_context
8
10
 
9
11
  _PROJECT = "MLJob"
10
12
  TERMINAL_JOB_STATUSES = {"FAILED", "DONE", "INTERNAL_ERROR"}
11
13
 
14
+ T = TypeVar("T")
15
+
12
16
 
13
- class MLJob:
14
- def __init__(self, id: str, session: Optional[snowpark.Session] = None) -> None:
17
+ class MLJob(Generic[T]):
18
+ def __init__(
19
+ self,
20
+ id: str,
21
+ service_spec: Optional[dict[str, Any]] = None,
22
+ session: Optional[snowpark.Session] = None,
23
+ ) -> None:
15
24
  self._id = id
25
+ self._service_spec_cached: Optional[dict[str, Any]] = service_spec
16
26
  self._session = session or sp_context.get_active_session()
27
+
17
28
  self._status: types.JOB_STATUS = "PENDING"
29
+ self._result: Optional[interop_utils.ExecutionResult] = None
18
30
 
19
31
  @property
20
32
  def id(self) -> str:
@@ -29,33 +41,76 @@ class MLJob:
29
41
  self._status = _get_status(self._session, self.id)
30
42
  return self._status
31
43
 
32
- @snowpark._internal.utils.private_preview(version="1.7.4")
33
- def get_logs(self, limit: int = -1) -> str:
44
+ @property
45
+ def _service_spec(self) -> dict[str, Any]:
46
+ """Get the job's service spec."""
47
+ if not self._service_spec_cached:
48
+ self._service_spec_cached = _get_service_spec(self._session, self.id)
49
+ return self._service_spec_cached
50
+
51
+ @property
52
+ def _container_spec(self) -> dict[str, Any]:
53
+ """Get the job's main container spec."""
54
+ containers = self._service_spec["spec"]["containers"]
55
+ container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
56
+ return cast(dict[str, Any], container_spec)
57
+
58
+ @property
59
+ def _stage_path(self) -> str:
60
+ """Get the job's artifact storage stage location."""
61
+ volumes = self._service_spec["spec"]["volumes"]
62
+ stage_path = next(v for v in volumes if v["name"] == constants.STAGE_VOLUME_NAME)["source"]
63
+ return cast(str, stage_path)
64
+
65
+ @property
66
+ def _result_path(self) -> str:
67
+ """Get the job's result file location."""
68
+ result_path = self._container_spec["env"].get(constants.RESULT_PATH_ENV_VAR)
69
+ if result_path is None:
70
+ raise RuntimeError(f"Job {self.id} doesn't have a result path configured")
71
+ return f"{self._stage_path}/{result_path}"
72
+
73
+ @overload
74
+ def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[True]) -> list[str]:
75
+ ...
76
+
77
+ @overload
78
+ def get_logs(self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: Literal[False] = False) -> str:
79
+ ...
80
+
81
+ def get_logs(
82
+ self, limit: int = -1, instance_id: Optional[int] = None, *, as_list: bool = False
83
+ ) -> Union[str, list[str]]:
34
84
  """
35
85
  Return the job's execution logs.
36
86
 
37
87
  Args:
38
88
  limit: The maximum number of lines to return. Negative values are treated as no limit.
89
+ instance_id: Optional instance ID to get logs from a specific instance.
90
+ If not provided, returns logs from the head node.
91
+ as_list: If True, returns logs as a list of lines. Otherwise, returns logs as a single string.
39
92
 
40
93
  Returns:
41
94
  The job's execution logs.
42
95
  """
43
- logs = _get_logs(self._session, self.id, limit)
96
+ logs = _get_logs(self._session, self.id, limit, instance_id)
44
97
  assert isinstance(logs, str) # mypy
98
+ if as_list:
99
+ return logs.splitlines()
45
100
  return logs
46
101
 
47
- @snowpark._internal.utils.private_preview(version="1.7.4")
48
- def show_logs(self, limit: int = -1) -> None:
102
+ def show_logs(self, limit: int = -1, instance_id: Optional[int] = None) -> None:
49
103
  """
50
104
  Display the job's execution logs.
51
105
 
52
106
  Args:
53
107
  limit: The maximum number of lines to display. Negative values are treated as no limit.
108
+ instance_id: Optional instance ID to get logs from a specific instance.
109
+ If not provided, displays logs from the head node.
54
110
  """
55
- print(self.get_logs(limit)) # noqa: T201: we need to print here.
111
+ print(self.get_logs(limit, instance_id, as_list=False)) # noqa: T201: we need to print here.
56
112
 
57
- @snowpark._internal.utils.private_preview(version="1.7.4")
58
- @telemetry.send_api_usage_telemetry(project=_PROJECT)
113
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
59
114
  def wait(self, timeout: float = -1) -> types.JOB_STATUS:
60
115
  """
61
116
  Block until completion. Returns completion status.
@@ -78,20 +133,58 @@ class MLJob:
78
133
  delay = min(delay * 2, constants.JOB_POLL_MAX_DELAY_SECONDS) # Exponential backoff
79
134
  return self.status
80
135
 
136
+ @snowpark._internal.utils.private_preview(version="1.8.2")
137
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["timeout"])
138
+ def result(self, timeout: float = -1) -> T:
139
+ """
140
+ Block until completion. Returns job execution result.
141
+
142
+ Args:
143
+ timeout: The maximum time to wait in seconds. Negative values are treated as no timeout.
144
+
145
+ Returns:
146
+ T: The deserialized job result. # noqa: DAR401
147
+
148
+ Raises:
149
+ RuntimeError: If the job failed or if the job doesn't have a result to retrieve.
150
+ TimeoutError: If the job does not complete within the specified timeout. # noqa: DAR402
151
+ """
152
+ if self._result is None:
153
+ self.wait(timeout)
154
+ try:
155
+ self._result = interop_utils.fetch_result(self._session, self._result_path)
156
+ except Exception as e:
157
+ raise RuntimeError(f"Failed to retrieve result for job (id={self.id})") from e
158
+
159
+ if self._result.success:
160
+ return cast(T, self._result.result)
161
+ raise RuntimeError(f"Job execution failed (id={self.id})") from self._result.exception
162
+
163
+
164
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "instance_id"])
165
+ def _get_status(session: snowpark.Session, job_id: str, instance_id: Optional[int] = None) -> types.JOB_STATUS:
166
+ """Retrieve job or job instance execution status."""
167
+ if instance_id is not None:
168
+ # Get specific instance status
169
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
170
+ for row in rows:
171
+ if row["instance_id"] == str(instance_id):
172
+ return cast(types.JOB_STATUS, row["status"])
173
+ raise ValueError(f"Instance {instance_id} not found in job {job_id}")
174
+ else:
175
+ (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
176
+ return cast(types.JOB_STATUS, row["status"])
177
+
81
178
 
82
179
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
83
- def _get_status(session: snowpark.Session, job_id: str) -> types.JOB_STATUS:
84
- """Retrieve job execution status."""
85
- # TODO: snowflake-snowpark-python<1.24.0 shows spurious error messages on
86
- # `DESCRIBE` queries with bind variables
87
- # Switch to use bind variables instead of client side formatting after
88
- # updating to snowflake-snowpark-python>=1.24.0
89
- (row,) = session.sql(f"DESCRIBE SERVICE {job_id}").collect()
90
- return cast(types.JOB_STATUS, row["status"])
91
-
92
-
93
- @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit"])
94
- def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
180
+ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
181
+ """Retrieve job execution service spec."""
182
+ (row,) = session.sql("DESCRIBE SERVICE IDENTIFIER(?)", params=[job_id]).collect()
183
+ return cast(dict[str, Any], yaml.safe_load(row["spec"]))
184
+
185
+
186
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
187
+ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None) -> str:
95
188
  """
96
189
  Retrieve the job's execution logs.
97
190
 
@@ -99,15 +192,54 @@ def _get_logs(session: snowpark.Session, job_id: str, limit: int = -1) -> str:
99
192
  job_id: The job ID.
100
193
  limit: The maximum number of lines to return. Negative values are treated as no limit.
101
194
  session: The Snowpark session to use. If none specified, uses active session.
195
+ instance_id: Optional instance ID to get logs from a specific instance.
102
196
 
103
197
  Returns:
104
198
  The job's execution logs.
105
199
  """
106
- params: List[Any] = [job_id]
200
+ # If instance_id is not specified, try to get the head instance ID
201
+ if instance_id is None:
202
+ instance_id = _get_head_instance_id(session, job_id)
203
+
204
+ # Assemble params: [job_id, instance_id, container_name, (optional) limit]
205
+ params: list[Any] = [
206
+ job_id,
207
+ 0 if instance_id is None else instance_id,
208
+ constants.DEFAULT_CONTAINER_NAME,
209
+ ]
107
210
  if limit > 0:
108
211
  params.append(limit)
212
+
109
213
  (row,) = session.sql(
110
- f"SELECT SYSTEM$GET_SERVICE_LOGS(?, 0, '{constants.DEFAULT_CONTAINER_NAME}'{f', ?' if limit > 0 else ''})",
214
+ f"SELECT SYSTEM$GET_SERVICE_LOGS(?, ?, ?{f', ?' if limit > 0 else ''})",
111
215
  params=params,
112
216
  ).collect()
113
217
  return str(row[0])
218
+
219
+
220
+ @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id"])
221
+ def _get_head_instance_id(session: snowpark.Session, job_id: str) -> Optional[int]:
222
+ """
223
+ Retrieve the head instance ID of a job.
224
+
225
+ Args:
226
+ session: The Snowpark session to use.
227
+ job_id: The job ID.
228
+
229
+ Returns:
230
+ The head instance ID of the job. Returns None if the head instance has not started yet.
231
+ """
232
+ rows = session.sql("SHOW SERVICE INSTANCES IN SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
233
+ if not rows:
234
+ return None
235
+
236
+ # Sort by start_time first, then by instance_id
237
+ sorted_instances = sorted(rows, key=lambda x: (x["start_time"], int(x["instance_id"])))
238
+ head_instance = sorted_instances[0]
239
+ if not head_instance["start_time"]:
240
+ # If head instance hasn't started yet, return None
241
+ return None
242
+ try:
243
+ return int(head_instance["instance_id"])
244
+ except (ValueError, TypeError):
245
+ return 0
@@ -1,6 +1,7 @@
1
+ import logging
1
2
  import pathlib
2
3
  import textwrap
3
- from typing import Any, Callable, Dict, List, Literal, Optional, Union
4
+ from typing import Any, Callable, Literal, Optional, TypeVar, Union, overload
4
5
  from uuid import uuid4
5
6
 
6
7
  import yaml
@@ -13,11 +14,14 @@ from snowflake.ml.jobs._utils import payload_utils, spec_utils
13
14
  from snowflake.snowpark.context import get_active_session
14
15
  from snowflake.snowpark.exceptions import SnowparkSQLException
15
16
 
17
+ logger = logging.getLogger(__name__)
18
+
16
19
  _PROJECT = "MLJob"
17
20
  JOB_ID_PREFIX = "MLJOB_"
18
21
 
22
+ T = TypeVar("T")
23
+
19
24
 
20
- @snowpark._internal.utils.private_preview(version="1.7.4")
21
25
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
22
26
  def list_jobs(
23
27
  limit: int = 10,
@@ -57,9 +61,8 @@ def list_jobs(
57
61
  return df
58
62
 
59
63
 
60
- @snowpark._internal.utils.private_preview(version="1.7.4")
61
64
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
62
- def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
65
+ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
63
66
  """Retrieve a job service from the backend."""
64
67
  session = session or get_active_session()
65
68
 
@@ -71,7 +74,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
71
74
 
72
75
  try:
73
76
  # Validate that job exists by doing a status check
74
- job = jb.MLJob(job_id, session=session)
77
+ # FIXME: Retrieve return path
78
+ job = jb.MLJob[Any](job_id, session=session)
75
79
  _ = job.status
76
80
  return job
77
81
  except SnowparkSQLException as e:
@@ -80,9 +84,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
80
84
  raise
81
85
 
82
86
 
83
- @snowpark._internal.utils.private_preview(version="1.7.4")
84
87
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
85
- def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
88
+ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
86
89
  """Delete a job service from the backend. Status and logs will be lost."""
87
90
  if isinstance(job, jb.MLJob):
88
91
  job_id = job.id
@@ -93,23 +96,22 @@ def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] =
93
96
  session.sql("DROP SERVICE IDENTIFIER(?)", params=(job_id,)).collect()
94
97
 
95
98
 
96
- @snowpark._internal.utils.private_preview(version="1.7.4")
97
99
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
98
100
  def submit_file(
99
101
  file_path: str,
100
102
  compute_pool: str,
101
103
  *,
102
104
  stage_name: str,
103
- args: Optional[List[str]] = None,
104
- env_vars: Optional[Dict[str, str]] = None,
105
- pip_requirements: Optional[List[str]] = None,
106
- external_access_integrations: Optional[List[str]] = None,
105
+ args: Optional[list[str]] = None,
106
+ env_vars: Optional[dict[str, str]] = None,
107
+ pip_requirements: Optional[list[str]] = None,
108
+ external_access_integrations: Optional[list[str]] = None,
107
109
  query_warehouse: Optional[str] = None,
108
- spec_overrides: Optional[Dict[str, Any]] = None,
110
+ spec_overrides: Optional[dict[str, Any]] = None,
109
111
  num_instances: Optional[int] = None,
110
112
  enable_metrics: bool = False,
111
113
  session: Optional[snowpark.Session] = None,
112
- ) -> jb.MLJob:
114
+ ) -> jb.MLJob[None]:
113
115
  """
114
116
  Submit a Python file as a job to the compute pool.
115
117
 
@@ -146,7 +148,6 @@ def submit_file(
146
148
  )
147
149
 
148
150
 
149
- @snowpark._internal.utils.private_preview(version="1.7.4")
150
151
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
151
152
  def submit_directory(
152
153
  dir_path: str,
@@ -154,16 +155,16 @@ def submit_directory(
154
155
  *,
155
156
  entrypoint: str,
156
157
  stage_name: str,
157
- args: Optional[List[str]] = None,
158
- env_vars: Optional[Dict[str, str]] = None,
159
- pip_requirements: Optional[List[str]] = None,
160
- external_access_integrations: Optional[List[str]] = None,
158
+ args: Optional[list[str]] = None,
159
+ env_vars: Optional[dict[str, str]] = None,
160
+ pip_requirements: Optional[list[str]] = None,
161
+ external_access_integrations: Optional[list[str]] = None,
161
162
  query_warehouse: Optional[str] = None,
162
- spec_overrides: Optional[Dict[str, Any]] = None,
163
+ spec_overrides: Optional[dict[str, Any]] = None,
163
164
  num_instances: Optional[int] = None,
164
165
  enable_metrics: bool = False,
165
166
  session: Optional[snowpark.Session] = None,
166
- ) -> jb.MLJob:
167
+ ) -> jb.MLJob[None]:
167
168
  """
168
169
  Submit a directory containing Python script(s) as a job to the compute pool.
169
170
 
@@ -202,6 +203,46 @@ def submit_directory(
202
203
  )
203
204
 
204
205
 
206
+ @overload
207
+ def _submit_job(
208
+ source: str,
209
+ compute_pool: str,
210
+ *,
211
+ stage_name: str,
212
+ entrypoint: Optional[str] = None,
213
+ args: Optional[list[str]] = None,
214
+ env_vars: Optional[dict[str, str]] = None,
215
+ pip_requirements: Optional[list[str]] = None,
216
+ external_access_integrations: Optional[list[str]] = None,
217
+ query_warehouse: Optional[str] = None,
218
+ spec_overrides: Optional[dict[str, Any]] = None,
219
+ num_instances: Optional[int] = None,
220
+ enable_metrics: bool = False,
221
+ session: Optional[snowpark.Session] = None,
222
+ ) -> jb.MLJob[None]:
223
+ ...
224
+
225
+
226
+ @overload
227
+ def _submit_job(
228
+ source: Callable[..., T],
229
+ compute_pool: str,
230
+ *,
231
+ stage_name: str,
232
+ entrypoint: Optional[str] = None,
233
+ args: Optional[list[str]] = None,
234
+ env_vars: Optional[dict[str, str]] = None,
235
+ pip_requirements: Optional[list[str]] = None,
236
+ external_access_integrations: Optional[list[str]] = None,
237
+ query_warehouse: Optional[str] = None,
238
+ spec_overrides: Optional[dict[str, Any]] = None,
239
+ num_instances: Optional[int] = None,
240
+ enable_metrics: bool = False,
241
+ session: Optional[snowpark.Session] = None,
242
+ ) -> jb.MLJob[T]:
243
+ ...
244
+
245
+
205
246
  @telemetry.send_api_usage_telemetry(
206
247
  project=_PROJECT,
207
248
  func_params_to_log=[
@@ -210,24 +251,26 @@ def submit_directory(
210
251
  # TODO: Log lengths of args, env_vars, and spec_overrides values
211
252
  "pip_requirements",
212
253
  "external_access_integrations",
254
+ "num_instances",
255
+ "enable_metrics",
213
256
  ],
214
257
  )
215
258
  def _submit_job(
216
- source: Union[str, Callable[..., Any]],
259
+ source: Union[str, Callable[..., T]],
217
260
  compute_pool: str,
218
261
  *,
219
262
  stage_name: str,
220
263
  entrypoint: Optional[str] = None,
221
- args: Optional[List[str]] = None,
222
- env_vars: Optional[Dict[str, str]] = None,
223
- pip_requirements: Optional[List[str]] = None,
224
- external_access_integrations: Optional[List[str]] = None,
264
+ args: Optional[list[str]] = None,
265
+ env_vars: Optional[dict[str, str]] = None,
266
+ pip_requirements: Optional[list[str]] = None,
267
+ external_access_integrations: Optional[list[str]] = None,
225
268
  query_warehouse: Optional[str] = None,
226
- spec_overrides: Optional[Dict[str, Any]] = None,
269
+ spec_overrides: Optional[dict[str, Any]] = None,
227
270
  num_instances: Optional[int] = None,
228
271
  enable_metrics: bool = False,
229
272
  session: Optional[snowpark.Session] = None,
230
- ) -> jb.MLJob:
273
+ ) -> jb.MLJob[T]:
231
274
  """
232
275
  Submit a job to the compute pool.
233
276
 
@@ -252,6 +295,12 @@ def _submit_job(
252
295
  Raises:
253
296
  RuntimeError: If required Snowflake features are not enabled.
254
297
  """
298
+ # Display warning about PrPr parameters
299
+ if num_instances is not None:
300
+ logger.warning(
301
+ "_submit_job() parameter 'num_instances' is in private preview since 1.8.2. Do not use it in production.",
302
+ )
303
+
255
304
  session = session or get_active_session()
256
305
  job_id = f"{JOB_ID_PREFIX}{str(uuid4()).replace('-', '_').upper()}"
257
306
  stage_name = "@" + stage_name.lstrip("@").rstrip("/")
@@ -314,5 +363,4 @@ def _submit_job(
314
363
  ) from e
315
364
  raise
316
365
 
317
- # TODO: Wrap snowflake.core.service.JobService object
318
- return jb.MLJob(job_id, session=session)
366
+ return jb.MLJob(job_id, service_spec=spec, session=session)
@@ -1,6 +1,6 @@
1
1
  import json
2
2
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, Type, Union
3
+ from typing import TYPE_CHECKING, Literal, Optional, Union
4
4
 
5
5
  from snowflake import snowpark
6
6
  from snowflake.ml._internal import telemetry
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from snowflake.ml.model._client.model import model_version_impl
13
13
 
14
14
  _PROJECT = "LINEAGE"
15
- DOMAIN_LINEAGE_REGISTRY: Dict[str, Type["LineageNode"]] = {}
15
+ DOMAIN_LINEAGE_REGISTRY: dict[str, type["LineageNode"]] = {}
16
16
 
17
17
 
18
18
  class LineageNode:
@@ -87,8 +87,8 @@ class LineageNode:
87
87
  def lineage(
88
88
  self,
89
89
  direction: Literal["upstream", "downstream"] = "downstream",
90
- domain_filter: Optional[Set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
- ) -> List[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
90
+ domain_filter: Optional[set[Literal["feature_view", "dataset", "model", "table", "view"]]] = None,
91
+ ) -> list[Union["feature_view.FeatureView", "dataset.Dataset", "model_version_impl.ModelVersion", "LineageNode"]]:
92
92
  """
93
93
  Retrieves the lineage nodes connected to this node.
94
94
 
@@ -109,7 +109,7 @@ class LineageNode:
109
109
  if domain_filter is not None:
110
110
  domain_filter = {d.lower() for d in domain_filter} # type: ignore[misc]
111
111
 
112
- lineage_nodes: List["LineageNode"] = []
112
+ lineage_nodes: list["LineageNode"] = []
113
113
  for row in df.collect():
114
114
  lineage_object = (
115
115
  json.loads(row["TARGET_OBJECT"])
@@ -1,4 +1,4 @@
1
- from typing import Dict, List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -224,7 +224,7 @@ class Model:
224
224
  project=_TELEMETRY_PROJECT,
225
225
  subproject=_TELEMETRY_SUBPROJECT,
226
226
  )
227
- def versions(self) -> List[model_version_impl.ModelVersion]:
227
+ def versions(self) -> list[model_version_impl.ModelVersion]:
228
228
  """Get all versions in the model.
229
229
 
230
230
  Returns:
@@ -298,7 +298,7 @@ class Model:
298
298
  project=_TELEMETRY_PROJECT,
299
299
  subproject=_TELEMETRY_SUBPROJECT,
300
300
  )
301
- def show_tags(self) -> Dict[str, str]:
301
+ def show_tags(self) -> dict[str, str]:
302
302
  """Get a dictionary showing the tag and its value attached to the model.
303
303
 
304
304
  Returns: