snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,8 @@ from typing import Optional
|
|
2
2
|
|
3
3
|
from pydantic import BaseModel
|
4
4
|
|
5
|
+
BaseModel.model_config["protected_namespaces"] = ()
|
6
|
+
|
5
7
|
|
6
8
|
class Model(BaseModel):
|
7
9
|
name: str
|
@@ -53,7 +55,7 @@ class HuggingFaceModel(BaseModel):
|
|
53
55
|
hf_model_name: str
|
54
56
|
task: Optional[str] = None
|
55
57
|
tokenizer: Optional[str] = None
|
56
|
-
|
58
|
+
token: Optional[str] = None
|
57
59
|
trust_remote_code: Optional[bool] = False
|
58
60
|
revision: Optional[str] = None
|
59
61
|
hf_model_kwargs: Optional[str] = "{}"
|
@@ -3,7 +3,7 @@ import tempfile
|
|
3
3
|
import uuid
|
4
4
|
import warnings
|
5
5
|
from types import ModuleType
|
6
|
-
from typing import Any, Optional, Union
|
6
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
7
7
|
from urllib import parse
|
8
8
|
|
9
9
|
from absl import logging
|
@@ -21,6 +21,9 @@ from snowflake.ml.model._packager.model_meta import model_meta
|
|
21
21
|
from snowflake.snowpark import Session
|
22
22
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
23
23
|
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
26
|
+
|
24
27
|
|
25
28
|
class ModelComposer:
|
26
29
|
"""Top-level class to construct contents in a MODEL object in SQL.
|
@@ -136,6 +139,7 @@ class ModelComposer:
|
|
136
139
|
ext_modules: Optional[list[ModuleType]] = None,
|
137
140
|
code_paths: Optional[list[str]] = None,
|
138
141
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
142
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
139
143
|
options: Optional[model_types.ModelSaveOption] = None,
|
140
144
|
) -> model_meta.ModelMetadata:
|
141
145
|
# set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
|
@@ -230,6 +234,7 @@ class ModelComposer:
|
|
230
234
|
options=options,
|
231
235
|
user_files=user_files,
|
232
236
|
data_sources=self._get_data_sources(model, sample_input_data),
|
237
|
+
experiment_info=experiment_info,
|
233
238
|
target_platforms=target_platforms,
|
234
239
|
)
|
235
240
|
|
@@ -2,11 +2,12 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import pathlib
|
4
4
|
import warnings
|
5
|
-
from typing import Optional, cast
|
5
|
+
from typing import TYPE_CHECKING, Optional, cast
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
9
9
|
from snowflake.ml._internal import env_utils
|
10
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
10
11
|
from snowflake.ml.data import data_source
|
11
12
|
from snowflake.ml.model import type_hints
|
12
13
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
@@ -22,6 +23,9 @@ from snowflake.ml.model._packager.model_meta import (
|
|
22
23
|
)
|
23
24
|
from snowflake.ml.model._packager.model_runtime import model_runtime
|
24
25
|
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
28
|
+
|
25
29
|
logger = logging.getLogger(__name__)
|
26
30
|
|
27
31
|
|
@@ -48,22 +52,50 @@ class ModelManifest:
|
|
48
52
|
user_files: Optional[dict[str, list[str]]] = None,
|
49
53
|
options: Optional[type_hints.ModelSaveOption] = None,
|
50
54
|
data_sources: Optional[list[data_source.DataSource]] = None,
|
55
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
51
56
|
target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
|
52
57
|
) -> None:
|
53
58
|
if options is None:
|
54
59
|
options = {}
|
55
60
|
|
61
|
+
has_pip_requirements = len(model_meta.env.pip_requirements) > 0
|
62
|
+
only_spcs = (
|
63
|
+
target_platforms
|
64
|
+
and len(target_platforms) == 1
|
65
|
+
and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
|
66
|
+
)
|
67
|
+
|
56
68
|
if "relax_version" not in options:
|
57
|
-
|
58
|
-
(
|
59
|
-
"`relax_version`
|
60
|
-
"
|
61
|
-
"
|
62
|
-
)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
69
|
+
if has_pip_requirements or only_spcs:
|
70
|
+
logger.info(
|
71
|
+
"Setting `relax_version=False` as this model will run in Snowpark Container Services "
|
72
|
+
"or in Warehouse with a specified artifact_repository_map where exact version "
|
73
|
+
" specifications will be honored."
|
74
|
+
)
|
75
|
+
relax_version = False
|
76
|
+
else:
|
77
|
+
warnings.warn(
|
78
|
+
(
|
79
|
+
"`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
|
80
|
+
" relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
|
81
|
+
" reproducibility, etc., set `options={'relax_version': False}` when logging the model."
|
82
|
+
),
|
83
|
+
category=UserWarning,
|
84
|
+
stacklevel=2,
|
85
|
+
)
|
86
|
+
relax_version = True
|
87
|
+
options["relax_version"] = relax_version
|
88
|
+
else:
|
89
|
+
relax_version = options.get("relax_version", True)
|
90
|
+
if relax_version and (has_pip_requirements or only_spcs):
|
91
|
+
raise exceptions.SnowflakeMLException(
|
92
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
93
|
+
original_exception=ValueError(
|
94
|
+
"Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
|
95
|
+
"Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
|
96
|
+
"targeting only Snowpark Container Services."
|
97
|
+
),
|
98
|
+
)
|
67
99
|
|
68
100
|
runtime_to_use = model_runtime.ModelRuntime(
|
69
101
|
name=self._DEFAULT_RUNTIME_NAME,
|
@@ -155,7 +187,7 @@ class ModelManifest:
|
|
155
187
|
if self.user_files:
|
156
188
|
manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
|
157
189
|
|
158
|
-
lineage_sources = self._extract_lineage_info(data_sources)
|
190
|
+
lineage_sources = self._extract_lineage_info(data_sources, experiment_info)
|
159
191
|
if lineage_sources:
|
160
192
|
manifest_dict["lineage_sources"] = lineage_sources
|
161
193
|
|
@@ -182,7 +214,9 @@ class ModelManifest:
|
|
182
214
|
return res
|
183
215
|
|
184
216
|
def _extract_lineage_info(
|
185
|
-
self,
|
217
|
+
self,
|
218
|
+
data_sources: Optional[list[data_source.DataSource]],
|
219
|
+
experiment_info: Optional["ExperimentInfo"],
|
186
220
|
) -> list[model_manifest_schema.LineageSourceDict]:
|
187
221
|
result = []
|
188
222
|
if data_sources:
|
@@ -201,4 +235,12 @@ class ModelManifest:
|
|
201
235
|
type=model_manifest_schema.LineageSourceTypes.QUERY.value, entity=source.sql
|
202
236
|
)
|
203
237
|
)
|
238
|
+
if experiment_info:
|
239
|
+
result.append(
|
240
|
+
model_manifest_schema.LineageSourceDict(
|
241
|
+
type=model_manifest_schema.LineageSourceTypes.EXPERIMENT.value,
|
242
|
+
entity=experiment_info.fully_qualified_name,
|
243
|
+
version=experiment_info.run_name,
|
244
|
+
)
|
245
|
+
)
|
204
246
|
return result
|
@@ -9,6 +9,7 @@ from packaging import requirements, version
|
|
9
9
|
|
10
10
|
from snowflake.ml import version as snowml_version
|
11
11
|
from snowflake.ml._internal import env as snowml_env, env_utils
|
12
|
+
from snowflake.ml.model import type_hints as model_types
|
12
13
|
from snowflake.ml.model._packager.model_meta import model_meta_schema
|
13
14
|
|
14
15
|
# requirement: Full version requirement where name is conda package name.
|
@@ -30,6 +31,7 @@ class ModelEnv:
|
|
30
31
|
conda_env_rel_path: Optional[str] = None,
|
31
32
|
pip_requirements_rel_path: Optional[str] = None,
|
32
33
|
prefer_pip: bool = False,
|
34
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
33
35
|
) -> None:
|
34
36
|
if conda_env_rel_path is None:
|
35
37
|
conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
|
@@ -45,6 +47,8 @@ class ModelEnv:
|
|
45
47
|
self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
|
46
48
|
self._cuda_version: Optional[version.Version] = None
|
47
49
|
self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
|
50
|
+
self._target_platforms = target_platforms
|
51
|
+
self._warnings_shown: set[str] = set()
|
48
52
|
|
49
53
|
@property
|
50
54
|
def conda_dependencies(self) -> list[str]:
|
@@ -116,6 +120,17 @@ class ModelEnv:
|
|
116
120
|
if snowpark_ml_version:
|
117
121
|
self._snowpark_ml_version = version.parse(snowpark_ml_version)
|
118
122
|
|
123
|
+
@property
|
124
|
+
def targets_warehouse(self) -> bool:
|
125
|
+
"""Returns True if warehouse is a target platform."""
|
126
|
+
return self._target_platforms is None or model_types.TargetPlatform.WAREHOUSE in self._target_platforms
|
127
|
+
|
128
|
+
def _warn_once(self, message: str, stacklevel: int = 2) -> None:
|
129
|
+
"""Show warning only once per ModelEnv instance."""
|
130
|
+
if message not in self._warnings_shown:
|
131
|
+
warnings.warn(message, category=UserWarning, stacklevel=stacklevel)
|
132
|
+
self._warnings_shown.add(message)
|
133
|
+
|
119
134
|
def include_if_absent(
|
120
135
|
self,
|
121
136
|
pkgs: list[ModelDependency],
|
@@ -130,14 +145,14 @@ class ModelEnv:
|
|
130
145
|
"""
|
131
146
|
if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
|
132
147
|
pip_pkg_reqs: list[str] = []
|
133
|
-
|
134
|
-
(
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
148
|
+
if self.targets_warehouse:
|
149
|
+
self._warn_once(
|
150
|
+
(
|
151
|
+
"Dependencies specified from pip requirements."
|
152
|
+
" This may prevent model deploying to Snowflake Warehouse."
|
153
|
+
),
|
154
|
+
stacklevel=2,
|
155
|
+
)
|
141
156
|
for conda_req_str, pip_name in pkgs:
|
142
157
|
_, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
|
143
158
|
pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
|
@@ -162,16 +177,15 @@ class ModelEnv:
|
|
162
177
|
req_to_add.name = conda_req.name
|
163
178
|
else:
|
164
179
|
req_to_add = conda_req
|
165
|
-
show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
|
180
|
+
show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
|
166
181
|
|
167
182
|
if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
|
168
183
|
if show_warning_message:
|
169
|
-
|
184
|
+
self._warn_once(
|
170
185
|
(
|
171
186
|
f"Basic dependency {req_to_add.name} specified from pip requirements."
|
172
187
|
" This may prevent model deploying to Snowflake Warehouse."
|
173
188
|
),
|
174
|
-
category=UserWarning,
|
175
189
|
stacklevel=2,
|
176
190
|
)
|
177
191
|
continue
|
@@ -182,12 +196,11 @@ class ModelEnv:
|
|
182
196
|
pass
|
183
197
|
except env_utils.DuplicateDependencyInMultipleChannelsError:
|
184
198
|
if show_warning_message:
|
185
|
-
|
199
|
+
self._warn_once(
|
186
200
|
(
|
187
201
|
f"Basic dependency {req_to_add.name} specified from non-Snowflake channel."
|
188
202
|
+ " This may prevent model deploying to Snowflake Warehouse."
|
189
203
|
),
|
190
|
-
category=UserWarning,
|
191
204
|
stacklevel=2,
|
192
205
|
)
|
193
206
|
|
@@ -272,22 +285,20 @@ class ModelEnv:
|
|
272
285
|
)
|
273
286
|
|
274
287
|
for channel, channel_dependencies in conda_dependencies_dict.items():
|
275
|
-
if channel != env_utils.DEFAULT_CHANNEL_NAME:
|
276
|
-
|
288
|
+
if channel != env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse:
|
289
|
+
self._warn_once(
|
277
290
|
(
|
278
291
|
"Found dependencies specified in the conda file from non-Snowflake channel."
|
279
292
|
" This may prevent model deploying to Snowflake Warehouse."
|
280
293
|
),
|
281
|
-
category=UserWarning,
|
282
294
|
stacklevel=2,
|
283
295
|
)
|
284
|
-
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
|
285
|
-
|
296
|
+
if len(channel_dependencies) == 0 and channel not in self._conda_dependencies and self.targets_warehouse:
|
297
|
+
self._warn_once(
|
286
298
|
(
|
287
299
|
f"Found additional conda channel {channel} specified in the conda file."
|
288
300
|
" This may prevent model deploying to Snowflake Warehouse."
|
289
301
|
),
|
290
|
-
category=UserWarning,
|
291
302
|
stacklevel=2,
|
292
303
|
)
|
293
304
|
self._conda_dependencies[channel] = []
|
@@ -298,22 +309,20 @@ class ModelEnv:
|
|
298
309
|
except env_utils.DuplicateDependencyError:
|
299
310
|
pass
|
300
311
|
except env_utils.DuplicateDependencyInMultipleChannelsError:
|
301
|
-
|
312
|
+
self._warn_once(
|
302
313
|
(
|
303
314
|
f"Dependency {channel_dependency.name} appeared in multiple channels as conda dependency."
|
304
315
|
" This may be unintentional."
|
305
316
|
),
|
306
|
-
category=UserWarning,
|
307
317
|
stacklevel=2,
|
308
318
|
)
|
309
319
|
|
310
|
-
if pip_requirements_list:
|
311
|
-
|
320
|
+
if pip_requirements_list and self.targets_warehouse:
|
321
|
+
self._warn_once(
|
312
322
|
(
|
313
323
|
"Found dependencies specified as pip requirements."
|
314
324
|
" This may prevent model deploying to Snowflake Warehouse."
|
315
325
|
),
|
316
|
-
category=UserWarning,
|
317
326
|
stacklevel=2,
|
318
327
|
)
|
319
328
|
for pip_dependency in pip_requirements_list:
|
@@ -333,13 +342,12 @@ class ModelEnv:
|
|
333
342
|
def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
|
334
343
|
pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
|
335
344
|
|
336
|
-
if pip_requirements_list:
|
337
|
-
|
345
|
+
if pip_requirements_list and self.targets_warehouse:
|
346
|
+
self._warn_once(
|
338
347
|
(
|
339
348
|
"Found dependencies specified as pip requirements."
|
340
349
|
" This may prevent model deploying to Snowflake Warehouse."
|
341
350
|
),
|
342
|
-
category=UserWarning,
|
343
351
|
stacklevel=2,
|
344
352
|
)
|
345
353
|
for pip_dependency in pip_requirements_list:
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import json
|
2
|
+
import logging
|
2
3
|
import os
|
3
4
|
import warnings
|
4
5
|
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast, final
|
@@ -23,9 +24,13 @@ from snowflake.ml.model._signatures import utils as model_signature_utils
|
|
23
24
|
from snowflake.ml.model.models import huggingface_pipeline
|
24
25
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
25
26
|
|
27
|
+
logger = logging.getLogger(__name__)
|
28
|
+
|
26
29
|
if TYPE_CHECKING:
|
27
30
|
import transformers
|
28
31
|
|
32
|
+
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" # noqa: E501
|
33
|
+
|
29
34
|
|
30
35
|
def get_requirements_from_task(task: str, spcs_only: bool = False) -> list[model_env.ModelDependency]:
|
31
36
|
# Text
|
@@ -326,6 +331,23 @@ class HuggingFacePipelineHandler(
|
|
326
331
|
**device_config,
|
327
332
|
)
|
328
333
|
|
334
|
+
# If the task is text-generation, and the tokenizer does not have a chat_template,
|
335
|
+
# set the default chat template.
|
336
|
+
if (
|
337
|
+
hasattr(m, "task")
|
338
|
+
and m.task == "text-generation"
|
339
|
+
and hasattr(m.tokenizer, "chat_template")
|
340
|
+
and not m.tokenizer.chat_template
|
341
|
+
):
|
342
|
+
warnings.warn(
|
343
|
+
"The tokenizer does not have default chat_template. "
|
344
|
+
"Setting the chat_template to default ChatML template.",
|
345
|
+
UserWarning,
|
346
|
+
stacklevel=1,
|
347
|
+
)
|
348
|
+
logger.info(DEFAULT_CHAT_TEMPLATE)
|
349
|
+
m.tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
350
|
+
|
329
351
|
m.__dict__.update(pipeline_params)
|
330
352
|
|
331
353
|
else:
|
@@ -481,8 +503,25 @@ class HuggingFacePipelineHandler(
|
|
481
503
|
|
482
504
|
# To enable batch_size > 1 for LLM
|
483
505
|
# Pipe might not have tokenizer, but should always have a model, and model should always have a config.
|
484
|
-
if
|
485
|
-
pipe
|
506
|
+
if (
|
507
|
+
getattr(pipe, "tokenizer", None) is not None
|
508
|
+
and pipe.tokenizer.pad_token_id is None
|
509
|
+
and hasattr(pipe.model.config, "eos_token_id")
|
510
|
+
):
|
511
|
+
if isinstance(pipe.model.config.eos_token_id, int):
|
512
|
+
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
|
513
|
+
elif (
|
514
|
+
isinstance(pipe.model.config.eos_token_id, list)
|
515
|
+
and len(pipe.model.config.eos_token_id) > 0
|
516
|
+
and isinstance(pipe.model.config.eos_token_id[0], int)
|
517
|
+
):
|
518
|
+
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id[0]
|
519
|
+
else:
|
520
|
+
warnings.warn(
|
521
|
+
f"Unexpected type of eos_token_id: {type(pipe.model.config.eos_token_id)}. "
|
522
|
+
"Not setting pad_token_id to eos_token_id.",
|
523
|
+
stacklevel=2,
|
524
|
+
)
|
486
525
|
|
487
526
|
_HFPipelineModel = _create_custom_model(pipe, model_meta)
|
488
527
|
hg_pipe_model = _HFPipelineModel(custom_model.ModelContext())
|
@@ -167,7 +167,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
167
167
|
model_blob_metadata = model_blobs_metadata[name]
|
168
168
|
model_blob_filename = model_blob_metadata.path
|
169
169
|
with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
|
170
|
-
m = torch.load(
|
170
|
+
m = torch.load(
|
171
|
+
f,
|
172
|
+
map_location="cuda" if kwargs.get("use_gpu", False) else "cpu",
|
173
|
+
weights_only=False,
|
174
|
+
)
|
171
175
|
assert isinstance(m, torch.nn.Module)
|
172
176
|
|
173
177
|
return m
|
@@ -110,6 +110,7 @@ def create_model_metadata(
|
|
110
110
|
python_version=python_version,
|
111
111
|
embed_local_ml_library=embed_local_ml_library,
|
112
112
|
prefer_pip=prefer_pip,
|
113
|
+
target_platforms=target_platforms,
|
113
114
|
)
|
114
115
|
|
115
116
|
if embed_local_ml_library:
|
@@ -162,8 +163,9 @@ def _create_env_for_model_metadata(
|
|
162
163
|
python_version: Optional[str] = None,
|
163
164
|
embed_local_ml_library: bool = False,
|
164
165
|
prefer_pip: bool = False,
|
166
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
165
167
|
) -> model_env.ModelEnv:
|
166
|
-
env = model_env.ModelEnv(prefer_pip=prefer_pip)
|
168
|
+
env = model_env.ModelEnv(prefer_pip=prefer_pip, target_platforms=target_platforms)
|
167
169
|
|
168
170
|
# Mypy doesn't like getter and setter have different types. See python/mypy #3004
|
169
171
|
env.conda_dependencies = conda_dependencies # type: ignore[assignment]
|
@@ -10,7 +10,7 @@ REQUIREMENTS = [
|
|
10
10
|
"cryptography",
|
11
11
|
"fsspec>=2024.6.1,<2026",
|
12
12
|
"importlib_resources>=6.1.1, <7",
|
13
|
-
"numpy>=1.23,<
|
13
|
+
"numpy>=1.23,<3",
|
14
14
|
"packaging>=20.9,<25",
|
15
15
|
"pandas>=2.1.4,<3",
|
16
16
|
"pyarrow",
|
@@ -28,6 +28,7 @@ REQUIREMENTS = [
|
|
28
28
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
29
29
|
"snowflake.core>=1.0.2,<2",
|
30
30
|
"sqlparse>=0.4,<1",
|
31
|
+
"tqdm<5",
|
31
32
|
"typing-extensions>=4.1.0,<5",
|
32
33
|
"xgboost>=1.7.3,<3",
|
33
34
|
]
|
@@ -98,9 +98,9 @@ class ModelRuntime:
|
|
98
98
|
dependencies=model_meta_schema.ModelRuntimeDependenciesDict(
|
99
99
|
conda=env_dict["conda"],
|
100
100
|
pip=env_dict["pip"],
|
101
|
-
artifact_repository_map=
|
102
|
-
|
103
|
-
|
101
|
+
artifact_repository_map=(
|
102
|
+
env_dict["artifact_repository_map"] if env_dict.get("artifact_repository_map") is not None else {}
|
103
|
+
),
|
104
104
|
),
|
105
105
|
resource_constraint=env_dict["resource_constraint"],
|
106
106
|
)
|
@@ -60,12 +60,19 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
60
60
|
data: snowflake.snowpark.DataFrame,
|
61
61
|
ensure_serializable: bool = True,
|
62
62
|
features: Optional[Sequence[core.BaseFeatureSpec]] = None,
|
63
|
+
statement_params: Optional[dict[str, Any]] = None,
|
63
64
|
) -> pd.DataFrame:
|
64
65
|
# This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
|
65
66
|
dtype_map = {}
|
67
|
+
|
66
68
|
if features:
|
69
|
+
quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
70
|
+
data.session, statement_params
|
71
|
+
)
|
67
72
|
for feature in features:
|
68
|
-
|
73
|
+
feature_name = feature.name.upper() if quoted_identifiers_ignore_case else feature.name
|
74
|
+
dtype_map[feature_name] = feature.as_dtype()
|
75
|
+
|
69
76
|
df_local = data.to_pandas()
|
70
77
|
|
71
78
|
# This is because Array will become string (Even though the correct schema is set)
|
@@ -93,6 +100,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
93
100
|
df: pd.DataFrame,
|
94
101
|
keep_order: bool = False,
|
95
102
|
features: Optional[Sequence[core.BaseFeatureSpec]] = None,
|
103
|
+
statement_params: Optional[dict[str, Any]] = None,
|
96
104
|
) -> snowflake.snowpark.DataFrame:
|
97
105
|
# This method is necessary to create the Snowpark Dataframe in correct schema.
|
98
106
|
# However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
|
@@ -100,6 +108,12 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
100
108
|
# Although in this case, the column with array type can get correct ARRAY type, however, the element
|
101
109
|
# type is not preserved, and will become string type. This affect the implementation of convert_from_df.
|
102
110
|
df = pandas_handler.PandasDataFrameHandler.convert_to_df(df)
|
111
|
+
quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
112
|
+
session, statement_params
|
113
|
+
)
|
114
|
+
if quoted_identifiers_ignore_case:
|
115
|
+
df.columns = [str(col).upper() for col in df.columns]
|
116
|
+
|
103
117
|
df_cols = df.columns
|
104
118
|
if df_cols.dtype != np.object_:
|
105
119
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -116,9 +130,47 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
116
130
|
column_names = []
|
117
131
|
columns = []
|
118
132
|
for feature in features:
|
119
|
-
|
120
|
-
|
133
|
+
feature_name = identifier.get_inferred_name(feature.name)
|
134
|
+
if quoted_identifiers_ignore_case:
|
135
|
+
feature_name = feature_name.upper()
|
136
|
+
column_names.append(feature_name)
|
137
|
+
columns.append(F.col(feature_name).cast(feature.as_snowpark_type()))
|
121
138
|
|
122
139
|
sp_df = sp_df.with_columns(column_names, columns)
|
123
140
|
|
124
141
|
return sp_df
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def _is_quoted_identifiers_ignore_case_enabled(
|
145
|
+
session: snowflake.snowpark.Session, statement_params: Optional[dict[str, Any]] = None
|
146
|
+
) -> bool:
|
147
|
+
"""
|
148
|
+
Check if QUOTED_IDENTIFIERS_IGNORE_CASE parameter is enabled.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
session: Snowpark session to check parameter for
|
152
|
+
statement_params: Optional statement parameters to check first
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
bool: True if QUOTED_IDENTIFIERS_IGNORE_CASE is enabled, False otherwise
|
156
|
+
Returns False if the parameter cannot be retrieved (e.g., in stored procedures)
|
157
|
+
"""
|
158
|
+
if statement_params is not None:
|
159
|
+
for key, value in statement_params.items():
|
160
|
+
if key.upper() == "QUOTED_IDENTIFIERS_IGNORE_CASE":
|
161
|
+
parameter_value = str(value)
|
162
|
+
return parameter_value.lower() == "true"
|
163
|
+
|
164
|
+
try:
|
165
|
+
result = session.sql(
|
166
|
+
"SHOW PARAMETERS LIKE 'QUOTED_IDENTIFIERS_IGNORE_CASE' IN SESSION",
|
167
|
+
_emit_ast=False,
|
168
|
+
).collect(_emit_ast=False)
|
169
|
+
|
170
|
+
parameter_value = str(result[0].value)
|
171
|
+
return parameter_value.lower() == "true"
|
172
|
+
|
173
|
+
except Exception:
|
174
|
+
# Parameter query can fail in certain environments (e.g., in stored procedures)
|
175
|
+
# In that case, assume default behavior (case-sensitive)
|
176
|
+
return False
|