snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import os
3
3
  import tempfile
4
- from typing import Any, Dict, List, Optional
4
+ from typing import Any, Optional
5
5
 
6
6
  import cloudpickle as cp
7
7
  import pandas as pd
@@ -41,13 +41,13 @@ _PROJECT = "ModelDevelopment"
41
41
 
42
42
 
43
43
  def get_data_iterator(
44
- file_paths: List[str],
44
+ file_paths: list[str],
45
45
  batch_size: int,
46
- input_cols: List[str],
47
- label_cols: List[str],
46
+ input_cols: list[str],
47
+ label_cols: list[str],
48
48
  sample_weight_col: Optional[str] = None,
49
49
  ) -> Any:
50
- from typing import List, Optional
50
+ from typing import Optional
51
51
 
52
52
  import xgboost
53
53
 
@@ -60,10 +60,10 @@ def get_data_iterator(
60
60
 
61
61
  def __init__(
62
62
  self,
63
- file_paths: List[str],
63
+ file_paths: list[str],
64
64
  batch_size: int,
65
- input_cols: List[str],
66
- label_cols: List[str],
65
+ input_cols: list[str],
66
+ label_cols: list[str],
67
67
  sample_weight_col: Optional[str] = None,
68
68
  ) -> None:
69
69
  """
@@ -151,10 +151,10 @@ def get_data_iterator(
151
151
 
152
152
  def train_xgboost_model(
153
153
  estimator: object,
154
- file_paths: List[str],
154
+ file_paths: list[str],
155
155
  batch_size: int,
156
- input_cols: List[str],
157
- label_cols: List[str],
156
+ input_cols: list[str],
157
+ label_cols: list[str],
158
158
  sample_weight_col: Optional[str] = None,
159
159
  ) -> object:
160
160
  """
@@ -247,8 +247,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
247
247
  estimator: object,
248
248
  dataset: DataFrame,
249
249
  session: Session,
250
- input_cols: List[str],
251
- label_cols: Optional[List[str]],
250
+ input_cols: list[str],
251
+ label_cols: Optional[list[str]],
252
252
  sample_weight_col: Optional[str],
253
253
  autogenerated: bool = False,
254
254
  subproject: str = "",
@@ -285,8 +285,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
285
285
  self,
286
286
  model_spec: ModelSpecifications,
287
287
  session: Session,
288
- statement_params: Dict[str, str],
289
- import_file_paths: List[str],
288
+ statement_params: dict[str, str],
289
+ import_file_paths: list[str],
290
290
  ) -> Any:
291
291
  fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
292
292
 
@@ -308,10 +308,10 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
308
308
  session: Session,
309
309
  dataset_stage_name: str,
310
310
  batch_size: int,
311
- input_cols: List[str],
312
- label_cols: List[str],
311
+ input_cols: list[str],
312
+ label_cols: list[str],
313
313
  sample_weight_col: Optional[str],
314
- statement_params: Dict[str, str],
314
+ statement_params: dict[str, str],
315
315
  ) -> str:
316
316
  import os
317
317
  import sys
@@ -365,7 +365,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
365
365
 
366
366
  return fit_wrapper_sproc
367
367
 
