snowflake-ml-python 1.8.0__py3-none-any.whl → 1.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. snowflake/cortex/_complete.py +44 -10
  2. snowflake/ml/_internal/platform_capabilities.py +39 -3
  3. snowflake/ml/data/data_connector.py +25 -0
  4. snowflake/ml/dataset/dataset_reader.py +5 -1
  5. snowflake/ml/jobs/_utils/constants.py +3 -5
  6. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  7. snowflake/ml/jobs/_utils/payload_utils.py +81 -47
  8. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +178 -0
  11. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  12. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  13. snowflake/ml/jobs/_utils/spec_utils.py +27 -8
  14. snowflake/ml/jobs/_utils/types.py +6 -0
  15. snowflake/ml/jobs/decorators.py +10 -6
  16. snowflake/ml/jobs/job.py +145 -23
  17. snowflake/ml/jobs/manager.py +79 -12
  18. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  19. snowflake/ml/model/_client/ops/service_ops.py +57 -39
  20. snowflake/ml/model/_client/service/model_deployment_spec.py +7 -4
  21. snowflake/ml/model/_client/sql/service.py +11 -5
  22. snowflake/ml/model/_model_composer/model_composer.py +29 -11
  23. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +1 -2
  24. snowflake/ml/model/_packager/model_env/model_env.py +8 -2
  25. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -4
  26. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +1 -1
  27. snowflake/ml/model/_packager/model_meta/model_meta.py +6 -1
  28. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  29. snowflake/ml/model/_packager/model_packager.py +2 -0
  30. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
  31. snowflake/ml/model/type_hints.py +2 -0
  32. snowflake/ml/modeling/_internal/estimator_utils.py +5 -1
  33. snowflake/ml/registry/_manager/model_manager.py +20 -1
  34. snowflake/ml/registry/registry.py +46 -2
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/METADATA +55 -4
  37. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/RECORD +40 -34
  38. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/WHEEL +1 -1
  39. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/licenses/LICENSE.txt +0 -0
  40. {snowflake_ml_python-1.8.0.dist-info → snowflake_ml_python-1.8.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,16 @@
1
1
  import pathlib
2
2
  import textwrap
3
- from typing import Any, Callable, Dict, List, Literal, Optional, Union
3
+ from typing import (
4
+ Any,
5
+ Callable,
6
+ Dict,
7
+ List,
8
+ Literal,
9
+ Optional,
10
+ TypeVar,
11
+ Union,
12
+ overload,
13
+ )
4
14
  from uuid import uuid4
5
15
 
6
16
  import yaml
@@ -16,6 +26,8 @@ from snowflake.snowpark.exceptions import SnowparkSQLException
16
26
  _PROJECT = "MLJob"
17
27
  JOB_ID_PREFIX = "MLJOB_"
18
28
 
29
+ T = TypeVar("T")
30
+
19
31
 
20
32
  @snowpark._internal.utils.private_preview(version="1.7.4")
21
33
  @telemetry.send_api_usage_telemetry(project=_PROJECT, func_params_to_log=["limit", "scope"])
@@ -59,7 +71,7 @@ def list_jobs(
59
71
 
60
72
  @snowpark._internal.utils.private_preview(version="1.7.4")
61
73
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
62
- def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob:
74
+ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
63
75
  """Retrieve a job service from the backend."""
64
76
  session = session or get_active_session()
65
77
 
@@ -71,7 +83,8 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
71
83
 
72
84
  try:
73
85
  # Validate that job exists by doing a status check
74
- job = jb.MLJob(job_id, session=session)
86
+ # FIXME: Retrieve return path
87
+ job = jb.MLJob[Any](job_id, session=session)
75
88
  _ = job.status
76
89
  return job
77
90
  except SnowparkSQLException as e:
@@ -82,7 +95,7 @@ def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob
82
95
 
83
96
  @snowpark._internal.utils.private_preview(version="1.7.4")
84
97
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
85
- def delete_job(job: Union[str, jb.MLJob], session: Optional[snowpark.Session] = None) -> None:
98
+ def delete_job(job: Union[str, jb.MLJob[Any]], session: Optional[snowpark.Session] = None) -> None:
86
99
  """Delete a job service from the backend. Status and logs will be lost."""
87
100
  if isinstance(job, jb.MLJob):
88
101
  job_id = job.id
@@ -106,8 +119,10 @@ def submit_file(
106
119
  external_access_integrations: Optional[List[str]] = None,
107
120
  query_warehouse: Optional[str] = None,
108
121
  spec_overrides: Optional[Dict[str, Any]] = None,
122
+ num_instances: Optional[int] = None,
123
+ enable_metrics: bool = False,
109
124
  session: Optional[snowpark.Session] = None,
110
- ) -> jb.MLJob:
125
+ ) -> jb.MLJob[None]:
111
126
  """
112
127
  Submit a Python file as a job to the compute pool.
113
128
 
@@ -121,6 +136,8 @@ def submit_file(
121
136
  external_access_integrations: A list of external access integrations.
122
137
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
123
138
  spec_overrides: Custom service specification overrides to apply.
139
+ num_instances: The number of instances to use for the job. If none specified, single node job is created.
140
+ enable_metrics: Whether to enable metrics publishing for the job.
124
141
  session: The Snowpark session to use. If none specified, uses active session.
125
142
 
126
143
  Returns:
@@ -136,6 +153,8 @@ def submit_file(
136
153
  external_access_integrations=external_access_integrations,
137
154
  query_warehouse=query_warehouse,
138
155
  spec_overrides=spec_overrides,
156
+ num_instances=num_instances,
157
+ enable_metrics=enable_metrics,
139
158
  session=session,
140
159
  )
141
160
 
@@ -154,8 +173,10 @@ def submit_directory(
154
173
  external_access_integrations: Optional[List[str]] = None,
155
174
  query_warehouse: Optional[str] = None,
156
175
  spec_overrides: Optional[Dict[str, Any]] = None,
176
+ num_instances: Optional[int] = None,
177
+ enable_metrics: bool = False,
157
178
  session: Optional[snowpark.Session] = None,
158
- ) -> jb.MLJob:
179
+ ) -> jb.MLJob[None]:
159
180
  """
160
181
  Submit a directory containing Python script(s) as a job to the compute pool.
161
182
 
@@ -170,6 +191,8 @@ def submit_directory(
170
191
  external_access_integrations: A list of external access integrations.
171
192
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
172
193
  spec_overrides: Custom service specification overrides to apply.
194
+ num_instances: The number of instances to use for the job. If none specified, single node job is created.
195
+ enable_metrics: Whether to enable metrics publishing for the job.
173
196
  session: The Snowpark session to use. If none specified, uses active session.
174
197
 
175
198
  Returns:
@@ -186,10 +209,52 @@ def submit_directory(
186
209
  external_access_integrations=external_access_integrations,
187
210
  query_warehouse=query_warehouse,
188
211
  spec_overrides=spec_overrides,
212
+ num_instances=num_instances,
213
+ enable_metrics=enable_metrics,
189
214
  session=session,
190
215
  )
191
216
 
192
217
 
218
+ @overload
219
+ def _submit_job(
220
+ source: str,
221
+ compute_pool: str,
222
+ *,
223
+ stage_name: str,
224
+ entrypoint: Optional[str] = None,
225
+ args: Optional[List[str]] = None,
226
+ env_vars: Optional[Dict[str, str]] = None,
227
+ pip_requirements: Optional[List[str]] = None,
228
+ external_access_integrations: Optional[List[str]] = None,
229
+ query_warehouse: Optional[str] = None,
230
+ spec_overrides: Optional[Dict[str, Any]] = None,
231
+ num_instances: Optional[int] = None,
232
+ enable_metrics: bool = False,
233
+ session: Optional[snowpark.Session] = None,
234
+ ) -> jb.MLJob[None]:
235
+ ...
236
+
237
+
238
+ @overload
239
+ def _submit_job(
240
+ source: Callable[..., T],
241
+ compute_pool: str,
242
+ *,
243
+ stage_name: str,
244
+ entrypoint: Optional[str] = None,
245
+ args: Optional[List[str]] = None,
246
+ env_vars: Optional[Dict[str, str]] = None,
247
+ pip_requirements: Optional[List[str]] = None,
248
+ external_access_integrations: Optional[List[str]] = None,
249
+ query_warehouse: Optional[str] = None,
250
+ spec_overrides: Optional[Dict[str, Any]] = None,
251
+ num_instances: Optional[int] = None,
252
+ enable_metrics: bool = False,
253
+ session: Optional[snowpark.Session] = None,
254
+ ) -> jb.MLJob[T]:
255
+ ...
256
+
257
+
193
258
  @telemetry.send_api_usage_telemetry(
194
259
  project=_PROJECT,
195
260
  func_params_to_log=[
@@ -201,7 +266,7 @@ def submit_directory(
201
266
  ],
202
267
  )
203
268
  def _submit_job(
204
- source: Union[str, Callable[..., Any]],
269
+ source: Union[str, Callable[..., T]],
205
270
  compute_pool: str,
206
271
  *,
207
272
  stage_name: str,
@@ -212,9 +277,10 @@ def _submit_job(
212
277
  external_access_integrations: Optional[List[str]] = None,
213
278
  query_warehouse: Optional[str] = None,
214
279
  spec_overrides: Optional[Dict[str, Any]] = None,
215
- session: Optional[snowpark.Session] = None,
216
280
  num_instances: Optional[int] = None,
217
- ) -> jb.MLJob:
281
+ enable_metrics: bool = False,
282
+ session: Optional[snowpark.Session] = None,
283
+ ) -> jb.MLJob[T]:
218
284
  """
219
285
  Submit a job to the compute pool.
220
286
 
@@ -229,8 +295,9 @@ def _submit_job(
229
295
  external_access_integrations: A list of external access integrations.
230
296
  query_warehouse: The query warehouse to use. Defaults to session warehouse.
231
297
  spec_overrides: Custom service specification overrides to apply.
232
- session: The Snowpark session to use. If none specified, uses active session.
233
298
  num_instances: The number of instances to use for the job. If none specified, single node job is created.
299
+ enable_metrics: Whether to enable metrics publishing for the job.
300
+ session: The Snowpark session to use. If none specified, uses active session.
234
301
 
235
302
  Returns:
236
303
  An object representing the submitted job.
@@ -257,6 +324,7 @@ def _submit_job(
257
324
  payload=uploaded_payload,
258
325
  args=args,
259
326
  num_instances=num_instances,
327
+ enable_metrics=enable_metrics,
260
328
  )
261
329
  spec_overrides = spec_utils.generate_spec_overrides(
262
330
  environment_vars=env_vars,
@@ -299,5 +367,4 @@ def _submit_job(
299
367
  ) from e
300
368
  raise
301
369
 
302
- # TODO: Wrap snowflake.core.service.JobService object
303
- return jb.MLJob(job_id, session=session)
370
+ return jb.MLJob(job_id, service_spec=spec, session=session)
@@ -789,14 +789,17 @@ class ModelOperator:
789
789
  version_name: sql_identifier.SqlIdentifier,
790
790
  statement_params: Optional[Dict[str, Any]] = None,
791
791
  ) -> type_hints.Task:
792
- model_spec = self._fetch_model_spec(
792
+ model_version = self._model_client.show_versions(
793
793
  database_name=database_name,
794
794
  schema_name=schema_name,
795
795
  model_name=model_name,
796
796
  version_name=version_name,
797
+ validate_result=True,
797
798
  statement_params=statement_params,
798
- )
799
- task_val = model_spec.get("task", type_hints.Task.UNKNOWN.value)
799
+ )[0]
800
+
801
+ model_attributes = json.loads(model_version.model_attributes)
802
+ task_val = model_attributes.get("task", type_hints.Task.UNKNOWN.value)
800
803
  return type_hints.Task(task_val)
801
804
 
802
805
  def get_functions(
@@ -9,7 +9,7 @@ import time
9
9
  from typing import Any, Dict, List, Optional, Tuple, Union, cast
10
10
 
11
11
  from snowflake import snowpark
12
- from snowflake.ml._internal import file_utils
12
+ from snowflake.ml._internal import file_utils, platform_capabilities as pc
13
13
  from snowflake.ml._internal.utils import service_logger, sql_identifier
14
14
  from snowflake.ml.model._client.service import model_deployment_spec
15
15
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
@@ -57,30 +57,30 @@ class ServiceOperator:
57
57
  self._session = session
58
58
  self._database_name = database_name
59
59
  self._schema_name = schema_name
60
- self._workspace = tempfile.TemporaryDirectory()
61
60
  self._service_client = service_sql.ServiceSQLClient(
62
61
  session,
63
62
  database_name=database_name,
64
63
  schema_name=schema_name,
65
64
  )
66
- self._stage_client = stage_sql.StageSQLClient(
67
- session,
68
- database_name=database_name,
69
- schema_name=schema_name,
70
- )
71
- self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
72
- workspace_path=pathlib.Path(self._workspace.name)
73
- )
65
+ if pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled():
66
+ self._workspace = None
67
+ self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
68
+ else:
69
+ self._workspace = tempfile.TemporaryDirectory()
70
+ self._stage_client = stage_sql.StageSQLClient(
71
+ session,
72
+ database_name=database_name,
73
+ schema_name=schema_name,
74
+ )
75
+ self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
76
+ workspace_path=pathlib.Path(self._workspace.name)
77
+ )
74
78
 
75
79
  def __eq__(self, __value: object) -> bool:
76
80
  if not isinstance(__value, ServiceOperator):
77
81
  return False
78
82
  return self._service_client == __value._service_client
79
83
 
80
- @property
81
- def workspace_path(self) -> pathlib.Path:
82
- return pathlib.Path(self._workspace.name)
83
-
84
84
  def create_service(
85
85
  self,
86
86
  *,
@@ -119,19 +119,21 @@ class ServiceOperator:
119
119
 
120
120
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
121
121
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
122
- # create a temp stage
123
- stage_name = sql_identifier.SqlIdentifier(
124
- snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
125
- )
126
- self._stage_client.create_tmp_stage(
127
- database_name=database_name,
128
- schema_name=schema_name,
129
- stage_name=stage_name,
130
- statement_params=statement_params,
131
- )
132
- stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
133
-
134
- self._model_deployment_spec.save(
122
+ if self._workspace:
123
+ # create a temp stage
124
+ stage_name = sql_identifier.SqlIdentifier(
125
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
126
+ )
127
+ self._stage_client.create_tmp_stage(
128
+ database_name=database_name,
129
+ schema_name=schema_name,
130
+ stage_name=stage_name,
131
+ statement_params=statement_params,
132
+ )
133
+ stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
134
+ else:
135
+ stage_path = None
136
+ spec_yaml_str_or_path = self._model_deployment_spec.save(
135
137
  database_name=database_name,
136
138
  schema_name=schema_name,
137
139
  model_name=model_name,
@@ -154,26 +156,35 @@ class ServiceOperator:
154
156
  force_rebuild=force_rebuild,
155
157
  external_access_integrations=build_external_access_integrations,
156
158
  )
157
- file_utils.upload_directory_to_stage(
158
- self._session,
159
- local_path=self.workspace_path,
160
- stage_path=pathlib.PurePosixPath(stage_path),
161
- statement_params=statement_params,
162
- )
159
+ if self._workspace:
160
+ assert stage_path is not None
161
+ file_utils.upload_directory_to_stage(
162
+ self._session,
163
+ local_path=pathlib.Path(self._workspace.name),
164
+ stage_path=pathlib.PurePosixPath(stage_path),
165
+ statement_params=statement_params,
166
+ )
163
167
 
164
- # check if the inference service is already running
168
+ # check if the inference service is already running/suspended
165
169
  model_inference_service_exists = self._check_if_service_exists(
166
170
  database_name=service_database_name,
167
171
  schema_name=service_schema_name,
168
172
  service_name=service_name,
169
- service_status_list_if_exists=[service_sql.ServiceStatus.READY],
173
+ service_status_list_if_exists=[
174
+ service_sql.ServiceStatus.READY,
175
+ service_sql.ServiceStatus.SUSPENDING,
176
+ service_sql.ServiceStatus.SUSPENDED,
177
+ ],
170
178
  statement_params=statement_params,
171
179
  )
172
180
 
173
181
  # deploy the model service
174
182
  query_id, async_job = self._service_client.deploy_model(
175
- stage_path=stage_path,
176
- model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
183
+ stage_path=stage_path if self._workspace else None,
184
+ model_deployment_spec_file_rel_path=(
185
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
186
+ ),
187
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
177
188
  statement_params=statement_params,
178
189
  )
179
190
 
@@ -309,7 +320,10 @@ class ServiceOperator:
309
320
  set_service_log_metadata_to_model_inference(
310
321
  service_log_meta,
311
322
  model_inference_service,
312
- "Model Inference image build is not rebuilding the image and using previously built image.",
323
+ (
324
+ "Model Inference image build is not rebuilding the image, but using a previously built "
325
+ "image."
326
+ ),
313
327
  )
314
328
  continue
315
329
 
@@ -366,7 +380,9 @@ class ServiceOperator:
366
380
  time.sleep(5)
367
381
 
368
382
  if model_inference_service_exists:
369
- module_logger.info(f"Inference service {model_inference_service.display_service_name} is already RUNNING.")
383
+ module_logger.info(
384
+ f"Inference service {model_inference_service.display_service_name} has already been deployed."
385
+ )
370
386
  else:
371
387
  self._finalize_logs(
372
388
  service_log_meta.service_logger, service_log_meta.service, service_log_meta.log_offset, statement_params
@@ -416,6 +432,8 @@ class ServiceOperator:
416
432
  service_status_list_if_exists = [
417
433
  service_sql.ServiceStatus.PENDING,
418
434
  service_sql.ServiceStatus.READY,
435
+ service_sql.ServiceStatus.SUSPENDING,
436
+ service_sql.ServiceStatus.SUSPENDED,
419
437
  service_sql.ServiceStatus.DONE,
420
438
  service_sql.ServiceStatus.FAILED,
421
439
  ]
@@ -16,7 +16,7 @@ class ModelDeploymentSpec:
16
16
 
17
17
  DEPLOY_SPEC_FILE_REL_PATH = "deploy.yml"
18
18
 
19
- def __init__(self, workspace_path: pathlib.Path) -> None:
19
+ def __init__(self, workspace_path: Optional[pathlib.Path] = None) -> None:
20
20
  self.workspace_path = workspace_path
21
21
 
22
22
  def save(
@@ -43,7 +43,7 @@ class ModelDeploymentSpec:
43
43
  max_batch_rows: Optional[int],
44
44
  force_rebuild: bool,
45
45
  external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
46
- ) -> None:
46
+ ) -> str:
47
47
  # create the deployment spec
48
48
  # models spec
49
49
  fq_model_name = identifier.get_schema_level_object_identifier(
@@ -105,9 +105,12 @@ class ModelDeploymentSpec:
105
105
  service=service_dict,
106
106
  )
107
107
 
108
+ # Anchors are not supported in the server, avoid that.
109
+ yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
110
+ if self.workspace_path is None:
111
+ return yaml.safe_dump(model_deployment_spec_dict)
108
112
  # save the yaml
109
113
  file_path = self.workspace_path / self.DEPLOY_SPEC_FILE_REL_PATH
110
114
  with file_path.open("w", encoding="utf-8") as f:
111
- # Anchors are not supported in the server, avoid that.
112
- yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
113
115
  yaml.safe_dump(model_deployment_spec_dict, f)
116
+ return str(file_path.resolve())
@@ -20,6 +20,8 @@ class ServiceStatus(enum.Enum):
20
20
  UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
21
21
  PENDING = "PENDING" # resource set is being created, can't be used yet
22
22
  READY = "READY" # resource set has been deployed.
23
+ SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
24
+ SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
23
25
  DELETING = "DELETING" # resource set is being deleted
24
26
  FAILED = "FAILED" # resource set has failed and cannot be used anymore
25
27
  DONE = "DONE" # resource set has finished running
@@ -71,13 +73,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
71
73
  def deploy_model(
72
74
  self,
73
75
  *,
74
- stage_path: str,
75
- model_deployment_spec_file_rel_path: str,
76
+ stage_path: Optional[str] = None,
77
+ model_deployment_spec_yaml_str: Optional[str] = None,
78
+ model_deployment_spec_file_rel_path: Optional[str] = None,
76
79
  statement_params: Optional[Dict[str, Any]] = None,
77
80
  ) -> Tuple[str, snowpark.AsyncJob]:
78
- async_job = self._session.sql(
79
- f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
80
- ).collect(block=False, statement_params=statement_params)
81
+ assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
82
+ if model_deployment_spec_yaml_str:
83
+ sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
84
+ else:
85
+ sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
86
+ async_job = self._session.sql(sql_str).collect(block=False, statement_params=statement_params)
81
87
  assert isinstance(async_job, snowpark.AsyncJob)
82
88
  return async_job.query_id, async_job
83
89
 
@@ -44,6 +44,7 @@ class ModelComposer:
44
44
  stage_path: str,
45
45
  *,
46
46
  statement_params: Optional[Dict[str, Any]] = None,
47
+ save_location: Optional[str] = None,
47
48
  ) -> None:
48
49
  self.session = session
49
50
  self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
@@ -54,10 +55,29 @@ class ModelComposer:
54
55
  # The stage path is a user stage path
55
56
  self.stage_path = pathlib.PurePosixPath(stage_path)
56
57
 
57
- self._workspace = tempfile.TemporaryDirectory()
58
- self._packager_workspace = tempfile.TemporaryDirectory()
58
+ # Set up workspace based on save_location if provided, otherwise use temporary directory
59
+ self.save_location = save_location
60
+ if save_location:
61
+ # Use the save_location directory directly
62
+ self._workspace_path = pathlib.Path(save_location)
63
+ self._workspace_path.mkdir(exist_ok=True)
64
+ # ensure that the directory is empty
65
+ if any(self._workspace_path.iterdir()):
66
+ raise ValueError(f"The directory {self._workspace_path} is not empty.")
67
+ self._workspace = None
68
+
69
+ self._packager_workspace_path = self._workspace_path / ModelComposer.MODEL_DIR_REL_PATH
70
+ self._packager_workspace_path.mkdir(exist_ok=True)
71
+ self._packager_workspace = None
72
+ else:
73
+ # Use a temporary directory
74
+ self._workspace = tempfile.TemporaryDirectory()
75
+ self._workspace_path = pathlib.Path(self._workspace.name)
76
+
77
+ self._packager_workspace_path = self._workspace_path / ModelComposer.MODEL_DIR_REL_PATH
78
+ self._packager_workspace_path.mkdir(exist_ok=True)
59
79
 
60
- self.packager = model_packager.ModelPackager(local_dir_path=str(self._packager_workspace_path))
80
+ self.packager = model_packager.ModelPackager(local_dir_path=str(self.packager_workspace_path))
61
81
  self.manifest = model_manifest.ModelManifest(workspace_path=self.workspace_path)
62
82
 
63
83
  self.model_file_rel_path = f"model-{uuid.uuid4().hex}.zip"
@@ -65,16 +85,16 @@ class ModelComposer:
65
85
  self._statement_params = statement_params
66
86
 
67
87
  def __del__(self) -> None:
68
- self._workspace.cleanup()
69
- self._packager_workspace.cleanup()
88
+ if self._workspace:
89
+ self._workspace.cleanup()
70
90
 
71
91
  @property
72
92
  def workspace_path(self) -> pathlib.Path:
73
- return pathlib.Path(self._workspace.name)
93
+ return self._workspace_path
74
94
 
75
95
  @property
76
- def _packager_workspace_path(self) -> pathlib.Path:
77
- return pathlib.Path(self._packager_workspace.name)
96
+ def packager_workspace_path(self) -> pathlib.Path:
97
+ return self._packager_workspace_path
78
98
 
79
99
  @property
80
100
  def model_stage_path(self) -> str:
@@ -167,6 +187,7 @@ class ModelComposer:
167
187
  conda_dependencies=conda_dependencies,
168
188
  pip_requirements=pip_requirements,
169
189
  artifact_repository_map=artifact_repository_map,
190
+ target_platforms=target_platforms,
170
191
  python_version=python_version,
171
192
  ext_modules=ext_modules,
172
193
  code_paths=code_paths,
@@ -175,9 +196,6 @@ class ModelComposer:
175
196
  )
176
197
  assert self.packager.meta is not None
177
198
 
178
- file_utils.copytree(
179
- str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
180
- )
181
199
  self.manifest.save(
182
200
  model_meta=self.packager.meta,
183
201
  model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
@@ -36,7 +36,6 @@ class ModelManifest:
36
36
  """
37
37
 
38
38
  MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
39
- _ENABLE_USER_FILES = False
40
39
  _DEFAULT_RUNTIME_NAME = "python_runtime"
41
40
 
42
41
  def __init__(self, workspace_path: pathlib.Path) -> None:
@@ -149,7 +148,7 @@ class ModelManifest:
149
148
  ],
150
149
  )
151
150
 
152
- if self._ENABLE_USER_FILES:
151
+ if self.user_files:
153
152
  manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
154
153
 
155
154
  lineage_sources = self._extract_lineage_info(data_sources)
@@ -29,11 +29,13 @@ class ModelEnv:
29
29
  self,
30
30
  conda_env_rel_path: Optional[str] = None,
31
31
  pip_requirements_rel_path: Optional[str] = None,
32
+ prefer_pip: bool = False,
32
33
  ) -> None:
33
34
  if conda_env_rel_path is None:
34
35
  conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
35
36
  if pip_requirements_rel_path is None:
36
37
  pip_requirements_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_PIP_REQUIREMENTS_FILENAME)
38
+ self.prefer_pip: bool = prefer_pip
37
39
  self.conda_env_rel_path = pathlib.PurePosixPath(pathlib.Path(conda_env_rel_path).as_posix())
38
40
  self.pip_requirements_rel_path = pathlib.PurePosixPath(pathlib.Path(pip_requirements_rel_path).as_posix())
39
41
  self.artifact_repository_map: Optional[Dict[str, str]] = None
@@ -113,7 +115,11 @@ class ModelEnv:
113
115
  if snowpark_ml_version:
114
116
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
115
117
 
116
- def include_if_absent(self, pkgs: List[ModelDependency], check_local_version: bool = False) -> None:
118
+ def include_if_absent(
119
+ self,
120
+ pkgs: List[ModelDependency],
121
+ check_local_version: bool = False,
122
+ ) -> None:
117
123
  """Append requirements into model env if absent. Depending on the environment, requirements may be added
118
124
  to either the pip requirements or conda dependencies.
119
125
 
@@ -121,7 +127,7 @@ class ModelEnv:
121
127
  pkgs: A list of ModelDependency namedtuple to be appended.
122
128
  check_local_version: Flag to indicate if it is required to pin to local version. Defaults to False.
123
129
  """
