snowflake-ml-python 1.8.1__py3-none-any.whl → 1.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (170) hide show
  1. snowflake/cortex/_classify_text.py +3 -3
  2. snowflake/cortex/_complete.py +64 -31
  3. snowflake/cortex/_embed_text_1024.py +4 -4
  4. snowflake/cortex/_embed_text_768.py +4 -4
  5. snowflake/cortex/_finetune.py +8 -8
  6. snowflake/cortex/_util.py +8 -12
  7. snowflake/ml/_internal/env.py +4 -3
  8. snowflake/ml/_internal/env_utils.py +63 -34
  9. snowflake/ml/_internal/file_utils.py +10 -21
  10. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +5 -7
  11. snowflake/ml/_internal/init_utils.py +2 -3
  12. snowflake/ml/_internal/lineage/lineage_utils.py +6 -6
  13. snowflake/ml/_internal/platform_capabilities.py +41 -5
  14. snowflake/ml/_internal/telemetry.py +39 -52
  15. snowflake/ml/_internal/type_utils.py +3 -3
  16. snowflake/ml/_internal/utils/db_utils.py +2 -2
  17. snowflake/ml/_internal/utils/identifier.py +8 -8
  18. snowflake/ml/_internal/utils/import_utils.py +2 -2
  19. snowflake/ml/_internal/utils/parallelize.py +7 -7
  20. snowflake/ml/_internal/utils/pkg_version_utils.py +11 -11
  21. snowflake/ml/_internal/utils/query_result_checker.py +4 -4
  22. snowflake/ml/_internal/utils/snowflake_env.py +28 -6
  23. snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +2 -2
  24. snowflake/ml/_internal/utils/sql_identifier.py +3 -3
  25. snowflake/ml/_internal/utils/table_manager.py +9 -9
  26. snowflake/ml/data/_internal/arrow_ingestor.py +7 -7
  27. snowflake/ml/data/data_connector.py +40 -36
  28. snowflake/ml/data/data_ingestor.py +4 -15
  29. snowflake/ml/data/data_source.py +2 -2
  30. snowflake/ml/data/ingestor_utils.py +3 -3
  31. snowflake/ml/data/torch_utils.py +5 -5
  32. snowflake/ml/dataset/dataset.py +11 -11
  33. snowflake/ml/dataset/dataset_metadata.py +8 -8
  34. snowflake/ml/dataset/dataset_reader.py +12 -8
  35. snowflake/ml/feature_store/__init__.py +1 -1
  36. snowflake/ml/feature_store/access_manager.py +7 -7
  37. snowflake/ml/feature_store/entity.py +6 -6
  38. snowflake/ml/feature_store/examples/airline_features/entities.py +1 -3
  39. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +1 -3
  40. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +1 -3
  41. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +1 -3
  42. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +1 -3
  43. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +1 -3
  44. snowflake/ml/feature_store/examples/example_helper.py +16 -16
  45. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +1 -3
  46. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +1 -3
  47. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +1 -3
  48. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +1 -3
  49. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +1 -3
  50. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +1 -3
  51. snowflake/ml/feature_store/feature_store.py +52 -64
  52. snowflake/ml/feature_store/feature_view.py +24 -24
  53. snowflake/ml/fileset/embedded_stage_fs.py +5 -5
  54. snowflake/ml/fileset/fileset.py +5 -5
  55. snowflake/ml/fileset/sfcfs.py +13 -13
  56. snowflake/ml/fileset/stage_fs.py +15 -15
  57. snowflake/ml/jobs/_utils/constants.py +2 -4
  58. snowflake/ml/jobs/_utils/interop_utils.py +442 -0
  59. snowflake/ml/jobs/_utils/payload_utils.py +86 -62
  60. snowflake/ml/jobs/_utils/scripts/constants.py +4 -0
  61. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +136 -0
  62. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +181 -0
  63. snowflake/ml/jobs/_utils/scripts/signal_workers.py +203 -0
  64. snowflake/ml/jobs/_utils/scripts/worker_shutdown_listener.py +242 -0
  65. snowflake/ml/jobs/_utils/spec_utils.py +22 -36
  66. snowflake/ml/jobs/_utils/types.py +8 -2
  67. snowflake/ml/jobs/decorators.py +7 -8
  68. snowflake/ml/jobs/job.py +158 -26
  69. snowflake/ml/jobs/manager.py +78 -30
  70. snowflake/ml/lineage/lineage_node.py +5 -5
  71. snowflake/ml/model/_client/model/model_impl.py +3 -3
  72. snowflake/ml/model/_client/model/model_version_impl.py +103 -35
  73. snowflake/ml/model/_client/ops/metadata_ops.py +7 -7
  74. snowflake/ml/model/_client/ops/model_ops.py +41 -41
  75. snowflake/ml/model/_client/ops/service_ops.py +230 -50
  76. snowflake/ml/model/_client/service/model_deployment_spec.py +175 -48
  77. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +44 -24
  78. snowflake/ml/model/_client/sql/model.py +8 -8
  79. snowflake/ml/model/_client/sql/model_version.py +26 -26
  80. snowflake/ml/model/_client/sql/service.py +22 -18
  81. snowflake/ml/model/_client/sql/stage.py +2 -2
  82. snowflake/ml/model/_client/sql/tag.py +6 -6
  83. snowflake/ml/model/_model_composer/model_composer.py +46 -25
  84. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -16
  85. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +14 -13
  86. snowflake/ml/model/_model_composer/model_method/model_method.py +3 -3
  87. snowflake/ml/model/_packager/model_env/model_env.py +35 -26
  88. snowflake/ml/model/_packager/model_handler.py +4 -4
  89. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  90. snowflake/ml/model/_packager/model_handlers/_utils.py +15 -3
  91. snowflake/ml/model/_packager/model_handlers/catboost.py +5 -5
  92. snowflake/ml/model/_packager/model_handlers/custom.py +8 -4
  93. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +7 -21
  94. snowflake/ml/model/_packager/model_handlers/keras.py +4 -4
  95. snowflake/ml/model/_packager/model_handlers/lightgbm.py +4 -14
  96. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -3
  97. snowflake/ml/model/_packager/model_handlers/pytorch.py +4 -4
  98. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +5 -5
  99. snowflake/ml/model/_packager/model_handlers/sklearn.py +5 -6
  100. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +3 -3
  101. snowflake/ml/model/_packager/model_handlers/tensorflow.py +4 -4
  102. snowflake/ml/model/_packager/model_handlers/torchscript.py +4 -4
  103. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -15
  104. snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -2
  105. snowflake/ml/model/_packager/model_meta/model_meta.py +42 -37
  106. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +13 -11
  107. snowflake/ml/model/_packager/model_meta_migrator/base_migrator.py +3 -3
  108. snowflake/ml/model/_packager/model_meta_migrator/migrator_plans.py +3 -3
  109. snowflake/ml/model/_packager/model_meta_migrator/migrator_v1.py +4 -4
  110. snowflake/ml/model/_packager/model_packager.py +12 -8
  111. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +32 -1
  112. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  113. snowflake/ml/model/_signatures/core.py +16 -24
  114. snowflake/ml/model/_signatures/dmatrix_handler.py +2 -2
  115. snowflake/ml/model/_signatures/utils.py +6 -6
  116. snowflake/ml/model/custom_model.py +8 -8
  117. snowflake/ml/model/model_signature.py +9 -20
  118. snowflake/ml/model/models/huggingface_pipeline.py +7 -4
  119. snowflake/ml/model/type_hints.py +5 -3
  120. snowflake/ml/modeling/_internal/estimator_utils.py +7 -7
  121. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +6 -6
  122. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +7 -7
  123. snowflake/ml/modeling/_internal/model_specifications.py +8 -10
  124. snowflake/ml/modeling/_internal/model_trainer.py +5 -5
  125. snowflake/ml/modeling/_internal/model_trainer_builder.py +6 -6
  126. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +30 -30
  127. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +13 -13
  128. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +31 -31
  129. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +19 -19
  130. snowflake/ml/modeling/_internal/transformer_protocols.py +17 -17
  131. snowflake/ml/modeling/framework/_utils.py +10 -10
  132. snowflake/ml/modeling/framework/base.py +32 -32
  133. snowflake/ml/modeling/impute/__init__.py +1 -1
  134. snowflake/ml/modeling/impute/simple_imputer.py +5 -5
  135. snowflake/ml/modeling/metrics/__init__.py +1 -1
  136. snowflake/ml/modeling/metrics/classification.py +39 -39
  137. snowflake/ml/modeling/metrics/metrics_utils.py +12 -12
  138. snowflake/ml/modeling/metrics/ranking.py +7 -7
  139. snowflake/ml/modeling/metrics/regression.py +13 -13
  140. snowflake/ml/modeling/model_selection/__init__.py +1 -1
  141. snowflake/ml/modeling/model_selection/grid_search_cv.py +7 -7
  142. snowflake/ml/modeling/model_selection/randomized_search_cv.py +7 -7
  143. snowflake/ml/modeling/pipeline/__init__.py +1 -1
  144. snowflake/ml/modeling/pipeline/pipeline.py +18 -18
  145. snowflake/ml/modeling/preprocessing/__init__.py +1 -1
  146. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +13 -13
  147. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +4 -4
  148. snowflake/ml/modeling/preprocessing/min_max_scaler.py +8 -8
  149. snowflake/ml/modeling/preprocessing/normalizer.py +0 -1
  150. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +28 -28
  151. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +9 -9
  152. snowflake/ml/modeling/preprocessing/robust_scaler.py +7 -7
  153. snowflake/ml/modeling/preprocessing/standard_scaler.py +5 -5
  154. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +26 -26
  155. snowflake/ml/monitoring/_manager/model_monitor_manager.py +5 -5
  156. snowflake/ml/monitoring/entities/model_monitor_config.py +6 -6
  157. snowflake/ml/registry/_manager/model_manager.py +50 -29
  158. snowflake/ml/registry/registry.py +34 -23
  159. snowflake/ml/utils/authentication.py +2 -2
  160. snowflake/ml/utils/connection_params.py +5 -5
  161. snowflake/ml/utils/sparse.py +5 -4
  162. snowflake/ml/utils/sql_client.py +1 -2
  163. snowflake/ml/version.py +2 -1
  164. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/METADATA +46 -6
  165. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/RECORD +168 -164
  166. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/WHEEL +1 -1
  167. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +0 -1
  168. snowflake/ml/modeling/_internal/constants.py +0 -2
  169. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  170. {snowflake_ml_python-1.8.1.dist-info → snowflake_ml_python-1.8.3.dist-info}/top_level.txt +0 -0
