snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Mapping, Optional
1
+ from typing import Any, Mapping, Optional
2
2
 
3
3
  from snowflake import snowpark
4
4
  from snowflake.ml._internal.utils import (
@@ -15,7 +15,7 @@ MODEL_JSON_MODEL_NAME_FIELD = "model_name"
15
15
  MODEL_JSON_VERSION_NAME_FIELD = "version_name"
16
16
 
17
17
 
18
- def _build_sql_list_from_columns(columns: List[sql_identifier.SqlIdentifier]) -> str:
18
+ def _build_sql_list_from_columns(columns: list[sql_identifier.SqlIdentifier]) -> str:
19
19
  sql_list = ", ".join([f"'{column}'" for column in columns])
20
20
  return f"({sql_list})"
21
21
 
@@ -60,17 +60,17 @@ class ModelMonitorSQLClient:
60
60
  function_name: str,
61
61
  warehouse_name: sql_identifier.SqlIdentifier,
62
62
  timestamp_column: sql_identifier.SqlIdentifier,
63
- id_columns: List[sql_identifier.SqlIdentifier],
64
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
65
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
66
- actual_score_columns: List[sql_identifier.SqlIdentifier],
67
- actual_class_columns: List[sql_identifier.SqlIdentifier],
63
+ id_columns: list[sql_identifier.SqlIdentifier],
64
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
65
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
66
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
67
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
68
68
  refresh_interval: str,
69
69
  aggregation_window: str,
70
70
  baseline_database: Optional[sql_identifier.SqlIdentifier] = None,
71
71
  baseline_schema: Optional[sql_identifier.SqlIdentifier] = None,
72
72
  baseline: Optional[sql_identifier.SqlIdentifier] = None,
73
- statement_params: Optional[Dict[str, Any]] = None,
73
+ statement_params: Optional[dict[str, Any]] = None,
74
74
  ) -> None:
75
75
  baseline_sql = ""
76
76
  if baseline:
@@ -103,7 +103,7 @@ class ModelMonitorSQLClient:
103
103
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
104
104
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
105
105
  monitor_name: sql_identifier.SqlIdentifier,
106
- statement_params: Optional[Dict[str, Any]] = None,
106
+ statement_params: Optional[dict[str, Any]] = None,
107
107
  ) -> None:
108
108
  search_database_name = database_name or self._database_name
109
109
  search_schema_name = schema_name or self._schema_name
@@ -116,8 +116,8 @@ class ModelMonitorSQLClient:
116
116
  def show_model_monitors(
117
117
  self,
118
118
  *,
119
- statement_params: Optional[Dict[str, Any]] = None,
120
- ) -> List[snowpark.Row]:
119
+ statement_params: Optional[dict[str, Any]] = None,
120
+ ) -> list[snowpark.Row]:
121
121
  fully_qualified_schema_name = ".".join([self._database_name.identifier(), self._schema_name.identifier()])
