snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ import os
2
+ import sys
3
+ from typing import Any, Optional
4
+
5
+
6
+ class _TqdmStatusContext:
7
+ """A tqdm-based context manager for status updates."""
8
+
9
+ def __init__(self, label: str, tqdm_module: Any, total: Optional[int] = None) -> None:
10
+ self._label = label
11
+ self._tqdm = tqdm_module
12
+ self._total = total or 1
13
+
14
+ def __enter__(self) -> "_TqdmStatusContext":
15
+ self._progress_bar = self._tqdm.tqdm(desc=self._label, file=sys.stdout, total=self._total, leave=True)
16
+ return self
17
+
18
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
19
+ self._progress_bar.close()
20
+
21
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
22
+ """Update the status by updating the tqdm description."""
23
+ if state == "complete":
24
+ self._progress_bar.update(self._progress_bar.total - self._progress_bar.n)
25
+ self._progress_bar.set_description(label)
26
+ else:
27
+ self._progress_bar.set_description(f"{self._label}: {label}")
28
+
29
+ def increment(self, n: int = 1) -> None:
30
+ """Increment the progress bar."""
31
+ self._progress_bar.update(n)
32
+
33
+
34
+ class _StreamlitStatusContext:
35
+ """A streamlit-based context manager for status updates with progress bar support."""
36
+
37
+ def __init__(self, label: str, streamlit_module: Any, total: Optional[int] = None) -> None:
38
+ self._label = label
39
+ self._streamlit = streamlit_module
40
+ self._total = total
41
+ self._current = 0
42
+ self._progress_bar = None
43
+
44
+ def __enter__(self) -> "_StreamlitStatusContext":
45
+ self._status_container = self._streamlit.status(self._label, state="running", expanded=True)
46
+ if self._total is not None:
47
+ with self._status_container:
48
+ self._progress_bar = self._streamlit.progress(0, text=f"0/{self._total}")
49
+ return self
50
+
51
+ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
52
+ self._status_container.update(state="complete")
53
+
54
+ def update(self, label: str, *, state: str = "running", expanded: bool = True) -> None:
55
+ """Update the status label."""
56
+ if state != "complete":
57
+ label = f"{self._label}: {label}"
58
+ self._status_container.update(label=label, state=state, expanded=expanded)
59
+ if self._progress_bar is not None:
60
+ self._progress_bar.progress(
61
+ self._current / self._total if self._total > 0 else 0,
62
+ text=f"{label} - {self._current}/{self._total}",
63
+ )
64
+
65
+ def increment(self, n: int = 1) -> None:
66
+ """Increment the progress."""
67
+ if self._total is not None:
68
+ self._current = min(self._current + n, self._total)
69
+ if self._progress_bar is not None:
70
+ progress_value = self._current / self._total if self._total > 0 else 0
71
+ self._progress_bar.progress(progress_value, text=f"{self._current}/{self._total}")
72
+
73
+
74
+ class ModelEventHandler:
75
+ """Event handler for model operations with streamlit-aware status updates."""
76
+
77
+ def __init__(self) -> None:
78
+ self._streamlit = None
79
+
80
+ # Try streamlit first
81
+ try:
82
+ import streamlit as st
83
+
84
+ if st.runtime.exists():
85
+ USE_STREAMLIT_WIDGETS = os.getenv("USE_STREAMLIT_WIDGETS", "1") == "1"
86
+ if USE_STREAMLIT_WIDGETS:
87
+ self._streamlit = st
88
+ except ImportError:
89
+ pass
90
+
91
+ import tqdm
92
+
93
+ self._tqdm = tqdm
94
+
95
+ def update(self, message: str) -> None:
96
+ """Write a message using streamlit if available, otherwise use tqdm."""
97
+ if self._streamlit is not None:
98
+ self._streamlit.write(message)
99
+ else:
100
+ self._tqdm.tqdm.write(message)
101
+
102
+ def status(self, label: str, *, state: str = "running", expanded: bool = True, total: Optional[int] = None) -> Any:
103
+ """Context manager that provides status updates with optional enhanced display capabilities.
104
+
105
+ Args:
106
+ label: The status label
107
+ state: The initial state ("running", "complete", "error")
108
+ expanded: Whether to show expanded view (streamlit only)
109
+ total: Total number of steps for progress tracking (optional)
110
+
111
+ Returns:
112
+ Status context (Streamlit or Tqdm)
113
+ """
114
+ if self._streamlit is not None:
115
+ return _StreamlitStatusContext(label, self._streamlit, total)
116
+ else:
117
+ return _TqdmStatusContext(label, self._tqdm, total)
@@ -16,7 +16,7 @@ from snowflake.ml._internal.exceptions import (
16
16
  exceptions as snowml_exceptions,
17
17
  )