@@ -6,14 +6,16 @@ import re
6
6
  import tempfile
7
7
  import threading
8
8
  import time
9
- from typing import Any, Dict, List, Optional, Tuple, Union, cast
9
+ from typing import Any, Optional, Union, cast
10
10
 
11
11
  from snowflake import snowpark
12
- from snowflake.ml._internal import file_utils
13
- from snowflake.ml._internal.utils import service_logger, sql_identifier
12
+ from snowflake.ml._internal import file_utils, platform_capabilities as pc
13
+ from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
14
+ from snowflake.ml.model import model_signature, type_hints
14
15
  from snowflake.ml.model._client.service import model_deployment_spec
15
16
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
16
- from snowflake.snowpark import async_job, exceptions, row, session
17
+ from snowflake.ml.model._signatures import snowpark_handler
18
+ from snowflake.snowpark import async_job, dataframe, exceptions, row, session
17
19
  from snowflake.snowpark._internal import utils as snowpark_utils
18
20
 
19
21
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -57,30 +59,30 @@ class ServiceOperator:
57
59
  self._session = session
58
60
  self._database_name = database_name
59
61
  self._schema_name = schema_name
60
- self._workspace = tempfile.TemporaryDirectory()
61
62
  self._service_client = service_sql.ServiceSQLClient(
62
63
  session,
63
64
  database_name=database_name,
64
65
  schema_name=schema_name,
65
66
  )