122
122
  return (
123
123
  query_result_checker.SqlResultValidator(
@@ -135,7 +135,7 @@ class ModelMonitorSQLClient:
135
135
  database_name: Optional[sql_identifier.SqlIdentifier] = None,
136
136
  schema_name: Optional[sql_identifier.SqlIdentifier] = None,
137
137
  monitor_name: sql_identifier.SqlIdentifier,
138
- statement_params: Optional[Dict[str, Any]] = None,
138
+ statement_params: Optional[dict[str, Any]] = None,
139
139
  ) -> bool:
140
140
  search_database_name = database_name or self._database_name
141
141
  search_schema_name = schema_name or self._schema_name
@@ -153,7 +153,7 @@ class ModelMonitorSQLClient:
153
153
  def validate_monitor_warehouse(
154
154
  self,
155
155
  warehouse_name: sql_identifier.SqlIdentifier,
156
- statement_params: Optional[Dict[str, Any]] = None,
156
+ statement_params: Optional[dict[str, Any]] = None,
157
157
  ) -> None:
158
158
  """Validate warehouse provided for monitoring exists.
159
159
 
@@ -177,11 +177,11 @@ class ModelMonitorSQLClient:
177
177
  *,
178
178
  source_column_schema: Mapping[str, types.DataType],
179
179
  timestamp_column: sql_identifier.SqlIdentifier,
180
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
181
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
182
- actual_score_columns: List[sql_identifier.SqlIdentifier],
183
- actual_class_columns: List[sql_identifier.SqlIdentifier],
184
- id_columns: List[sql_identifier.SqlIdentifier],
180
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
181
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
182
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
183
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
184
+ id_columns: list[sql_identifier.SqlIdentifier],
185
185
  ) -> None:
186
186
  """Ensures all columns exist in the source table.
187
187
 
@@ -221,11 +221,11 @@ class ModelMonitorSQLClient:
221
221
  source_schema: Optional[sql_identifier.SqlIdentifier],
222
222
  source: sql_identifier.SqlIdentifier,
223
223
  timestamp_column: sql_identifier.SqlIdentifier,
224
- prediction_score_columns: List[sql_identifier.SqlIdentifier],
225
- prediction_class_columns: List[sql_identifier.SqlIdentifier],
226
- actual_score_columns: List[sql_identifier.SqlIdentifier],
227
- actual_class_columns: List[sql_identifier.SqlIdentifier],
228
- id_columns: List[sql_identifier.SqlIdentifier],
224
+ prediction_score_columns: list[sql_identifier.SqlIdentifier],
225
+ prediction_class_columns: list[sql_identifier.SqlIdentifier],
226
+ actual_score_columns: list[sql_identifier.SqlIdentifier],
227
+ actual_class_columns: list[sql_identifier.SqlIdentifier],
228
+ id_columns: list[sql_identifier.SqlIdentifier],
229
229
  ) -> None:
230
230
  source_database = source_database or self._database_name
231
231
  source_schema = source_schema or self._schema_name
@@ -250,7 +250,7 @@ class ModelMonitorSQLClient:
250
250
  self,
251
251
  operation: str,
252
252
  monitor_name: sql_identifier.SqlIdentifier,
253
- statement_params: Optional[Dict[str, Any]] = None,
253
+ statement_params: Optional[dict[str, Any]] = None,
254
254
  ) -> None:
255
255
  if operation not in {"SUSPEND", "RESUME"}:
256
256
  raise ValueError(f"Operation {operation} not supported for altering Dynamic Tables")
@@ -263,7 +263,7 @@ class ModelMonitorSQLClient:
263
263
  def suspend_monitor(
264
264
  self,
265
265
  monitor_name: sql_identifier.SqlIdentifier,
266
- statement_params: Optional[Dict[str, Any]] = None,
266
+ statement_params: Optional[dict[str, Any]] = None,
267
267
  ) -> None:
268
268
  self._alter_monitor(
269
269
  operation="SUSPEND",
@@ -274,7 +274,7 @@ class ModelMonitorSQLClient:
274
274
  def resume_monitor(
275
275
  self,
276
276
  monitor_name: sql_identifier.SqlIdentifier,
277
- statement_params: Optional[Dict[str, Any]] = None,
277
+ statement_params: Optional[dict[str, Any]] = None,
278
278
  ) -> None:
279
279
  self._alter_monitor(
280
280
  operation="RESUME",
@@ -1,5 +1,5 @@
1
1
  import json
2
- from typing import Any, Dict, List, Optional
2
+ from typing import Any, Optional
3
3
 
4
4
  from snowflake import snowpark
5
5
  from snowflake.ml._internal.utils import sql_identifier
@@ -20,7 +20,7 @@ class ModelMonitorManager:
20
20
  database_name: sql_identifier.SqlIdentifier,
21
21
  schema_name: sql_identifier.SqlIdentifier,
22
22
  *,
23
- statement_params: Optional[Dict[str, Any]] = None,
23
+ statement_params: Optional[dict[str, Any]] = None,
24
24
  ) -> None:
25
25
  """
26
26
  Opens a ModelMonitorManager for a given database and schema.
@@ -64,7 +64,7 @@ class ModelMonitorManager:
64
64
  f"Found: {existing_target_methods}."
65
65
  )
66
66
 
67
- def _build_column_list_from_input(self, columns: Optional[List[str]]) -> List[sql_identifier.SqlIdentifier]:
67
+ def _build_column_list_from_input(self, columns: Optional[list[str]]) -> list[sql_identifier.SqlIdentifier]:
68
68
  return [sql_identifier.SqlIdentifier(column_name) for column_name in columns] if columns else []
69
69
 
70
70
  def add_monitor(
@@ -172,7 +172,7 @@ class ModelMonitorManager:
172
172
  """
173
173
  rows = self._model_monitor_client.show_model_monitors(statement_params=self.statement_params)
174
174
 
175
- def model_match_fn(model_details: Dict[str, str]) -> bool:
175
+ def model_match_fn(model_details: dict[str, str]) -> bool:
176
176
  return (
177
177
  model_details[model_monitor_sql_client.MODEL_JSON_MODEL_NAME_FIELD] == model_version.model_name
178
178
  and model_details[model_monitor_sql_client.MODEL_JSON_VERSION_NAME_FIELD] == model_version.version_name
@@ -215,7 +215,7 @@ class ModelMonitorManager:
215
215
  name=monitor_name_id,
216
216
  )
217
217
 
218
- def show_model_monitors(self) -> List[snowpark.Row]:
218
+ def show_model_monitors(self) -> list[snowpark.Row]:
219
219
  """Show all model monitors in the registry.
220
220
 
221
221
  Returns:
@@ -1,5 +1,5 @@
1
1
  from dataclasses import dataclass
2
- from typing import List, Optional
2
+ from typing import Optional
3
3
 
4
4
  from snowflake.ml.model._client.model import model_version_impl
5
5
 
@@ -14,20 +14,20 @@ class ModelMonitorSourceConfig:
14
14
  timestamp_column: str
15
15
  """Name of column in the source containing timestamp."""
16
16
 
17
- id_columns: List[str]
17
+ id_columns: list[str]
18
18
  """List of columns in the source containing unique identifiers."""
19
19
 
20
- prediction_score_columns: Optional[List[str]] = None
20
+ prediction_score_columns: Optional[list[str]] = None
21
21
  """List of columns in the source containing prediction scores.
22
22
  Can be regression scores for regression models and probability scores for classification models."""
23
23
 
24
- prediction_class_columns: Optional[List[str]] = None
24
+ prediction_class_columns: Optional[list[str]] = None
25
25
  """List of columns in the source containing prediction classes for classification models."""
26
26
 
27
- actual_score_columns: Optional[List[str]] = None
27
+ actual_score_columns: Optional[list[str]] = None
28
28
  """List of columns in the source containing actual scores."""
29
29
 
30
- actual_class_columns: Optional[List[str]] = None
30
+ actual_class_columns: Optional[list[str]] = None
31
31
  """List of columns in the source containing actual classes for classification models."""
32
32
 
33
33
  baseline: Optional[str] = None
@@ -1,10 +1,10 @@
1
1
  from types import ModuleType
2
- from typing import Any, Dict, List, Optional, Tuple, Union
2
+ from typing import Any, Optional, Union
3
3
 
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
6
 
7
- from snowflake.ml._internal import platform_capabilities, telemetry
7
+ from snowflake.ml._internal import env, platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
@@ -43,20 +43,21 @@ class ModelManager:
43
43
  model_name: str,
44
44
  version_name: Optional[str] = None,
45
45
  comment: Optional[str] = None,
46
- metrics: Optional[Dict[str, Any]] = None,
47
- conda_dependencies: Optional[List[str]] = None,
48
- pip_requirements: Optional[List[str]] = None,
49
- artifact_repository_map: Optional[Dict[str, str]] = None,
50
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
46
+ metrics: Optional[dict[str, Any]] = None,
47
+ conda_dependencies: Optional[list[str]] = None,
48
+ pip_requirements: Optional[list[str]] = None,
49
+ artifact_repository_map: Optional[dict[str, str]] = None,
50
+ resource_constraint: Optional[dict[str, str]] = None,
51
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
51
52
  python_version: Optional[str] = None,
52
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
53
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
53
54
  sample_input_data: Optional[model_types.SupportedDataType] = None,
54
- user_files: Optional[Dict[str, List[str]]] = None,
55
- code_paths: Optional[List[str]] = None,
56
- ext_modules: Optional[List[ModuleType]] = None,
55
+ user_files: Optional[dict[str, list[str]]] = None,
56
+ code_paths: Optional[list[str]] = None,
57
+ ext_modules: Optional[list[ModuleType]] = None,
57
58
  task: model_types.Task = model_types.Task.UNKNOWN,
58
59
  options: Optional[model_types.ModelSaveOption] = None,
59
- statement_params: Optional[Dict[str, Any]] = None,
60
+ statement_params: Optional[dict[str, Any]] = None,
60
61
  ) -> model_version_impl.ModelVersion:
61
62
 
62
63
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -129,6 +130,7 @@ class ModelManager:
129
130
  conda_dependencies=conda_dependencies,
130
131
  pip_requirements=pip_requirements,
131
132
  artifact_repository_map=artifact_repository_map,
133
+ resource_constraint=resource_constraint,
132
134
  target_platforms=target_platforms,
133
135
  python_version=python_version,
134
136
  signatures=signatures,
@@ -148,20 +150,21 @@ class ModelManager:
148
150
  model_name: str,
149
151
  version_name: str,
150
152
  comment: Optional[str] = None,
151
- metrics: Optional[Dict[str, Any]] = None,
152
- conda_dependencies: Optional[List[str]] = None,
153
- pip_requirements: Optional[List[str]] = None,
154
- artifact_repository_map: Optional[Dict[str, str]] = None,
155
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
153
+ metrics: Optional[dict[str, Any]] = None,
154
+ conda_dependencies: Optional[list[str]] = None,
155
+ pip_requirements: Optional[list[str]] = None,
156
+ artifact_repository_map: Optional[dict[str, str]] = None,
157
+ resource_constraint: Optional[dict[str, str]] = None,
158
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
156
159
  python_version: Optional[str] = None,
157
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
160
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
158
161
  sample_input_data: Optional[model_types.SupportedDataType] = None,
159
- user_files: Optional[Dict[str, List[str]]] = None,
160
- code_paths: Optional[List[str]] = None,
161
- ext_modules: Optional[List[ModuleType]] = None,
162
+ user_files: Optional[dict[str, list[str]]] = None,
163
+ code_paths: Optional[list[str]] = None,
164
+ ext_modules: Optional[list[ModuleType]] = None,
162
165
  task: model_types.Task = model_types.Task.UNKNOWN,
163
166
  options: Optional[model_types.ModelSaveOption] = None,
164
- statement_params: Optional[Dict[str, Any]] = None,
167
+ statement_params: Optional[dict[str, Any]] = None,
165
168
  ) -> model_version_impl.ModelVersion:
166
169
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
167
170
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -208,6 +211,14 @@ class ModelManager:
208
211
  if target_platforms:
209
212
  # Convert any string target platforms to TargetPlatform objects
210
213
  platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
214
+ else:
215
+ # Default the target platform to SPCS if not specified when running in ML runtime
216
+ if env.IN_ML_RUNTIME:
217
+ logger.info(
218
+ "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
219
+ 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
220
+ )
221
+ platforms = [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
211
222
 
212
223
  if artifact_repository_map:
213
224
  for channel, artifact_repository_name in artifact_repository_map.items():
@@ -223,8 +234,17 @@ class ModelManager:
223
234
 
224
235
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
225
236
 
237
+ # Extract save_location from options if present
238
+ save_location = None
239
+ if options and "save_location" in options:
240
+ save_location = options.get("save_location")
241
+ logger.info(f"Model will be saved to local directory: {save_location}")
242
+
226
243
  mc = model_composer.ModelComposer(
227
- self._model_ops._session, stage_path=stage_path, statement_params=statement_params
244
+ self._model_ops._session,
245
+ stage_path=stage_path,
246
+ statement_params=statement_params,
247
+ save_location=save_location,
228
248
  )
229
249
  model_metadata: model_meta.ModelMetadata = mc.save(
230
250
  name=model_name_id.resolved(),
@@ -234,6 +254,7 @@ class ModelManager:
234
254
  conda_dependencies=conda_dependencies,
235
255
  pip_requirements=pip_requirements,
236
256
  artifact_repository_map=artifact_repository_map,
257
+ resource_constraint=resource_constraint,
237
258
  target_platforms=platforms,
238
259
  python_version=python_version,
239
260
  user_files=user_files,
@@ -295,7 +316,7 @@ class ModelManager:
295
316
  self,
296
317
  model_name: str,
297
318
  *,
298
- statement_params: Optional[Dict[str, Any]] = None,
319
+ statement_params: Optional[dict[str, Any]] = None,
299
320
  ) -> model_impl.Model:
300
321
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
301
322
  if self._model_ops.validate_existence(
@@ -323,8 +344,8 @@ class ModelManager:
323
344
  def models(
324
345
  self,
325
346
  *,
326
- statement_params: Optional[Dict[str, Any]] = None,
327
- ) -> List[model_impl.Model]:
347
+ statement_params: Optional[dict[str, Any]] = None,
348
+ ) -> list[model_impl.Model]:
328
349
  model_names = self._model_ops.list_models_or_versions(
329
350
  database_name=None,
330
351
  schema_name=None,
@@ -342,7 +363,7 @@ class ModelManager:
342
363
  def show_models(
343
364
  self,
344
365
  *,
345
- statement_params: Optional[Dict[str, Any]] = None,
366
+ statement_params: Optional[dict[str, Any]] = None,
346
367
  ) -> pd.DataFrame:
347
368
  rows = self._model_ops.show_models_or_versions(
348
369
  database_name=None,
@@ -355,7 +376,7 @@ class ModelManager:
355
376
  self,
356
377
  model_name: str,
357
378
  *,
358
- statement_params: Optional[Dict[str, Any]] = None,
379
+ statement_params: Optional[dict[str, Any]] = None,
359
380
  ) -> None:
360
381
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
361
382
 
@@ -368,7 +389,7 @@ class ModelManager:
368
389
 
369
390
  def _parse_fully_qualified_name(
370
391
  self, model_name: str
371
- ) -> Tuple[
392
+ ) -> tuple[
372
393
  Optional[sql_identifier.SqlIdentifier], Optional[sql_identifier.SqlIdentifier], sql_identifier.SqlIdentifier
373
394
  ]:
374
395
  try:
@@ -1,6 +1,6 @@
1
1
  import warnings
2
2
  from types import ModuleType
3
- from typing import Any, Dict, List, Optional, Union, overload
3
+ from typing import Any, Optional, Union, overload
4
4
 
5
5
  import pandas as pd
6
6
 
@@ -36,7 +36,7 @@ class Registry:
36
36
  *,
37
37
  database_name: Optional[str] = None,
38
38
  schema_name: Optional[str] = None,
39
- options: Optional[Dict[str, Any]] = None,
39
+ options: Optional[dict[str, Any]] = None,
40
40
  ) -> None:
41
41
  """Opens a registry within a pre-created Snowflake schema.
42
42
 
@@ -75,7 +75,9 @@ class Registry:
75
75
  )
