snowflake-ml-python 1.13.0__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +9 -7
  2. snowflake/ml/_internal/utils/connection_params.py +5 -3
  3. snowflake/ml/_internal/utils/jwt_generator.py +3 -2
  4. snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
  6. snowflake/ml/experiment/_entities/__init__.py +2 -1
  7. snowflake/ml/experiment/_entities/run.py +0 -15
  8. snowflake/ml/experiment/_entities/run_metadata.py +3 -51
  9. snowflake/ml/experiment/experiment_tracking.py +8 -8
  10. snowflake/ml/jobs/_utils/constants.py +1 -1
  11. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +9 -7
  12. snowflake/ml/jobs/job.py +12 -4
  13. snowflake/ml/jobs/manager.py +34 -7
  14. snowflake/ml/lineage/lineage_node.py +0 -1
  15. snowflake/ml/model/__init__.py +2 -6
  16. snowflake/ml/model/_client/model/batch_inference_specs.py +0 -4
  17. snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
  18. snowflake/ml/model/_client/model/model_version_impl.py +25 -77
  19. snowflake/ml/model/_client/ops/model_ops.py +9 -2
  20. snowflake/ml/model/_client/ops/service_ops.py +82 -36
  21. snowflake/ml/model/_client/sql/service.py +29 -5
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
  23. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
  24. snowflake/ml/model/_packager/model_packager.py +4 -3
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +0 -1
  26. snowflake/ml/model/_signatures/utils.py +0 -21
  27. snowflake/ml/model/models/huggingface_pipeline.py +56 -21
  28. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +47 -3
  29. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  30. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  31. snowflake/ml/monitoring/model_monitor.py +30 -0
  32. snowflake/ml/registry/_manager/model_manager.py +1 -1
  33. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -2
  34. snowflake/ml/utils/connection_params.py +5 -3
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/METADATA +51 -34
  37. {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/RECORD +40 -39
  38. {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/WHEEL +0 -0
  39. {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowflake_ml_python-1.13.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import json
2
+ import logging
2
3
  from contextlib import contextmanager
3
4
  from typing import Any, Optional
4
5
 
5
- from absl import logging
6
6
  from packaging import version
7
7
 
8
8
  from snowflake.ml import version as snowml_version
@@ -13,6 +13,8 @@ from snowflake.snowpark import (
13
13
  session as snowpark_session,
14
14
  )
15
15
 
16
+ logger = logging.getLogger(__name__)
17
+
16
18
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
17
19
  INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
18
20
 
@@ -60,12 +62,12 @@ class PlatformCapabilities:
60
62
  @classmethod # type: ignore[arg-type]
61
63
  @contextmanager
62
64
  def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
63
- logging.debug(f"Setting mock features: {features}")
65
+ logger.debug(f"Setting mock features: {features}")
64
66
  cls.set_mock_features(features)
65
67
  try:
66
68
  yield
67
69
  finally:
68
- logging.debug(f"Clearing mock features: {features}")
70
+ logger.debug(f"Clearing mock features: {features}")
69
71
  cls.clear_mock_features()
70
72
 
71
73
  def is_inlined_deployment_spec_enabled(self) -> bool:
@@ -98,7 +100,7 @@ class PlatformCapabilities:
98
100
  error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
99
101
  )
100
102
  except snowpark_exceptions.SnowparkSQLException as e:
101
- logging.debug(f"Failed to retrieve platform capabilities: {e}")
103
+ logger.debug(f"Failed to retrieve platform capabilities: {e}")
102
104
  # This can happen is server side is older than 9.2. That is fine.
103
105
  return {}
104
106
 
@@ -144,7 +146,7 @@ class PlatformCapabilities:
144
146
 
145
147
  value = self.features.get(feature_name)
146
148
  if value is None:
147
- logging.debug(f"Feature {feature_name} not found, returning large version number")
149
+ logger.debug(f"Feature {feature_name} not found, returning large version number")
148
150
  return large_version
149
151
 
150
152
  try:
@@ -152,7 +154,7 @@ class PlatformCapabilities:
152
154
  version_str = str(value)
153
155
  return version.Version(version_str)
154
156
  except (version.InvalidVersion, ValueError, TypeError) as e:
155
- logging.debug(
157
+ logger.debug(
156
158
  f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
157
159
  f"Returning large version number"
158
160
  )
@@ -171,7 +173,7 @@ class PlatformCapabilities:
171
173
  feature_version = self._get_version_feature(feature_name)
172
174
 
173
175
  result = current_version >= feature_version
174
- logging.debug(
176
+ logger.debug(
175
177
  f"Version comparison for feature {feature_name}: "
176
178
  f"current={current_version}, feature={feature_version}, enabled={result}"
177
179
  )
@@ -1,11 +1,13 @@
1
1
  import configparser
2
+ import logging
2
3
  import os
3
4
  from typing import Optional, Union
4
5
 
5
- from absl import logging
6
6
  from cryptography.hazmat import backends
7
7
  from cryptography.hazmat.primitives import serialization
8
8
 
9
+ logger = logging.getLogger(__name__)
10
+
9
11
  _DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
10
12
 
11
13
 
@@ -106,7 +108,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
106
108
  """Loads the dictionary from snowsql config file."""
107
109
  snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
108
110
  if not os.path.exists(snowsql_config_file):
109
- logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
111
+ logger.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
110
112
  raise Exception("Snowflake SnowSQL config not found.")
111
113
 
112
114
  config = configparser.ConfigParser(inline_comment_prefixes="#")
@@ -122,7 +124,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
122
124
  # See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
123
125
  connection_name = "connections"
124
126
 
125
- logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
127
+ logger.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
126
128
  config.read(snowsql_config_file)
127
129
  conn_params = dict(config[connection_name])
128
130
  # Remap names to appropriate args in Python Connector API
@@ -110,15 +110,16 @@ class JWTGenerator:
110
110
  }
111
111
 
112
112
  # Regenerate the actual token
113
- token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
113
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) # type: ignore[arg-type]
114
114
  # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115
115
  # If the token is a byte string, convert it to a string.
116
116
  if isinstance(token, bytes):
117
117
  token = token.decode("utf-8")
118
118
  self.token = token
119
+ public_key = self.private_key.public_key()
119
120
  logger.info(
120
121
  "Generated a JWT with the following payload: %s",
121
- jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122
+ jwt.decode(self.token, key=public_key, algorithms=[JWTGenerator.ALGORITHM]), # type: ignore[arg-type]
122
123
  )
123
124
 
124
125
  return token
@@ -1,10 +1,9 @@
1
+ import logging
1
2
  import os
2
3
  import shutil
3
4
  import tempfile
4
5
  from typing import Iterable, Union
5
6
 
6
- from absl.logging import logging
7
-
8
7
  logger = logging.getLogger(__name__)
9
8
 
10
9
 
@@ -76,17 +76,30 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
76
76
  self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
77
77
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
78
78
 
79
- def modify_run(
79
+ def modify_run_add_metrics(
80
80
  self,
81
81
  *,
82
82
  experiment_name: sql_identifier.SqlIdentifier,
83
83
  run_name: sql_identifier.SqlIdentifier,
84
- run_metadata: str,
84
+ metrics: str,
85
85
  ) -> None:
86
86
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
87
87
  query_result_checker.SqlResultValidator(
88
88
  self._session,
89
- f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
89
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
90
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
+
92
+ def modify_run_add_params(
93
+ self,
94
+ *,
95
+ experiment_name: sql_identifier.SqlIdentifier,
96
+ run_name: sql_identifier.SqlIdentifier,
97
+ params: str,
98
+ ) -> None:
99
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
100
+ query_result_checker.SqlResultValidator(
101
+ self._session,
102
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
90
103
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
104
 
92
105
  def put_artifact(
@@ -1,4 +1,5 @@
1
1
  from snowflake.ml.experiment._entities.experiment import Experiment
2
2
  from snowflake.ml.experiment._entities.run import Run
3
+ from snowflake.ml.experiment._entities.run_metadata import Metric, Param
3
4
 
4
- __all__ = ["Experiment", "Run"]
5
+ __all__ = ["Experiment", "Run", "Metric", "Param"]
@@ -1,11 +1,8 @@
1
- import json
2
1
  import types
3
2
  from typing import TYPE_CHECKING, Optional
4
3
 
5
4
  from snowflake.ml._internal.utils import sql_identifier
6
5
  from snowflake.ml.experiment import _experiment_info as experiment_info
7
- from snowflake.ml.experiment._client import experiment_tracking_sql_client
8
- from snowflake.ml.experiment._entities import run_metadata
9
6
 
10
7
  if TYPE_CHECKING:
11
8
  from snowflake.ml.experiment import experiment_tracking
@@ -41,18 +38,6 @@ class Run:
41
38
  if self._experiment_tracking._run is self:
42
39
  self._experiment_tracking.end_run()
43
40
 
44
- def _get_metadata(
45
- self,
46
- ) -> run_metadata.RunMetadata:
47
- runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
48
- experiment_name=self.experiment_name, like=str(self.name)
49
- )
50
- if not runs:
51
- raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
52
- return run_metadata.RunMetadata.from_dict(
53
- json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
54
- )
55
-
56
41
  def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
57
42
  return experiment_info.ExperimentInfo(
58
43
  fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
@@ -1,12 +1,4 @@
1
1
  import dataclasses
2
- import enum
3
- import typing
4
-
5
-
6
- class RunStatus(str, enum.Enum):
7
- UNKNOWN = "UNKNOWN"
8
- RUNNING = "RUNNING"
9
- FINISHED = "FINISHED"
10
2
 
11
3
 
12
4
  @dataclasses.dataclass
@@ -15,54 +7,14 @@ class Metric:
15
7
  value: float
16
8
  step: int
17
9
 
10
+ def to_dict(self) -> dict: # type: ignore[type-arg]
11
+ return dataclasses.asdict(self)
12
+
18
13
 
19
14
  @dataclasses.dataclass
20
15
  class Param:
21
16
  name: str
22
17
  value: str
23
18
 
24
-
25
- @dataclasses.dataclass
26
- class RunMetadata:
27
- status: RunStatus
28
- metrics: list[Metric]
29
- parameters: list[Param]
30
-
31
- @classmethod
32
- def from_dict(
33
- cls,
34
- metadata: dict, # type: ignore[type-arg]
35
- ) -> "RunMetadata":
36
- return RunMetadata(
37
- status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
38
- metrics=[Metric(**m) for m in metadata.get("metrics", [])],
39
- parameters=[Param(**p) for p in metadata.get("parameters", [])],
40
- )
41
-
42
19
  def to_dict(self) -> dict: # type: ignore[type-arg]
43
20
  return dataclasses.asdict(self)
44
-
45
- def set_metric(
46
- self,
47
- key: str,
48
- value: float,
49
- step: int,
50
- ) -> None:
51
- for metric in self.metrics:
52
- if metric.name == key and metric.step == step:
53
- metric.value = value
54
- break
55
- else:
56
- self.metrics.append(Metric(name=key, value=value, step=step))
57
-
58
- def set_param(
59
- self,
60
- key: str,
61
- value: typing.Any,
62
- ) -> None:
63
- for parameter in self.parameters:
64
- if parameter.name == key:
65
- parameter.value = str(value)
66
- break
67
- else:
68
- self.parameters.append(Param(name=key, value=str(value)))
@@ -261,13 +261,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
261
261
  step: The step of the metrics. Defaults to 0.
262
262
  """
263
263
  run = self._get_or_start_run()
264
- metadata = run._get_metadata()
264
+ metrics_list = []
265
265
  for key, value in metrics.items():
266
- metadata.set_metric(key, value, step)
267
- self._sql_client.modify_run(
266
+ metrics_list.append(entities.Metric(key, value, step))
267
+ self._sql_client.modify_run_add_metrics(
268
268
  experiment_name=run.experiment_name,
269
269
  run_name=run.name,
270
- run_metadata=json.dumps(metadata.to_dict()),
270
+ metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
271
271
  )
272
272
 
273
273
  def log_param(
@@ -296,13 +296,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
296
296
  to string.
297
297
  """
298
298
  run = self._get_or_start_run()
299
- metadata = run._get_metadata()
299
+ params_list = []
300
300
  for key, value in params.items():
301
- metadata.set_param(key, value)
302
- self._sql_client.modify_run(
301
+ params_list.append(entities.Param(key, str(value)))
302
+ self._sql_client.modify_run_add_params(
303
303
  experiment_name=run.experiment_name,
304
304
  run_name=run.name,
305
- run_metadata=json.dumps(metadata.to_dict()),
305
+ params=json.dumps([param.to_dict() for param in params_list]),
306
306
  )
307
307
 
308
308
  def log_artifact(
@@ -25,7 +25,7 @@ RESULT_PATH_DEFAULT_VALUE = f"{OUTPUT_STAGE_SUBPATH}/mljob_result.pkl"
25
25
  DEFAULT_IMAGE_REPO = "/snowflake/images/snowflake_images"
26
26
  DEFAULT_IMAGE_CPU = "st_plat/runtime/x86/runtime_image/snowbooks"
27
27
  DEFAULT_IMAGE_GPU = "st_plat/runtime/x86/generic_gpu/runtime_image/snowbooks"
28
- DEFAULT_IMAGE_TAG = "1.6.2"
28
+ DEFAULT_IMAGE_TAG = "1.8.0"
29
29
  DEFAULT_ENTRYPOINT_PATH = "func.py"
30
30
 
31
31
  # Percent of container memory to allocate for /dev/shm volume
@@ -234,12 +234,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
234
234
  if payload_dir and payload_dir not in sys.path:
235
235
  sys.path.insert(0, payload_dir)
236
236
 
237
- # Create a Snowpark session before running the script
238
- # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
239
- config = SnowflakeLoginOptions()
240
- config["client_session_keep_alive"] = "True"
241
- session = Session.builder.configs(config).create() # noqa: F841
242
-
243
237
  try:
244
238
 
245
239
  if main_func:
@@ -266,7 +260,6 @@ def run_script(script_path: str, *script_args: Any, main_func: Optional[str] = N
266
260
  finally:
267
261
  # Restore original sys.argv
268
262
  sys.argv = original_argv
269
- session.close()
270
263
 
271
264
 
272
265
  def main(script_path: str, *script_args: Any, script_main_func: Optional[str] = None) -> ExecutionResult:
@@ -297,6 +290,12 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
297
290
  except ModuleNotFoundError:
298
291
  warnings.warn("Ray is not installed, skipping Ray initialization", ImportWarning, stacklevel=1)
299
292
 
293
+ # Create a Snowpark session before starting
294
+ # Session can be retrieved from using snowflake.snowpark.context.get_active_session()
295
+ config = SnowflakeLoginOptions()
296
+ config["client_session_keep_alive"] = "True"
297
+ session = Session.builder.configs(config).create() # noqa: F841
298
+
300
299
  try:
301
300
  # Wait for minimum required instances if specified
302
301
  min_instances_str = os.environ.get(MIN_INSTANCES_ENV_VAR) or "1"
@@ -352,6 +351,9 @@ def main(script_path: str, *script_args: Any, script_main_func: Optional[str] =
352
351
  f"Failed to serialize JSON result to {result_json_path}: {json_exc}", RuntimeWarning, stacklevel=1
353
352
  )
354
353
 
354
+ # Close the session after serializing the result
355
+ session.close()
356
+
355
357
 
356
358
  if __name__ == "__main__":
357
359
  # Parse command line arguments
snowflake/ml/jobs/job.py CHANGED
@@ -83,6 +83,8 @@ class MLJob(Generic[T], SerializableSessionMixin):
83
83
  def _container_spec(self) -> dict[str, Any]:
84
84
  """Get the job's main container spec."""
85
85
  containers = self._service_spec["spec"]["containers"]
86
+ if len(containers) == 1:
87
+ return cast(dict[str, Any], containers[0])
86
88
  try:
87
89
  container_spec = next(c for c in containers if c["name"] == constants.DEFAULT_CONTAINER_NAME)
88
90
  except StopIteration:
@@ -163,7 +165,7 @@ class MLJob(Generic[T], SerializableSessionMixin):
163
165
  Returns:
164
166
  The job's execution logs.
165
167
  """
166
- logs = _get_logs(self._session, self.id, limit, instance_id, verbose)
168
+ logs = _get_logs(self._session, self.id, limit, instance_id, self._container_spec["name"], verbose)
167
169
  assert isinstance(logs, str) # mypy
168
170
  if as_list:
169
171
  return logs.splitlines()
@@ -281,7 +283,12 @@ def _get_service_spec(session: snowpark.Session, job_id: str) -> dict[str, Any]:
281
283
 
282
284
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["job_id", "limit", "instance_id"])
283
285
  def _get_logs(
284
- session: snowpark.Session, job_id: str, limit: int = -1, instance_id: Optional[int] = None, verbose: bool = True
286
+ session: snowpark.Session,
287
+ job_id: str,
288
+ limit: int = -1,
289
+ instance_id: Optional[int] = None,
290
+ container_name: str = constants.DEFAULT_CONTAINER_NAME,
291
+ verbose: bool = True,
285
292
  ) -> str:
286
293
  """
287
294
  Retrieve the job's execution logs.
@@ -291,6 +298,7 @@ def _get_logs(
291
298
  limit: The maximum number of lines to return. Negative values are treated as no limit.
292
299
  session: The Snowpark session to use. If none specified, uses active session.
293
300
  instance_id: Optional instance ID to get logs from a specific instance.
301
+ container_name: The container name to get logs from a specific container.
294
302
  verbose: Whether to return the full log or just the portion between START and END messages.
295
303
 
296
304
  Returns:
@@ -311,7 +319,7 @@ def _get_logs(
311
319
  params: list[Any] = [
312
320
  job_id,
313
321
  0 if instance_id is None else instance_id,
314
- constants.DEFAULT_CONTAINER_NAME,
322
+ container_name,
315
323
  ]
316
324
  if limit > 0:
317
325
  params.append(limit)
@@ -337,7 +345,7 @@ def _get_logs(
337
345
  job_id,
338
346
  limit=limit,
339
347
  instance_id=instance_id if instance_id else 0,
340
- container_name=constants.DEFAULT_CONTAINER_NAME,
348
+ container_name=container_name,
341
349
  )
342
350
  full_log = os.linesep.join(row[0] for row in logs)
343
351
 
@@ -232,6 +232,7 @@ def submit_file(
232
232
  enable_metrics (bool): Whether to enable metrics publishing for the job.
233
233
  query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
234
234
  spec_overrides (dict): A dictionary of overrides for the service spec.
235
+ imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
235
236
 
236
237
  Returns:
237
238
  An object representing the submitted job.
@@ -286,6 +287,7 @@ def submit_directory(
286
287
  enable_metrics (bool): Whether to enable metrics publishing for the job.
287
288
  query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
288
289
  spec_overrides (dict): A dictionary of overrides for the service spec.
290
+ imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
289
291
 
290
292
  Returns:
291
293
  An object representing the submitted job.
@@ -341,6 +343,7 @@ def submit_from_stage(
341
343
  enable_metrics (bool): Whether to enable metrics publishing for the job.
342
344
  query_warehouse (str): The query warehouse to use. Defaults to session warehouse.
343
345
  spec_overrides (dict): A dictionary of overrides for the service spec.
346
+ imports (list[Union[tuple[str, str], tuple[str]]]): A list of additional payloads used in the job.
344
347
 
345
348
  Returns:
346
349
  An object representing the submitted job.
@@ -404,6 +407,8 @@ def _submit_job(
404
407
  "num_instances", # deprecated
405
408
  "target_instances",
406
409
  "min_instances",
410
+ "enable_metrics",
411
+ "query_warehouse",
407
412
  ],
408
413
  )
409
414
  def _submit_job(
@@ -447,6 +452,13 @@ def _submit_job(
447
452
  )
448
453
  target_instances = max(target_instances, kwargs.pop("num_instances"))
449
454
 
455
+ imports = None
456
+ if "additional_payloads" in kwargs:
457
+ logger.warning(
458
+ "'additional_payloads' is deprecated and will be removed in a future release. Use 'imports' instead."
459
+ )
460
+ imports = kwargs.pop("additional_payloads")
461
+
450
462
  # Use kwargs for less common optional parameters
451
463
  database = kwargs.pop("database", None)
452
464
  schema = kwargs.pop("schema", None)
@@ -457,10 +469,7 @@ def _submit_job(
457
469
  spec_overrides = kwargs.pop("spec_overrides", None)
458
470
  enable_metrics = kwargs.pop("enable_metrics", True)
459
471
  query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
460
- additional_payloads = kwargs.pop("additional_payloads", None)
461
-
462
- if additional_payloads:
463
- logger.warning("'additional_payloads' is in private preview since 1.9.1. Do not use it in production.")
472
+ imports = kwargs.pop("imports", None) or imports
464
473
 
465
474
  # Warn if there are unknown kwargs
466
475
  if kwargs:
@@ -492,7 +501,7 @@ def _submit_job(
492
501
  try:
493
502
  # Upload payload
494
503
  uploaded_payload = payload_utils.JobPayload(
495
- source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
504
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=imports
496
505
  ).upload(session, stage_path)
497
506
  except snowpark.exceptions.SnowparkSQLException as e:
498
507
  if e.sql_error_code == 90106:
@@ -501,6 +510,22 @@ def _submit_job(
501
510
  )
502
511
  raise
503
512
 
513
+ # FIXME: Temporary patches, remove this after v1 is deprecated
514
+ if target_instances > 1:
515
+ default_spec_overrides = {
516
+ "spec": {
517
+ "endpoints": [
518
+ {"name": "ray-dashboard-endpoint", "port": 12003, "protocol": "TCP"},
519
+ ]
520
+ },
521
+ }
522
+ if spec_overrides:
523
+ spec_overrides = spec_utils.merge_patch(
524
+ default_spec_overrides, spec_overrides, display_name="spec_overrides"
525
+ )
526
+ else:
527
+ spec_overrides = default_spec_overrides
528
+
504
529
  if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
505
530
  # Add default env vars (extracted from spec_utils.generate_service_spec)
506
531
  combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
@@ -668,8 +693,10 @@ def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
668
693
  session = session or get_active_session()
669
694
  except snowpark.exceptions.SnowparkSessionException as e:
670
695
  if "More than one active session" in e.message:
671
- raise RuntimeError("Please specify the session as a parameter in API call")
696
+ raise RuntimeError(
697
+ "More than one active session is found. Please specify the session explicitly as a parameter"
698
+ ) from None
672
699
  if "No default Session is found" in e.message:
673
- raise RuntimeError("Please create a session before API call")
700
+ raise RuntimeError("No active session is found. Please create a session") from None
674
701
  raise
675
702
  return session
@@ -83,7 +83,6 @@ class LineageNode(mixins.SerializableSessionMixin):
83
83
  raise NotImplementedError()
84
84
 
85
85
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
86
- @snowpark._internal.utils.private_preview(version="1.5.3")
87
86
  def lineage(
88
87
  self,
89
88
  direction: Literal["upstream", "downstream"] = "downstream",
@@ -1,10 +1,6 @@
1
- from snowflake.ml.model._client.model.batch_inference_specs import (
2
- InputSpec,
3
- JobSpec,
4
- OutputSpec,
5
- )
1
+ from snowflake.ml.model._client.model.batch_inference_specs import JobSpec, OutputSpec
6
2
  from snowflake.ml.model._client.model.model_impl import Model
7
3
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
8
4
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
9
5
 
10
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
6
+ __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "JobSpec", "OutputSpec"]
@@ -3,10 +3,6 @@ from typing import Optional, Union
3
3
  from pydantic import BaseModel
4
4
 
5
5
 
6
- class InputSpec(BaseModel):
7
- stage_location: str
8
-
9
-
10
6
  class OutputSpec(BaseModel):
11
7
  stage_location: str
12
8
 
@@ -0,0 +1,55 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from snowflake.ml.model._client.ops import service_ops
4
+
5
+
6
+ def _get_inference_engine_args(
7
+ experimental_options: Optional[dict[str, Any]],
8
+ ) -> Optional[service_ops.InferenceEngineArgs]:
9
+
10
+ if not experimental_options:
11
+ return None
12
+
13
+ if "inference_engine" not in experimental_options:
14
+ raise ValueError("inference_engine is required in experimental_options")
15
+
16
+ return service_ops.InferenceEngineArgs(
17
+ inference_engine=experimental_options["inference_engine"],
18
+ inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
19
+ )
20
+
21
+
22
+ def _enrich_inference_engine_args(
23
+ inference_engine_args: service_ops.InferenceEngineArgs,
24
+ gpu_requests: Optional[Union[str, int]] = None,
25
+ ) -> Optional[service_ops.InferenceEngineArgs]:
26
+ """Enrich inference engine args with model path and tensor parallelism settings.
27
+
28
+ Args:
29
+ inference_engine_args: The original inference engine args
30
+ gpu_requests: The number of GPUs requested
31
+
32
+ Returns:
33
+ Enriched inference engine args
34
+
35
+ Raises:
36
+ ValueError: Invalid gpu_requests
37
+ """
38
+ if inference_engine_args.inference_engine_args_override is None:
39
+ inference_engine_args.inference_engine_args_override = []
40
+
41
+ gpu_count = None
42
+
43
+ # Set tensor-parallelism if gpu_requests is specified
44
+ if gpu_requests is not None:
45
+ # assert gpu_requests is a string or an integer before casting to int
46
+ try:
47
+ gpu_count = int(gpu_requests)
48
+ if gpu_count > 0:
49
+ inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
50
+ else:
51
+ raise ValueError(f"GPU count must be greater than 0, got {gpu_count}")
52
+ except ValueError:
53
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests} with type {type(gpu_requests).__name__}")
54
+
55
+ return inference_engine_args