66
- self._stage_client = stage_sql.StageSQLClient(
67
- session,
68
- database_name=database_name,
69
- schema_name=schema_name,
70
- )
71
- self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
72
- workspace_path=pathlib.Path(self._workspace.name)
73
- )
67
+ if pc.PlatformCapabilities.get_instance().is_inlined_deployment_spec_enabled():
68
+ self._workspace = None
69
+ self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec()
70
+ else:
71
+ self._workspace = tempfile.TemporaryDirectory()
72
+ self._stage_client = stage_sql.StageSQLClient(
73
+ session,
74
+ database_name=database_name,
75
+ schema_name=schema_name,
76
+ )
77
+ self._model_deployment_spec = model_deployment_spec.ModelDeploymentSpec(
78
+ workspace_path=pathlib.Path(self._workspace.name)
79
+ )
74
80
 
75
81
  def __eq__(self, __value: object) -> bool:
76
82
  if not isinstance(__value, ServiceOperator):
77
83
  return False
78
84
  return self._service_client == __value._service_client
79
85
 
80
- @property
81
- def workspace_path(self) -> pathlib.Path:
82
- return pathlib.Path(self._workspace.name)
83
-
84
86
  def create_service(
85
87
  self,
86
88
  *,
@@ -104,9 +106,9 @@ class ServiceOperator:
104
106
  num_workers: Optional[int],
105
107
  max_batch_rows: Optional[int],
106
108
  force_rebuild: bool,
107
- build_external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
109
+ build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
108
110
  block: bool,
109
- statement_params: Optional[Dict[str, Any]] = None,
111
+ statement_params: Optional[dict[str, Any]] = None,
110
112
  ) -> Union[str, async_job.AsyncJob]:
111
113
 
112
114
  # Fall back to the registry's database and schema if not provided
@@ -119,19 +121,11 @@ class ServiceOperator:
119
121
 
120
122
  image_repo_database_name = image_repo_database_name or database_name or self._database_name
121
123
  image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
