snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  import pandas as pd
5
5
 
@@ -38,9 +38,9 @@ class PandasTransformHandlers:
38
38
  def batch_inference(
39
39
  self,
40
40
  inference_method: str,
41
- input_cols: List[str],
42
- expected_output_cols: List[str],
43
- snowpark_input_cols: Optional[List[str]] = None,
41
+ input_cols: list[str],
42
+ expected_output_cols: list[str],
43
+ snowpark_input_cols: Optional[list[str]] = None,
44
44
  drop_input_cols: Optional[bool] = False,
45
45
  *args: Any,
46
46
  **kwargs: Any,
@@ -147,8 +147,8 @@ class PandasTransformHandlers:
147
147
 
148
148
  def score(
149
149
  self,
150
- input_cols: List[str],
151
- label_cols: List[str],
150
+ input_cols: list[str],
151
+ label_cols: list[str],
152
152
  sample_weight_col: Optional[str],
153
153
  *args: Any,
154
154
  **kwargs: Any,
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import List, Optional, Tuple
2
+ from typing import Optional
3
3
 
4
4
  import pandas as pd
5
5
 
@@ -15,8 +15,8 @@ class PandasModelTrainer:
15
15
  self,
16
16
  estimator: object,
17
17
  dataset: pd.DataFrame,
18
- input_cols: List[str],
19
- label_cols: Optional[List[str]],
18
+ input_cols: list[str],
19
+ label_cols: Optional[list[str]],
20
20
  sample_weight_col: Optional[str],
21
21
  ) -> None:
22
22
  """
@@ -57,10 +57,10 @@ class PandasModelTrainer:
57
57
 
58
58
  def train_fit_predict(
59
59
  self,
60
- expected_output_cols_list: List[str],
60
+ expected_output_cols_list: list[str],
61
61
  drop_input_cols: Optional[bool] = False,
62
62
  example_output_pd_df: Optional[pd.DataFrame] = None,
63
- ) -> Tuple[pd.DataFrame, object]:
63
+ ) -> tuple[pd.DataFrame, object]:
64
64
  """Trains the model using specified features and target columns from the dataset.
65
65
  This API is different from fit itself because it would also provide the predict
66
66
  output.
@@ -92,9 +92,9 @@ class PandasModelTrainer:
92
92
 
93
93
  def train_fit_transform(
94
94
  self,
95
- expected_output_cols_list: List[str],
95
+ expected_output_cols_list: list[str],
96
96
  drop_input_cols: Optional[bool] = False,
97
- ) -> Tuple[pd.DataFrame, object]:
97
+ ) -> tuple[pd.DataFrame, object]:
98
98
  """Trains the model using specified features and target columns from the dataset.
99
99
  This API is different from fit itself because it would also provide the transform
100
100
  output.
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  import cloudpickle as cp
4
2
  import numpy as np
5
3
 
@@ -11,7 +9,7 @@ class ModelSpecifications:
11
9
  A dataclass to define model based specifications like required imports, and package dependencies for Sproc/Udfs.
12
10
  """
13
11
 
14
- def __init__(self, imports: List[str], pkgDependencies: List[str]) -> None:
12
+ def __init__(self, imports: list[str], pkgDependencies: list[str]) -> None:
15
13
  self.imports = imports
16
14
  self.pkgDependencies = pkgDependencies
17
15
 
@@ -20,7 +18,7 @@ class SKLearnModelSpecifications(ModelSpecifications):
20
18
  def __init__(self) -> None:
21
19
  import sklearn
22
20
 
23
- imports: List[str] = ["sklearn"]
21
+ imports: list[str] = ["sklearn"]
24
22
  # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
