snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import inspect
3
- from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Union
3
+ from typing import Any, Callable, Coroutine, Generator, Optional, Union
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
@@ -78,7 +78,7 @@ class ModelRef:
78
78
  return MethodRef(self, method_name)
79
79
  raise AttributeError(f"Method {method_name} not found in model {self._name}.")
80
80
 
81
- def __getstate__(self) -> Dict[str, Any]:
81
+ def __getstate__(self) -> dict[str, Any]:
82
82
  state = self.__dict__.copy()
83
83
  del state["_model"]
84
84
  return state
@@ -113,8 +113,8 @@ class ModelContext:
113
113
  def __init__(
114
114
  self,
115
115
  *,
116
- artifacts: Optional[Union[Dict[str, str], str, model_types.SupportedModelType]] = None,
117
- models: Optional[Union[Dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
116
+ artifacts: Optional[Union[dict[str, str], str, model_types.SupportedModelType]] = None,
117
+ models: Optional[Union[dict[str, model_types.SupportedModelType], str, model_types.SupportedModelType]] = None,
118
118
  **kwargs: Optional[Union[str, model_types.SupportedModelType]],
119
119
  ) -> None:
120
120
  """Initialize the model context.
@@ -130,8 +130,8 @@ class ModelContext:
130
130
  ValueError: Raised when the model name is duplicated.
131
131
  """
132
132
 
133
- self.artifacts: Dict[str, str] = dict()
134
- self.model_refs: Dict[str, ModelRef] = dict()
133
+ self.artifacts: dict[str, str] = dict()
134
+ self.model_refs: dict[str, ModelRef] = dict()
135
135
 
136
136
  # In case that artifacts is a dictionary, assume the original usage,
137
137
  # which is to pass in a dictionary of artifacts.
@@ -185,7 +185,7 @@ class ModelContext:
185
185
  return self.model_refs[name]
186
186
 
187
187
  def __getitem__(self, key: str) -> Union[str, ModelRef]:
188
- combined: Dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
188
+ combined: dict[str, Union[str, ModelRef]] = {**self.artifacts, **self.model_refs}
189
189
  if key not in combined:
190
190
  raise KeyError(f"Key {key} not found in the kwargs, current available keys are: {combined.keys()}")
191
191
  return combined[key]
@@ -226,7 +226,7 @@ class CustomModel:
226
226
  else:
227
227
  raise TypeError("A non-method inference API function is not supported.")
228
228
 
229
- def _get_partitioned_infer_methods(self) -> List[str]:
229
+ def _get_partitioned_infer_methods(self) -> list[str]:
230
230
  """Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
231
231
  rv = []
232
232
  for cls_method_str in dir(self):
@@ -1,18 +1,7 @@
1
1
  import enum
2
2
  import json
3
3
  import warnings
4
- from typing import (
5
- Any,
6
- Dict,
7
- List,
8
- Literal,
9
- Optional,
10
- Sequence,
11
- Tuple,
12
- Type,
13
- Union,
14
- cast,
15
- )
4
+ from typing import Any, Literal, Optional, Sequence, Union, cast
16
5
 
17
6
  import numpy as np
18
7
  import pandas as pd
@@ -30,7 +19,7 @@ from snowflake.ml._internal.utils import formatting, identifier, sql_identifier
30
19
  from snowflake.ml.model import type_hints as model_types
31
20
  from snowflake.ml.model._signatures import (
32
21
  base_handler,
33
- builtins_handler as builtins_handler,
22
+ builtins_handler,
34
23
  core,
35
24
  dmatrix_handler,
36
25
  numpy_handler,
@@ -48,7 +37,7 @@ FeatureGroupSpec = core.FeatureGroupSpec
48
37
  ModelSignature = core.ModelSignature
49
38
 
50
39
 
51
- _LOCAL_DATA_HANDLERS: List[Type[base_handler.BaseDataHandler[Any]]] = [
40
+ _LOCAL_DATA_HANDLERS: list[type[base_handler.BaseDataHandler[Any]]] = [
52
41
  pandas_handler.PandasDataFrameHandler,
53
42
  numpy_handler.NumpyArrayHandler,
54
43
  builtins_handler.ListOfBuiltinHandler,
@@ -414,7 +403,7 @@ class SnowparkIdentifierRule(enum.Enum):
414
403
 
415
404
  def _get_dataframe_values_range(
416
405
  df: snowflake.snowpark.DataFrame,
417
- ) -> Dict[str, Union[Tuple[int, int], Tuple[float, float]]]:
406
+ ) -> dict[str, Union[tuple[int, int], tuple[float, float]]]:
418
407
  columns = [
419
408
  F.array_construct(F.min(field.name), F.max(field.name)).as_(field.name)
420
409
  for field in df.schema.fields
@@ -429,7 +418,7 @@ def _get_dataframe_values_range(
429
418
  original_exception=ValueError(f"Unable to get the value range of fields {df.columns}"),
430
419
  )
431
420
  return cast(
432
- Dict[str, Union[Tuple[int, int], Tuple[float, float]]],
421
+ dict[str, Union[tuple[int, int], tuple[float, float]]],
433
422
  {
434
423
  sql_identifier.SqlIdentifier(k, case_sensitive=True).identifier(): (json.loads(v)[0], json.loads(v)[1])
435
424
  for k, v in res[0].as_dict().items()
@@ -456,7 +445,7 @@ def _validate_snowpark_data(
456
445
  - inferred: signature `a` - Snowpark DF `"a"`, use `get_inferred_name`
457
446
  - normalized: signature `a` - Snowpark DF `A`, use `resolve_identifier`
458
447
  """
459
- errors: Dict[SnowparkIdentifierRule, List[Exception]] = {
448
+ errors: dict[SnowparkIdentifierRule, list[Exception]] = {
460
449
  SnowparkIdentifierRule.INFERRED: [],
461
450
  SnowparkIdentifierRule.NORMALIZED: [],
462
451
  }
@@ -549,7 +538,7 @@ def _validate_snowpark_type_feature(
549
538
  field: spt.StructField,
550
539
  ft_type: DataType,
551
540
  ft_name: str,
552
- value_range: Optional[Union[Tuple[int, int], Tuple[float, float]]],
541
+ value_range: Optional[Union[tuple[int, int], tuple[float, float]]],
553
542
  strict: bool = False,
554
543
  ) -> None:
555
544
  field_data_type = field.datatype
@@ -716,8 +705,8 @@ def _convert_and_validate_local_data(
716
705
  def infer_signature(
717
706
  input_data: model_types.SupportedLocalDataType,
718
707
  output_data: model_types.SupportedLocalDataType,
719
- input_feature_names: Optional[List[str]] = None,
720
- output_feature_names: Optional[List[str]] = None,
708
+ input_feature_names: Optional[list[str]] = None,
709
+ output_feature_names: Optional[list[str]] = None,
721
710
  input_data_limit: Optional[int] = 100,
722
711
  output_data_limit: Optional[int] = 100,
723
712
  ) -> core.ModelSignature:
@@ -1,5 +1,5 @@
1
1
  import warnings
2
- from typing import Any, Dict, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from packaging import version
5
5
 
@@ -13,7 +13,7 @@ class HuggingFacePipelineModel:
13
13
  revision: Optional[str] = None,
14
14
  token: Optional[str] = None,
15
15
  trust_remote_code: Optional[bool] = None,
16
- model_kwargs: Optional[Dict[str, Any]] = None,
16
+ model_kwargs: Optional[dict[str, Any]] = None,
17
17
  **kwargs: Any,
18
18
  ) -> None:
19
19
  """
@@ -65,6 +65,7 @@ class HuggingFacePipelineModel:
65
65
  warnings.warn(
66
66
  "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.",
67
67
  FutureWarning,
68
+ stacklevel=2,
68
69
  )
69
70
  if token is not None:
70
71
  raise ValueError(
@@ -183,7 +184,8 @@ class HuggingFacePipelineModel:
183
184
  warnings.warn(
184
185
  f"No model was supplied, defaulted to {model} and revision"
185
186
  f" {revision} ({transformers.pipelines.HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{model}).\n"
186
- "Using a pipeline without specifying a model name and revision in production is not recommended."
187
+ "Using a pipeline without specifying a model name and revision in production is not recommended.",
188
+ stacklevel=2,
187
189
  )
188
190
  if config is None and isinstance(model, str):
189
191
  config_obj = transformers.AutoConfig.from_pretrained(
@@ -200,7 +202,8 @@ class HuggingFacePipelineModel:
200
202
  if kwargs.get("device", None) is not None:
201
203
  warnings.warn(
202
204
  "Both `device` and `device_map` are specified. `device` will override `device_map`. You"
203
- " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`."
205
+ " will most likely encounter unexpected behavior. Please remove `device` and keep `device_map`.",
206
+ stacklevel=2,
204
207
  )
205
208
 
206
209
  # ==== End pipeline logic from transformers ====
@@ -1,6 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
2
  from enum import Enum
3
- from typing import TYPE_CHECKING, Dict, Literal, Sequence, TypedDict, TypeVar, Union
3
+ from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
4
4
 
5
5
  import numpy.typing as npt
6
6
  from typing_extensions import NotRequired
@@ -32,7 +32,7 @@ _SupportedBuiltins = Union[
32
32
  bool,
33
33
  str,
34
34
  bytes,
35
- Dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
35
+ dict[str, Union["_SupportedBuiltins", "_SupportedBuiltinsList"]],
36
36
  "_SupportedBuiltinsList",
37
37
  ]
38
38
  _SupportedNumpyDtype = Union[
@@ -153,7 +153,7 @@ class BaseModelSaveOption(TypedDict):
153
153
  embed_local_ml_library: NotRequired[bool]
154
154
  relax_version: NotRequired[bool]
155
155
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
156
- method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
156
+ method_options: NotRequired[dict[str, ModelMethodSaveOptions]]
157
157
  enable_explainability: NotRequired[bool]
158
158
  save_location: NotRequired[str]
159
159
 
@@ -1,7 +1,7 @@
1
1
  import inspect
2
2
  import numbers
3
3
  import os
4
- from typing import Any, Callable, Dict, List, Set, Tuple
4
+ from typing import Any, Callable
5
5
 
6
6
  import cloudpickle as cp
7
7
  import numpy as np
@@ -16,7 +16,7 @@ from snowflake.snowpark import Session
16
16
  from snowflake.snowpark._internal import utils as snowpark_utils
17
17
 
18
18
 
19
- def validate_sklearn_args(args: Dict[str, Tuple[Any, Any, bool]], klass: type) -> Dict[str, Any]:
19
+ def validate_sklearn_args(args: dict[str, tuple[Any, Any, bool]], klass: type) -> dict[str, Any]:
20
20
  """Validate if all the keyword args are supported by current version of SKLearn/XGBoost object.
21
21
 
22
22
  Args:
@@ -71,7 +71,7 @@ def transform_snowml_obj_to_sklearn_obj(obj: Any) -> Any:
71
71
  return obj
72
72
 
73
73
 
74
- def gather_dependencies(obj: Any) -> Set[str]:
74
+ def gather_dependencies(obj: Any) -> set[str]:
75
75
  """Gathers dependencies from the SnowML Estimator and Transformer objects.
76
76
 
77
77
  Args:
@@ -82,7 +82,7 @@ def gather_dependencies(obj: Any) -> Set[str]:
82
82
  """
83
83
 
84
84
  if isinstance(obj, list) or isinstance(obj, tuple):
85
- deps: Set[str] = set()
85
+ deps: set[str] = set()
86
86
  for elem in obj:
87
87
  deps = deps | set(gather_dependencies(elem))
88
88
  return deps
@@ -167,8 +167,8 @@ def get_module_name(model: object) -> str:
167
167
 
168
168
 
169
169
  def handle_inference_result(
170
- inference_res: Any, output_cols: List[str], inference_method: str, within_udf: bool = False
171
- ) -> Tuple[npt.NDArray[Any], List[str]]:
170
+ inference_res: Any, output_cols: list[str], inference_method: str, within_udf: bool = False
171
+ ) -> tuple[npt.NDArray[Any], list[str]]:
172
172
  if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
173
173
  # In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
174
174
  # ndarrays. We need to concatenate them.
@@ -248,7 +248,7 @@ def create_temp_stage(session: Session) -> str:
248
248
 
249
249
 
250
250
  def upload_model_to_stage(
251
- stage_name: str, estimator: object, session: Session, statement_params: Dict[str, str]
251
+ stage_name: str, estimator: object, session: Session, statement_params: dict[str, str]
252
252
  ) -> str:
253
253
  """Util method to pickle and upload the model to a temp Snowflake stage.
254
254
 
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import Any, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  import pandas as pd
5
5
 
@@ -38,9 +38,9 @@ class PandasTransformHandlers:
38
38
  def batch_inference(
39
39
  self,
40
40
  inference_method: str,
41
- input_cols: List[str],
42
- expected_output_cols: List[str],
43
- snowpark_input_cols: Optional[List[str]] = None,
41
+ input_cols: list[str],
42
+ expected_output_cols: list[str],
43
+ snowpark_input_cols: Optional[list[str]] = None,
44
44
  drop_input_cols: Optional[bool] = False,
45
45
  *args: Any,
46
46
  **kwargs: Any,
@@ -147,8 +147,8 @@ class PandasTransformHandlers:
147
147
 
148
148
  def score(
149
149
  self,
150
- input_cols: List[str],
151
- label_cols: List[str],
150
+ input_cols: list[str],
151
+ label_cols: list[str],
152
152
  sample_weight_col: Optional[str],
153
153
  *args: Any,
154
154
  **kwargs: Any,
@@ -1,5 +1,5 @@
1
1
  import inspect
2
- from typing import List, Optional, Tuple
2
+ from typing import Optional
3
3
 
4
4
  import pandas as pd
5
5
 
@@ -15,8 +15,8 @@ class PandasModelTrainer:
15
15
  self,
16
16
  estimator: object,
17
17
  dataset: pd.DataFrame,
18
- input_cols: List[str],
19
- label_cols: Optional[List[str]],
18
+ input_cols: list[str],
19
+ label_cols: Optional[list[str]],
20
20
  sample_weight_col: Optional[str],
21
21
  ) -> None:
22
22
  """
@@ -57,10 +57,10 @@ class PandasModelTrainer:
57
57
 
58
58
  def train_fit_predict(
59
59
  self,
60
- expected_output_cols_list: List[str],
60
+ expected_output_cols_list: list[str],
61
61
  drop_input_cols: Optional[bool] = False,
62
62
  example_output_pd_df: Optional[pd.DataFrame] = None,
63
- ) -> Tuple[pd.DataFrame, object]:
63
+ ) -> tuple[pd.DataFrame, object]:
64
64
  """Trains the model using specified features and target columns from the dataset.
65
65
  This API is different from fit itself because it would also provide the predict
66
66
  output.
@@ -92,9 +92,9 @@ class PandasModelTrainer:
92
92
 
93
93
  def train_fit_transform(
94
94
  self,
95
- expected_output_cols_list: List[str],
95
+ expected_output_cols_list: list[str],
96
96
  drop_input_cols: Optional[bool] = False,
97
- ) -> Tuple[pd.DataFrame, object]:
97
+ ) -> tuple[pd.DataFrame, object]:
98
98
  """Trains the model using specified features and target columns from the dataset.
99
99
  This API is different from fit itself because it would also provide the transform
100
100
  output.
@@ -1,5 +1,3 @@
1
- from typing import List
2
-
3
1
  import cloudpickle as cp
4
2
  import numpy as np
5
3
 
@@ -11,7 +9,7 @@ class ModelSpecifications:
11
9
  A dataclass to define model based specifications like required imports, and package dependencies for Sproc/Udfs.
12
10
  """
13
11
 
14
- def __init__(self, imports: List[str], pkgDependencies: List[str]) -> None:
12
+ def __init__(self, imports: list[str], pkgDependencies: list[str]) -> None:
15
13
  self.imports = imports
16
14
  self.pkgDependencies = pkgDependencies
17
15
 
@@ -20,7 +18,7 @@ class SKLearnModelSpecifications(ModelSpecifications):
20
18
  def __init__(self) -> None:
21
19
  import sklearn
22
20
 
23
- imports: List[str] = ["sklearn"]
21
+ imports: list[str] = ["sklearn"]
24
22
  # TODO(snandamuri): Replace cloudpickle with joblib after latest version of joblib is added to snowflake conda.
25
23
  pkgDependencies = [
26
24
  f"numpy=={np.__version__}",
@@ -56,8 +54,8 @@ class XGBoostModelSpecifications(ModelSpecifications):
56
54
  import sklearn
57
55
  import xgboost
58
56
 
59
- imports: List[str] = ["xgboost"]
60
- pkgDependencies: List[str] = [
57
+ imports: list[str] = ["xgboost"]
58
+ pkgDependencies: list[str] = [
61
59
  f"numpy=={np.__version__}",
62
60
  f"scikit-learn=={sklearn.__version__}",
63
61
  f"xgboost=={xgboost.__version__}",
@@ -71,8 +69,8 @@ class LightGBMModelSpecifications(ModelSpecifications):
71
69
  import lightgbm
72
70
  import sklearn
73
71
 
74
- imports: List[str] = ["lightgbm"]
75
- pkgDependencies: List[str] = [
72
+ imports: list[str] = ["lightgbm"]
73
+ pkgDependencies: list[str] = [
76
74
  f"numpy=={np.__version__}",
77
75
  f"scikit-learn=={sklearn.__version__}",
78
76
  f"lightgbm=={lightgbm.__version__}",
@@ -86,8 +84,8 @@ class SklearnModelSelectionModelSpecifications(ModelSpecifications):
86
84
  import sklearn
87
85
  import xgboost
88
86
 
89
- imports: List[str] = ["sklearn", "xgboost"]
90
- pkgDependencies: List[str] = [
87
+ imports: list[str] = ["sklearn", "xgboost"]
88
+ pkgDependencies: list[str] = [
91
89
  f"numpy=={np.__version__}",
92
90
  f"scikit-learn=={sklearn.__version__}",
93
91
  f"cloudpickle=={cp.__version__}",
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Protocol, Tuple, Union
1
+ from typing import Optional, Protocol, Union
2
2
 
3
3
  import pandas as pd
4
4
 
@@ -18,15 +18,15 @@ class ModelTrainer(Protocol):
18
18
 
19
19
  def train_fit_predict(
20
20
  self,
21
- expected_output_cols_list: List[str],
21
+ expected_output_cols_list: list[str],
22
22
  drop_input_cols: Optional[bool] = False,
23
23
  example_output_pd_df: Optional[pd.DataFrame] = None,
24
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
24
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
25
25
  raise NotImplementedError
26
26
 
27
27
  def train_fit_transform(
28
28
  self,
29
- expected_output_cols_list: List[str],
29
+ expected_output_cols_list: list[str],
30
30
  drop_input_cols: Optional[bool] = False,
31
- ) -> Tuple[Union[DataFrame, pd.DataFrame], object]:
31
+ ) -> tuple[Union[DataFrame, pd.DataFrame], object]:
32
32
  raise NotImplementedError
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Union
1
+ from typing import Optional, Union
2
2
 
3
3
  import pandas as pd
4
4
  from sklearn import model_selection
@@ -71,8 +71,8 @@ class ModelTrainerBuilder:
71
71
  cls,
72
72
  estimator: object,
73
73
  dataset: Union[DataFrame, pd.DataFrame],
74
- input_cols: Optional[List[str]] = None,
75
- label_cols: Optional[List[str]] = None,
74
+ input_cols: Optional[list[str]] = None,
75
+ label_cols: Optional[list[str]] = None,
76
76
  sample_weight_col: Optional[str] = None,
77
77
  autogenerated: bool = False,
78
78
  subproject: str = "",
@@ -130,7 +130,7 @@ class ModelTrainerBuilder:
130
130
  cls,
131
131
  estimator: object,
132
132
  dataset: Union[DataFrame, pd.DataFrame],
133
- input_cols: List[str],
133
+ input_cols: list[str],
134
134
  autogenerated: bool = False,
135
135
  subproject: str = "",
136
136
  ) -> ModelTrainer:
@@ -169,8 +169,8 @@ class ModelTrainerBuilder:
169
169
  cls,
170
170
  estimator: object,
171
171
  dataset: Union[DataFrame, pd.DataFrame],
172
- input_cols: List[str],
173
- label_cols: Optional[List[str]] = None,
172
+ input_cols: list[str],
173
+ label_cols: Optional[list[str]] = None,
174
174
  sample_weight_col: Optional[str] = None,
175
175
  autogenerated: bool = False,
176
176
  subproject: str = "",
@@ -5,7 +5,7 @@ import os
5
5
  import posixpath
6
6
  import sys
7
7
  import uuid
8
- from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from typing import Any, Optional, Union
9
9
 
10
10
  import cloudpickle as cp
11
11
  import numpy as np
@@ -50,11 +50,11 @@ _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}
50
50
  def construct_cv_results(
51
51
  estimator: Union[GridSearchCV, RandomizedSearchCV],
52
52
  n_split: int,
53
- param_grid: List[Dict[str, Any]],
54
- cv_results_raw_hex: List[Row],
53
+ param_grid: list[dict[str, Any]],
54
+ cv_results_raw_hex: list[Row],
55
55
  cross_validator_indices_length: int,
56
56
  parameter_grid_length: int,
57
- ) -> Tuple[bool, Dict[str, Any]]:
57
+ ) -> tuple[bool, dict[str, Any]]:
58
58
  """Construct the cross validation result from the UDF. Because we accelerate the process
59
59
  by the number of cross validation number, and the combination of parameter grids.
60
60
  Therefore, we need to stick them back together instead of returning the raw result
@@ -158,11 +158,11 @@ def construct_cv_results(
158
158
  def construct_cv_results_memory_efficient_version(
159
159
  estimator: Union[GridSearchCV, RandomizedSearchCV],
160
160
  n_split: int,
161
- param_grid: List[Dict[str, Any]],
162
- cv_results_raw_hex: List[Row],
161
+ param_grid: list[dict[str, Any]],
162
+ cv_results_raw_hex: list[Row],
163
163
  cross_validator_indices_length: int,
164
164
  parameter_grid_length: int,
165
- ) -> Tuple[Any, Dict[str, Any]]:
165
+ ) -> tuple[Any, dict[str, Any]]:
166
166
  """Construct the cross validation result from the UDF.
167
167
  The output is a raw dictionary generated by _fit_and_score, encoded into hex binary.
168
168
  This function need to decode the string and then call _format_result to stick them back together
@@ -210,7 +210,7 @@ def construct_cv_results_memory_efficient_version(
210
210
  # because original SearchCV is ranked by parameter first and cv second,
211
211
  # to make the memory efficient, we implemented by fitting on cv first and parameter second
212
212
  # when retrieving the results back, the ordering should revert back to remain the same result as original SearchCV
213
- def generate_the_order_by_parameter_index(all_combination_length: int) -> List[int]:
213
+ def generate_the_order_by_parameter_index(all_combination_length: int) -> list[int]:
214
214
  pattern = []
215
215
  for i in range(all_combination_length):
216
216
  if i % parameter_grid_length == 0:
@@ -221,7 +221,7 @@ def construct_cv_results_memory_efficient_version(
221
221
  pattern.append(j)
222
222
  return pattern
223
223
 
224
- def rerank_array(original_array: List[Any], pattern: List[int]) -> List[Any]:
224
+ def rerank_array(original_array: list[Any], pattern: list[int]) -> list[Any]:
225
225
  reranked_array = []
226
226
  for index in pattern:
227
227
  reranked_array.append(original_array[index])
@@ -251,8 +251,8 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
251
251
  estimator: object,
252
252
  dataset: DataFrame,
253
253
  session: Session,
254
- input_cols: List[str],
255
- label_cols: Optional[List[str]],
254
+ input_cols: list[str],
255
+ label_cols: Optional[list[str]],
256
256
  sample_weight_col: Optional[str],
257
257
  autogenerated: bool = False,
258
258
  subproject: str = "",
@@ -289,10 +289,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
289
289
  dataset: DataFrame,
290
290
  session: Session,
291
291
  estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
292
- dependencies: List[str],
293
- udf_imports: List[str],
294
- input_cols: List[str],
295
- label_cols: Optional[List[str]],
292
+ dependencies: list[str],
293
+ udf_imports: list[str],
294
+ input_cols: list[str],
295
+ label_cols: Optional[list[str]],
296
296
  sample_weight_col: Optional[str],
297
297
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
298
298
  from itertools import product
@@ -382,10 +382,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
382
382
  )
383
383
  def _distributed_search(
384
384
  session: Session,
385
- imports: List[str],
385
+ imports: list[str],
386
386
  stage_estimator_file_name: str,
387
- input_cols: List[str],
388
- label_cols: Optional[List[str]],
387
+ input_cols: list[str],
388
+ label_cols: Optional[list[str]],
389
389
  ) -> str:
390
390
  import os
391
391
  import time
@@ -455,12 +455,12 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
455
455
  assert estimator is not None
456
456
 
457
457
  @cachetools.cached(cache={})
458
- def _load_data_into_udf() -> Tuple[
459
- Dict[str, pd.DataFrame],
458
+ def _load_data_into_udf() -> tuple[
459
+ dict[str, pd.DataFrame],
460
460
  Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
461
461
  pd.DataFrame,
462
462
  int,
463
- List[Dict[str, Any]],
463
+ list[dict[str, Any]],
464
464
  ]:
465
465
  import pyarrow.parquet as pq
466
466
 
@@ -512,7 +512,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
512
512
  self.data_length = data_length
513
513
  self.params_to_evaluate = params_to_evaluate
514
514
 
515
- def process(self, params_idx: int, cv_idx: int) -> Iterator[Tuple[str]]:
515
+ def process(self, params_idx: int, cv_idx: int) -> Iterator[tuple[str]]:
516
516
  # Assign parameter to GridSearchCV
517
517
  if hasattr(estimator, "param_grid"):
518
518
  self.estimator.param_grid = self.params_to_evaluate[params_idx]
@@ -699,10 +699,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
699
699
  dataset: DataFrame,
700
700
  session: Session,
701
701
  estimator: Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV],
702
- dependencies: List[str],
703
- udf_imports: List[str],
704
- input_cols: List[str],
705
- label_cols: Optional[List[str]],
702
+ dependencies: list[str],
703
+ udf_imports: list[str],
704
+ input_cols: list[str],
705
+ label_cols: Optional[list[str]],
706
706
  sample_weight_col: Optional[str],
707
707
  ) -> Union[model_selection.GridSearchCV, model_selection.RandomizedSearchCV]:
708
708
  from itertools import product
@@ -727,7 +727,7 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
727
727
  # Create a temp file and dump the estimator to that file.
728
728
  estimator_file_name = temp_file_utils.get_temp_file_path()
729
729
  params_to_evaluate = list(param_grid)
730
- CONSTANTS: Dict[str, Any] = dict()
730
+ CONSTANTS: dict[str, Any] = dict()
731
731
  CONSTANTS["dataset_snowpark_cols"] = dataset.columns
732
732
  CONSTANTS["n_candidates"] = len(params_to_evaluate)
733
733
  CONSTANTS["_N_JOBS"] = estimator.n_jobs
@@ -791,10 +791,10 @@ class DistributedHPOTrainer(SnowparkModelTrainer):
791
791
  )
792
792
  def _distributed_search(
793
793
  session: Session,
794
- imports: List[str],
794
+ imports: list[str],
795
795
  stage_estimator_file_name: str,
796
- input_cols: List[str],
797
- label_cols: Optional[List[str]],
796
+ input_cols: list[str],
797
+ label_cols: Optional[list[str]],
798
798
  ) -> str:
799
799
  import os
800
800
  import time