18
18
  from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
19
- from snowflake.ml.model import type_hints as model_types
19
+ from snowflake.ml.model import type_hints
20
20
  from snowflake.ml.model._signatures import (
21
21
  base_handler,
22
22
  builtins_handler,
@@ -55,9 +55,9 @@ _MODEL_TELEMETRY_SUBPROJECT = "ModelSignature"
55
55
 
56
56
 
57
57
  def _truncate_data(
58
- data: model_types.SupportedDataType,
58
+ data: type_hints.SupportedDataType,
59
59
  length: Optional[int] = 100,
60
- ) -> model_types.SupportedDataType:
60
+ ) -> type_hints.SupportedDataType:
61
61
  for handler in _ALL_DATA_HANDLERS:
62
62
  if handler.can_handle(data):
63
63
  # If length is None, return the original data
@@ -89,7 +89,7 @@ def _truncate_data(
89
89
 
90
90
 
91
91
  def _infer_signature(
92
- data: model_types.SupportedLocalDataType, role: Literal["input", "output"], use_snowflake_identifiers: bool = False
92
+ data: type_hints.SupportedLocalDataType, role: Literal["input", "output"], use_snowflake_identifiers: bool = False
93
93
  ) -> Sequence[core.BaseFeatureSpec]:
94
94
  """Infer the inputs/outputs signature given a data that could be dataframe, numpy array or list.
95
95
  Dispatching is used to separate logic for different types.
@@ -142,7 +142,7 @@ def _rename_signature_with_snowflake_identifiers(
142
142
 
143
143
 
144
144
  def _validate_array_or_series_type(
145
- arr: Union[model_types._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
145
+ arr: Union[type_hints._SupportedNumpyArray, pd.Series], feature_type: core.DataType, strict: bool = False
146
146
  ) -> bool:
147
147
  original_dtype = arr.dtype
148
148
  dtype = arr.dtype
@@ -649,7 +649,7 @@ def _validate_snowpark_type_feature(
649
649
 
650
650
 
651
651
  def _convert_local_data_to_df(
652
- data: model_types.SupportedLocalDataType, ensure_serializable: bool = False
652
+ data: type_hints.SupportedLocalDataType, ensure_serializable: bool = False
653
653
  ) -> pd.DataFrame:
654
654
  """Convert local data to pandas DataFrame or Snowpark DataFrame
655
655
 
@@ -679,7 +679,7 @@ def _convert_local_data_to_df(
679
679
 
680
680
 
681
681
  def _convert_and_validate_local_data(
682
- data: model_types.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False
682
+ data: type_hints.SupportedLocalDataType, features: Sequence[core.BaseFeatureSpec], strict: bool = False
683
683
  ) -> pd.DataFrame:
684
684
  """Validate the data with features in model signature and convert to DataFrame
685
685
 
@@ -703,8 +703,8 @@ def _convert_and_validate_local_data(
703
703
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
704
704
  )
705
705
  def infer_signature(
706
- input_data: model_types.SupportedLocalDataType,
707
- output_data: model_types.SupportedLocalDataType,
706
+ input_data: type_hints.SupportedLocalDataType,
707
+ output_data: type_hints.SupportedLocalDataType,
708
708
  input_feature_names: Optional[list[str]] = None,
709
709
  output_feature_names: Optional[list[str]] = None,
710
710
  input_data_limit: Optional[int] = 100,
@@ -1,8 +1,22 @@
1
+ import logging
1
2
  import warnings
2
- from typing import Any, Optional
3
+ from typing import Any, Optional, Union
3
4
 
4
5
  from packaging import version
5
6
 
7
+ from snowflake import snowpark
8
+ from snowflake.ml._internal import telemetry
9
+ from snowflake.ml._internal.human_readable_id import hrid_generator
10
+ from snowflake.ml._internal.utils import sql_identifier
11
+ from snowflake.ml.model._client.ops import service_ops
12
+ from snowflake.snowpark import async_job, session
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ _TELEMETRY_PROJECT = "MLOps"
18
+ _TELEMETRY_SUBPROJECT = "ModelManagement"
19
+
6
20
 
7
21
  class HuggingFacePipelineModel:
8
22
  def __init__(
@@ -214,4 +228,159 @@ class HuggingFacePipelineModel:
214
228
  self.token = token
215
229
  self.trust_remote_code = trust_remote_code
216
230
  self.model_kwargs = model_kwargs
231
+ self.tokenizer = tokenizer
217
232
  self.__dict__.update(kwargs)
233
+
234
+ @telemetry.send_api_usage_telemetry(
235
+ project=_TELEMETRY_PROJECT,
236
+ subproject=_TELEMETRY_SUBPROJECT,
237
+ func_params_to_log=[
238
+ "service_name",
239
+ "image_build_compute_pool",
240
+ "service_compute_pool",
241
+ "image_repo",
242
+ "gpu_requests",
243
+ "num_workers",
244
+ "max_batch_rows",
245
+ ],
246
+ )
247
+ @snowpark._internal.utils.private_preview(version="1.9.1")
248
+ def create_service(
249
+ self,
250
+ *,
251
+ session: session.Session,
252
+ # registry.log_model parameters
253
+ model_name: str,
254
+ version_name: Optional[str] = None,
255
+ pip_requirements: Optional[list[str]] = None,
256
+ conda_dependencies: Optional[list[str]] = None,
257
+ comment: Optional[str] = None,
258
+ # model_version_impl.create_service parameters
259
+ service_name: str,
260
+ service_compute_pool: str,
261
+ image_repo: str,
262
+ image_build_compute_pool: Optional[str] = None,
263
+ ingress_enabled: bool = False,
264
+ max_instances: int = 1,
265
+ cpu_requests: Optional[str] = None,
266
+ memory_requests: Optional[str] = None,
267
+ gpu_requests: Optional[Union[str, int]] = None,
268
+ num_workers: Optional[int] = None,
269
+ max_batch_rows: Optional[int] = None,
270
+ force_rebuild: bool = False,
271
+ build_external_access_integrations: Optional[list[str]] = None,
272
+ block: bool = True,
273
+ ) -> Union[str, async_job.AsyncJob]:
274
+ """Logs a Hugging Face model and creates a service in Snowflake.
275
+
276
+ Args:
277
+ session: The Snowflake session object.
278
+ model_name: The name of the model in Snowflake.
279
+ version_name: The version name of the model. Defaults to None.
280
+ pip_requirements: Pip requirements for the model. Defaults to None.
281
+ conda_dependencies: Conda dependencies for the model. Defaults to None.
282
+ comment: Comment for the model. Defaults to None.
283
+ service_name: The name of the service to create.
284
+ service_compute_pool: The compute pool for the service.
285
+ image_repo: The name of the image repository.
286
+ image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
287
+ the service compute pool if None.
288
+ ingress_enabled: Whether ingress is enabled. Defaults to False.
289
+ max_instances: Maximum number of instances. Defaults to 1.
290
+ cpu_requests: CPU requests configuration. Defaults to None.
291
+ memory_requests: Memory requests configuration. Defaults to None.
292
+ gpu_requests: GPU requests configuration. Defaults to None.
293
+ num_workers: Number of workers. Defaults to None.
294
+ max_batch_rows: Maximum batch rows. Defaults to None.
295
+ force_rebuild: Whether to force rebuild the image. Defaults to False.
296
+ build_external_access_integrations: External access integrations for building the image. Defaults to None.
297
+ block: Whether to block the operation. Defaults to True.
298
+
299
+ Raises:
300
+ ValueError: if database and schema name is not provided and session doesn't have a
301
+ database and schema name.
302
+
303
+ Returns:
304
+ The service ID or an async job object.
305
+
306
+ .. # noqa: DAR003
307
+ """
308
+ statement_params = telemetry.get_statement_params(
309
+ project=_TELEMETRY_PROJECT,
310
+ subproject=_TELEMETRY_SUBPROJECT,
311
+ )
312
+
313
+ database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
314
+ session_database_name = session.get_current_database()
315
+ session_schema_name = session.get_current_schema()
316
+ if database_name_id is None:
317
+ if session_database_name is None:
318
+ raise ValueError("Either database needs to be provided or needs to be available in session.")
319
+ database_name_id = sql_identifier.SqlIdentifier(session_database_name)
320
+ if schema_name_id is None:
321
+ if session_schema_name is None:
322
+ raise ValueError("Either schema needs to be provided or needs to be available in session.")
323
+ schema_name_id = sql_identifier.SqlIdentifier(session_schema_name)
324
+
325
+ if version_name is None:
326
+ name_generator = hrid_generator.HRID16()
327
+ version_name = name_generator.generate()[1]
328
+
329
+ service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
330
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
331
+
332
+ service_operator = service_ops.ServiceOperator(
333
+ session=session,
334
+ database_name=database_name_id,
335
+ schema_name=schema_name_id,
336
+ )
337
+ logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
338
+
339
+ return service_operator.create_service(
340
+ database_name=database_name_id,
341
+ schema_name=schema_name_id,
342
+ model_name=model_name_id,
343
+ version_name=sql_identifier.SqlIdentifier(version_name),
344
+ service_database_name=service_db_id,
345
+ service_schema_name=service_schema_id,
346
+ service_name=service_id,
347
+ image_build_compute_pool_name=(
348
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
349
+ if image_build_compute_pool
350
+ else sql_identifier.SqlIdentifier(service_compute_pool)
351
+ ),
352
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
353
+ image_repo_database_name=image_repo_db_id,
354
+ image_repo_schema_name=image_repo_schema_id,
355
+ image_repo_name=image_repo_id,
356
+ ingress_enabled=ingress_enabled,
357
+ max_instances=max_instances,
358
+ cpu_requests=cpu_requests,
359
+ memory_requests=memory_requests,
360
+ gpu_requests=gpu_requests,
361
+ num_workers=num_workers,
362
+ max_batch_rows=max_batch_rows,
363
+ force_rebuild=force_rebuild,
364
+ build_external_access_integrations=(
365
+ None
366
+ if build_external_access_integrations is None
367
+ else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
368
+ ),
369
+ block=block,
370
+ statement_params=statement_params,
371
+ # hf model
372
+ hf_model_args=service_ops.HFModelArgs(
373
+ hf_model_name=self.model,
374
+ hf_task=self.task,
375
+ hf_tokenizer=self.tokenizer,
376
+ hf_revision=self.revision,
377
+ hf_token=self.token,
378
+ hf_trust_remote_code=bool(self.trust_remote_code),
379
+ hf_model_kwargs=self.model_kwargs,
380
+ pip_requirements=pip_requirements,
381
+ conda_dependencies=conda_dependencies,
382
+ comment=comment,
383
+ # TODO: remove warehouse in the next release
384
+ warehouse=session.get_current_warehouse(),
385
+ ),
386
+ )
@@ -0,0 +1,11 @@
1
+ from enum import Enum
2
+
3
+
4
+ class TargetPlatform(Enum):
5
+ WAREHOUSE = "WAREHOUSE"
6
+ SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
7
+
8
+
9
+ WAREHOUSE_ONLY = [TargetPlatform.WAREHOUSE]
10
+ SNOWPARK_CONTAINER_SERVICES_ONLY = [TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
11
+ BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES = [TargetPlatform.WAREHOUSE, TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Task(Enum):
5
+ UNKNOWN = "UNKNOWN"
6
+ TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
7
+ TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
8
+ TABULAR_REGRESSION = "TABULAR_REGRESSION"
9
+ TABULAR_RANKING = "TABULAR_RANKING"
@@ -1,10 +1,12 @@
1
1
  # mypy: disable-error-code="import"
2
- from enum import Enum
3
2
  from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
4
3
 
5
4
  import numpy.typing as npt
6
5
  from typing_extensions import NotRequired
7
6
 
7
+ from snowflake.ml.model.target_platform import TargetPlatform
8
+ from snowflake.ml.model.task import Task
9
+
8
10
  if TYPE_CHECKING:
9
11
  import catboost
10
12
  import keras
@@ -321,17 +323,7 @@ ModelLoadOption = Union[
321
323
  ]
322
324
 
323
325
 
324
- class Task(Enum):
325
- UNKNOWN = "UNKNOWN"
326
- TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
327
- TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
328
- TABULAR_REGRESSION = "TABULAR_REGRESSION"
329
- TABULAR_RANKING = "TABULAR_RANKING"
330
-
331
-
332
- class TargetPlatform(Enum):
333
- WAREHOUSE = "WAREHOUSE"
334
- SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
326
+ SupportedTargetPlatformType = Union[TargetPlatform, str]
335
327
 
336
328
 
337
- SupportedTargetPlatformType = Union[TargetPlatform, str]
329
+ __all__ = ["TargetPlatform", "Task"]
@@ -698,7 +698,7 @@ class BaseTransformer(BaseEstimator):
698
698
  self,
699
699
  attribute: Optional[Mapping[str, Union[int, float, str, Iterable[Union[int, float, str]]]]],
700
700
  dtype: Optional[type] = None,
701
- ) -> Optional[npt.NDArray[Union[np.int_, np.float_, np.str_]]]:
701
+ ) -> Optional[npt.NDArray[Union[np.int_, np.float64, np.str_]]]:
702
702
  """
703
703
  Convert the attribute from dict to ndarray based on the order of `self.input_cols`.
704
704
 
@@ -96,7 +96,7 @@ def confusion_matrix(
96
96
  labels: Optional[npt.ArrayLike] = None,
97
97
  sample_weight_col_name: Optional[str] = None,
98
98
  normalize: Optional[str] = None,
99
- ) -> Union[npt.NDArray[np.int_], npt.NDArray[np.float_]]:
99
+ ) -> Union[npt.NDArray[np.int_], npt.NDArray[np.float64]]:
100
100
  """
101
101
  Compute confusion matrix to evaluate the accuracy of a classification.
102
102
 
@@ -320,7 +320,7 @@ def f1_score(
320
320
  average: Optional[str] = "binary",
321
321
  sample_weight_col_name: Optional[str] = None,
322
322
  zero_division: Union[str, int] = "warn",
323
- ) -> Union[float, npt.NDArray[np.float_]]:
323
+ ) -> Union[float, npt.NDArray[np.float64]]:
324
324
  """
325
325
  Compute the F1 score, also known as balanced F-score or F-measure.
326
326
 
@@ -414,7 +414,7 @@ def fbeta_score(
414
414
  average: Optional[str] = "binary",
415
415
  sample_weight_col_name: Optional[str] = None,
416
416
  zero_division: Union[str, int] = "warn",
417
- ) -> Union[float, npt.NDArray[np.float_]]:
417
+ ) -> Union[float, npt.NDArray[np.float64]]:
418
418
  """
419
419
  Compute the F-beta score.
420
420
 
@@ -696,7 +696,7 @@ def precision_recall_fscore_support(
696
696
  zero_division: Union[str, int] = "warn",
697
697
  ) -> Union[
698
698
  tuple[float, float, float, None],
699
- tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
699
+ tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
700
700
  ]:
701
701
  """
