snowflake-ml-python 1.8.2__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +23 -24
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +6 -6
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +15 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +7 -7
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/interop_utils.py +10 -10
  58. snowflake/ml/jobs/_utils/payload_utils.py +6 -16
  59. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +7 -4
  60. snowflake/ml/jobs/_utils/scripts/signal_workers.py +8 -8
  61. snowflake/ml/jobs/_utils/spec_utils.py +17 -28
  62. snowflake/ml/jobs/_utils/types.py +2 -2
  63. snowflake/ml/jobs/decorators.py +4 -5
  64. snowflake/ml/jobs/job.py +24 -14
  65. snowflake/ml/jobs/manager.py +37 -41
  66. snowflake/ml/lineage/lineage_node.py +5 -5
  67. snowflake/ml/model/_client/model/model_impl.py +3 -3
  68. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  69. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  70. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  71. snowflake/ml/model/_client/ops/service_ops.py +199 -26
  72. snowflake/ml/model/_client/service/model_deployment_spec.py +171 -47
  73. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  74. snowflake/ml/model/_client/sql/model.py +8 -8
  75. snowflake/ml/model/_client/sql/model_version.py +26 -26
  76. snowflake/ml/model/_client/sql/service.py +13 -13
  77. snowflake/ml/model/_client/sql/stage.py +2 -2
  78. snowflake/ml/model/_client/sql/tag.py +6 -6
  79. snowflake/ml/model/_model_composer/model_composer.py +17 -14
  80. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  81. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  82. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  83. snowflake/ml/model/_packager/model_env/model_env.py +28 -25
  84. snowflake/ml/model/_packager/model_handler.py +4 -4
  85. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  86. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  87. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  88. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  89. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  90. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  91. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  92. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  93. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  94. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  95. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  96. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  99. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  100. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  101. snowflake/ml/model/_packager/model_meta/model_meta.py +37 -37
  102. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  103. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  104. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  105. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  106. snowflake/ml/model/_packager/model_packager.py +11 -9
  107. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  108. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  109. snowflake/ml/model/_signatures/core.py +16 -24
  110. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  111. snowflake/ml/model/_signatures/utils.py +6 -6
  112. snowflake/ml/model/custom_model.py +8 -8
  113. snowflake/ml/model/model_signature.py +9 -20
  114. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  115. snowflake/ml/model/type_hints.py +3 -3
  116. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  117. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  118. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  119. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  120. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  121. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  122. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  123. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  124. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  125. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  126. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  127. snowflake/ml/modeling/framework/_utils.py +10 -10
  128. snowflake/ml/modeling/framework/base.py +32 -32
  129. snowflake/ml/modeling/impute/__init__.py +1 -1
  130. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  131. snowflake/ml/modeling/metrics/__init__.py +1 -1
  132. snowflake/ml/modeling/metrics/classification.py +39 -39
  133. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  134. snowflake/ml/modeling/metrics/ranking.py +7 -7
  135. snowflake/ml/modeling/metrics/regression.py +13 -13
  136. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  137. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  138. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  139. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  140. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  141. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  142. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  143. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  144. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  145. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  146. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  147. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  148. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  149. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  150. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  151. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  152. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  153. snowflake/ml/registry/_manager/model_manager.py +33 -31
  154. snowflake/ml/registry/registry.py +29 -22
  155. snowflake/ml/utils/authentication.py +2 -2
  156. snowflake/ml/utils/connection_params.py +5 -5
  157. snowflake/ml/utils/sparse.py +5 -4
  158. snowflake/ml/utils/sql_client.py +1 -2
  159. snowflake/ml/version.py +2 -1
  160. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +16 -7
  161. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +164 -166
  162. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  163. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  164. snowflake/ml/modeling/_internal/constants.py +0 -2
  165. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  166. {snowflake_ml_python-1.8.2.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Mapping, Optional
1
+ from typing import Any, Mapping, Optional
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.utils import (
@@ -15,7 +15,7 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
15
15
  MODEL_JSON_VERSION_NAME_FIELD = "version_name"
16
16
 
17
17
 
18
- def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
18
+ def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
19
19
  sql_list = ", ".join([f"'{column}'" for column in columns])
20
20
  return f"({sql_list})"
21
21
 
@@ -60,17 +60,17 @@ class ModelMonitorSQLClient:
60
60
  function_name: str,
61
61
  warehouse_name: sql_identifier.SqlIdentifier,
62
62
  timestamp_column: sql_identifier.SqlIdentifier,
63
- id_columns: List[sql_identifier.SqlIdentifier],
64
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
65
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
66
- actual_score_columns: List[sql_identifier.SqlIdentifier],
67
- actual_class_columns: List[sql_identifier.SqlIdentifier],
63
+ id_columns: list[sql_identifier.SqlIdentifier],
64
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
65
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
66
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
67
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
68
68
  refresh_interval: str,
69
69
  aggregation_window: str,
70
70
  baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
71
71
  baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
72
72
  baseline: Optional[sql_identifier.SqlIdentifier] = None,
73
- statement_params: Optional[Dict[str, Any]] = None,
73
+ statement_params: Optional[dict[str, Any]] = None,
74
74
  ) -> None:
75
75
  baseline_sql = ""
76
76
  if baseline:
@@ -103,7 +103,7 @@ class ModelMonitorSQLClient:
103
103
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
104
104
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
105
105
  monitor_name: sql_identifier.SqlIdentifier,
106
- statement_params: Optional[Dict[str, Any]] = None,
106
+ statement_params: Optional[dict[str, Any]] = None,
107
107
  ) -> None:
108
108
  search_database_name = database_name or self._database_name
109
109
  search_schema_name = schema_name or self._schema_name
@@ -116,8 +116,8 @@ class ModelMonitorSQLClient:
116
116
  def show_model_monitors(
117
117
  self,
118
118
  *,
119
- statement_params: Optional[Dict[str, Any]] = None,
120
- ) -> List[snowpark.Row]:
119
+ statement_params: Optional[dict[str, Any]] = None,
120
+ ) -> list[snowpark.Row]:
121
121
  fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
122
122
  return (
123
123
  query_result_checker.SqlResultValidator(
@@ -135,7 +135,7 @@ class ModelMonitorSQLClient:
135
135
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
136
136
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
137
137
  monitor_name: sql_identifier.SqlIdentifier,
138
- statement_params: Optional[Dict[str, Any]] = None,
138
+ statement_params: Optional[dict[str, Any]] = None,
139
139
  ) -> bool:
140
140
  search_database_name = database_name or self._database_name
141
141
  search_schema_name = schema_name or self._schema_name
@@ -153,7 +153,7 @@ class ModelMonitorSQLClient:
153
153
  def validate_monitor_warehouse(
154
154
  self,
155
155
  warehouse_name: sql_identifier.SqlIdentifier,
156
- statement_params: Optional[Dict[str, Any]] = None,
156
+ statement_params: Optional[dict[str, Any]] = None,
157
157
  ) -> None:
158
158
  """Validate warehouse provided for monitoring exists.
159
159
 
@@ -177,11 +177,11 @@ class ModelMonitorSQLClient:
177
177
  *,
178
178
  source_column_schema: Mapping[str, types.DataType],
179
179
  timestamp_column: sql_identifier.SqlIdentifier,
180
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
181
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
182
- actual_score_columns: List[sql_identifier.SqlIdentifier],
183
- actual_class_columns: List[sql_identifier.SqlIdentifier],
184
- id_columns: List[sql_identifier.SqlIdentifier],
180
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
181
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
182
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
183
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
184
+ id_columns: list[sql_identifier.SqlIdentifier],
185
185
  ) -> None:
186
186
  """Ensures all columns exist in the source table.
187
187
 
@@ -221,11 +221,11 @@ class ModelMonitorSQLClient:
221
221
  source_schema: Optional[sql_identifier.SqlIdentifier],
222
222
  source: sql_identifier.SqlIdentifier,
223
223
  timestamp_column: sql_identifier.SqlIdentifier,
224
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
225
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
226
- actual_score_columns: List[sql_identifier.SqlIdentifier],
227
- actual_class_columns: List[sql_identifier.SqlIdentifier],
228
- id_columns: List[sql_identifier.SqlIdentifier],
224
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
225
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
226
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
227
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
228
+ id_columns: list[sql_identifier.SqlIdentifier],
229
229
  ) -> None:
230
230
  source_database = source_database or self._database_name
231
231
  source_schema = source_schema or self._schema_name
@@ -250,7 +250,7 @@ class ModelMonitorSQLClient:
250
250
  self,
251
251
  operation: str,
