snowflake-ml-python 1.7.4__py3-none-any.whl → 1.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +58 -3
- snowflake/ml/_internal/env_utils.py +64 -21
- snowflake/ml/_internal/file_utils.py +18 -4
- snowflake/ml/_internal/platform_capabilities.py +3 -0
- snowflake/ml/_internal/relax_version_strategy.py +16 -0
- snowflake/ml/_internal/telemetry.py +25 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +1 -1
- snowflake/ml/feature_store/feature_store.py +18 -0
- snowflake/ml/feature_store/feature_view.py +46 -1
- snowflake/ml/fileset/fileset.py +0 -1
- snowflake/ml/jobs/_utils/constants.py +31 -1
- snowflake/ml/jobs/_utils/payload_utils.py +232 -72
- snowflake/ml/jobs/_utils/spec_utils.py +78 -38
- snowflake/ml/jobs/decorators.py +8 -25
- snowflake/ml/jobs/job.py +4 -4
- snowflake/ml/jobs/manager.py +5 -0
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/model_ops.py +107 -14
- snowflake/ml/model/_client/ops/service_ops.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec.py +7 -3
- snowflake/ml/model/_client/sql/model_version.py +58 -0
- snowflake/ml/model/_client/sql/service.py +8 -2
- snowflake/ml/model/_model_composer/model_composer.py +50 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +4 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +0 -1
- snowflake/ml/model/_packager/model_env/model_env.py +49 -29
- snowflake/ml/model/_packager/model_handlers/_utils.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +44 -24
- snowflake/ml/model/_packager/model_handlers/keras.py +226 -0
- snowflake/ml/model/_packager/model_handlers/pytorch.py +51 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +25 -3
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +73 -21
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -72
- snowflake/ml/model/_packager/model_handlers/torchscript.py +49 -20
- snowflake/ml/model/_packager/model_handlers/xgboost.py +2 -2
- snowflake/ml/model/_packager/model_handlers_migrator/pytorch_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2023_12_01.py +48 -0
- snowflake/ml/model/_packager/model_handlers_migrator/tensorflow_migrator_2025_01_01.py +19 -0
- snowflake/ml/model/_packager/model_handlers_migrator/torchscript_migrator_2023_12_01.py +20 -0
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +6 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +16 -0
- snowflake/ml/model/_packager/model_packager.py +3 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -2
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +8 -1
- snowflake/ml/model/_packager/model_task/model_task_utils.py +5 -1
- snowflake/ml/model/_signatures/builtins_handler.py +20 -9
- snowflake/ml/model/_signatures/core.py +54 -33
- snowflake/ml/model/_signatures/dmatrix_handler.py +98 -0
- snowflake/ml/model/_signatures/numpy_handler.py +12 -20
- snowflake/ml/model/_signatures/pandas_handler.py +28 -37
- snowflake/ml/model/_signatures/pytorch_handler.py +57 -41
- snowflake/ml/model/_signatures/snowpark_handler.py +0 -12
- snowflake/ml/model/_signatures/tensorflow_handler.py +61 -67
- snowflake/ml/model/_signatures/utils.py +120 -8
- snowflake/ml/model/custom_model.py +13 -4
- snowflake/ml/model/model_signature.py +39 -13
- snowflake/ml/model/type_hints.py +28 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +14 -1
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/pipeline/pipeline.py +18 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +2 -2
- snowflake/ml/registry/_manager/model_manager.py +55 -7
- snowflake/ml/registry/registry.py +52 -4
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/METADATA +336 -27
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/RECORD +73 -66
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info/licenses}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.4.dist-info → snowflake_ml_python-1.8.0.dist-info}/top_level.txt +0 -0
@@ -98,6 +98,7 @@ def precision_recall_curve(
|
|
98
98
|
packages=[
|
99
99
|
f"cloudpickle=={cloudpickle.__version__}",
|
100
100
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
101
|
+
f"numpy=={np.__version__}",
|
101
102
|
"snowflake-snowpark-python",
|
102
103
|
],
|
103
104
|
statement_params=statement_params,
|
@@ -245,6 +246,7 @@ def roc_auc_score(
|
|
245
246
|
packages=[
|
246
247
|
f"cloudpickle=={cloudpickle.__version__}",
|
247
248
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
249
|
+
f"numpy=={np.__version__}",
|
248
250
|
"snowflake-snowpark-python",
|
249
251
|
],
|
250
252
|
statement_params=statement_params,
|
@@ -348,6 +350,7 @@ def roc_curve(
|
|
348
350
|
packages=[
|
349
351
|
f"cloudpickle=={cloudpickle.__version__}",
|
350
352
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
353
|
+
f"numpy=={np.__version__}",
|
351
354
|
"snowflake-snowpark-python",
|
352
355
|
],
|
353
356
|
statement_params=statement_params,
|
@@ -83,6 +83,7 @@ def d2_absolute_error_score(
|
|
83
83
|
packages=[
|
84
84
|
f"cloudpickle=={cloudpickle.__version__}",
|
85
85
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
86
|
+
f"numpy=={np.__version__}",
|
86
87
|
"snowflake-snowpark-python",
|
87
88
|
],
|
88
89
|
statement_params=statement_params,
|
@@ -180,6 +181,7 @@ def d2_pinball_score(
|
|
180
181
|
packages=[
|
181
182
|
f"cloudpickle=={cloudpickle.__version__}",
|
182
183
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
184
|
+
f"numpy=={np.__version__}",
|
183
185
|
"snowflake-snowpark-python",
|
184
186
|
],
|
185
187
|
statement_params=statement_params,
|
@@ -295,6 +297,7 @@ def explained_variance_score(
|
|
295
297
|
packages=[
|
296
298
|
f"cloudpickle=={cloudpickle.__version__}",
|
297
299
|
f"scikit-learn=={sklearn_release[0]}.{sklearn_release[1]}.*",
|
300
|
+
f"numpy=={np.__version__}",
|
298
301
|
"snowflake-snowpark-python",
|
299
302
|
],
|
300
303
|
statement_params=statement_params,
|
@@ -854,6 +854,7 @@ class Pipeline(base.BaseTransformer):
|
|
854
854
|
# Create a fitted sklearn pipeline object by translating each non-estimator step in pipeline with with
|
855
855
|
# a fitted column transformer.
|
856
856
|
sksteps = []
|
857
|
+
i = 0
|
857
858
|
for i, (name, trans) in enumerate(self._get_transformers()):
|
858
859
|
if isinstance(trans, base.BaseTransformer):
|
859
860
|
trans = self._construct_fitted_column_transformer_object(
|
@@ -899,7 +900,23 @@ class Pipeline(base.BaseTransformer):
|
|
899
900
|
if estimator_step:
|
900
901
|
estimator_signatures = estimator_step[1].model_signatures
|
901
902
|
for method, signature in estimator_signatures.items():
|
902
|
-
|
903
|
+
# Add the inferred input signature to the model signature dictionary for each method
|
904
|
+
self._model_signature_dict[method] = ModelSignature(
|
905
|
+
inputs=inputs_signature,
|
906
|
+
outputs=(
|
907
|
+
# If _drop_input_cols is True, do not include any input columns in the output signature
|
908
|
+
[]
|
909
|
+
if self._drop_input_cols
|
910
|
+
else [
|
911
|
+
# Include input columns in the output signature if they are not already present
|
912
|
+
# Those already present means they are overwritten by the output of the estimator
|
913
|
+
spec
|
914
|
+
for spec in inputs_signature
|
915
|
+
if spec.name not in [_spec.name for _spec in signature.outputs]
|
916
|
+
]
|
917
|
+
)
|
918
|
+
+ signature.outputs, # Append the existing output signature
|
919
|
+
)
|
903
920
|
|
904
921
|
@property
|
905
922
|
def model_signatures(self) -> Dict[str, ModelSignature]:
|
@@ -341,7 +341,7 @@ class KBinsDiscretizer(base.BaseTransformer):
|
|
341
341
|
is_permanent=False,
|
342
342
|
name=udf_name,
|
343
343
|
replace=True,
|
344
|
-
packages=["numpy"],
|
344
|
+
packages=[f"numpy=={np.__version__}"],
|
345
345
|
session=dataset._session,
|
346
346
|
statement_params=telemetry.get_statement_params(base.PROJECT, base.SUBPROJECT, self.__class__.__name__),
|
347
347
|
)
|
@@ -337,7 +337,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
337
337
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
338
338
|
|
339
339
|
if isinstance(dataset, DataFrame):
|
340
|
-
expected_type_inferred = ""
|
340
|
+
expected_type_inferred = "float"
|
341
341
|
# when it is classifier, infer the datatype from label columns
|
342
342
|
if expected_type_inferred == "" and 'predict' in self.model_signatures:
|
343
343
|
# Batch inference takes a single expected output column type. Use the first columns type for now.
|
@@ -415,7 +415,7 @@ class PolynomialFeatures(BaseTransformer):
|
|
415
415
|
# are specific to the type of dataset used.
|
416
416
|
transform_kwargs: BatchInferenceKwargsTypedDict = dict()
|
417
417
|
if isinstance(dataset, DataFrame):
|
418
|
-
expected_dtype = ""
|
418
|
+
expected_dtype = "float"
|
419
419
|
if False: # is child of _BaseHeterogeneousEnsemble
|
420
420
|
# transform() method of HeterogeneousEnsemble estimators return responses of varying shapes
|
421
421
|
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes) (and everything in between)
|
@@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
6
6
|
|
7
|
-
from snowflake.ml._internal import telemetry
|
7
|
+
from snowflake.ml._internal import platform_capabilities, telemetry
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
@@ -13,7 +13,7 @@ from snowflake.ml.model._client.model import model_impl, model_version_impl
|
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
15
|
from snowflake.ml.model._packager.model_meta import model_meta
|
16
|
-
from snowflake.snowpark import session
|
16
|
+
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
17
17
|
|
18
18
|
logger = logging.getLogger(__name__)
|
19
19
|
|
@@ -46,6 +46,7 @@ class ModelManager:
|
|
46
46
|
metrics: Optional[Dict[str, Any]] = None,
|
47
47
|
conda_dependencies: Optional[List[str]] = None,
|
48
48
|
pip_requirements: Optional[List[str]] = None,
|
49
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
49
50
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
50
51
|
python_version: Optional[str] = None,
|
51
52
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -127,6 +128,7 @@ class ModelManager:
|
|
127
128
|
metrics=metrics,
|
128
129
|
conda_dependencies=conda_dependencies,
|
129
130
|
pip_requirements=pip_requirements,
|
131
|
+
artifact_repository_map=artifact_repository_map,
|
130
132
|
target_platforms=target_platforms,
|
131
133
|
python_version=python_version,
|
132
134
|
signatures=signatures,
|
@@ -149,6 +151,7 @@ class ModelManager:
|
|
149
151
|
metrics: Optional[Dict[str, Any]] = None,
|
150
152
|
conda_dependencies: Optional[List[str]] = None,
|
151
153
|
pip_requirements: Optional[List[str]] = None,
|
154
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
152
155
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
153
156
|
python_version: Optional[str] = None,
|
154
157
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -163,11 +166,42 @@ class ModelManager:
|
|
163
166
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
164
167
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
165
168
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
169
|
+
use_live_commit = platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
|
170
|
+
if use_live_commit:
|
171
|
+
logger.info("Using live commit model version")
|
172
|
+
else:
|
173
|
+
logger.info("Using non-live commit model version")
|
174
|
+
|
175
|
+
if use_live_commit:
|
176
|
+
# This step creates the live model version, and the files can be written directly to the stage
|
177
|
+
# after this.
|
178
|
+
try:
|
179
|
+
self._model_ops.add_or_create_live_version(
|
180
|
+
database_name=database_name_id,
|
181
|
+
schema_name=schema_name_id,
|
182
|
+
model_name=model_name_id,
|
183
|
+
version_name=version_name_id,
|
184
|
+
statement_params=statement_params,
|
185
|
+
)
|
186
|
+
except (AssertionError, snowpark_exceptions.SnowparkSQLException) as e:
|
187
|
+
logger.info(f"Failed to create live model version: {e}, falling back to regular model version creation")
|
188
|
+
use_live_commit = False
|
189
|
+
|
190
|
+
if use_live_commit:
|
191
|
+
# using model version's stage path to write files directly to the stage
|
192
|
+
stage_path = self._model_ops.get_model_version_stage_path(
|
193
|
+
database_name=database_name_id,
|
194
|
+
schema_name=schema_name_id,
|
195
|
+
model_name=model_name_id,
|
196
|
+
version_name=version_name_id,
|
197
|
+
)
|
198
|
+
else:
|
199
|
+
# using a temp path to write files and then upload to the model version's stage
|
200
|
+
stage_path = self._model_ops.prepare_model_temp_stage_path(
|
201
|
+
database_name=database_name_id,
|
202
|
+
schema_name=schema_name_id,
|
203
|
+
statement_params=statement_params,
|
204
|
+
)
|
171
205
|
|
172
206
|
platforms = None
|
173
207
|
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
@@ -175,6 +209,18 @@ class ModelManager:
|
|
175
209
|
# Convert any string target platforms to TargetPlatform objects
|
176
210
|
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
177
211
|
|
212
|
+
if artifact_repository_map:
|
213
|
+
for channel, artifact_repository_name in artifact_repository_map.items():
|
214
|
+
db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
|
215
|
+
|
216
|
+
artifact_repository_map[channel] = sql_identifier.get_fully_qualified_name(
|
217
|
+
db_id,
|
218
|
+
schema_id,
|
219
|
+
repo_id,
|
220
|
+
self._database_name,
|
221
|
+
self._schema_name,
|
222
|
+
)
|
223
|
+
|
178
224
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
179
225
|
|
180
226
|
mc = model_composer.ModelComposer(
|
@@ -187,6 +233,7 @@ class ModelManager:
|
|
187
233
|
sample_input_data=sample_input_data,
|
188
234
|
conda_dependencies=conda_dependencies,
|
189
235
|
pip_requirements=pip_requirements,
|
236
|
+
artifact_repository_map=artifact_repository_map,
|
190
237
|
target_platforms=platforms,
|
191
238
|
python_version=python_version,
|
192
239
|
user_files=user_files,
|
@@ -211,6 +258,7 @@ class ModelManager:
|
|
211
258
|
model_name=model_name_id,
|
212
259
|
version_name=version_name_id,
|
213
260
|
statement_params=statement_params,
|
261
|
+
use_live_commit=use_live_commit,
|
214
262
|
)
|
215
263
|
|
216
264
|
mv = model_version_impl.ModelVersion._ref(
|
@@ -78,7 +78,7 @@ class Registry:
|
|
78
78
|
session, database_name=self._database_name, schema_name=self._schema_name
|
79
79
|
)
|
80
80
|
|
81
|
-
self.enable_monitoring = options.get("enable_monitoring",
|
81
|
+
self.enable_monitoring = options.get("enable_monitoring", True) if options else True
|
82
82
|
if self.enable_monitoring:
|
83
83
|
monitor_statement_params = telemetry.get_statement_params(
|
84
84
|
project=telemetry.TelemetryProject.MLOPS.value,
|
@@ -108,6 +108,7 @@ class Registry:
|
|
108
108
|
metrics: Optional[Dict[str, Any]] = None,
|
109
109
|
conda_dependencies: Optional[List[str]] = None,
|
110
110
|
pip_requirements: Optional[List[str]] = None,
|
111
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
111
112
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
112
113
|
python_version: Optional[str] = None,
|
113
114
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -140,6 +141,13 @@ class Registry:
|
|
140
141
|
See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
141
142
|
Models with pip requirements specified will not be executable in Snowflake Warehouse where all
|
142
143
|
dependencies must be retrieved from Snowflake Anaconda Channel.
|
144
|
+
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
145
|
+
repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
|
146
|
+
Note : This feature is currently in Private Preview; please contact your Snowflake account team
|
147
|
+
to enable it.
|
148
|
+
Format: {channel_name: artifact_repository_name}, where:
|
149
|
+
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
150
|
+
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
143
151
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
144
152
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
145
153
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -162,8 +170,12 @@ class Registry:
|
|
162
170
|
- relax_version: Whether to relax the version constraints of the dependencies when running in the
|
163
171
|
Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
164
172
|
- function_type: Set the method function type globally. To set method function types individually see
|
165
|
-
|
166
|
-
- method_options: Per-method saving options
|
173
|
+
function_type in model_options.
|
174
|
+
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
175
|
+
values with the desired options.
|
176
|
+
|
177
|
+
The following are the available method options:
|
178
|
+
|
167
179
|
- case_sensitive: Indicates whether the method and its signature should be case sensitive.
|
168
180
|
This means when you refer the method in the SQL, you need to double quote it.
|
169
181
|
This will be helpful if you need case to tell apart your methods or features, or you have
|
@@ -206,6 +218,7 @@ class Registry:
|
|
206
218
|
"metrics",
|
207
219
|
"conda_dependencies",
|
208
220
|
"pip_requirements",
|
221
|
+
"artifact_repository_map",
|
209
222
|
"target_platforms",
|
210
223
|
"python_version",
|
211
224
|
"signatures",
|
@@ -221,6 +234,7 @@ class Registry:
|
|
221
234
|
metrics: Optional[Dict[str, Any]] = None,
|
222
235
|
conda_dependencies: Optional[List[str]] = None,
|
223
236
|
pip_requirements: Optional[List[str]] = None,
|
237
|
+
artifact_repository_map: Optional[Dict[str, str]] = None,
|
224
238
|
target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
|
225
239
|
python_version: Optional[str] = None,
|
226
240
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
@@ -255,6 +269,13 @@ class Registry:
|
|
255
269
|
See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
|
256
270
|
Models with pip requirements specified will not be executable in Snowflake Warehouse where all
|
257
271
|
dependencies must be retrieved from Snowflake Anaconda Channel.
|
272
|
+
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
273
|
+
repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
|
274
|
+
Note : This feature is currently in Private Preview; please contact your Snowflake account team to
|
275
|
+
enable it.
|
276
|
+
Format: {channel_name: artifact_repository_name}, where:
|
277
|
+
- channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
|
278
|
+
- artifact_repository_name: The name or URL of the repository to fetch packages from.
|
258
279
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
259
280
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
260
281
|
python_version: Python version in which the model is run. Defaults to None.
|
@@ -283,7 +304,11 @@ class Registry:
|
|
283
304
|
Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
284
305
|
- function_type: Set the method function type globally. To set method function types individually see
|
285
306
|
function_type in model_options.
|
286
|
-
- method_options: Per-method saving options
|
307
|
+
- method_options: Per-method saving options. This dictionary has method names as keys and dictionary
|
308
|
+
values with the desired options. See the example below.
|
309
|
+
|
310
|
+
The following are the available method options:
|
311
|
+
|
287
312
|
- case_sensitive: Indicates whether the method and its signature should be case sensitive.
|
288
313
|
This means when you refer the method in the SQL, you need to double quote it.
|
289
314
|
This will be helpful if you need case to tell apart your methods or features, or you have
|
@@ -294,6 +319,28 @@ class Registry:
|
|
294
319
|
|
295
320
|
Returns:
|
296
321
|
ModelVersion: ModelVersion object corresponding to the model just logged.
|
322
|
+
|
323
|
+
Example::
|
324
|
+
|
325
|
+
from snowflake.ml.registry import Registry
|
326
|
+
|
327
|
+
# create a session
|
328
|
+
session = ...
|
329
|
+
|
330
|
+
registry = Registry(session=session)
|
331
|
+
|
332
|
+
# Define `method_options` for each inference method if needed.
|
333
|
+
method_options={
|
334
|
+
"predict": {
|
335
|
+
"case_sensitive": True
|
336
|
+
}
|
337
|
+
}
|
338
|
+
|
339
|
+
registry.log_model(
|
340
|
+
model=model,
|
341
|
+
model_name="my_model",
|
342
|
+
method_options=method_options,
|
343
|
+
)
|
297
344
|
"""
|
298
345
|
statement_params = telemetry.get_statement_params(
|
299
346
|
project=_TELEMETRY_PROJECT,
|
@@ -315,6 +362,7 @@ class Registry:
|
|
315
362
|
metrics=metrics,
|
316
363
|
conda_dependencies=conda_dependencies,
|
317
364
|
pip_requirements=pip_requirements,
|
365
|
+
artifact_repository_map=artifact_repository_map,
|
318
366
|
target_platforms=target_platforms,
|
319
367
|
python_version=python_version,
|
320
368
|
signatures=signatures,
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.
|
1
|
+
VERSION="1.8.0"
|