702
702
  Compute precision, recall, F-measure and support for each class.
@@ -855,7 +855,7 @@ def precision_recall_fscore_support(
855
855
 
856
856
  res: Union[
857
857
  tuple[float, float, float, None],
858
- tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
858
+ tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]],
859
859
  ] = result_object[:4]
860
860
  warning = result_object[-1]
861
861
  if warning:
@@ -1050,7 +1050,7 @@ def _register_multilabel_confusion_matrix_computer(
1050
1050
 
1051
1051
  def end_partition(
1052
1052
  self,
1053
- ) -> Iterable[tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]]:
1053
+ ) -> Iterable[tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]]:
1054
1054
  MCM = metrics.multilabel_confusion_matrix(
1055
1055
  self._y_true,
1056
1056
  self._y_pred,
@@ -1098,7 +1098,7 @@ def _binary_precision_score(
1098
1098
  pos_label: Union[str, int] = 1,
1099
1099
  sample_weight_col_name: Optional[str] = None,
1100
1100
  zero_division: Union[str, int] = "warn",
1101
- ) -> Union[float, npt.NDArray[np.float_]]:
1101
+ ) -> Union[float, npt.NDArray[np.float64]]:
1102
1102
 
1103
1103
  statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
1104
1104
 
@@ -1173,7 +1173,7 @@ def precision_score(
1173
1173
  average: Optional[str] = "binary",
1174
1174
  sample_weight_col_name: Optional[str] = None,
1175
1175
  zero_division: Union[str, int] = "warn",
1176
- ) -> Union[float, npt.NDArray[np.float_]]:
1176
+ ) -> Union[float, npt.NDArray[np.float64]]:
1177
1177
  """
1178
1178
  Compute the precision.
1179
1179
 
@@ -1271,7 +1271,7 @@ def recall_score(
1271
1271
  average: Optional[str] = "binary",
1272
1272
  sample_weight_col_name: Optional[str] = None,
1273
1273
  zero_division: Union[str, int] = "warn",
1274
- ) -> Union[float, npt.NDArray[np.float_]]:
1274
+ ) -> Union[float, npt.NDArray[np.float64]]:
1275
1275
  """
