snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import inspect
3
3
  import os
4
4
  import posixpath
5
5
  import sys
6
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Optional
7
7
  from uuid import uuid4
8
8
 
9
9
  import cloudpickle as cp
@@ -73,10 +73,10 @@ class SnowparkTransformHandlers:
73
73
  def batch_inference(
74
74
  self,
75
75
  inference_method: str,
76
- input_cols: List[str],
77
- expected_output_cols: List[str],
76
+ input_cols: list[str],
77
+ expected_output_cols: list[str],
78
78
  session: Session,
79
- dependencies: List[str],
79
+ dependencies: list[str],
80
80
  drop_input_cols: Optional[bool] = False,
81
81
  expected_output_cols_type: Optional[str] = "",
82
82
  *args: Any,
@@ -229,11 +229,11 @@ class SnowparkTransformHandlers:
229
229
 
230
230
  def score(
231
231
  self,
232
- input_cols: List[str],
233
- label_cols: List[str],
232
+ input_cols: list[str],
233
+ label_cols: list[str],
234
234
  session: Session,
235
- dependencies: List[str],
236
- score_sproc_imports: List[str],
235
+ dependencies: list[str],
236
+ score_sproc_imports: list[str],
237
237
  sample_weight_col: Optional[str] = None,
238
238
  *args: Any,
239
239
  **kwargs: Any,
@@ -308,12 +308,12 @@ class SnowparkTransformHandlers:
308
308
  )
309
309
  def score_wrapper_sproc(
310
310
  session: Session,
311
- sql_queries: List[str],
311
+ sql_queries: list[str],
312
312
  stage_score_file_name: str,
313
- input_cols: List[str],
314
- label_cols: List[str],
313
+ input_cols: list[str],
314
+ label_cols: list[str],
315
315
  sample_weight_col: Optional[str],
316
- score_statement_params: Dict[str, str],
316
+ score_statement_params: dict[str, str],
317
317
  ) -> float:
318
318
  import inspect
319
319
  import os
@@ -382,7 +382,7 @@ class SnowparkTransformHandlers:
382
382
 
383
383
  return score
384
384
 
385
- def _get_validated_snowpark_dependencies(self, session: Session, dependencies: List[str]) -> List[str]:
385
+ def _get_validated_snowpark_dependencies(self, session: Session, dependencies: list[str]) -> list[str]:
386
386
  """A helper function to validate dependencies and return the available packages that exists
387
387
  in the snowflake anaconda channel
388
388
 
@@ -2,7 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import os
4
4
  import posixpath
5
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+ from typing import Any, Callable, Optional, Union
6
6
 
7
7
  import cloudpickle as cp
8
8
  import pandas as pd
@@ -55,8 +55,8 @@ class SnowparkModelTrainer:
55
55
  estimator: object,
56
56
  dataset: DataFrame,
57
57
  session: Session,
58
- input_cols: List[str],
59
- label_cols: Optional[List[str]],
58
+ input_cols: list[str],
59
+ label_cols: Optional[list[str]],
60
60
  sample_weight_col: Optional[str],
61
61
  autogenerated: bool = False,
62
62
  subproject: str = "",
@@ -84,7 +84,7 @@ class SnowparkModelTrainer:
84
84
  self._subproject = subproject
85
85
  self._class_name = estimator.__class__.__name__
86
86
 
87
- def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
87
+ def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: dict[str, str]) -> object:
88
88
  """
89
89
  Downloads the serialized model from a stage location and unpickles it.
90
90
 
@@ -112,7 +112,7 @@ class SnowparkModelTrainer:
112
112
  def _build_fit_wrapper_sproc(
113
113
  self,
114
114
  model_spec: ModelSpecifications,
115
- ) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
115
+ ) -> Callable[[Any, list[str], str, list[str], list[str], Optional[str], dict[str, str]], str]:
116
116
  """
117
117
  Constructs and returns a python stored procedure function to be used for training model.
118
118
 
@@ -129,12 +129,12 @@ class SnowparkModelTrainer:
129
129
 
130
130
  def fit_wrapper_function(
131
131
  session: Session,
132
- sql_queries: List[str],
132
+ sql_queries: list[str],
133
133
  temp_stage_name: str,
134
- input_cols: List[str],
135
- label_cols: List[str],
134
+ input_cols: list[str],
135
+ label_cols: list[str],
136
136
  sample_weight_col: Optional[str],
137
- statement_params: Dict[str, str],
137
+ statement_params: dict[str, str],
138
138
  ) -> str:
139
139
  import inspect
140
140
  import os
@@ -218,7 +218,7 @@ class SnowparkModelTrainer:
218
218
 
219
219
  return fit_wrapper_function
220
220
 
221
- def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
221
+ def _get_fit_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
222
222
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
223
223
 
224
224
  fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -243,7 +243,7 @@ class SnowparkModelTrainer:
243
243
  def _build_fit_predict_wrapper_sproc(
244
244
  self,
245
245
  model_spec: ModelSpecifications,
246
- ) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
246
+ ) -> Callable[[Session, list[str], str, list[str], dict[str, str], bool, list[str], str], str]:
247
247
  """