76
76
 
77
77
  self._model_manager = model_manager.ModelManager(
78
- session, database_name=self._database_name, schema_name=self._schema_name
78
+ session,
79
+ database_name=self._database_name,
80
+ schema_name=self._schema_name,
79
81
  )
80
82
 
81
83
  self.enable_monitoring = options.get("enable_monitoring", True) if options else True
@@ -105,17 +107,18 @@ class Registry:
105
107
  model_name: str,
106
108
  version_name: Optional[str] = None,
107
109
  comment: Optional[str] = None,
108
- metrics: Optional[Dict[str, Any]] = None,
109
- conda_dependencies: Optional[List[str]] = None,
110
- pip_requirements: Optional[List[str]] = None,
111
- artifact_repository_map: Optional[Dict[str, str]] = None,
112
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
110
+ metrics: Optional[dict[str, Any]] = None,
111
+ conda_dependencies: Optional[list[str]] = None,
112
+ pip_requirements: Optional[list[str]] = None,
113
+ artifact_repository_map: Optional[dict[str, str]] = None,
114
+ resource_constraint: Optional[dict[str, str]] = None,
115
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
113
116
  python_version: Optional[str] = None,
114
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
117
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
115
118
  sample_input_data: Optional[model_types.SupportedDataType] = None,