1276
1276
  Compute the recall.
1277
1277
 
@@ -1406,14 +1406,14 @@ def _check_binary_labels(
1406
1406
 
1407
1407
 
1408
1408
  def _prf_divide(
1409
- numerator: npt.NDArray[np.float_],
1410
- denominator: npt.NDArray[np.float_],
1409
+ numerator: npt.NDArray[np.float64],
1410
+ denominator: npt.NDArray[np.float64],
1411
1411
  metric: str,
1412
1412
  modifier: str,
1413
1413
  average: Optional[str] = None,
1414
1414
  warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
1415
1415
  zero_division: Union[str, int] = "warn",
1416
- ) -> npt.NDArray[np.float_]:
1416
+ ) -> npt.NDArray[np.float64]:
1417
1417
  """Performs division and handles divide-by-zero.
1418
1418
 
1419
1419
  On zero-division, sets the corresponding result elements equal to
@@ -1436,7 +1436,7 @@ def _prf_divide(
1436
1436
  "warn", this acts as 0, but warnings are also raised.
1437
1437
 
1438
1438
  Returns:
1439
- npt.NDArray[np.float_]: Result of the division, an array of floats.
1439
+ npt.NDArray[np.float64]: Result of the division, an array of floats.
1440
1440
  """
1441
1441
  mask = denominator == 0.0
