snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python3
2
2
  import copy
3
3
  import warnings
4
- from typing import Any, Dict, Iterable, Optional, Type, Union
4
+ from typing import Any, Iterable, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
@@ -25,7 +25,7 @@ STRATEGY_TO_STATE_DICT = {
25
25
  "most_frequent": _utils.BasicStatistics.MODE,
26
26
  }
27
27
 
28
- SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP: Dict[Type[T.DataType], npt.DTypeLike] = {
28
+ SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP: dict[type[T.DataType], npt.DTypeLike] = {
29
29
  T.ByteType: np.dtype("int8"),
30
30
  T.ShortType: np.dtype("int16"),
31
31
  T.IntegerType: np.dtype("int32"),
@@ -164,7 +164,7 @@ class SimpleImputer(base.BaseTransformer):
164
164
 
165
165
  self.fill_value = fill_value
166
166
  self.missing_values = missing_values
167
- self.statistics_: Dict[str, Any] = {}
167
+ self.statistics_: dict[str, Any] = {}
168
168
  # TODO(hayu): [SNOW-752265] Support SimpleImputer keep_empty_features.
169
169
  # Add back when `keep_empty_features` is supported.
170
170
  # self.keep_empty_features = keep_empty_features
@@ -195,7 +195,7 @@ class SimpleImputer(base.BaseTransformer):
195
195
  del self.feature_names_in_
196
196
  del self._sklearn_fit_dtype
197
197
 
198
- def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) -> Dict[str, T.DataType]:
198
+ def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) -> dict[str, T.DataType]:
199
199
  """
200
200
  Checks that the input columns are all the same datatype category(except for most_frequent strategy) and
201
201
  returns the datatype.
@@ -211,7 +211,7 @@ class SimpleImputer(base.BaseTransformer):
211
211
  supported.
212
212
  """
213
213
 
214
- def check_type_consistency(col_types: Dict[str, T.DataType]) -> None:
214
+ def check_type_consistency(col_types: dict[str, T.DataType]) -> None:
215
215
  is_numeric_type = None
216
216
  for col_name, col_type in col_types.items():
217
217
  if is_numeric_type is None:
@@ -5,7 +5,7 @@ import cloudpickle
5
5
  from snowflake.ml._internal import init_utils
6
6
  from snowflake.ml._internal.utils import result
7
7
 
8
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
8
+ pkg_dir = os.path.dirname(__file__)
9
9
  pkg_name = __name__
10
10
  exportable_functions = init_utils.fetch_functions_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
11
11
  for k, v in exportable_functions.items():
@@ -2,7 +2,7 @@ import inspect
2
2
  import json
3
3
  import math
4
4
  import warnings
5
- from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
5
+ from typing import Any, Iterable, Optional, Union
6
6
 
7
7
  import cloudpickle
8
8
  import numpy as np
@@ -32,8 +32,8 @@ _SUBPROJECT = "Metrics"
32
32
  def accuracy_score(
33
33
  *,
34
34
  df: snowpark.DataFrame,
35
- y_true_col_names: Union[str, List[str]],
36
- y_pred_col_names: Union[str, List[str]],
35
+ y_true_col_names: Union[str, list[str]],
36
+ y_pred_col_names: Union[str, list[str]],
37
37
  normalize: bool = True,
38
38
  sample_weight_col_name: Optional[str] = None,
39
39
  ) -> float:
@@ -221,7 +221,7 @@ def confusion_matrix(
221
221
  return cm
222
222
 
223
223
 
224
- def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_params: Dict[str, Any]) -> str:
224
+ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_params: dict[str, Any]) -> str:
225
225
  """Registers confusion matrix computation UDTF in Snowflake and returns the name of the UDTF.
226
226
 
227
227
  Args:
@@ -247,7 +247,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
247
247
  # Number of labels.
248
248
  self._n_label = 0
249
249
 
250
- def process(self, input_row: List[float], n_label: int) -> None:
250
+ def process(self, input_row: list[float], n_label: int) -> None:
251
251
  """Computes confusion matrix.
252
252
 
253
253
  Args:
@@ -270,7 +270,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
270
270
  self.update_confusion_matrix()
271
271
  self._cur_count = 0
272
272
 
273
- def end_partition(self) -> Iterable[Tuple[bytes, str]]:
273
+ def end_partition(self) -> Iterable[tuple[bytes, str]]:
274
274
  # 3. Compute sum and dot_prod for the remaining rows in the batch.
