snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/telemetry.py +6 -9
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/_internal/utils/identifier.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +61 -0
  5. snowflake/ml/jobs/__init__.py +2 -0
  6. snowflake/ml/jobs/_utils/constants.py +3 -2
  7. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  8. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  9. snowflake/ml/jobs/_utils/payload_utils.py +89 -40
  10. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  11. snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
  13. snowflake/ml/jobs/_utils/spec_utils.py +29 -5
  14. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  15. snowflake/ml/jobs/_utils/types.py +5 -1
  16. snowflake/ml/jobs/decorators.py +20 -28
  17. snowflake/ml/jobs/job.py +197 -61
  18. snowflake/ml/jobs/manager.py +253 -121
  19. snowflake/ml/model/_client/model/model_impl.py +58 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  21. snowflake/ml/model/_client/ops/model_ops.py +18 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +23 -6
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  24. snowflake/ml/model/_client/sql/service.py +68 -20
  25. snowflake/ml/model/_client/sql/stage.py +5 -2
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  27. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  28. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  31. snowflake/ml/model/_signatures/core.py +24 -0
  32. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  33. snowflake/ml/model/target_platform.py +11 -0
  34. snowflake/ml/model/task.py +9 -0
  35. snowflake/ml/model/type_hints.py +5 -13
  36. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  37. snowflake/ml/monitoring/explain_visualize.py +2 -2
  38. snowflake/ml/monitoring/model_monitor.py +0 -4
  39. snowflake/ml/registry/_manager/model_manager.py +30 -15
  40. snowflake/ml/registry/registry.py +144 -47
  41. snowflake/ml/utils/connection_params.py +1 -1
  42. snowflake/ml/utils/html_utils.py +263 -0
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
  45. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
  46. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  47. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  48. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,96 @@ class ModelVersion(lineage_node.LineageNode):
38
38
  def __init__(self) -> None:
39
39
  raise RuntimeError("ModelVersion's initializer is not meant to be used. Use `version` from model instead.")
40
40
 
41
+ def _repr_html_(self) -> str:
42
+ """Generate an HTML representation of the model version.
43
+
44
+ Returns:
45
+ str: HTML string containing formatted model version details.
46
+ """
47
+ from snowflake.ml.utils import html_utils
48
+
49
+ # Get task
50
+ try:
51
+ task = self.get_model_task().value
52
+ except Exception:
53
+ task = (
54
+ html_utils.create_error_message("Not available")
55
+ .replace('<em style="color: #888; font-style: italic;">', "")
56
+ .replace("</em>", "")
57
+ )
58
+
59
+ # Get functions info for display
60
+ try:
61
+ functions = self.show_functions()
62
+ if not functions:
63
+ functions_html = html_utils.create_error_message("No functions available")
64
+ else:
65
+ functions_list = []
66
+ for func in functions:
67
+ try:
68
+ sig_html = func["signature"]._repr_html_()
69
+ except Exception:
70
+ # Fallback to simple display if can't display signature
71
+ sig_html = f"<pre style='margin: 5px 0;'>{func['signature']}</pre>"
72
+
73
+ function_content = f"""
74
+ <div style="margin: 5px 0;">
75
+ <strong>Target Method:</strong> {func['target_method']}
76
+ </div>
77
+ <div style="margin: 5px 0;">
78
+ <strong>Function Type:</strong> {func.get('target_method_function_type', 'N/A')}
79
+ </div>
80
+ <div style="margin: 5px 0;">
81
+ <strong>Partitioned:</strong> {func.get('is_partitioned', False)}
82
+ </div>
83
+ <div style="margin: 10px 0;">
84
+ <strong>Signature:</strong>
85
+ {sig_html}
86
+ </div>
87
+ """
88
+
89
+ functions_list.append(
90
+ html_utils.create_collapsible_section(
91
+ title=func["name"], content=function_content, open_by_default=False
92
+ )
93
+ )
94
+ functions_html = "".join(functions_list)
95
+ except Exception:
96
+ functions_html = html_utils.create_error_message("Error retrieving functions")
97
+
98
+ # Get metrics for display
99
+ try:
100
+ metrics = self.show_metrics()
101
+ if not metrics:
102
+ metrics_html = html_utils.create_error_message("No metrics available")
103
+ else:
104
+ metrics_html = ""
105
+ for metric_name, value in metrics.items():
106
+ metrics_html += html_utils.create_metric_item(metric_name, value)
107
+ except Exception:
108
+ metrics_html = html_utils.create_error_message("Error retrieving metrics")
109
+
110
+ # Create main content sections
111
+ main_info = html_utils.create_grid_section(
112
+ [
113
+ ("Model Name", self.model_name),
114
+ ("Version", f'<strong style="color: #28a745;">{self.version_name}</strong>'),
115
+ ("Full Name", self.fully_qualified_model_name),
116
+ ("Description", self.description),
117
+ ("Task", task),
118
+ ]
119
+ )
120
+
121
+ functions_section = html_utils.create_section_header("Functions") + html_utils.create_content_section(
122
+ functions_html
123
+ )
124
+
125
+ metrics_section = html_utils.create_section_header("Metrics") + html_utils.create_content_section(metrics_html)
126
+
127
+ content = main_info + functions_section + metrics_section
128
+
129
+ return html_utils.create_base_container("Model Version Details", content)
130
+
41
131
  @classmethod
