snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,15 @@
1
+ import enum
2
+ import json
1
3
  import textwrap
2
4
  from typing import Any, Dict, List, Optional, Tuple
3
5
 
6
+ from packaging import version
7
+
8
+ from snowflake import snowpark
4
9
  from snowflake.ml._internal.utils import (
5
10
  identifier,
6
11
  query_result_checker,
12
+ snowflake_env,
7
13
  sql_identifier,
8
14
  )
9
15
  from snowflake.ml.model._client.sql import _base
@@ -11,6 +17,17 @@ from snowflake.snowpark import dataframe, functions as F, types as spt
11
17
  from snowflake.snowpark._internal import utils as snowpark_utils
12
18
 
13
19
 
20
+ class ServiceStatus(enum.Enum):
21
+ UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
22
+ PENDING = "PENDING" # resource set is being created, can't be used yet
23
+ READY = "READY" # resource set has been deployed.
24
+ DELETING = "DELETING" # resource set is being deleted
25
+ FAILED = "FAILED" # resource set has failed and cannot be used anymore
26
+ DONE = "DONE" # resource set has finished running
27
+ NOT_FOUND = "NOT_FOUND" # not found or deleted
28
+ INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
29
+
30
+
14
31
  class ServiceSQLClient(_base._BaseSQLClient):
15
32
  def build_model_container(
16
33
  self,
@@ -30,20 +47,21 @@ class ServiceSQLClient(_base._BaseSQLClient):
30
47
  ) -> None:
31
48
  actual_image_repo_database = image_repo_database_name or self._database_name
32
49
  actual_image_repo_schema = image_repo_schema_name or self._schema_name
33
- fq_model_name = self.fully_qualified_object_name(database_name, schema_name, model_name)
34
- fq_image_repo_name = "/" + "/".join(
35
- [
36
- actual_image_repo_database.identifier(),
37
- actual_image_repo_schema.identifier(),
38
- image_repo_name.identifier(),
39
- ]
50
+ actual_model_database = database_name or self._database_name
51
+ actual_model_schema = schema_name or self._schema_name
52
+ fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
53
+ fq_image_repo_name = identifier.get_schema_level_object_identifier(
54
+ actual_image_repo_database.identifier(),
55
+ actual_image_repo_schema.identifier(),
56
+ image_repo_name.identifier(),
40
57
  )
41
- is_gpu = gpu is not None
58
+ is_gpu_str = "TRUE" if gpu else "FALSE"
59
+ force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
42
60
  query_result_checker.SqlResultValidator(
43
61
  self._session,
44
62
  (
45
63
  f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
46
- f" '{fq_image_repo_name}', '{is_gpu}', '{force_rebuild}', '', '{external_access_integration}')"
64
+ f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
47
65
  ),
48
66
  statement_params=statement_params,
49
67
  ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -54,12 +72,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
54
72
  stage_path: str,
55
73
  model_deployment_spec_file_rel_path: str,
56
74
  statement_params: Optional[Dict[str, Any]] = None,
57
- ) -> None:
58
- query_result_checker.SqlResultValidator(
59
- self._session,
60
- f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')",
61
- statement_params=statement_params,
62
- ).has_dimensions(expected_rows=1, expected_cols=1).validate()
75
+ ) -> Tuple[str, snowpark.AsyncJob]:
76
+ async_job = self._session.sql(
77
+ f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
78
+ ).collect(block=False, statement_params=statement_params)
79
+ assert isinstance(async_job, snowpark.AsyncJob)
80
+ return async_job.query_id, async_job
63
81
 
64
82
  def invoke_function_method(
65
83
  self,
@@ -74,12 +92,13 @@ class ServiceSQLClient(_base._BaseSQLClient):
74
92
  statement_params: Optional[Dict[str, Any]] = None,
75
93
  ) -> dataframe.DataFrame:
76
94
  with_statements = []
95
+ actual_database_name = database_name or self._database_name
96
+ actual_schema_name = schema_name or self._schema_name
97
+
77
98
  if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
78
- INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
99
+ INTERMEDIATE_TABLE_NAME = ServiceSQLClient.get_tmp_name_with_prefix("SNOWPARK_ML_MODEL_INFERENCE_INPUT")
79
100
  with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