248
248
  Constructs and returns a python stored procedure function to be used for training model.
249
249
 
@@ -258,12 +258,12 @@ class SnowparkModelTrainer:
258
258
 
259
259
  def fit_predict_wrapper_function(
260
260
  session: Session,
261
- sql_queries: List[str],
261
+ sql_queries: list[str],
262
262
  temp_stage_name: str,
263
- input_cols: List[str],
264
- statement_params: Dict[str, str],
263
+ input_cols: list[str],
264
+ statement_params: dict[str, str],
265
265
  drop_input_cols: bool,
266
- expected_output_cols_list: List[str],
266
+ expected_output_cols_list: list[str],
267
267
  fit_predict_result_name: str,
268
268
  ) -> str:
269
269
  import os
@@ -346,14 +346,14 @@ class SnowparkModelTrainer:
346
346
  ) -> Callable[
347
347
  [
348
348
  Session,
349
- List[str],
349
+ list[str],
350
350
  str,
351
- List[str],
352
- Optional[List[str]],
351
+ list[str],
352
+ Optional[list[str]],
353
353
  Optional[str],
354
- Dict[str, str],
354
+ dict[str, str],
355
355
  bool,
356
- List[str],
356
+ list[str],
357
357
  str,
358
358
  ],
359
359
  str,
@@ -372,14 +372,14 @@ class SnowparkModelTrainer:
372
372
 
373
373
  def fit_transform_wrapper_function(
374
374
  session: Session,
375
- sql_queries: List[str],
375
+ sql_queries: list[str],
376
376
  temp_stage_name: str,
377
- input_cols: List[str],
378
- label_cols: Optional[List[str]],
377
+ input_cols: list[str],
378
+ label_cols: Optional[list[str]],
379
379
  sample_weight_col: Optional[str],
380
- statement_params: Dict[str, str],
380
+ statement_params: dict[str, str],
381
381
  drop_input_cols: bool,
382
- expected_output_cols_list: List[str],
382
+ expected_output_cols_list: list[str],
383
383
  fit_transform_result_name: str,
384
384
  ) -> str:
385
385
  import os
@@ -473,7 +473,7 @@ class SnowparkModelTrainer:
473
473
 
474
474
  return fit_transform_wrapper_function
475
475
 
476
- def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
476
+ def _get_fit_predict_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
477
477
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
478
478
 
479
479
  fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -495,7 +495,7 @@ class SnowparkModelTrainer:
495
495
 
496
496
  return fit_predict_wrapper_sproc
497
497
 
498
- def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
498
+ def _get_fit_transform_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
499
499
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
500
500
 
501
501
  fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -586,10 +586,10 @@ class SnowparkModelTrainer:
586
586
 
587
587
  def train_fit_predict(
588
588
  self,
589
- expected_output_cols_list: List[str],
589
+ expected_output_cols_list: list[str],
590
590
  drop_input_cols: Optional[bool] = False,
591
591
  example_output_pd_df: Optional[pd.DataFrame] = None,
592
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
592
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
593
593
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
594
594
  This API is different from fit itself because it would also provide the predict
595
595
  output.
@@ -682,9 +682,9 @@ class SnowparkModelTrainer:
682
682
 
683
683
  def train_fit_transform(
684
684
  self,
685
- expected_output_cols_list: List[str],
685
+ expected_output_cols_list: list[str],
686
686
  drop_input_cols: Optional[bool] = False,
687
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
687
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
688
688
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
689
689
  This API is different from fit itself because it would also provide the transform
690
690
  output.
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import os
3
3
  import tempfile
4
- from typing import Any, Dict, List, Optional
4
+ from typing import Any, Optional
5
5
 
6
6
  import cloudpickle as cp
7
7
  import pandas as pd
@@ -41,13 +41,13 @@ _PROJECT = "ModelDevelopment"
41
41
 
42
42
 
43
43
  def get_data_iterator(
44
- file_paths: List[str],
44
+ file_paths: list[str],
45
45
  batch_size: int,
46
- input_cols: List[str],
47
- label_cols: List[str],
46
+ input_cols: list[str],
47
+ label_cols: list[str],
48
48
  sample_weight_col: Optional[str] = None,
49
49
  ) -> Any:
50
- from typing import List, Optional
50
+ from typing import Optional
51
51
 
52
52
  import xgboost
53
53
 
@@ -60,10 +60,10 @@ def get_data_iterator(
60
60
 
61
61
  def __init__(
62
62
  self,
63
- file_paths: List[str],
63
+ file_paths: list[str],
64
64
  batch_size: int,
65
- input_cols: List[str],
66
- label_cols: List[str],
65
+ input_cols: list[str],
66
+ label_cols: list[str],
67
67
  sample_weight_col: Optional[str] = None,
68
68
  ) -> None:
69
69
  """
@@ -151,10 +151,10 @@ def get_data_iterator(
151
151
 
152
152
  def train_xgboost_model(
153
153
  estimator: object,
154
- file_paths: List[str],
154
+ file_paths: list[str],
155
155
  batch_size: int,
156
- input_cols: List[str],
157
- label_cols: List[str],
156
+ input_cols: list[str],
157
+ label_cols: list[str],
158
158
  sample_weight_col: Optional[str] = None,
159
159
  ) -> object:
160
160
  """
@@ -247,8 +247,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
247
247
  estimator: object,
248
248
  dataset: DataFrame,
249
249
  session: Session,
250
- input_cols: List[str],
251
- label_cols: Optional[List[str]],
250
+ input_cols: list[str],
251
+ label_cols: Optional[list[str]],
252
252
  sample_weight_col: Optional[str],
253
253
  autogenerated: bool = False,
254
254
  subproject: str = "",
@@ -285,8 +285,8 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
285
285
  self,
286
286
  model_spec: ModelSpecifications,
287
287
  session: Session,
288
- statement_params: Dict[str, str],
289
- import_file_paths: List[str],
288
+ statement_params: dict[str, str],
289
+ import_file_paths: list[str],
290
290
  ) -> Any:
291
291
  fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
292
292
 
@@ -308,10 +308,10 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
308
308
  session: Session,
309
309
  dataset_stage_name: str,
310
310
  batch_size: int,
311
- input_cols: List[str],
312
- label_cols: List[str],
311
+ input_cols: list[str],
312
+ label_cols: list[str],
313
313
  sample_weight_col: Optional[str],
314
- statement_params: Dict[str, str],
314
+ statement_params: dict[str, str],
315
315
  ) -> str:
316
316
  import os
317
317
  import sys
@@ -365,7 +365,7 @@ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
365
365
 
366
366
  return fit_wrapper_sproc
367
367
 
368
- def _write_training_data_to_stage(self, dataset_stage_name: str) -> List[str]:
368
+ def _write_training_data_to_stage(self, dataset_stage_name: str) -> list[str]:
369
369
  """
370
370
  Materializes the training to the specified stage and returns the list of stage file paths.
371
371
 
@@ -1,4 +1,4 @@
1
- from typing import Any, List, Optional, Protocol, TypedDict, Union
1
+ from typing import Any, Optional, Protocol, TypedDict, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -29,9 +29,9 @@ class LocalModelTransformHandlers(Protocol):
29
29
  def batch_inference(
30
30
  self,
31
31
  inference_method: str,
32
- input_cols: List[str],
33
- expected_output_cols: List[str],
34
- snowpark_input_cols: Optional[List[str]],
32
+ input_cols: list[str],
33
+ expected_output_cols: list[str],
34
+ snowpark_input_cols: Optional[list[str]],
35
35
  drop_input_cols: Optional[bool] = False,
36
36
  *args: Any,
37
37
  **kwargs: Any,
@@ -57,8 +57,8 @@ class LocalModelTransformHandlers(Protocol):
57
57
 
58
58
  def score(
59
59
  self,
60
- input_cols: List[str],
61
- label_cols: List[str],
60
+ input_cols: list[str],
61
+ label_cols: list[str],
62
62
  sample_weight_col: Optional[str],
63
63
  *args: Any,
64
64
  **kwargs: Any,
@@ -105,10 +105,10 @@ class RemoteModelTransformHandlers(Protocol):
105
105
  def batch_inference(
106
106
  self,
107
107
  inference_method: str,
108
- input_cols: List[str],
109
- expected_output_cols: List[str],
108
+ input_cols: list[str],
109
+ expected_output_cols: list[str],
110
110
  session: snowpark.Session,
111
- dependencies: List[str],
111
+ dependencies: list[str],
112
112
  drop_input_cols: Optional[bool] = False,
113
113
  expected_output_cols_type: Optional[str] = "",
114
114
  *args: Any,
@@ -137,11 +137,11 @@ class RemoteModelTransformHandlers(Protocol):
137
137
 
138
138
  def score(
139
139
  self,
140
- input_cols: List[str],
141
- label_cols: List[str],
140
+ input_cols: list[str],
141
+ label_cols: list[str],
142
142
  session: snowpark.Session,
143
- dependencies: List[str],
144
- score_sproc_imports: List[str],
143
+ dependencies: list[str],
144
+ score_sproc_imports: list[str],
145
145
  sample_weight_col: Optional[str] = None,
146
146
  *args: Any,
147
147
  **kwargs: Any,
@@ -173,10 +173,10 @@ ModelTransformHandlers = Union[LocalModelTransformHandlers, RemoteModelTransform
173
173
  class BatchInferenceKwargsTypedDict(TypedDict, total=False):
174
174
  """A typed dict specifying all possible optional keyword args accepted by batch_inference() methods."""
175
175
 
176
- snowpark_input_cols: Optional[List[str]]
176
+ snowpark_input_cols: Optional[list[str]]
177
177
  drop_input_cols: Optional[bool]
178
178
  session: snowpark.Session
179
- dependencies: List[str]
179
+ dependencies: list[str]
180
180
  expected_output_cols_type: str
181
181
  n_neighbors: Optional[int]
182
182
  return_distance: bool
@@ -186,5 +186,5 @@ class ScoreKwargsTypedDict(TypedDict, total=False):
186
186
  """A typed dict specifying all possible optional keyword args accepted by score() methods."""
187
187
 
188
188
  session: snowpark.Session
189
- dependencies: List[str]
190
- score_sproc_imports: List[str]
189
+ dependencies: list[str]
190
+ score_sproc_imports: list[str]
@@ -3,7 +3,7 @@
3
3
  import inspect
4
4
  import warnings
5
5
  from enum import Enum
6
- from typing import Any, Callable, Dict, Iterable, Optional, Union
6
+ from typing import Any, Callable, Iterable, Optional, Union
7
7
 
8
8
  import numpy as np
9
9
  import sklearn
@@ -62,7 +62,7 @@ class BasicStatistics(str, Enum):
62
62
  MODE = "mode"
63
63
 
64
64
 
65
- def get_default_args(func: Callable[..., None]) -> Dict[str, Any]:
65
+ def get_default_args(func: Callable[..., None]) -> dict[str, Any]:
66
66
  signature = inspect.signature(func)
67
67
  return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
68
68
 
@@ -72,16 +72,16 @@ def generate_value_with_prefix(prefix: str) -> str:
72
72
 
73
73
 
74
74
  def get_filtered_valid_sklearn_args(
75
- args: Dict[str, Any],
76
- default_sklearn_args: Dict[str, Any],
75
+ args: dict[str, Any],
76
+ default_sklearn_args: dict[str, Any],
77
77
  sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
78
78
  sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
79
79
  snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
80
- sklearn_added_keyword_to_version_dict: Optional[Dict[str, str]] = None,
81
- sklearn_added_kwarg_value_to_version_dict: Optional[Dict[str, Dict[str, str]]] = None,
82
- sklearn_deprecated_keyword_to_version_dict: Optional[Dict[str, str]] = None,
83
- sklearn_removed_keyword_to_version_dict: Optional[Dict[str, str]] = None,
84
- ) -> Dict[str, Any]:
80
+ sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
81
+ sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
82
+ sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
83
+ sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
84
+ ) -> dict[str, Any]:
85
85
  """
86
86
  Get valid sklearn keyword arguments with non-default values.
87
87
 
@@ -241,7 +241,7 @@ def to_native_format(obj: Any) -> Any:
241
241
  return obj.to_sklearn()
242
242
 
243
243
 
244
- def table_exists(session: snowpark.Session, table_name: str, statement_params: Dict[str, Any]) -> bool:
244
+ def table_exists(session: snowpark.Session, table_name: str, statement_params: dict[str, Any]) -> bool:
245
245
  try:
246
246
  session.table(table_name).limit(0).collect(statement_params=statement_params)
247
247
  return True
@@ -2,7 +2,7 @@
2
2
  import inspect
3
3
  from abc import abstractmethod
4
4
  from datetime import datetime
5
- from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
5
+ from typing import Any, Iterable, Mapping, Optional, Union, overload
6
6
 
7
7
  import numpy as np
8
8
  import numpy.typing as npt
@@ -28,9 +28,9 @@ SKLEARN_SUPERVISED_ESTIMATORS = ["regressor", "classifier"]
28
28
  SKLEARN_SINGLE_OUTPUT_ESTIMATORS = ["DensityEstimator", "clusterer", "outlier_detector"]
29
29
 
30
30
 
31
- def _process_cols(cols: Optional[Union[str, Iterable[str]]]) -> List[str]:
31
+ def _process_cols(cols: Optional[Union[str, Iterable[str]]]) -> list[str]:
32
32
  """Convert cols to a list."""
33
- col_list: List[str] = []
33
+ col_list: list[str] = []
34
34
  if cols is None:
35
35
  return col_list
36
36
  elif type(cols) is list:
@@ -55,10 +55,10 @@ class Base:
55
55
  passthrough_cols: List columns not to be used or modified by the estimator/transformers.
56
56
  These columns will be passed through all the estimator/transformer operations without any modifications.
57
57
  """
58
- self.input_cols: List[str] = []
59
- self.output_cols: List[str] = []
60
- self.label_cols: List[str] = []
61
- self.passthrough_cols: List[str] = []
58
+ self.input_cols: list[str] = []
59
+ self.output_cols: list[str] = []
60
+ self.label_cols: list[str] = []
61
+ self.passthrough_cols: list[str] = []
62
62
 
63
63
  def _create_unfitted_sklearn_object(self) -> Any:
64
64
  raise NotImplementedError()
@@ -66,7 +66,7 @@ class Base:
66
66
  def _create_sklearn_object(self) -> Any:
67
67
  raise NotImplementedError()
68
68
 
69
- def get_input_cols(self) -> List[str]:
69
+ def get_input_cols(self) -> list[str]:
70
70
  """
71
71
  Input columns getter.
72
72
 
@@ -88,7 +88,7 @@ class Base:
88
88
  self.input_cols = _process_cols(input_cols)
89
89
  return self
90
90
 
91
- def get_output_cols(self) -> List[str]:
91
+ def get_output_cols(self) -> list[str]:
92
92
  """
93
93
  Output columns getter.
94
94
 
@@ -110,7 +110,7 @@ class Base:
110
110
  self.output_cols = _process_cols(output_cols)
111
111
  return self
112
112
 
113
- def get_label_cols(self) -> List[str]:
113
+ def get_label_cols(self) -> list[str]:
114
114
  """
115
115
  Label column getter.
116
116
 
@@ -132,7 +132,7 @@ class Base:
132
132
  self.label_cols = _process_cols(label_cols)
133
133
  return self
134
134
 
135
- def get_passthrough_cols(self) -> List[str]:
135
+ def get_passthrough_cols(self) -> list[str]:
136
136
  """
137
137
  Passthrough columns getter.
138
138
 
@@ -215,7 +215,7 @@ class Base:
215
215
  )
216
216
 
217
217
  @classmethod
218
- def _get_param_names(cls) -> List[str]:
218
+ def _get_param_names(cls) -> list[str]:
219
219
  """Get parameter names for the transformer"""
220
220
  # fetch the constructor or the original constructor before
221
221
  # deprecation wrapping if any
@@ -244,7 +244,7 @@ class Base:
244
244
  # Extract and sort argument names excluding 'self'
245
245
  return sorted(p.name for p in parameters)
246
246
 
247
- def get_params(self, deep: bool = True) -> Dict[str, Any]:
247
+ def get_params(self, deep: bool = True) -> dict[str, Any]:
248
248
  """
249
249
  Get the snowflake-ml parameters for this transformer.
250
250
 
@@ -255,7 +255,7 @@ class Base:
255
255
  Returns:
256
256
  Parameter names mapped to their values.
257
257
  """
258
- out: Dict[str, Any] = dict()
258
+ out: dict[str, Any] = dict()
259
259
  for key in self._get_param_names():
260
260
  if hasattr(self, key):
261
261
  value = getattr(self, key)
@@ -320,11 +320,11 @@ class Base:
320
320
  sklearn_initial_keywords: Optional[Union[str, Iterable[str]]] = None,
321
321
  sklearn_unused_keywords: Optional[Union[str, Iterable[str]]] = None,
322
322
  snowml_only_keywords: Optional[Union[str, Iterable[str]]] = None,
323
- sklearn_added_keyword_to_version_dict: Optional[Dict[str, str]] = None,
324
- sklearn_added_kwarg_value_to_version_dict: Optional[Dict[str, Dict[str, str]]] = None,
325
- sklearn_deprecated_keyword_to_version_dict: Optional[Dict[str, str]] = None,
326
- sklearn_removed_keyword_to_version_dict: Optional[Dict[str, str]] = None,
327
- ) -> Dict[str, Any]:
323
+ sklearn_added_keyword_to_version_dict: Optional[dict[str, str]] = None,
324
+ sklearn_added_kwarg_value_to_version_dict: Optional[dict[str, dict[str, str]]] = None,
325
+ sklearn_deprecated_keyword_to_version_dict: Optional[dict[str, str]] = None,
326
+ sklearn_removed_keyword_to_version_dict: Optional[dict[str, str]] = None,
327
+ ) -> dict[str, Any]:
328
328
  """
329
329
  Get sklearn keyword arguments.
330
330
 
@@ -350,7 +350,7 @@ class Base:
350
350
  """
351
351
  default_sklearn_args = _utils.get_default_args(default_sklearn_obj.__class__.__init__)
352
352
  given_args = self.get_params()
353
- sklearn_args: Dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
353
+ sklearn_args: dict[str, Any] = _utils.get_filtered_valid_sklearn_args(
354
354
  args=given_args,
355
355
  default_sklearn_args=default_sklearn_args,
356
356
  sklearn_initial_keywords=sklearn_initial_keywords,
@@ -368,8 +368,8 @@ class BaseEstimator(Base):
368
368
  def __init__(
369
369
  self,
370
370
  *,
371
- file_names: Optional[List[str]] = None,
372
- custom_states: Optional[List[str]] = None,
371
+ file_names: Optional[list[str]] = None,
372
+ custom_states: Optional[list[str]] = None,
373
373
  sample_weight_col: Optional[str] = None,
374
374
  ) -> None:
375
375
  """
@@ -418,7 +418,7 @@ class BaseEstimator(Base):
418
418
  self.sample_weight_col = sample_weight_col
419
419
  return self
420
420
 
421
- def _get_dependencies(self) -> List[str]:
421
+ def _get_dependencies(self) -> list[str]:
422
422
  """
423
423
  Return the list of conda dependencies required to work with the object.
424
424
 
@@ -458,8 +458,8 @@ class BaseEstimator(Base):
458
458
  return dataset[self.input_cols]
459
459
 
460
460
  def _compute(
461
- self, dataset: snowpark.DataFrame, cols: List[str], states: List[str]
462
- ) -> Dict[str, Dict[str, Union[int, float, str]]]:
461
+ self, dataset: snowpark.DataFrame, cols: list[str], states: list[str]
462
+ ) -> dict[str, dict[str, Union[int, float, str]]]:
463
463
  """
464
464
  Compute required states of the columns.
465
465
 
@@ -474,7 +474,7 @@ class BaseEstimator(Base):
474
474
  A dict of {column_name: {state: value}} of each column.
475
475
  """
476
476
 
477
- def _compute_on_partition(df: snowpark.DataFrame, cols_subset: List[str]) -> snowpark.DataFrame:
477
+ def _compute_on_partition(df: snowpark.DataFrame, cols_subset: list[str]) -> snowpark.DataFrame:
478
478
  """Returns a DataFrame with the desired computation on the specified column subset."""
479
479
  exprs = []
480
480
  sql_prefix = "SQL>>>"
@@ -499,7 +499,7 @@ class BaseEstimator(Base):
499
499
  statement_params=telemetry.get_statement_params(PROJECT, SUBPROJECT, self.__class__.__name__),
500
500
  )
501
501
 
502
- computed_dict: Dict[str, Dict[str, Union[int, float, str]]] = {}
502
+ computed_dict: dict[str, dict[str, Union[int, float, str]]] = {}
503
503
  for idx, val in enumerate(_results[0]):
504
504
  col_name = cols[idx // len(states)]
505
505
  if col_name not in computed_dict:
@@ -516,8 +516,8 @@ class BaseTransformer(BaseEstimator):
516
516
  self,
517
517
  *,
518
518
  drop_input_cols: Optional[bool] = False,
519
- file_names: Optional[List[str]] = None,
520
- custom_states: Optional[List[str]] = None,
519
+ file_names: Optional[list[str]] = None,
520
+ custom_states: Optional[list[str]] = None,
521
521
  sample_weight_col: Optional[str] = None,
522
522
  ) -> None:
523
523
  """Base class for all transformers."""
@@ -551,7 +551,7 @@ class BaseTransformer(BaseEstimator):
551
551
  ),
552
552
  )
553
553
 
554
- def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> List[str]:
554
+ def _infer_input_cols(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> list[str]:
555
555
  """
556
556
  Infer input_cols from the dataset. Input column are all columns in the input dataset that are not
557
557
  designated as label, passthrough, or sample weight columns.
@@ -569,7 +569,7 @@ class BaseTransformer(BaseEstimator):
569
569
  ]
570
570
  return cols
571
571
 
572
- def _infer_output_cols(self) -> List[str]:
572
+ def _infer_output_cols(self) -> list[str]:
573
573
  """Infer output column names from based on the estimator.
574
574
 
575
575
  Returns:
@@ -624,7 +624,7 @@ class BaseTransformer(BaseEstimator):
624
624
  cols = self._infer_output_cols()
625
625
  self.set_output_cols(output_cols=cols)
626
626
 
627
- def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
627
+ def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[list[str]] = None) -> list[str]:
628
628
  """Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
629
629
  Returns a list with output_cols_prefix as the only element if the estimator is not a classifier.
630
630
 
@@ -2,7 +2,7 @@ import os
2
2
 
3
3
  from snowflake.ml._internal import init_utils
4
4
 
5
- pkg_dir = os.path.dirname(os.path.abspath(__file__))
5
+ pkg_dir = os.path.dirname(__file__)
6
6
  pkg_name = __name__
7
7
  exportable_classes = init_utils.fetch_classes_from_modules_in_pkg_dir(pkg_dir=pkg_dir, pkg_name=pkg_name)
8
8
  for k, v in exportable_classes.items():