snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/utils/identifier.py +2 -2
  4. snowflake/ml/jobs/_utils/constants.py +1 -1
  5. snowflake/ml/jobs/_utils/payload_utils.py +39 -30
  6. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  7. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +1 -1
  8. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  9. snowflake/ml/jobs/decorators.py +6 -0
  10. snowflake/ml/jobs/job.py +63 -16
  11. snowflake/ml/jobs/manager.py +50 -16
  12. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  13. snowflake/ml/model/_client/ops/service_ops.py +26 -14
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +340 -170
  15. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  16. snowflake/ml/model/_client/sql/service.py +4 -13
  17. snowflake/ml/model/_model_composer/model_composer.py +41 -18
  18. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  19. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  20. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  22. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  23. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  24. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  25. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +4 -4
  28. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  29. snowflake/ml/model/custom_model.py +17 -4
  30. snowflake/ml/model/model_signature.py +3 -3
  31. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  32. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  33. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  34. snowflake/ml/modeling/cluster/birch.py +9 -1
  35. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  36. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  37. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  38. snowflake/ml/modeling/cluster/k_means.py +9 -1
  39. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  40. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/optics.py +9 -1
  42. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  43. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  44. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  45. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  46. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  47. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  48. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  49. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  51. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  52. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  53. snowflake/ml/modeling/covariance/oas.py +9 -1
  54. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  55. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  56. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  57. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  58. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  59. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  60. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  62. snowflake/ml/modeling/decomposition/pca.py +9 -1
  63. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  65. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  66. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  67. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  69. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  70. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  71. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  73. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  77. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  78. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  81. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  82. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  83. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  84. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  85. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  86. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  87. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  88. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  89. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  90. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  91. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  93. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  94. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  95. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  96. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  97. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  98. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  99. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  100. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  104. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  106. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  108. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  110. snowflake/ml/modeling/linear_model/lars.py +9 -1
  111. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  112. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  113. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  114. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  117. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  118. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  120. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  122. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  124. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  125. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  127. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  128. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  129. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  130. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  131. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  133. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  134. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  135. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  136. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  137. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  138. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  139. snowflake/ml/modeling/manifold/isomap.py +9 -1
  140. snowflake/ml/modeling/manifold/mds.py +9 -1
  141. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  142. snowflake/ml/modeling/manifold/tsne.py +9 -1
  143. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  144. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  145. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  146. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  147. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  148. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  149. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  150. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  151. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  152. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  153. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  155. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  156. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  157. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  158. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  159. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  160. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  162. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  163. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  164. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  165. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  166. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  167. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  168. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  169. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  170. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  171. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  172. snowflake/ml/modeling/svm/svc.py +9 -1
  173. snowflake/ml/modeling/svm/svr.py +9 -1
  174. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  175. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  176. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  177. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  178. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  179. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  180. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  182. snowflake/ml/monitoring/explain_visualize.py +286 -0
  183. snowflake/ml/registry/_manager/model_manager.py +23 -2
  184. snowflake/ml/registry/registry.py +10 -9
  185. snowflake/ml/version.py +1 -1
  186. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +40 -8
  187. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/RECORD +190 -189
  188. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  189. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  190. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
+ import json
1
2
  import pathlib
2
- from typing import Any, Optional, Union, overload
3
+ from typing import Any, Optional, Union
3
4
 
4
5
  import yaml
5
6
 
@@ -18,126 +19,78 @@ class ModelDeploymentSpec:
18
19
 
19
20
  def __init__(self, workspace_path: Optional[pathlib.Path] = None) -> None:
20
21
  self.workspace_path = workspace_path
22
+ self._models: list[model_deployment_spec_schema.Model] = []
23
+ self._image_build: Optional[model_deployment_spec_schema.ImageBuild] = None
24
+ self._service: Optional[model_deployment_spec_schema.Service] = None
25
+ self._job: Optional[model_deployment_spec_schema.Job] = None
26
+ self._model_loggings: Optional[list[model_deployment_spec_schema.ModelLogging]] = None
27
+ self._inference_spec: dict[str, Any] = {} # Common inference spec for service/job
21
28
 
