snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ from typing import Optional
2
2
 
3
3
  from pydantic import BaseModel
4
4
 
5
+ BaseModel.model_config["protected_namespaces"] = ()
6
+
5
7
 
6
8
  class Model(BaseModel):
7
9
  name: str
@@ -53,7 +55,7 @@ class HuggingFaceModel(BaseModel):
53
55
  hf_model_name: str
54
56
  task: Optional[str] = None
55
57
  tokenizer: Optional[str] = None
56
- hf_token: Optional[str] = None
58
+ token: Optional[str] = None
57
59
  trust_remote_code: Optional[bool] = False
58
60
  revision: Optional[str] = None
59
61
  hf_model_kwargs: Optional[str] = "{}"
@@ -3,7 +3,7 @@ import tempfile
3
3
  import uuid
4
4
  import warnings
5
5
  from types import ModuleType
6
- from typing import Any, Optional, Union
6
+ from typing import TYPE_CHECKING, Any, Optional, Union
7
7
  from urllib import parse
8
8
 
9
9
  from absl import logging
@@ -21,6 +21,9 @@ from snowflake.ml.model._packager.model_meta import model_meta
21
21
  from snowflake.snowpark import Session
22
22
  from snowflake.snowpark._internal import utils as snowpark_utils
23
23
 
24
+ if TYPE_CHECKING:
25
+ from snowflake.ml.experiment._experiment_info import ExperimentInfo
26
+
24
27
 
25
28
  class ModelComposer:
26
29
  """Top-level class to construct contents in a MODEL object in SQL.
@@ -136,6 +139,7 @@ class ModelComposer:
136
139
  ext_modules: Optional[list[ModuleType]] = None,
137
140
  code_paths: Optional[list[str]] = None,
138
141
  task: model_types.Task = model_types.Task.UNKNOWN,
142
+ experiment_info: Optional["ExperimentInfo"] = None,
139
143
  options: Optional[model_types.ModelSaveOption] = None,
140
144
  ) -> model_meta.ModelMetadata:
141
145
  # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
@@ -230,6 +234,7 @@ class ModelComposer:
230
234
  options=options,
231
235
  user_files=user_files,
232
236
  data_sources=self._get_data_sources(model, sample_input_data),
237
+ experiment_info=experiment_info,
233
238
  target_platforms=target_platforms,
234
239
  )
235
240
 
@@ -2,11 +2,12 @@ import collections
2
2
  import logging
3
3
  import pathlib
4
4
  import warnings
5
- from typing import Optional, cast
5
+ from typing import TYPE_CHECKING, Optional, cast
6
6
 
7
7
  import yaml
8
8
 
9
9
  from snowflake.ml._internal import env_utils
10
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
10
11
  from snowflake.ml.data import data_source
11
12
  from snowflake.ml.model import type_hints
12
13
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -22,6 +23,9 @@ from snowflake.ml.model._packager.model_meta import (
22
23
  )
23
24
  from snowflake.ml.model._packager.model_runtime import model_runtime
24
25
 
26
+ if TYPE_CHECKING:
27
+ from snowflake.ml.experiment._experiment_info import ExperimentInfo
28
+
25
29
  logger = logging.getLogger(__name__)
26
30
 
27
31
 
@@ -48,22 +52,50 @@ class ModelManifest:
48
52
  user_files: Optional[dict[str, list[str]]] = None,
49
53
  options: Optional[type_hints.ModelSaveOption] = None,
50
54
  data_sources: Optional[list[data_source.DataSource]] = None,
55
+ experiment_info: Optional["ExperimentInfo"] = None,
51
56
  target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
52
57
  ) -> None:
53
58
  if options is None:
54
59
  options = {}
55
60
 
61
+ has_pip_requirements = len(model_meta.env.pip_requirements) > 0
62
+ only_spcs = (
63
+ target_platforms
64
+ and len(target_platforms) == 1
65
+ and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
66
+ )
67
+
56
68
  if "relax_version" not in options:
57
- warnings.warn(
58
- (
59
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
60
- " from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
61
- "reproducibility, etc., set `options={'relax_version': False}` when logging the model."
62
- ),
63
- category=UserWarning,
64
- stacklevel=2,
65
- )
66
- relax_version = options.get("relax_version", True)
69
+ if has_pip_requirements or only_spcs:
70
+ logger.info(
71
+ "Setting `relax_version=False` as this model will run in Snowpark Container Services "
72
+ "or in Warehouse with a specified artifact_repository_map where exact version "
73
+ " specifications will be honored."
74
+ )
75
+ relax_version = False
76
+ else:
77
+ warnings.warn(
78
+ (
79
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
80
+ " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
81
+ " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
82
+ ),
83
+ category=UserWarning,
84
+ stacklevel=2,
85
+ )
86
+ relax_version = True
87
+ options["relax_version"] = relax_version
88
+ else:
89
+ relax_version = options.get("relax_version", True)
90
+ if relax_version and (has_pip_requirements or only_spcs):
91
+ raise exceptions.SnowflakeMLException(
92
+ error_code=error_codes.INVALID_ARGUMENT,
93
+ original_exception=ValueError(
94
+ "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
95
+ "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
96
+ "targeting only Snowpark Container Services."
97
+ ),
98
+ )
67
99
 
68
100
  runtime_to_use = model_runtime.ModelRuntime(
69
101
  name=self._DEFAULT_RUNTIME_NAME,
@@ -155,7 +187,7 @@ class ModelManifest:
155
187
  if self.user_files:
156
188
  manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
157
189
 
158
- lineage_sources = self._extract_lineage_info(data_sources)
190
+ lineage_sources = self._extract_lineage_info(data_sources, experiment_info)
159
191
  if lineage_sources:
160
192
  manifest_dict["lineage_sources"] = lineage_sources
161
193
 
@@ -182,7 +214,9 @@ class ModelManifest:
182
214
  return res
183
215
 
184
216
  def _extract_lineage_info(
185
- self, data_sources: Optional[list[data_source.DataSource]]
217
+ self,
218
+ data_sources: Optional[list[data_source.DataSource]],
219
+ experiment_info: Optional["ExperimentInfo"],
186
220
  ) -> list[model_manifest_schema.LineageSourceDict]:
187
221
  result = []
188
222
  if data_sources:
@@ -201,4 +235,12 @@ class ModelManifest:
201
235
  type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
202
236
  )
203
237
  )
238
+ if experiment_info:
239
+ result.append(
240
+ model_manifest_schema.LineageSourceDict(
241
+ type=model_manifest_schema.LineageSourceTypes.EXPERIMENT.value,
242
+ entity=experiment_info.fully_qualified_name,
243
+ version=experiment_info.run_name,
244
+ )
245
+ )
204
246
  return result
@@ -83,6 +83,7 @@ class SnowparkMLDataDict(TypedDict):
83
83
  class LineageSourceTypes(enum.Enum):
84
84
  DATASET = "DATASET"
85
85
  QUERY = "QUERY"
86
+ EXPERIMENT = "EXPERIMENT"
86
87
 
87
88
 
88
89
  class LineageSourceDict(TypedDict):
@@ -9,6 +9,7 @@ from packaging import requirements, version
9
9
 
10
10
  from snowflake.ml import version as snowml_version
11
11
  from snowflake.ml._internal import env as snowml_env, env_utils
12
+ from snowflake.ml.model import type_hints as model_types
12
13
  from snowflake.ml.model._packager.model_meta import model_meta_schema
13
14
 
14
15
  # requirement: Full version requirement where name is conda package name.
@@ -30,6 +31,7 @@ class ModelEnv:
30
31
  conda_env_rel_path: Optional[str] = None,
31
32
  pip_requirements_rel_path: Optional[str] = None,
32
33
  prefer_pip: bool = False,
34
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
33
35
  ) -> None:
34
36
  if conda_env_rel_path is None:
35
37
  conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
@@ -45,6 +47,8 @@ class ModelEnv:
45
47
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
46
48
  self._cuda_version: Optional[version.Version] = None
47
49
  self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
50
+ self._target_platforms = target_platforms
51
+ self._warnings_shown: set[str] = set()
48
52
 
49
53
  @property
50
54
  def conda_dependencies(self) -> list[str]:
@@ -116,6 +120,17 @@ class ModelEnv:
116
120
  if snowpark_ml_version:
117
121
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
118
122
 
123
+ @property
124
+ def targets_warehouse(self) -> bool:
125
+ """Returns True if warehouse is a target platform."""
126
+ return self._target_platforms is None or model_types.TargetPlatform.WAREHOUSE in self._target_platforms
127
+
128
+ def _warn_once(self, message: str, stacklevel: int = 2) -> None:
129
+ """Show warning only once per ModelEnv instance."""
130
+ if message not in self._warnings_shown:
131
+ warnings.warn(message, category=UserWarning, stacklevel=stacklevel)
132
+ self._warnings_shown.add(message)
133
+
119
134
  def include_if_absent(
120
135
  self,
121
136
  pkgs: list[ModelDependency],
@@ -130,14 +145,14 @@ class ModelEnv:
130
145
  """
