snowflake-ml-python 1.14.0__py3-none-any.whl → 1.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/platform_capabilities.py +9 -7
  2. snowflake/ml/_internal/utils/connection_params.py +5 -3
  3. snowflake/ml/_internal/utils/jwt_generator.py +3 -2
  4. snowflake/ml/_internal/utils/temp_file_utils.py +1 -2
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +16 -3
  6. snowflake/ml/experiment/_entities/__init__.py +2 -1
  7. snowflake/ml/experiment/_entities/run.py +0 -15
  8. snowflake/ml/experiment/_entities/run_metadata.py +3 -51
  9. snowflake/ml/experiment/experiment_tracking.py +8 -8
  10. snowflake/ml/model/__init__.py +2 -6
  11. snowflake/ml/model/_client/model/batch_inference_specs.py +0 -4
  12. snowflake/ml/model/_client/model/inference_engine_utils.py +55 -0
  13. snowflake/ml/model/_client/model/model_version_impl.py +25 -62
  14. snowflake/ml/model/_client/ops/service_ops.py +18 -2
  15. snowflake/ml/model/_client/sql/service.py +29 -2
  16. snowflake/ml/model/_packager/model_handlers/_utils.py +4 -2
  17. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -5
  18. snowflake/ml/model/_packager/model_packager.py +4 -3
  19. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +0 -1
  20. snowflake/ml/model/_signatures/utils.py +0 -21
  21. snowflake/ml/model/models/huggingface_pipeline.py +56 -21
  22. snowflake/ml/registry/_manager/model_manager.py +1 -1
  23. snowflake/ml/registry/_manager/model_parameter_reconciler.py +2 -2
  24. snowflake/ml/utils/connection_params.py +5 -3
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/METADATA +42 -35
  27. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/RECORD +30 -29
  28. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.14.0.dist-info → snowflake_ml_python-1.15.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,8 @@
1
1
  import json
2
+ import logging
2
3
  from contextlib import contextmanager
3
4
  from typing import Any, Optional
4
5
 
5
- from absl import logging
6
6
  from packaging import version
7
7
 
8
8
  from snowflake.ml import version as snowml_version
@@ -13,6 +13,8 @@ from snowflake.snowpark import (
13
13
  session as snowpark_session,
14
14
  )
15
15
 
16
+ logger = logging.getLogger(__name__)
17
+
16
18
  LIVE_COMMIT_PARAMETER = "ENABLE_LIVE_VERSION_IN_SDK"
17
19
  INLINE_DEPLOYMENT_SPEC_PARAMETER = "ENABLE_INLINE_DEPLOYMENT_SPEC_FROM_CLIENT_VERSION"
18
20
 
@@ -60,12 +62,12 @@ class PlatformCapabilities:
60
62
  @classmethod # type: ignore[arg-type]
61
63
  @contextmanager
62
64
  def mock_features(cls, features: dict[str, Any] = _dummy_features) -> None: # type: ignore[misc]
63
- logging.debug(f"Setting mock features: {features}")
65
+ logger.debug(f"Setting mock features: {features}")
64
66
  cls.set_mock_features(features)
65
67
  try:
66
68
  yield
67
69
  finally:
68
- logging.debug(f"Clearing mock features: {features}")
70
+ logger.debug(f"Clearing mock features: {features}")
69
71
  cls.clear_mock_features()
70
72
 
71
73
  def is_inlined_deployment_spec_enabled(self) -> bool:
@@ -98,7 +100,7 @@ class PlatformCapabilities:
98
100
  error_code=error_codes.INTERNAL_SNOWML_ERROR, original_exception=RuntimeError(message)
99
101
  )
100
102
  except snowpark_exceptions.SnowparkSQLException as e:
101
- logging.debug(f"Failed to retrieve platform capabilities: {e}")
103
+ logger.debug(f"Failed to retrieve platform capabilities: {e}")
102
104
  # This can happen is server side is older than 9.2. That is fine.
103
105
  return {}
104
106
 
@@ -144,7 +146,7 @@ class PlatformCapabilities:
144
146
 
145
147
  value = self.features.get(feature_name)
146
148
  if value is None:
147
- logging.debug(f"Feature {feature_name} not found, returning large version number")
149
+ logger.debug(f"Feature {feature_name} not found, returning large version number")
148
150
  return large_version
149
151
 
150
152
  try:
@@ -152,7 +154,7 @@ class PlatformCapabilities:
152
154
  version_str = str(value)
153
155
  return version.Version(version_str)
154
156
  except (version.InvalidVersion, ValueError, TypeError) as e:
155
- logging.debug(
157
+ logger.debug(
156
158
  f"Failed to parse version from feature {feature_name} with value '{value}': {e}. "
157
159
  f"Returning large version number"
158
160
  )
@@ -171,7 +173,7 @@ class PlatformCapabilities:
171
173
  feature_version = self._get_version_feature(feature_name)
172
174
 
173
175
  result = current_version >= feature_version
174
- logging.debug(
176
+ logger.debug(
175
177
  f"Version comparison for feature {feature_name}: "
176
178
  f"current={current_version}, feature={feature_version}, enabled={result}"
177
179
  )
@@ -1,11 +1,13 @@
1
1
  import configparser
2
+ import logging
2
3
  import os
3
4
  from typing import Optional, Union
4
5
 
5
- from absl import logging
6
6
  from cryptography.hazmat import backends
7
7
  from cryptography.hazmat.primitives import serialization
8
8
 
9
+ logger = logging.getLogger(__name__)
10
+
9
11
  _DEFAULT_CONNECTION_FILE = "~/.snowsql/config"
10
12
 
11
13
 
@@ -106,7 +108,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
106
108
  """Loads the dictionary from snowsql config file."""
107
109
  snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
108
110
  if not os.path.exists(snowsql_config_file):
109
- logging.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
111
+ logger.error(f"Connection name given but snowsql config file is not found at: {snowsql_config_file}")
110
112
  raise Exception("Snowflake SnowSQL config not found.")
111
113
 
112
114
  config = configparser.ConfigParser(inline_comment_prefixes="#")
@@ -122,7 +124,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
122
124
  # See https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
123
125
  connection_name = "connections"
124
126
 
125
- logging.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
127
+ logger.info(f"Reading {snowsql_config_file} for connection parameters defined as {connection_name}")
126
128
  config.read(snowsql_config_file)
127
129
  conn_params = dict(config[connection_name])
128
130
  # Remap names to appropriate args in Python Connector API
@@ -110,15 +110,16 @@ class JWTGenerator:
110
110
  }
111
111
 
112
112
  # Regenerate the actual token
113
- token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM)
113
+ token = jwt.encode(payload, key=self.private_key, algorithm=JWTGenerator.ALGORITHM) # type: ignore[arg-type]
114
114
  # If you are using a version of PyJWT prior to 2.0, jwt.encode returns a byte string instead of a string.
115
115
  # If the token is a byte string, convert it to a string.
116
116
  if isinstance(token, bytes):
117
117
  token = token.decode("utf-8")
118
118
  self.token = token
119
+ public_key = self.private_key.public_key()
119
120
  logger.info(
120
121
  "Generated a JWT with the following payload: %s",
121
- jwt.decode(self.token, key=self.private_key.public_key(), algorithms=[JWTGenerator.ALGORITHM]),
122
+ jwt.decode(self.token, key=public_key, algorithms=[JWTGenerator.ALGORITHM]), # type: ignore[arg-type]
122
123
  )
123
124
 
124
125
  return token
@@ -1,10 +1,9 @@
1
+ import logging
1
2
  import os
2
3
  import shutil
3
4
  import tempfile
4
5
  from typing import Iterable, Union
5
6
 
6
- from absl.logging import logging
7
-
8
7
  logger = logging.getLogger(__name__)
9
8
 
10
9
 
@@ -76,17 +76,30 @@ class ExperimentTrackingSQLClient(_base._BaseSQLClient):
76
76
  self._session, f"ALTER EXPERIMENT {experiment_fqn} DROP RUN {run_name}"
77
77
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
78
78
 