42
132
  def _ref(
43
133
  cls,
@@ -643,14 +643,17 @@ class ModelOperator:
643
643
  # TODO(sdas): Figure out a better way to filter out MODEL_BUILD_ services server side.
644
644
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
645
645
 
646
- result = []
647
-
646
+ result: list[ServiceInfo] = []
648
647
  for fully_qualified_service_name in fully_qualified_service_names:
649
648
  ingress_url: Optional[str] = None
650
649
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
651
- service_status, _ = self._service_client.get_service_status(
650
+ statuses = self._service_client.get_service_container_statuses(
652
651
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
653
652
  )
653
+ if len(statuses) == 0:
654
+ return result
655
+
656
+ service_status = statuses[0].service_status
654
657
  for res_row in self._service_client.show_endpoints(
655
658
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
656
659
  ):
@@ -952,7 +955,7 @@ class ModelOperator:
952
955
  output_with_input_features = False
953
956
  df = model_signature._convert_and_validate_local_data(X, signature.inputs, strict=strict_input_validation)
954
957
  s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
955
- self._session, df, keep_order=keep_order, features=signature.inputs
958
+ self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
956
959
  )
957
960
  else:
958
961
  keep_order = False
@@ -966,9 +969,16 @@ class ModelOperator:
966
969
 
967
970
  # Compose input and output names
968
971
  input_args = []
972
+ quoted_identifiers_ignore_case = (
973
+ snowpark_handler.SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
974
+ self._session, statement_params
975
+ )
976
+ )
977
+
969
978
  for input_feature in signature.inputs:
970
979
  col_name = identifier_rule.get_sql_identifier_from_feature(input_feature.name)
971
-
980
+ if quoted_identifiers_ignore_case:
981
+ col_name = sql_identifier.SqlIdentifier(input_feature.name.upper(), case_sensitive=True)
972
982
  input_args.append(col_name)
973
983
 
974
984
  returns = []
@@ -1048,7 +1058,9 @@ class ModelOperator:
1048
1058
 
1049
1059
  # Get final result
1050
1060
  if not isinstance(X, dataframe.DataFrame):
1051
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
1061
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
1062
+ df_res, features=signature.outputs, statement_params=statement_params
1063
+ )
1052
1064
  else:
1053
1065
  return df_res
1054
1066
 
@@ -325,13 +325,14 @@ class ServiceOperator:
325
325
  )
326
326
  continue
327
327
 