25
23
  pkgDependencies = [
26
24
  f"numpy=={np.__version__}",
@@ -56,8 +54,8 @@ class XGBoostModelSpecifications(ModelSpecifications):
56
54
  import sklearn
57
55
  import xgboost
58
56
 
59
- imports: List[str] = ["xgboost"]
60
- pkgDependencies: List[str] = [
57
+ imports: list[str] = ["xgboost"]
58
+ pkgDependencies: list[str] = [
61
59
  f"numpy=={np.__version__}",
62
60
  f"scikit-learn=={sklearn.__version__}",
63
61
  f"xgboost=={xgboost.__version__}",
@@ -71,8 +69,8 @@ class LightGBMModelSpecifications(ModelSpecifications):
71
69
  import lightgbm
72
70
  import sklearn
73
71
 
74
- imports: List[str] = ["lightgbm"]
75
- pkgDependencies: List[str] = [
72
+ imports: list[str] = ["lightgbm"]
73
+ pkgDependencies: list[str] = [
76
74
  f"numpy=={np.__version__}",
77
75
  f"scikit-learn=={sklearn.__version__}",
78
76
  f"lightgbm=={lightgbm.__version__}",
@@ -86,8 +84,8 @@ class SklearnModelSelectionModelSpecifications(ModelSpecifications):
86
84
  import sklearn
87
85
  import xgboost
88
86
 
89
- imports: List[str] = ["sklearn", "xgboost"]
90
- pkgDependencies: List[str] = [
87
+ imports: list[str] = ["sklearn", "xgboost"]
88
+ pkgDependencies: list[str] = [
91
89
  f"numpy=={np.__version__}",
92
90
  f"scikit-learn=={sklearn.__version__}",
93
91
  f"cloudpickle=={cp.__version__}",
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Protocol, Tuple, Union
1
+ from typing import Optional, Protocol, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -18,15 +18,15 @@ class ModelTrainer(Protocol):
18
18
 
19
19
  def train_fit_predict(
20
20
  self,
21
- expected_output_cols_list: List[str],
21
+ expected_output_cols_list: list[str],
22
22
  drop_input_cols: Optional[bool] = False,
23
23
  example_output_pd_df: Optional[pd.DataFrame] = None,
24
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
24
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
25
25
  raise NotImplementedError
26
26
 
27
27
  def train_fit_transform(
28
28
  self,
29
- expected_output_cols_list: List[str],
29
+ expected_output_cols_list: list[str],
30
30
  drop_input_cols: Optional[bool] = False,
31
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
31
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
32
32
  raise NotImplementedError
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
  from sklearn import model_selection
@@ -71,8 +71,8 @@ class ModelTrainerBuilder:
71
71
  cls,
72
72
  estimator: object,
73
73
  dataset: Union[DataFrame, pd.DataFrame],
74
- input_cols: Optional[List[str]] = None,
75
- label_cols: Optional[List[str]] = None,
74
+ input_cols: Optional[list[str]] = None,
75
+ label_cols: Optional[list[str]] = None,
76
76
  sample_weight_col: Optional[str] = None,
77
77
  autogenerated: bool = False,
78
78
  subproject: str = "",
@@ -130,7 +130,7 @@ class ModelTrainerBuilder:
130
130
  cls,
131
131
  estimator: object,
132
132
  dataset: Union[DataFrame, pd.DataFrame],
133
- input_cols: List[str],
133
+ input_cols: list[str],
134
134
  autogenerated: bool = False,
135
135
  subproject: str = "",
136
136
  ) -> ModelTrainer:
@@ -169,8 +169,8 @@ class ModelTrainerBuilder:
169
169
  cls,
170
170
  estimator: object,
171
171
  dataset: Union[DataFrame, pd.DataFrame],
172
- input_cols: List[str],
173
- label_cols: Optional[List[str]] = None,
172
+ input_cols: list[str],
173
+ label_cols: Optional[list[str]] = None,
174
174
  sample_weight_col: Optional[str] = None,
175
175
  autogenerated: bool = False,
176
176
  subproject: str = "",
@@ -5,7 +5,7 @@ import os
5
5
  import posixpath
6
6
  import sys
7
7
  import uuid
8
- from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Optional, Union
9
9
 
10
10
  import cloudpickle as cp
11
11
  import numpy as np
@@ -50,11 +50,11 @@ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}
50
50
  def construct_cv_results(
51
51
  estimator: Union[GridSearchCV, RandomizedSearchCV],
52
52
  n_split: int,
53
- param_grid: List[Dict[str, Any]],
54
- cv_results_raw_hex: List[Row],
53
+ param_grid: list[dict[str, Any]],
54
+ cv_results_raw_hex: list[Row],
55
55
  cross_validator_indices_length: int,
56
56
  parameter_grid_length: int,
57
- ) -> Tuple[bool, Dict[str, Any]]:
57
+ ) -> tuple[bool, dict[str, Any]]:
58
58
  """Construct the cross validation result from the UDF. Because we accelerate the process
59
59
  by the number of cross validation number, and the combination of parameter grids.
60
60
  Therefore, we need to stick them back together instead of returning the raw result
@@ -158,11 +158,11 @@ def construct_cv_results(
158
158
  def construct_cv_results_memory_efficient_version(
159
159
  estimator: Union[GridSearchCV, RandomizedSearchCV],
160
160
  n_split: int,
161
- param_grid: List[Dict[str, Any]],
162
- cv_results_raw_hex: List[Row],
161
+ param_grid: list[dict[str, Any]],
162
+ cv_results_raw_hex: list[Row],
163
163
  cross_validator_indices_length: int,
164
164
  parameter_grid_length: int,
165
- ) -> Tuple[Any, Dict[str, Any]]:
165
+ ) -> tuple[Any, dict[str, Any]]:
166
166
  """Construct the cross validation result from the UDF.
167
167
  The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
168
168
  This function need to decode the string and then call _format_result to stick them back together
@@ -210,7 +210,7 @@ def construct_cv_results_memory_efficient_version(
210
210
  # because original SearchCV is ranked by parameter first and cv second,
211
211
  # to make the memory efficient, we implemented by fitting on cv first and parameter second
212
212
  # when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
213
- def generate_the_order_by_parameter_index(all_combination_length: int) -> List[int]:
213
+ def generate_the_order_by_parameter_index(all_combination_length: int) -> list[int]:
214
214
  pattern = []
215
215
  for i in range(all_combination_length):
216
216
  if i % parameter_grid_length == 0:
@@ -221,7 +221,7 @@ def construct_cv_results_memory_efficient_version(
221
221
  pattern.append(j)
222
222
  return pattern
223
223
 
224
- def rerank_array(original_array: List[Any], pattern: List[int]) -> List[Any]:
224
+ def rerank_array(original_array: list[Any], pattern: list[int]) -> list[Any]:
225
225
  reranked_array = []
226
226
  for index in pattern:
227
227
  reranked_array.append(original_array[index])
@@ -251,8 +251,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
251
251
  estimator: object,
252
252
  dataset: DataFrame,
253
253
  session: Session,
254
- input_cols: List[str],
255
- label_cols: Optional[List[str]],
254
+ input_cols: list[str],
255
+ label_cols: Optional[list[str]],
256
256
  sample_weight_col: Optional[str],
257
257
  autogenerated: bool = False,
258
258
  subproject: str = "",
@@ -289,10 +289,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
289
289
  dataset: DataFrame,
290
290
  session: Session,
291
291
  estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
292
- dependencies: List[str],
293
- udf_imports: List[str],
294
- input_cols: List[str],
295
- label_cols: Optional[List[str]],
292
+ dependencies: list[str],
293
+ udf_imports: list[str],
294
+ input_cols: list[str],
295
+ label_cols: Optional[list[str]],
296
296
  sample_weight_col: Optional[str],
297
297
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
298
298
  from itertools import product
@@ -382,10 +382,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
382
382
  )
383
383
  def _distributed_search(
384
384
  session: Session,
385
- imports: List[str],
385
+ imports: list[str],
386
386
  stage_estimator_file_name: str,
387
- input_cols: List[str],
388
- label_cols: Optional[List[str]],
387
+ input_cols: list[str],
388
+ label_cols: Optional[list[str]],
389
389
  ) -> str:
390
390
  import os
391
391
  import time
@@ -455,12 +455,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
455
455
  assert estimator is not None
456
456
 
457
457
  @cachetools.cached(cache={})
458
- def _load_data_into_udf() -> Tuple[
459
- Dict[str, pd.DataFrame],
458
+ def _load_data_into_udf() -> tuple[
459
+ dict[str, pd.DataFrame],
460
460
  Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
461
461
  pd.DataFrame,
462
462
  int,
463
- List[Dict[str, Any]],
463
+ list[dict[str, Any]],
464
464
  ]:
465
465
  import pyarrow.parquet as pq
466
466
 
@@ -512,7 +512,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
512
512
  self.data_length = data_length
513
513
  self.params_to_evaluate = params_to_evaluate
514
514
 
515
- def process(self, params_idx: int, cv_idx: int) -> Iterator[Tuple[str]]:
515
+ def process(self, params_idx: int, cv_idx: int) -> Iterator[tuple[str]]:
516
516
  # Assign parameter to GridSearchCV
517
517
  if hasattr(estimator, "param_grid"):
518
518
  self.estimator.param_grid = self.params_to_evaluate[params_idx]
@@ -699,10 +699,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
699
699
  dataset: DataFrame,
700
700
  session: Session,
701
701
  estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
702
- dependencies: List[str],
703
- udf_imports: List[str],
704
- input_cols: List[str],
705
- label_cols: Optional[List[str]],
702
+ dependencies: list[str],
703
+ udf_imports: list[str],
704
+ input_cols: list[str],
705
+ label_cols: Optional[list[str]],
706
706
  sample_weight_col: Optional[str],
707
707
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
708
708
  from itertools import product
@@ -727,7 +727,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
727
727
  # Create a temp file and dump the estimator to that file.
728
728
  estimator_file_name = temp_file_utils.get_temp_file_path()
729
729
  params_to_evaluate = list(param_grid)
730
- CONSTANTS: Dict[str, Any] = dict()
730
+ CONSTANTS: dict[str, Any] = dict()
731
731
  CONSTANTS["dataset_snowpark_cols"] = dataset.columns
732
732
  CONSTANTS["n_candidates"] = len(params_to_evaluate)
733
733
  CONSTANTS["_N_JOBS"] = estimator.n_jobs
@@ -791,10 +791,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
791
791
  )
792
792
  def _distributed_search(
793
793
  session: Session,
794
- imports: List[str],
794
+ imports: list[str],
795
795
  stage_estimator_file_name: str,
796
- input_cols: List[str],
797
- label_cols: Optional[List[str]],
796
+ input_cols: list[str],
797
+ label_cols: Optional[list[str]],
798
798
  ) -> str:
799
799
  import os
800
800
  import time
@@ -3,7 +3,7 @@ import inspect
3
3
  import os
4
4
  import posixpath
5
5
  import sys
6
- from typing import Any, Dict, List, Optional
6
+ from typing import Any, Optional
7
7
  from uuid import uuid4
8
8
 
9
9
  import cloudpickle as cp
@@ -73,10 +73,10 @@ class SnowparkTransformHandlers:
73
73
  def batch_inference(
74
74
  self,
75
75
  inference_method: str,
76
- input_cols: List[str],
77
- expected_output_cols: List[str],
76
+ input_cols: list[str],
77
+ expected_output_cols: list[str],
78
78
  session: Session,
79
- dependencies: List[str],
79
+ dependencies: list[str],
80
80
  drop_input_cols: Optional[bool] = False,
81
81
  expected_output_cols_type: Optional[str] = "",
82
82
  *args: Any,
@@ -229,11 +229,11 @@ class SnowparkTransformHandlers:
229
229
 
230
230
  def score(
231
231
  self,
232
- input_cols: List[str],
233
- label_cols: List[str],
232
+ input_cols: list[str],
233
+ label_cols: list[str],
234
234
  session: Session,
235
- dependencies: List[str],
236
- score_sproc_imports: List[str],
235
+ dependencies: list[str],
236
+ score_sproc_imports: list[str],
237
237
  sample_weight_col: Optional[str] = None,
238
238
  *args: Any,
239
239
  **kwargs: Any,
@@ -308,12 +308,12 @@ class SnowparkTransformHandlers:
308
308
  )
309
309
  def score_wrapper_sproc(
310
310
  session: Session,
311
- sql_queries: List[str],
311
+ sql_queries: list[str],
312
312
  stage_score_file_name: str,
313
- input_cols: List[str],
314
- label_cols: List[str],
313
+ input_cols: list[str],
314
+ label_cols: list[str],
315
315
  sample_weight_col: Optional[str],
316
- score_statement_params: Dict[str, str],
316
+ score_statement_params: dict[str, str],
317
317
  ) -> float:
318
318
  import inspect
319
319
  import os
@@ -382,7 +382,7 @@ class SnowparkTransformHandlers:
382
382
 
383
383
  return score
384
384
 
385
- def _get_validated_snowpark_dependencies(self, session: Session, dependencies: List[str]) -> List[str]:
385
+ def _get_validated_snowpark_dependencies(self, session: Session, dependencies: list[str]) -> list[str]:
386
386
  """A helper function to validate dependencies and return the available packages that exists
387
387
  in the snowflake anaconda channel
388
388
 
@@ -2,7 +2,7 @@ import importlib
2
2
  import inspect
3
3
  import os
4
4
  import posixpath
5
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5
+ from typing import Any, Callable, Optional, Union
6
6
 
7
7
  import cloudpickle as cp
8
8
  import pandas as pd
@@ -55,8 +55,8 @@ class SnowparkModelTrainer:
55
55
  estimator: object,
56
56
  dataset: DataFrame,
57
57
  session: Session,
58
- input_cols: List[str],
59
- label_cols: Optional[List[str]],
58
+ input_cols: list[str],
59
+ label_cols: Optional[list[str]],
60
60
  sample_weight_col: Optional[str],
61
61
  autogenerated: bool = False,
62
62
  subproject: str = "",
@@ -84,7 +84,7 @@ class SnowparkModelTrainer:
84
84
  self._subproject = subproject
85
85
  self._class_name = estimator.__class__.__name__
86
86
 
87
- def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: Dict[str, str]) -> object:
87
+ def _fetch_model_from_stage(self, dir_path: str, file_name: str, statement_params: dict[str, str]) -> object:
88
88
  """
89
89
  Downloads the serialized model from a stage location and unpickles it.
90
90
 
@@ -112,7 +112,7 @@ class SnowparkModelTrainer:
112
112
  def _build_fit_wrapper_sproc(
113
113
  self,
114
114
  model_spec: ModelSpecifications,
115
- ) -> Callable[[Any, List[str], str, List[str], List[str], Optional[str], Dict[str, str]], str]:
115
+ ) -> Callable[[Any, list[str], str, list[str], list[str], Optional[str], dict[str, str]], str]:
116
116
  """
117
117
  Constructs and returns a python stored procedure function to be used for training model.
118
118
 
@@ -129,12 +129,12 @@ class SnowparkModelTrainer:
129
129
 
130
130
  def fit_wrapper_function(
131
131
  session: Session,
132
- sql_queries: List[str],
132
+ sql_queries: list[str],
133
133
  temp_stage_name: str,
134
- input_cols: List[str],
135
- label_cols: List[str],
134
+ input_cols: list[str],
135
+ label_cols: list[str],
136
136
  sample_weight_col: Optional[str],
137
- statement_params: Dict[str, str],
137
+ statement_params: dict[str, str],
138
138
  ) -> str:
139
139
  import inspect
140
140
  import os
@@ -218,7 +218,7 @@ class SnowparkModelTrainer:
218
218
 
219
219
  return fit_wrapper_function
220
220
 
221
- def _get_fit_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
221
+ def _get_fit_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
222
222
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
223
223
 
224
224
  fit_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -243,7 +243,7 @@ class SnowparkModelTrainer:
243
243
  def _build_fit_predict_wrapper_sproc(
244
244
  self,
245
245
  model_spec: ModelSpecifications,
246
- ) -> Callable[[Session, List[str], str, List[str], Dict[str, str], bool, List[str], str], str]:
246
+ ) -> Callable[[Session, list[str], str, list[str], dict[str, str], bool, list[str], str], str]:
247
247
  """
248
248
  Constructs and returns a python stored procedure function to be used for training model.
249
249
 
@@ -258,12 +258,12 @@ class SnowparkModelTrainer:
258
258
 
259
259
  def fit_predict_wrapper_function(
260
260
  session: Session,
261
- sql_queries: List[str],
261
+ sql_queries: list[str],
262
262
  temp_stage_name: str,
263
- input_cols: List[str],
264
- statement_params: Dict[str, str],
263
+ input_cols: list[str],
264
+ statement_params: dict[str, str],
265
265
  drop_input_cols: bool,
266
- expected_output_cols_list: List[str],
266
+ expected_output_cols_list: list[str],
267
267
  fit_predict_result_name: str,
268
268
  ) -> str:
269
269
  import os
@@ -346,14 +346,14 @@ class SnowparkModelTrainer:
346
346
  ) -> Callable[
347
347
  [
348
348
  Session,
349
- List[str],
349
+ list[str],
350
350
  str,
351
- List[str],
352
- Optional[List[str]],
351
+ list[str],
352
+ Optional[list[str]],
353
353
  Optional[str],
354
- Dict[str, str],
354
+ dict[str, str],
355
355
  bool,
356
- List[str],
356
+ list[str],
357
357
  str,
358
358
  ],
359
359
  str,
@@ -372,14 +372,14 @@ class SnowparkModelTrainer:
372
372
 
373
373
  def fit_transform_wrapper_function(
374
374
  session: Session,
375
- sql_queries: List[str],
375
+ sql_queries: list[str],
376
376
  temp_stage_name: str,
377
- input_cols: List[str],
378
- label_cols: Optional[List[str]],
377
+ input_cols: list[str],
378
+ label_cols: Optional[list[str]],
379
379
  sample_weight_col: Optional[str],
380
- statement_params: Dict[str, str],
380
+ statement_params: dict[str, str],
381
381
  drop_input_cols: bool,
382
- expected_output_cols_list: List[str],
382
+ expected_output_cols_list: list[str],
383
383
  fit_transform_result_name: str,
384
384
  ) -> str:
385
385
  import os
@@ -473,7 +473,7 @@ class SnowparkModelTrainer:
473
473
 
474
474
  return fit_transform_wrapper_function
475
475
 
476
- def _get_fit_predict_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
476
+ def _get_fit_predict_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
477
477
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
478
478
 
479
479
  fit_predict_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -495,7 +495,7 @@ class SnowparkModelTrainer:
495
495
 
496
496
  return fit_predict_wrapper_sproc
497
497
 
498
- def _get_fit_transform_wrapper_sproc(self, statement_params: Dict[str, str], anonymous: bool) -> StoredProcedure:
498
+ def _get_fit_transform_wrapper_sproc(self, statement_params: dict[str, str], anonymous: bool) -> StoredProcedure:
499
499
  model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
500
500
 
501
501
  fit_transform_sproc_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.PROCEDURE)
@@ -586,10 +586,10 @@ class SnowparkModelTrainer:
586
586
 
587
587
  def train_fit_predict(
588
588
  self,
589
- expected_output_cols_list: List[str],
589
+ expected_output_cols_list: list[str],
590
590
  drop_input_cols: Optional[bool] = False,
591
591
  example_output_pd_df: Optional[pd.DataFrame] = None,
592
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
592
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
593
593
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
594
594
  This API is different from fit itself because it would also provide the predict
595
595
  output.
@@ -682,9 +682,9 @@ class SnowparkModelTrainer:
682
682
 
683
683
  def train_fit_transform(
684
684
  self,
685
- expected_output_cols_list: List[str],
685
+ expected_output_cols_list: list[str],
686
686
  drop_input_cols: Optional[bool] = False,
687
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
687
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
688
688
  """Trains the model by pushing down the compute into Snowflake using stored procedures.
689
689
  This API is different from fit itself because it would also provide the transform
690
690
  output.