1442
1442
  denominator = denominator.copy()
@@ -1522,7 +1522,7 @@ def _check_zero_division(zero_division: Union[int, float, str]) -> float:
1522
1522
  return np.nan
1523
1523
 
1524
1524
 
1525
- def _nanaverage(a: npt.NDArray[np.float_], weights: Optional[npt.ArrayLike] = None) -> Any:
1525
+ def _nanaverage(a: npt.NDArray[np.float64], weights: Optional[npt.ArrayLike] = None) -> Any:
1526
1526
  """Compute the weighted average, ignoring NaNs.
1527
1527
 
1528
1528
  Args:
@@ -26,7 +26,7 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
26
26
  The below steps explain how correlation matrix is computed in a distributed way:
27
27
  Let n = # of rows in the dataframe; sqrt_n = sqrt(n); X, Y are 2 columns in the dataframe
28
28
  Correlation(X, Y) = numerator/denominator where
29
- numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(X/n)
29
+ numerator = dot(X/sqrt_n, Y/sqrt_n) - sum(X/n)*sum(Y/n)
30
30
  denominator = std_dev(X)*std_dev(Y)
31
31
  std_dev(X) = sqrt(dot(X/sqrt_n, X/sqrt_n) - sum(X/n)*sum(X/n))
32
32
 
@@ -74,27 +74,38 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
74
74
  # Pushing this to a udtf requires creating a temp udtf which takes about 20 secs, so it doesn't make sense
75
75
  # to have this in a udtf.
76
76
  n_cols = len(columns)
77
- sum_arr = np.zeros(n_cols)
78
- squared_sum_arr = np.zeros(n_cols)
77
+ column_means = np.zeros(n_cols)
78
+ mean_of_squares = np.zeros(n_cols)
79
79
  dot_prod = np.zeros((n_cols, n_cols))
80
80
  # Get sum, dot_prod and squared sum array from the results.
81
81
  for i in range(len(results)):
82
82
  x = results[i]
83
83
  if x[1] == "sum_by_count":
84
- sum_arr = cloudpickle.loads(x[0])
84
+ column_means = cloudpickle.loads(x[0])
85
85
  else:
86
86
  row = int(x[1].strip("row_"))
87
87
  dot_prod[row, :] = cloudpickle.loads(x[0])
88
- squared_sum_arr[row] = dot_prod[row, row]
88
+ mean_of_squares[row] = dot_prod[row, row]
89
89
 
90
90
  # sum(X/n)*sum(Y/n) is computed for all combinations of X,Y (columns in the dataframe)
91
- exey_arr = np.einsum("t,m->tm", sum_arr, sum_arr, optimize="optimal")
91
+ exey_arr = np.einsum("t,m->tm", column_means, column_means, optimize="optimal")
92
92
  numerator_matrix = dot_prod - exey_arr
93
93
 
94
94
  # standard deviation for all columns in the dataframe
95
- stddev_arr = np.sqrt(squared_sum_arr - np.einsum("i, i -> i", sum_arr, sum_arr, optimize="optimal"))
95
+ variance_arr = mean_of_squares - np.einsum("i, i -> i", column_means, column_means, optimize="optimal")
96
+ # ensure non-negative values from potential precision issues where variance might be slightly negative
97
+ variance_arr = np.maximum(variance_arr, 0)
98
+ stddev_arr = np.sqrt(variance_arr)
96
99
  # std_dev(X)*std_dev(Y) is computed for all combinations of X,Y (columns in the dataframe)
97
100
  denominator_matrix = np.einsum("t,m->tm", stddev_arr, stddev_arr, optimize="optimal")
98
- corr_res = numerator_matrix / denominator_matrix
101
+
102
+ # Use np.divide to handle NaN cases
103
+ corr_res = np.divide(
104
+ numerator_matrix,
105
+ denominator_matrix,
106
+ out=np.full_like(numerator_matrix, np.nan),
107
+ where=(denominator_matrix != 0),
108
+ )
109
+
99
110
  correlation_matrix = pd.DataFrame(corr_res, columns=columns, index=columns)
100
111
  return correlation_matrix
@@ -60,6 +60,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
60
60
  ),
61
61
  input_types=[T.BinaryType()],
62
62
  packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
63
+ imports=[], # Prevents unnecessary import resolution.
63
64
  name=accumulator,
64
65
  is_permanent=False,
65
66
  replace=True,
@@ -175,6 +176,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
175
176
  ),
176
177
  input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
177
178
  packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
179
+ imports=[], # Prevents unnecessary import resolution.
178
180
  name=sharded_dot_and_sum_computer,
179
181
  is_permanent=False,
180
182
  replace=True,
@@ -26,7 +26,7 @@ def precision_recall_curve(
26
26
  probas_pred_col_name: str,
27
27
  pos_label: Optional[Union[str, int]] = None,
28
28
  sample_weight_col_name: Optional[str] = None,
29
- ) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
29
+ ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
30
30
  """