328
- service_status, message = self._service_client.get_service_status(
328
+ statuses = self._service_client.get_service_container_statuses(
329
329
  database_name=service_log_meta.service.database_name,
330
330
  schema_name=service_log_meta.service.schema_name,
331
331
  service_name=service_log_meta.service.service_name,
332
332
  include_message=True,
333
333
  statement_params=statement_params,
334
334
  )
335
+ service_status = statuses[0].service_status
335
336
  if (service_status != service_sql.ServiceStatus.RUNNING) or (
336
337
  service_status != service_log_meta.service_status
337
338
  ):
@@ -341,7 +342,19 @@ class ServiceOperator:
341
342
  f"{service_log_meta.service.display_service_name} is "
342
343
  f"{service_log_meta.service_status.value}."
343
344
  )
344
- module_logger.info(f"Service message: {message}")
345
+ for status in statuses:
346
+ if status.instance_id is not None:
347
+ instance_status, container_status = None, None
348
+ if status.instance_status is not None:
349
+ instance_status = status.instance_status.value
350
+ if status.container_status is not None:
351
+ container_status = status.container_status.value
352
+ module_logger.info(
353
+ f"Instance[{status.instance_id}]: "
354
+ f"instance status: {instance_status}, "
355
+ f"container status: {container_status}, "
356
+ f"message: {status.message}"
357
+ )
345
358
 