368
- def _write_training_data_to_stage(self, dataset_stage_name: str) -> List[str]:
368
+ def _write_training_data_to_stage(self, dataset_stage_name: str) -> list[str]:
369
369
  """
370
370
  Materializes the training to the specified stage and returns the list of stage file paths.
371
371
 
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Optional, Protocol, TypedDict, Union
1
+ from typing import Any, Optional, Protocol, TypedDict, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -29,9 +29,9 @@ class LocalModelTransformHandlers(Protocol):
29
29
  def batch_inference(
30
30
  self,
31
31
  inference_method: str,
32
- input_cols: List[str],
33
- expected_output_cols: List[str],
34
- snowpark_input_cols: Optional[List[str]],
32
+ input_cols: list[str],
33
+ expected_output_cols: list[str],
34
+ snowpark_input_cols: Optional[list[str]],
35
35
  drop_input_cols: Optional[bool] = False,
36
36
  *args: Any,
37
37
  **kwargs: Any,
@@ -57,8 +57,8 @@ class LocalModelTransformHandlers(Protocol):
57
57
 
58
58
  def score(
59
59
  self,
60
- input_cols: List[str],
61
- label_cols: List[str],
60
+ input_cols: list[str],
61
+ label_cols: list[str],
62
62
  sample_weight_col: Optional[str],
63
63
  *args: Any,
64
64
  **kwargs: Any,
@@ -105,10 +105,10 @@ class RemoteModelTransformHandlers(Protocol):
105
105
  def batch_inference(
106
106
  self,
107
107
  inference_method: str,
108
- input_cols: List[str],
109
- expected_output_cols: List[str],
108
+ input_cols: list[str],
109
+ expected_output_cols: list[str],
110
110
  session: snowpark.Session,
111
- dependencies: List[str],
111
+ dependencies: list[str],
112
112
  drop_input_cols: Optional[bool] = False,
113
113
  expected_output_cols_type: Optional[str] = "",
114
114
  *args: Any,
@@ -137,11 +137,11 @@ class RemoteModelTransformHandlers(Protocol):
137
137
 
138
138
  def score(
139
139
  self,
140
- input_cols: List[str],
141
- label_cols: List[str],
140
+ input_cols: list[str],
141
+ label_cols: list[str],
142
142
  session: snowpark.Session,
143
- dependencies: List[str],
144
- score_sproc_imports: List[str],
143
+ dependencies: list[str],
144
+ score_sproc_imports: list[str],
145
145
  sample_weight_col: Optional[str] = None,
146
146
  *args: Any,
147
147
  **kwargs: Any,
@@ -173,10 +173,10 @@ ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransform
173
173
  class BatchInferenceKwargsTypedDict(TypedDict, total=False):
174
174
  """A typed dict specifying all possible optional keyword args accepted by batch_inference() methods."""
175
175
 
176
- snowpark_input_cols: Optional[List[str]]
176
+ snowpark_input_cols: Optional[list[str]]
177
177
  drop_input_cols: Optional[bool]
178
178
  session: snowpark.Session
179
- dependencies: List[str]
179
+ dependencies: list[str]
180
180
  expected_output_cols_type: str
181
181
  n_neighbors: Optional[int]
182
182
  return_distance: bool
@@ -186,5 +186,5 @@ class ScoreKwargsTypedDict(TypedDict, total=False):
186
186
  """A typed dict specifying all possible optional keyword args accepted by score() methods."""
187
187
 
188
188
  session: snowpark.Session
189
- dependencies: List[str]
190
- score_sproc_imports: List[str]
189
+ dependencies: list[str]
190
+ score_sproc_imports: list[str]
@@ -3,7 +3,7 @@
3
3
  import inspect
4
4
  import warnings
5
5
  from enum import Enum
6
- from typing import Any, Callable, Dict, Iterable, Optional, Union
6
+ from typing import Any, Callable, Iterable, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import sklearn
@@ -62,7 +62,7 @@ class BasicStatistics(str, Enum):
62
62
  MODE = "mode"
63
63
 
64
64
 
65
- def get_default_args(func: Callable[..., None]) -> Dict[str, Any]:
65
+ def get_default_args(func: Callable[..., None]) -> dict[str, Any]:
66
66
  signature = inspect.signature(func)
67
67
  return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
68
68
 
@@ -72,16 +72,16 @@ def generate_value_with_prefix(prefix: str) -> str:
72
72
 
73
73
 
74
74
  def get_filtered_valid_sklearn_args(
75
- args: Dict[str, Any],
76
- default_sklearn_args: Dict[str, Any],
75
+ args: dict[str, Any],
76
+ default_sklearn_args: dict[str, Any],
77
77
  sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
78
78
  sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
79
79
  snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
80
- sklearn_added_keyword_to_version_dict: Optional[Dict[str, str]] = None,
81
- sklearn_added_kwarg_value_to_version_dict: Optional[Dict[str, Dict[str, str]]] = None,
82
- sklearn_deprecated_keyword_to_version_dict: Optional[Dict[str, str]] = None,
83
- sklearn_removed_keyword_to_version_dict: Optional[Dict[str, str]] = None,
84
- ) -> Dict[str, Any]:
80
+ sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
81
+ sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
82
+ sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
83
+ sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
84
+ ) -> dict[str, Any]:
85
85
  """
86
86
  Get valid sklearn keyword arguments with non-default values.
87
87
 
@@ -241,7 +241,7 @@ def to_native_format(obj: Any) -> Any:
241
241
  return obj.to_sklearn()
242
242
 