122
- # create a temp stage
123
- stage_name = sql_identifier.SqlIdentifier(
124
- snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
125
- )
126
- self._stage_client.create_tmp_stage(
127
- database_name=database_name,
128
- schema_name=schema_name,
129
- stage_name=stage_name,
130
- statement_params=statement_params,
131
- )
132
- stage_path = self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name)
133
-
134
- self._model_deployment_spec.save(
124
+ if self._workspace:
125
+ stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
126
+ else:
127
+ stage_path = None
128
+ spec_yaml_str_or_path = self._model_deployment_spec.save(
135
129
  database_name=database_name,
136
130
  schema_name=schema_name,
137
131
  model_name=model_name,
@@ -140,7 +134,7 @@ class ServiceOperator:
140
134
  service_schema_name=service_schema_name,
141
135
  service_name=service_name,
142
136
  image_build_compute_pool_name=image_build_compute_pool_name,
143
- service_compute_pool_name=service_compute_pool_name,
137
+ inference_compute_pool_name=service_compute_pool_name,
144
138
  image_repo_database_name=image_repo_database_name,
145
139
  image_repo_schema_name=image_repo_schema_name,
146
140
  image_repo_name=image_repo_name,
@@ -154,12 +148,14 @@ class ServiceOperator:
154
148
  force_rebuild=force_rebuild,
155
149
  external_access_integrations=build_external_access_integrations,
156
150
  )
157
- file_utils.upload_directory_to_stage(
158
- self._session,
159
- local_path=self.workspace_path,
160
- stage_path=pathlib.PurePosixPath(stage_path),
161
- statement_params=statement_params,
162
- )
151
+ if self._workspace:
152
+ assert stage_path is not None
153
+ file_utils.upload_directory_to_stage(
154
+ self._session,
155
+ local_path=pathlib.Path(self._workspace.name),
156
+ stage_path=pathlib.PurePosixPath(stage_path),
157
+ statement_params=statement_params,
158
+ )
163
159
 
164
160
  # check if the inference service is already running/suspended
165
161
  model_inference_service_exists = self._check_if_service_exists(
@@ -176,8 +172,11 @@ class ServiceOperator:
176
172
 
177
173
  # deploy the model service
178
174
  query_id, async_job = self._service_client.deploy_model(
179
- stage_path=stage_path,
180
- model_deployment_spec_file_rel_path=model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH,
175
+ stage_path=stage_path if self._workspace else None,
176
+ model_deployment_spec_file_rel_path=(
177
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
178
+ ),
179
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
181
180
  statement_params=statement_params,
182
181
  )
183
182
 
@@ -203,7 +202,7 @@ class ServiceOperator:
203
202
  if block:
204
203
  log_thread.join()
205
204
 
206
- res = cast(str, cast(List[row.Row], async_job.result())[0][0])
205
+ res = cast(str, cast(list[row.Row], async_job.result())[0][0])
207
206
  module_logger.info(f"Inference service {service_name} deployment complete: {res}")
208
207
  return res
209
208
  else:
@@ -212,10 +211,10 @@ class ServiceOperator:
212
211
  def _start_service_log_streaming(
213
212
  self,
214
213
  async_job: snowpark.AsyncJob,
215
- services: List[ServiceLogInfo],
214
+ services: list[ServiceLogInfo],
216
215
  model_inference_service_exists: bool,
217
216
  force_rebuild: bool,
218
- statement_params: Optional[Dict[str, Any]] = None,
217
+ statement_params: Optional[dict[str, Any]] = None,
219
218
  ) -> threading.Thread:
220
219
  """Start the service log streaming in a separate thread."""
221
220
  log_thread = threading.Thread(
@@ -234,14 +233,14 @@ class ServiceOperator:
234
233
  def _stream_service_logs(
235
234
  self,
236
235
  async_job: snowpark.AsyncJob,
237
- services: List[ServiceLogInfo],
236
+ services: list[ServiceLogInfo],
238
237
  model_inference_service_exists: bool,
239
238
  force_rebuild: bool,
240
- statement_params: Optional[Dict[str, Any]] = None,
239
+ statement_params: Optional[dict[str, Any]] = None,
241
240
  ) -> None:
242
241
  """Stream service logs while the async job is running."""
243
242
 
244
- def fetch_logs(service: ServiceLogInfo, offset: int) -> Tuple[str, int]:
243
+ def fetch_logs(service: ServiceLogInfo, offset: int) -> tuple[str, int]:
245
244
  service_logs = self._service_client.get_service_logs(
246
245
  database_name=service.database_name,
247
246
  schema_name=service.schema_name,
@@ -386,7 +385,7 @@ class ServiceOperator:
386
385
  service_logger: logging.Logger,
387
386
  service: ServiceLogInfo,
388
387
  offset: int,
389
- statement_params: Optional[Dict[str, Any]] = None,
388
+ statement_params: Optional[dict[str, Any]] = None,
390
389
  ) -> None:
391
390
  """Fetch service logs after the async job is done to ensure no logs are missed."""
392
391
  try:
@@ -418,8 +417,8 @@ class ServiceOperator:
418
417
  database_name: Optional[sql_identifier.SqlIdentifier],
419
418
  schema_name: Optional[sql_identifier.SqlIdentifier],
420
419
  service_name: sql_identifier.SqlIdentifier,
421
- service_status_list_if_exists: Optional[List[service_sql.ServiceStatus]] = None,
422
- statement_params: Optional[Dict[str, Any]] = None,
420
+ service_status_list_if_exists: Optional[list[service_sql.ServiceStatus]] = None,
421
+ statement_params: Optional[dict[str, Any]] = None,
423
422
  ) -> bool:
424
423
  if service_status_list_if_exists is None:
425
424
  service_status_list_if_exists = [
@@ -441,3 +440,184 @@ class ServiceOperator:
441
440
  return any(service_status == status for status in service_status_list_if_exists)
442
441
  except exceptions.SnowparkSQLException:
443
442
  return False
443
+
444
+ def invoke_job_method(
445
+ self,
446
+ target_method: str,
447
+ signature: model_signature.ModelSignature,
448
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
449
+ database_name: Optional[sql_identifier.SqlIdentifier],
450
+ schema_name: Optional[sql_identifier.SqlIdentifier],
451
+ model_name: sql_identifier.SqlIdentifier,
452
+ version_name: sql_identifier.SqlIdentifier,
453
+ job_database_name: Optional[sql_identifier.SqlIdentifier],
454
+ job_schema_name: Optional[sql_identifier.SqlIdentifier],
455
+ job_name: sql_identifier.SqlIdentifier,
456
+ compute_pool_name: sql_identifier.SqlIdentifier,
457
+ warehouse_name: sql_identifier.SqlIdentifier,
458
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
459
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
460
+ image_repo_name: sql_identifier.SqlIdentifier,
461
+ output_table_database_name: Optional[sql_identifier.SqlIdentifier],
462
+ output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
463
+ output_table_name: sql_identifier.SqlIdentifier,
464
+ cpu_requests: Optional[str],
465
+ memory_requests: Optional[str],
466
+ gpu_requests: Optional[Union[int, str]],
467
+ num_workers: Optional[int],
468
+ max_batch_rows: Optional[int],
469
+ force_rebuild: bool,
470
+ build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
471
+ statement_params: Optional[dict[str, Any]] = None,
472
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
473
+ # fall back to the registry's database and schema if not provided
474
+ database_name = database_name or self._database_name
475
+ schema_name = schema_name or self._schema_name
476
+
477
+ # fall back to the model's database and schema if not provided then to the registry's database and schema
478
+ job_database_name = job_database_name or database_name or self._database_name
479
+ job_schema_name = job_schema_name or schema_name or self._schema_name
480
+
481
+ image_repo_database_name = image_repo_database_name or database_name or self._database_name
482
+ image_repo_schema_name = image_repo_schema_name or schema_name or self._schema_name
483
+
484
+ input_table_database_name = job_database_name
485
+ input_table_schema_name = job_schema_name
486
+ output_table_database_name = output_table_database_name or database_name or self._database_name
487
+ output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
488
+
489
+ if self._workspace:
490
+ stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
491
+ else:
492
+ stage_path = None
493
+
494
+ # validate and prepare input
495
+ if not isinstance(X, dataframe.DataFrame):
496
+ keep_order = True
497
+ output_with_input_features = False
498
+ df = model_signature._convert_and_validate_local_data(X, signature.inputs)
499
+ s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
500
+ self._session, df, keep_order=keep_order, features=signature.inputs
501
+ )
502
+ else:
503
+ keep_order = False
504
+ output_with_input_features = True
505
+ s_df = X
506
+
507
+ # only write the index and feature input columns
508
+ cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
509
+ cols += [
510
+ sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
511
+ ]
512
+ s_df = s_df.select(cols)
513
+ original_cols = s_df.columns
514
+
515
+ # input/output tables
516
+ fq_output_table_name = identifier.get_schema_level_object_identifier(
517
+ output_table_database_name.identifier(),
518
+ output_table_schema_name.identifier(),
519
+ output_table_name.identifier(),
520
+ )
521
+ tmp_input_table_id = sql_identifier.SqlIdentifier(
522
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
523
+ )
524
+ fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
525
+ job_database_name.identifier(),
526
+ job_schema_name.identifier(),
527
+ tmp_input_table_id.identifier(),
528
+ )
529
+ s_df.write.save_as_table(
530
+ table_name=fq_tmp_input_table_name,
531
+ mode="errorifexists",
532
+ statement_params=statement_params,
533
+ )
534
+
535
+ try:
536
+ # save the spec
537
+ spec_yaml_str_or_path = self._model_deployment_spec.save(
538
+ database_name=database_name,
539
+ schema_name=schema_name,
540
+ model_name=model_name,
541
+ version_name=version_name,
542
+ job_database_name=job_database_name,
543
+ job_schema_name=job_schema_name,
544
+ job_name=job_name,
545
+ image_build_compute_pool_name=compute_pool_name,
546
+ inference_compute_pool_name=compute_pool_name,
547
+ image_repo_database_name=image_repo_database_name,
548
+ image_repo_schema_name=image_repo_schema_name,
549
+ image_repo_name=image_repo_name,
550
+ cpu=cpu_requests,
551
+ memory=memory_requests,
552
+ gpu=gpu_requests,
553
+ num_workers=num_workers,
554
+ max_batch_rows=max_batch_rows,
555
+ force_rebuild=force_rebuild,
556
+ external_access_integrations=build_external_access_integrations,
557
+ warehouse=warehouse_name,
558
+ target_method=target_method,
559
+ input_table_database_name=input_table_database_name,
560
+ input_table_schema_name=input_table_schema_name,
561
+ input_table_name=tmp_input_table_id,
562
+ output_table_database_name=output_table_database_name,
563
+ output_table_schema_name=output_table_schema_name,
564
+ output_table_name=output_table_name,
565
+ )
566
+ if self._workspace:
567
+ assert stage_path is not None
568
+ file_utils.upload_directory_to_stage(
569
+ self._session,
570
+ local_path=pathlib.Path(self._workspace.name),
571
+ stage_path=pathlib.PurePosixPath(stage_path),
572
+ statement_params=statement_params,
573
+ )
574
+
575
+ # deploy the job
576
+ query_id, async_job = self._service_client.deploy_model(
577
+ stage_path=stage_path if self._workspace else None,
578
+ model_deployment_spec_file_rel_path=(
579
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
580
+ ),
581
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
582
+ statement_params=statement_params,
583
+ )
584
+
585
+ while not async_job.is_done():
586
+ time.sleep(5)
587
+ finally:
588
+ self._session.table(fq_tmp_input_table_name).drop_table()
589
+
590
+ # handle the output
591
+ df_res = self._session.table(fq_output_table_name)
592
+ if keep_order:
593
+ df_res = df_res.sort(
594
+ snowpark_handler._KEEP_ORDER_COL_NAME,
595
+ ascending=True,
596
+ )
597
+ df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
598
+
599
+ if not output_with_input_features:
600
+ df_res = df_res.drop(*original_cols)
601
+
602
+ # get final result
603
+ if not isinstance(X, dataframe.DataFrame):
604
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(df_res, features=signature.outputs)
605
+ else:
606
+ return df_res
607
+
608
+ def _create_temp_stage(
609
+ self,
610
+ database_name: Optional[sql_identifier.SqlIdentifier],
611
+ schema_name: Optional[sql_identifier.SqlIdentifier],
612
+ statement_params: Optional[dict[str, Any]] = None,
613
+ ) -> str:
614
+ stage_name = sql_identifier.SqlIdentifier(
615
+ snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.STAGE)
616
+ )
617
+ self._stage_client.create_tmp_stage(
618
+ database_name=database_name,
619
+ schema_name=schema_name,
620
+ stage_name=stage_name,
621
+ statement_params=statement_params,
622
+ )
623
+ return self._stage_client.fully_qualified_object_name(database_name, schema_name, stage_name) # stage path
@@ -1,5 +1,5 @@
1
1
  import pathlib
2
- from typing import List, Optional, Union
2
+ from typing import Any, Optional, Union, overload
3
3
 
4
4
  import yaml
5
5
 
@@ -16,9 +16,10 @@ class ModelDeploymentSpec:
16
16
 
17
17
  DEPLOY_SPEC_FILE_REL_PATH = "deploy.yml"
18
18
 
19
- def __init__(self, workspace_path: pathlib.Path) -> None:
19
+ def __init__(self, workspace_path: Optional[pathlib.Path] = None) -> None:
20
20
  self.workspace_path = workspace_path
21
21
 
22
+ @overload
22
23
  def save(
23
24
  self,
24
25
  *,
@@ -26,88 +27,214 @@ class ModelDeploymentSpec:
26
27
  schema_name: sql_identifier.SqlIdentifier,
27
28
  model_name: sql_identifier.SqlIdentifier,
28
29
  version_name: sql_identifier.SqlIdentifier,
29
- service_database_name: Optional[sql_identifier.SqlIdentifier],
30
- service_schema_name: Optional[sql_identifier.SqlIdentifier],
30
+ service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
31
+ service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
31
32
  service_name: sql_identifier.SqlIdentifier,
33
+ inference_compute_pool_name: sql_identifier.SqlIdentifier,
32
34
  image_build_compute_pool_name: sql_identifier.SqlIdentifier,
33
- service_compute_pool_name: sql_identifier.SqlIdentifier,
34
35
  image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
35
36
  image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
36
37
  image_repo_name: sql_identifier.SqlIdentifier,
38
+ cpu: Optional[str],
39
+ memory: Optional[str],
40
+ gpu: Optional[Union[str, int]],
41
+ num_workers: Optional[int],
42
+ max_batch_rows: Optional[int],
43
+ force_rebuild: bool,
44
+ external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
45
+ # service spec
37
46
  ingress_enabled: bool,
38
47
  max_instances: int,
48
+ ) -> str:
49
+ ...
50
+
51
+ @overload
52
+ def save(
53
+ self,
54
+ *,
55
+ database_name: sql_identifier.SqlIdentifier,
56
+ schema_name: sql_identifier.SqlIdentifier,
57
+ model_name: sql_identifier.SqlIdentifier,
58
+ version_name: sql_identifier.SqlIdentifier,
59
+ job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
60
+ job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
61
+ job_name: sql_identifier.SqlIdentifier,
62
+ inference_compute_pool_name: sql_identifier.SqlIdentifier,
63
+ image_build_compute_pool_name: sql_identifier.SqlIdentifier,
64
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
65
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
66
+ image_repo_name: sql_identifier.SqlIdentifier,
39
67
  cpu: Optional[str],
40
68
  memory: Optional[str],
41
69
  gpu: Optional[Union[str, int]],
42
70
  num_workers: Optional[int],
43
71
  max_batch_rows: Optional[int],
44
72
  force_rebuild: bool,
45
- external_access_integrations: Optional[List[sql_identifier.SqlIdentifier]],
46
- ) -> None:
73
+ external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
74
+ # job spec
75
+ warehouse: sql_identifier.SqlIdentifier,
76
+ target_method: str,
77
+ input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
78
+ input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
79
+ input_table_name: sql_identifier.SqlIdentifier,
80
+ output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
81
+ output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
82
+ output_table_name: sql_identifier.SqlIdentifier,
83
+ ) -> str:
84
+ ...
85
+
86
+ def save(
87
+ self,
88
+ *,
89
+ database_name: sql_identifier.SqlIdentifier,
90
+ schema_name: sql_identifier.SqlIdentifier,
91
+ model_name: sql_identifier.SqlIdentifier,
92
+ version_name: sql_identifier.SqlIdentifier,
93
+ service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
94
+ service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
95
+ service_name: Optional[sql_identifier.SqlIdentifier] = None,
96
+ job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
97
+ job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
98
+ job_name: Optional[sql_identifier.SqlIdentifier] = None,
99
+ inference_compute_pool_name: sql_identifier.SqlIdentifier,
100
+ image_build_compute_pool_name: sql_identifier.SqlIdentifier,
101
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
102
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
103
+ image_repo_name: sql_identifier.SqlIdentifier,
104
+ cpu: Optional[str],
105
+ memory: Optional[str],
106
+ gpu: Optional[Union[str, int]],
107
+ num_workers: Optional[int],
108
+ max_batch_rows: Optional[int],
109
+ force_rebuild: bool,
110
+ external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
111
+ # service spec
112
+ ingress_enabled: Optional[bool] = None,
113
+ max_instances: Optional[int] = None,
114
+ # job spec
115
+ warehouse: Optional[sql_identifier.SqlIdentifier] = None,
116
+ target_method: Optional[str] = None,
117
+ input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
118
+ input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
119
+ input_table_name: Optional[sql_identifier.SqlIdentifier] = None,
120
+ output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
121
+ output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
122
+ output_table_name: Optional[sql_identifier.SqlIdentifier] = None,
123
+ ) -> str:
47
124
  # create the deployment spec
