snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -1,33 +1,53 @@
1
- from typing import List, TypedDict
1
+ from typing import Optional
2
2
 
3
- from typing_extensions import NotRequired, Required
3
+ from pydantic import BaseModel
4
4
 
5
5
 
6
- class ModelDict(TypedDict):
7
- name: Required[str]
8
- version: Required[str]
6
+ class Model(BaseModel):
7
+ name: str
8
+ version: str
9
9
 
10
10
 
11
- class ImageBuildDict(TypedDict):
12
- compute_pool: Required[str]
13
- image_repo: Required[str]
14
- force_rebuild: Required[bool]
15
- external_access_integrations: NotRequired[List[str]]
11
+ class ImageBuild(BaseModel):
12
+ compute_pool: str
13
+ image_repo: str
14
+ force_rebuild: bool
15
+ external_access_integrations: Optional[list[str]] = None
16
16
 
17
17
 
18
- class ServiceDict(TypedDict):
19
- name: Required[str]
20
- compute_pool: Required[str]
21
- ingress_enabled: Required[bool]
22
- max_instances: Required[int]
23
- cpu: NotRequired[str]
24
- memory: NotRequired[str]
25
- gpu: NotRequired[str]
26
- num_workers: NotRequired[int]
27
- max_batch_rows: NotRequired[int]
18
+ class Service(BaseModel):
19
+ name: str
20
+ compute_pool: str
21
+ ingress_enabled: bool
22
+ max_instances: int
23
+ cpu: Optional[str] = None
24
+ memory: Optional[str] = None
25
+ gpu: Optional[str] = None
26
+ num_workers: Optional[int] = None
27
+ max_batch_rows: Optional[int] = None
28
28
 
29
29
 
30
- class ModelDeploymentSpecDict(TypedDict):
31
- models: Required[List[ModelDict]]
32
- image_build: Required[ImageBuildDict]
33
- service: Required[ServiceDict]
30
+ class Job(BaseModel):
31
+ name: str
32
+ compute_pool: str
33
+ cpu: Optional[str] = None
34
+ memory: Optional[str] = None
35
+ gpu: Optional[str] = None
36
+ num_workers: Optional[int] = None
37
+ max_batch_rows: Optional[int] = None
38
+ warehouse: str
39
+ target_method: str
40
+ input_table_name: str
41
+ output_table_name: str
42
+
43
+
44
+ class ModelServiceDeploymentSpec(BaseModel):
45
+ models: list[Model]
46
+ image_build: ImageBuild
47
+ service: Service
48
+
49
+
50
+ class ModelJobDeploymentSpec(BaseModel):
51
+ models: list[Model]
52
+ image_build: ImageBuild
53
+ job: Job
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -24,8 +24,8 @@ class ModelSQLClient(_base._BaseSQLClient):
24
24
  schema_name: Optional[sql_identifier.SqlIdentifier],
25
25
  model_name: Optional[sql_identifier.SqlIdentifier] = None,
26
26
  validate_result: bool = True,
27
- statement_params: Optional[Dict[str, Any]] = None,
28
- ) -> List[row.Row]:
27
+ statement_params: Optional[dict[str, Any]] = None,
28
+ ) -> list[row.Row]:
29
29
  actual_database_name = database_name or self._database_name
30
30
  actual_schema_name = schema_name or self._schema_name
31
31
  fully_qualified_schema_name = ".".join([actual_database_name.identifier(), actual_schema_name.identifier()])
@@ -57,8 +57,8 @@ class ModelSQLClient(_base._BaseSQLClient):
57
57
  version_name: Optional[sql_identifier.SqlIdentifier] = None,
58
58
  validate_result: bool = True,
59
59
  check_model_details: bool = False,
60
- statement_params: Optional[Dict[str, Any]] = None,
61
- ) -> List[row.Row]:
60
+ statement_params: Optional[dict[str, Any]] = None,
61
+ ) -> list[row.Row]:
62
62
  like_sql = ""
63
63
  if version_name:
64
64
  like_sql = f" LIKE '{version_name.resolved()}'"
@@ -90,7 +90,7 @@ class ModelSQLClient(_base._BaseSQLClient):
90
90
  schema_name: Optional[sql_identifier.SqlIdentifier],
91
91
  model_name: sql_identifier.SqlIdentifier,
92
92
  comment: str,
