snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_classify_text.py +3 -3
- snowflake/cortex/_complete.py +23 -24
- snowflake/cortex/_embed_text_1024.py +4 -4
- snowflake/cortex/_embed_text_768.py +4 -4
- snowflake/cortex/_finetune.py +8 -8
- snowflake/cortex/_util.py +8 -12
- snowflake/ml/_internal/env.py +4 -3
- snowflake/ml/_internal/env_utils.py +63 -34
- snowflake/ml/_internal/file_utils.py +10 -21
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
- snowflake/ml/_internal/init_utils.py +2 -3
- snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
- snowflake/ml/_internal/platform_capabilities.py +6 -6
- snowflake/ml/_internal/telemetry.py +39 -52
- snowflake/ml/_internal/type_utils.py +3 -3
- snowflake/ml/_internal/utils/db_utils.py +2 -2
- snowflake/ml/_internal/utils/identifier.py +8 -8
- snowflake/ml/_internal/utils/import_utils.py +2 -2
- snowflake/ml/_internal/utils/parallelize.py +7 -7
- snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
- snowflake/ml/_internal/utils/query_result_checker.py +4 -4
- snowflake/ml/_internal/utils/snowflake_env.py +28 -6
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
- snowflake/ml/_internal/utils/sql_identifier.py +3 -3
- snowflake/ml/_internal/utils/table_manager.py +9 -9
- snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
- snowflake/ml/data/data_connector.py +15 -36
- snowflake/ml/data/data_ingestor.py +4 -15
- snowflake/ml/data/data_source.py +2 -2
- snowflake/ml/data/ingestor_utils.py +3 -3
- snowflake/ml/data/torch_utils.py +5 -5
- snowflake/ml/dataset/dataset.py +11 -11
- snowflake/ml/dataset/dataset_metadata.py +8 -8
- snowflake/ml/dataset/dataset_reader.py +7 -7
- snowflake/ml/feature_store/__init__.py +1 -1
- snowflake/ml/feature_store/access_manager.py +7 -7
- snowflake/ml/feature_store/entity.py +6 -6
- snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
- snowflake/ml/feature_store/examples/example_helper.py +16 -16
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
- snowflake/ml/feature_store/feature_store.py +52 -64
- snowflake/ml/feature_store/feature_view.py +24 -24
- snowflake/ml/fileset/embedded_stage_fs.py +5 -5
- snowflake/ml/fileset/fileset.py +5 -5
- snowflake/ml/fileset/sfcfs.py +13 -13
- snowflake/ml/fileset/stage_fs.py +15 -15
- snowflake/ml/jobs/_utils/interop_utils.py +10 -10
- snowflake/ml/jobs/_utils/payload_utils.py +6 -16
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
- snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
- snowflake/ml/jobs/_utils/spec_utils.py +17 -28
- snowflake/ml/jobs/_utils/types.py +2 -2
- snowflake/ml/jobs/decorators.py +4 -5
- snowflake/ml/jobs/job.py +24 -14
- snowflake/ml/jobs/manager.py +37 -41
- snowflake/ml/lineage/lineage_node.py +5 -5
- snowflake/ml/model/_client/model/model_impl.py +3 -3
- snowflake/ml/model/_client/model/model_version_impl.py +103 -35
- snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
- snowflake/ml/model/_client/ops/model_ops.py +41 -41
- snowflake/ml/model/_client/ops/service_ops.py +199 -26
- snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
- snowflake/ml/model/_client/sql/model.py +8 -8
- snowflake/ml/model/_client/sql/model_version.py +26 -26
- snowflake/ml/model/_client/sql/service.py +13 -13
- snowflake/ml/model/_client/sql/stage.py +2 -2
- snowflake/ml/model/_client/sql/tag.py +6 -6
- snowflake/ml/model/_model_composer/model_composer.py +17 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
- snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
- snowflake/ml/model/_packager/model_env/model_env.py +28 -25
- snowflake/ml/model/_packager/model_handler.py +4 -4
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
- snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
- snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
- snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
- snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
- snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
- snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
- snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
- snowflake/ml/model/_packager/model_packager.py +11 -9
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/_signatures/core.py +16 -24
- snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
- snowflake/ml/model/_signatures/utils.py +6 -6
- snowflake/ml/model/custom_model.py +8 -8
- snowflake/ml/model/model_signature.py +9 -20
- snowflake/ml/model/models/huggingface_pipeline.py +7 -4
- snowflake/ml/model/type_hints.py +3 -3
- snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
- snowflake/ml/modeling/_internal/model_specifications.py +8 -10
- snowflake/ml/modeling/_internal/model_trainer.py +5 -5
- snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
- snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
- snowflake/ml/modeling/framework/_utils.py +10 -10
- snowflake/ml/modeling/framework/base.py +32 -32
- snowflake/ml/modeling/impute/__init__.py +1 -1
- snowflake/ml/modeling/impute/simple_imputer.py +5 -5
- snowflake/ml/modeling/metrics/__init__.py +1 -1
- snowflake/ml/modeling/metrics/classification.py +39 -39
- snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
- snowflake/ml/modeling/metrics/ranking.py +7 -7
- snowflake/ml/modeling/metrics/regression.py +13 -13
- snowflake/ml/modeling/model_selection/__init__.py +1 -1
- snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
- snowflake/ml/modeling/pipeline/__init__.py +1 -1
- snowflake/ml/modeling/pipeline/pipeline.py +18 -18
- snowflake/ml/modeling/preprocessing/__init__.py +1 -1
- snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
- snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
- snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
- snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
- snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
- snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
- snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
- snowflake/ml/registry/_manager/model_manager.py +33 -31
- snowflake/ml/registry/registry.py +29 -22
- snowflake/ml/utils/authentication.py +2 -2
- snowflake/ml/utils/connection_params.py +5 -5
- snowflake/ml/utils/sparse.py +5 -4
- snowflake/ml/utils/sql_client.py +1 -2
- snowflake/ml/version.py +2 -1
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
- snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
- snowflake/ml/modeling/_internal/constants.py +0 -2
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import copy
|
3
3
|
import warnings
|
4
|
-
from typing import Any,
|
4
|
+
from typing import Any, Iterable, Optional, Union
|
5
5
|
|
6
6
|
import numpy as np
|
7
7
|
import numpy.typing as npt
|
@@ -25,7 +25,7 @@ STRATEGY_TO_STATE_DICT = {
|
|
25
25
|
"most_frequent": _utils.BasicStatistics.MODE,
|
26
26
|
}
|
27
27
|
|
28
|
-
SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP:
|
28
|
+
SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP: dict[type[T.DataType], npt.DTypeLike] = {
|
29
29
|
T.ByteType: np.dtype("int8"),
|
30
30
|
T.ShortType: np.dtype("int16"),
|
31
31
|
T.IntegerType: np.dtype("int32"),
|
@@ -164,7 +164,7 @@ class SimpleImputer(base.BaseTransformer):
|
|
164
164
|
|
165
165
|
self.fill_value = fill_value
|
166
166
|
self.missing_values = missing_values
|
167
|
-
self.statistics_:
|
167
|
+
self.statistics_: dict[str, Any] = {}
|
168
168
|
# TODO(hayu): [SNOW-752265] Support SimpleImputer keep_empty_features.
|
169
169
|
# Add back when `keep_empty_features` is supported.
|
170
170
|
# self.keep_empty_features = keep_empty_features
|
@@ -195,7 +195,7 @@ class SimpleImputer(base.BaseTransformer):
|
|
195
195
|
del self.feature_names_in_
|
196
196
|
del self._sklearn_fit_dtype
|
197
197
|
|
198
|
-
def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) ->
|
198
|
+
def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) -> dict[str, T.DataType]:
|
199
199
|
"""
|
200
200
|
Checks that the input columns are all the same datatype category(except for most_frequent strategy) and
|
201
201
|
returns the datatype.
|
@@ -211,7 +211,7 @@ class SimpleImputer(base.BaseTransformer):
|
|
211
211
|
supported.
|
212
212
|
"""
|
213
213
|
|
214
|
-
def check_type_consistency(col_types:
|
214
|
+
def check_type_consistency(col_types: dict[str, T.DataType]) -> None:
|
215
215
|
is_numeric_type = None
|
216
216
|
for col_name, col_type in col_types.items():
|
217
217
|
if is_numeric_type is None:
|
@@ -5,7 +5,7 @@ import cloudpickle
|
|
5
5
|
from snowflake.ml._internal import init_utils
|
6
6
|
from snowflake.ml._internal.utils import result
|
7
7
|
|
8
|
-
pkg_dir = os.path.dirname(
|
8
|
+
pkg_dir = os.path.dirname(__file__)
|
9
9
|
pkg_name = __name__
|
10
10
|
exportable_functions = init_utils.fetch_functions_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
11
11
|
for k, v in exportable_functions.items():
|
@@ -2,7 +2,7 @@ import inspect
|
|
2
2
|
import json
|
3
3
|
import math
|
4
4
|
import warnings
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Iterable, Optional, Union
|
6
6
|
|
7
7
|
import cloudpickle
|
8
8
|
import numpy as np
|
@@ -32,8 +32,8 @@ _SUBPROJECT = "Metrics"
|
|
32
32
|
def accuracy_score(
|
33
33
|
*,
|
34
34
|
df: snowpark.DataFrame,
|
35
|
-
y_true_col_names: Union[str,
|
36
|
-
y_pred_col_names: Union[str,
|
35
|
+
y_true_col_names: Union[str, list[str]],
|
36
|
+
y_pred_col_names: Union[str, list[str]],
|
37
37
|
normalize: bool = True,
|
38
38
|
sample_weight_col_name: Optional[str] = None,
|
39
39
|
) -> float:
|
@@ -221,7 +221,7 @@ def confusion_matrix(
|
|
221
221
|
return cm
|
222
222
|
|
223
223
|
|
224
|
-
def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_params:
|
224
|
+
def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_params: dict[str, Any]) -> str:
|
225
225
|
"""Registers confusion matrix computation UDTF in Snowflake and returns the name of the UDTF.
|
226
226
|
|
227
227
|
Args:
|
@@ -247,7 +247,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
247
247
|
# Number of labels.
|
248
248
|
self._n_label = 0
|
249
249
|
|
250
|
-
def process(self, input_row:
|
250
|
+
def process(self, input_row: list[float], n_label: int) -> None:
|
251
251
|
"""Computes confusion matrix.
|
252
252
|
|
253
253
|
Args:
|
@@ -270,7 +270,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
270
270
|
self.update_confusion_matrix()
|
271
271
|
self._cur_count = 0
|
272
272
|
|
273
|
-
def end_partition(self) -> Iterable[
|
273
|
+
def end_partition(self) -> Iterable[tuple[bytes, str]]:
|
274
274
|
# 3. Compute sum and dot_prod for the remaining rows in the batch.
|
275
275
|
if self._cur_count > 0:
|
276
276
|
self.update_confusion_matrix()
|
@@ -313,8 +313,8 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
|
|
313
313
|
def f1_score(
|
314
314
|
*,
|
315
315
|
df: snowpark.DataFrame,
|
316
|
-
y_true_col_names: Union[str,
|
317
|
-
y_pred_col_names: Union[str,
|
316
|
+
y_true_col_names: Union[str, list[str]],
|
317
|
+
y_pred_col_names: Union[str, list[str]],
|
318
318
|
labels: Optional[npt.ArrayLike] = None,
|
319
319
|
pos_label: Union[str, int] = 1,
|
320
320
|
average: Optional[str] = "binary",
|
@@ -406,8 +406,8 @@ def f1_score(
|
|
406
406
|
def fbeta_score(
|
407
407
|
*,
|
408
408
|
df: snowpark.DataFrame,
|
409
|
-
y_true_col_names: Union[str,
|
410
|
-
y_pred_col_names: Union[str,
|
409
|
+
y_true_col_names: Union[str, list[str]],
|
410
|
+
y_pred_col_names: Union[str, list[str]],
|
411
411
|
beta: float,
|
412
412
|
labels: Optional[npt.ArrayLike] = None,
|
413
413
|
pos_label: Union[str, int] = 1,
|
@@ -501,8 +501,8 @@ def fbeta_score(
|
|
501
501
|
def log_loss(
|
502
502
|
*,
|
503
503
|
df: snowpark.DataFrame,
|
504
|
-
y_true_col_names: Union[str,
|
505
|
-
y_pred_col_names: Union[str,
|
504
|
+
y_true_col_names: Union[str, list[str]],
|
505
|
+
y_pred_col_names: Union[str, list[str]],
|
506
506
|
eps: Union[float, str] = "auto",
|
507
507
|
normalize: bool = True,
|
508
508
|
sample_weight_col_name: Optional[str] = None,
|
@@ -625,7 +625,7 @@ def log_loss(
|
|
625
625
|
def _register_log_loss_computer(
|
626
626
|
*,
|
627
627
|
session: snowpark.Session,
|
628
|
-
statement_params:
|
628
|
+
statement_params: dict[str, Any],
|
629
629
|
labels: Optional[npt.ArrayLike] = None,
|
630
630
|
) -> str:
|
631
631
|
"""Registers log loss computation UDTF in Snowflake and returns the name of the UDTF.
|
@@ -644,16 +644,16 @@ def _register_log_loss_computer(
|
|
644
644
|
class LogLossComputer:
|
645
645
|
def __init__(self) -> None:
|
646
646
|
self._labels = labels
|
647
|
-
self._y_true:
|
648
|
-
self._y_pred:
|
649
|
-
self._sample_weight:
|
647
|
+
self._y_true: list[list[int]] = []
|
648
|
+
self._y_pred: list[list[float]] = []
|
649
|
+
self._sample_weight: list[float] = []
|
650
650
|
|
651
|
-
def process(self, y_true:
|
651
|
+
def process(self, y_true: list[int], y_pred: list[float], sample_weight: float) -> None:
|
652
652
|
self._y_true.append(y_true)
|
653
653
|
self._y_pred.append(y_pred)
|
654
654
|
self._sample_weight.append(sample_weight)
|
655
655
|
|
656
|
-
def end_partition(self) -> Iterable[
|
656
|
+
def end_partition(self) -> Iterable[tuple[float]]:
|
657
657
|
res = metrics.log_loss(
|
658
658
|
self._y_true,
|
659
659
|
self._y_pred,
|
@@ -685,18 +685,18 @@ def _register_log_loss_computer(
|
|
685
685
|
def precision_recall_fscore_support(
|
686
686
|
*,
|
687
687
|
df: snowpark.DataFrame,
|
688
|
-
y_true_col_names: Union[str,
|
689
|
-
y_pred_col_names: Union[str,
|
688
|
+
y_true_col_names: Union[str, list[str]],
|
689
|
+
y_pred_col_names: Union[str, list[str]],
|
690
690
|
beta: float = 1.0,
|
691
691
|
labels: Optional[npt.ArrayLike] = None,
|
692
692
|
pos_label: Union[str, int] = 1,
|
693
693
|
average: Optional[str] = None,
|
694
|
-
warn_for: Union[
|
694
|
+
warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
|
695
695
|
sample_weight_col_name: Optional[str] = None,
|
696
696
|
zero_division: Union[str, int] = "warn",
|
697
697
|
) -> Union[
|
698
|
-
|
699
|
-
|
698
|
+
tuple[float, float, float, None],
|
699
|
+
tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
|
700
700
|
]:
|
701
701
|
"""
|
702
702
|
Compute precision, recall, F-measure and support for each class.
|
@@ -854,8 +854,8 @@ def precision_recall_fscore_support(
|
|
854
854
|
result_object = result.deserialize(session, precision_recall_fscore_support_anon_sproc(session, **kwargs))
|
855
855
|
|
856
856
|
res: Union[
|
857
|
-
|
858
|
-
|
857
|
+
tuple[float, float, float, None],
|
858
|
+
tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
|
859
859
|
] = result_object[:4]
|
860
860
|
warning = result_object[-1]
|
861
861
|
if warning:
|
@@ -1039,18 +1039,18 @@ def _register_multilabel_confusion_matrix_computer(
|
|
1039
1039
|
def __init__(self) -> None:
|
1040
1040
|
self._labels = labels
|
1041
1041
|
self._samplewise = samplewise
|
1042
|
-
self._y_true:
|
1043
|
-
self._y_pred:
|
1044
|
-
self._sample_weight:
|
1042
|
+
self._y_true: list[list[int]] = []
|
1043
|
+
self._y_pred: list[list[int]] = []
|
1044
|
+
self._sample_weight: list[float] = []
|
1045
1045
|
|
1046
|
-
def process(self, y_true:
|
1046
|
+
def process(self, y_true: list[int], y_pred: list[int], sample_weight: float) -> None:
|
1047
1047
|
self._y_true.append(y_true)
|
1048
1048
|
self._y_pred.append(y_pred)
|
1049
1049
|
self._sample_weight.append(sample_weight)
|
1050
1050
|
|
1051
1051
|
def end_partition(
|
1052
1052
|
self,
|
1053
|
-
) -> Iterable[
|
1053
|
+
) -> Iterable[tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]]:
|
1054
1054
|
MCM = metrics.multilabel_confusion_matrix(
|
1055
1055
|
self._y_true,
|
1056
1056
|
self._y_pred,
|
@@ -1093,8 +1093,8 @@ def _register_multilabel_confusion_matrix_computer(
|
|
1093
1093
|
def _binary_precision_score(
|
1094
1094
|
*,
|
1095
1095
|
df: snowpark.DataFrame,
|
1096
|
-
y_true_col_names: Union[str,
|
1097
|
-
y_pred_col_names: Union[str,
|
1096
|
+
y_true_col_names: Union[str, list[str]],
|
1097
|
+
y_pred_col_names: Union[str, list[str]],
|
1098
1098
|
pos_label: Union[str, int] = 1,
|
1099
1099
|
sample_weight_col_name: Optional[str] = None,
|
1100
1100
|
zero_division: Union[str, int] = "warn",
|
@@ -1166,8 +1166,8 @@ def _binary_precision_score(
|
|
1166
1166
|
def precision_score(
|
1167
1167
|
*,
|
1168
1168
|
df: snowpark.DataFrame,
|
1169
|
-
y_true_col_names: Union[str,
|
1170
|
-
y_pred_col_names: Union[str,
|
1169
|
+
y_true_col_names: Union[str, list[str]],
|
1170
|
+
y_pred_col_names: Union[str, list[str]],
|
1171
1171
|
labels: Optional[npt.ArrayLike] = None,
|
1172
1172
|
pos_label: Union[str, int] = 1,
|
1173
1173
|
average: Optional[str] = "binary",
|
@@ -1264,8 +1264,8 @@ def precision_score(
|
|
1264
1264
|
def recall_score(
|
1265
1265
|
*,
|
1266
1266
|
df: snowpark.DataFrame,
|
1267
|
-
y_true_col_names: Union[str,
|
1268
|
-
y_pred_col_names: Union[str,
|
1267
|
+
y_true_col_names: Union[str, list[str]],
|
1268
|
+
y_pred_col_names: Union[str, list[str]],
|
1269
1269
|
labels: Optional[npt.ArrayLike] = None,
|
1270
1270
|
pos_label: Union[str, int] = 1,
|
1271
1271
|
average: Optional[str] = "binary",
|
@@ -1376,9 +1376,9 @@ def _sum_array_col(df: snowpark.DataFrame, col_name: str) -> snowpark.DataFrame:
|
|
1376
1376
|
|
1377
1377
|
|
1378
1378
|
def _check_binary_labels(
|
1379
|
-
labels:
|
1379
|
+
labels: list[Any],
|
1380
1380
|
pos_label: Union[str, int] = 1,
|
1381
|
-
) ->
|
1381
|
+
) -> list[Any]:
|
1382
1382
|
"""Validation associated with binary average labels.
|
1383
1383
|
|
1384
1384
|
Args:
|
@@ -1411,7 +1411,7 @@ def _prf_divide(
|
|
1411
1411
|
metric: str,
|
1412
1412
|
modifier: str,
|
1413
1413
|
average: Optional[str] = None,
|
1414
|
-
warn_for: Union[
|
1414
|
+
warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
|
1415
1415
|
zero_division: Union[str, int] = "warn",
|
1416
1416
|
) -> npt.NDArray[np.float_]:
|
1417
1417
|
"""Performs division and handles divide-by-zero.
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import math
|
2
2
|
import warnings
|
3
|
-
from typing import Any, Collection,
|
3
|
+
from typing import Any, Collection, Iterable, Optional, Union
|
4
4
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
@@ -18,7 +18,7 @@ INDEX = "INDEX"
|
|
18
18
|
BATCH_SIZE = 1000
|
19
19
|
|
20
20
|
|
21
|
-
def register_accumulator_udtf(*, session: Session, statement_params:
|
21
|
+
def register_accumulator_udtf(*, session: Session, statement_params: dict[str, Any]) -> str:
|
22
22
|
"""Registers accumulator UDTF in Snowflake and returns the name of the UDTF.
|
23
23
|
|
24
24
|
Args:
|
@@ -47,7 +47,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
|
|
47
47
|
else:
|
48
48
|
self._accumulated_row = self._accumulated_row + row
|
49
49
|
|
50
|
-
def end_partition(self) -> Iterable[
|
50
|
+
def end_partition(self) -> Iterable[tuple[bytes]]:
|
51
51
|
yield (cloudpickle.dumps(self._accumulated_row),)
|
52
52
|
|
53
53
|
accumulator = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE_FUNCTION)
|
@@ -68,7 +68,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
|
|
68
68
|
return accumulator
|
69
69
|
|
70
70
|
|
71
|
-
def register_sharded_dot_sum_computer(*, session: Session, statement_params:
|
71
|
+
def register_sharded_dot_sum_computer(*, session: Session, statement_params: dict[str, Any]) -> str:
|
72
72
|
"""Registers dot and sum computation UDTF in Snowflake and returns the name of the UDTF.
|
73
73
|
|
74
74
|
Args:
|
@@ -110,7 +110,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
110
110
|
# Square root of count - ddof
|
111
111
|
self._sqrt_count_d = -1.0
|
112
112
|
|
113
|
-
def process(self, input_row:
|
113
|
+
def process(self, input_row: list[float], count: int, ddof: int) -> None:
|
114
114
|
"""Computes sum and dot product.
|
115
115
|
|
116
116
|
Args:
|
@@ -138,7 +138,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
138
138
|
self.accumulate_batch_sum_and_dot_prod()
|
139
139
|
self._cur_count = 0
|
140
140
|
|
141
|
-
def end_partition(self) -> Iterable[
|
141
|
+
def end_partition(self) -> Iterable[tuple[bytes, str]]:
|
142
142
|
# 3. Compute sum and dot_prod for the remaining rows in the batch.
|
143
143
|
if self._cur_count > 0:
|
144
144
|
self.accumulate_batch_sum_and_dot_prod()
|
@@ -185,7 +185,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
|
|
185
185
|
|
186
186
|
def validate_and_return_dataframe_and_columns(
|
187
187
|
*, df: snowpark.DataFrame, columns: Optional[Collection[str]] = None
|
188
|
-
) ->
|
188
|
+
) -> tuple[snowpark.DataFrame, Collection[str]]:
|
189
189
|
"""Validates that the columns are all numeric and returns a dataframe with those columns.
|
190
190
|
|
191
191
|
Args:
|
@@ -212,8 +212,8 @@ def validate_and_return_dataframe_and_columns(
|
|
212
212
|
|
213
213
|
|
214
214
|
def check_label_columns(
|
215
|
-
y_true_col_names: Union[str,
|
216
|
-
y_pred_col_names: Union[str,
|
215
|
+
y_true_col_names: Union[str, list[str]],
|
216
|
+
y_pred_col_names: Union[str, list[str]],
|
217
217
|
) -> None:
|
218
218
|
"""Check y true and y pred columns.
|
219
219
|
|
@@ -238,7 +238,7 @@ def check_label_columns(
|
|
238
238
|
)
|
239
239
|
|
240
240
|
|
241
|
-
def flatten_cols(cols:
|
241
|
+
def flatten_cols(cols: list[Optional[Union[str, list[str]]]]) -> list[str]:
|
242
242
|
res = []
|
243
243
|
for col in cols:
|
244
244
|
if isinstance(col, str):
|
@@ -251,7 +251,7 @@ def flatten_cols(cols: List[Optional[Union[str, List[str]]]]) -> List[str]:
|
|
251
251
|
def unique_labels(
|
252
252
|
*,
|
253
253
|
df: snowpark.DataFrame,
|
254
|
-
columns:
|
254
|
+
columns: list[snowpark.Column],
|
255
255
|
) -> snowpark.DataFrame:
|
256
256
|
"""Extract indexed ordered unique labels as a dataframe.
|
257
257
|
|
@@ -311,7 +311,7 @@ def weighted_sum(
|
|
311
311
|
sample_score_column: snowpark.Column,
|
312
312
|
sample_weight_column: Optional[snowpark.Column] = None,
|
313
313
|
normalize: bool = False,
|
314
|
-
statement_params:
|
314
|
+
statement_params: dict[str, str],
|
315
315
|
) -> float:
|
316
316
|
"""Weighted sum of the sample score column.
|
317
317
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import
|
1
|
+
from typing import Optional, Union
|
2
2
|
|
3
3
|
import cloudpickle
|
4
4
|
import numpy as np
|
@@ -26,7 +26,7 @@ def precision_recall_curve(
|
|
26
26
|
probas_pred_col_name: str,
|
27
27
|
pos_label: Optional[Union[str, int]] = None,
|
28
28
|
sample_weight_col_name: Optional[str] = None,
|
29
|
-
) ->
|
29
|
+
) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
|
30
30
|
"""
|
31
31
|
Compute precision-recall pairs for different probability thresholds.
|
32
32
|
|
@@ -125,7 +125,7 @@ def precision_recall_curve(
|
|
125
125
|
|
126
126
|
kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
|
127
127
|
result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
|
128
|
-
res:
|
128
|
+
res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
|
129
129
|
return res
|
130
130
|
|
131
131
|
|
@@ -133,8 +133,8 @@ def precision_recall_curve(
|
|
133
133
|
def roc_auc_score(
|
134
134
|
*,
|
135
135
|
df: snowpark.DataFrame,
|
136
|
-
y_true_col_names: Union[str,
|
137
|
-
y_score_col_names: Union[str,
|
136
|
+
y_true_col_names: Union[str, list[str]],
|
137
|
+
y_score_col_names: Union[str, list[str]],
|
138
138
|
average: Optional[str] = "macro",
|
139
139
|
sample_weight_col_name: Optional[str] = None,
|
140
140
|
max_fpr: Optional[float] = None,
|
@@ -289,7 +289,7 @@ def roc_curve(
|
|
289
289
|
pos_label: Optional[Union[str, int]] = None,
|
290
290
|
sample_weight_col_name: Optional[str] = None,
|
291
291
|
drop_intermediate: bool = True,
|
292
|
-
) ->
|
292
|
+
) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
|
293
293
|
"""
|
294
294
|
Compute Receiver operating characteristic (ROC).
|
295
295
|
|
@@ -380,6 +380,6 @@ def roc_curve(
|
|
380
380
|
kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
|
381
381
|
result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
|
382
382
|
|
383
|
-
res:
|
383
|
+
res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
|
384
384
|
|
385
385
|
return res
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import inspect
|
2
|
-
from typing import
|
2
|
+
from typing import Optional, Union
|
3
3
|
|
4
4
|
import cloudpickle
|
5
5
|
import numpy as np
|
@@ -25,8 +25,8 @@ _MULTIOUTPUT_RAW_VALUES = "raw_values"
|
|
25
25
|
def d2_absolute_error_score(
|
26
26
|
*,
|
27
27
|
df: snowpark.DataFrame,
|
28
|
-
y_true_col_names: Union[str,
|
29
|
-
y_pred_col_names: Union[str,
|
28
|
+
y_true_col_names: Union[str, list[str]],
|
29
|
+
y_pred_col_names: Union[str, list[str]],
|
30
30
|
sample_weight_col_name: Optional[str] = None,
|
31
31
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
32
32
|
) -> Union[float, npt.NDArray[np.float_]]:
|
@@ -119,8 +119,8 @@ def d2_absolute_error_score(
|
|
119
119
|
def d2_pinball_score(
|
120
120
|
*,
|
121
121
|
df: snowpark.DataFrame,
|
122
|
-
y_true_col_names: Union[str,
|
123
|
-
y_pred_col_names: Union[str,
|
122
|
+
y_true_col_names: Union[str, list[str]],
|
123
|
+
y_pred_col_names: Union[str, list[str]],
|
124
124
|
sample_weight_col_name: Optional[str] = None,
|
125
125
|
alpha: float = 0.5,
|
126
126
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
@@ -219,8 +219,8 @@ def d2_pinball_score(
|
|
219
219
|
def explained_variance_score(
|
220
220
|
*,
|
221
221
|
df: snowpark.DataFrame,
|
222
|
-
y_true_col_names: Union[str,
|
223
|
-
y_pred_col_names: Union[str,
|
222
|
+
y_true_col_names: Union[str, list[str]],
|
223
|
+
y_pred_col_names: Union[str, list[str]],
|
224
224
|
sample_weight_col_name: Optional[str] = None,
|
225
225
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
226
226
|
force_finite: bool = True,
|
@@ -334,8 +334,8 @@ def explained_variance_score(
|
|
334
334
|
def mean_absolute_error(
|
335
335
|
*,
|
336
336
|
df: snowpark.DataFrame,
|
337
|
-
y_true_col_names: Union[str,
|
338
|
-
y_pred_col_names: Union[str,
|
337
|
+
y_true_col_names: Union[str, list[str]],
|
338
|
+
y_pred_col_names: Union[str, list[str]],
|
339
339
|
sample_weight_col_name: Optional[str] = None,
|
340
340
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
341
341
|
) -> Union[float, npt.NDArray[np.float_]]:
|
@@ -407,8 +407,8 @@ def mean_absolute_error(
|
|
407
407
|
def mean_absolute_percentage_error(
|
408
408
|
*,
|
409
409
|
df: snowpark.DataFrame,
|
410
|
-
y_true_col_names: Union[str,
|
411
|
-
y_pred_col_names: Union[str,
|
410
|
+
y_true_col_names: Union[str, list[str]],
|
411
|
+
y_pred_col_names: Union[str, list[str]],
|
412
412
|
sample_weight_col_name: Optional[str] = None,
|
413
413
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
414
414
|
) -> Union[float, npt.NDArray[np.float_]]:
|
@@ -490,8 +490,8 @@ def mean_absolute_percentage_error(
|
|
490
490
|
def mean_squared_error(
|
491
491
|
*,
|
492
492
|
df: snowpark.DataFrame,
|
493
|
-
y_true_col_names: Union[str,
|
494
|
-
y_pred_col_names: Union[str,
|
493
|
+
y_true_col_names: Union[str, list[str]],
|
494
|
+
y_pred_col_names: Union[str, list[str]],
|
495
495
|
sample_weight_col_name: Optional[str] = None,
|
496
496
|
multioutput: Union[str, npt.ArrayLike] = "uniform_average",
|
497
497
|
squared: bool = True,
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
4
4
|
|
5
|
-
pkg_dir = os.path.dirname(
|
5
|
+
pkg_dir = os.path.dirname(__file__)
|
6
6
|
pkg_name = __name__
|
7
7
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
8
8
|
for k, v in exportable_classes.items():
|
@@ -2,7 +2,7 @@
|
|
2
2
|
# This code is auto-generated using the sklearn_wrapper_template.py_template template.
|
3
3
|
# Do not modify the auto-generated code(except automatic reformatting by precommit hooks).
|
4
4
|
#
|
5
|
-
from typing import Any,
|
5
|
+
from typing import Any, Iterable, Optional, Union
|
6
6
|
|
7
7
|
import cloudpickle as cp
|
8
8
|
import numpy as np
|
@@ -244,7 +244,7 @@ class GridSearchCV(BaseTransformer):
|
|
244
244
|
sample_weight_col: Optional[str] = None,
|
245
245
|
) -> None:
|
246
246
|
super().__init__()
|
247
|
-
deps:
|
247
|
+
deps: set[str] = {
|
248
248
|
f"numpy=={np.__version__}",
|
249
249
|
f"scikit-learn=={sklearn.__version__}",
|
250
250
|
f"cloudpickle=={cp.__version__}",
|
@@ -268,7 +268,7 @@ class GridSearchCV(BaseTransformer):
|
|
268
268
|
self._sklearn_object: Any = sklearn.model_selection.GridSearchCV(
|
269
269
|
**cleaned_up_init_args,
|
270
270
|
)
|
271
|
-
self._model_signature_dict: Optional[
|
271
|
+
self._model_signature_dict: Optional[dict[str, ModelSignature]] = None
|
272
272
|
self.set_input_cols(input_cols)
|
273
273
|
self.set_output_cols(output_cols)
|
274
274
|
self.set_label_cols(label_cols)
|
@@ -281,7 +281,7 @@ class GridSearchCV(BaseTransformer):
|
|
281
281
|
self._class_name = GridSearchCV.__class__.__name__
|
282
282
|
self._subproject = _SUBPROJECT
|
283
283
|
|
284
|
-
def _get_active_columns(self) ->
|
284
|
+
def _get_active_columns(self) -> list[str]:
|
285
285
|
""" "Get the list of columns that are relevant to the transformer."""
|
286
286
|
selected_cols = (
|
287
287
|
self.input_cols + self.label_cols + ([self.sample_weight_col] if self.sample_weight_col is not None else [])
|
@@ -805,7 +805,7 @@ class GridSearchCV(BaseTransformer):
|
|
805
805
|
assert self._sklearn_object is not None
|
806
806
|
return self._sklearn_object
|
807
807
|
|
808
|
-
def _get_dependencies(self) ->
|
808
|
+
def _get_dependencies(self) -> list[str]:
|
809
809
|
return self._deps
|
810
810
|
|
811
811
|
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
@@ -820,7 +820,7 @@ class GridSearchCV(BaseTransformer):
|
|
820
820
|
use_snowflake_identifiers=True,
|
821
821
|
)
|
822
822
|
)
|
823
|
-
outputs:
|
823
|
+
outputs: list[BaseFeatureSpec] = []
|
824
824
|
if hasattr(self, "predict"):
|
825
825
|
# keep mypy happy
|
826
826
|
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
@@ -863,7 +863,7 @@ class GridSearchCV(BaseTransformer):
|
|
863
863
|
self._model_signature_dict[method] = signature
|
864
864
|
|
865
865
|
@property
|
866
|
-
def model_signatures(self) ->
|
866
|
+
def model_signatures(self) -> dict[str, ModelSignature]:
|
867
867
|
"""Returns model signature of current class.
|
868
868
|
|
869
869
|
Raises:
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any,
|
1
|
+
from typing import Any, Iterable, Optional, Union
|
2
2
|
|
3
3
|
import cloudpickle as cp
|
4
4
|
import numpy as np
|
@@ -254,7 +254,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
254
254
|
sample_weight_col: Optional[str] = None,
|
255
255
|
) -> None:
|
256
256
|
super().__init__()
|
257
|
-
deps:
|
257
|
+
deps: set[str] = {
|
258
258
|
f"numpy=={np.__version__}",
|
259
259
|
f"scikit-learn=={sklearn.__version__}",
|
260
260
|
f"cloudpickle=={cp.__version__}",
|
@@ -280,7 +280,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
280
280
|
self._sklearn_object: Any = sklearn.model_selection.RandomizedSearchCV(
|
281
281
|
**cleaned_up_init_args,
|
282
282
|
)
|
283
|
-
self._model_signature_dict: Optional[
|
283
|
+
self._model_signature_dict: Optional[dict[str, ModelSignature]] = None
|
284
284
|
self.set_input_cols(input_cols)
|
285
285
|
self.set_output_cols(output_cols)
|
286
286
|
self.set_label_cols(label_cols)
|
@@ -294,7 +294,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
294
294
|
self._class_name = RandomizedSearchCV.__class__.__name__
|
295
295
|
self._subproject = _SUBPROJECT
|
296
296
|
|
297
|
-
def _get_active_columns(self) ->
|
297
|
+
def _get_active_columns(self) -> list[str]:
|
298
298
|
""" "Get the list of columns that are relevant to the transformer."""
|
299
299
|
selected_cols = (
|
300
300
|
self.input_cols + self.label_cols + ([self.sample_weight_col] if self.sample_weight_col is not None else [])
|
@@ -820,7 +820,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
820
820
|
assert self._sklearn_object is not None
|
821
821
|
return self._sklearn_object
|
822
822
|
|
823
|
-
def _get_dependencies(self) ->
|
823
|
+
def _get_dependencies(self) -> list[str]:
|
824
824
|
return self._deps
|
825
825
|
|
826
826
|
def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
|
@@ -835,7 +835,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
835
835
|
use_snowflake_identifiers=True,
|
836
836
|
)
|
837
837
|
)
|
838
|
-
outputs:
|
838
|
+
outputs: list[BaseFeatureSpec] = []
|
839
839
|
if hasattr(self, "predict"):
|
840
840
|
# keep mypy happy
|
841
841
|
assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
|
@@ -878,7 +878,7 @@ class RandomizedSearchCV(BaseTransformer):
|
|
878
878
|
self._model_signature_dict[method] = signature
|
879
879
|
|
880
880
|
@property
|
881
|
-
def model_signatures(self) ->
|
881
|
+
def model_signatures(self) -> dict[str, ModelSignature]:
|
882
882
|
"""Returns model signature of current class.
|
883
883
|
|
884
884
|
Raises:
|
@@ -2,7 +2,7 @@ import os
|
|
2
2
|
|
3
3
|
from snowflake.ml._internal import init_utils
|
4
4
|
|
5
|
-
pkg_dir = os.path.dirname(
|
5
|
+
pkg_dir = os.path.dirname(__file__)
|
6
6
|
pkg_name = __name__
|
7
7
|
exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
|
8
8
|
for k, v in exportable_classes.items():
|