252
252
  monitor_name: sql_identifier.SqlIdentifier,
253
- statement_params: Optional[Dict[str, Any]] = None,
253
+ statement_params: Optional[dict[str, Any]] = None,
254
254
  ) -> None:
255
255
  if operation not in {"SUSPEND", "RESUME"}:
256
256
  raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
@@ -263,7 +263,7 @@ class ModelMonitorSQLClient:
263
263
  def suspend_monitor(
264
264
  self,
265
265
  monitor_name: sql_identifier.SqlIdentifier,
266
- statement_params: Optional[Dict[str, Any]] = None,
266
+ statement_params: Optional[dict[str, Any]] = None,
267
267
  ) -> None:
268
268
  self._alter_monitor(
269
269
  operation="SUSPEND",
@@ -274,7 +274,7 @@ class ModelMonitorSQLClient:
274
274
  def resume_monitor(
275
275
  self,
276
276
  monitor_name: sql_identifier.SqlIdentifier,
277
- statement_params: Optional[Dict[str, Any]] = None,
277
+ statement_params: Optional[dict[str, Any]] = None,
278
278
  ) -> None:
279
279
  self._alter_monitor(
280
280
  operation="RESUME",
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from snowflake import snowpark
5
5
  from snowflake.ml._internal.utils import sql_identifier
@@ -20,7 +20,7 @@ class ModelMonitorManager:
20
20
  database_name: sql_identifier.SqlIdentifier,
21
21
  schema_name: sql_identifier.SqlIdentifier,
22
22
  *,
23
- statement_params: Optional[Dict[str, Any]] = None,
23
+ statement_params: Optional[dict[str, Any]] = None,
24
24
  ) -> None:
25
25
  """
26
26
  Opens a ModelMonitorManager for a given database and schema.
@@ -64,7 +64,7 @@ class ModelMonitorManager:
64
64
  f"Found: {existing_target_methods}."
65
65
  )
66
66
 
67
- def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
67
+ def _build_column_list_from_input(self, columns: Optional[list[str]]) -> list[sql_identifier.SqlIdentifier]:
68
68
  return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
69
69
 
70
70
  def add_monitor(
@@ -172,7 +172,7 @@ class ModelMonitorManager:
172
172
  """
173
173
  rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
174
174
 
175
- def model_match_fn(model_details: Dict[str, str]) -> bool:
175
+ def model_match_fn(model_details: dict[str, str]) -> bool:
176
176
  return (
177
177
  model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
178
178
  and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
@@ -215,7 +215,7 @@ class ModelMonitorManager:
215
215
  name=monitor_name_id,
216
216
  )
217
217
 
218
- def show_model_monitors(self) -> List[snowpark.Row]:
218
+ def show_model_monitors(self) -> list[snowpark.Row]:
219
219
  """Show all model monitors in the registry.
220
220
 
221
221
  Returns:
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import List, Optional
2
+ from typing import Optional
3
3
 
4
4
  from snowflake.ml.model._client.model import model_version_impl
5
5
 
@@ -14,20 +14,20 @@ class ModelMonitorSourceConfig:
14
14
  timestamp_column: str
15
15
  """Name of column in the source containing timestamp."""
16
16
 
17
- id_columns: List[str]
17
+ id_columns: list[str]
18
18
  """List of columns in the source containing unique identifiers."""
19
19
 
20
- prediction_score_columns: Optional[List[str]] = None
20
+ prediction_score_columns: Optional[list[str]] = None
21
21
  """List of columns in the source containing prediction scores.
22
22
  Can be regression scores for regression models and probability scores for classification models."""
23
23
 
24
- prediction_class_columns: Optional[List[str]] = None
24
+ prediction_class_columns: Optional[list[str]] = None
25
25
  """List of columns in the source containing prediction classes for classification models."""
26
26
 
27
- actual_score_columns: Optional[List[str]] = None
27
+ actual_score_columns: Optional[list[str]] = None
28
28
  """List of columns in the source containing actual scores."""
29
29
 
30
- actual_class_columns: Optional[List[str]] = None
30
+ actual_class_columns: Optional[list[str]] = None
31
31
  """List of columns in the source containing actual classes for classification models."""
32
32
 
33
33
  baseline: Optional[str] = None
@@ -1,11 +1,10 @@
1
- import os
2
1
  from types import ModuleType
3
- from typing import Any, Dict, List, Optional, Tuple, Union
2
+ from typing import Any, Optional, Union
4
3
 
5
4
  import pandas as pd
6
5
  from absl.logging import logging
7
6
 
8
- from snowflake.ml._internal import platform_capabilities, telemetry
7
+ from snowflake.ml._internal import env, platform_capabilities, telemetry
9
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
10
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
11
10
  from snowflake.ml._internal.utils import sql_identifier
@@ -14,7 +13,6 @@ from snowflake.ml.model._client.model import model_impl, model_version_impl
14
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
15
14
  from snowflake.ml.model._model_composer import model_composer
16
15
  from snowflake.ml.model._packager.model_meta import model_meta
17
- from snowflake.ml.modeling._internal import constants
18
16
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
19
17
 
20
18
  logger = logging.getLogger(__name__)
@@ -45,20 +43,21 @@ class ModelManager:
45
43
  model_name: str,
46
44
  version_name: Optional[str] = None,
47
45
  comment: Optional[str] = None,
48
- metrics: Optional[Dict[str, Any]] = None,
49
- conda_dependencies: Optional[List[str]] = None,
50
- pip_requirements: Optional[List[str]] = None,
51
- artifact_repository_map: Optional[Dict[str, str]] = None,
52
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
46
+ metrics: Optional[dict[str, Any]] = None,
47
+ conda_dependencies: Optional[list[str]] = None,
48
+ pip_requirements: Optional[list[str]] = None,
49
+ artifact_repository_map: Optional[dict[str, str]] = None,
50
+ resource_constraint: Optional[dict[str, str]] = None,
51
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
53
52
  python_version: Optional[str] = None,
54
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
53
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
55
54
  sample_input_data: Optional[model_types.SupportedDataType] = None,
56
- user_files: Optional[Dict[str, List[str]]] = None,
57
- code_paths: Optional[List[str]] = None,
58
- ext_modules: Optional[List[ModuleType]] = None,
55
+ user_files: Optional[dict[str, list[str]]] = None,
56
+ code_paths: Optional[list[str]] = None,
57
+ ext_modules: Optional[list[ModuleType]] = None,
59
58
  task: model_types.Task = model_types.Task.UNKNOWN,
60
59
  options: Optional[model_types.ModelSaveOption] = None,
61
- statement_params: Optional[Dict[str, Any]] = None,
60
+ statement_params: Optional[dict[str, Any]] = None,
62
61
  ) -> model_version_impl.ModelVersion:
63
62
 
64
63
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -131,6 +130,7 @@ class ModelManager:
131
130
  conda_dependencies=conda_dependencies,
132
131
  pip_requirements=pip_requirements,
133
132
  artifact_repository_map=artifact_repository_map,
133
+ resource_constraint=resource_constraint,
134
134
  target_platforms=target_platforms,
135
135
  python_version=python_version,
136
136
  signatures=signatures,
@@ -150,20 +150,21 @@ class ModelManager:
150
150
  model_name: str,
151
151
  version_name: str,
152
152
  comment: Optional[str] = None,
153
- metrics: Optional[Dict[str, Any]] = None,
154
- conda_dependencies: Optional[List[str]] = None,
155
- pip_requirements: Optional[List[str]] = None,
156
- artifact_repository_map: Optional[Dict[str, str]] = None,
157
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
153
+ metrics: Optional[dict[str, Any]] = None,
154
+ conda_dependencies: Optional[list[str]] = None,
155
+ pip_requirements: Optional[list[str]] = None,
156
+ artifact_repository_map: Optional[dict[str, str]] = None,
157
+ resource_constraint: Optional[dict[str, str]] = None,
158
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
158
159
  python_version: Optional[str] = None,
159
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
160
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
160
161
  sample_input_data: Optional[model_types.SupportedDataType] = None,
161
- user_files: Optional[Dict[str, List[str]]] = None,
162
- code_paths: Optional[List[str]] = None,
163
- ext_modules: Optional[List[ModuleType]] = None,
162
+ user_files: Optional[dict[str, list[str]]] = None,
163
+ code_paths: Optional[list[str]] = None,
164
+ ext_modules: Optional[list[ModuleType]] = None,
164
165
  task: model_types.Task = model_types.Task.UNKNOWN,
165
166
  options: Optional[model_types.ModelSaveOption] = None,
166
- statement_params: Optional[Dict[str, Any]] = None,
167
+ statement_params: Optional[dict[str, Any]] = None,
167
168
  ) -> model_version_impl.ModelVersion:
168
169
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
169
170
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -212,7 +213,7 @@ class ModelManager:
212
213
  platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
213
214
  else:
214
215
  # Default the target platform to SPCS if not specified when running in ML runtime
215
- if os.getenv(constants.IN_ML_RUNTIME_ENV_VAR):
216
+ if env.IN_ML_RUNTIME:
216
217
  logger.info(
217
218
  "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
218
219
  'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
@@ -253,6 +254,7 @@ class ModelManager:
253
254
  conda_dependencies=conda_dependencies,
254
255
  pip_requirements=pip_requirements,
255
256
  artifact_repository_map=artifact_repository_map,
257
+ resource_constraint=resource_constraint,
256
258
  target_platforms=platforms,
257
259
  python_version=python_version,
258
260
  user_files=user_files,
@@ -314,7 +316,7 @@ class ModelManager:
314
316
  self,
315
317
  model_name: str,
316
318
  *,
317
- statement_params: Optional[Dict[str, Any]] = None,
319
+ statement_params: Optional[dict[str, Any]] = None,
318
320
  ) -> model_impl.Model:
319
321
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
320
322
  if self._model_ops.validate_existence(
@@ -342,8 +344,8 @@ class ModelManager:
342
344
  def models(
343
345
  self,
344
346
  *,
345
- statement_params: Optional[Dict[str, Any]] = None,
346
- ) -> List[model_impl.Model]:
347
+ statement_params: Optional[dict[str, Any]] = None,
348
+ ) -> list[model_impl.Model]:
347
349
  model_names = self._model_ops.list_models_or_versions(
348
350
  database_name=None,
349
351
  schema_name=None,
@@ -361,7 +363,7 @@ class ModelManager:
361
363
  def show_models(
362
364
  self,
363
365
  *,
364
- statement_params: Optional[Dict[str, Any]] = None,
366
+ statement_params: Optional[dict[str, Any]] = None,
365
367
  ) -> pd.DataFrame:
366
368
  rows = self._model_ops.show_models_or_versions(
367
369
  database_name=None,
@@ -374,7 +376,7 @@ class ModelManager:
374
376
  self,
375
377
  model_name: str,
376
378
  *,
377
- statement_params: Optional[Dict[str, Any]] = None,
379
+ statement_params: Optional[dict[str, Any]] = None,
378
380
  ) -> None:
379
381
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
380
382
 
@@ -387,7 +389,7 @@ class ModelManager:
387
389
 
388
390
  def _parse_fully_qualified_name(
389
391
  self, model_name: str
390
- ) -> Tuple[
392
+ ) -> tuple[
391
393
  Optional[sql_identifier.SqlIdentifier], Optional[sql_identifier.SqlIdentifier], sql_identifier.SqlIdentifier
392
394
  ]:
393
395
  try:
@@ -1,6 +1,6 @@
1
1
  import warnings
2
2
  from types import ModuleType
3
- from typing import Any, Dict, List, Optional, Union, overload
3
+ from typing import Any, Optional, Union, overload
4
4
 
5
5
  import pandas as pd
6
6
 
@@ -36,7 +36,7 @@ class Registry:
36
36
  *,
37
37
  database_name: Optional[str] = None,
38
38
  schema_name: Optional[str] = None,
39
- options: Optional[Dict[str, Any]] = None,
39
+ options: Optional[dict[str, Any]] = None,
40
40
  ) -> None:
41
41
  """Opens a registry within a pre-created Snowflake schema.
42
42
 
@@ -107,17 +107,18 @@ class Registry:
107
107
  model_name: str,
108
108
  version_name: Optional[str] = None,
109
109
  comment: Optional[str] = None,
110
- metrics: Optional[Dict[str, Any]] = None,
111
- conda_dependencies: Optional[List[str]] = None,
112
- pip_requirements: Optional[List[str]] = None,
113
- artifact_repository_map: Optional[Dict[str, str]] = None,
114
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
110
+ metrics: Optional[dict[str, Any]] = None,
111
+ conda_dependencies: Optional[list[str]] = None,
112
+ pip_requirements: Optional[list[str]] = None,
113
+ artifact_repository_map: Optional[dict[str, str]] = None,
114
+ resource_constraint: Optional[dict[str, str]] = None,
115
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
115
116
  python_version: Optional[str] = None,
116
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
117
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
117
118
  sample_input_data: Optional[model_types.SupportedDataType] = None,
118
- user_files: Optional[Dict[str, List[str]]] = None,
119
- code_paths: Optional[List[str]] = None,
120
- ext_modules: Optional[List[ModuleType]] = None,
119
+ user_files: Optional[dict[str, list[str]]] = None,
120
+ code_paths: Optional[list[str]] = None,
121
+ ext_modules: Optional[list[ModuleType]] = None,
121
122
  task: model_types.Task = model_types.Task.UNKNOWN,
122
123
  options: Optional[model_types.ModelSaveOption] = None,
123
124
  ) -> ModelVersion:
@@ -152,6 +153,7 @@ class Registry:
152
153
  Format: {channel_name: artifact_repository_name}, where:
153
154
  - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
154
155
  - artifact_repository_name: The name or URL of the repository to fetch packages from.
156
+ resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
155
157
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
156
158
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
157
159
  python_version: Python version in which the model is run. Defaults to None.
@@ -232,6 +234,7 @@ class Registry:
232
234
  "conda_dependencies",
233
235
  "pip_requirements",
234
236
  "artifact_repository_map",
237
+ "resource_constraint",
235
238
  "target_platforms",
236
239
  "python_version",
237
240
  "signatures",
@@ -244,17 +247,18 @@ class Registry:
244
247
  model_name: str,
245
248
  version_name: Optional[str] = None,
246
249
  comment: Optional[str] = None,
247
- metrics: Optional[Dict[str, Any]] = None,
248
- conda_dependencies: Optional[List[str]] = None,
249
- pip_requirements: Optional[List[str]] = None,
250
- artifact_repository_map: Optional[Dict[str, str]] = None,
251
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
250
+ metrics: Optional[dict[str, Any]] = None,
251
+ conda_dependencies: Optional[list[str]] = None,
252
+ pip_requirements: Optional[list[str]] = None,
253
+ artifact_repository_map: Optional[dict[str, str]] = None,
254
+ resource_constraint: Optional[dict[str, str]] = None,
255
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
252
256
  python_version: Optional[str] = None,