116
- user_files: Optional[Dict[str, List[str]]] = None,
117
- code_paths: Optional[List[str]] = None,
118
- ext_modules: Optional[List[ModuleType]] = None,
119
+ user_files: Optional[dict[str, list[str]]] = None,
120
+ code_paths: Optional[list[str]] = None,
121
+ ext_modules: Optional[list[ModuleType]] = None,
119
122
  task: model_types.Task = model_types.Task.UNKNOWN,
120
123
  options: Optional[model_types.ModelSaveOption] = None,
121
124
  ) -> ModelVersion:
@@ -150,6 +153,7 @@ class Registry:
150
153
  Format: {channel_name: artifact_repository_name}, where:
151
154
  - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
152
155
  - artifact_repository_name: The name or URL of the repository to fetch packages from.
156
+ resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
153
157
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
154
158
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
155
159
  python_version: Python version in which the model is run. Defaults to None.
@@ -181,6 +185,7 @@ class Registry:
181
185
  - target_methods: List of target methods to register when logging the model.
182
186
  This option is not used in MLFlow models. Defaults to None, in which case the model handler's
183
187
  default target methods will be used.
188
+ - save_location: Location to save the model and metadata.
184
189
  - method_options: Per-method saving options. This dictionary has method names as keys and dictionary
