snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import os
|
|
3
1
|
import warnings
|
|
4
2
|
from types import ModuleType
|
|
5
3
|
from typing import Any, Optional, Union, overload
|
|
@@ -8,13 +6,15 @@ import pandas as pd
|
|
|
8
6
|
|
|
9
7
|
from snowflake import snowpark
|
|
10
8
|
from snowflake.ml._internal import telemetry
|
|
11
|
-
from snowflake.ml._internal.utils import sql_identifier
|
|
9
|
+
from snowflake.ml._internal.utils import query_result_checker, sql_identifier
|
|
12
10
|
from snowflake.ml.model import (
|
|
13
11
|
Model,
|
|
14
12
|
ModelVersion,
|
|
13
|
+
event_handler,
|
|
15
14
|
model_signature,
|
|
15
|
+
target_platform,
|
|
16
16
|
task,
|
|
17
|
-
type_hints
|
|
17
|
+
type_hints,
|
|
18
18
|
)
|
|
19
19
|
from snowflake.ml.model._client.model import model_version_impl
|
|
20
20
|
from snowflake.ml.monitoring import model_monitor
|
|
@@ -32,52 +32,6 @@ _MODEL_MONITORING_DISABLED_ERROR = (
|
|
|
32
32
|
)
|
|
33
33
|
|
|
34
34
|
|
|
35
|
-
class _NullStatusContext:
|
|
36
|
-
"""A fallback context manager that logs status updates."""
|
|
37
|
-
|
|
38
|
-
def __init__(self, label: str) -> None:
|
|
39
|
-
self._label = label
|
|
40
|
-
|
|
41
|
-
def __enter__(self) -> "_NullStatusContext":
|
|
42
|
-
logging.info(f"Starting: {self._label}")
|
|
43
|
-
return self
|
|
44
|
-
|
|
45
|
-
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
|
46
|
-
pass
|
|
47
|
-
|
|
48
|
-
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
|
49
|
-
"""Update the status by logging the message."""
|
|
50
|
-
logging.info(f"Status update: {label} (state: {state})")
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class RegistryEventHandler:
|
|
54
|
-
def __init__(self) -> None:
|
|
55
|
-
try:
|
|
56
|
-
import streamlit as st
|
|
57
|
-
|
|
58
|
-
if not st.runtime.exists():
|
|
59
|
-
self._streamlit = None
|
|
60
|
-
else:
|
|
61
|
-
self._streamlit = st
|
|
62
|
-
USE_STREAMLIT_WIDGETS = os.getenv("USE_STREAMLIT_WIDGETS", "1") == "1"
|
|
63
|
-
if not USE_STREAMLIT_WIDGETS:
|
|
64
|
-
self._streamlit = None
|
|
65
|
-
except ImportError:
|
|
66
|
-
self._streamlit = None
|
|
67
|
-
|
|
68
|
-
def update(self, message: str) -> None:
|
|
69
|
-
"""Write a message using streamlit if available, otherwise do nothing."""
|
|
70
|
-
if self._streamlit is not None:
|
|
71
|
-
self._streamlit.write(message)
|
|
72
|
-
|
|
73
|
-
def status(self, label: str, *, state: str = "running", expanded: bool = True) -> Any:
|
|
74
|
-
"""Context manager that provides status updates with optional enhanced display capabilities."""
|
|
75
|
-
if self._streamlit is None:
|
|
76
|
-
return _NullStatusContext(label)
|
|
77
|
-
else:
|
|
78
|
-
return self._streamlit.status(label, state=state, expanded=expanded)
|
|
79
|
-
|
|
80
|
-
|
|
81
35
|
class Registry:
|
|
82
36
|
@telemetry.send_api_usage_telemetry(project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT)
|
|
83
37
|
def __init__(
|
|
@@ -124,20 +78,30 @@ class Registry:
|
|
|
124
78
|
else sql_identifier.SqlIdentifier("PUBLIC")
|
|
125
79
|
)
|
|
126
80
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
81
|
+
database_results = (
|
|
82
|
+
query_result_checker.SqlResultValidator(
|
|
83
|
+
session, f"""SHOW DATABASES LIKE '{self._database_name.resolved()}';"""
|
|
84
|
+
)
|
|
85
|
+
.has_column("name", allow_empty=True)
|
|
86
|
+
.validate()
|
|
87
|
+
)
|
|
130
88
|
|
|
131
|
-
|
|
89
|
+
db_names = [row["name"] for row in database_results]
|
|
90
|
+
if not self._database_name.resolved() in db_names:
|
|
132
91
|
raise ValueError(f"Database {self._database_name} does not exist.")
|
|
133
92
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
93
|
+
schema_results = (
|
|
94
|
+
query_result_checker.SqlResultValidator(
|
|
95
|
+
session,
|
|
96
|
+
f"""SHOW SCHEMAS LIKE '{self._schema_name.resolved()}'
|
|
97
|
+
IN DATABASE {self._database_name.identifier()};""",
|
|
98
|
+
)
|
|
99
|
+
.has_column("name", allow_empty=True)
|
|
100
|
+
.validate()
|
|
101
|
+
)
|
|
139
102
|
|
|
140
|
-
|
|
103
|
+
schema_names = [row["name"] for row in schema_results]
|
|
104
|
+
if not self._schema_name.resolved() in schema_names:
|
|
141
105
|
raise ValueError(f"Schema {self._schema_name} does not exist.")
|
|
142
106
|
|
|
143
107
|
self._model_manager = model_manager.ModelManager(
|
|
@@ -168,7 +132,7 @@ class Registry:
|
|
|
168
132
|
@overload
|
|
169
133
|
def log_model(
|
|
170
134
|
self,
|
|
171
|
-
model:
|
|
135
|
+
model: type_hints.SupportedModelType,
|
|
172
136
|
*,
|
|
173
137
|
model_name: str,
|
|
174
138
|
version_name: Optional[str] = None,
|
|
@@ -178,15 +142,15 @@ class Registry:
|
|
|
178
142
|
pip_requirements: Optional[list[str]] = None,
|
|
179
143
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
|
180
144
|
resource_constraint: Optional[dict[str, str]] = None,
|
|
181
|
-
target_platforms: Optional[list[
|
|
145
|
+
target_platforms: Optional[list[Union[target_platform.TargetPlatform, str]]] = None,
|
|
182
146
|
python_version: Optional[str] = None,
|
|
183
147
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
|
184
|
-
sample_input_data: Optional[
|
|
148
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
|
185
149
|
user_files: Optional[dict[str, list[str]]] = None,
|
|
186
150
|
code_paths: Optional[list[str]] = None,
|
|
187
151
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
188
|
-
task:
|
|
189
|
-
options: Optional[
|
|
152
|
+
task: task.Task = task.Task.UNKNOWN,
|
|
153
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
|
190
154
|
) -> ModelVersion:
|
|
191
155
|
"""
|
|
192
156
|
Log a model with various parameters and metadata, or a ModelVersion object.
|
|
@@ -258,7 +222,8 @@ class Registry:
|
|
|
258
222
|
- target_methods: List of target methods to register when logging the model.
|
|
259
223
|
This option is not used in MLFlow models. Defaults to None, in which case the model handler's
|
|
260
224
|
default target methods will be used.
|
|
261
|
-
- save_location:
|
|
225
|
+
- save_location: Local directory to save the the serialized model files first before
|
|
226
|
+
uploading to Snowflake. This is useful when default tmp directory is not writable.
|
|
262
227
|
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
|
263
228
|
values with the desired options.
|
|
264
229
|
|
|
@@ -315,7 +280,7 @@ class Registry:
|
|
|
315
280
|
)
|
|
316
281
|
def log_model(
|
|
317
282
|
self,
|
|
318
|
-
model: Union[
|
|
283
|
+
model: Union[type_hints.SupportedModelType, ModelVersion],
|
|
319
284
|
*,
|
|
320
285
|
model_name: str,
|
|
321
286
|
version_name: Optional[str] = None,
|
|
@@ -325,15 +290,15 @@ class Registry:
|
|
|
325
290
|
pip_requirements: Optional[list[str]] = None,
|
|
326
291
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
|
327
292
|
resource_constraint: Optional[dict[str, str]] = None,
|
|
328
|
-
target_platforms: Optional[list[
|
|
293
|
+
target_platforms: Optional[list[Union[target_platform.TargetPlatform, str]]] = None,
|
|
329
294
|
python_version: Optional[str] = None,
|
|
330
295
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
|
331
|
-
sample_input_data: Optional[
|
|
296
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
|
332
297
|
user_files: Optional[dict[str, list[str]]] = None,
|
|
333
298
|
code_paths: Optional[list[str]] = None,
|
|
334
299
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
335
|
-
task:
|
|
336
|
-
options: Optional[
|
|
300
|
+
task: task.Task = task.Task.UNKNOWN,
|
|
301
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
|
337
302
|
) -> ModelVersion:
|
|
338
303
|
"""
|
|
339
304
|
Log a model with various parameters and metadata, or a ModelVersion object.
|
|
@@ -474,7 +439,7 @@ class Registry:
|
|
|
474
439
|
raise ValueError(
|
|
475
440
|
"When calling log_model with a ModelVersion, only model_name and version_name may be specified."
|
|
476
441
|
)
|
|
477
|
-
if task is not
|
|
442
|
+
if task is not type_hints.Task.UNKNOWN:
|
|
478
443
|
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
|
479
444
|
|
|
480
445
|
if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
|
|
@@ -486,8 +451,12 @@ class Registry:
|
|
|
486
451
|
stacklevel=1,
|
|
487
452
|
)
|
|
488
453
|
|
|
489
|
-
|
|
490
|
-
with
|
|
454
|
+
registry_event_handler = event_handler.ModelEventHandler()
|
|
455
|
+
with registry_event_handler.status("Logging model", total=6) as status:
|
|
456
|
+
# Step 1: Validation and setup
|
|
457
|
+
status.update("validating model and dependencies...")
|
|
458
|
+
status.increment()
|
|
459
|
+
|
|
491
460
|
# Perform the actual model logging
|
|
492
461
|
try:
|
|
493
462
|
result = self._model_manager.log_model(
|
|
@@ -510,13 +479,12 @@ class Registry:
|
|
|
510
479
|
task=task,
|
|
511
480
|
options=options,
|
|
512
481
|
statement_params=statement_params,
|
|
513
|
-
|
|
482
|
+
progress_status=status,
|
|
514
483
|
)
|
|
515
|
-
status.update(label="Model logged successfully
|
|
484
|
+
status.update(label="Model logged successfully.", state="complete", expanded=False)
|
|
516
485
|
return result
|
|
517
486
|
except Exception as e:
|
|
518
|
-
|
|
519
|
-
status.update(label="Model logging failed!", state="error", expanded=False)
|
|
487
|
+
status.update(label="Model logging failed.", state="error", expanded=False)
|
|
520
488
|
raise e
|
|
521
489
|
|
|
522
490
|
@telemetry.send_api_usage_telemetry(
|
|
@@ -696,10 +664,10 @@ class Registry:
|
|
|
696
664
|
self._model_monitor_manager.delete_monitor(name)
|
|
697
665
|
|
|
698
666
|
@staticmethod
|
|
699
|
-
def _targets_warehouse(target_platforms: Optional[list[
|
|
667
|
+
def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
|
|
700
668
|
"""Returns True if warehouse is a target platform (None defaults to True)."""
|
|
701
669
|
return (
|
|
702
670
|
target_platforms is None
|
|
703
|
-
or
|
|
671
|
+
or type_hints.TargetPlatform.WAREHOUSE in target_platforms
|
|
704
672
|
or "WAREHOUSE" in target_platforms
|
|
705
673
|
)
|
snowflake/ml/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
|
2
|
-
VERSION = "1.9.
|
|
2
|
+
VERSION = "1.9.2"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: snowflake-ml-python
|
|
3
|
-
Version: 1.9.
|
|
3
|
+
Version: 1.9.2
|
|
4
4
|
Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
|
|
5
5
|
Author-email: "Snowflake, Inc" <support@snowflake.com>
|
|
6
6
|
License:
|
|
@@ -240,9 +240,10 @@ Requires-Dist: cloudpickle>=2.0.0
|
|
|
240
240
|
Requires-Dist: cryptography
|
|
241
241
|
Requires-Dist: fsspec[http]<2026,>=2024.6.1
|
|
242
242
|
Requires-Dist: importlib_resources<7,>=6.1.1
|
|
243
|
-
Requires-Dist: numpy<
|
|
243
|
+
Requires-Dist: numpy<3,>=1.23
|
|
244
244
|
Requires-Dist: packaging<25,>=20.9
|
|
245
245
|
Requires-Dist: pandas<3,>=2.1.4
|
|
246
|
+
Requires-Dist: platformdirs<5
|
|
246
247
|
Requires-Dist: pyarrow
|
|
247
248
|
Requires-Dist: pydantic<3,>=2.8.2
|
|
248
249
|
Requires-Dist: pyjwt<3,>=2.0.0
|
|
@@ -257,6 +258,7 @@ Requires-Dist: snowflake-connector-python[pandas]<4,>=3.15.0
|
|
|
257
258
|
Requires-Dist: snowflake-snowpark-python!=1.26.0,<2,>=1.17.0
|
|
258
259
|
Requires-Dist: snowflake.core<2,>=1.0.2
|
|
259
260
|
Requires-Dist: sqlparse<1,>=0.4
|
|
261
|
+
Requires-Dist: tqdm<5
|
|
260
262
|
Requires-Dist: typing-extensions<5,>=4.1.0
|
|
261
263
|
Requires-Dist: xgboost<3,>=1.7.3
|
|
262
264
|
Provides-Extra: all
|
|
@@ -272,7 +274,7 @@ Requires-Dist: tensorflow<3,>=2.17.0; extra == "all"
|
|
|
272
274
|
Requires-Dist: tokenizers<1,>=0.15.1; extra == "all"
|
|
273
275
|
Requires-Dist: torch<3,>=2.0.1; extra == "all"
|
|
274
276
|
Requires-Dist: torchdata<1,>=0.4; extra == "all"
|
|
275
|
-
Requires-Dist: transformers
|
|
277
|
+
Requires-Dist: transformers!=4.51.3,<5,>=4.39.3; extra == "all"
|
|
276
278
|
Provides-Extra: altair
|
|
277
279
|
Requires-Dist: altair<6,>=5; extra == "altair"
|
|
278
280
|
Provides-Extra: catboost
|
|
@@ -297,7 +299,7 @@ Requires-Dist: sentence-transformers<4,>=2.7.0; extra == "transformers"
|
|
|
297
299
|
Requires-Dist: sentencepiece<0.2.0,>=0.1.95; extra == "transformers"
|
|
298
300
|
Requires-Dist: tokenizers<1,>=0.15.1; extra == "transformers"
|
|
299
301
|
Requires-Dist: torch<3,>=2.0.1; extra == "transformers"
|
|
300
|
-
Requires-Dist: transformers
|
|
302
|
+
Requires-Dist: transformers!=4.51.3,<5,>=4.39.3; extra == "transformers"
|
|
301
303
|
Dynamic: license-file
|
|
302
304
|
|
|
303
305
|
# Snowpark ML
|
|
@@ -408,6 +410,92 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
|
408
410
|
|
|
409
411
|
# Release History
|
|
410
412
|
|
|
413
|
+
## 1.9.2
|
|
414
|
+
|
|
415
|
+
### Bug Fixes
|
|
416
|
+
|
|
417
|
+
- DataConnector: Fix `self._session` related errors inside Container Runtime.
|
|
418
|
+
- Registry: Fix a bug when trying to pass `None` to array (`pd.dtype('O')`) in signature and pandas data handler.
|
|
419
|
+
|
|
420
|
+
### New Features
|
|
421
|
+
|
|
422
|
+
- Experiment Tracking (PrPr): Automatically log the model, metrics, and parameters while training
|
|
423
|
+
XGBoost and LightGBM models.
|
|
424
|
+
|
|
425
|
+
```python
|
|
426
|
+
from snowflake.ml.experiment import ExperimentTracking
|
|
427
|
+
from snowflake.ml.experiment.callback import SnowflakeXgboostCallback, SnowflakeLightgbmCallback
|
|
428
|
+
|
|
429
|
+
exp = ExperimentTracking(session=sp_session, database_name="ML", schema_name="PUBLIC")
|
|
430
|
+
|
|
431
|
+
exp.set_experiment("MY_EXPERIMENT")
|
|
432
|
+
|
|
433
|
+
# XGBoost
|
|
434
|
+
callback = SnowflakeXgboostCallback(
|
|
435
|
+
exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig
|
|
436
|
+
)
|
|
437
|
+
model = XGBClassifier(callbacks=[callback])
|
|
438
|
+
with exp.start_run():
|
|
439
|
+
model.fit(X, y, eval_set=[(X_test, y_test)])
|
|
440
|
+
|
|
441
|
+
# LightGBM
|
|
442
|
+
callback = SnowflakeLightgbmCallback(
|
|
443
|
+
exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig
|
|
444
|
+
)
|
|
445
|
+
model = LGBMClassifier()
|
|
446
|
+
with exp.start_run():
|
|
447
|
+
model.fit(X, y, eval_set=[(X_test, y_test)], callbacks=[callback])
|
|
448
|
+
```
|
|
449
|
+
|
|
450
|
+
## 1.9.1 (07-18-2025)
|
|
451
|
+
|
|
452
|
+
### Bug Fixes
|
|
453
|
+
|
|
454
|
+
- Registry: Fix a bug when trying to set the PAD token the HuggingFace `text-generation` model had multiple EOS tokens.
|
|
455
|
+
The handler picks the first EOS token as PAD token now.
|
|
456
|
+
|
|
457
|
+
### New Features
|
|
458
|
+
|
|
459
|
+
- DataConnector: DataConnector objects can now be pickled
|
|
460
|
+
- Dataset: Dataset objects can now be pickled
|
|
461
|
+
- Registry (PrPr): Introducing `create_service` function in `snowflake/ml/model/models/huggingface_pipeline.py`
|
|
462
|
+
which creates a service to log a HF model and upon successful logging, an inference service is created.
|
|
463
|
+
|
|
464
|
+
```python
|
|
465
|
+
from snowflake.ml.model.models import huggingface_pipeline
|
|
466
|
+
|
|
467
|
+
hf_model_ref = huggingface_pipeline.HuggingFacePipelineModel(
|
|
468
|
+
model="gpt2",
|
|
469
|
+
task="text-generation", # Optional
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
hf_model_ref.create_service(
|
|
474
|
+
session=session,
|
|
475
|
+
service_name="test_service",
|
|
476
|
+
service_compute_pool="test_compute_pool",
|
|
477
|
+
image_repo="test_repo",
|
|
478
|
+
...
|
|
479
|
+
)
|
|
480
|
+
```
|
|
481
|
+
|
|
482
|
+
- Experiment Tracking (PrPr): New module for managing and tracking ML experiments in Snowflake.
|
|
483
|
+
|
|
484
|
+
```python
|
|
485
|
+
from snowflake.ml.experiment import ExperimentTracking
|
|
486
|
+
|
|
487
|
+
exp = ExperimentTracking(session=sp_session, database_name="ML", schema_name="PUBLIC")
|
|
488
|
+
|
|
489
|
+
exp.set_experiment("MY_EXPERIMENT")
|
|
490
|
+
|
|
491
|
+
with exp.start_run():
|
|
492
|
+
exp.log_param("batch_size", 32)
|
|
493
|
+
exp.log_metrics("accuracy", 0.98, step=10)
|
|
494
|
+
exp.log_model(my_model, model_name="MY_MODEL")
|
|
495
|
+
```
|
|
496
|
+
|
|
497
|
+
- Registry: Added support for wide input (500+ features) for inference done using SPCS
|
|
498
|
+
|
|
411
499
|
## 1.9.0
|
|
412
500
|
|
|
413
501
|
### Bug Fixes
|
|
@@ -415,6 +503,19 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
|
415
503
|
- Registry: Fixed bug causing snowpark to pandas dataframe conversion to fail when `QUOTED_IDENTIFIERS_IGNORE_CASE`
|
|
416
504
|
parameter is enabled
|
|
417
505
|
- Registry: Fixed duplicate UserWarning logs during model packaging
|
|
506
|
+
- Registry: If the huggingface pipeline text-generation model doesn't contain a default chat template, a ChatML template
|
|
507
|
+
is assigned to the tokenizer.
|
|
508
|
+
|
|
509
|
+
```shell
|
|
510
|
+
{% for message in messages %}
|
|
511
|
+
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
|
|
512
|
+
{% endfor %}
|
|
513
|
+
{% if add_generation_prompt %}
|
|
514
|
+
{{ '<|im_start|>assistant\n' }}
|
|
515
|
+
{% endif %}"
|
|
516
|
+
```
|
|
517
|
+
|
|
518
|
+
- Registry: Fixed SQL queries during registry initialization that were forcing warehouse requirement
|
|
418
519
|
|
|
419
520
|
### Behavior Changes
|
|
420
521
|
|
|
@@ -524,7 +625,8 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
|
524
625
|
- Pre-created Snowpark Session is now available inside job payloads using
|
|
525
626
|
`snowflake.snowpark.context.get_active_session()`
|
|
526
627
|
- Registry: Introducing `save_location` to `log_model` using the `options` argument.
|
|
527
|
-
|
|
628
|
+
Users can use the `save_location` option to specify a local directory where the model files and configuration are written.
|
|
629
|
+
This is useful when the default temporary directory has space limitations.
|
|
528
630
|
|
|
529
631
|
```python
|
|
530
632
|
reg.log_model(
|