253
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
257
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
254
258
  sample_input_data: Optional[model_types.SupportedDataType] = None,
255
- user_files: Optional[Dict[str, List[str]]] = None,
256
- code_paths: Optional[List[str]] = None,
257
- ext_modules: Optional[List[ModuleType]] = None,
259
+ user_files: Optional[dict[str, list[str]]] = None,
260
+ code_paths: Optional[list[str]] = None,
261
+ ext_modules: Optional[list[ModuleType]] = None,
258
262
  task: model_types.Task = model_types.Task.UNKNOWN,
259
263
  options: Optional[model_types.ModelSaveOption] = None,
260
264
  ) -> ModelVersion:
@@ -289,6 +293,7 @@ class Registry:
289
293
  Format: {channel_name: artifact_repository_name}, where:
290
294
  - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
291
295
  - artifact_repository_name: The name or URL of the repository to fetch packages from.
296
+ resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
292
297
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
293
298
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
294
299
  python_version: Python version in which the model is run. Defaults to None.
@@ -373,6 +378,7 @@ class Registry:
373
378
  conda_dependencies,
374
379
  pip_requirements,
375
380
  artifact_repository_map,
381
+ resource_constraint,
376
382
  target_platforms,
377
383
  python_version,
378
384
  signatures,
@@ -407,6 +413,7 @@ class Registry:
407
413
  conda_dependencies=conda_dependencies,