185
190
  values with the desired options.
186
191
 
@@ -229,6 +234,7 @@ class Registry:
229
234
  "conda_dependencies",
230
235
  "pip_requirements",
231
236
  "artifact_repository_map",
237
+ "resource_constraint",
232
238
  "target_platforms",
233
239
  "python_version",
234
240
  "signatures",
@@ -241,17 +247,18 @@ class Registry:
241
247
  model_name: str,
242
248
  version_name: Optional[str] = None,
243
249
  comment: Optional[str] = None,
244
- metrics: Optional[Dict[str, Any]] = None,
245
- conda_dependencies: Optional[List[str]] = None,
246
- pip_requirements: Optional[List[str]] = None,
247
- artifact_repository_map: Optional[Dict[str, str]] = None,
248
- target_platforms: Optional[List[model_types.SupportedTargetPlatformType]] = None,
250
+ metrics: Optional[dict[str, Any]] = None,
251
+ conda_dependencies: Optional[list[str]] = None,
252
+ pip_requirements: Optional[list[str]] = None,
253
+ artifact_repository_map: Optional[dict[str, str]] = None,
254
+ resource_constraint: Optional[dict[str, str]] = None,
255
+ target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
249
256
  python_version: Optional[str] = None,
250
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
257
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
251
258
  sample_input_data: Optional[model_types.SupportedDataType] = None,