131
146
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
132
147
  pip_pkg_reqs: list[str] = []
133
- warnings.warn(
134
- (
135
- "Dependencies specified from pip requirements."
136
- " This may prevent model deploying to Snowflake Warehouse."
137
- ),
138
- category=UserWarning,
139
- stacklevel=2,
140
- )
148
+ if self.targets_warehouse:
149
+ self._warn_once(
150
+ (
151
+ "Dependencies specified from pip requirements."
152
+ " This may prevent model deploying to Snowflake Warehouse."
153
+ ),
154
+ stacklevel=2,
155
+ )
141
156
  for conda_req_str, pip_name in pkgs:
142
157
  _, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
143
158
  pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
@@ -162,16 +177,15 @@ class ModelEnv:
162
177
  req_to_add.name = conda_req.name
163
178
  else:
164
179
  req_to_add = conda_req
165
- show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
180
+ show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
166
181
 
167
182
  if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
168
183
  if show_warning_message:
169
- warnings.warn(
184
+ self._warn_once(
170
185
  (
171
186
  f"Basic dependency {req_to_add.name} specified from pip requirements."
172
187
  " This may prevent model deploying to Snowflake Warehouse."
173
188
  ),
174
- category=UserWarning,
175
189
  stacklevel=2,
176
190
  )
177
191
  continue
@@ -182,12 +196,11 @@ class ModelEnv:
182
196
  pass
183
197
  except env_utils.DuplicateDependencyInMultipleChannelsError:
184
198
  if show_warning_message:
185
- warnings.warn(
199
+ self._warn_once(
186
200
  (
187
201
  f"Basic dependency {req_to_add.name} specified from non-Snowflake channel."
188
202
  + " This may prevent model deploying to Snowflake Warehouse."
189
203
  ),
190
- category=UserWarning,
191
204
  stacklevel=2,
192
205
  )
193
206
 
@@ -272,22 +285,20 @@ class ModelEnv:
272
285
  )
273
286
 
274
287
  for channel, channel_dependencies in conda_dependencies_dict.items():
275
- if channel != env_utils.DEFAULT_CHANNEL_NAME:
276
- warnings.warn(
288
+ if channel != env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse:
289
+ self._warn_once(
277
290
  (
278
291
  "Found dependencies specified in the conda file from non-Snowflake channel."
279
292
  " This may prevent model deploying to Snowflake Warehouse."
280
293
  ),
281
- category=UserWarning,
282
294
  stacklevel=2,
283
295
  )
284
- if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
285
- warnings.warn(
296
+ if len(channel_dependencies) == 0 and channel not in self._conda_dependencies and self.targets_warehouse:
297
+ self._warn_once(
286
298
  (
287
299
  f"Found additional conda channel {channel} specified in the conda file."
288
300
  " This may prevent model deploying to Snowflake Warehouse."
289
301
  ),
290
- category=UserWarning,
291
302
  stacklevel=2,
292
303
  )
293
304
  self._conda_dependencies[channel] = []
@@ -298,22 +309,20 @@ class ModelEnv:
298
309
  except env_utils.DuplicateDependencyError:
299
310
  pass
300
311
  except env_utils.DuplicateDependencyInMultipleChannelsError:
301
- warnings.warn(
312
+ self._warn_once(
302
313
  (
303
314
  f"Dependency {channel_dependency.name} appeared in multiple channels as conda dependency."
304
315
  " This may be unintentional."
305
316
  ),
306
- category=UserWarning,
307
317
  stacklevel=2,
308
318
  )
309
319
 
310
- if pip_requirements_list:
311
- warnings.warn(
320
+ if pip_requirements_list and self.targets_warehouse:
321
+ self._warn_once(
312
322
  (
313
323
  "Found dependencies specified as pip requirements."
314
324
  " This may prevent model deploying to Snowflake Warehouse."
315
325
  ),
316
- category=UserWarning,
317
326
  stacklevel=2,
318
327
  )
319
328
  for pip_dependency in pip_requirements_list:
@@ -333,13 +342,12 @@ class ModelEnv:
333
342
  def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
334
343
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
335
344
 
336
- if pip_requirements_list:
337
- warnings.warn(
345
+ if pip_requirements_list and self.targets_warehouse:
346
+ self._warn_once(
338
347
  (
339
348
  "Found dependencies specified as pip requirements."
340
349
  " This may prevent model deploying to Snowflake Warehouse."
341
350
  ),
342
- category=UserWarning,
343
351
  stacklevel=2,
344
352
  )
345
353
  for pip_dependency in pip_requirements_list:
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  import warnings
4
5
  from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
@@ -23,9 +24,13 @@ from snowflake.ml.model._signatures import utils as model_signature_utils
23
24
  from snowflake.ml.model.models import huggingface_pipeline
24
25
  from snowflake.snowpark._internal import utils as snowpark_utils
25
26
 
27
+ logger = logging.getLogger(__name__)
28
+
26
29
  if TYPE_CHECKING:
27
30
  import transformers
28
31
 
32
+ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
33
+
29
34
 
30
35
  def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
31
36
  # Text
@@ -326,6 +331,23 @@ class HuggingFacePipelineHandler(
326
331
  **device_config,
327
332
  )
328
333
 
334
+ # If the task is text-generation, and the tokenizer does not have a chat_template,
335
+ # set the default chat template.
336
+ if (
337
+ hasattr(m, "task")
338
+ and m.task == "text-generation"
339
+ and hasattr(m.tokenizer, "chat_template")
340
+ and not m.tokenizer.chat_template
341
+ ):
342
+ warnings.warn(
343
+ "The tokenizer does not have default chat_template. "
344
+ "Setting the chat_template to default ChatML template.",
345
+ UserWarning,
346
+ stacklevel=1,
347
+ )
348
+ logger.info(DEFAULT_CHAT_TEMPLATE)
349
+ m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
350
+
329
351
  m.__dict__.update(pipeline_params)
330
352
 
331
353
  else:
@@ -481,8 +503,25 @@ class HuggingFacePipelineHandler(
481
503
 
482
504
  # To enable batch_size > 1 for LLM
483
505
  # Pipe might not have tokenizer, but should always have a model, and model should always have a config.
484
- if getattr(pipe, "tokenizer", None) is not None and pipe.tokenizer.pad_token_id is None:
485
- pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
506
+ if (
507
+ getattr(pipe, "tokenizer", None) is not None
508
+ and pipe.tokenizer.pad_token_id is None
509
+ and hasattr(pipe.model.config, "eos_token_id")
510
+ ):
511
+ if isinstance(pipe.model.config.eos_token_id, int):
512
+ pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
513
+ elif (
514
+ isinstance(pipe.model.config.eos_token_id, list)
515
+ and len(pipe.model.config.eos_token_id) > 0
516
+ and isinstance(pipe.model.config.eos_token_id[0], int)
517
+ ):
518
+ pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]
519
+ else:
520
+ warnings.warn(
521
+ f"Unexpected type of eos_token_id: {type(pipe.model.config.eos_token_id)}. "
522
+ "Not setting pad_token_id to eos_token_id.",
523
+ stacklevel=2,
524
+ )
486
525
 
487
526
  _HFPipelineModel = _create_custom_model(pipe, model_meta)
488
527
  hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
@@ -167,7 +167,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
167
167
  model_blob_metadata = model_blobs_metadata[name]
168
168
  model_blob_filename = model_blob_metadata.path
169
169
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
170
- m = torch.load(f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu")
170
+ m = torch.load(
171
+ f,
172
+ map_location="cuda" if kwargs.get("use_gpu", False) else "cpu",
173
+ weights_only=False,
174
+ )
171
175
  assert isinstance(m, torch.nn.Module)
172
176
 
173
177
  return m
@@ -110,6 +110,7 @@ def create_model_metadata(
110
110
  python_version=python_version,
111
111
  embed_local_ml_library=embed_local_ml_library,
112
112
  prefer_pip=prefer_pip,
113
+ target_platforms=target_platforms,
113
114
  )
114
115
 
115
116
  if embed_local_ml_library:
@@ -162,8 +163,9 @@ def _create_env_for_model_metadata(
162
163
  python_version: Optional[str] = None,
163
164
  embed_local_ml_library: bool = False,
164
165
  prefer_pip: bool = False,
166
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
165
167
  ) -> model_env.ModelEnv:
166
- env = model_env.ModelEnv(prefer_pip=prefer_pip)
168
+ env = model_env.ModelEnv(prefer_pip=prefer_pip, target_platforms=target_platforms)
167
169
 
168
170
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
169
171
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
@@ -10,7 +10,7 @@ REQUIREMENTS = [
10
10
  "cryptography",
11
11
  "fsspec>=2024.6.1,<2026",
12
12
  "importlib_resources>=6.1.1, <7",
13
- "numpy>=1.23,<2",
13
+ "numpy>=1.23,<3",
14
14
  "packaging>=20.9,<25",
15
15
  "pandas>=2.1.4,<3",
16
16
  "pyarrow",
@@ -28,6 +28,7 @@ REQUIREMENTS = [
28
28
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
29
29
  "snowflake.core>=1.0.2,<2",
30
30
  "sqlparse>=0.4,<1",
31
+ "tqdm<5",
31
32
  "typing-extensions>=4.1.0,<5",
32
33
  "xgboost>=1.7.3,<3",
33
34
  ]
@@ -98,9 +98,9 @@ class ModelRuntime:
98
98
  dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
99
99
  conda=env_dict["conda"],
100
100
  pip=env_dict["pip"],
101
- artifact_repository_map=env_dict["artifact_repository_map"]
102
- if env_dict.get("artifact_repository_map") is not None
103
- else {},
101
+ artifact_repository_map=(
102
+ env_dict["artifact_repository_map"] if env_dict.get("artifact_repository_map") is not None else {}
103
+ ),
104
104
  ),
105
105
  resource_constraint=env_dict["resource_constraint"],
106
106
  )
@@ -60,12 +60,19 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
60
60
  data: snowflake.snowpark.DataFrame,
61
61
  ensure_serializable: bool = True,
62
62
  features: Optional[Sequence[core.BaseFeatureSpec]] = None,
63
+ statement_params: Optional[dict[str, Any]] = None,
63
64
  ) -> pd.DataFrame:
64
65
  # This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
65
66
  dtype_map = {}
67
+
66
68
  if features:
69
+ quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
70
+ data.session, statement_params
71
+ )
67
72
  for feature in features:
68
- dtype_map[feature.name] = feature.as_dtype()
73
+ feature_name = feature.name.upper() if quoted_identifiers_ignore_case else feature.name
74
+ dtype_map[feature_name] = feature.as_dtype()
75
+
69
76
  df_local = data.to_pandas()
70
77
 
71
78
  # This is because Array will become string (Even though the correct schema is set)
@@ -93,6 +100,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
93
100
  df: pd.DataFrame,
94
101
  keep_order: bool = False,
95
102
  features: Optional[Sequence[core.BaseFeatureSpec]] = None,
103
+ statement_params: Optional[dict[str, Any]] = None,
96
104
  ) -> snowflake.snowpark.DataFrame:
97
105
  # This method is necessary to create the Snowpark Dataframe in correct schema.
98
106
  # However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
@@ -100,6 +108,12 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
100
108
  # Although in this case, the column with array type can get correct ARRAY type, however, the element
101
109
  # type is not preserved, and will become string type. This affect the implementation of convert_from_df.
102
110
  df = pandas_handler.PandasDataFrameHandler.convert_to_df(df)
111
+ quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
112
+ session, statement_params
113
+ )
114
+ if quoted_identifiers_ignore_case:
115
+ df.columns = [str(col).upper() for col in df.columns]
116
+
103
117
  df_cols = df.columns
104
118
  if df_cols.dtype != np.object_:
105
119
  raise snowml_exceptions.SnowflakeMLException(
@@ -116,9 +130,47 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
116
130
  column_names = []
117
131
  columns = []
118
132
  for feature in features:
119
- column_names.append(identifier.get_inferred_name(feature.name))
120
- columns.append(F.col(identifier.get_inferred_name(feature.name)).cast(feature.as_snowpark_type()))
133
+ feature_name = identifier.get_inferred_name(feature.name)
134
+ if quoted_identifiers_ignore_case:
135
+ feature_name = feature_name.upper()
136
+ column_names.append(feature_name)
137
+ columns.append(F.col(feature_name).cast(feature.as_snowpark_type()))
121
138
 
122
139
  sp_df = sp_df.with_columns(column_names, columns)
123
140
 
124
141
  return sp_df
142
+
143
+ @staticmethod
144
+ def _is_quoted_identifiers_ignore_case_enabled(
145
+ session: snowflake.snowpark.Session, statement_params: Optional[dict[str, Any]] = None
146
+ ) -> bool:
147
+ """
148
+ Check if QUOTED_IDENTIFIERS_IGNORE_CASE parameter is enabled.
149
+
150
+ Args:
151
+ session: Snowpark session to check parameter for
152
+ statement_params: Optional statement parameters to check first
153
+
154
+ Returns:
155
+ bool: True if QUOTED_IDENTIFIERS_IGNORE_CASE is enabled, False otherwise
156
+ Returns False if the parameter cannot be retrieved (e.g., in stored procedures)
157
+ """
158
+ if statement_params is not None:
159
+ for key, value in statement_params.items():
160
+ if key.upper() == "QUOTED_IDENTIFIERS_IGNORE_CASE":
161
+ parameter_value = str(value)
162
+ return parameter_value.lower() == "true"
163
+
164
+ try:
165
+ result = session.sql(
166
+ "SHOW PARAMETERS LIKE 'QUOTED_IDENTIFIERS_IGNORE_CASE' IN SESSION",
167
+ _emit_ast=False,
168
+ ).collect(_emit_ast=False)
169
+
170
+ parameter_value = str(result[0].value)
171
+ return parameter_value.lower() == "true"
172
+
173
+ except Exception:
174
+ # Parameter query can fail in certain environments (e.g., in stored procedures)
175
+ # In that case, assume default behavior (case-sensitive)
176
+ return False