48
125
  # models spec
49
126
  fq_model_name = identifier.get_schema_level_object_identifier(
50
127
  database_name.identifier(), schema_name.identifier(), model_name.identifier()
51
128
  )
52
- model_dict = model_deployment_spec_schema.ModelDict(name=fq_model_name, version=version_name.identifier())
129
+ model = model_deployment_spec_schema.Model(name=fq_model_name, version=version_name.identifier())
53
130
 
54
131
  # image_build spec
55
132
  saved_image_repo_database = image_repo_database_name or database_name
56
133
  saved_image_repo_schema = image_repo_schema_name or schema_name
57
134
  fq_image_repo_name = identifier.get_schema_level_object_identifier(
58
- saved_image_repo_database.identifier(), saved_image_repo_schema.identifier(), image_repo_name.identifier()
135
+ db=saved_image_repo_database.identifier(),
136
+ schema=saved_image_repo_schema.identifier(),
137
+ object_name=image_repo_name.identifier(),
59
138
  )
60
- image_build_dict: model_deployment_spec_schema.ImageBuildDict = {
61
- "compute_pool": image_build_compute_pool_name.identifier(),
62
- "image_repo": fq_image_repo_name,
63
- "force_rebuild": force_rebuild,
64
- }
65
- if external_access_integrations is not None:
66
- image_build_dict["external_access_integrations"] = [
67
- eai.identifier() for eai in external_access_integrations
68
- ]
69
139
 