252
- user_files: Optional[Dict[str, List[str]]] = None,
253
- code_paths: Optional[List[str]] = None,
254
- ext_modules: Optional[List[ModuleType]] = None,
259
+ user_files: Optional[dict[str, list[str]]] = None,
260
+ code_paths: Optional[list[str]] = None,
261
+ ext_modules: Optional[list[ModuleType]] = None,
255
262
  task: model_types.Task = model_types.Task.UNKNOWN,
256
263
  options: Optional[model_types.ModelSaveOption] = None,
257
264
  ) -> ModelVersion:
@@ -286,6 +293,7 @@ class Registry:
286
293
  Format: {channel_name: artifact_repository_name}, where:
287
294
  - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
288
295
  - artifact_repository_name: The name or URL of the repository to fetch packages from.
296
+ resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
289
297
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
290
298
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
291
299
  python_version: Python version in which the model is run. Defaults to None.
@@ -317,6 +325,7 @@ class Registry:
317
325
  - target_methods: List of target methods to register when logging the model.
318
326
  This option is not used in MLFlow models. Defaults to None, in which case the model handler's
319
327
  default target methods will be used.
328
+ - save_location: Location to save the model and metadata.
320
329
  - method_options: Per-method saving options. This dictionary has method names as keys and dictionary
321
330
  values with the desired options. See the example below.
322
331
 
@@ -369,6 +378,7 @@ class Registry:
369
378
  conda_dependencies,
370
379
  pip_requirements,
371
380
  artifact_repository_map,
381
+ resource_constraint,
372
382
  target_platforms,
373
383
  python_version,
374
384
  signatures,
@@ -403,6 +413,7 @@ class Registry:
403
413
  conda_dependencies=conda_dependencies,
404
414
  pip_requirements=pip_requirements,
405
415
  artifact_repository_map=artifact_repository_map,
416
+ resource_constraint=resource_constraint,
406
417
  target_platforms=target_platforms,
407
418
  python_version=python_version,
408
419
  signatures=signatures,
@@ -438,7 +449,7 @@ class Registry:
438
449
  project=_TELEMETRY_PROJECT,
439
450
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
440
451
  )
441
- def models(self) -> List[Model]:
452
+ def models(self) -> list[Model]:
442
453
  """Get all models in the schema where the registry is opened.
443
454
 
444
455
  Returns:
@@ -564,7 +575,7 @@ class Registry:
564
575
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
565
576
  )
566
577
  @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
567
- def show_model_monitors(self) -> List[snowpark.Row]:
578
+ def show_model_monitors(self) -> list[snowpark.Row]:
568
579
  """Show all model monitors in the registry.
569
580
 
570
581
  Returns:
@@ -1,7 +1,7 @@
1
1
  import http
2
2
  import logging
3
3
  from datetime import timedelta
4
- from typing import Dict, Optional
4
+ from typing import Optional
5
5
 
6
6
  import requests
7
7
  from cryptography.hazmat.primitives.asymmetric import types
@@ -10,7 +10,7 @@ from requests import auth
10
10
  from snowflake.ml._internal.utils import jwt_generator
11
11
 
12
12
  logger = logging.getLogger(__name__)
13
- _JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {}
13
+ _JWT_TOKEN_CACHE: dict[str, dict[int, str]] = {}
14
14
 
15
15
 
16
16
  def get_jwt_token_generator(
@@ -1,6 +1,6 @@
1
1
  import configparser
2
2
  import os
3
- from typing import Dict, Optional, Union
3
+ from typing import Optional, Union
4
4
 
5
5
  from absl import logging
6
6
  from cryptography.hazmat import backends
@@ -76,7 +76,7 @@ def _load_pem_to_der(private_key_path: str) -> bytes:
76
76
  )