80
101
  else:
81
- actual_database_name = database_name or self._database_name
82
- actual_schema_name = schema_name or self._schema_name
83
102
  tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
84
103
  INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
85
104
  actual_database_name.identifier(),
@@ -93,7 +112,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
93
112
  statement_params=statement_params,
94
113
  )
95
114
 
96
- INTERMEDIATE_OBJ_NAME = "TMP_RESULT"
115
+ INTERMEDIATE_OBJ_NAME = ServiceSQLClient.get_tmp_name_with_prefix("TMP_RESULT")
97
116
 
98
117
  with_sql = f"WITH {','.join(with_statements)}" if with_statements else ""
99
118
  args_sql_list = []
@@ -101,10 +120,26 @@ class ServiceSQLClient(_base._BaseSQLClient):
101
120
  args_sql_list.append(input_arg_value)
102
121
  args_sql = ", ".join(args_sql_list)
103
122
 
123
+ if snowflake_env.get_current_snowflake_version(
124
+ self._session, statement_params=statement_params
125
+ ) >= version.parse("8.39.0"):
126
+ fully_qualified_service_name = self.fully_qualified_object_name(
127
+ actual_database_name, actual_schema_name, service_name
128
+ )
129
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
130
+
131
+ else:
132
+ function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
133
+ fully_qualified_function_name = identifier.get_schema_level_object_identifier(
134
+ actual_database_name.identifier(),
135
+ actual_schema_name.identifier(),
136
+ function_name,
137
+ )
138
+
104
139
  sql = textwrap.dedent(
105
140
  f"""{with_sql}
106
141
  SELECT *,
107
- {service_name.identifier()}_{method_name.identifier()}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
142
+ {fully_qualified_function_name}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
108
143
  FROM {INTERMEDIATE_TABLE_NAME}"""
109
144
  )
110
145
 
@@ -127,3 +162,69 @@ class ServiceSQLClient(_base._BaseSQLClient):
127
162
  output_df._statement_params = statement_params # type: ignore[assignment]
128
163
 
129
164
  return output_df
165
+
166
+ def get_service_logs(
167
+ self,
168
+ *,
169
+ database_name: Optional[sql_identifier.SqlIdentifier],
170
+ schema_name: Optional[sql_identifier.SqlIdentifier],
171
+ service_name: sql_identifier.SqlIdentifier,
172
+ instance_id: str = "0",
173
+ container_name: str,
174
+ statement_params: Optional[Dict[str, Any]] = None,
175
+ ) -> str:
176
+ system_func = "SYSTEM$GET_SERVICE_LOGS"
177
+ rows = (
178
+ query_result_checker.SqlResultValidator(
179
+ self._session,
180
+ (
181
+ f"CALL {system_func}("
182
+ f"'{self.fully_qualified_object_name(database_name, schema_name, service_name)}', '{instance_id}', "
183
+ f"'{container_name}')"
184
+ ),
185
+ statement_params=statement_params,
186
+ )
187
+ .has_dimensions(expected_rows=1, expected_cols=1)
188
+ .validate()
189
+ )
190
+ return str(rows[0][system_func])
191
+
192
+ def get_service_status(
193
+ self,
194
+ *,
195
+ database_name: Optional[sql_identifier.SqlIdentifier],
196
+ schema_name: Optional[sql_identifier.SqlIdentifier],
197
+ service_name: sql_identifier.SqlIdentifier,
198
+ include_message: bool = False,
199
+ statement_params: Optional[Dict[str, Any]] = None,
200
+ ) -> Tuple[ServiceStatus, Optional[str]]:
201
+ system_func = "SYSTEM$GET_SERVICE_STATUS"
202
+ rows = (
203
+ query_result_checker.SqlResultValidator(
204
+ self._session,
205
+ f"CALL {system_func}('{self.fully_qualified_object_name(database_name, schema_name, service_name)}')",
206
+ statement_params=statement_params,
207
+ )
208
+ .has_dimensions(expected_rows=1, expected_cols=1)
209
+ .validate()
210
+ )
211
+ metadata = json.loads(rows[0][system_func])[0]
212
+ if metadata and metadata["status"]:
213
+ service_status = ServiceStatus(metadata["status"])
214
+ message = metadata["message"] if include_message else None
215
+ return service_status, message
216
+ return ServiceStatus.UNKNOWN, None
217
+
218
+ def drop_service(
219
+ self,
220
+ *,
221
+ database_name: Optional[sql_identifier.SqlIdentifier],
222
+ schema_name: Optional[sql_identifier.SqlIdentifier],
223
+ service_name: sql_identifier.SqlIdentifier,
224
+ statement_params: Optional[Dict[str, Any]] = None,
225
+ ) -> None:
226
+ query_result_checker.SqlResultValidator(
227
+ self._session,
228
+ f"DROP SERVICE {self.fully_qualified_object_name(database_name, schema_name, service_name)}",
229
+ statement_params=statement_params,
230
+ ).has_dimensions(expected_rows=1, expected_cols=1).validate()
@@ -1,14 +1,11 @@
1
- import glob
2
1
  import pathlib