346
359
  new_logs, new_offset = fetch_logs(
347
360
  service_log_meta.service,
@@ -353,13 +366,14 @@ class ServiceOperator:
353
366
 
354
367
  # check if model build service is done
355
368
  if not service_log_meta.is_model_build_service_done:
356
- service_status, _ = self._service_client.get_service_status(
369
+ statuses = self._service_client.get_service_container_statuses(
357
370
  database_name=model_build_service.database_name,
358
371
  schema_name=model_build_service.schema_name,
359
372
  service_name=model_build_service.service_name,
360
373
  include_message=False,
361
374
  statement_params=statement_params,
362
375
  )
376
+ service_status = statuses[0].service_status
363
377
 
364
378
  if service_status == service_sql.ServiceStatus.DONE:
365
379
  set_service_log_metadata_to_model_inference(
@@ -436,13 +450,14 @@ class ServiceOperator:
436
450
  service_sql.ServiceStatus.FAILED,
437
451
  ]
438
452
  try:
439
- service_status, _ = self._service_client.get_service_status(
453
+ statuses = self._service_client.get_service_container_statuses(
440
454
  database_name=database_name,
441
455
  schema_name=schema_name,
442
456
  service_name=service_name,
443
457
  include_message=False,
444
458
  statement_params=statement_params,
445
459
  )
460
+ service_status = statuses[0].service_status
446
461
  return any(service_status == status for status in service_status_list_if_exists)
447
462
  except exceptions.SnowparkSQLException:
448
463
  return False
@@ -503,7 +518,7 @@ class ServiceOperator:
503
518
  output_with_input_features = False
504
519
  df = model_signature._convert_and_validate_local_data(X, signature.inputs)
505
520
  s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
506
- self._session, df, keep_order=keep_order, features=signature.inputs
521
+ self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
507
522
  )
508
523
  else:
509
524
  keep_order = False
@@ -615,7 +630,9 @@ class ServiceOperator:
615
630
 
616
631
  # get final result
617
632
  if not isinstance(X, dataframe.DataFrame):
618
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
633
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
634
+ df_res, features=signature.outputs, statement_params=statement_params
635
+ )
619
636
  else:
620
637
  return df_res
621
638
 
@@ -2,6 +2,8 @@ from typing import Optional
2
2
 
3
3
  from pydantic import BaseModel
4
4
 
5
+ BaseModel.model_config["protected_namespaces"] = ()
6
+
5
7
 
6
8
  class Model(BaseModel):
7
9
  name: str
@@ -1,4 +1,6 @@
1
+ import dataclasses
1
2
  import enum
3
+ import logging
2
4
  import textwrap
3
5
  from typing import Any, Optional, Union
4
6
 
@@ -13,26 +15,59 @@ from snowflake.ml.model._model_composer.model_method import constants
13
15
  from snowflake.snowpark import dataframe, functions as F, row, types as spt
14
16
  from snowflake.snowpark._internal import utils as snowpark_utils
15
17
 
18
+ logger = logging.getLogger(__name__)
19
+
16
20
 
17
- # The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
18
- # except UNKNOWN
19
21
  class ServiceStatus(enum.Enum):
20
- UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
21
- PENDING = "PENDING" # resource set is being created, can't be used yet
22
- SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
23
- SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
24
- DELETING = "DELETING" # resource set is being deleted
25
- FAILED = "FAILED" # resource set has failed and cannot be used anymore
26
- DONE = "DONE" # resource set has finished running
27
- INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
22
+ PENDING = "PENDING"
28
23
  RUNNING = "RUNNING"
24
+ FAILED = "FAILED"
25
+ DONE = "DONE"
26
+ SUSPENDING = "SUSPENDING"
27
+ SUSPENDED = "SUSPENDED"
28
+ DELETING = "DELETING"
29
29
  DELETED = "DELETED"
30
+ INTERNAL_ERROR = "INTERNAL_ERROR"
31
+
32
+
33
+ class InstanceStatus(enum.Enum):
34
+ PENDING = "PENDING"
35
+ READY = "READY"
36
+ FAILED = "FAILED"
37
+ TERMINATING = "TERMINATING"
38
+ SUCCEEDED = "SUCCEEDED"
39
+
40
+
41
+ class ContainerStatus(enum.Enum):
42
+ PENDING = "PENDING"
43
+ READY = "READY"
44
+ DONE = "DONE"
45
+ FAILED = "FAILED"
46
+ UNKNOWN = "UNKNOWN"
47
+
48
+
49
+ @dataclasses.dataclass
50
+ class ServiceStatusInfo:
51
+ """
52
+ Class containing information about service container status.
53
+ Reference: https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service
54
+ """
55
+
56
+ service_status: ServiceStatus
57
+ instance_id: Optional[int] = None
58
+ instance_status: Optional[InstanceStatus] = None
59
+ container_status: Optional[ContainerStatus] = None
60
+ message: Optional[str] = None
30
61
 
31
62
 
32
63
  class ServiceSQLClient(_base._BaseSQLClient):
33
64
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
34
65
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
35
66
  SERVICE_STATUS = "service_status"
67
+ INSTANCE_ID = "instance_id"
68
+ INSTANCE_STATUS = "instance_status"
69
+ CONTAINER_STATUS = "status"
70
+ MESSAGE = "message"
36
71
 
37
72
  def build_model_container(
38
73
  self,
@@ -81,6 +116,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
81
116
  ) -> tuple[str, snowpark.AsyncJob]:
82
117
  assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
83
118
  if model_deployment_spec_yaml_str:
119
+ model_deployment_spec_yaml_str = snowpark_utils.escape_single_quotes(
120
+ model_deployment_spec_yaml_str
121
+ ) # type: ignore[no-untyped-call]
122
+ logger.info(f"Deploying model with spec={model_deployment_spec_yaml_str}")
84
123
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
85
124
  else:
86
125
  sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
@@ -192,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
192
231
  )
193
232
  return str(rows[0][system_func])
194
233
 
195
- def get_service_status(
234
+ def get_service_container_statuses(
196
235
  self,
197
236
  *,
198
237
  database_name: Optional[sql_identifier.SqlIdentifier],
@@ -200,18 +239,27 @@ class ServiceSQLClient(_base._BaseSQLClient):
200
239
  service_name: sql_identifier.SqlIdentifier,
201
240
  include_message: bool = False,
202
241
  statement_params: Optional[dict[str, Any]] = None,
203
- ) -> tuple[ServiceStatus, Optional[str]]:
242
+ ) -> list[ServiceStatusInfo]:
204
243
  fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