93
- statement_params: Optional[Dict[str, Any]] = None,
93
+ statement_params: Optional[dict[str, Any]] = None,
94
94
  ) -> None:
95
95
  query_result_checker.SqlResultValidator(
96
96
  self._session,
@@ -107,7 +107,7 @@ class ModelSQLClient(_base._BaseSQLClient):
107
107
  database_name: Optional[sql_identifier.SqlIdentifier],
108
108
  schema_name: Optional[sql_identifier.SqlIdentifier],
109
109
  model_name: sql_identifier.SqlIdentifier,
110
- statement_params: Optional[Dict[str, Any]] = None,
110
+ statement_params: Optional[dict[str, Any]] = None,
111
111
  ) -> None:
112
112
  query_result_checker.SqlResultValidator(
113
113
  self._session,
@@ -124,7 +124,7 @@ class ModelSQLClient(_base._BaseSQLClient):
124
124
  new_model_db: Optional[sql_identifier.SqlIdentifier],
125
125
  new_model_schema: Optional[sql_identifier.SqlIdentifier],
126
126
  new_model_name: sql_identifier.SqlIdentifier,
127
- statement_params: Optional[Dict[str, Any]] = None,
127
+ statement_params: Optional[dict[str, Any]] = None,
128
128
  ) -> None:
129
129
  # Use registry's database and schema if a non fully qualified new model name is provided.
130
130
  new_fully_qualified_name = self.fully_qualified_object_name(new_model_db, new_model_schema, new_model_name)
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  import pathlib
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Optional
5
5
  from urllib.parse import ParseResult
6
6
 
7
7
  from snowflake.ml._internal.utils import (
@@ -34,7 +34,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
34
34
  model_name: sql_identifier.SqlIdentifier,
35
35
  version_name: sql_identifier.SqlIdentifier,
36
36
  stage_path: str,
37
- statement_params: Optional[Dict[str, Any]] = None,
37
+ statement_params: Optional[dict[str, Any]] = None,
38
38
  ) -> None:
39
39
  query_result_checker.SqlResultValidator(
40
40
  self._session,
@@ -56,7 +56,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
56
56
  schema_name: Optional[sql_identifier.SqlIdentifier],
57
57
  model_name: sql_identifier.SqlIdentifier,
58
58
  version_name: sql_identifier.SqlIdentifier,
59
- statement_params: Optional[Dict[str, Any]] = None,
59
+ statement_params: Optional[dict[str, Any]] = None,
60
60
  ) -> None:
61
61
  fq_source_model_name = self.fully_qualified_object_name(
62
62
  source_database_name, source_schema_name, source_model_name
@@ -78,7 +78,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
78
78
  schema_name: Optional[sql_identifier.SqlIdentifier],
79
79
  model_name: sql_identifier.SqlIdentifier,
80
80
  version_name: sql_identifier.SqlIdentifier,
81
- statement_params: Optional[Dict[str, Any]] = None,
81
+ statement_params: Optional[dict[str, Any]] = None,
82
82
  ) -> None:
83
83
  sql = (
84
84
  f"CREATE MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
@@ -97,7 +97,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
97
97
  schema_name: Optional[sql_identifier.SqlIdentifier],
98
98
  model_name: sql_identifier.SqlIdentifier,
99
99
  version_name: sql_identifier.SqlIdentifier,
100
- statement_params: Optional[Dict[str, Any]] = None,
100
+ statement_params: Optional[dict[str, Any]] = None,
101
101
  ) -> None:
102
102
  sql = (
103
103
  f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
@@ -116,7 +116,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
116
116
  schema_name: Optional[sql_identifier.SqlIdentifier],
117
117
  model_name: sql_identifier.SqlIdentifier,
118
118
  version_name: sql_identifier.SqlIdentifier,
119
- statement_params: Optional[Dict[str, Any]] = None,
119
+ statement_params: Optional[dict[str, Any]] = None,
120
120
  ) -> None:
121
121
  sql = (
122
122
  f"ALTER MODEL {self.fully_qualified_object_name(database_name, schema_name, model_name)}"
@@ -138,7 +138,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
138
138
  model_name: sql_identifier.SqlIdentifier,
139
139
  version_name: sql_identifier.SqlIdentifier,
140
140
  stage_path: str,
141
- statement_params: Optional[Dict[str, Any]] = None,
141
+ statement_params: Optional[dict[str, Any]] = None,
142
142
  ) -> None:
143
143
  query_result_checker.SqlResultValidator(
144
144
  self._session,
@@ -160,7 +160,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
160
160
  schema_name: Optional[sql_identifier.SqlIdentifier],
161
161
  model_name: sql_identifier.SqlIdentifier,
162
162
  version_name: sql_identifier.SqlIdentifier,
163
- statement_params: Optional[Dict[str, Any]] = None,
163
+ statement_params: Optional[dict[str, Any]] = None,
164
164
  ) -> None:
165
165
  fq_source_model_name = self.fully_qualified_object_name(
166
166
  source_database_name, source_schema_name, source_model_name
@@ -182,7 +182,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
182
182
  schema_name: Optional[sql_identifier.SqlIdentifier],
183
183
  model_name: sql_identifier.SqlIdentifier,
184
184
  version_name: sql_identifier.SqlIdentifier,
185
- statement_params: Optional[Dict[str, Any]] = None,
185
+ statement_params: Optional[dict[str, Any]] = None,
186
186
  ) -> None:
187
187
  query_result_checker.SqlResultValidator(
188
188
  self._session,
@@ -201,7 +201,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
201
201
  model_name: sql_identifier.SqlIdentifier,
202
202
  version_name: sql_identifier.SqlIdentifier,
203
203
  alias_name: sql_identifier.SqlIdentifier,
204
- statement_params: Optional[Dict[str, Any]] = None,
204
+ statement_params: Optional[dict[str, Any]] = None,
205
205
  ) -> None:
206
206
  query_result_checker.SqlResultValidator(
207
207
  self._session,
@@ -219,7 +219,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
219
219
  schema_name: Optional[sql_identifier.SqlIdentifier],
220
220
  model_name: sql_identifier.SqlIdentifier,
221
221
  version_or_alias_name: sql_identifier.SqlIdentifier,
222
- statement_params: Optional[Dict[str, Any]] = None,
222
+ statement_params: Optional[dict[str, Any]] = None,
223
223
  ) -> None:
224
224
  query_result_checker.SqlResultValidator(
225
225
  self._session,
@@ -239,8 +239,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
239
239
  version_name: sql_identifier.SqlIdentifier,
240
240
  file_path: pathlib.PurePosixPath,
241
241
  is_dir: bool = False,
242
- statement_params: Optional[Dict[str, Any]] = None,
243
- ) -> List[row.Row]:
242
+ statement_params: Optional[dict[str, Any]] = None,
243
+ ) -> list[row.Row]:
244
244
  # Workaround for snowURL bug.
245
245
  trailing_slash = "/" if is_dir else ""
246
246
 
@@ -276,7 +276,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
276
276
  version_name: sql_identifier.SqlIdentifier,
277
277
  file_path: pathlib.PurePosixPath,
278
278
  target_path: pathlib.Path,
279
- statement_params: Optional[Dict[str, Any]] = None,
279
+ statement_params: Optional[dict[str, Any]] = None,
280
280
  ) -> pathlib.Path:
281
281
  stage_location = pathlib.PurePosixPath(
282
282
  self.fully_qualified_object_name(database_name, schema_name, model_name),
@@ -310,8 +310,8 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
310
310
  schema_name: Optional[sql_identifier.SqlIdentifier],
311
311
  model_name: sql_identifier.SqlIdentifier,
312
312
  version_name: sql_identifier.SqlIdentifier,
313
- statement_params: Optional[Dict[str, Any]] = None,
314
- ) -> List[row.Row]:
313
+ statement_params: Optional[dict[str, Any]] = None,
314
+ ) -> list[row.Row]:
315
315
  res = query_result_checker.SqlResultValidator(
316
316
  self._session,
317
317
  (
@@ -331,7 +331,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
331
331
  model_name: sql_identifier.SqlIdentifier,
332
332
  version_name: sql_identifier.SqlIdentifier,
333
333
  comment: str,
334
- statement_params: Optional[Dict[str, Any]] = None,
334
+ statement_params: Optional[dict[str, Any]] = None,
335
335
  ) -> None:
336
336
  query_result_checker.SqlResultValidator(
337
337
  self._session,
@@ -351,9 +351,9 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
351
351
  version_name: sql_identifier.SqlIdentifier,
352
352
  method_name: sql_identifier.SqlIdentifier,
353
353
  input_df: dataframe.DataFrame,
354
- input_args: List[sql_identifier.SqlIdentifier],
355
- returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
- statement_params: Optional[Dict[str, Any]] = None,
354
+ input_args: list[sql_identifier.SqlIdentifier],
355
+ returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
356
+ statement_params: Optional[dict[str, Any]] = None,
357
357
  ) -> dataframe.DataFrame:
358
358
  with_statements = []
359
359
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
@@ -433,10 +433,10 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
433
433
  version_name: sql_identifier.SqlIdentifier,
434
434
  method_name: sql_identifier.SqlIdentifier,
435
435
  input_df: dataframe.DataFrame,
436
- input_args: List[sql_identifier.SqlIdentifier],
437
- returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
436
+ input_args: list[sql_identifier.SqlIdentifier],
437
+ returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
438
438
  partition_column: Optional[sql_identifier.SqlIdentifier],
439
- statement_params: Optional[Dict[str, Any]] = None,
439
+ statement_params: Optional[dict[str, Any]] = None,
440
440
  is_partitioned: bool = True,
441
441
  ) -> dataframe.DataFrame:
442
442
  with_statements = []
@@ -529,13 +529,13 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
529
529
 
530
530
  def set_metadata(
531
531
  self,
532
- metadata_dict: Dict[str, Any],
532
+ metadata_dict: dict[str, Any],
533
533
  *,
534
534
  database_name: Optional[sql_identifier.SqlIdentifier],
535
535
  schema_name: Optional[sql_identifier.SqlIdentifier],
536
536
  model_name: sql_identifier.SqlIdentifier,
537
537
  version_name: sql_identifier.SqlIdentifier,
538
- statement_params: Optional[Dict[str, Any]] = None,
538
+ statement_params: Optional[dict[str, Any]] = None,
539
539
  ) -> None:
540
540
  json_metadata = json.dumps(metadata_dict)
541
541
  query_result_checker.SqlResultValidator(
@@ -554,7 +554,7 @@ class ModelVersionSQLClient(_base._BaseSQLClient):
554
554
  schema_name: Optional[sql_identifier.SqlIdentifier],
555
555
  model_name: sql_identifier.SqlIdentifier,
556
556
  version_name: sql_identifier.SqlIdentifier,
557
- statement_params: Optional[Dict[str, Any]] = None,
557
+ statement_params: Optional[dict[str, Any]] = None,
558
558
  ) -> None:
559
559
  query_result_checker.SqlResultValidator(
560
560
  self._session,
@@ -1,7 +1,7 @@
1
1
  import enum
2
2
  import json
3
3
  import textwrap
4
- from typing import Any, Dict, List, Optional, Tuple, Union
4
+ from typing import Any, Optional, Union
5
5
 
6
6
  from snowflake import snowpark
7
7
  from snowflake.ml._internal import platform_capabilities
@@ -47,7 +47,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
47
47
  gpu: Optional[Union[str, int]],
48
48
  force_rebuild: bool,
49
49
  external_access_integration: sql_identifier.SqlIdentifier,
50
- statement_params: Optional[Dict[str, Any]] = None,
50
+ statement_params: Optional[dict[str, Any]] = None,
51
51
  ) -> None:
52
52
  actual_image_repo_database = image_repo_database_name or self._database_name
53
53
  actual_image_repo_schema = image_repo_schema_name or self._schema_name
@@ -73,13 +73,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
73
73
  def deploy_model(
74
74
  self,
75
75
  *,
76
- stage_path: str,
77
- model_deployment_spec_file_rel_path: str,
78
- statement_params: Optional[Dict[str, Any]] = None,
79
- ) -> Tuple[str, snowpark.AsyncJob]:
80
- async_job = self._session.sql(
81
- f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
82
- ).collect(block=False, statement_params=statement_params)
76
+ stage_path: Optional[str] = None,
77
+ model_deployment_spec_yaml_str: Optional[str] = None,
78
+ model_deployment_spec_file_rel_path: Optional[str] = None,
79
+ statement_params: Optional[dict[str, Any]] = None,
80
+ ) -> tuple[str, snowpark.AsyncJob]:
81
+ assert model_deployment_spec_yaml_str or model_deployment_spec_file_rel_path
82
+ if model_deployment_spec_yaml_str:
83
+ sql_str = f"CALL SYSTEM$DEPLOY_MODEL('{model_deployment_spec_yaml_str}')"
84
+ else:
85
+ sql_str = f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
86
+ async_job = self._session.sql(sql_str).collect(block=False, statement_params=statement_params)
83
87
  assert isinstance(async_job, snowpark.AsyncJob)