3
2
  import tempfile
4
3
  import uuid
5
- import zipfile
6
4
  from types import ModuleType
7
5
  from typing import Any, Dict, List, Optional
8
6
 
9
7
  from absl import logging
10
8
  from packaging import requirements
11
- from typing_extensions import deprecated
12
9
 
13
10
  from snowflake import snowpark
14
11
  from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
@@ -92,6 +89,7 @@ class ModelComposer:
92
89
  python_version: Optional[str] = None,
93
90
  ext_modules: Optional[List[ModuleType]] = None,
94
91
  code_paths: Optional[List[str]] = None,
92
+ task: model_types.Task = model_types.Task.UNKNOWN,
95
93
  options: Optional[model_types.ModelSaveOption] = None,
96
94
  ) -> model_meta.ModelMetadata:
97
95
  if not options:
@@ -120,24 +118,20 @@ class ModelComposer:
120
118
  python_version=python_version,
121
119
  ext_modules=ext_modules,
122
120
  code_paths=code_paths,
121
+ task=task,
123
122
  options=options,
124
123
  )
125
124
  assert self.packager.meta is not None
126
125
 
127
- if not options.get("_legacy_save", False):
128
- # Keep both loose files and zipped file.
129
- # TODO(SNOW-726678): Remove once import a directory is possible.
130
- file_utils.copytree(
131
- str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
132
- )
133
- self.manifest.save(
134
- model_meta=self.packager.meta,
135
- model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
136
- options=options,
137
- data_sources=self._get_data_sources(model, sample_input_data),
138
- )
139
- else:
140
- file_utils.make_archive(self.model_local_path, str(self._packager_workspace_path))
126
+ file_utils.copytree(
127
+ str(self._packager_workspace_path), str(self.workspace_path / ModelComposer.MODEL_DIR_REL_PATH)
128
+ )
129
+ self.manifest.save(
130
+ model_meta=self.packager.meta,
131
+ model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
132
+ options=options,
133
+ data_sources=self._get_data_sources(model, sample_input_data),
134
+ )
141
135
 
142
136
  file_utils.upload_directory_to_stage(
143
137
  self.session,
@@ -147,28 +141,6 @@ class ModelComposer:
147
141
  )
148
142
  return model_metadata
149
143
 
150
- @deprecated("Only used by PrPr model registry. Use static method version of load instead.")
151
- def legacy_load(
152
- self,
153
- *,
154
- meta_only: bool = False,
155
- options: Optional[model_types.ModelLoadOption] = None,
156
- ) -> None:
157
- file_utils.download_directory_from_stage(
158
- self.session,
159
- stage_path=self.stage_path,
160
- local_path=self.workspace_path,
161
- statement_params=self._statement_params,
162
- )
163
-
164
- # TODO (Server-side Model Rollout): Remove this section.
165
- model_zip_path = pathlib.Path(glob.glob(str(self.workspace_path / "*.zip"))[0])
166
- self.model_file_rel_path = str(model_zip_path.relative_to(self.workspace_path))
167
-
168
- with zipfile.ZipFile(self.model_local_path, mode="r", compression=zipfile.ZIP_DEFLATED) as zf:
169
- zf.extractall(path=self._packager_workspace_path)
170
- self.packager.load(meta_only=meta_only, options=options)
171
-
172
144
  @staticmethod