243
243
 
244
- def table_exists(session: snowpark.Session, table_name: str, statement_params: Dict[str, Any]) -> bool:
244
+ def table_exists(session: snowpark.Session, table_name: str, statement_params: dict[str, Any]) -> bool:
245
245
  try:
246
246
  session.table(table_name).limit(0).collect(statement_params=statement_params)
247
247
  return True
@@ -2,7 +2,7 @@
2
2
  import inspect
3
3
  from abc import abstractmethod
4
4
  from datetime import datetime
5
- from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
5
+ from typing import Any, Iterable, Mapping, Optional, Union, overload
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -28,9 +28,9 @@ SKLEARN_SUPERVISED_ESTIMATORS = ["regressor", "classifier"]
28
28
  SKLEARN_SINGLE_OUTPUT_ESTIMATORS = ["DensityEstimator", "clusterer", "outlier_detector"]
29
29
 
30
30
 
31
- def _process_cols(cols: Optional[Union[str, Iterable[str]]]) -> List[str]:
31
+ def _process_cols(cols: Optional[Union[str, Iterable[str]]]) -> list[str]:
32
32
  """Convert cols to a list."""
33
- col_list: List[str] = []
33
+ col_list: list[str] = []
34
34
  if cols is None:
35
35
  return col_list
36
36
  elif type(cols) is list:
@@ -55,10 +55,10 @@ class Base:
55
55
  passthrough_cols: List columns not to be used or modified by the estimator/transformers.
56
56
  These columns will be passed through all the estimator/transformer operations without any modifications.
57
57
  """
58
- self.input_cols: List[str] = []
59
- self.output_cols: List[str] = []
60
- self.label_cols: List[str] = []
61
- self.passthrough_cols: List[str] = []
58
+ self.input_cols: list[str] = []
59
+ self.output_cols: list[str] = []
60
+ self.label_cols: list[str] = []
61
+ self.passthrough_cols: list[str] = []
62
62
 
63
63
  def _create_unfitted_sklearn_object(self) -> Any:
64
64
  raise NotImplementedError()
@@ -66,7 +66,7 @@ class Base:
66
66
  def _create_sklearn_object(self) -> Any:
67
67
  raise NotImplementedError()
68
68
 
69
- def get_input_cols(self) -> List[str]:
69
+ def get_input_cols(self) -> list[str]:
70
70
  """
71
71
  Input columns getter.
72
72
 
@@ -88,7 +88,7 @@ class Base:
88
88
  self.input_cols = _process_cols(input_cols)
89
89
  return self
90
90
 
91
- def get_output_cols(self) -> List[str]:
91
+ def get_output_cols(self) -> list[str]:
92
92
  """
93
93
  Output columns getter.
94
94
 
@@ -110,7 +110,7 @@ class Base:
110
110
  self.output_cols = _process_cols(output_cols)
111
111
  return self
112
112
 
113
- def get_label_cols(self) -> List[str]:
113
+ def get_label_cols(self) -> list[str]:
114
114
  """
115
115
  Label column getter.
116
116
 
@@ -132,7 +132,7 @@ class Base:
132
132
  self.label_cols = _process_cols(label_cols)
133
133
  return self
134
134
 
135
- def get_passthrough_cols(self) -> List[str]:
135
+ def get_passthrough_cols(self) -> list[str]:
136
136
  """
137
137
  Passthrough columns getter.
138
138
 
@@ -215,7 +215,7 @@ class Base:
215
215
  )
216
216
 
217
217
  @classmethod
218
- def _get_param_names(cls) -> List[str]:
218
+ def _get_param_names(cls) -> list[str]:
219
219
  """Get parameter names for the transformer"""
220
220
  # fetch the constructor or the original constructor before
221
221
  # deprecation wrapping if any
@@ -244,7 +244,7 @@ class Base:
244
244
  # Extract and sort argument names excluding 'self'
245
245
  return sorted(p.name for p in parameters)
246
246
 
247
- def get_params(self, deep: bool = True) -> Dict[str, Any]:
247
+ def get_params(self, deep: bool = True) -> dict[str, Any]:
248
248
  """
249
249
  Get the snowflake-ml parameters for this transformer.
250
250
 
@@ -255,7 +255,7 @@ class Base:
255
255
  Returns:
256
256
  Parameter names mapped to their values.