79
- def modify_run(
79
+ def modify_run_add_metrics(
80
80
  self,
81
81
  *,
82
82
  experiment_name: sql_identifier.SqlIdentifier,
83
83
  run_name: sql_identifier.SqlIdentifier,
84
- run_metadata: str,
84
+ metrics: str,
85
85
  ) -> None:
86
86
  experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
87
87
  query_result_checker.SqlResultValidator(
88
88
  self._session,
89
- f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} SET METADATA=$${run_metadata}$$",
89
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD METRICS=$${metrics}$$",
90
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
+
92
+ def modify_run_add_params(
93
+ self,
94
+ *,
95
+ experiment_name: sql_identifier.SqlIdentifier,
96
+ run_name: sql_identifier.SqlIdentifier,
97
+ params: str,
98
+ ) -> None:
99
+ experiment_fqn = self.fully_qualified_object_name(self._database_name, self._schema_name, experiment_name)
100
+ query_result_checker.SqlResultValidator(
101
+ self._session,
102
+ f"ALTER EXPERIMENT {experiment_fqn} MODIFY RUN {run_name} ADD PARAMETERS=$${params}$$",
90
103
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
91
104
 
92
105
  def put_artifact(
@@ -1,4 +1,5 @@
1
1
  from snowflake.ml.experiment._entities.experiment import Experiment
2
2
  from snowflake.ml.experiment._entities.run import Run
3
+ from snowflake.ml.experiment._entities.run_metadata import Metric, Param
3
4
 
4
- __all__ = ["Experiment", "Run"]
5
+ __all__ = ["Experiment", "Run", "Metric", "Param"]
@@ -1,11 +1,8 @@
1
- import json
2
1
  import types
3
2
  from typing import TYPE_CHECKING, Optional
4
3
 
5
4
  from snowflake.ml._internal.utils import sql_identifier
6
5
  from snowflake.ml.experiment import _experiment_info as experiment_info
7
- from snowflake.ml.experiment._client import experiment_tracking_sql_client
8
- from snowflake.ml.experiment._entities import run_metadata
9
6
 
10
7
  if TYPE_CHECKING:
11
8
  from snowflake.ml.experiment import experiment_tracking
@@ -41,18 +38,6 @@ class Run:
41
38
  if self._experiment_tracking._run is self:
42
39
  self._experiment_tracking.end_run()
43
40
 
44
- def _get_metadata(
45
- self,
46
- ) -> run_metadata.RunMetadata:
47
- runs = self._experiment_tracking._sql_client.show_runs_in_experiment(
48
- experiment_name=self.experiment_name, like=str(self.name)
49
- )
50
- if not runs:
51
- raise RuntimeError(f"Run {self.name} not found in experiment {self.experiment_name}.")
52
- return run_metadata.RunMetadata.from_dict(
53
- json.loads(runs[0][experiment_tracking_sql_client.ExperimentTrackingSQLClient.RUN_METADATA_COL_NAME])
54
- )
55
-
56
41
  def _get_experiment_info(self) -> experiment_info.ExperimentInfo:
57
42
  return experiment_info.ExperimentInfo(
58
43
  fully_qualified_name=self._experiment_tracking._sql_client.fully_qualified_object_name(
@@ -1,12 +1,4 @@
1
1
  import dataclasses
2
- import enum
3
- import typing
4
-
5
-
6
- class RunStatus(str, enum.Enum):
7
- UNKNOWN = "UNKNOWN"
8
- RUNNING = "RUNNING"
9
- FINISHED = "FINISHED"
10
2
 
11
3
 
12
4
  @dataclasses.dataclass
@@ -15,54 +7,14 @@ class Metric:
15
7
  value: float
16
8
  step: int
17
9
 
10
+ def to_dict(self) -> dict: # type: ignore[type-arg]
11
+ return dataclasses.asdict(self)
12
+
18
13
 
19
14
  @dataclasses.dataclass
20
15
  class Param:
21
16
  name: str
22
17
  value: str
23
18
 
24
-
25
- @dataclasses.dataclass
26
- class RunMetadata:
27
- status: RunStatus
28
- metrics: list[Metric]
29
- parameters: list[Param]
30
-
31
- @classmethod
32
- def from_dict(
33
- cls,
34
- metadata: dict, # type: ignore[type-arg]
35
- ) -> "RunMetadata":
36
- return RunMetadata(
37
- status=RunStatus(metadata.get("status", RunStatus.UNKNOWN.value)),
38
- metrics=[Metric(**m) for m in metadata.get("metrics", [])],
39
- parameters=[Param(**p) for p in metadata.get("parameters", [])],
40
- )
41
-
42
19
  def to_dict(self) -> dict: # type: ignore[type-arg]
43
20
  return dataclasses.asdict(self)
44
-
45
- def set_metric(
46
- self,
47
- key: str,
48
- value: float,
49
- step: int,
50
- ) -> None:
51
- for metric in self.metrics:
52
- if metric.name == key and metric.step == step:
53
- metric.value = value
54
- break
55
- else:
56
- self.metrics.append(Metric(name=key, value=value, step=step))
57
-
58
- def set_param(
59
- self,
60
- key: str,
61
- value: typing.Any,
62
- ) -> None:
63
- for parameter in self.parameters:
64
- if parameter.name == key:
65
- parameter.value = str(value)
66
- break
67
- else:
68
- self.parameters.append(Param(name=key, value=str(value)))
@@ -261,13 +261,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
261
261
  step: The step of the metrics. Defaults to 0.
262
262
  """
263
263
  run = self._get_or_start_run()
264
- metadata = run._get_metadata()
264
+ metrics_list = []
265
265
  for key, value in metrics.items():
266
- metadata.set_metric(key, value, step)
267
- self._sql_client.modify_run(
266
+ metrics_list.append(entities.Metric(key, value, step))
267
+ self._sql_client.modify_run_add_metrics(
268
268
  experiment_name=run.experiment_name,
269
269
  run_name=run.name,
270
- run_metadata=json.dumps(metadata.to_dict()),
270
+ metrics=json.dumps([metric.to_dict() for metric in metrics_list]),
271
271
  )
272
272
 
273
273
  def log_param(
@@ -296,13 +296,13 @@ class ExperimentTracking(mixins.SerializableSessionMixin):
296
296
  to string.
297
297
  """
298
298
  run = self._get_or_start_run()
299
- metadata = run._get_metadata()
299
+ params_list = []
300
300
  for key, value in params.items():
301
- metadata.set_param(key, value)
302
- self._sql_client.modify_run(
301
+ params_list.append(entities.Param(key, str(value)))
302
+ self._sql_client.modify_run_add_params(
303
303
  experiment_name=run.experiment_name,
304
304
  run_name=run.name,
305
- run_metadata=json.dumps(metadata.to_dict()),
305
+ params=json.dumps([param.to_dict() for param in params_list]),
306
306
  )
307
307
 
308
308
  def log_artifact(
@@ -1,10 +1,6 @@
1
- from snowflake.ml.model._client.model.batch_inference_specs import (
2
- InputSpec,
3
- JobSpec,
4
- OutputSpec,
5
- )
1
+ from snowflake.ml.model._client.model.batch_inference_specs import JobSpec, OutputSpec
6
2
  from snowflake.ml.model._client.model.model_impl import Model
7
3
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
8
4
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
9
5
 
10
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
6
+ __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "JobSpec", "OutputSpec"]
@@ -3,10 +3,6 @@ from typing import Optional, Union
3
3
  from pydantic import BaseModel
4
4
 
5
5
 
6
- class InputSpec(BaseModel):
7
- stage_location: str
8
-
9
-
10
6
  class OutputSpec(BaseModel):
11
7
  stage_location: str
12
8
 
@@ -0,0 +1,55 @@
1
+ from typing import Any, Optional, Union
2
+
3
+ from snowflake.ml.model._client.ops import service_ops
4
+
5
+
6
+ def _get_inference_engine_args(
7
+ experimental_options: Optional[dict[str, Any]],
8
+ ) -> Optional[service_ops.InferenceEngineArgs]:
9
+
10
+ if not experimental_options:
11
+ return None
12
+
13
+ if "inference_engine" not in experimental_options:
14
+ raise ValueError("inference_engine is required in experimental_options")
15
+
16
+ return service_ops.InferenceEngineArgs(
17
+ inference_engine=experimental_options["inference_engine"],
18
+ inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
19
+ )
20
+
21
+
22
+ def _enrich_inference_engine_args(
23
+ inference_engine_args: service_ops.InferenceEngineArgs,
24
+ gpu_requests: Optional[Union[str, int]] = None,
25
+ ) -> Optional[service_ops.InferenceEngineArgs]:
26
+ """Enrich inference engine args with model path and tensor parallelism settings.
27
+
28
+ Args:
29
+ inference_engine_args: The original inference engine args
30
+ gpu_requests: The number of GPUs requested
31
+
32
+ Returns:
33
+ Enriched inference engine args
34
+
35
+ Raises:
36
+ ValueError: Invalid gpu_requests
37
+ """
38
+ if inference_engine_args.inference_engine_args_override is None:
39
+ inference_engine_args.inference_engine_args_override = []
40
+
41
+ gpu_count = None
42
+
43
+ # Set tensor-parallelism if gpu_requests is specified
44
+ if gpu_requests is not None:
45
+ # assert gpu_requests is a string or an integer before casting to int
46
+ try:
47
+ gpu_count = int(gpu_requests)
48
+ if gpu_count > 0:
49
+ inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
50
+ else:
51
+ raise ValueError(f"GPU count must be greater than 0, got {gpu_count}")
52
+ except ValueError:
53
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests} with type {type(gpu_requests).__name__}")
54
+
55
+ return inference_engine_args
@@ -12,7 +12,10 @@ from snowflake.ml._internal import telemetry
12
12
  from snowflake.ml._internal.utils import sql_identifier
13
13
  from snowflake.ml.lineage import lineage_node
14
14
  from snowflake.ml.model import task, type_hints
15
- from snowflake.ml.model._client.model import batch_inference_specs
15
+ from snowflake.ml.model._client.model import (
16
+ batch_inference_specs,
17
+ inference_engine_utils,
18
+ )
16
19
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
17
20
  from snowflake.ml.model._model_composer import model_composer
18
21
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -22,6 +25,7 @@ from snowflake.snowpark import Session, async_job, dataframe
22
25
  _TELEMETRY_PROJECT = "MLOps"
23
26
  _TELEMETRY_SUBPROJECT = "ModelManagement"
24
27
  _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
28
+ _BATCH_INFERENCE_TEMPORARY_FOLDER = "_temporary"
25
29
 
26
30
 
27
31
  class ExportMode(enum.Enum):
@@ -553,7 +557,7 @@ class ModelVersion(lineage_node.LineageNode):
553
557
  self,
554
558
  *,
555
559
  compute_pool: str,
556
- input_spec: batch_inference_specs.InputSpec,
560
+ input_spec: dataframe.DataFrame,
557
561
  output_spec: batch_inference_specs.OutputSpec,
558
562
  job_spec: Optional[batch_inference_specs.JobSpec] = None,
559
563
  ) -> jobs.MLJob[Any]:
@@ -569,6 +573,18 @@ class ModelVersion(lineage_node.LineageNode):
569
573
  if warehouse is None:
570
574
  raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
571
575
 
576
+ # use a temporary folder in the output stage to store the intermediate output from the dataframe
577
+ output_stage_location = output_spec.stage_location
578
+ if not output_stage_location.endswith("/"):
579
+ output_stage_location += "/"
580
+ input_stage_location = f"{output_stage_location}{_BATCH_INFERENCE_TEMPORARY_FOLDER}/"
581
+
582
+ try:
583
+ input_spec.write.copy_into_location(location=input_stage_location, file_format_type="parquet", header=True)
584
+ # todo: be specific about the type of errors to provide better error messages.
585
+ except Exception as e:
586
+ raise RuntimeError(f"Failed to process input_spec: {e}")
587
+
572
588
  if job_spec.job_name is None:
573
589
  # Same as the MLJob ID generation logic with a different prefix
574
590
  job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
@@ -592,9 +608,9 @@ class ModelVersion(lineage_node.LineageNode):
592
608
  job_name=job_name,
593
609
  replicas=job_spec.replicas,
594
610
  # input and output
595
- input_stage_location=input_spec.stage_location,
611
+ input_stage_location=input_stage_location,
596
612
  input_file_pattern="*",
597
- output_stage_location=output_spec.stage_location,
613
+ output_stage_location=output_stage_location,
598
614
  completion_filename="_SUCCESS",
599
615
  # misc
600
616
  statement_params=statement_params,
@@ -768,60 +784,6 @@ class ModelVersion(lineage_node.LineageNode):
768
784
  version_name=sql_identifier.SqlIdentifier(version),
769
785
  )
770
786
 
771
- def _get_inference_engine_args(
772
- self, experimental_options: Optional[dict[str, Any]]
773
- ) -> Optional[service_ops.InferenceEngineArgs]:
774
-
775
- if not experimental_options:
776
- return None
777
-
778
- if "inference_engine" not in experimental_options:
779
- raise ValueError("inference_engine is required in experimental_options")
780
-
781
- return service_ops.InferenceEngineArgs(
782
- inference_engine=experimental_options["inference_engine"],
783
- inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
784
- )
785
-
786
- def _enrich_inference_engine_args(
787
- self,
788
- inference_engine_args: service_ops.InferenceEngineArgs,
789
- gpu_requests: Optional[Union[str, int]] = None,
790
- ) -> Optional[service_ops.InferenceEngineArgs]:
791
- """Enrich inference engine args with tensor parallelism settings.
792
-
793
- Args:
794
- inference_engine_args: The original inference engine args
795
- gpu_requests: The number of GPUs requested
796
-
797
- Returns:
798
- Enriched inference engine args
799
-
800
- Raises:
801
- ValueError: Invalid gpu_requests
802
- """
803
- if inference_engine_args.inference_engine_args_override is None:
804
- inference_engine_args.inference_engine_args_override = []
805
-
806
- gpu_count = None
807
-
808
- # Set tensor-parallelism if gpu_requests is specified
809
- if gpu_requests is not None:
810
- # assert gpu_requests is a string or an integer before casting to int
811
- if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
812
- try:
813
- gpu_count = int(gpu_requests)
814
- except ValueError:
815
- raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
816
-
817
- if gpu_count is not None:
818
- if gpu_count > 0:
819
- inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
820
- else:
821
- raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
822
-
823
- return inference_engine_args
824
-
825
787
  def _check_huggingface_text_generation_model(
826
788
  self,
827
789
  statement_params: Optional[dict[str, Any]] = None,
@@ -1101,13 +1063,14 @@ class ModelVersion(lineage_node.LineageNode):
1101
1063
  if experimental_options:
1102
1064
  self._check_huggingface_text_generation_model(statement_params)
1103
1065
 
1104
- inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args(
1105
- experimental_options
1106
- )
1066
+ inference_engine_args = inference_engine_utils._get_inference_engine_args(experimental_options)
1107
1067
 
1108
1068
  # Enrich inference engine args if inference engine is specified
1109
1069
  if inference_engine_args is not None:
1110
- inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests)
1070
+ inference_engine_args = inference_engine_utils._enrich_inference_engine_args(
1071
+ inference_engine_args,
1072
+ gpu_requests,
1073
+ )
1111
1074
 
1112
1075
  from snowflake.ml.model import event_handler
1113
1076
  from snowflake.snowpark import exceptions
@@ -155,7 +155,8 @@ class ServiceOperator:
155
155
  database_name=database_name,
156
156
  schema_name=schema_name,
157
157
  )
158
- if pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled():
158
+ self._use_inlined_deployment_spec = pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled()
159
+ if self._use_inlined_deployment_spec:
159
160
  self._workspace = None
160
161
  self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
161
162
  else:
@@ -264,7 +265,14 @@ class ServiceOperator:
264
265
  self._model_deployment_spec.add_hf_logger_spec(
265
266
  hf_model_name=hf_model_args.hf_model_name,
266
267
  hf_task=hf_model_args.hf_task,
267
- hf_token=hf_model_args.hf_token,
268
+ hf_token=(
269
+ # when using inlined deployment spec, we need to use QMARK_RESERVED_TOKEN
270
+ # to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
271
+ # noop if using file-based deployment spec or token is not provided
272
+ service_sql.QMARK_RESERVED_TOKEN
273
+ if hf_model_args.hf_token and self._use_inlined_deployment_spec
274
+ else hf_model_args.hf_token
275
+ ),
268
276
  hf_tokenizer=hf_model_args.hf_tokenizer,
269
277
  hf_revision=hf_model_args.hf_revision,
270
278
  hf_trust_remote_code=hf_model_args.hf_trust_remote_code,
@@ -320,6 +328,14 @@ class ServiceOperator:
320
328
  model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
321
329
  ),
322
330
  model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
331
+ query_params=(
332
+ # when using inlined deployment spec, we need to add the token to the query params
333
+ # to avoid revealing the token while calling the SYSTEM$DEPLOY_MODEL function
334
+ # noop if using file-based deployment spec or token is not provided
335
+ [hf_model_args.hf_token]
336
+ if (self._use_inlined_deployment_spec and hf_model_args and hf_model_args.hf_token)
337
+ else []
338
+ ),
323
339
  statement_params=statement_params,
324
340
  )
325
341
 
@@ -1,8 +1,9 @@
1
+ import contextlib
1
2
  import dataclasses
2
3
  import enum
3
4
  import logging
4
5
  import textwrap
5
- from typing import Any, Optional
6
+ from typing import Any, Generator, Optional
6
7
 
7
8
  from snowflake import snowpark
8
9
  from snowflake.ml._internal.utils import (
@@ -17,6 +18,11 @@ from snowflake.snowpark._internal import utils as snowpark_utils
17
18
 
18
19
  logger = logging.getLogger(__name__)
19
20
 
21
+ # Using this token instead of '?' to avoid escaping issues
22
+ # After quotes are escaped, we replace this token with '|| ? ||'
23
+ QMARK_RESERVED_TOKEN = "<QMARK_RESERVED_TOKEN>"
24
+ QMARK_PARAMETER_TOKEN = "'|| ? ||'"
25
+
20
26
 
21
27
  class ServiceStatus(enum.Enum):
22
28
  PENDING = "PENDING"
@@ -70,12 +76,26 @@ class ServiceSQLClient(_base._BaseSQLClient):
70
76
  CONTAINER_STATUS = "status"
71
77
  MESSAGE = "message"
72
78
 
79
+ @contextlib.contextmanager
80
+ def _qmark_paramstyle(self) -> Generator[None, None, None]:
81
+ """Context manager that temporarily changes paramstyle to qmark and restores original value on exit."""
82
+ if not hasattr(self._session, "_options"):
83
+ yield
84
+ else:
85
+ original_paramstyle = self._session._options["paramstyle"]
86
+ try:
87
+ self._session._options["paramstyle"] = "qmark"
88
+ yield
89
+ finally:
90
+ self._session._options["paramstyle"] = original_paramstyle
91
+
73
92
  def deploy_model(
74
93
  self,
75
94
  *,
76
95
  stage_path: Optional[str] = None,
77
96
  model_deployment_spec_yaml_str: Optional[str] = None,
78
97
  model_deployment_spec_file_rel_path: Optional[str] = None,
98
+ query_params: Optional[list[Any]] = None,
79
99
  statement_params: Optional[dict[str, Any]] = None,
80
100
  ) -> tuple[str, snowpark.AsyncJob]:
81
101
  assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
@@ -83,11 +103,18 @@ class ServiceSQLClient(_base._BaseSQLClient):
83
103
  model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
84
104
  model_deployment_spec_yaml_str
85
105
  ) # type: ignore[no-untyped-call]
106
+ model_deployment_spec_yaml_str = model_deployment_spec_yaml_str.replace( # type: ignore[union-attr]
107
+ QMARK_RESERVED_TOKEN, QMARK_PARAMETER_TOKEN
108
+ )
86
109
  logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
87
110
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
88
111
  else:
89
112
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
90
- async_job = self._session.sql(sql_str).collect(block=False, statement_params=statement_params)
113
+ with self._qmark_paramstyle():
114
+ async_job = self._session.sql(
115
+ sql_str,
116
+ params=query_params if query_params else None,
117
+ ).collect(block=False, statement_params=statement_params)
91
118
  assert isinstance(async_job, snowpark.AsyncJob)
92
119
  return async_job.query_id, async_job
93
120