173
145
  def load(
174
146
  workspace_path: pathlib.Path,
@@ -1,5 +1,5 @@
1
1
  import collections
2
- import copy
2
+ import logging
3
3
  import pathlib
4
4
  import warnings
5
5
  from typing import List, Optional, cast
@@ -18,6 +18,9 @@ from snowflake.ml.model._packager.model_meta import (
18
18
  model_meta as model_meta_api,
19
19
  model_meta_schema,
20
20
  )
21
+ from snowflake.ml.model._packager.model_runtime import model_runtime
22
+
23
+ logger = logging.getLogger(__name__)
21
24
 
22
25
 
23
26
  class ModelManifest:
@@ -45,9 +48,30 @@ class ModelManifest:
45
48
  if options is None:
46
49
  options = {}
47
50
 
48
- runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
49
- runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
50
- runtime_to_use.imports.append(str(model_rel_path) + "/")
51
+ if "relax_version" not in options:
52
+ warnings.warn(
53
+ (
54
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
55
+ " from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
56
+ "reproducibility, etc., set `options={'relax_version': False}` when logging the model."
57
+ ),
58
+ category=UserWarning,
59
+ stacklevel=2,
60
+ )
61
+ relax_version = options.get("relax_version", True)
62
+
63
+ runtime_to_use = model_runtime.ModelRuntime(
64
+ name=self._DEFAULT_RUNTIME_NAME,
65
+ env=model_meta.env,
66
+ imports=[str(model_rel_path) + "/"],
67
+ is_gpu=False,
68
+ is_warehouse=True,
69
+ )
70
+ if relax_version:
71
+ runtime_to_use.runtime_env.relax_version()
72
+ logger.info("Relaxing version constraints for dependencies in the model.")
73
+ logger.info(f"Conda dependencies: {runtime_to_use.runtime_env.conda_dependencies}")
74
+ logger.info(f"Pip requirements: {runtime_to_use.runtime_env.pip_requirements}")
51
75
  runtime_dict = runtime_to_use.save(
52
76
  self.workspace_path, default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
53
77
  )
@@ -78,13 +102,9 @@ class ModelManifest:
78
102
  )
79
103
 
80
104
  dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
81
- if options.get("include_pip_dependencies"):
82
- warnings.warn(
83
- "`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
84
- "be warehouse-compabible. The model may need to be run in SPCS.",
85
- category=UserWarning,
86
- stacklevel=1,
87
- )
105
+
106
+ # We only want to include pip dependencies file if there are any pip requirements.
107
+ if len(model_meta.env.pip_requirements) > 0:
88
108
  dependencies["pip"] = runtime_dict["dependencies"]["pip"]
89
109
 
90
110
  manifest_dict = model_manifest_schema.ModelManifestDict(
@@ -21,7 +21,7 @@ _DEFAULT_PIP_REQUIREMENTS_FILENAME = "requirements.txt"
21
21
  # The default CUDA version is chosen based on the driver availability in SPCS.
22
22
  # If changing this version, we need also change the version of default PyTorch in HuggingFace pipeline handler to
23
23
  # make sure they are compatible.
24
- DEFAULT_CUDA_VERSION = "11.7"
24
+ DEFAULT_CUDA_VERSION = "11.8"
25
25
 
26
26
 
27
27
  class ModelEnv:
@@ -199,50 +199,16 @@ class ModelEnv:
199
199
  )
200
200
  if xgboost_spec:
201
201
  self.include_if_absent(
202
- [
203
- ModelDependency(
204
- requirement=f"conda-forge::py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost"
205
- )
206
- ],
202
+ [ModelDependency(requirement=f"py-xgboost-gpu{xgboost_spec.specifier}", pip_name="xgboost")],
207
203
  check_local_version=False,
208
204
  )
209
205
 
