snowflake-ml-python 1.9.0__py3-none-any.whl → 1.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +54 -42
- snowflake/ml/_internal/utils/service_logger.py +105 -3
- snowflake/ml/data/_internal/arrow_ingestor.py +15 -2
- snowflake/ml/data/data_connector.py +13 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +2 -1
- snowflake/ml/dataset/dataset_reader.py +14 -4
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/callback.py +121 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +156 -54
- snowflake/ml/jobs/_utils/query_helper.py +16 -5
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +130 -23
- snowflake/ml/jobs/_utils/spec_utils.py +23 -8
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +70 -75
- snowflake/ml/jobs/manager.py +59 -31
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/service_ops.py +336 -137
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -1
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +17 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +11 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +32 -15
- snowflake/ml/registry/registry.py +48 -80
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/METADATA +107 -5
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/RECORD +62 -52
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.0.dist-info → snowflake_ml_python-1.9.2.dist-info}/top_level.txt +0 -0
|
@@ -1,8 +1,22 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import warnings
|
|
2
|
-
from typing import Any, Optional
|
|
3
|
+
from typing import Any, Optional, Union
|
|
3
4
|
|
|
4
5
|
from packaging import version
|
|
5
6
|
|
|
7
|
+
from snowflake import snowpark
|
|
8
|
+
from snowflake.ml._internal import telemetry
|
|
9
|
+
from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
10
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
11
|
+
from snowflake.ml.model._client.ops import service_ops
|
|
12
|
+
from snowflake.snowpark import async_job, session
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
_TELEMETRY_PROJECT = "MLOps"
|
|
18
|
+
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
19
|
+
|
|
6
20
|
|
|
7
21
|
class HuggingFacePipelineModel:
|
|
8
22
|
def __init__(
|
|
@@ -214,4 +228,159 @@ class HuggingFacePipelineModel:
|
|
|
214
228
|
self.token = token
|
|
215
229
|
self.trust_remote_code = trust_remote_code
|
|
216
230
|
self.model_kwargs = model_kwargs
|
|
231
|
+
self.tokenizer = tokenizer
|
|
217
232
|
self.__dict__.update(kwargs)
|
|
233
|
+
|
|
234
|
+
@telemetry.send_api_usage_telemetry(
|
|
235
|
+
project=_TELEMETRY_PROJECT,
|
|
236
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
237
|
+
func_params_to_log=[
|
|
238
|
+
"service_name",
|
|
239
|
+
"image_build_compute_pool",
|
|
240
|
+
"service_compute_pool",
|
|
241
|
+
"image_repo",
|
|
242
|
+
"gpu_requests",
|
|
243
|
+
"num_workers",
|
|
244
|
+
"max_batch_rows",
|
|
245
|
+
],
|
|
246
|
+
)
|
|
247
|
+
@snowpark._internal.utils.private_preview(version="1.9.1")
|
|
248
|
+
def create_service(
|
|
249
|
+
self,
|
|
250
|
+
*,
|
|
251
|
+
session: session.Session,
|
|
252
|
+
# registry.log_model parameters
|
|
253
|
+
model_name: str,
|
|
254
|
+
version_name: Optional[str] = None,
|
|
255
|
+
pip_requirements: Optional[list[str]] = None,
|
|
256
|
+
conda_dependencies: Optional[list[str]] = None,
|
|
257
|
+
comment: Optional[str] = None,
|
|
258
|
+
# model_version_impl.create_service parameters
|
|
259
|
+
service_name: str,
|
|
260
|
+
service_compute_pool: str,
|
|
261
|
+
image_repo: str,
|
|
262
|
+
image_build_compute_pool: Optional[str] = None,
|
|
263
|
+
ingress_enabled: bool = False,
|
|
264
|
+
max_instances: int = 1,
|
|
265
|
+
cpu_requests: Optional[str] = None,
|
|
266
|
+
memory_requests: Optional[str] = None,
|
|
267
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
|
268
|
+
num_workers: Optional[int] = None,
|
|
269
|
+
max_batch_rows: Optional[int] = None,
|
|
270
|
+
force_rebuild: bool = False,
|
|
271
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
|
272
|
+
block: bool = True,
|
|
273
|
+
) -> Union[str, async_job.AsyncJob]:
|
|
274
|
+
"""Logs a Hugging Face model and creates a service in Snowflake.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
session: The Snowflake session object.
|
|
278
|
+
model_name: The name of the model in Snowflake.
|
|
279
|
+
version_name: The version name of the model. Defaults to None.
|
|
280
|
+
pip_requirements: Pip requirements for the model. Defaults to None.
|
|
281
|
+
conda_dependencies: Conda dependencies for the model. Defaults to None.
|
|
282
|
+
comment: Comment for the model. Defaults to None.
|
|
283
|
+
service_name: The name of the service to create.
|
|
284
|
+
service_compute_pool: The compute pool for the service.
|
|
285
|
+
image_repo: The name of the image repository.
|
|
286
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
|
287
|
+
the service compute pool if None.
|
|
288
|
+
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
|
289
|
+
max_instances: Maximum number of instances. Defaults to 1.
|
|
290
|
+
cpu_requests: CPU requests configuration. Defaults to None.
|
|
291
|
+
memory_requests: Memory requests configuration. Defaults to None.
|
|
292
|
+
gpu_requests: GPU requests configuration. Defaults to None.
|
|
293
|
+
num_workers: Number of workers. Defaults to None.
|
|
294
|
+
max_batch_rows: Maximum batch rows. Defaults to None.
|
|
295
|
+
force_rebuild: Whether to force rebuild the image. Defaults to False.
|
|
296
|
+
build_external_access_integrations: External access integrations for building the image. Defaults to None.
|
|
297
|
+
block: Whether to block the operation. Defaults to True.
|
|
298
|
+
|
|
299
|
+
Raises:
|
|
300
|
+
ValueError: if database and schema name is not provided and session doesn't have a
|
|
301
|
+
database and schema name.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
The service ID or an async job object.
|
|
305
|
+
|
|
306
|
+
.. # noqa: DAR003
|
|
307
|
+
"""
|
|
308
|
+
statement_params = telemetry.get_statement_params(
|
|
309
|
+
project=_TELEMETRY_PROJECT,
|
|
310
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
|
314
|
+
session_database_name = session.get_current_database()
|
|
315
|
+
session_schema_name = session.get_current_schema()
|
|
316
|
+
if database_name_id is None:
|
|
317
|
+
if session_database_name is None:
|
|
318
|
+
raise ValueError("Either database needs to be provided or needs to be available in session.")
|
|
319
|
+
database_name_id = sql_identifier.SqlIdentifier(session_database_name)
|
|
320
|
+
if schema_name_id is None:
|
|
321
|
+
if session_schema_name is None:
|
|
322
|
+
raise ValueError("Either schema needs to be provided or needs to be available in session.")
|
|
323
|
+
schema_name_id = sql_identifier.SqlIdentifier(session_schema_name)
|
|
324
|
+
|
|
325
|
+
if version_name is None:
|
|
326
|
+
name_generator = hrid_generator.HRID16()
|
|
327
|
+
version_name = name_generator.generate()[1]
|
|
328
|
+
|
|
329
|
+
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
330
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
|
331
|
+
|
|
332
|
+
service_operator = service_ops.ServiceOperator(
|
|
333
|
+
session=session,
|
|
334
|
+
database_name=database_name_id,
|
|
335
|
+
schema_name=schema_name_id,
|
|
336
|
+
)
|
|
337
|
+
logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
|
|
338
|
+
|
|
339
|
+
return service_operator.create_service(
|
|
340
|
+
database_name=database_name_id,
|
|
341
|
+
schema_name=schema_name_id,
|
|
342
|
+
model_name=model_name_id,
|
|
343
|
+
version_name=sql_identifier.SqlIdentifier(version_name),
|
|
344
|
+
service_database_name=service_db_id,
|
|
345
|
+
service_schema_name=service_schema_id,
|
|
346
|
+
service_name=service_id,
|
|
347
|
+
image_build_compute_pool_name=(
|
|
348
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
|
349
|
+
if image_build_compute_pool
|
|
350
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
351
|
+
),
|
|
352
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
353
|
+
image_repo_database_name=image_repo_db_id,
|
|
354
|
+
image_repo_schema_name=image_repo_schema_id,
|
|
355
|
+
image_repo_name=image_repo_id,
|
|
356
|
+
ingress_enabled=ingress_enabled,
|
|
357
|
+
max_instances=max_instances,
|
|
358
|
+
cpu_requests=cpu_requests,
|
|
359
|
+
memory_requests=memory_requests,
|
|
360
|
+
gpu_requests=gpu_requests,
|
|
361
|
+
num_workers=num_workers,
|
|
362
|
+
max_batch_rows=max_batch_rows,
|
|
363
|
+
force_rebuild=force_rebuild,
|
|
364
|
+
build_external_access_integrations=(
|
|
365
|
+
None
|
|
366
|
+
if build_external_access_integrations is None
|
|
367
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
368
|
+
),
|
|
369
|
+
block=block,
|
|
370
|
+
statement_params=statement_params,
|
|
371
|
+
# hf model
|
|
372
|
+
hf_model_args=service_ops.HFModelArgs(
|
|
373
|
+
hf_model_name=self.model,
|
|
374
|
+
hf_task=self.task,
|
|
375
|
+
hf_tokenizer=self.tokenizer,
|
|
376
|
+
hf_revision=self.revision,
|
|
377
|
+
hf_token=self.token,
|
|
378
|
+
hf_trust_remote_code=bool(self.trust_remote_code),
|
|
379
|
+
hf_model_kwargs=self.model_kwargs,
|
|
380
|
+
pip_requirements=pip_requirements,
|
|
381
|
+
conda_dependencies=conda_dependencies,
|
|
382
|
+
comment=comment,
|
|
383
|
+
# TODO: remove warehouse in the next release
|
|
384
|
+
warehouse=session.get_current_warehouse(),
|
|
385
|
+
),
|
|
386
|
+
)
|
|
@@ -698,7 +698,7 @@ class BaseTransformer(BaseEstimator):
|
|
|
698
698
|
self,
|
|
699
699
|
attribute: Optional[Mapping[str, Union[int, float, str, Iterable[Union[int, float, str]]]]],
|
|
700
700
|
dtype: Optional[type] = None,
|
|
701
|
-
) -> Optional[npt.NDArray[Union[np.int_, np.
|
|
701
|
+
) -> Optional[npt.NDArray[Union[np.int_, np.float64, np.str_]]]:
|
|
702
702
|
"""
|
|
703
703
|
Convert the attribute from dict to ndarray based on the order of `self.input_cols`.
|
|
704
704
|
|
|
@@ -96,7 +96,7 @@ def confusion_matrix(
|
|
|
96
96
|
labels: Optional[npt.ArrayLike] = None,
|
|
97
97
|
sample_weight_col_name: Optional[str] = None,
|
|
98
98
|
normalize: Optional[str] = None,
|
|
99
|
-
) -> Union[npt.NDArray[np.int_], npt.NDArray[np.
|
|
99
|
+
) -> Union[npt.NDArray[np.int_], npt.NDArray[np.float64]]:
|
|
100
100
|
"""
|
|
101
101
|
Compute confusion matrix to evaluate the accuracy of a classification.
|
|
102
102
|
|
|
@@ -320,7 +320,7 @@ def f1_score(
|
|
|
320
320
|
average: Optional[str] = "binary",
|
|
321
321
|
sample_weight_col_name: Optional[str] = None,
|
|
322
322
|
zero_division: Union[str, int] = "warn",
|
|
323
|
-
) -> Union[float, npt.NDArray[np.
|
|
323
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
324
324
|
"""
|
|
325
325
|
Compute the F1 score, also known as balanced F-score or F-measure.
|
|
326
326
|
|
|
@@ -414,7 +414,7 @@ def fbeta_score(
|
|
|
414
414
|
average: Optional[str] = "binary",
|
|
415
415
|
sample_weight_col_name: Optional[str] = None,
|
|
416
416
|
zero_division: Union[str, int] = "warn",
|
|
417
|
-
) -> Union[float, npt.NDArray[np.
|
|
417
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
418
418
|
"""
|
|
419
419
|
Compute the F-beta score.
|
|
420
420
|
|
|
@@ -696,7 +696,7 @@ def precision_recall_fscore_support(
|
|
|
696
696
|
zero_division: Union[str, int] = "warn",
|
|
697
697
|
) -> Union[
|
|
698
698
|
tuple[float, float, float, None],
|
|
699
|
-
tuple[npt.NDArray[np.
|
|
699
|
+
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
|
|
700
700
|
]:
|
|
701
701
|
"""
|
|
702
702
|
Compute precision, recall, F-measure and support for each class.
|
|
@@ -855,7 +855,7 @@ def precision_recall_fscore_support(
|
|
|
855
855
|
|
|
856
856
|
res: Union[
|
|
857
857
|
tuple[float, float, float, None],
|
|
858
|
-
tuple[npt.NDArray[np.
|
|
858
|
+
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
|
|
859
859
|
] = result_object[:4]
|
|
860
860
|
warning = result_object[-1]
|
|
861
861
|
if warning:
|
|
@@ -1050,7 +1050,7 @@ def _register_multilabel_confusion_matrix_computer(
|
|
|
1050
1050
|
|
|
1051
1051
|
def end_partition(
|
|
1052
1052
|
self,
|
|
1053
|
-
) -> Iterable[tuple[npt.NDArray[np.
|
|
1053
|
+
) -> Iterable[tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]]:
|
|
1054
1054
|
MCM = metrics.multilabel_confusion_matrix(
|
|
1055
1055
|
self._y_true,
|
|
1056
1056
|
self._y_pred,
|
|
@@ -1098,7 +1098,7 @@ def _binary_precision_score(
|
|
|
1098
1098
|
pos_label: Union[str, int] = 1,
|
|
1099
1099
|
sample_weight_col_name: Optional[str] = None,
|
|
1100
1100
|
zero_division: Union[str, int] = "warn",
|
|
1101
|
-
) -> Union[float, npt.NDArray[np.
|
|
1101
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
1102
1102
|
|
|
1103
1103
|
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
|
|
1104
1104
|
|
|
@@ -1173,7 +1173,7 @@ def precision_score(
|
|
|
1173
1173
|
average: Optional[str] = "binary",
|
|
1174
1174
|
sample_weight_col_name: Optional[str] = None,
|
|
1175
1175
|
zero_division: Union[str, int] = "warn",
|
|
1176
|
-
) -> Union[float, npt.NDArray[np.
|
|
1176
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
1177
1177
|
"""
|
|
1178
1178
|
Compute the precision.
|
|
1179
1179
|
|
|
@@ -1271,7 +1271,7 @@ def recall_score(
|
|
|
1271
1271
|
average: Optional[str] = "binary",
|
|
1272
1272
|
sample_weight_col_name: Optional[str] = None,
|
|
1273
1273
|
zero_division: Union[str, int] = "warn",
|
|
1274
|
-
) -> Union[float, npt.NDArray[np.
|
|
1274
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
1275
1275
|
"""
|
|
1276
1276
|
Compute the recall.
|
|
1277
1277
|
|
|
@@ -1406,14 +1406,14 @@ def _check_binary_labels(
|
|
|
1406
1406
|
|
|
1407
1407
|
|
|
1408
1408
|
def _prf_divide(
|
|
1409
|
-
numerator: npt.NDArray[np.
|
|
1410
|
-
denominator: npt.NDArray[np.
|
|
1409
|
+
numerator: npt.NDArray[np.float64],
|
|
1410
|
+
denominator: npt.NDArray[np.float64],
|
|
1411
1411
|
metric: str,
|
|
1412
1412
|
modifier: str,
|
|
1413
1413
|
average: Optional[str] = None,
|
|
1414
1414
|
warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
|
|
1415
1415
|
zero_division: Union[str, int] = "warn",
|
|
1416
|
-
) -> npt.NDArray[np.
|
|
1416
|
+
) -> npt.NDArray[np.float64]:
|
|
1417
1417
|
"""Performs division and handles divide-by-zero.
|
|
1418
1418
|
|
|
1419
1419
|
On zero-division, sets the corresponding result elements equal to
|
|
@@ -1436,7 +1436,7 @@ def _prf_divide(
|
|
|
1436
1436
|
"warn", this acts as 0, but warnings are also raised.
|
|
1437
1437
|
|
|
1438
1438
|
Returns:
|
|
1439
|
-
npt.NDArray[np.
|
|
1439
|
+
npt.NDArray[np.float64]: Result of the division, an array of floats.
|
|
1440
1440
|
"""
|
|
1441
1441
|
mask = denominator == 0.0
|
|
1442
1442
|
denominator = denominator.copy()
|
|
@@ -1522,7 +1522,7 @@ def _check_zero_division(zero_division: Union[int, float, str]) -> float:
|
|
|
1522
1522
|
return np.nan
|
|
1523
1523
|
|
|
1524
1524
|
|
|
1525
|
-
def _nanaverage(a: npt.NDArray[np.
|
|
1525
|
+
def _nanaverage(a: npt.NDArray[np.float64], weights: Optional[npt.ArrayLike] = None) -> Any:
|
|
1526
1526
|
"""Compute the weighted average, ignoring NaNs.
|
|
1527
1527
|
|
|
1528
1528
|
Args:
|
|
@@ -26,7 +26,7 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
|
|
|
26
26
|
The below steps explain how correlation matrix is computed in a distributed way:
|
|
27
27
|
Let n = # of rows in the dataframe; sqrt_n = sqrt(n); X, Y are 2 columns in the dataframe
|
|
28
28
|
Correlation(X, Y) = numerator/denominator where
|
|
29
|
-
numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(
|
|
29
|
+
numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(Y/n)
|
|
30
30
|
denominator = std_dev(X)*std_dev(Y)
|
|
31
31
|
std_dev(X) = sqrt(dot(X/sqrt_n, X/sqrt_n) - sum(X/n)*sum(X/n))
|
|
32
32
|
|
|
@@ -74,27 +74,38 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
|
|
|
74
74
|
# Pushing this to a udtf requires creating a temp udtf which takes about 20 secs, so it doesn't make sense
|
|
75
75
|
# to have this in a udtf.
|
|
76
76
|
n_cols = len(columns)
|
|
77
|
-
|
|
78
|
-
|
|
77
|
+
column_means = np.zeros(n_cols)
|
|
78
|
+
mean_of_squares = np.zeros(n_cols)
|
|
79
79
|
dot_prod = np.zeros((n_cols, n_cols))
|
|
80
80
|
# Get sum, dot_prod and squared sum array from the results.
|
|
81
81
|
for i in range(len(results)):
|
|
82
82
|
x = results[i]
|
|
83
83
|
if x[1] == "sum_by_count":
|
|
84
|
-
|
|
84
|
+
column_means = cloudpickle.loads(x[0])
|
|
85
85
|
else:
|
|
86
86
|
row = int(x[1].strip("row_"))
|
|
87
87
|
dot_prod[row, :] = cloudpickle.loads(x[0])
|
|
88
|
-
|
|
88
|
+
mean_of_squares[row] = dot_prod[row, row]
|
|
89
89
|
|
|
90
90
|
# sum(X/n)*sum(Y/n) is computed for all combinations of X,Y (columns in the dataframe)
|
|
91
|
-
exey_arr = np.einsum("t,m->tm",
|
|
91
|
+
exey_arr = np.einsum("t,m->tm", column_means, column_means, optimize="optimal")
|
|
92
92
|
numerator_matrix = dot_prod - exey_arr
|
|
93
93
|
|
|
94
94
|
# standard deviation for all columns in the dataframe
|
|
95
|
-
|
|
95
|
+
variance_arr = mean_of_squares - np.einsum("i, i -> i", column_means, column_means, optimize="optimal")
|
|
96
|
+
# ensure non-negative values from potential precision issues where variance might be slightly negative
|
|
97
|
+
variance_arr = np.maximum(variance_arr, 0)
|
|
98
|
+
stddev_arr = np.sqrt(variance_arr)
|
|
96
99
|
# std_dev(X)*std_dev(Y) is computed for all combinations of X,Y (columns in the dataframe)
|
|
97
100
|
denominator_matrix = np.einsum("t,m->tm", stddev_arr, stddev_arr, optimize="optimal")
|
|
98
|
-
|
|
101
|
+
|
|
102
|
+
# Use np.divide to handle NaN cases
|
|
103
|
+
corr_res = np.divide(
|
|
104
|
+
numerator_matrix,
|
|
105
|
+
denominator_matrix,
|
|
106
|
+
out=np.full_like(numerator_matrix, np.nan),
|
|
107
|
+
where=(denominator_matrix != 0),
|
|
108
|
+
)
|
|
109
|
+
|
|
99
110
|
correlation_matrix = pd.DataFrame(corr_res, columns=columns, index=columns)
|
|
100
111
|
return correlation_matrix
|
|
@@ -26,7 +26,7 @@ def precision_recall_curve(
|
|
|
26
26
|
probas_pred_col_name: str,
|
|
27
27
|
pos_label: Optional[Union[str, int]] = None,
|
|
28
28
|
sample_weight_col_name: Optional[str] = None,
|
|
29
|
-
) -> tuple[npt.NDArray[np.
|
|
29
|
+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
|
|
30
30
|
"""
|
|
31
31
|
Compute precision-recall pairs for different probability thresholds.
|
|
32
32
|
|
|
@@ -125,7 +125,7 @@ def precision_recall_curve(
|
|
|
125
125
|
|
|
126
126
|
kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
|
|
127
127
|
result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
|
|
128
|
-
res: tuple[npt.NDArray[np.
|
|
128
|
+
res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
|
|
129
129
|
return res
|
|
130
130
|
|
|
131
131
|
|
|
@@ -140,7 +140,7 @@ def roc_auc_score(
|
|
|
140
140
|
max_fpr: Optional[float] = None,
|
|
141
141
|
multi_class: str = "raise",
|
|
142
142
|
labels: Optional[npt.ArrayLike] = None,
|
|
143
|
-
) -> Union[float, npt.NDArray[np.
|
|
143
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
144
144
|
"""
|
|
145
145
|
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
|
146
146
|
from prediction scores.
|
|
@@ -276,7 +276,7 @@ def roc_auc_score(
|
|
|
276
276
|
|
|
277
277
|
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_auc_score_anon_sproc, statement_params)
|
|
278
278
|
result_object = result.deserialize(session, roc_auc_score_anon_sproc(session, **kwargs))
|
|
279
|
-
auc: Union[float, npt.NDArray[np.
|
|
279
|
+
auc: Union[float, npt.NDArray[np.float64]] = result_object
|
|
280
280
|
return auc
|
|
281
281
|
|
|
282
282
|
|
|
@@ -289,7 +289,7 @@ def roc_curve(
|
|
|
289
289
|
pos_label: Optional[Union[str, int]] = None,
|
|
290
290
|
sample_weight_col_name: Optional[str] = None,
|
|
291
291
|
drop_intermediate: bool = True,
|
|
292
|
-
) -> tuple[npt.NDArray[np.
|
|
292
|
+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
|
|
293
293
|
"""
|
|
294
294
|
Compute Receiver operating characteristic (ROC).
|
|
295
295
|
|
|
@@ -380,6 +380,6 @@ def roc_curve(
|
|
|
380
380
|
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
|
|
381
381
|
result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
|
|
382
382
|
|
|
383
|
-
res: tuple[npt.NDArray[np.
|
|
383
|
+
res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
|
|
384
384
|
|
|
385
385
|
return res
|
|
@@ -29,7 +29,7 @@ def d2_absolute_error_score(
|
|
|
29
29
|
y_pred_col_names: Union[str, list[str]],
|
|
30
30
|
sample_weight_col_name: Optional[str] = None,
|
|
31
31
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
|
32
|
-
) -> Union[float, npt.NDArray[np.
|
|
32
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
33
33
|
"""
|
|
34
34
|
:math:`D^2` regression score function, \
|
|
35
35
|
fraction of absolute error explained.
|
|
@@ -111,7 +111,7 @@ def d2_absolute_error_score(
|
|
|
111
111
|
|
|
112
112
|
kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
|
|
113
113
|
result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
|
|
114
|
-
score: Union[float, npt.NDArray[np.
|
|
114
|
+
score: Union[float, npt.NDArray[np.float64]] = result_object
|
|
115
115
|
return score
|
|
116
116
|
|
|
117
117
|
|
|
@@ -124,7 +124,7 @@ def d2_pinball_score(
|
|
|
124
124
|
sample_weight_col_name: Optional[str] = None,
|
|
125
125
|
alpha: float = 0.5,
|
|
126
126
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
|
127
|
-
) -> Union[float, npt.NDArray[np.
|
|
127
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
128
128
|
"""
|
|
129
129
|
:math:`D^2` regression score function, fraction of pinball loss explained.
|
|
130
130
|
|
|
@@ -211,7 +211,7 @@ def d2_pinball_score(
|
|
|
211
211
|
kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
|
|
212
212
|
result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
|
|
213
213
|
|
|
214
|
-
score: Union[float, npt.NDArray[np.
|
|
214
|
+
score: Union[float, npt.NDArray[np.float64]] = result_object
|
|
215
215
|
return score
|
|
216
216
|
|
|
217
217
|
|
|
@@ -224,7 +224,7 @@ def explained_variance_score(
|
|
|
224
224
|
sample_weight_col_name: Optional[str] = None,
|
|
225
225
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
|
226
226
|
force_finite: bool = True,
|
|
227
|
-
) -> Union[float, npt.NDArray[np.
|
|
227
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
228
228
|
"""
|
|
229
229
|
Explained variance regression score function.
|
|
230
230
|
|
|
@@ -326,7 +326,7 @@ def explained_variance_score(
|
|
|
326
326
|
|
|
327
327
|
kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
|
|
328
328
|
result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
|
|
329
|
-
score: Union[float, npt.NDArray[np.
|
|
329
|
+
score: Union[float, npt.NDArray[np.float64]] = result_object
|
|
330
330
|
return score
|
|
331
331
|
|
|
332
332
|
|
|
@@ -338,7 +338,7 @@ def mean_absolute_error(
|
|
|
338
338
|
y_pred_col_names: Union[str, list[str]],
|
|
339
339
|
sample_weight_col_name: Optional[str] = None,
|
|
340
340
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
|
341
|
-
) -> Union[float, npt.NDArray[np.
|
|
341
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
342
342
|
"""
|
|
343
343
|
Mean absolute error regression loss.
|
|
344
344
|
|
|
@@ -411,7 +411,7 @@ def mean_absolute_percentage_error(
|
|
|
411
411
|
y_pred_col_names: Union[str, list[str]],
|
|
412
412
|
sample_weight_col_name: Optional[str] = None,
|
|
413
413
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
|
414
|
-
) -> Union[float, npt.NDArray[np.
|
|
414
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
415
415
|
"""
|
|
416
416
|
Mean absolute percentage error (MAPE) regression loss.
|
|
417
417
|
|
|
@@ -495,7 +495,7 @@ def mean_squared_error(
|
|
|
495
495
|
sample_weight_col_name: Optional[str] = None,
|
|
496
496
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
|
497
497
|
squared: bool = True,
|
|
498
|
-
) -> Union[float, npt.NDArray[np.
|
|
498
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
|
499
499
|
"""
|
|
500
500
|
Mean squared error regression loss.
|
|
501
501
|
|
|
@@ -264,6 +264,7 @@ def plot_force(
|
|
|
264
264
|
def plot_influence_sensitivity(
|
|
265
265
|
shap_values: type_hints.SupportedDataType,
|
|
266
266
|
feature_values: type_hints.SupportedDataType,
|
|
267
|
+
infer_is_categorical: bool = True,
|
|
267
268
|
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
|
268
269
|
) -> Any:
|
|
269
270
|
"""
|
|
@@ -274,6 +275,8 @@ def plot_influence_sensitivity(
|
|
|
274
275
|
Args:
|
|
275
276
|
shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
|
|
276
277
|
feature_values: pandas Series or 2D array containing the feature values for the same feature
|
|
278
|
+
infer_is_categorical: If True, the function will infer if the feature is categorical
|
|
279
|
+
based on the number of unique values.
|
|
277
280
|
figsize: tuple of (width, height) for the plot
|
|
278
281
|
|
|
279
282
|
Returns:
|
|
@@ -294,7 +297,7 @@ def plot_influence_sensitivity(
|
|
|
294
297
|
elif feature_values_df.shape[0] != shap_values_df.shape[0]:
|
|
295
298
|
raise ValueError("Feature values and SHAP values must have the same number of rows.")
|
|
296
299
|
|
|
297
|
-
scatter = _create_scatter_plot(feature_values, shap_values, figsize)
|
|
300
|
+
scatter = _create_scatter_plot(feature_values, shap_values, infer_is_categorical, figsize)
|
|
298
301
|
return st.altair_chart(scatter) if use_streamlit else scatter
|
|
299
302
|
|
|
300
303
|
|
|
@@ -322,11 +325,13 @@ def _prepare_feature_values_for_streamlit(
|
|
|
322
325
|
return feature_values, shap_values, st
|
|
323
326
|
|
|
324
327
|
|
|
325
|
-
def _create_scatter_plot(
|
|
328
|
+
def _create_scatter_plot(
|
|
329
|
+
feature_values: pd.Series, shap_values: pd.Series, infer_is_categorical: bool, figsize: tuple[float, float]
|
|
330
|
+
) -> alt.Chart:
|
|
326
331
|
unique_vals = np.sort(np.unique(feature_values.values))
|
|
327
332
|
max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
|
|
328
333
|
points_per_value = len(feature_values.values) / len(unique_vals)
|
|
329
|
-
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
|
|
334
|
+
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10 if infer_is_categorical else False
|
|
330
335
|
|
|
331
336
|
kwargs = (
|
|
332
337
|
{
|
|
@@ -403,9 +408,11 @@ def plot_violin(
|
|
|
403
408
|
.transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
|
|
404
409
|
.mark_area(orient="vertical")
|
|
405
410
|
.encode(
|
|
406
|
-
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=
|
|
411
|
+
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=False),
|
|
407
412
|
x=alt.X("shap_value:Q", title="SHAP Value"),
|
|
408
|
-
row=alt.Row(
|
|
413
|
+
row=alt.Row(
|
|
414
|
+
"feature_name:N", sort=column_sort_order, header=alt.Header(labelAngle=0, labelAlign="left")
|
|
415
|
+
).spacing(0),
|
|
409
416
|
color=alt.Color("feature_name:N", legend=None),
|
|
410
417
|
tooltip=["feature_name", "shap_value"],
|
|
411
418
|
)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from types import ModuleType
|
|
2
|
-
from typing import Any, Optional,
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
3
3
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
from absl.logging import logging
|
|
@@ -17,15 +17,10 @@ from snowflake.ml.model._packager.model_meta import model_meta
|
|
|
17
17
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
|
18
18
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from snowflake.ml.experiment._experiment_info import ExperimentInfo
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
"""Protocol defining the interface for event handlers used during model operations."""
|
|
25
|
-
|
|
26
|
-
def update(self, message: str) -> None:
|
|
27
|
-
"""Update with a progress message."""
|
|
28
|
-
...
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
29
24
|
|
|
30
25
|
|
|
31
26
|
class ModelManager:
|
|
@@ -66,9 +61,10 @@ class ModelManager:
|
|
|
66
61
|
code_paths: Optional[list[str]] = None,
|
|
67
62
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
68
63
|
task: type_hints.Task = task.Task.UNKNOWN,
|
|
64
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
|
69
65
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
70
66
|
statement_params: Optional[dict[str, Any]] = None,
|
|
71
|
-
|
|
67
|
+
progress_status: Optional[Any] = None,
|
|
72
68
|
) -> model_version_impl.ModelVersion:
|
|
73
69
|
|
|
74
70
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
|
@@ -150,9 +146,10 @@ class ModelManager:
|
|
|
150
146
|
code_paths=code_paths,
|
|
151
147
|
ext_modules=ext_modules,
|
|
152
148
|
task=task,
|
|
149
|
+
experiment_info=experiment_info,
|
|
153
150
|
options=options,
|
|
154
151
|
statement_params=statement_params,
|
|
155
|
-
|
|
152
|
+
progress_status=progress_status,
|
|
156
153
|
)
|
|
157
154
|
|
|
158
155
|
def _log_model(
|
|
@@ -175,9 +172,10 @@ class ModelManager:
|
|
|
175
172
|
code_paths: Optional[list[str]] = None,
|
|
176
173
|
ext_modules: Optional[list[ModuleType]] = None,
|
|
177
174
|
task: type_hints.Task = task.Task.UNKNOWN,
|
|
175
|
+
experiment_info: Optional["ExperimentInfo"] = None,
|
|
178
176
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
179
177
|
statement_params: Optional[dict[str, Any]] = None,
|
|
180
|
-
|
|
178
|
+
progress_status: Optional[Any] = None,
|
|
181
179
|
) -> model_version_impl.ModelVersion:
|
|
182
180
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
|
183
181
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
|
@@ -265,7 +263,9 @@ class ModelManager:
|
|
|
265
263
|
)
|
|
266
264
|
|
|
267
265
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
|
268
|
-
|
|
266
|
+
if progress_status:
|
|
267
|
+
progress_status.update("packaging model...")
|
|
268
|
+
progress_status.increment()
|
|
269
269
|
|
|
270
270
|
# Extract save_location from options if present
|
|
271
271
|
save_location = None
|
|
@@ -279,6 +279,11 @@ class ModelManager:
|
|
|
279
279
|
statement_params=statement_params,
|
|
280
280
|
save_location=save_location,
|
|
281
281
|
)
|
|
282
|
+
|
|
283
|
+
if progress_status:
|
|
284
|
+
progress_status.update("creating model manifest...")
|
|
285
|
+
progress_status.increment()
|
|
286
|
+
|
|
282
287
|
model_metadata: model_meta.ModelMetadata = mc.save(
|
|
283
288
|
name=model_name_id.resolved(),
|
|
284
289
|
model=model,
|
|
@@ -295,7 +300,12 @@ class ModelManager:
|
|
|
295
300
|
ext_modules=ext_modules,
|
|
296
301
|
options=options,
|
|
297
302
|
task=task,
|
|
303
|
+
experiment_info=experiment_info,
|
|
298
304
|
)
|
|
305
|
+
|
|
306
|
+
if progress_status:
|
|
307
|
+
progress_status.update("uploading model files...")
|
|
308
|
+
progress_status.increment()
|
|
299
309
|
statement_params = telemetry.add_statement_params_custom_tags(
|
|
300
310
|
statement_params, model_metadata.telemetry_metadata()
|
|
301
311
|
)
|
|
@@ -304,7 +314,9 @@ class ModelManager:
|
|
|
304
314
|
)
|
|
305
315
|
|
|
306
316
|
logger.info("Start creating MODEL object for you in the Snowflake.")
|
|
307
|
-
|
|
317
|
+
if progress_status:
|
|
318
|
+
progress_status.update("creating model object in Snowflake...")
|
|
319
|
+
progress_status.increment()
|
|
308
320
|
|
|
309
321
|
self._model_ops.create_from_stage(
|
|
310
322
|
composed_model=mc,
|
|
@@ -331,6 +343,10 @@ class ModelManager:
|
|
|
331
343
|
version_name=version_name_id,
|
|
332
344
|
)
|
|
333
345
|
|
|
346
|
+
if progress_status:
|
|
347
|
+
progress_status.update("setting model metadata...")
|
|
348
|
+
progress_status.increment()
|
|
349
|
+
|
|
334
350
|
if comment:
|
|
335
351
|
mv.comment = comment
|
|
336
352
|
|
|
@@ -344,7 +360,8 @@ class ModelManager:
|
|
|
344
360
|
statement_params=statement_params,
|
|
345
361
|
)
|
|
346
362
|
|
|
347
|
-
|
|
363
|
+
if progress_status:
|
|
364
|
+
progress_status.update("model logged successfully!")
|
|
348
365
|
|
|
349
366
|
return mv
|
|
350
367
|
|