205
244
  query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
206
245
  rows = self._session.sql(query).collect(statement_params=statement_params)
207
- if len(rows) == 0:
208
- return ServiceStatus.UNKNOWN, None
209
- row = rows[0]
210
- service_status = row[ServiceSQLClient.SERVICE_STATUS]
211
- message = row["message"] if include_message else None
212
- if not isinstance(service_status, ServiceStatus):
213
- return ServiceStatus.UNKNOWN, message
214
- return ServiceStatus(service_status), message
246
+ statuses = []
247
+ for r in rows:
248
+ instance_status, container_status = None, None
249
+ if r[ServiceSQLClient.INSTANCE_STATUS] is not None:
250
+ instance_status = InstanceStatus(r[ServiceSQLClient.INSTANCE_STATUS])
251
+ if r[ServiceSQLClient.CONTAINER_STATUS] is not None:
252
+ container_status = ContainerStatus(r[ServiceSQLClient.CONTAINER_STATUS])
253
+ statuses.append(
254
+ ServiceStatusInfo(
255
+ service_status=ServiceStatus(r[ServiceSQLClient.SERVICE_STATUS]),
256
+ instance_id=r[ServiceSQLClient.INSTANCE_ID],
257
+ instance_status=instance_status,
258
+ container_status=container_status,
259
+ message=r[ServiceSQLClient.MESSAGE] if include_message else None,
260
+ )
261
+ )
262
+ return statuses
215
263
 