124
- if self.pip_requirements and not self.conda_dependencies and pkgs:
130
+ if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
125
131
  pip_pkg_reqs: List[str] = []
126
132
  warnings.warn(
127
133
  (
@@ -57,6 +57,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
57
57
  "predict_proba",
58
58
  "predict_log_proba",
59
59
  "decision_function",
60
+ "score_samples",
60
61
  ]
61
62
  EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
62
63
 
@@ -74,10 +75,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
74
75
  and (
75
76
  not type_utils.LazyType("lightgbm.LGBMModel").isinstance(model)
76
77
  ) # LGBMModel is actually a BaseEstimator
77
- and any(
78
- (hasattr(model, method) and callable(getattr(model, method, None)))
79
- for method in cls.DEFAULT_TARGET_METHODS
80
- )
81
78
  )
82
79
 
83
80
  @classmethod
@@ -1 +1 @@
1
- REQUIREMENTS = ['cloudpickle>=2.0.0']
1
+ REQUIREMENTS = ['cloudpickle>=2.0.0,<3']
@@ -49,6 +49,7 @@ def create_model_metadata(
49
49
  conda_dependencies: Optional[List[str]] = None,
50
50
  pip_requirements: Optional[List[str]] = None,
51
51
  artifact_repository_map: Optional[Dict[str, str]] = None,
52
+ target_platforms: Optional[List[model_types.TargetPlatform]] = None,
52
53
  python_version: Optional[str] = None,
53
54
  task: model_types.Task = model_types.Task.UNKNOWN,
54
55
  **kwargs: Any,
@@ -69,6 +70,7 @@ def create_model_metadata(
69
70
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
70
71
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
71
72
  artifact_repository_map: A dict mapping from package channel to artifact repository name.
73
+ target_platforms: List of target platforms to run the model.
72
74
  python_version: A string of python version where model is run. Used for user override. If specified as None,
73
75
  current version would be captured. Defaults to None.
74
76
  task: The task of the Model Version. It is an enum class Task with values TABULAR_REGRESSION,
@@ -101,12 +103,14 @@ def create_model_metadata(
101
103
  else:
102
104
  raise ValueError("`snowflake.ml` is imported via a way that embedding local ML library is not supported.")
103
105
 
106
+ prefer_pip = target_platforms == [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
104
107
  env = _create_env_for_model_metadata(
105
108
  conda_dependencies=conda_dependencies,
106
109
  pip_requirements=pip_requirements,
107
110
  artifact_repository_map=artifact_repository_map,
108
111
  python_version=python_version,
109
112
  embed_local_ml_library=embed_local_ml_library,
113
+ prefer_pip=prefer_pip,
110
114
  )
111
115
 
112
116
  if embed_local_ml_library:
@@ -157,8 +161,9 @@ def _create_env_for_model_metadata(
157
161
  artifact_repository_map: Optional[Dict[str, str]] = None,
158
162
  python_version: Optional[str] = None,
159
163
  embed_local_ml_library: bool = False,
164
+ prefer_pip: bool = False,
160
165
  ) -> model_env.ModelEnv:
161
- env = model_env.ModelEnv()
166
+ env = model_env.ModelEnv(prefer_pip=prefer_pip)
162
167
 
163
168
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
164
169
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
@@ -82,6 +82,7 @@ class SentenceTransformersModelBlobOptions(BaseModelBlobOptions):
82
82
 
83
83
  ModelBlobOptions = Union[
84
84
  BaseModelBlobOptions,
85
+ CatBoostModelBlobOptions,
85
86
  HuggingFacePipelineModelBlobOptions,
86
87
  MLFlowModelBlobOptions,
87
88
  XgboostModelBlobOptions,