408
414
  pip_requirements=pip_requirements,
409
415
  artifact_repository_map=artifact_repository_map,
416
+ resource_constraint=resource_constraint,
410
417
  target_platforms=target_platforms,
411
418
  python_version=python_version,
412
419
  signatures=signatures,
@@ -442,7 +449,7 @@ class Registry:
442
449
  project=_TELEMETRY_PROJECT,
443
450
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
444
451
  )
445
- def models(self) -> List[Model]:
452
+ def models(self) -> list[Model]:
446
453
  """Get all models in the schema where the registry is opened.
447
454
 
448
455
  Returns:
@@ -568,7 +575,7 @@ class Registry:
568
575
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
569
576
  )
570
577
  @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
571
- def show_model_monitors(self) -> List[snowpark.Row]:
578
+ def show_model_monitors(self) -> list[snowpark.Row]:
572
579
  """Show all model monitors in the registry.
573
580
 
574
581
  Returns:
@@ -1,7 +1,7 @@
1
1
  import http
2
2
  import logging
3
3
  from datetime import timedelta
4
- from typing import Dict, Optional
4
+ from typing import Optional
5
5
 
6
6
  import requests
7
7
  from cryptography.hazmat.primitives.asymmetric import types
@@ -10,7 +10,7 @@ from requests import auth
10
10
  from snowflake.ml._internal.utils import jwt_generator
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
- _JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {}
13
+ _JWT_TOKEN_CACHE: dict[str, dict[int, str]] = {}
14
14
 