31
31
  Compute precision-recall pairs for different probability thresholds.
32
32
 
@@ -125,7 +125,7 @@ def precision_recall_curve(
125
125
 
126
126
  kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
127
127
  result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
128
- res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
128
+ res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
129
129
  return res
130
130
 
131
131
 
@@ -140,7 +140,7 @@ def roc_auc_score(
140
140
  max_fpr: Optional[float] = None,
141
141
  multi_class: str = "raise",
142
142
  labels: Optional[npt.ArrayLike] = None,
143
- ) -> Union[float, npt.NDArray[np.float_]]:
143
+ ) -> Union[float, npt.NDArray[np.float64]]:
144
144
  """
145
145
  Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
146
146
  from prediction scores.
@@ -276,7 +276,7 @@ def roc_auc_score(
276
276
 
277
277
  kwargs = telemetry.get_sproc_statement_params_kwargs(roc_auc_score_anon_sproc, statement_params)
278
278
  result_object = result.deserialize(session, roc_auc_score_anon_sproc(session, **kwargs))
279
- auc: Union[float, npt.NDArray[np.float_]] = result_object
279
+ auc: Union[float, npt.NDArray[np.float64]] = result_object
280
280
  return auc
281
281
 
282
282
 
@@ -289,7 +289,7 @@ def roc_curve(
289
289
  pos_label: Optional[Union[str, int]] = None,
290
290
  sample_weight_col_name: Optional[str] = None,
291
291
  drop_intermediate: bool = True,
292
- ) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
292
+ ) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]]:
293
293
  """
294
294
  Compute Receiver operating characteristic (ROC).
295
295
 
@@ -380,6 +380,6 @@ def roc_curve(
380
380
  kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
381
381
  result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
382
382
 
383
- res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
383
+ res: tuple[npt.NDArray[np.float64], npt.NDArray[np.float64], npt.NDArray[np.float64]] = result_object
384
384
 
385
385
  return res