77
77
 
78
78
 
79
- def _connection_properties_from_env() -> Dict[str, str]:
79
+ def _connection_properties_from_env() -> dict[str, str]:
80
80
  """Returns a dict with all possible login related env variables."""
81
81
  sf_conn_prop = {
82
82
  # Mandatory fields
@@ -104,7 +104,7 @@ def _connection_properties_from_env() -> Dict[str, str]:
104
104
  return sf_conn_prop
105
105
 
106
106
 
107
- def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> Dict[str, str]:
107
+ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -> dict[str, str]:
108
108
  """Loads the dictionary from snowsql config file."""
109
109
  snowsql_config_file = login_file if login_file else os.path.expanduser(_DEFAULT_CONNECTION_FILE)
110
110
  if not os.path.exists(snowsql_config_file):
@@ -133,7 +133,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
133
133
 
134
134
 
135
135
  @snowpark._internal.utils.private_preview(version="0.2.0")
136
- def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> Dict[str, Union[str, bytes]]:
136
+ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
137
137
  """Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
138
138
 
139
139
  NOTE: Token/Auth information is sideloaded in all cases above, if provided in following order:
@@ -164,7 +164,7 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
164
164
  Raises:
165
165
  Exception: if none of config file and environment variable are present.
166
166
  """
167
- conn_prop: Dict[str, Union[str, bytes]] = {}
167
+ conn_prop: dict[str, Union[str, bytes]] = {}
168
168
  login_file = login_file or os.path.expanduser(_DEFAULT_CONNECTION_FILE)
169
169
  # If login file exists, use this exclusively.
170
170
  if os.path.exists(login_file):
@@ -1,6 +1,6 @@
1
1
  import collections
2
2
  import json
3
- from typing import List, Optional
3
+ from typing import Optional
4
4
 
5
5
  import pandas as pd
6
6
  from pandas import arrays as pandas_arrays
@@ -9,7 +9,7 @@ from pandas.core.arrays import sparse as pandas_sparse
9
9
  from snowflake.snowpark import DataFrame
10
10
 
11
11
 
12
- def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: List[str]) -> Optional[pd.DataFrame]:
12
+ def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: list[str]) -> Optional[pd.DataFrame]:
13
13
  """Convert the pandas df into pandas df with multiple SparseArray columns."""
14
14
  num_rows = pandas_df.shape[0]
15
15
  if num_rows == 0:
@@ -52,8 +52,9 @@ def _pandas_to_sparse_pandas(pandas_df: pd.DataFrame, sparse_cols: List[str]) ->
52
52
  return pandas_df
53
53
 
54
54
 
55
- def to_pandas_with_sparse(df: DataFrame, sparse_cols: List[str]) -> pd.DataFrame:
56
- """Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray columns.
55
+ def to_pandas_with_sparse(df: DataFrame, sparse_cols: list[str]) -> pd.DataFrame:
56
+ """Load a Snowpark df with sparse columns represented in JSON strings into pandas df with multiple SparseArray
57
+ columns.
57
58
 
58
59
  For example, for below input:
59
60
  ----------------------------------------------
@@ -1,5 +1,4 @@
1
1
  from enum import Enum
2
- from typing import Dict
3
2
 
4
3
 
5
4
  class CreationOption(Enum):
@@ -13,7 +12,7 @@ class CreationMode:
13
12
  self.if_not_exists = if_not_exists
14
13
  self.or_replace = or_replace
15
14
 
16
- def get_ddl_phrases(self) -> Dict[CreationOption, str]:
15
+ def get_ddl_phrases(self) -> dict[CreationOption, str]:
17
16
  if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
18
17
  or_replace_sql = " OR REPLACE" if self.or_replace else ""
19
18
  return {
snowflake/ml/version.py CHANGED
@@ -1 +1,2 @@
1
- VERSION="1.8.1"
1
+ # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
+ VERSION = "1.8.3"