15
15
 
16
16
  def get_jwt_token_generator(
@@ -1,6 +1,6 @@
1
1
  import configparser
2
2
  import os
3
- from typing import Dict, Optional, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  from absl import logging
6
6
  from cryptography.hazmat import backends
@@ -76,7 +76,7 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
76
76
  )
77
77
 
78
78
 
79
- def _connection_properties_from_env() -> Dict[str, str]:
79
+ def _connection_properties_from_env() -> dict[str, str]:
80
80
  """Returns a dict with all possible login related env variables."""
81
81
  sf_conn_prop = {
82
82
  # Mandatory fields
@@ -104,7 +104,7 @@ def _connection_properties_from_env() -> Dict[str, str]:
104
104
  return sf_conn_prop
105
105
 
106
106
 
107
- def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> Dict[str, str]:
107
+ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
108
108
  """Loads the dictionary from snowsql config file."""
109
109
  snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
110
110
  if not os.path.exists(snowsql_config_file):
@@ -133,7 +133,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
133
133
 
134
134
 
135
135
  @snowpark._internal.utils.private_preview(version="0.2.0")
136
- def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> Dict[str, Union[str, bytes]]:
136
+ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
137
137
  """Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
138
138
 
139
139
  NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
@@ -164,7 +164,7 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
164
164
  Raises:
165
165
  Exception: if none of config file and environment variable are present.
166
166
  """
167
- conn_prop: Dict[str, Union[str, bytes]] = {}
167
+ conn_prop: dict[str, Union[str, bytes]] = {}
168
168
  login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
169
169
  # If login file exists, use this exclusively.
170
170
  if os.path.exists(login_file):
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import json
3
- from typing import List, Optional
3
+ from typing import Optional
4
4
 
5
5
  import pandas as pd
6
6
  from pandas import arrays as pandas_arrays
@@ -9,7 +9,7 @@ from pandas.core.arrays import sparse as pandas_sparse
9
9
  from snowflake.snowpark import DataFrame
10
10
 
11
11
 
12
- def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: List[str]) -> Optional[pd.DataFrame]:
12
+ def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: list[str]) -> Optional[pd.DataFrame]:
13
13
  """Convert the pandas df into pandas df with multiple SparseArray columns."""
14
14
  num_rows = pandas_df.shape[0]
15
15
  if num_rows == 0:
@@ -52,8 +52,9 @@ def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: List[str]) ->
52
52
  return pandas_df
53
53
 
54
54
 
55
- def to_pandas_with_sparse(df: DataFrame, sparse_cols: List[str]) -> pd.DataFrame:
56
- """Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray columns.
55
+ def to_pandas_with_sparse(df: DataFrame, sparse_cols: list[str]) -> pd.DataFrame:
56
+ """Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray
57
+ columns.
57
58
 
58
59
  For example, for below input:
59
60
  ----------------------------------------------
@@ -1,5 +1,4 @@
1
1
  from enum import Enum
2
- from typing import Dict
3
2
 
4
3
 
5
4
  class CreationOption(Enum):
@@ -13,7 +12,7 @@ class CreationMode:
13
12
  self.if_not_exists = if_not_exists
14
13
  self.or_replace = or_replace
15
14
 
16
- def get_ddl_phrases(self) -> Dict[CreationOption, str]:
15
+ def get_ddl_phrases(self) -> dict[CreationOption, str]:
17
16
  if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
18
17
  or_replace_sql = " OR REPLACE" if self.or_replace else ""
19
18
  return {
snowflake/ml/version.py CHANGED
@@ -1 +1,2 @@
1
- VERSION="1.8.2"
1
+ # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
+ VERSION = "1.8.3"