70
- # service spec
71
- saved_service_database = service_database_name or database_name
72
- saved_service_schema = service_schema_name or schema_name
73
- fq_service_name = identifier.get_schema_level_object_identifier(
74
- saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
140
+ image_build = model_deployment_spec_schema.ImageBuild(
141
+ compute_pool=image_build_compute_pool_name.identifier(),
142
+ image_repo=fq_image_repo_name,
143
+ force_rebuild=force_rebuild,
144
+ external_access_integrations=(
145
+ [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
146
+ ),
75
147
  )
76
- service_dict = model_deployment_spec_schema.ServiceDict(
77
- name=fq_service_name,
78
- compute_pool=service_compute_pool_name.identifier(),
79
- ingress_enabled=ingress_enabled,
80
- max_instances=max_instances,
81
- )
82
- if cpu:
83
- service_dict["cpu"] = cpu
84
148
 
149
+ # universal base inference spec in service and job
150
+ base_inference_spec: dict[str, Any] = {}
151
+ if cpu:
152
+ base_inference_spec["cpu"] = cpu
85
153
  if memory:
86
- service_dict["memory"] = memory
87
-
154
+ base_inference_spec["memory"] = memory
88
155
  if gpu:
89
156
  if isinstance(gpu, int):
90
157
  gpu_str = str(gpu)
91
158
  else:
92
159
  gpu_str = gpu
93
- service_dict["gpu"] = gpu_str
94
-
160
+ base_inference_spec["gpu"] = gpu_str
95
161
  if num_workers:
96
- service_dict["num_workers"] = num_workers
97
-
162
+ base_inference_spec["num_workers"] = num_workers
98
163
  if max_batch_rows:
99
- service_dict["max_batch_rows"] = max_batch_rows
100
-
101
- # model deployment spec
102
- model_deployment_spec_dict = model_deployment_spec_schema.ModelDeploymentSpecDict(
103
- models=[model_dict],
104
- image_build=image_build_dict,
105
- service=service_dict,
106
- )
164
+ base_inference_spec["max_batch_rows"] = max_batch_rows
165
+
166
+ if service_name: # service spec
167
+ assert ingress_enabled, "ingress_enabled is required for service spec"
168
+ assert max_instances, "max_instances is required for service spec"
169
+ saved_service_database = service_database_name or database_name
170
+ saved_service_schema = service_schema_name or schema_name
171
+ fq_service_name = identifier.get_schema_level_object_identifier(
172
+ saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
173
+ )
174
+ service = model_deployment_spec_schema.Service(
175
+ name=fq_service_name,
176
+ compute_pool=inference_compute_pool_name.identifier(),
177
+ ingress_enabled=ingress_enabled,
178
+ max_instances=max_instances,
179
+ **base_inference_spec,
180
+ )
181
+
182
+ # model deployment spec
183
+ model_deployment_spec: Union[
184
+ model_deployment_spec_schema.ModelServiceDeploymentSpec,
185
+ model_deployment_spec_schema.ModelJobDeploymentSpec,
186
+ ] = model_deployment_spec_schema.ModelServiceDeploymentSpec(
187
+ models=[model],
188
+ image_build=image_build,
189
+ service=service,
190
+ )
191
+ else: # job spec
192
+ assert job_name, "job_name is required for job spec"
193
+ assert warehouse, "warehouse is required for job spec"
194
+ assert target_method, "target_method is required for job spec"
195
+ assert input_table_name, "input_table_name is required for job spec"
196
+ assert output_table_name, "output_table_name is required for job spec"
197
+ saved_job_database = job_database_name or database_name
198
+ saved_job_schema = job_schema_name or schema_name
199
+ input_table_database_name = input_table_database_name or database_name
200
+ input_table_schema_name = input_table_schema_name or schema_name
201
+ output_table_database_name = output_table_database_name or database_name
202
+ output_table_schema_name = output_table_schema_name or schema_name
203
+ fq_job_name = identifier.get_schema_level_object_identifier(
204
+ saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
205
+ )
206
+ fq_input_table_name = identifier.get_schema_level_object_identifier(
207
+ input_table_database_name.identifier(),
208
+ input_table_schema_name.identifier(),
209
+ input_table_name.identifier(),
210
+ )
211
+ fq_output_table_name = identifier.get_schema_level_object_identifier(
212
+ output_table_database_name.identifier(),
213
+ output_table_schema_name.identifier(),
214
+ output_table_name.identifier(),
215
+ )
216
+ job = model_deployment_spec_schema.Job(
217
+ name=fq_job_name,
218
+ compute_pool=inference_compute_pool_name.identifier(),
219
+ warehouse=warehouse.identifier(),
220
+ target_method=target_method,
221
+ input_table_name=fq_input_table_name,
222
+ output_table_name=fq_output_table_name,
223
+ **base_inference_spec,
224
+ )
225
+
226
+ # model deployment spec
227
+ model_deployment_spec = model_deployment_spec_schema.ModelJobDeploymentSpec(
228
+ models=[model],
229
+ image_build=image_build,
230
+ job=job,
231
+ )
232
+
233
+ if self.workspace_path is None:
234
+ return yaml.safe_dump(model_deployment_spec.model_dump(exclude_none=True))
107
235
 
108
236
  # save the yaml
109
237
  file_path = self.workspace_path / self.DEPLOY_SPEC_FILE_REL_PATH
110
238
  with file_path.open("w", encoding="utf-8") as f:
111
- # Anchors are not supported in the server, avoid that.
112
- yaml.SafeDumper.ignore_aliases = lambda *args: True # type: ignore[method-assign]
113
- yaml.safe_dump(model_deployment_spec_dict, f)
239
+ yaml.safe_dump(model_deployment_spec.model_dump(exclude_none=True), f)
240
+ return str(file_path.resolve())