257
257
  """
258
- out: Dict[str, Any] = dict()
258
+ out: dict[str, Any] = dict()
259
259
  for key in self._get_param_names():
260
260
  if hasattr(self, key):
261
261
  value = getattr(self, key)
@@ -320,11 +320,11 @@ class Base:
320
320
  sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
321
321
  sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
322
322
  snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
323
- sklearn_added_keyword_to_version_dict: Optional[Dict[str, str]] = None,
324
- sklearn_added_kwarg_value_to_version_dict: Optional[Dict[str, Dict[str, str]]] = None,
325
- sklearn_deprecated_keyword_to_version_dict: Optional[Dict[str, str]] = None,
326
- sklearn_removed_keyword_to_version_dict: Optional[Dict[str, str]] = None,
327
- ) -> Dict[str, Any]:
323
+ sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
324
+ sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
325
+ sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
326
+ sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
327
+ ) -> dict[str, Any]:
328
328
  """
329
329
  Get sklearn keyword arguments.
330
330
 
@@ -350,7 +350,7 @@ class Base:
350
350
  """
351
351
  default_sklearn_args = _utils.get_default_args(default_sklearn_obj.__class__.__init__)
352
352
  given_args = self.get_params()
353
- sklearn_args: Dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
353
+ sklearn_args: dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
354
354
  args=given_args,
355
355
  default_sklearn_args=default_sklearn_args,
356
356
  sklearn_initial_keywords=sklearn_initial_keywords,
@@ -368,8 +368,8 @@ class BaseEstimator(Base):
368
368
  def __init__(
369
369
  self,
370
370
  *,
371
- file_names: Optional[List[str]] = None,
372
- custom_states: Optional[List[str]] = None,
371
+ file_names: Optional[list[str]] = None,
372
+ custom_states: Optional[list[str]] = None,
373
373
  sample_weight_col: Optional[str] = None,
374
374
  ) -> None:
375
375
  """
@@ -418,7 +418,7 @@ class BaseEstimator(Base):
418
418
  self.sample_weight_col = sample_weight_col
419
419
  return self
420
420
 
421
- def _get_dependencies(self) -> List[str]:
421
+ def _get_dependencies(self) -> list[str]:
422
422
  """
423
423
  Return the list of conda dependencies required to work with the object.
424
424
 
@@ -458,8 +458,8 @@ class BaseEstimator(Base):
458
458
  return dataset[self.input_cols]
459
459
 
460
460
  def _compute(
461
- self, dataset: snowpark.DataFrame, cols: List[str], states: List[str]
462
- ) -> Dict[str, Dict[str, Union[int, float, str]]]:
461
+ self, dataset: snowpark.DataFrame, cols: list[str], states: list[str]
462
+ ) -> dict[str, dict[str, Union[int, float, str]]]:
463
463
  """
464
464
  Compute required states of the columns.
465
465
 
@@ -474,7 +474,7 @@ class BaseEstimator(Base):
474
474
  A dict of {column_name: {state: value}} of each column.
475
475
  """
476
476
 
477
- def _compute_on_partition(df: snowpark.DataFrame, cols_subset: List[str]) -> snowpark.DataFrame:
477
+ def _compute_on_partition(df: snowpark.DataFrame, cols_subset: list[str]) -> snowpark.DataFrame:
478
478
  """Returns a DataFrame with the desired computation on the specified column subset."""
479
479
  exprs = []
480
480
  sql_prefix = "SQL>>>"
@@ -499,7 +499,7 @@ class BaseEstimator(Base):
499
499
  statement_params=telemetry.get_statement_params(PROJECT, SUBPROJECT, self.__class__.__name__),
500
500
  )
501
501
 
502
- computed_dict: Dict[str, Dict[str, Union[int, float, str]]] = {}
502
+ computed_dict: dict[str, dict[str, Union[int, float, str]]] = {}
503
503
  for idx, val in enumerate(_results[0]):
504
504
  col_name = cols[idx // len(states)]
505
505
  if col_name not in computed_dict:
@@ -516,8 +516,8 @@ class BaseTransformer(BaseEstimator):
516
516
  self,
517
517
  *,
518
518
  drop_input_cols: Optional[bool] = False,
519
- file_names: Optional[List[str]] = None,
520
- custom_states: Optional[List[str]] = None,
519
+ file_names: Optional[list[str]] = None,
520
+ custom_states: Optional[list[str]] = None,
521
521
  sample_weight_col: Optional[str] = None,
522
522
  ) -> None:
523
523
  """Base class for all transformers."""
@@ -551,7 +551,7 @@ class BaseTransformer(BaseEstimator):
551
551
  ),
552
552
  )
553
553
 
554
- def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> List[str]:
554
+ def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> list[str]:
555
555
  """