22
- @overload
23
- def save(
24
- self,
25
- *,
26
- database_name: sql_identifier.SqlIdentifier,
27
- schema_name: sql_identifier.SqlIdentifier,
28
- model_name: sql_identifier.SqlIdentifier,
29
- version_name: sql_identifier.SqlIdentifier,
30
- service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
31
- service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
32
- service_name: sql_identifier.SqlIdentifier,
33
- inference_compute_pool_name: sql_identifier.SqlIdentifier,
34
- image_build_compute_pool_name: sql_identifier.SqlIdentifier,
35
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
36
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
37
- image_repo_name: sql_identifier.SqlIdentifier,
38
- cpu: Optional[str],
39
- memory: Optional[str],
40
- gpu: Optional[Union[str, int]],
41
- num_workers: Optional[int],
42
- max_batch_rows: Optional[int],
43
- force_rebuild: bool,
44
- external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
45
- # service spec
46
- ingress_enabled: bool,
47
- max_instances: int,
48
- ) -> str:
49
- ...
50
-
51
- @overload
52
- def save(
53
- self,
54
- *,
55
- database_name: sql_identifier.SqlIdentifier,
56
- schema_name: sql_identifier.SqlIdentifier,
57
- model_name: sql_identifier.SqlIdentifier,
58
- version_name: sql_identifier.SqlIdentifier,
59
- job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
60
- job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
61
- job_name: sql_identifier.SqlIdentifier,
62
- inference_compute_pool_name: sql_identifier.SqlIdentifier,
63
- image_build_compute_pool_name: sql_identifier.SqlIdentifier,
64
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
65
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
66
- image_repo_name: sql_identifier.SqlIdentifier,
67
- cpu: Optional[str],
68
- memory: Optional[str],
69
- gpu: Optional[Union[str, int]],
70
- num_workers: Optional[int],
71
- max_batch_rows: Optional[int],
72
- force_rebuild: bool,
73
- external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
74
- # job spec
75
- warehouse: sql_identifier.SqlIdentifier,
76
- target_method: str,
77
- input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
78
- input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
79
- input_table_name: sql_identifier.SqlIdentifier,
80
- output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
81
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
82
- output_table_name: sql_identifier.SqlIdentifier,
83
- ) -> str:
84
- ...
29
+ self.database: Optional[sql_identifier.SqlIdentifier] = None
30
+ self.schema: Optional[sql_identifier.SqlIdentifier] = None
85
31
 
86
- def save(
32
+ def add_model_spec(
87
33
  self,
88
- *,
89
34
  database_name: sql_identifier.SqlIdentifier,
90
35
  schema_name: sql_identifier.SqlIdentifier,
91
36
  model_name: sql_identifier.SqlIdentifier,
92
37
  version_name: sql_identifier.SqlIdentifier,
93
- service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
94
- service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
95
- service_name: Optional[sql_identifier.SqlIdentifier] = None,
96
- job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
97
- job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
98
- job_name: Optional[sql_identifier.SqlIdentifier] = None,
99
- inference_compute_pool_name: sql_identifier.SqlIdentifier,
100
- image_build_compute_pool_name: sql_identifier.SqlIdentifier,
101
- image_repo_database_name: Optional[sql_identifier.SqlIdentifier],
102
- image_repo_schema_name: Optional[sql_identifier.SqlIdentifier],
103
- image_repo_name: sql_identifier.SqlIdentifier,
104
- cpu: Optional[str],
105
- memory: Optional[str],
106
- gpu: Optional[Union[str, int]],
107
- num_workers: Optional[int],
108
- max_batch_rows: Optional[int],
109
- force_rebuild: bool,
110
- external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
111
- # service spec
112
- ingress_enabled: Optional[bool] = None,
113
- max_instances: Optional[int] = None,
114
- # job spec
115
- warehouse: Optional[sql_identifier.SqlIdentifier] = None,
116
- target_method: Optional[str] = None,
117
- input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
118
- input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
119
- input_table_name: Optional[sql_identifier.SqlIdentifier] = None,
120
- output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
121
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
122
- output_table_name: Optional[sql_identifier.SqlIdentifier] = None,
123
- ) -> str:
124
- # create the deployment spec
125
- # models spec
38
+ ) -> "ModelDeploymentSpec":
39
+ """Add model specification to the deployment spec.
40
+
41
+ Args:
42
+ database_name: Database name containing the model.
43
+ schema_name: Schema name containing the model.
44
+ model_name: Name of the model.
45
+ version_name: Version of the model.
46
+
47
+ Returns:
48
+ Self for chaining.
49
+ """
126
50
  fq_model_name = identifier.get_schema_level_object_identifier(
127
51
  database_name.identifier(), schema_name.identifier(), model_name.identifier()
128
52
  )