216
264
  def drop_service(
217
265
  self,
@@ -12,9 +12,12 @@ class StageSQLClient(_base._BaseSQLClient):
12
12
  schema_name: Optional[sql_identifier.SqlIdentifier],
13
13
  stage_name: sql_identifier.SqlIdentifier,
14
14
  statement_params: Optional[dict[str, Any]] = None,
15
- ) -> None:
15
+ ) -> str:
16
+ fq_stage_name = self.fully_qualified_object_name(database_name, schema_name, stage_name)
16
17
  query_result_checker.SqlResultValidator(
17
18
  self._session,
18
- f"CREATE SCOPED TEMPORARY STAGE {self.fully_qualified_object_name(database_name, schema_name, stage_name)}",
19
+ f"CREATE SCOPED TEMPORARY STAGE {fq_stage_name}",
19
20
  statement_params=statement_params,
20
21
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
22
+
23
+ return fq_stage_name
@@ -7,6 +7,7 @@ from typing import Optional, cast
7
7
  import yaml
8
8
 
9
9
  from snowflake.ml._internal import env_utils
10
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
10
11
  from snowflake.ml.data import data_source
11
12
  from snowflake.ml.model import type_hints
12
13
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -53,17 +54,44 @@ class ModelManifest:
53
54
  if options is None:
54
55
  options = {}
55
56
 
57
+ has_pip_requirements = len(model_meta.env.pip_requirements) > 0
58
+ only_spcs = (
59
+ target_platforms
60
+ and len(target_platforms) == 1
61
+ and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
62
+ )
63
+
56
64
  if "relax_version" not in options:
57
- warnings.warn(
58
- (
59
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
60
- " from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
61
- "reproducibility, etc., set `options={'relax_version': False}` when logging the model."
62
- ),
63
- category=UserWarning,
64
- stacklevel=2,
65
- )
66
- relax_version = options.get("relax_version", True)
65
+ if has_pip_requirements or only_spcs:
66
+ logger.info(
67
+ "Setting `relax_version=False` as this model will run in Snowpark Container Services "
68
+ "or in Warehouse with a specified artifact_repository_map where exact version "
69
+ " specifications will be honored."
70
+ )
71
+ relax_version = False
72
+ else:
73
+ warnings.warn(
74
+ (
75
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
76
+ " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
77
+ " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
78
+ ),
79
+ category=UserWarning,
80
+ stacklevel=2,
81
+ )
82
+ relax_version = True
83
+ options["relax_version"] = relax_version
84
+ else:
85
+ relax_version = options.get("relax_version", True)
86
+ if relax_version and (has_pip_requirements or only_spcs):
87
+ raise exceptions.SnowflakeMLException(
88
+ error_code=error_codes.INVALID_ARGUMENT,
89
+ original_exception=ValueError(
90
+ "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
91
+ "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
92
+ "targeting only Snowpark Container Services."
93
+ ),
94
+ )
67
95
 
68
96
  runtime_to_use = model_runtime.ModelRuntime(
69
97
  name=self._DEFAULT_RUNTIME_NAME,
@@ -9,6 +9,7 @@ from packaging import requirements, version
9
9
 
10
10
  from snowflake.ml import version as snowml_version
11
11
  from snowflake.ml._internal import env as snowml_env, env_utils
12
+ from snowflake.ml.model import type_hints as model_types
12
13
  from snowflake.ml.model._packager.model_meta import model_meta_schema
13
14
 
14
15
  # requirement: Full version requirement where name is conda package name.
@@ -30,6 +31,7 @@ class ModelEnv:
30
31
  conda_env_rel_path: Optional[str] = None,
31
32
  pip_requirements_rel_path: Optional[str] = None,
32
33
  prefer_pip: bool = False,
34
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
33
35
  ) -> None:
34
36
  if conda_env_rel_path is None:
35
37
  conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
@@ -45,6 +47,8 @@ class ModelEnv:
45
47
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
46
48
  self._cuda_version: Optional[version.Version] = None
47
49
  self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
50
+ self._target_platforms = target_platforms
51
+ self._warnings_shown: set[str] = set()
48
52
 
49
53
  @property
50
54
  def conda_dependencies(self) -> list[str]:
@@ -116,6 +120,17 @@ class ModelEnv:
116
120
  if snowpark_ml_version:
117
121
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
118
122
 
123
+ @property
124
+ def targets_warehouse(self) -> bool:
125
+ """Returns True if warehouse is a target platform."""
126
+ return self._target_platforms is None or model_types.TargetPlatform.WAREHOUSE in self._target_platforms
127
+
128
+ def _warn_once(self, message: str, stacklevel: int = 2) -> None:
129
+ """Show warning only once per ModelEnv instance."""
130
+ if message not in self._warnings_shown:
131
+ warnings.warn(message, category=UserWarning, stacklevel=stacklevel)
132
+ self._warnings_shown.add(message)
133
+
119
134
  def include_if_absent(
120
135
  self,
121
136
  pkgs: list[ModelDependency],
@@ -130,14 +145,14 @@ class ModelEnv:
130
145
  """
131
146
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
132
147
  pip_pkg_reqs: list[str] = []
133
- warnings.warn(
134
- (
135
- "Dependencies specified from pip requirements."
136
- " This may prevent model deploying to Snowflake Warehouse."
137
- ),
138
- category=UserWarning,
139
- stacklevel=2,
140
- )
148
+ if self.targets_warehouse:
149
+ self._warn_once(
150
+ (
151
+ "Dependencies specified from pip requirements."
152
+ " This may prevent model deploying to Snowflake Warehouse."
153
+ ),
154
+ stacklevel=2,
155
+ )
141
156
  for conda_req_str, pip_name in pkgs:
142
157
  _, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
143
158
  pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
@@ -162,16 +177,15 @@ class ModelEnv:
162
177
  req_to_add.name = conda_req.name
163
178
  else:
164
179
  req_to_add = conda_req
165
- show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
180
+ show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
166
181
 
167
182
  if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
168
183
  if show_warning_message:
169
- warnings.warn(
184
+ self._warn_once(
170
185
  (
171
186
  f"Basic dependency {req_to_add.name} specified from pip requirements."
172
187
  " This may prevent model deploying to Snowflake Warehouse."
173
188
  ),
174
- category=UserWarning,
175
189
  stacklevel=2,
176
190
  )
177
191
  continue
@@ -182,12 +196,11 @@ class ModelEnv:
182
196
  pass
183
197
  except env_utils.DuplicateDependencyInMultipleChannelsError:
184
198
  if show_warning_message:
185
- warnings.warn(
199
+ self._warn_once(
186
200
  (
187
201
  f"Basic dependency {req_to_add.name} specified from non-Snowflake channel."
188
202
  + " This may prevent model deploying to Snowflake Warehouse."
189
203
  ),
190
- category=UserWarning,
191
204
  stacklevel=2,
192
205
  )
193
206
 
@@ -272,22 +285,20 @@ class ModelEnv:
272
285
  )
273
286
 
274
287
  for channel, channel_dependencies in conda_dependencies_dict.items():
275
- if channel != env_utils.DEFAULT_CHANNEL_NAME:
276
- warnings.warn(
288
+ if channel != env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse:
289
+ self._warn_once(
277
290
  (
278
291
  "Found dependencies specified in the conda file from non-Snowflake channel."
279
292
  " This may prevent model deploying to Snowflake Warehouse."
280
293
  ),
281
- category=UserWarning,
282
294
  stacklevel=2,
283
295
  )
284
- if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
285
- warnings.warn(
296
+ if len(channel_dependencies) == 0 and channel not in self._conda_dependencies and self.targets_warehouse:
297
+ self._warn_once(
286
298
  (
287
299
  f"Found additional conda channel {channel} specified in the conda file."
288
300
  " This may prevent model deploying to Snowflake Warehouse."
289
301
  ),
290
- category=UserWarning,
291
302
  stacklevel=2,
292
303
  )
293
304
  self._conda_dependencies[channel] = []
@@ -298,22 +309,20 @@ class ModelEnv:
298
309
  except env_utils.DuplicateDependencyError:
299
310
  pass
300
311
  except env_utils.DuplicateDependencyInMultipleChannelsError:
301
- warnings.warn(
312
+ self._warn_once(
302
313
  (
303
314
  f"Dependency {channel_dependency.name} appeared in multiple channels as conda dependency."
304
315
  " This may be unintentional."
305
316
  ),
306
- category=UserWarning,
307
317
  stacklevel=2,
308
318
  )
309
319
 
310
- if pip_requirements_list:
311
- warnings.warn(
320
+ if pip_requirements_list and self.targets_warehouse:
321
+ self._warn_once(
312
322
  (
313
323
  "Found dependencies specified as pip requirements."
314
324
  " This may prevent model deploying to Snowflake Warehouse."
315
325
  ),
316
- category=UserWarning,
317
326
  stacklevel=2,
318
327
  )
319
328
  for pip_dependency in pip_requirements_list:
@@ -333,13 +342,12 @@ class ModelEnv:
333
342
  def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
334
343
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
335
344
 
336
- if pip_requirements_list:
337
- warnings.warn(
345
+ if pip_requirements_list and self.targets_warehouse:
346
+ self._warn_once(
338
347
  (
339
348
  "Found dependencies specified as pip requirements."
340
349
  " This may prevent model deploying to Snowflake Warehouse."
341
350
  ),
342
- category=UserWarning,
343
351
  stacklevel=2,
344
352
  )
345
353
  for pip_dependency in pip_requirements_list:
@@ -167,7 +167,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
167
167
  model_blob_metadata = model_blobs_metadata[name]
168
168
  model_blob_filename = model_blob_metadata.path
169
169
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
170
- m = torch.load(f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu")
170
+ m = torch.load(
171
+ f,
172
+ map_location="cuda" if kwargs.get("use_gpu", False) else "cpu",
173
+ weights_only=False,
174
+ )
171
175
  assert isinstance(m, torch.nn.Module)
172
176
 
173
177
  return m