snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +44 -3
- snowflake/ml/_internal/platform_capabilities.py +52 -2
- snowflake/ml/_internal/type_utils.py +1 -1
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +71 -0
- snowflake/ml/_internal/utils/service_logger.py +4 -2
- snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
- snowflake/ml/data/data_connector.py +43 -2
- snowflake/ml/data/data_ingestor.py +8 -0
- snowflake/ml/data/torch_utils.py +1 -1
- snowflake/ml/dataset/dataset.py +3 -2
- snowflake/ml/dataset/dataset_reader.py +22 -6
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
- snowflake/ml/experiment/_entities/__init__.py +4 -0
- snowflake/ml/experiment/_entities/experiment.py +10 -0
- snowflake/ml/experiment/_entities/run.py +62 -0
- snowflake/ml/experiment/_entities/run_metadata.py +68 -0
- snowflake/ml/experiment/_experiment_info.py +63 -0
- snowflake/ml/experiment/experiment_tracking.py +319 -0
- snowflake/ml/jobs/_utils/constants.py +1 -1
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +5 -3
- snowflake/ml/jobs/_utils/query_helper.py +20 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
- snowflake/ml/jobs/_utils/spec_utils.py +21 -4
- snowflake/ml/jobs/decorators.py +18 -25
- snowflake/ml/jobs/job.py +137 -37
- snowflake/ml/jobs/manager.py +228 -153
- snowflake/ml/lineage/lineage_node.py +2 -2
- snowflake/ml/model/_client/model/model_version_impl.py +16 -4
- snowflake/ml/model/_client/ops/model_ops.py +12 -3
- snowflake/ml/model/_client/ops/service_ops.py +324 -138
- snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
- snowflake/ml/model/_model_composer/model_composer.py +6 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/event_handler.py +117 -0
- snowflake/ml/model/model_signature.py +9 -9
- snowflake/ml/model/models/huggingface_pipeline.py +170 -1
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/framework/base.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +14 -14
- snowflake/ml/modeling/metrics/correlation.py +19 -8
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/modeling/metrics/ranking.py +6 -6
- snowflake/ml/modeling/metrics/regression.py +9 -9
- snowflake/ml/monitoring/explain_visualize.py +12 -5
- snowflake/ml/registry/_manager/model_manager.py +47 -15
- snowflake/ml/registry/registry.py +109 -64
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from typing import Any, Optional
|
4
|
+
|
5
|
+
|
6
|
+
class _TqdmStatusContext:
|
7
|
+
"""A tqdm-based context manager for status updates."""
|
8
|
+
|
9
|
+
def __init__(self, label: str, tqdm_module: Any, total: Optional[int] = None) -> None:
|
10
|
+
self._label = label
|
11
|
+
self._tqdm = tqdm_module
|
12
|
+
self._total = total or 1
|
13
|
+
|
14
|
+
def __enter__(self) -> "_TqdmStatusContext":
|
15
|
+
self._progress_bar = self._tqdm.tqdm(desc=self._label, file=sys.stdout, total=self._total, leave=True)
|
16
|
+
return self
|
17
|
+
|
18
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
19
|
+
self._progress_bar.close()
|
20
|
+
|
21
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
22
|
+
"""Update the status by updating the tqdm description."""
|
23
|
+
if state == "complete":
|
24
|
+
self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
|
25
|
+
self._progress_bar.set_description(label)
|
26
|
+
else:
|
27
|
+
self._progress_bar.set_description(f"{self._label}: {label}")
|
28
|
+
|
29
|
+
def increment(self, n: int = 1) -> None:
|
30
|
+
"""Increment the progress bar."""
|
31
|
+
self._progress_bar.update(n)
|
32
|
+
|
33
|
+
|
34
|
+
class _StreamlitStatusContext:
|
35
|
+
"""A streamlit-based context manager for status updates with progress bar support."""
|
36
|
+
|
37
|
+
def __init__(self, label: str, streamlit_module: Any, total: Optional[int] = None) -> None:
|
38
|
+
self._label = label
|
39
|
+
self._streamlit = streamlit_module
|
40
|
+
self._total = total
|
41
|
+
self._current = 0
|
42
|
+
self._progress_bar = None
|
43
|
+
|
44
|
+
def __enter__(self) -> "_StreamlitStatusContext":
|
45
|
+
self._status_container = self._streamlit.status(self._label, state="running", expanded=True)
|
46
|
+
if self._total is not None:
|
47
|
+
with self._status_container:
|
48
|
+
self._progress_bar = self._streamlit.progress(0, text=f"0/{self._total}")
|
49
|
+
return self
|
50
|
+
|
51
|
+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
52
|
+
self._status_container.update(state="complete")
|
53
|
+
|
54
|
+
def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
|
55
|
+
"""Update the status label."""
|
56
|
+
if state != "complete":
|
57
|
+
label = f"{self._label}: {label}"
|
58
|
+
self._status_container.update(label=label, state=state, expanded=expanded)
|
59
|
+
if self._progress_bar is not None:
|
60
|
+
self._progress_bar.progress(
|
61
|
+
self._current / self._total if self._total > 0 else 0,
|
62
|
+
text=f"{label} - {self._current}/{self._total}",
|
63
|
+
)
|
64
|
+
|
65
|
+
def increment(self, n: int = 1) -> None:
|
66
|
+
"""Increment the progress."""
|
67
|
+
if self._total is not None:
|
68
|
+
self._current = min(self._current + n, self._total)
|
69
|
+
if self._progress_bar is not None:
|
70
|
+
progress_value = self._current / self._total if self._total > 0 else 0
|
71
|
+
self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
|
72
|
+
|
73
|
+
|
74
|
+
class ModelEventHandler:
|
75
|
+
"""Event handler for model operations with streamlit-aware status updates."""
|
76
|
+
|
77
|
+
def __init__(self) -> None:
|
78
|
+
self._streamlit = None
|
79
|
+
|
80
|
+
# Try streamlit first
|
81
|
+
try:
|
82
|
+
import streamlit as st
|
83
|
+
|
84
|
+
if st.runtime.exists():
|
85
|
+
USE_STREAMLIT_WIDGETS = os.getenv("USE_STREAMLIT_WIDGETS", "1") == "1"
|
86
|
+
if USE_STREAMLIT_WIDGETS:
|
87
|
+
self._streamlit = st
|
88
|
+
except ImportError:
|
89
|
+
pass
|
90
|
+
|
91
|
+
import tqdm
|
92
|
+
|
93
|
+
self._tqdm = tqdm
|
94
|
+
|
95
|
+
def update(self, message: str) -> None:
|
96
|
+
"""Write a message using streamlit if available, otherwise use tqdm."""
|
97
|
+
if self._streamlit is not None:
|
98
|
+
self._streamlit.write(message)
|
99
|
+
else:
|
100
|
+
self._tqdm.tqdm.write(message)
|
101
|
+
|
102
|
+
def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
|
103
|
+
"""Context manager that provides status updates with optional enhanced display capabilities.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
label: The status label
|
107
|
+
state: The initial state ("running", "complete", "error")
|
108
|
+
expanded: Whether to show expanded view (streamlit only)
|
109
|
+
total: Total number of steps for progress tracking (optional)
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Status context (Streamlit or Tqdm)
|
113
|
+
"""
|
114
|
+
if self._streamlit is not None:
|
115
|
+
return _StreamlitStatusContext(label, self._streamlit, total)
|
116
|
+
else:
|
117
|
+
return _TqdmStatusContext(label, self._tqdm, total)
|
@@ -16,7 +16,7 @@ from snowflake.ml._internal.exceptions import (
|
|
16
16
|
exceptions as snowml_exceptions,
|
17
17
|
)
|
18
18
|
from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
|
19
|
-
from snowflake.ml.model import type_hints
|
19
|
+
from snowflake.ml.model import type_hints
|
20
20
|
from snowflake.ml.model._signatures import (
|
21
21
|
base_handler,
|
22
22
|
builtins_handler,
|
@@ -55,9 +55,9 @@ _MODEL_TELEMETRY_SUBPROJECT = "ModelSignature"
|
|
55
55
|
|
56
56
|
|
57
57
|
def _truncate_data(
|
58
|
-
data:
|
58
|
+
data: type_hints.SupportedDataType,
|
59
59
|
length: Optional[int] = 100,
|
60
|
-
) ->
|
60
|
+
) -> type_hints.SupportedDataType:
|
61
61
|
for handler in _ALL_DATA_HANDLERS:
|
62
62
|
if handler.can_handle(data):
|
63
63
|
# If length is None, return the original data
|
@@ -89,7 +89,7 @@ def _truncate_data(
|
|
89
89
|
|
90
90
|
|
91
91
|
def _infer_signature(
|
92
|
-
data:
|
92
|
+
data: type_hints.SupportedLocalDataType, role: Literal["input", "output"], use_snowflake_identifiers: bool = False
|
93
93
|
) -> Sequence[core.BaseFeatureSpec]:
|
94
94
|
"""Infer the inputs/outputs signature given a data that could be dataframe, numpy array or list.
|
95
95
|
Dispatching is used to separate logic for different types.
|
@@ -142,7 +142,7 @@ def _rename_signature_with_snowflake_identifiers(
|
|
142
142
|
|
143
143
|
|
144
144
|
def _validate_array_or_series_type(
|
145
|
-
arr: Union[
|
145
|
+
arr: Union[type_hints._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
|
146
146
|
) -> bool:
|
147
147
|
original_dtype = arr.dtype
|
148
148
|
dtype = arr.dtype
|
@@ -649,7 +649,7 @@ def _validate_snowpark_type_feature(
|
|
649
649
|
|
650
650
|
|
651
651
|
def _convert_local_data_to_df(
|
652
|
-
data:
|
652
|
+
data: type_hints.SupportedLocalDataType, ensure_serializable: bool = False
|
653
653
|
) -> pd.DataFrame:
|
654
654
|
"""Convert local data to pandas DataFrame or Snowpark DataFrame
|
655
655
|
|
@@ -679,7 +679,7 @@ def _convert_local_data_to_df(
|
|
679
679
|
|
680
680
|
|
681
681
|
def _convert_and_validate_local_data(
|
682
|
-
data:
|
682
|
+
data: type_hints.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False
|
683
683
|
) -> pd.DataFrame:
|
684
684
|
"""Validate the data with features in model signature and convert to DataFrame
|
685
685
|
|
@@ -703,8 +703,8 @@ def _convert_and_validate_local_data(
|
|
703
703
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
704
704
|
)
|
705
705
|
def infer_signature(
|
706
|
-
input_data:
|
707
|
-
output_data:
|
706
|
+
input_data: type_hints.SupportedLocalDataType,
|
707
|
+
output_data: type_hints.SupportedLocalDataType,
|
708
708
|
input_feature_names: Optional[list[str]] = None,
|
709
709
|
output_feature_names: Optional[list[str]] = None,
|
710
710
|
input_data_limit: Optional[int] = 100,
|
@@ -1,8 +1,22 @@
|
|
1
|
+
import logging
|
1
2
|
import warnings
|
2
|
-
from typing import Any, Optional
|
3
|
+
from typing import Any, Optional, Union
|
3
4
|
|
4
5
|
from packaging import version
|
5
6
|
|
7
|
+
from snowflake import snowpark
|
8
|
+
from snowflake.ml._internal import telemetry
|
9
|
+
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
|
+
from snowflake.ml._internal.utils import sql_identifier
|
11
|
+
from snowflake.ml.model._client.ops import service_ops
|
12
|
+
from snowflake.snowpark import async_job, session
|
13
|
+
|
14
|
+
logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
_TELEMETRY_PROJECT = "MLOps"
|
18
|
+
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
19
|
+
|
6
20
|
|
7
21
|
class HuggingFacePipelineModel:
|
8
22
|
def __init__(
|
@@ -214,4 +228,159 @@ class HuggingFacePipelineModel:
|
|
214
228
|
self.token = token
|
215
229
|
self.trust_remote_code = trust_remote_code
|
216
230
|
self.model_kwargs = model_kwargs
|
231
|
+
self.tokenizer = tokenizer
|
217
232
|
self.__dict__.update(kwargs)
|
233
|
+
|
234
|
+
@telemetry.send_api_usage_telemetry(
|
235
|
+
project=_TELEMETRY_PROJECT,
|
236
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
237
|
+
func_params_to_log=[
|
238
|
+
"service_name",
|
239
|
+
"image_build_compute_pool",
|
240
|
+
"service_compute_pool",
|
241
|
+
"image_repo",
|
242
|
+
"gpu_requests",
|
243
|
+
"num_workers",
|
244
|
+
"max_batch_rows",
|
245
|
+
],
|
246
|
+
)
|
247
|
+
@snowpark._internal.utils.private_preview(version="1.9.1")
|
248
|
+
def create_service(
|
249
|
+
self,
|
250
|
+
*,
|
251
|
+
session: session.Session,
|
252
|
+
# registry.log_model parameters
|
253
|
+
model_name: str,
|
254
|
+
version_name: Optional[str] = None,
|
255
|
+
pip_requirements: Optional[list[str]] = None,
|
256
|
+
conda_dependencies: Optional[list[str]] = None,
|
257
|
+
comment: Optional[str] = None,
|
258
|
+
# model_version_impl.create_service parameters
|
259
|
+
service_name: str,
|
260
|
+
service_compute_pool: str,
|
261
|
+
image_repo: str,
|
262
|
+
image_build_compute_pool: Optional[str] = None,
|
263
|
+
ingress_enabled: bool = False,
|
264
|
+
max_instances: int = 1,
|
265
|
+
cpu_requests: Optional[str] = None,
|
266
|
+
memory_requests: Optional[str] = None,
|
267
|
+
gpu_requests: Optional[Union[str, int]] = None,
|
268
|
+
num_workers: Optional[int] = None,
|
269
|
+
max_batch_rows: Optional[int] = None,
|
270
|
+
force_rebuild: bool = False,
|
271
|
+
build_external_access_integrations: Optional[list[str]] = None,
|
272
|
+
block: bool = True,
|
273
|
+
) -> Union[str, async_job.AsyncJob]:
|
274
|
+
"""Logs a Hugging Face model and creates a service in Snowflake.
|
275
|
+
|
276
|
+
Args:
|
277
|
+
session: The Snowflake session object.
|
278
|
+
model_name: The name of the model in Snowflake.
|
279
|
+
version_name: The version name of the model. Defaults to None.
|
280
|
+
pip_requirements: Pip requirements for the model. Defaults to None.
|
281
|
+
conda_dependencies: Conda dependencies for the model. Defaults to None.
|
282
|
+
comment: Comment for the model. Defaults to None.
|
283
|
+
service_name: The name of the service to create.
|
284
|
+
service_compute_pool: The compute pool for the service.
|
285
|
+
image_repo: The name of the image repository.
|
286
|
+
image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
|
287
|
+
the service compute pool if None.
|
288
|
+
ingress_enabled: Whether ingress is enabled. Defaults to False.
|
289
|
+
max_instances: Maximum number of instances. Defaults to 1.
|
290
|
+
cpu_requests: CPU requests configuration. Defaults to None.
|
291
|
+
memory_requests: Memory requests configuration. Defaults to None.
|
292
|
+
gpu_requests: GPU requests configuration. Defaults to None.
|
293
|
+
num_workers: Number of workers. Defaults to None.
|
294
|
+
max_batch_rows: Maximum batch rows. Defaults to None.
|
295
|
+
force_rebuild: Whether to force rebuild the image. Defaults to False.
|
296
|
+
build_external_access_integrations: External access integrations for building the image. Defaults to None.
|
297
|
+
block: Whether to block the operation. Defaults to True.
|
298
|
+
|
299
|
+
Raises:
|
300
|
+
ValueError: if database and schema name is not provided and session doesn't have a
|
301
|
+
database and schema name.
|
302
|
+
|
303
|
+
Returns:
|
304
|
+
The service ID or an async job object.
|
305
|
+
|
306
|
+
.. # noqa: DAR003
|
307
|
+
"""
|
308
|
+
statement_params = telemetry.get_statement_params(
|
309
|
+
project=_TELEMETRY_PROJECT,
|
310
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
311
|
+
)
|
312
|
+
|
313
|
+
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
314
|
+
session_database_name = session.get_current_database()
|
315
|
+
session_schema_name = session.get_current_schema()
|
316
|
+
if database_name_id is None:
|
317
|
+
if session_database_name is None:
|
318
|
+
raise ValueError("Either database needs to be provided or needs to be available in session.")
|
319
|
+
database_name_id = sql_identifier.SqlIdentifier(session_database_name)
|
320
|
+
if schema_name_id is None:
|
321
|
+
if session_schema_name is None:
|
322
|
+
raise ValueError("Either schema needs to be provided or needs to be available in session.")
|
323
|
+
schema_name_id = sql_identifier.SqlIdentifier(session_schema_name)
|
324
|
+
|
325
|
+
if version_name is None:
|
326
|
+
name_generator = hrid_generator.HRID16()
|
327
|
+
version_name = name_generator.generate()[1]
|
328
|
+
|
329
|
+
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
330
|
+
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
331
|
+
|
332
|
+
service_operator = service_ops.ServiceOperator(
|
333
|
+
session=session,
|
334
|
+
database_name=database_name_id,
|
335
|
+
schema_name=schema_name_id,
|
336
|
+
)
|
337
|
+
logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
|
338
|
+
|
339
|
+
return service_operator.create_service(
|
340
|
+
database_name=database_name_id,
|
341
|
+
schema_name=schema_name_id,
|
342
|
+
model_name=model_name_id,
|
343
|
+
version_name=sql_identifier.SqlIdentifier(version_name),
|
344
|
+
service_database_name=service_db_id,
|
345
|
+
service_schema_name=service_schema_id,
|
346
|
+
service_name=service_id,
|
347
|
+
image_build_compute_pool_name=(
|
348
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
349
|
+
if image_build_compute_pool
|
350
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
351
|
+
),
|
352
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
353
|
+
image_repo_database_name=image_repo_db_id,
|
354
|
+
image_repo_schema_name=image_repo_schema_id,
|
355
|
+
image_repo_name=image_repo_id,
|
356
|
+
ingress_enabled=ingress_enabled,
|
357
|
+
max_instances=max_instances,
|
358
|
+
cpu_requests=cpu_requests,
|
359
|
+
memory_requests=memory_requests,
|
360
|
+
gpu_requests=gpu_requests,
|
361
|
+
num_workers=num_workers,
|
362
|
+
max_batch_rows=max_batch_rows,
|
363
|
+
force_rebuild=force_rebuild,
|
364
|
+
build_external_access_integrations=(
|
365
|
+
None
|
366
|
+
if build_external_access_integrations is None
|
367
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
368
|
+
),
|
369
|
+
block=block,
|
370
|
+
statement_params=statement_params,
|
371
|
+
# hf model
|
372
|
+
hf_model_args=service_ops.HFModelArgs(
|
373
|
+
hf_model_name=self.model,
|
374
|
+
hf_task=self.task,
|
375
|
+
hf_tokenizer=self.tokenizer,
|
376
|
+
hf_revision=self.revision,
|
377
|
+
hf_token=self.token,
|
378
|
+
hf_trust_remote_code=bool(self.trust_remote_code),
|
379
|
+
hf_model_kwargs=self.model_kwargs,
|
380
|
+
pip_requirements=pip_requirements,
|
381
|
+
conda_dependencies=conda_dependencies,
|
382
|
+
comment=comment,
|
383
|
+
# TODO: remove warehouse in the next release
|
384
|
+
warehouse=session.get_current_warehouse(),
|
385
|
+
),
|
386
|
+
)
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class TargetPlatform(Enum):
|
5
|
+
WAREHOUSE = "WAREHOUSE"
|
6
|
+
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
7
|
+
|
8
|
+
|
9
|
+
WAREHOUSE_ONLY = [TargetPlatform.WAREHOUSE]
|
10
|
+
SNOWPARK_CONTAINER_SERVICES_ONLY = [TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
11
|
+
BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES = [TargetPlatform.WAREHOUSE, TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
@@ -0,0 +1,9 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class Task(Enum):
|
5
|
+
UNKNOWN = "UNKNOWN"
|
6
|
+
TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
|
7
|
+
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
8
|
+
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
9
|
+
TABULAR_RANKING = "TABULAR_RANKING"
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
-
from enum import Enum
|
3
2
|
from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
|
4
3
|
|
5
4
|
import numpy.typing as npt
|
6
5
|
from typing_extensions import NotRequired
|
7
6
|
|
7
|
+
from snowflake.ml.model.target_platform import TargetPlatform
|
8
|
+
from snowflake.ml.model.task import Task
|
9
|
+
|
8
10
|
if TYPE_CHECKING:
|
9
11
|
import catboost
|
10
12
|
import keras
|
@@ -321,17 +323,7 @@ ModelLoadOption = Union[
|
|
321
323
|
]
|
322
324
|
|
323
325
|
|
324
|
-
|
325
|
-
UNKNOWN = "UNKNOWN"
|
326
|
-
TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
|
327
|
-
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
328
|
-
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
329
|
-
TABULAR_RANKING = "TABULAR_RANKING"
|
330
|
-
|
331
|
-
|
332
|
-
class TargetPlatform(Enum):
|
333
|
-
WAREHOUSE = "WAREHOUSE"
|
334
|
-
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
326
|
+
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
335
327
|
|
336
328
|
|
337
|
-
|
329
|
+
__all__ = ["TargetPlatform", "Task"]
|
@@ -698,7 +698,7 @@ class BaseTransformer(BaseEstimator):
|
|
698
698
|
self,
|
699
699
|
attribute: Optional[Mapping[str, Union[int, float, str, Iterable[Union[int, float, str]]]]],
|
700
700
|
dtype: Optional[type] = None,
|
701
|
-
) -> Optional[npt.NDArray[Union[np.int_, np.
|
701
|
+
) -> Optional[npt.NDArray[Union[np.int_, np.float64, np.str_]]]:
|
702
702
|
"""
|
703
703
|
Convert the attribute from dict to ndarray based on the order of `self.input_cols`.
|
704
704
|
|
@@ -96,7 +96,7 @@ def confusion_matrix(
|
|
96
96
|
labels: Optional[npt.ArrayLike] = None,
|
97
97
|
sample_weight_col_name: Optional[str] = None,
|
98
98
|
normalize: Optional[str] = None,
|
99
|
-
) -> Union[npt.NDArray[np.int_], npt.NDArray[np.
|
99
|
+
) -> Union[npt.NDArray[np.int_], npt.NDArray[np.float64]]:
|
100
100
|
"""
|
101
101
|
Compute confusion matrix to evaluate the accuracy of a classification.
|
102
102
|
|
@@ -320,7 +320,7 @@ def f1_score(
|
|
320
320
|
average: Optional[str] = "binary",
|
321
321
|
sample_weight_col_name: Optional[str] = None,
|
322
322
|
zero_division: Union[str, int] = "warn",
|
323
|
-
) -> Union[float, npt.NDArray[np.
|
323
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
324
324
|
"""
|
325
325
|
Compute the F1 score, also known as balanced F-score or F-measure.
|
326
326
|
|
@@ -414,7 +414,7 @@ def fbeta_score(
|
|
414
414
|
average: Optional[str] = "binary",
|
415
415
|
sample_weight_col_name: Optional[str] = None,
|
416
416
|
zero_division: Union[str, int] = "warn",
|
417
|
-
) -> Union[float, npt.NDArray[np.
|
417
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
418
418
|
"""
|
419
419
|
Compute the F-beta score.
|
420
420
|
|
@@ -696,7 +696,7 @@ def precision_recall_fscore_support(
|
|
696
696
|
zero_division: Union[str, int] = "warn",
|
697
697
|
) -> Union[
|
698
698
|
tuple[float, float, float, None],
|
699
|
-
tuple[npt.NDArray[np.
|
699
|
+
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
|
700
700
|
]:
|
701
701
|
"""
|
702
702
|
Compute precision, recall, F-measure and support for each class.
|
@@ -855,7 +855,7 @@ def precision_recall_fscore_support(
|
|
855
855
|
|
856
856
|
res: Union[
|
857
857
|
tuple[float, float, float, None],
|
858
|
-
tuple[npt.NDArray[np.
|
858
|
+
tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
|
859
859
|
] = result_object[:4]
|
860
860
|
warning = result_object[-1]
|
861
861
|
if warning:
|
@@ -1050,7 +1050,7 @@ def _register_multilabel_confusion_matrix_computer(
|
|
1050
1050
|
|
1051
1051
|
def end_partition(
|
1052
1052
|
self,
|
1053
|
-
) -> Iterable[tuple[npt.NDArray[np.
|
1053
|
+
) -> Iterable[tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]]:
|
1054
1054
|
MCM = metrics.multilabel_confusion_matrix(
|
1055
1055
|
self._y_true,
|
1056
1056
|
self._y_pred,
|
@@ -1098,7 +1098,7 @@ def _binary_precision_score(
|
|
1098
1098
|
pos_label: Union[str, int] = 1,
|
1099
1099
|
sample_weight_col_name: Optional[str] = None,
|
1100
1100
|
zero_division: Union[str, int] = "warn",
|
1101
|
-
) -> Union[float, npt.NDArray[np.
|
1101
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
1102
1102
|
|
1103
1103
|
statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
|
1104
1104
|
|
@@ -1173,7 +1173,7 @@ def precision_score(
|
|
1173
1173
|
average: Optional[str] = "binary",
|
1174
1174
|
sample_weight_col_name: Optional[str] = None,
|
1175
1175
|
zero_division: Union[str, int] = "warn",
|
1176
|
-
) -> Union[float, npt.NDArray[np.
|
1176
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
1177
1177
|
"""
|
1178
1178
|
Compute the precision.
|
1179
1179
|
|
@@ -1271,7 +1271,7 @@ def recall_score(
|
|
1271
1271
|
average: Optional[str] = "binary",
|
1272
1272
|
sample_weight_col_name: Optional[str] = None,
|
1273
1273
|
zero_division: Union[str, int] = "warn",
|
1274
|
-
) -> Union[float, npt.NDArray[np.
|
1274
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
1275
1275
|
"""
|
1276
1276
|
Compute the recall.
|
1277
1277
|
|
@@ -1406,14 +1406,14 @@ def _check_binary_labels(
|
|
1406
1406
|
|
1407
1407
|
|
1408
1408
|
def _prf_divide(
|
1409
|
-
numerator: npt.NDArray[np.
|
1410
|
-
denominator: npt.NDArray[np.
|
1409
|
+
numerator: npt.NDArray[np.float64],
|
1410
|
+
denominator: npt.NDArray[np.float64],
|
1411
1411
|
metric: str,
|
1412
1412
|
modifier: str,
|
1413
1413
|
average: Optional[str] = None,
|
1414
1414
|
warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
|
1415
1415
|
zero_division: Union[str, int] = "warn",
|
1416
|
-
) -> npt.NDArray[np.
|
1416
|
+
) -> npt.NDArray[np.float64]:
|
1417
1417
|
"""Performs division and handles divide-by-zero.
|
1418
1418
|
|
1419
1419
|
On zero-division, sets the corresponding result elements equal to
|
@@ -1436,7 +1436,7 @@ def _prf_divide(
|
|
1436
1436
|
"warn", this acts as 0, but warnings are also raised.
|
1437
1437
|
|
1438
1438
|
Returns:
|
1439
|
-
npt.NDArray[np.
|
1439
|
+
npt.NDArray[np.float64]: Result of the division, an array of floats.
|
1440
1440
|
"""
|
1441
1441
|
mask = denominator == 0.0
|
1442
1442
|
denominator = denominator.copy()
|
@@ -1522,7 +1522,7 @@ def _check_zero_division(zero_division: Union[int, float, str]) -> float:
|
|
1522
1522
|
return np.nan
|
1523
1523
|
|
1524
1524
|
|
1525
|
-
def _nanaverage(a: npt.NDArray[np.
|
1525
|
+
def _nanaverage(a: npt.NDArray[np.float64], weights: Optional[npt.ArrayLike] = None) -> Any:
|
1526
1526
|
"""Compute the weighted average, ignoring NaNs.
|
1527
1527
|
|
1528
1528
|
Args:
|
@@ -26,7 +26,7 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
|
|
26
26
|
The below steps explain how correlation matrix is computed in a distributed way:
|
27
27
|
Let n = # of rows in the dataframe; sqrt_n = sqrt(n); X, Y are 2 columns in the dataframe
|
28
28
|
Correlation(X, Y) = numerator/denominator where
|
29
|
-
numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(
|
29
|
+
numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(Y/n)
|
30
30
|
denominator = std_dev(X)*std_dev(Y)
|
31
31
|
std_dev(X) = sqrt(dot(X/sqrt_n, X/sqrt_n) - sum(X/n)*sum(X/n))
|
32
32
|
|
@@ -74,27 +74,38 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
|
|
74
74
|
# Pushing this to a udtf requires creating a temp udtf which takes about 20 secs, so it doesn't make sense
|
75
75
|
# to have this in a udtf.
|
76
76
|
n_cols = len(columns)
|
77
|
-
|
78
|
-
|
77
|
+
column_means = np.zeros(n_cols)
|
78
|
+
mean_of_squares = np.zeros(n_cols)
|
79
79
|
dot_prod = np.zeros((n_cols, n_cols))
|
80
80
|
# Get sum, dot_prod and squared sum array from the results.
|
81
81
|
for i in range(len(results)):
|
82
82
|
x = results[i]
|
83
83
|
if x[1] == "sum_by_count":
|
84
|
-
|
84
|
+
column_means = cloudpickle.loads(x[0])
|
85
85
|
else:
|
86
86
|
row = int(x[1].strip("row_"))
|
87
87
|
dot_prod[row, :] = cloudpickle.loads(x[0])
|
88
|
-
|
88
|
+
mean_of_squares[row] = dot_prod[row, row]
|
89
89
|
|
90
90
|
# sum(X/n)*sum(Y/n) is computed for all combinations of X,Y (columns in the dataframe)
|
91
|
-
exey_arr = np.einsum("t,m->tm",
|
91
|
+
exey_arr = np.einsum("t,m->tm", column_means, column_means, optimize="optimal")
|
92
92
|
numerator_matrix = dot_prod - exey_arr
|
93
93
|
|
94
94
|
# standard deviation for all columns in the dataframe
|
95
|
-
|
95
|
+
variance_arr = mean_of_squares - np.einsum("i, i -> i", column_means, column_means, optimize="optimal")
|
96
|
+
# ensure non-negative values from potential precision issues where variance might be slightly negative
|
97
|
+
variance_arr = np.maximum(variance_arr, 0)
|
98
|
+
stddev_arr = np.sqrt(variance_arr)
|
96
99
|
# std_dev(X)*std_dev(Y) is computed for all combinations of X,Y (columns in the dataframe)
|
97
100
|
denominator_matrix = np.einsum("t,m->tm", stddev_arr, stddev_arr, optimize="optimal")
|
98
|
-
|
101
|
+
|
102
|
+
# Use np.divide to handle NaN cases
|
103
|
+
corr_res = np.divide(
|
104
|
+
numerator_matrix,
|
105
|
+
denominator_matrix,
|
106
|
+
out=np.full_like(numerator_matrix, np.nan),
|
107
|
+
where=(denominator_matrix != 0),
|
108
|
+
)
|
109
|
+
|
99
110
|
correlation_matrix = pd.DataFrame(corr_res, columns=columns, index=columns)
|
100
111
|
return correlation_matrix
|
@@ -60,6 +60,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
|
|
60
60
|
),
|
61
61
|
input_types=[T.BinaryType()],
|
62
62
|
packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
|
63
|
+
imports=[], # Prevents unnecessary import resolution.
|
63
64
|
name=accumulator,
|
64
65
|
is_permanent=False,
|
65
66
|
replace=True,
|
@@ -175,6 +176,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
|
|
175
176
|
),
|
176
177
|
input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
|
177
178
|
packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
|
179
|
+
imports=[], # Prevents unnecessary import resolution.
|
178
180
|
name=sharded_dot_and_sum_computer,
|
179
181
|
is_permanent=False,
|
180
182
|
replace=True,
|
@@ -26,7 +26,7 @@ def precision_recall_curve(
|
|
26
26
|
probas_pred_col_name: str,
|
27
27
|
pos_label: Optional[Union[str, int]] = None,
|
28
28
|
sample_weight_col_name: Optional[str] = None,
|
29
|
-
) -> tuple[npt.NDArray[np.
|
29
|
+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
|
30
30
|
"""
|
31
31
|
Compute precision-recall pairs for different probability thresholds.
|
32
32
|
|
@@ -125,7 +125,7 @@ def precision_recall_curve(
|
|
125
125
|
|
126
126
|
kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
|
127
127
|
result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
|
128
|
-
res: tuple[npt.NDArray[np.
|
128
|
+
res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
|
129
129
|
return res
|
130
130
|
|
131
131
|
|
@@ -140,7 +140,7 @@ def roc_auc_score(
|
|
140
140
|
max_fpr: Optional[float] = None,
|
141
141
|
multi_class: str = "raise",
|
142
142
|
labels: Optional[npt.ArrayLike] = None,
|
143
|
-
) -> Union[float, npt.NDArray[np.
|
143
|
+
) -> Union[float, npt.NDArray[np.float64]]:
|
144
144
|
"""
|
145
145
|
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
|
146
146
|
from prediction scores.
|
@@ -276,7 +276,7 @@ def roc_auc_score(
|
|
276
276
|
|
277
277
|
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_auc_score_anon_sproc, statement_params)
|
278
278
|
result_object = result.deserialize(session, roc_auc_score_anon_sproc(session, **kwargs))
|
279
|
-
auc: Union[float, npt.NDArray[np.
|
279
|
+
auc: Union[float, npt.NDArray[np.float64]] = result_object
|
280
280
|
return auc
|
281
281
|
|
282
282
|
|
@@ -289,7 +289,7 @@ def roc_curve(
|
|
289
289
|
pos_label: Optional[Union[str, int]] = None,
|
290
290
|
sample_weight_col_name: Optional[str] = None,
|
291
291
|
drop_intermediate: bool = True,
|
292
|
-
) -> tuple[npt.NDArray[np.
|
292
|
+
) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
|
293
293
|
"""
|
294
294
|
Compute Receiver operating characteristic (ROC).
|
295
295
|
|
@@ -380,6 +380,6 @@ def roc_curve(
|
|
380
380
|
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
|
381
381
|
result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
|
382
382
|
|
383
|
-
res: tuple[npt.NDArray[np.
|
383
|
+
res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
|
384
384
|
|
385
385
|
return res
|