53
+ if not self.database:
54
+ self.database = database_name
55
+ if not self.schema:
56
+ self.schema = schema_name
129
57
  model = model_deployment_spec_schema.Model(name=fq_model_name, version=version_name.identifier())
58
+ self._models.append(model)
59
+ return self
60
+
61
+ def add_image_build_spec(
62
+ self,
63
+ image_build_compute_pool_name: sql_identifier.SqlIdentifier,
64
+ image_repo_name: sql_identifier.SqlIdentifier,
65
+ image_repo_database_name: Optional[sql_identifier.SqlIdentifier] = None,
66
+ image_repo_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
67
+ force_rebuild: bool = False,
68
+ external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]] = None,
69
+ ) -> "ModelDeploymentSpec":
70
+ """Add image build specification to the deployment spec.
130
71
 
131
- # image_build spec
132
- saved_image_repo_database = image_repo_database_name or database_name
133
- saved_image_repo_schema = image_repo_schema_name or schema_name
72
+ Args:
73
+ image_build_compute_pool_name: Compute pool for image building.
74
+ image_repo_name: Name of the image repository.
75
+ image_repo_database_name: Database name for the image repository.
76
+ image_repo_schema_name: Schema name for the image repository.
77
+ force_rebuild: Whether to force rebuilding the image.
78
+ external_access_integrations: List of external access integrations.
79
+
80
+ Returns:
81
+ Self for chaining.
82
+ """
83
+ saved_image_repo_database = image_repo_database_name or self.database
84
+ saved_image_repo_schema = image_repo_schema_name or self.schema
85
+ assert saved_image_repo_database is not None
86
+ assert saved_image_repo_schema is not None
134
87
  fq_image_repo_name = identifier.get_schema_level_object_identifier(
135
88
  db=saved_image_repo_database.identifier(),
136
89
  schema=saved_image_repo_schema.identifier(),
137
90
  object_name=image_repo_name.identifier(),
138
91
  )
139
92
 
140
- image_build = model_deployment_spec_schema.ImageBuild(
93
+ self._image_build = model_deployment_spec_schema.ImageBuild(
141
94
  compute_pool=image_build_compute_pool_name.identifier(),
142
95
  image_repo=fq_image_repo_name,
143
96
  force_rebuild=force_rebuild,
@@ -145,96 +98,313 @@ class ModelDeploymentSpec:
145
98
  [eai.identifier() for eai in external_access_integrations] if external_access_integrations else None
146
99
  ),
147
100
  )
101
+ return self
148
102
 
149
- # universal base inference spec in service and job
150
- base_inference_spec: dict[str, Any] = {}
103
+ def _add_inference_spec(
104
+ self,
105
+ cpu: Optional[str],
106
+ memory: Optional[str],
107
+ gpu: Optional[Union[str, int]],
108
+ num_workers: Optional[int],
109
+ max_batch_rows: Optional[int],
110
+ ) -> None:
111
+ """Internal helper to store common inference specs."""
151
112
  if cpu:
152
- base_inference_spec["cpu"] = cpu
113
+ self._inference_spec["cpu"] = cpu
153
114
  if memory:
154
- base_inference_spec["memory"] = memory
115
+ self._inference_spec["memory"] = memory
155
116
  if gpu:
156
117
  if isinstance(gpu, int):
157
118
  gpu_str = str(gpu)
158
119
  else:
159
120
  gpu_str = gpu
160
- base_inference_spec["gpu"] = gpu_str
121
+ self._inference_spec["gpu"] = gpu_str
161
122
  if num_workers:
162
- base_inference_spec["num_workers"] = num_workers
123
+ self._inference_spec["num_workers"] = num_workers
163
124
  if max_batch_rows:
164
- base_inference_spec["max_batch_rows"] = max_batch_rows
165
-
166
- if service_name: # service spec
167
- assert ingress_enabled, "ingress_enabled is required for service spec"
168
- assert max_instances, "max_instances is required for service spec"
169
- saved_service_database = service_database_name or database_name
170
- saved_service_schema = service_schema_name or schema_name
171
- fq_service_name = identifier.get_schema_level_object_identifier(
172
- saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
125
+ self._inference_spec["max_batch_rows"] = max_batch_rows
126
+
127
+ def add_service_spec(
128
+ self,
129
+ service_name: sql_identifier.SqlIdentifier,
130
+ inference_compute_pool_name: sql_identifier.SqlIdentifier,
131
+ service_database_name: Optional[sql_identifier.SqlIdentifier] = None,
132
+ service_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
133
+ ingress_enabled: bool = True,
134
+ max_instances: int = 1,
135
+ cpu: Optional[str] = None,
136
+ memory: Optional[str] = None,
137
+ gpu: Optional[Union[str, int]] = None,
138
+ num_workers: Optional[int] = None,
139
+ max_batch_rows: Optional[int] = None,
140
+ ) -> "ModelDeploymentSpec":
141
+ """Add service specification to the deployment spec.
142
+
143
+ Args:
144
+ service_name: Name of the service.
145
+ inference_compute_pool_name: Compute pool for inference.
146
+ service_database_name: Database name for the service.
147
+ service_schema_name: Schema name for the service.
148
+ ingress_enabled: Whether ingress is enabled.
149
+ max_instances: Maximum number of service instances.
150
+ cpu: CPU requirement.
151
+ memory: Memory requirement.
152
+ gpu: GPU requirement.
153
+ num_workers: Number of workers.
154
+ max_batch_rows: Maximum batch rows for inference.
155
+
156
+ Raises:
157
+ ValueError: If a job spec already exists.
158
+
159
+ Returns:
160
+ Self for chaining.
161
+ """
162
+ if self._job:
163
+ raise ValueError("Cannot add a service spec when a job spec already exists.")
164
+
165
+ saved_service_database = service_database_name or self.database
166
+ saved_service_schema = service_schema_name or self.schema
167
+ assert saved_service_database is not None
168
+ assert saved_service_schema is not None
169
+ fq_service_name = identifier.get_schema_level_object_identifier(
170
+ saved_service_database.identifier(), saved_service_schema.identifier(), service_name.identifier()
171
+ )
172
+
173
+ self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
174
+
175
+ self._service = model_deployment_spec_schema.Service(
176
+ name=fq_service_name,
177
+ compute_pool=inference_compute_pool_name.identifier(),
178
+ ingress_enabled=ingress_enabled,
179
+ max_instances=max_instances,
180
+ **self._inference_spec,
181
+ )
182
+ return self
183
+
184
+ def add_job_spec(
185
+ self,
186
+ job_name: sql_identifier.SqlIdentifier,
187
+ inference_compute_pool_name: sql_identifier.SqlIdentifier,
188
+ warehouse: sql_identifier.SqlIdentifier,
189
+ target_method: str,
190
+ input_table_name: sql_identifier.SqlIdentifier,
191
+ output_table_name: sql_identifier.SqlIdentifier,
192
+ job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
193
+ job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
194
+ input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
195
+ input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
196
+ output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
197
+ output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
198
+ cpu: Optional[str] = None,
199
+ memory: Optional[str] = None,
200
+ gpu: Optional[Union[str, int]] = None,
201
+ num_workers: Optional[int] = None,
202
+ max_batch_rows: Optional[int] = None,
203
+ ) -> "ModelDeploymentSpec":
204
+ """Add job specification to the deployment spec.
205
+
206
+ Args:
207
+ job_name: Name of the job.
208
+ inference_compute_pool_name: Compute pool for inference.
209
+ job_database_name: Database name for the job.
210
+ job_schema_name: Schema name for the job.
211
+ warehouse: Warehouse for the job.
212
+ target_method: Target method for inference.
213
+ input_table_name: Input table name.
214
+ output_table_name: Output table name.
215
+ input_table_database_name: Database for input table.
216
+ input_table_schema_name: Schema for input table.
217
+ output_table_database_name: Database for output table.
218
+ output_table_schema_name: Schema for output table.
219
+ cpu: CPU requirement.
220
+ memory: Memory requirement.
221
+ gpu: GPU requirement.
222
+ num_workers: Number of workers.
223
+ max_batch_rows: Maximum batch rows for inference.
224
+
225
+ Raises:
226
+ ValueError: If a service spec already exists.
227
+
228
+ Returns:
229
+ Self for chaining.
230
+ """
231
+ if self._service:
232
+ raise ValueError("Cannot add a job spec when a service spec already exists.")
233
+
234
+ saved_job_database = job_database_name or self.database
235
+ saved_job_schema = job_schema_name or self.schema
236
+ input_table_database_name = input_table_database_name or self.database
237
+ input_table_schema_name = input_table_schema_name or self.schema
238
+ output_table_database_name = output_table_database_name or self.database
239
+ output_table_schema_name = output_table_schema_name or self.schema
240
+
241
+ assert saved_job_database is not None
242
+ assert saved_job_schema is not None
243
+ assert input_table_database_name is not None
244
+ assert input_table_schema_name is not None
245
+ assert output_table_database_name is not None
246
+ assert output_table_schema_name is not None
247
+
248
+ fq_job_name = identifier.get_schema_level_object_identifier(
249
+ saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
250
+ )
251
+ fq_input_table_name = identifier.get_schema_level_object_identifier(
252
+ input_table_database_name.identifier(),
253
+ input_table_schema_name.identifier(),
254
+ input_table_name.identifier(),
255
+ )
256
+ fq_output_table_name = identifier.get_schema_level_object_identifier(
257
+ output_table_database_name.identifier(),
258
+ output_table_schema_name.identifier(),
259
+ output_table_name.identifier(),
260
+ )
261
+
262
+ self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
263
+
264
+ self._job = model_deployment_spec_schema.Job(
265
+ name=fq_job_name,
266
+ compute_pool=inference_compute_pool_name.identifier(),
267
+ warehouse=warehouse.identifier(),
268
+ target_method=target_method,
269
+ input_table_name=fq_input_table_name,
270
+ output_table_name=fq_output_table_name,
271
+ **self._inference_spec,
272
+ )
273
+ return self
274
+
275
+ def add_hf_logger_spec(
276
+ self,
277
+ hf_model_name: str,
278
+ hf_task: Optional[str] = None,
279
+ hf_token: Optional[str] = None,
280
+ hf_tokenizer: Optional[str] = None,
281
+ hf_revision: Optional[str] = None,
282
+ hf_trust_remote_code: Optional[bool] = False,
283
+ pip_requirements: Optional[list[str]] = None,
284
+ conda_dependencies: Optional[list[str]] = None,
285
+ target_platforms: Optional[list[str]] = None,
286
+ comment: Optional[str] = None,
287
+ warehouse: Optional[str] = None,
288
+ **kwargs: Any,
289
+ ) -> "ModelDeploymentSpec":
290
+ """Add Hugging Face logger specification.
291
+
292
+ Args:
293
+ hf_model_name: Hugging Face model name.
294
+ hf_task: Hugging Face task.
295
+ hf_token: Hugging Face token.
296
+ hf_tokenizer: Hugging Face tokenizer.
297
+ hf_revision: Hugging Face model revision.
298
+ hf_trust_remote_code: Whether to trust remote code.
299
+ pip_requirements: List of pip requirements.
300
+ conda_dependencies: List of conda dependencies.
301
+ target_platforms: List of target platforms.
302
+ comment: Comment for the model.
303
+ warehouse: Warehouse used to log the model.
304
+ **kwargs: Additional Hugging Face model arguments.
305
+
306
+ Raises:
307
+ ValueError: If Hugging Face model name is missing when other HF parameters are provided.
308
+
309
+ Returns:
310
+ Self for chaining.
311
+ """
312
+ # Validation moved here from save
313
+ if (
314
+ any(
315
+ [
316
+ hf_task,
317
+ hf_token,
318
+ hf_tokenizer,
319
+ hf_revision,
320
+ hf_trust_remote_code,
321
+ pip_requirements,
322
+ ]
173
323
  )
174
- service = model_deployment_spec_schema.Service(
175
- name=fq_service_name,
176
- compute_pool=inference_compute_pool_name.identifier(),
177
- ingress_enabled=ingress_enabled,
178
- max_instances=max_instances,
179
- **base_inference_spec,
324
+ and not hf_model_name
325
+ ):
326
+ # This condition might be redundant now as hf_model_name is mandatory
327
+ raise ValueError("Hugging Face model name is required when using Hugging Face model deployment.")
328
+
329
+ log_model_args = model_deployment_spec_schema.LogModelArgs(
330
+ pip_requirements=pip_requirements,
331
+ conda_dependencies=conda_dependencies,
332
+ target_platforms=target_platforms,
333
+ comment=comment,
334
+ warehouse=warehouse,
335
+ )
336
+ hf_model = model_deployment_spec_schema.HuggingFaceModel(
337
+ hf_model_name=hf_model_name,
338
+ task=hf_task,
339
+ hf_token=hf_token,
340
+ tokenizer=hf_tokenizer,
341
+ trust_remote_code=hf_trust_remote_code,
342
+ revision=hf_revision,
343
+ hf_model_kwargs=json.dumps(kwargs),
344
+ )
345
+ model_logging = model_deployment_spec_schema.ModelLogging(
346
+ log_model_args=log_model_args,
347
+ hf_model=hf_model,
348
+ )
349
+ if self._model_loggings is None:
350
+ self._model_loggings = [model_logging]
351
+ else:
352
+ self._model_loggings.append(model_logging)
353
+ return self
354
+
355
+ def save(self) -> str:
356
+ """Constructs the final deployment spec from added components and saves it.
357
+
358
+ Raises:
359
+ ValueError: If required components are missing or conflicting specs are added.
360
+ RuntimeError: If no service or job spec is found despite validation.
361
+
362
+ Returns:
363
+ The path to the saved YAML file as a string, or the YAML content as a string
364
+ if workspace_path was not provided.
365
+ """
366
+ # Validations
367
+ if not self._models:
368
+ raise ValueError("Model specification is required. Call add_model_spec().")
369
+ if not self._image_build:
370
+ raise ValueError("Image build specification is required. Call add_image_build_spec().")
371
+ if not self._service and not self._job:
372
+ raise ValueError(
373
+ "Either service or job specification is required. Call add_service_spec() or add_job_spec()."
180
374
  )
375
+ if self._service and self._job:
376
+ # This case should be prevented by checks in add_service_spec/add_job_spec, but double-check
377
+ raise ValueError("Cannot have both service and job specifications.")
181
378
 
182
- # model deployment spec
379
+ # Construct the final spec object
380
+ if self._service:
183
381
  model_deployment_spec: Union[
184
382
  model_deployment_spec_schema.ModelServiceDeploymentSpec,
185
383
  model_deployment_spec_schema.ModelJobDeploymentSpec,
186
384
  ] = model_deployment_spec_schema.ModelServiceDeploymentSpec(
187
- models=[model],
188
- image_build=image_build,
189
- service=service,
190
- )
191
- else: # job spec
192
- assert job_name, "job_name is required for job spec"
193
- assert warehouse, "warehouse is required for job spec"
194
- assert target_method, "target_method is required for job spec"
195
- assert input_table_name, "input_table_name is required for job spec"
196
- assert output_table_name, "output_table_name is required for job spec"
197
- saved_job_database = job_database_name or database_name
198
- saved_job_schema = job_schema_name or schema_name
199
- input_table_database_name = input_table_database_name or database_name
200
- input_table_schema_name = input_table_schema_name or schema_name
201
- output_table_database_name = output_table_database_name or database_name
202
- output_table_schema_name = output_table_schema_name or schema_name
203
- fq_job_name = identifier.get_schema_level_object_identifier(
204
- saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
205
- )
206
- fq_input_table_name = identifier.get_schema_level_object_identifier(
207
- input_table_database_name.identifier(),
208
- input_table_schema_name.identifier(),
209
- input_table_name.identifier(),
210
- )
211
- fq_output_table_name = identifier.get_schema_level_object_identifier(
212
- output_table_database_name.identifier(),
213
- output_table_schema_name.identifier(),
214
- output_table_name.identifier(),
385
+ models=self._models,
386
+ image_build=self._image_build,
387
+ service=self._service,
388
+ model_loggings=self._model_loggings,
215
389
  )
216
- job = model_deployment_spec_schema.Job(
217
- name=fq_job_name,
218
- compute_pool=inference_compute_pool_name.identifier(),
219
- warehouse=warehouse.identifier(),
220
- target_method=target_method,
221
- input_table_name=fq_input_table_name,
222
- output_table_name=fq_output_table_name,
223
- **base_inference_spec,
224
- )
225
-
226
- # model deployment spec
390
+ elif self._job:
227
391
  model_deployment_spec = model_deployment_spec_schema.ModelJobDeploymentSpec(
228
- models=[model],
229
- image_build=image_build,
230
- job=job,
392
+ models=self._models,
393
+ image_build=self._image_build,
394
+ job=self._job,
395
+ model_loggings=self._model_loggings,
231
396
  )
397
+ else:
398
+ # Should not happen due to earlier validation
399
+ raise RuntimeError("Internal error: No service or job spec found despite validation.")
400
+
401
+ # Serialize and save/return
402
+ yaml_content = model_deployment_spec.model_dump(exclude_none=True)
232
403
 
233
404
  if self.workspace_path is None:
234
- return yaml.safe_dump(model_deployment_spec.model_dump(exclude_none=True))
405
+ return yaml.safe_dump(yaml_content)
235
406
 
236
- # save the yaml
237
407
  file_path = self.workspace_path / self.DEPLOY_SPEC_FILE_REL_PATH
238
408
  with file_path.open("w", encoding="utf-8") as f:
239
- yaml.safe_dump(model_deployment_spec.model_dump(exclude_none=True), f)
409
+ yaml.safe_dump(yaml_content, f)
240
410
  return str(file_path.resolve())
@@ -41,13 +41,38 @@ class Job(BaseModel):
41
41
  output_table_name: str
42
42
 
43
43
 
44
+ class LogModelArgs(BaseModel):
45
+ pip_requirements: Optional[list[str]] = None
46
+ conda_dependencies: Optional[list[str]] = None
47
+ target_platforms: Optional[list[str]] = None
48
+ comment: Optional[str] = None
49
+ warehouse: Optional[str] = None
50
+
51
+
52
+ class HuggingFaceModel(BaseModel):
53
+ hf_model_name: str
54
+ task: Optional[str] = None
55
+ tokenizer: Optional[str] = None
56
+ hf_token: Optional[str] = None
57
+ trust_remote_code: Optional[bool] = False
58
+ revision: Optional[str] = None
59
+ hf_model_kwargs: Optional[str] = "{}"
60
+
61
+
62
+ class ModelLogging(BaseModel):
63
+ log_model_args: Optional[LogModelArgs] = None
64
+ hf_model: Optional[HuggingFaceModel] = None
65
+
66
+
44
67
  class ModelServiceDeploymentSpec(BaseModel):
45
68
  models: list[Model]
46
69
  image_build: ImageBuild
47
70
  service: Service
71
+ model_loggings: Optional[list[ModelLogging]] = None
48
72
 
49
73
 
50
74
  class ModelJobDeploymentSpec(BaseModel):
51
75
  models: list[Model]
52
76
  image_build: ImageBuild
53
77
  job: Job
78
+ model_loggings: Optional[list[ModelLogging]] = None
@@ -4,7 +4,6 @@ import textwrap
4
4
  from typing import Any, Optional, Union
5
5
 
6
6
  from snowflake import snowpark
7
- from snowflake.ml._internal import platform_capabilities
8
7
  from snowflake.ml._internal.utils import (
9
8
  identifier,
10
9
  query_result_checker,
@@ -133,18 +132,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
133
132
  input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
134
133
  args_sql = f"object_construct_keep_null({input_args_sql})"
135
134
 
136
- if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
137
- fully_qualified_service_name = self.fully_qualified_object_name(
138
- actual_database_name, actual_schema_name, service_name
139
- )
140
- fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
141
- else:
142
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
143
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
144
- actual_database_name.identifier(),
145
- actual_schema_name.identifier(),
146
- function_name,
147
- )
135
+ fully_qualified_service_name = self.fully_qualified_object_name(
136
+ actual_database_name, actual_schema_name, service_name
137
+ )
138
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
148
139
 
149
140
  sql = textwrap.dedent(
150
141
  f"""{with_sql}