210
- pytorch_spec = env_utils.find_dep_spec(
211
- self._conda_dependencies,
212
- self._pip_requirements,
213
- conda_pkg_name="pytorch",
214
- pip_pkg_name="torch",
215
- remove_spec=True,
216
- )
217
- pytorch_cuda_spec = env_utils.find_dep_spec(
218
- self._conda_dependencies,
219
- self._pip_requirements,
220
- conda_pkg_name="pytorch-cuda",
221
- remove_spec=False,
222
- )
223
- if pytorch_cuda_spec and not pytorch_cuda_spec.specifier.contains(self.cuda_version):
224
- raise ValueError(
225
- "The Pytorch-CUDA requirement you specified in your conda dependencies or pip requirements is"
226
- " conflicting with CUDA version required. Please do not specify Pytorch-CUDA dependency using conda"
227
- " dependencies or pip requirements."
228
- )
229
- if pytorch_spec:
230
- self.include_if_absent(
231
- [ModelDependency(requirement=f"pytorch::pytorch{pytorch_spec.specifier}", pip_name="torch")],
232
- check_local_version=False,
233
- )
234
- if not pytorch_cuda_spec:
235
- self.include_if_absent(
236
- [ModelDependency(requirement=f"pytorch::pytorch-cuda=={self.cuda_version}.*", pip_name="torch")],
237
- check_local_version=False,
238
- )
239
-
240
206
  tf_spec = env_utils.find_dep_spec(
241
207
  self._conda_dependencies, self._pip_requirements, conda_pkg_name="tensorflow", remove_spec=True
242
208
  )
243
209
  if tf_spec:
244
210
  self.include_if_absent(
245
- [ModelDependency(requirement=f"conda-forge::tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")],
211
+ [ModelDependency(requirement=f"tensorflow-gpu{tf_spec.specifier}", pip_name="tensorflow")],
246
212
  check_local_version=False,
247
213
  )
248
214
 
@@ -252,7 +218,7 @@ class ModelEnv:
252
218
  if transformers_spec:
253
219
  self.include_if_absent(
254
220
  [
255
- ModelDependency(requirement="conda-forge::accelerate>=0.22.0", pip_name="accelerate"),
221
+ ModelDependency(requirement="accelerate>=0.22.0", pip_name="accelerate"),
256
222
  ModelDependency(requirement="scipy>=1.9", pip_name="scipy"),
257
223
  ],
258
224
  check_local_version=False,
@@ -1,20 +1,54 @@
1
1
  import json
2
- from typing import Any, Callable, Iterable, Optional, Sequence, cast
2
+ import os
3
+ import warnings
4
+ from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
3
5
 
4
6
  import numpy as np
5
7
  import numpy.typing as npt
6
8
  import pandas as pd
9
+ from absl import logging
7
10
 
11
+ import snowflake.snowpark.dataframe as sp_df
12
+ from snowflake.ml._internal.utils import identifier
8
13
  from snowflake.ml.model import model_signature, type_hints as model_types
9
14
  from snowflake.ml.model._packager.model_meta import model_meta
10
- from snowflake.ml.model._signatures import snowpark_handler
15
+ from snowflake.ml.model._signatures import (
16
+ core,
17
+ snowpark_handler,
18
+ utils as model_signature_utils,
19
+ )
11
20
  from snowflake.snowpark import DataFrame as SnowparkDataFrame
12
21
 
22
+ EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT = 1000
23
+
24
+
25
+ class NumpyEncoder(json.JSONEncoder):
26
+ def default(self, obj: Any) -> Any:
27
+ if isinstance(obj, np.integer):
28
+ return int(obj)
29
+ if isinstance(obj, np.floating):
30
+ return float(obj)
31
+ if isinstance(obj, np.ndarray):
32
+ return obj.tolist()
33
+ return super().default(obj)
34
+
13
35
 
14
36
  def _is_callable(model: model_types.SupportedModelType, method_name: str) -> bool:
15
37
  return callable(getattr(model, method_name, None))
16
38
 
17
39
 
40
+ def get_truncated_sample_data(sample_input_data: model_types.SupportedDataType) -> model_types.SupportedLocalDataType:
41
+ trunc_sample_input = model_signature._truncate_data(sample_input_data)
42
+ local_sample_input: model_types.SupportedLocalDataType = None
43
+ if isinstance(sample_input_data, SnowparkDataFrame):
44
+ # Added because of Any from missing stubs.
45
+ trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
46
+ local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
47
+ else:
48
+ local_sample_input = trunc_sample_input
49
+ return local_sample_input
50
+
51
+
18
52
  def validate_signature(
19
53
  model: model_types.SupportedRequireSignatureModelType,
20
54
  model_meta: model_meta.ModelMetadata,
@@ -24,19 +58,23 @@ def validate_signature(
24
58
  ) -> model_meta.ModelMetadata:
25
59
  if model_meta.signatures:
26
60
  validate_target_methods(model, list(model_meta.signatures.keys()))
61
+ if sample_input_data is not None:
62
+ local_sample_input = get_truncated_sample_data(sample_input_data)
63
+ for target_method in model_meta.signatures.keys():
64
+
65
+ model_signature_inst = model_meta.signatures.get(target_method)
66
+ if model_signature_inst is not None:
67
+ # strict validation the input signature
68
+ model_signature._convert_and_validate_local_data(
69
+ local_sample_input, model_signature_inst._inputs, True
70
+ )
27
71
  return model_meta
28
72
 
29
73
  # In this case sample_input_data should be available, because of the check in save_model.
30
74
  assert (
31
75
  sample_input_data is not None
32
76
  ), "Model signature and sample input are None at the same time. This should not happen with local model."
33
- trunc_sample_input = model_signature._truncate_data(sample_input_data)
34
- if isinstance(sample_input_data, SnowparkDataFrame):
35
- # Added because of Any from missing stubs.
36
- trunc_sample_input = cast(SnowparkDataFrame, trunc_sample_input)
37
- local_sample_input = snowpark_handler.SnowparkDataFrameHandler.convert_to_df(trunc_sample_input)
38
- else:
39
- local_sample_input = trunc_sample_input
77
+ local_sample_input = get_truncated_sample_data(sample_input_data)
40
78
  for target_method in target_methods:
41
79
  predictions_df = get_prediction_fn(target_method, local_sample_input)
42
80
  sig = model_signature.infer_signature(local_sample_input, predictions_df)
@@ -45,24 +83,55 @@ def validate_signature(
45
83
  return model_meta
46
84
 
47
85
 
86
+ def get_input_signature(
87
+ model_meta: model_meta.ModelMetadata, target_method: Optional[str]
88
+ ) -> Sequence[core.BaseFeatureSpec]:
89
+ if target_method is None or target_method not in model_meta.signatures:
90
+ raise ValueError(f"Signature for target method {target_method} is missing or no method to explain.")
91
+ input_sig = model_meta.signatures[target_method].inputs
92
+ return input_sig
93
+
94
+
48
95
  def add_explain_method_signature(
49
96
  model_meta: model_meta.ModelMetadata,
50
97
  explain_method: str,
51
- target_method: str,
98
+ target_method: Optional[str],
52
99
  output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
53
100
  ) -> model_meta.ModelMetadata:
54
- if target_method not in model_meta.signatures:
55
- raise ValueError(f"Signature for target method {target_method} is missing")
56
- inputs = model_meta.signatures[target_method].inputs
101
+ inputs = get_input_signature(model_meta, target_method)
102
+ if model_meta.model_type == "snowml":
103
+ output_feature_names = [identifier.concat_names([spec.name, "_explanation"]) for spec in inputs]
104
+ else:
105
+ output_feature_names = [f"{spec.name}_explanation" for spec in inputs]
57
106
  model_meta.signatures[explain_method] = model_signature.ModelSignature(
58
107
  inputs=inputs,
59
108
  outputs=[
60
- model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
109
+ model_signature.FeatureSpec(dtype=output_return_type, name=output_name)
110
+ for output_name in output_feature_names
61
111
  ],
62
112
  )
63
113
  return model_meta
64
114
 
65
115
 
116
+ def get_explainability_supported_background(
117
+ sample_input_data: Optional[model_types.SupportedDataType],
118
+ meta: model_meta.ModelMetadata,
119
+ explain_target_method: Optional[str],
120
+ ) -> pd.DataFrame:
121
+ if sample_input_data is None:
122
+ return None
123
+
124
+ if isinstance(sample_input_data, pd.DataFrame):
125
+ return sample_input_data
126
+ if isinstance(sample_input_data, sp_df.DataFrame):
127
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
128
+
129
+ df = model_signature._convert_local_data_to_df(sample_input_data)
130
+ input_signature_for_explain = get_input_signature(meta, explain_target_method)
131
+ df_with_named_cols = model_signature_utils.rename_pandas_df(df, input_signature_for_explain)
132
+ return df_with_named_cols
133
+
134
+
66
135
  def get_target_methods(
67
136
  model: model_types.SupportedModelType,
68
137
  target_methods: Optional[Sequence[str]],
@@ -75,6 +144,23 @@ def get_target_methods(
75
144
  return target_methods
76
145
 
77
146
 
147
+ def save_background_data(
148
+ model_blobs_dir_path: str,
149
+ explain_artifact_dir: str,
150
+ bg_data_file_suffix: str,
151
+ model_name: str,
152
+ background_data: pd.DataFrame,
153
+ ) -> None:
154
+ data_blob_path = os.path.join(model_blobs_dir_path, explain_artifact_dir)
155
+ os.makedirs(data_blob_path, exist_ok=True)
156
+ with open(os.path.join(data_blob_path, model_name + bg_data_file_suffix), "wb") as f:
157
+ # saving only the truncated data
158
+ trunc_background_data = background_data.head(
159
+ min(len(background_data.index), EXPLAIN_BACKGROUND_DATA_ROWS_COUNT_LIMIT)
160
+ )
161
+ trunc_background_data.to_parquet(f)
162
+
163
+
78
164
  def validate_target_methods(model: model_types.SupportedModelType, target_methods: Iterable[str]) -> None:
79
165
  for method_name in target_methods:
80
166
  if not _is_callable(model, method_name):
@@ -93,23 +179,43 @@ def convert_explanations_to_2D_df(
93
179
  return pd.DataFrame(explanations)
94
180
 
95
181
  if hasattr(model, "classes_"):
96
- classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
182
+ classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
97
183
  len_classes = len(classes_list)
98
184
  if explanations.shape[2] != len_classes:
99
185
  raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
100
186
  else:
101
- classes_list = [i for i in range(explanations.shape[2])]
102
- exp_2d = []
103
- # TODO (SNOW-1549044): Optimize this
104
- for row in explanations:
105
- col_list = []
106
- for column in row:
107
- class_explanations = {}
108
- for cl, cl_exp in zip(classes_list, column):
109
- if isinstance(cl, (int, np.integer)):
110
- cl = int(cl)
111
- class_explanations[cl] = cl_exp
112
- col_list.append(json.dumps(class_explanations))
113
- exp_2d.append(col_list)
187
+ classes_list = [str(i) for i in range(explanations.shape[2])]
188
+
189
+ def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]:
190
+ """Converts a single row to a dictionary."""
191
+ # convert to object or numpy creates strings of fixed length
192
+ return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
193
+
194
+ exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
114
195
 
115
196
  return pd.DataFrame(exp_2d)
197
+
198
+
199
+ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task: model_types.Task) -> model_types.Task:
200
+ if passed_model_task != model_types.Task.UNKNOWN and inferred_model_task != model_types.Task.UNKNOWN:
201
+ if passed_model_task != inferred_model_task:
202
+ warnings.warn(
203
+ f"Inferred Task: {inferred_model_task.name} is used as task for this model "
204
+ f"version and passed argument Task: {passed_model_task.name} is ignored",
205
+ category=UserWarning,
206
+ stacklevel=1,
207
+ )
208
+ return inferred_model_task
209
+ elif inferred_model_task != model_types.Task.UNKNOWN:
210
+ logging.info(f"Inferred Task: {inferred_model_task.name} is used as task for this model " f"version")
211
+ return inferred_model_task
212
+ return passed_model_task
213
+
214
+
215
+ def get_explain_target_method(
216
+ model_metadata: model_meta.ModelMetadata, target_methods_list: List[str]
217
+ ) -> Optional[str]:
218
+ for method in model_metadata.signatures.keys():
219
+ if method in target_methods_list:
220
+ return method
221
+ return None