556
556
  Infer input_cols from the dataset. Input column are all columns in the input dataset that are not
557
557
  designated as label, passthrough, or sample weight columns.
@@ -569,7 +569,7 @@ class BaseTransformer(BaseEstimator):
569
569
  ]
570
570
  return cols
571
571
 
572
- def _infer_output_cols(self) -> List[str]:
572
+ def _infer_output_cols(self) -> list[str]:
573
573
  """Infer output column names from based on the estimator.
574
574
 
575
575
  Returns:
@@ -624,7 +624,7 @@ class BaseTransformer(BaseEstimator):
624
624
  cols = self._infer_output_cols()
625
625
  self.set_output_cols(output_cols=cols)
626
626
 
627
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
627
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[list[str]] = None) -> list[str]:
628
628
  """Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
629
629
  Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
630
630
 
@@ -2,7 +2,7 @@ import os
2
2
 
3
3
  from snowflake.ml._internal import init_utils
4
4
 
5
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
5
+ pkg_dir = os.path.dirname(__file__)
6
6
  pkg_name = __name__
7
7
  exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
8
8
  for k, v in exportable_classes.items():
@@ -1,7 +1,7 @@
1
1
  #!/usr/bin/env python3
2
2
  import copy
3
3
  import warnings
4
- from typing import Any, Dict, Iterable, Optional, Type, Union
4
+ from typing import Any, Iterable, Optional, Union
5
5
 
6
6
  import numpy as np
7
7
  import numpy.typing as npt
@@ -25,7 +25,7 @@ STRATEGY_TO_STATE_DICT = {
25
25
  "most_frequent": _utils.BasicStatistics.MODE,
26
26
  }
27
27
 
28
- SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP: Dict[Type[T.DataType], npt.DTypeLike] = {
28
+ SNOWFLAKE_DATATYPE_TO_NUMPY_DTYPE_MAP: dict[type[T.DataType], npt.DTypeLike] = {
29
29
  T.ByteType: np.dtype("int8"),
30
30
  T.ShortType: np.dtype("int16"),
31
31
  T.IntegerType: np.dtype("int32"),
@@ -164,7 +164,7 @@ class SimpleImputer(base.BaseTransformer):
164
164
 
165
165
  self.fill_value = fill_value
166
166
  self.missing_values = missing_values
167
- self.statistics_: Dict[str, Any] = {}
167
+ self.statistics_: dict[str, Any] = {}
168
168
  # TODO(hayu): [SNOW-752265] Support SimpleImputer keep_empty_features.
169
169
  # Add back when `keep_empty_features` is supported.
170
170
  # self.keep_empty_features = keep_empty_features
@@ -195,7 +195,7 @@ class SimpleImputer(base.BaseTransformer):
195
195
  del self.feature_names_in_
196
196
  del self._sklearn_fit_dtype
197
197
 
198
- def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) -> Dict[str, T.DataType]:
198
+ def _get_dataset_input_col_datatypes(self, dataset: snowpark.DataFrame) -> dict[str, T.DataType]:
199
199
  """
200
200
  Checks that the input columns are all the same datatype category(except for most_frequent strategy) and
201
201
  returns the datatype.
@@ -211,7 +211,7 @@ class SimpleImputer(base.BaseTransformer):
211
211
  supported.
212
212
  """
213
213
 
214
- def check_type_consistency(col_types: Dict[str, T.DataType]) -> None:
214
+ def check_type_consistency(col_types: dict[str, T.DataType]) -> None:
215
215
  is_numeric_type = None
216
216
  for col_name, col_type in col_types.items():
217
217
  if is_numeric_type is None:
@@ -5,7 +5,7 @@ import cloudpickle
5
5
  from snowflake.ml._internal import init_utils
6
6
  from snowflake.ml._internal.utils import result
7
7
 
8
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
8
+ pkg_dir = os.path.dirname(__file__)
9
9
  pkg_name = __name__
10
10
  exportable_functions = init_utils.fetch_functions_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
11
11
  for k, v in exportable_functions.items():