84
88
  return async_job.query_id, async_job
85
89
 
@@ -91,9 +95,9 @@ class ServiceSQLClient(_base._BaseSQLClient):
91
95
  service_name: sql_identifier.SqlIdentifier,
92
96
  method_name: sql_identifier.SqlIdentifier,
93
97
  input_df: dataframe.DataFrame,
94
- input_args: List[sql_identifier.SqlIdentifier],
95
- returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
96
- statement_params: Optional[Dict[str, Any]] = None,
98
+ input_args: list[sql_identifier.SqlIdentifier],
99
+ returns: list[tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
100
+ statement_params: Optional[dict[str, Any]] = None,
97
101
  ) -> dataframe.DataFrame:
98
102
  with_statements = []
99
103
  actual_database_name = database_name or self._database_name
@@ -177,7 +181,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
177
181
  service_name: sql_identifier.SqlIdentifier,
178
182
  instance_id: str = "0",
179
183
  container_name: str,
180
- statement_params: Optional[Dict[str, Any]] = None,
184
+ statement_params: Optional[dict[str, Any]] = None,
181
185
  ) -> str:
182
186
  system_func = "SYSTEM$GET_SERVICE_LOGS"
183
187
  rows = (
@@ -202,8 +206,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
202
206
  schema_name: Optional[sql_identifier.SqlIdentifier],
203
207
  service_name: sql_identifier.SqlIdentifier,
204
208
  include_message: bool = False,
205
- statement_params: Optional[Dict[str, Any]] = None,
206
- ) -> Tuple[ServiceStatus, Optional[str]]:
209
+ statement_params: Optional[dict[str, Any]] = None,
210
+ ) -> tuple[ServiceStatus, Optional[str]]:
207
211
  system_func = "SYSTEM$GET_SERVICE_STATUS"
208
212
  rows = (
209
213
  query_result_checker.SqlResultValidator(
@@ -227,7 +231,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
227
231
  database_name: Optional[sql_identifier.SqlIdentifier],
228
232
  schema_name: Optional[sql_identifier.SqlIdentifier],
229
233
  service_name: sql_identifier.SqlIdentifier,
230
- statement_params: Optional[Dict[str, Any]] = None,
234
+ statement_params: Optional[dict[str, Any]] = None,
231
235
  ) -> None:
232
236
  query_result_checker.SqlResultValidator(
233
237
  self._session,
@@ -241,8 +245,8 @@ class ServiceSQLClient(_base._BaseSQLClient):
241
245
  database_name: Optional[sql_identifier.SqlIdentifier],
242
246
  schema_name: Optional[sql_identifier.SqlIdentifier],
243
247
  service_name: sql_identifier.SqlIdentifier,
244
- statement_params: Optional[Dict[str, Any]] = None,
245
- ) -> List[row.Row]:
248
+ statement_params: Optional[dict[str, Any]] = None,
249
+ ) -> list[row.Row]:
246
250
  fully_qualified_service_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
247
251
  res = (
248
252
  query_result_checker.SqlResultValidator(
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -11,7 +11,7 @@ class StageSQLClient(_base._BaseSQLClient):
11
11
  database_name: Optional[sql_identifier.SqlIdentifier],
12
12
  schema_name: Optional[sql_identifier.SqlIdentifier],
13
13
  stage_name: sql_identifier.SqlIdentifier,
14
- statement_params: Optional[Dict[str, Any]] = None,
14
+ statement_params: Optional[dict[str, Any]] = None,
15
15
  ) -> None:
16
16
  query_result_checker.SqlResultValidator(
17
17
  self._session,
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional
1
+ from typing import Any, Optional
2
2
 
3
3
  from snowflake.ml._internal.utils import query_result_checker, sql_identifier
4
4
  from snowflake.ml.model._client.sql import _base
@@ -16,7 +16,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
16
16
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
17
17
  tag_name: sql_identifier.SqlIdentifier,
18
18
  tag_value: str,
19
- statement_params: Optional[Dict[str, Any]] = None,
19
+ statement_params: Optional[dict[str, Any]] = None,
20
20
  ) -> None:
21
21
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
22
22
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -35,7 +35,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
35
35
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
36
36
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
37
37
  tag_name: sql_identifier.SqlIdentifier,
38
- statement_params: Optional[Dict[str, Any]] = None,
38
+ statement_params: Optional[dict[str, Any]] = None,
39
39
  ) -> None:
40
40
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
41
41
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -54,7 +54,7 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
54
54
  tag_database_name: Optional[sql_identifier.SqlIdentifier],
55
55
  tag_schema_name: Optional[sql_identifier.SqlIdentifier],
56
56
  tag_name: sql_identifier.SqlIdentifier,
57
- statement_params: Optional[Dict[str, Any]] = None,
57
+ statement_params: Optional[dict[str, Any]] = None,
58
58
  ) -> row.Row:
59
59
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
60
60
  fq_tag_name = self.fully_qualified_object_name(tag_database_name, tag_schema_name, tag_name)
@@ -75,8 +75,8 @@ class ModuleTagSQLClient(_base._BaseSQLClient):
75
75
  database_name: Optional[sql_identifier.SqlIdentifier],
76
76
  schema_name: Optional[sql_identifier.SqlIdentifier],
77
77
  model_name: sql_identifier.SqlIdentifier,
78
- statement_params: Optional[Dict[str, Any]] = None,
79
- ) -> List[row.Row]:
78
+ statement_params: Optional[dict[str, Any]] = None,
79
+ ) -> list[row.Row]:
80
80
  fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
81
81
  actual_database_name = database_name or self._database_name
82
82
  return (
@@ -3,13 +3,14 @@ import tempfile
3
3
  import uuid
4
4
  import warnings
5
5
  from types import ModuleType
6
- from typing import Any, Dict, List, Optional, Union
6
+ from typing import Any, Optional, Union
7
7
  from urllib import parse
8
8
 
9
9
  from absl import logging
10
10
  from packaging import requirements
11
11
 
12
12
  from snowflake import snowpark
13
+ from snowflake.ml import version as snowml_version
13
14
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
14
15
  from snowflake.ml._internal.lineage import lineage_utils
15
16
  from snowflake.ml.data import data_source
@@ -43,7 +44,8 @@ class ModelComposer:
43
44
  session: Session,
44
45
  stage_path: str,
45
46
  *,
46
- statement_params: Optional[Dict[str, Any]] = None,
47
+ statement_params: Optional[dict[str, Any]] = None,
48
+ save_location: Optional[str] = None,
47
49
  ) -> None:
48
50
  self.session = session
49
51
  self.stage_path: Union[pathlib.PurePosixPath, parse.ParseResult] = None # type: ignore[assignment]
@@ -54,10 +56,29 @@ class ModelComposer:
54
56
  # The stage path is a user stage path
55
57
  self.stage_path = pathlib.PurePosixPath(stage_path)
56
58
 
57
- self._workspace = tempfile.TemporaryDirectory()
58
- self._packager_workspace = tempfile.TemporaryDirectory()
59
+ # Set up workspace based on save_location if provided, otherwise use temporary directory
60
+ self.save_location = save_location
61
+ if save_location:
62
+ # Use the save_location directory directly
63
+ self._workspace_path = pathlib.Path(save_location)
64
+ self._workspace_path.mkdir(exist_ok=True)
65
+ # ensure that the directory is empty
66
+ if any(self._workspace_path.iterdir()):
67
+ raise ValueError(f"The directory {self._workspace_path} is not empty.")
68
+ self._workspace = None
69
+
70
+ self._packager_workspace_path = self._workspace_path / ModelComposer.MODEL_DIR_REL_PATH
71
+ self._packager_workspace_path.mkdir(exist_ok=True)
72
+ self._packager_workspace = None
73
+ else:
74
+ # Use a temporary directory
75
+ self._workspace = tempfile.TemporaryDirectory()
76
+ self._workspace_path = pathlib.Path(self._workspace.name)
77
+
78
+ self._packager_workspace_path = self._workspace_path / ModelComposer.MODEL_DIR_REL_PATH
79
+ self._packager_workspace_path.mkdir(exist_ok=True)
59
80
 
60
- self.packager = model_packager.ModelPackager(local_dir_path=str(self._packager_workspace_path))
81
+ self.packager = model_packager.ModelPackager(local_dir_path=str(self.packager_workspace_path))
61
82
  self.manifest = model_manifest.ModelManifest(workspace_path=self.workspace_path)
62
83
 
63
84
  self.model_file_rel_path = f"model-{uuid.uuid4().hex}.zip"