275
275
  if self._cur_count > 0:
276
276
  self.update_confusion_matrix()
@@ -313,8 +313,8 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
313
313
  def f1_score(
314
314
  *,
315
315
  df: snowpark.DataFrame,
316
- y_true_col_names: Union[str, List[str]],
317
- y_pred_col_names: Union[str, List[str]],
316
+ y_true_col_names: Union[str, list[str]],
317
+ y_pred_col_names: Union[str, list[str]],
318
318
  labels: Optional[npt.ArrayLike] = None,
319
319
  pos_label: Union[str, int] = 1,
320
320
  average: Optional[str] = "binary",
@@ -406,8 +406,8 @@ def f1_score(
406
406
  def fbeta_score(
407
407
  *,
408
408
  df: snowpark.DataFrame,
409
- y_true_col_names: Union[str, List[str]],
410
- y_pred_col_names: Union[str, List[str]],
409
+ y_true_col_names: Union[str, list[str]],
410
+ y_pred_col_names: Union[str, list[str]],
411
411
  beta: float,
412
412
  labels: Optional[npt.ArrayLike] = None,
413
413
  pos_label: Union[str, int] = 1,
@@ -501,8 +501,8 @@ def fbeta_score(
501
501
  def log_loss(
502
502
  *,
503
503
  df: snowpark.DataFrame,
504
- y_true_col_names: Union[str, List[str]],
505
- y_pred_col_names: Union[str, List[str]],
504
+ y_true_col_names: Union[str, list[str]],
505
+ y_pred_col_names: Union[str, list[str]],
506
506
  eps: Union[float, str] = "auto",
507
507
  normalize: bool = True,
508
508
  sample_weight_col_name: Optional[str] = None,
@@ -625,7 +625,7 @@ def log_loss(
625
625
  def _register_log_loss_computer(
626
626
  *,
627
627
  session: snowpark.Session,
628
- statement_params: Dict[str, Any],
628
+ statement_params: dict[str, Any],
629
629
  labels: Optional[npt.ArrayLike] = None,
630
630
  ) -> str:
631
631
  """Registers log loss computation UDTF in Snowflake and returns the name of the UDTF.
@@ -644,16 +644,16 @@ def _register_log_loss_computer(
644
644
  class LogLossComputer:
645
645
  def __init__(self) -> None:
646
646
  self._labels = labels
647
- self._y_true: List[List[int]] = []
648
- self._y_pred: List[List[float]] = []
649
- self._sample_weight: List[float] = []
647
+ self._y_true: list[list[int]] = []
648
+ self._y_pred: list[list[float]] = []
649
+ self._sample_weight: list[float] = []
650
650
 
651
- def process(self, y_true: List[int], y_pred: List[float], sample_weight: float) -> None:
651
+ def process(self, y_true: list[int], y_pred: list[float], sample_weight: float) -> None:
652
652
  self._y_true.append(y_true)
653
653
  self._y_pred.append(y_pred)
654
654
  self._sample_weight.append(sample_weight)
655
655
 
656
- def end_partition(self) -> Iterable[Tuple[float]]:
656
+ def end_partition(self) -> Iterable[tuple[float]]:
657
657
  res = metrics.log_loss(
658
658
  self._y_true,
659
659
  self._y_pred,
@@ -685,18 +685,18 @@ def _register_log_loss_computer(
685
685
  def precision_recall_fscore_support(
686
686
  *,
687
687
  df: snowpark.DataFrame,
688
- y_true_col_names: Union[str, List[str]],
689
- y_pred_col_names: Union[str, List[str]],
688
+ y_true_col_names: Union[str, list[str]],
689
+ y_pred_col_names: Union[str, list[str]],
690
690
  beta: float = 1.0,
691
691
  labels: Optional[npt.ArrayLike] = None,
692
692
  pos_label: Union[str, int] = 1,
693
693
  average: Optional[str] = None,
694
- warn_for: Union[Tuple[str, ...], Set[str]] = ("precision", "recall", "f-score"),
694
+ warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
695
695
  sample_weight_col_name: Optional[str] = None,
696
696
  zero_division: Union[str, int] = "warn",
697
697
  ) -> Union[
698
- Tuple[float, float, float, None],
699
- Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
698
+ tuple[float, float, float, None],
699
+ tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
700
700
  ]:
701
701
  """
702
702
  Compute precision, recall, F-measure and support for each class.
@@ -854,8 +854,8 @@ def precision_recall_fscore_support(
854
854
  result_object = result.deserialize(session, precision_recall_fscore_support_anon_sproc(session, **kwargs))
855
855
 
856
856
  res: Union[
857
- Tuple[float, float, float, None],
858
- Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
857
+ tuple[float, float, float, None],
858
+ tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]],
859
859
  ] = result_object[:4]
860
860
  warning = result_object[-1]
861
861
  if warning:
@@ -1039,18 +1039,18 @@ def _register_multilabel_confusion_matrix_computer(
1039
1039
  def __init__(self) -> None:
1040
1040
  self._labels = labels
1041
1041
  self._samplewise = samplewise
1042
- self._y_true: List[List[int]] = []
1043
- self._y_pred: List[List[int]] = []
1044
- self._sample_weight: List[float] = []
1042
+ self._y_true: list[list[int]] = []
1043
+ self._y_pred: list[list[int]] = []
1044
+ self._sample_weight: list[float] = []
1045
1045
 
1046
- def process(self, y_true: List[int], y_pred: List[int], sample_weight: float) -> None:
1046
+ def process(self, y_true: list[int], y_pred: list[int], sample_weight: float) -> None:
1047
1047
  self._y_true.append(y_true)
1048
1048
  self._y_pred.append(y_pred)
1049
1049
  self._sample_weight.append(sample_weight)
1050
1050
 
1051
1051
  def end_partition(
1052
1052
  self,
1053
- ) -> Iterable[Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]]:
1053
+ ) -> Iterable[tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]]:
1054
1054
  MCM = metrics.multilabel_confusion_matrix(
1055
1055
  self._y_true,
1056
1056
  self._y_pred,
@@ -1093,8 +1093,8 @@ def _register_multilabel_confusion_matrix_computer(
1093
1093
  def _binary_precision_score(
1094
1094
  *,
1095
1095
  df: snowpark.DataFrame,
1096
- y_true_col_names: Union[str, List[str]],
1097
- y_pred_col_names: Union[str, List[str]],
1096
+ y_true_col_names: Union[str, list[str]],
1097
+ y_pred_col_names: Union[str, list[str]],
1098
1098
  pos_label: Union[str, int] = 1,
1099
1099
  sample_weight_col_name: Optional[str] = None,
1100
1100
  zero_division: Union[str, int] = "warn",
@@ -1166,8 +1166,8 @@ def _binary_precision_score(
1166
1166
  def precision_score(
1167
1167
  *,
1168
1168
  df: snowpark.DataFrame,
1169
- y_true_col_names: Union[str, List[str]],
1170
- y_pred_col_names: Union[str, List[str]],
1169
+ y_true_col_names: Union[str, list[str]],
1170
+ y_pred_col_names: Union[str, list[str]],
1171
1171
  labels: Optional[npt.ArrayLike] = None,
1172
1172
  pos_label: Union[str, int] = 1,
1173
1173
  average: Optional[str] = "binary",
@@ -1264,8 +1264,8 @@ def precision_score(
1264
1264
  def recall_score(
1265
1265
  *,
1266
1266
  df: snowpark.DataFrame,
1267
- y_true_col_names: Union[str, List[str]],
1268
- y_pred_col_names: Union[str, List[str]],
1267
+ y_true_col_names: Union[str, list[str]],
1268
+ y_pred_col_names: Union[str, list[str]],
1269
1269
  labels: Optional[npt.ArrayLike] = None,
1270
1270
  pos_label: Union[str, int] = 1,
1271
1271
  average: Optional[str] = "binary",
@@ -1376,9 +1376,9 @@ def _sum_array_col(df: snowpark.DataFrame, col_name: str) -> snowpark.DataFrame:
1376
1376
 
1377
1377
 
1378
1378
  def _check_binary_labels(
1379
- labels: List[Any],
1379
+ labels: list[Any],
1380
1380
  pos_label: Union[str, int] = 1,
1381
- ) -> List[Any]:
1381
+ ) -> list[Any]:
1382
1382
  """Validation associated with binary average labels.
1383
1383
 
1384
1384
  Args:
@@ -1411,7 +1411,7 @@ def _prf_divide(
1411
1411
  metric: str,
1412
1412
  modifier: str,
1413
1413
  average: Optional[str] = None,
1414
- warn_for: Union[Tuple[str, ...], Set[str]] = ("precision", "recall", "f-score"),
1414
+ warn_for: Union[tuple[str, ...], set[str]] = ("precision", "recall", "f-score"),
1415
1415
  zero_division: Union[str, int] = "warn",
1416
1416
  ) -> npt.NDArray[np.float_]:
1417
1417
  """Performs division and handles divide-by-zero.
@@ -1,6 +1,6 @@
1
1
  import math
2
2
  import warnings
3
- from typing import Any, Collection, Dict, Iterable, List, Optional, Tuple, Union
3
+ from typing import Any, Collection, Iterable, Optional, Union
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
@@ -18,7 +18,7 @@ INDEX = "INDEX"
18
18
  BATCH_SIZE = 1000
19
19
 
20
20
 
21
- def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, Any]) -> str:
21
+ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, Any]) -> str:
22
22
  """Registers accumulator UDTF in Snowflake and returns the name of the UDTF.
23
23
 
24
24
  Args:
@@ -47,7 +47,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
47
47
  else:
48
48
  self._accumulated_row = self._accumulated_row + row
49
49
 
50
- def end_partition(self) -> Iterable[Tuple[bytes]]:
50
+ def end_partition(self) -> Iterable[tuple[bytes]]:
51
51
  yield (cloudpickle.dumps(self._accumulated_row),)
52
52
 
53
53
  accumulator = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE_FUNCTION)
@@ -68,7 +68,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
68
68
  return accumulator
69
69
 
70
70
 
71
- def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dict[str, Any]) -> str:
71
+ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dict[str, Any]) -> str:
72
72
  """Registers dot and sum computation UDTF in Snowflake and returns the name of the UDTF.
73
73
 
74
74
  Args:
@@ -110,7 +110,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
110
110
  # Square root of count - ddof
111
111
  self._sqrt_count_d = -1.0
112
112
 
113
- def process(self, input_row: List[float], count: int, ddof: int) -> None:
113
+ def process(self, input_row: list[float], count: int, ddof: int) -> None:
114
114
  """Computes sum and dot product.
115
115
 
116
116
  Args:
@@ -138,7 +138,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
138
138
  self.accumulate_batch_sum_and_dot_prod()
139
139
  self._cur_count = 0
140
140
 
141
- def end_partition(self) -> Iterable[Tuple[bytes, str]]:
141
+ def end_partition(self) -> Iterable[tuple[bytes, str]]:
142
142
  # 3. Compute sum and dot_prod for the remaining rows in the batch.
143
143
  if self._cur_count > 0:
144
144
  self.accumulate_batch_sum_and_dot_prod()
@@ -185,7 +185,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
185
185
 
186
186
  def validate_and_return_dataframe_and_columns(
187
187
  *, df: snowpark.DataFrame, columns: Optional[Collection[str]] = None
188
- ) -> Tuple[snowpark.DataFrame, Collection[str]]:
188
+ ) -> tuple[snowpark.DataFrame, Collection[str]]:
189
189
  """Validates that the columns are all numeric and returns a dataframe with those columns.
190
190
 
191
191
  Args:
@@ -212,8 +212,8 @@ def validate_and_return_dataframe_and_columns(
212
212
 
213
213
 
214
214
  def check_label_columns(
215
- y_true_col_names: Union[str, List[str]],
216
- y_pred_col_names: Union[str, List[str]],
215
+ y_true_col_names: Union[str, list[str]],
216
+ y_pred_col_names: Union[str, list[str]],
217
217
  ) -> None:
218
218
  """Check y true and y pred columns.
219
219
 
@@ -238,7 +238,7 @@ def check_label_columns(
238
238
  )
239
239
 
240
240
 
241
- def flatten_cols(cols: List[Optional[Union[str, List[str]]]]) -> List[str]:
241
+ def flatten_cols(cols: list[Optional[Union[str, list[str]]]]) -> list[str]:
242
242
  res = []
243
243
  for col in cols:
244
244
  if isinstance(col, str):
@@ -251,7 +251,7 @@ def flatten_cols(cols: List[Optional[Union[str, List[str]]]]) -> List[str]:
251
251
  def unique_labels(
252
252
  *,
253
253
  df: snowpark.DataFrame,
254
- columns: List[snowpark.Column],
254
+ columns: list[snowpark.Column],
255
255
  ) -> snowpark.DataFrame:
256
256
  """Extract indexed ordered unique labels as a dataframe.
257
257
 
@@ -311,7 +311,7 @@ def weighted_sum(
311
311
  sample_score_column: snowpark.Column,
312
312
  sample_weight_column: Optional[snowpark.Column] = None,
313
313
  normalize: bool = False,
314
- statement_params: Dict[str, str],
314
+ statement_params: dict[str, str],
315
315
  ) -> float:
316
316
  """Weighted sum of the sample score column.
317
317
 
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import cloudpickle
4
4
  import numpy as np
@@ -26,7 +26,7 @@ def precision_recall_curve(
26
26
  probas_pred_col_name: str,
27
27
  pos_label: Optional[Union[str, int]] = None,
28
28
  sample_weight_col_name: Optional[str] = None,
29
- ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
29
+ ) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
30
30
  """
31
31
  Compute precision-recall pairs for different probability thresholds.
32
32
 
@@ -125,7 +125,7 @@ def precision_recall_curve(
125
125
 
126
126
  kwargs = telemetry.get_sproc_statement_params_kwargs(precision_recall_curve_anon_sproc, statement_params)
127
127
  result_object = result.deserialize(session, precision_recall_curve_anon_sproc(session, **kwargs))
128
- res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
128
+ res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
129
129
  return res
130
130
 
131
131
 
@@ -133,8 +133,8 @@ def precision_recall_curve(
133
133
  def roc_auc_score(
134
134
  *,
135
135
  df: snowpark.DataFrame,
136
- y_true_col_names: Union[str, List[str]],
137
- y_score_col_names: Union[str, List[str]],
136
+ y_true_col_names: Union[str, list[str]],
137
+ y_score_col_names: Union[str, list[str]],
138
138
  average: Optional[str] = "macro",
139
139
  sample_weight_col_name: Optional[str] = None,
140
140
  max_fpr: Optional[float] = None,
@@ -289,7 +289,7 @@ def roc_curve(
289
289
  pos_label: Optional[Union[str, int]] = None,
290
290
  sample_weight_col_name: Optional[str] = None,
291
291
  drop_intermediate: bool = True,
292
- ) -> Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
292
+ ) -> tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]]:
293
293
  """
294
294
  Compute Receiver operating characteristic (ROC).
295
295
 
@@ -380,6 +380,6 @@ def roc_curve(
380
380
  kwargs = telemetry.get_sproc_statement_params_kwargs(roc_curve_anon_sproc, statement_params)
381
381
  result_object = result.deserialize(session, roc_curve_anon_sproc(session, **kwargs))
382
382
 
383
- res: Tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
383
+ res: tuple[npt.NDArray[np.float_], npt.NDArray[np.float_], npt.NDArray[np.float_]] = result_object
384
384
 
385
385
  return res
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import List, Optional, Union
2
+ from typing import Optional, Union
3
3
 
4
4
  import cloudpickle
5
5
  import numpy as np
@@ -25,8 +25,8 @@ _MULTIOUTPUT_RAW_VALUES = "raw_values"
25
25
  def d2_absolute_error_score(
26
26
  *,
27
27
  df: snowpark.DataFrame,
28
- y_true_col_names: Union[str, List[str]],
29
- y_pred_col_names: Union[str, List[str]],
28
+ y_true_col_names: Union[str, list[str]],
29
+ y_pred_col_names: Union[str, list[str]],
30
30
  sample_weight_col_name: Optional[str] = None,
31
31
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
32
32
  ) -> Union[float, npt.NDArray[np.float_]]:
@@ -119,8 +119,8 @@ def d2_absolute_error_score(
119
119
  def d2_pinball_score(
120
120
  *,
121
121
  df: snowpark.DataFrame,
122
- y_true_col_names: Union[str, List[str]],
123
- y_pred_col_names: Union[str, List[str]],
122
+ y_true_col_names: Union[str, list[str]],
123
+ y_pred_col_names: Union[str, list[str]],
124
124
  sample_weight_col_name: Optional[str] = None,
125
125
  alpha: float = 0.5,
126
126
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
@@ -219,8 +219,8 @@ def d2_pinball_score(
219
219
  def explained_variance_score(
220
220
  *,
221
221
  df: snowpark.DataFrame,
222
- y_true_col_names: Union[str, List[str]],
223
- y_pred_col_names: Union[str, List[str]],
222
+ y_true_col_names: Union[str, list[str]],
223
+ y_pred_col_names: Union[str, list[str]],
224
224
  sample_weight_col_name: Optional[str] = None,
225
225
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
226
226
  force_finite: bool = True,
@@ -334,8 +334,8 @@ def explained_variance_score(
334
334
  def mean_absolute_error(
335
335
  *,
336
336
  df: snowpark.DataFrame,
337
- y_true_col_names: Union[str, List[str]],
338
- y_pred_col_names: Union[str, List[str]],
337
+ y_true_col_names: Union[str, list[str]],
338
+ y_pred_col_names: Union[str, list[str]],
339
339
  sample_weight_col_name: Optional[str] = None,
340
340
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
341
341
  ) -> Union[float, npt.NDArray[np.float_]]:
@@ -407,8 +407,8 @@ def mean_absolute_error(
407
407
  def mean_absolute_percentage_error(
408
408
  *,
409
409
  df: snowpark.DataFrame,
410
- y_true_col_names: Union[str, List[str]],
411
- y_pred_col_names: Union[str, List[str]],
410
+ y_true_col_names: Union[str, list[str]],
411
+ y_pred_col_names: Union[str, list[str]],
412
412
  sample_weight_col_name: Optional[str] = None,
413
413
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
414
414
  ) -> Union[float, npt.NDArray[np.float_]]:
@@ -490,8 +490,8 @@ def mean_absolute_percentage_error(
490
490
  def mean_squared_error(
491
491
  *,
492
492
  df: snowpark.DataFrame,
493
- y_true_col_names: Union[str, List[str]],
494
- y_pred_col_names: Union[str, List[str]],
493
+ y_true_col_names: Union[str, list[str]],
494
+ y_pred_col_names: Union[str, list[str]],
495
495
  sample_weight_col_name: Optional[str] = None,
496
496
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
497
497
  squared: bool = True,
@@ -2,7 +2,7 @@ import os
2
2
 
3
3
  from snowflake.ml._internal import init_utils
4
4
 
5
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
5
+ pkg_dir = os.path.dirname(__file__)
6
6
  pkg_name = __name__
7
7
  exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
8
8
  for k, v in exportable_classes.items():
@@ -2,7 +2,7 @@
2
2
  # This code is auto-generated using the sklearn_wrapper_template.py_template template.
3
3
  # Do not modify the auto-generated code(except automatic reformatting by precommit hooks).
4
4
  #
5
- from typing import Any, Dict, Iterable, List, Optional, Set, Union
5
+ from typing import Any, Iterable, Optional, Union
6
6
 
7
7
  import cloudpickle as cp
8
8
  import numpy as np
@@ -244,7 +244,7 @@ class GridSearchCV(BaseTransformer):
244
244
  sample_weight_col: Optional[str] = None,
245
245
  ) -> None:
246
246
  super().__init__()
247
- deps: Set[str] = {
247
+ deps: set[str] = {
248
248
  f"numpy=={np.__version__}",
249
249
  f"scikit-learn=={sklearn.__version__}",
250
250
  f"cloudpickle=={cp.__version__}",
@@ -268,7 +268,7 @@ class GridSearchCV(BaseTransformer):
268
268
  self._sklearn_object: Any = sklearn.model_selection.GridSearchCV(
269
269
  **cleaned_up_init_args,
270
270
  )
271
- self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
271
+ self._model_signature_dict: Optional[dict[str, ModelSignature]] = None
272
272
  self.set_input_cols(input_cols)
273
273
  self.set_output_cols(output_cols)
274
274
  self.set_label_cols(label_cols)
@@ -281,7 +281,7 @@ class GridSearchCV(BaseTransformer):
281
281
  self._class_name = GridSearchCV.__class__.__name__
282
282
  self._subproject = _SUBPROJECT
283
283
 
284
- def _get_active_columns(self) -> List[str]:
284
+ def _get_active_columns(self) -> list[str]:
285
285
  """ "Get the list of columns that are relevant to the transformer."""
286
286
  selected_cols = (
287
287
  self.input_cols + self.label_cols + ([self.sample_weight_col] if self.sample_weight_col is not None else [])
@@ -805,7 +805,7 @@ class GridSearchCV(BaseTransformer):
805
805
  assert self._sklearn_object is not None
806
806
  return self._sklearn_object
807
807
 
808
- def _get_dependencies(self) -> List[str]:
808
+ def _get_dependencies(self) -> list[str]:
809
809
  return self._deps
810
810
 
811
811
  def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
@@ -820,7 +820,7 @@ class GridSearchCV(BaseTransformer):
820
820
  use_snowflake_identifiers=True,
821
821
  )
822
822
  )
823
- outputs: List[BaseFeatureSpec] = []
823
+ outputs: list[BaseFeatureSpec] = []
824
824
  if hasattr(self, "predict"):
825
825
  # keep mypy happy
826
826
  assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
@@ -863,7 +863,7 @@ class GridSearchCV(BaseTransformer):
863
863
  self._model_signature_dict[method] = signature
864
864
 
865
865
  @property
866
- def model_signatures(self) -> Dict[str, ModelSignature]:
866
+ def model_signatures(self) -> dict[str, ModelSignature]:
867
867
  """Returns model signature of current class.
868
868
 
869
869
  Raises:
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Iterable, List, Optional, Set, Union
1
+ from typing import Any, Iterable, Optional, Union
2
2
 
3
3
  import cloudpickle as cp
4
4
  import numpy as np
@@ -254,7 +254,7 @@ class RandomizedSearchCV(BaseTransformer):
254
254
  sample_weight_col: Optional[str] = None,
255
255
  ) -> None:
256
256
  super().__init__()
257
- deps: Set[str] = {
257
+ deps: set[str] = {
258
258
  f"numpy=={np.__version__}",
259
259
  f"scikit-learn=={sklearn.__version__}",
260
260
  f"cloudpickle=={cp.__version__}",
@@ -280,7 +280,7 @@ class RandomizedSearchCV(BaseTransformer):
280
280
  self._sklearn_object: Any = sklearn.model_selection.RandomizedSearchCV(
281
281
  **cleaned_up_init_args,
282
282
  )
283
- self._model_signature_dict: Optional[Dict[str, ModelSignature]] = None
283
+ self._model_signature_dict: Optional[dict[str, ModelSignature]] = None
284
284
  self.set_input_cols(input_cols)
285
285
  self.set_output_cols(output_cols)
286
286
  self.set_label_cols(label_cols)
@@ -294,7 +294,7 @@ class RandomizedSearchCV(BaseTransformer):
294
294
  self._class_name = RandomizedSearchCV.__class__.__name__
295
295
  self._subproject = _SUBPROJECT
296
296
 
297
- def _get_active_columns(self) -> List[str]:
297
+ def _get_active_columns(self) -> list[str]:
298
298
  """ "Get the list of columns that are relevant to the transformer."""
299
299
  selected_cols = (
300
300
  self.input_cols + self.label_cols + ([self.sample_weight_col] if self.sample_weight_col is not None else [])
@@ -820,7 +820,7 @@ class RandomizedSearchCV(BaseTransformer):
820
820
  assert self._sklearn_object is not None
821
821
  return self._sklearn_object
822
822
 
823
- def _get_dependencies(self) -> List[str]:
823
+ def _get_dependencies(self) -> list[str]:
824
824
  return self._deps
825
825
 
826
826
  def _generate_model_signatures(self, dataset: Union[DataFrame, pd.DataFrame]) -> None:
@@ -835,7 +835,7 @@ class RandomizedSearchCV(BaseTransformer):
835
835
  use_snowflake_identifiers=True,
836
836
  )
837
837
  )
838
- outputs: List[BaseFeatureSpec] = []
838
+ outputs: list[BaseFeatureSpec] = []
839
839
  if hasattr(self, "predict"):
840
840
  # keep mypy happy
841
841
  assert self._sklearn_object is not None and hasattr(self._sklearn_object, "_estimator_type")
@@ -878,7 +878,7 @@ class RandomizedSearchCV(BaseTransformer):
878
878
  self._model_signature_dict[method] = signature
879
879
 
880
880
  @property
881
- def model_signatures(self) -> Dict[str, ModelSignature]:
881
+ def model_signatures(self) -> dict[str, ModelSignature]:
882
882
  """Returns model signature of current class.
883
883
 
884
884
  Raises:
@@ -2,7 +2,7 @@ import os
2
2
 
3
3
  from snowflake.ml._internal import init_utils
4
4
 
5
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
5
+ pkg_dir = os.path.dirname(__file__)
6
6
  pkg_name = __name__
7
7
  exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
8
8
  for k, v in exportable_classes.items():