@@ -65,16 +86,16 @@ class ModelComposer:
65
86
  self._statement_params = statement_params
66
87
 
67
88
  def __del__(self) -> None:
68
- self._workspace.cleanup()
69
- self._packager_workspace.cleanup()
89
+ if self._workspace:
90
+ self._workspace.cleanup()
70
91
 
71
92
  @property
72
93
  def workspace_path(self) -> pathlib.Path:
73
- return pathlib.Path(self._workspace.name)
94
+ return self._workspace_path
74
95
 
75
96
  @property
76
- def _packager_workspace_path(self) -> pathlib.Path:
77
- return pathlib.Path(self._packager_workspace.name)
97
+ def packager_workspace_path(self) -> pathlib.Path:
98
+ return self._packager_workspace_path
78
99
 
79
100
  @property
80
101
  def model_stage_path(self) -> str:
@@ -102,17 +123,18 @@ class ModelComposer:
102
123
  *,
103
124
  name: str,
104
125
  model: model_types.SupportedModelType,
105
- signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
126
+ signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
106
127
  sample_input_data: Optional[model_types.SupportedDataType] = None,
107
- metadata: Optional[Dict[str, str]] = None,
108
- conda_dependencies: Optional[List[str]] = None,
109
- pip_requirements: Optional[List[str]] = None,
110
- artifact_repository_map: Optional[Dict[str, str]] = None,
111
- target_platforms: Optional[List[model_types.TargetPlatform]] = None,
128
+ metadata: Optional[dict[str, str]] = None,
129
+ conda_dependencies: Optional[list[str]] = None,
130
+ pip_requirements: Optional[list[str]] = None,
131
+ artifact_repository_map: Optional[dict[str, str]] = None,
132
+ resource_constraint: Optional[dict[str, str]] = None,
133
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
112
134
  python_version: Optional[str] = None,
113
- user_files: Optional[Dict[str, List[str]]] = None,
114
- ext_modules: Optional[List[ModuleType]] = None,
115
- code_paths: Optional[List[str]] = None,
135
+ user_files: Optional[dict[str, list[str]]] = None,
136
+ ext_modules: Optional[list[ModuleType]] = None,
137
+ code_paths: Optional[list[str]] = None,
116
138
  task: model_types.Task = model_types.Task.UNKNOWN,
117
139
  options: Optional[model_types.ModelSaveOption] = None,
118
140
  ) -> model_meta.ModelMetadata:
@@ -146,14 +168,14 @@ class ModelComposer:
146
168
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
147
169
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
148
170
  self.session,
149
- reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
171
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
150
172
  python_version=python_version or snowml_env.PYTHON_VERSION,
151
173
  statement_params=self._statement_params,
152
174
  ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
153
175
 
154
176
  if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
155
177
  logging.info(
156
- f"Local snowflake-ml-python library has version {snowml_env.VERSION},"
178
+ f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
157
179
  " which is not available in the Snowflake server, embedding local ML library automatically."
158
180
  )
159
181
  options["embed_local_ml_library"] = True
@@ -167,6 +189,8 @@ class ModelComposer:
167
189
  conda_dependencies=conda_dependencies,
168
190
  pip_requirements=pip_requirements,
169
191
  artifact_repository_map=artifact_repository_map,
192
+ resource_constraint=resource_constraint,
193
+ target_platforms=target_platforms,
170
194
  python_version=python_version,
171
195
  ext_modules=ext_modules,
172
196
  code_paths=code_paths,
@@ -175,9 +199,6 @@ class ModelComposer:
175
199
  )
176
200
  assert self.packager.meta is not None
177
201
 
178
- file_utils.copytree(
179
- str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
180
- )
181
202
  self.manifest.save(
182
203
  model_meta=self.packager.meta,
183
204
  model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
@@ -208,7 +229,7 @@ class ModelComposer:
208
229
 
209
230
  def _get_data_sources(
210
231
  self, model: model_types.SupportedModelType, sample_input_data: Optional[model_types.SupportedDataType] = None
211
- ) -> Optional[List[data_source.DataSource]]:
232
+ ) -> Optional[list[data_source.DataSource]]:
212
233
  data_sources = lineage_utils.get_data_sources(model)
213
234
  if not data_sources and sample_input_data is not None:
214
235
  data_sources = lineage_utils.get_data_sources(sample_input_data)