snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,7 @@ import enum
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Callable, Dict, List, Optional, Union
5
+ from typing import Any, Callable, Dict, List, Optional, Union, overload
6
6
 
7
7
  import pandas as pd
8
8
 
@@ -10,7 +10,7 @@ from snowflake.ml._internal import telemetry
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
11
  from snowflake.ml.lineage import lineage_node
12
12
  from snowflake.ml.model import type_hints as model_types
13
- from snowflake.ml.model._client.ops import metadata_ops, model_ops
13
+ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
15
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
16
  from snowflake.ml.model._packager.model_handlers import snowmlmodel
@@ -29,6 +29,7 @@ class ModelVersion(lineage_node.LineageNode):
29
29
  """Model Version Object representing a specific version of the model that could be run."""
30
30
 
31
31
  _model_ops: model_ops.ModelOperator
32
+ _service_ops: service_ops.ServiceOperator
32
33
  _model_name: sql_identifier.SqlIdentifier
33
34
  _version_name: sql_identifier.SqlIdentifier
34
35
  _functions: List[model_manifest_schema.ModelFunctionInfo]
@@ -41,11 +42,13 @@ class ModelVersion(lineage_node.LineageNode):
41
42
  cls,
42
43
  model_ops: model_ops.ModelOperator,
43
44
  *,
45
+ service_ops: service_ops.ServiceOperator,
44
46
  model_name: sql_identifier.SqlIdentifier,
45
47
  version_name: sql_identifier.SqlIdentifier,
46
48
  ) -> "ModelVersion":
47
49
  self: "ModelVersion" = object.__new__(cls)
48
50
  self._model_ops = model_ops
51
+ self._service_ops = service_ops
49
52
  self._model_name = model_name
50
53
  self._version_name = version_name
51
54
  self._functions = self._get_functions()
@@ -65,6 +68,7 @@ class ModelVersion(lineage_node.LineageNode):
65
68
  return False
66
69
  return (
67
70
  self._model_ops == __value._model_ops
71
+ and self._service_ops == __value._service_ops
68
72
  and self._model_name == __value._model_name
69
73
  and self._version_name == __value._version_name
70
74
  )
@@ -302,6 +306,23 @@ class ModelVersion(lineage_node.LineageNode):
302
306
  statement_params=statement_params,
303
307
  )
304
308
 
309
+ @telemetry.send_api_usage_telemetry(
310
+ project=_TELEMETRY_PROJECT,
311
+ subproject=_TELEMETRY_SUBPROJECT,
312
+ )
313
+ def get_model_objective(self) -> model_types.ModelObjective:
314
+ statement_params = telemetry.get_statement_params(
315
+ project=_TELEMETRY_PROJECT,
316
+ subproject=_TELEMETRY_SUBPROJECT,
317
+ )
318
+ return self._model_ops.get_model_objective(
319
+ database_name=None,
320
+ schema_name=None,
321
+ model_name=self._model_name,
322
+ version_name=self._version_name,
323
+ statement_params=statement_params,
324
+ )
325
+
305
326
  @telemetry.send_api_usage_telemetry(
306
327
  project=_TELEMETRY_PROJECT,
307
328
  subproject=_TELEMETRY_SUBPROJECT,
@@ -318,10 +339,7 @@ class ModelVersion(lineage_node.LineageNode):
318
339
  """
319
340
  return self._functions
320
341
 
321
- @telemetry.send_api_usage_telemetry(
322
- project=_TELEMETRY_PROJECT,
323
- subproject=_TELEMETRY_SUBPROJECT,
324
- )
342
+ @overload
325
343
  def run(
326
344
  self,
327
345
  X: Union[pd.DataFrame, dataframe.DataFrame],
@@ -339,6 +357,53 @@ class ModelVersion(lineage_node.LineageNode):
339
357
  partition_column: The partition column name to partition by.
340
358
  strict_input_validation: Enable stricter validation for the input data. This will result value range based
341
359
  type validation to make sure your input data won't overflow when providing to the model.
360
+ """
361
+ ...
362
+
363
+ @overload
364
+ def run(
365
+ self,
366
+ X: Union[pd.DataFrame, dataframe.DataFrame],
367
+ *,
368
+ service_name: str,
369
+ function_name: Optional[str] = None,
370
+ strict_input_validation: bool = False,
371
+ ) -> Union[pd.DataFrame, dataframe.DataFrame]:
372
+ """Invoke a method in a model version object via a service.
373
+
374
+ Args:
375
+ X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
376
+ service_name: The service name.
377
+ function_name: The function name to run. It is the name used to call a function in SQL.
378
+ strict_input_validation: Enable stricter validation for the input data. This will result value range based
379
+ type validation to make sure your input data won't overflow when providing to the model.
380
+ """
381
+ ...
382
+
383
+ @telemetry.send_api_usage_telemetry(
384
+ project=_TELEMETRY_PROJECT,
385
+ subproject=_TELEMETRY_SUBPROJECT,
386
+ func_params_to_log=["function_name", "service_name"],
387
+ )
388
+ def run(
389
+ self,
390
+ X: Union[pd.DataFrame, "dataframe.DataFrame"],
391
+ *,
392
+ service_name: Optional[str] = None,
393
+ function_name: Optional[str] = None,
394
+ partition_column: Optional[str] = None,
395
+ strict_input_validation: bool = False,
396
+ ) -> Union[pd.DataFrame, "dataframe.DataFrame"]:
397
+ """Invoke a method in a model version object via the warehouse or a service.
398
+
399
+ Args:
400
+ X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
401
+ service_name: The service name. If None, the function is invoked via the warehouse. Otherwise, the function
402
+ is invoked via the given service.
403
+ function_name: The function name to run. It is the name used to call a function in SQL.
404
+ partition_column: The partition column name to partition by.
405
+ strict_input_validation: Enable stricter validation for the input data. This will result value range based
406
+ type validation to make sure your input data won't overflow when providing to the model.
342
407
 
343
408
  Raises:
344
409
  ValueError: When no method with the corresponding name is available.
@@ -375,23 +440,37 @@ class ModelVersion(lineage_node.LineageNode):
375
440
  elif len(functions) != 1:
376
441
  raise ValueError(
377
442
  f"There are more than 1 target methods available in the model {self.fully_qualified_model_name}"
378
- f" version {self.version_name}. Please specify a `method_name` when calling the `run` method."
443
+ f" version {self.version_name}. Please specify a `function_name` when calling the `run` method."
379
444
  )
380
445
  else:
381
446
  target_function_info = functions[0]
382
- return self._model_ops.invoke_method(
383
- method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
384
- method_function_type=target_function_info["target_method_function_type"],
385
- signature=target_function_info["signature"],
386
- X=X,
387
- database_name=None,
388
- schema_name=None,
389
- model_name=self._model_name,
390
- version_name=self._version_name,
391
- strict_input_validation=strict_input_validation,
392
- partition_column=partition_column,
393
- statement_params=statement_params,
394
- )
447
+
448
+ if service_name:
449
+ return self._model_ops.invoke_method(
450
+ method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
451
+ signature=target_function_info["signature"],
452
+ X=X,
453
+ database_name=None,
454
+ schema_name=None,
455
+ service_name=sql_identifier.SqlIdentifier(service_name),
456
+ strict_input_validation=strict_input_validation,
457
+ statement_params=statement_params,
458
+ )
459
+ else:
460
+ return self._model_ops.invoke_method(
461
+ method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
462
+ method_function_type=target_function_info["target_method_function_type"],
463
+ signature=target_function_info["signature"],
464
+ X=X,
465
+ database_name=None,
466
+ schema_name=None,
467
+ model_name=self._model_name,
468
+ version_name=self._version_name,
469
+ strict_input_validation=strict_input_validation,
470
+ partition_column=partition_column,
471
+ statement_params=statement_params,
472
+ is_partitioned=target_function_info["is_partitioned"],
473
+ )
395
474
 
396
475
  @telemetry.send_api_usage_telemetry(
397
476
  project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
@@ -525,9 +604,96 @@ class ModelVersion(lineage_node.LineageNode):
525
604
  database_name=database_name_id,
526
605
  schema_name=schema_name_id,
527
606
  ),
607
+ service_ops=service_ops.ServiceOperator(
608
+ session,
609
+ database_name=database_name_id,
610
+ schema_name=schema_name_id,
611
+ ),
528
612
  model_name=model_name_id,
529
613
  version_name=sql_identifier.SqlIdentifier(version),
530
614
  )
531
615
 
616
+ @telemetry.send_api_usage_telemetry(
617
+ project=_TELEMETRY_PROJECT,
618
+ subproject=_TELEMETRY_SUBPROJECT,
619
+ func_params_to_log=[
620
+ "service_name",
621
+ "image_build_compute_pool",
622
+ "service_compute_pool",
623
+ "image_repo_database",
624
+ "image_repo_schema",
625
+ "image_repo",
626
+ "gpu_requests",
627
+ "num_workers",
628
+ ],
629
+ )
630
+ def create_service(
631
+ self,
632
+ *,
633
+ service_name: str,
634
+ image_build_compute_pool: Optional[str] = None,
635
+ service_compute_pool: str,
636
+ image_repo: str,
637
+ ingress_enabled: bool = False,
638
+ max_instances: int = 1,
639
+ gpu_requests: Optional[str] = None,
640
+ num_workers: Optional[int] = None,
641
+ force_rebuild: bool = False,
642
+ build_external_access_integration: str,
643
+ ) -> str:
644
+ """Create an inference service with the given spec.
645
+
646
+ Args:
647
+ service_name: The name of the service, can be fully qualified. If not fully qualified, the database or
648
+ schema of the model will be used.
649
+ image_build_compute_pool: The name of the compute pool used to build the model inference image. Use
650
+ the service compute pool if None.
651
+ service_compute_pool: The name of the compute pool used to run the inference service.
652
+ image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
653
+ or schema of the model will be used.
654
+ ingress_enabled: Whether to enable ingress.
655
+ max_instances: The maximum number of inference service instances to run.
656
+ gpu_requests: The gpu limit for GPU based inference. Can be integer, fractional or string values. Use CPU
657
+ if None.
658
+ num_workers: The number of workers (replicas of models) to run the inference service.
659
+ Auto determined if None.
660
+ force_rebuild: Whether to force a model inference image rebuild.
661
+ build_external_access_integration: The external access integration for image build.
662
+
663
+ Returns:
664
+ The service name.
665
+ """
666
+ statement_params = telemetry.get_statement_params(
667
+ project=_TELEMETRY_PROJECT,
668
+ subproject=_TELEMETRY_SUBPROJECT,
669
+ )
670
+ service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
671
+ image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
672
+ return self._service_ops.create_service(
673
+ database_name=None,
674
+ schema_name=None,
675
+ model_name=self._model_name,
676
+ version_name=self._version_name,
677
+ service_database_name=service_db_id,
678
+ service_schema_name=service_schema_id,
679
+ service_name=service_id,
680
+ image_build_compute_pool_name=(
681
+ sql_identifier.SqlIdentifier(image_build_compute_pool)
682
+ if image_build_compute_pool
683
+ else sql_identifier.SqlIdentifier(service_compute_pool)
684
+ ),
685
+ service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
686
+ image_repo_database_name=image_repo_db_id,
687
+ image_repo_schema_name=image_repo_schema_id,
688
+ image_repo_name=image_repo_id,
689
+ ingress_enabled=ingress_enabled,
690
+ max_instances=max_instances,
691
+ gpu_requests=gpu_requests,
692
+ num_workers=num_workers,
693
+ force_rebuild=force_rebuild,
694
+ build_external_access_integration=sql_identifier.SqlIdentifier(build_external_access_integration),
695
+ statement_params=statement_params,
696
+ )
697
+
532
698
 
533
699
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -2,7 +2,7 @@ import os
2
2
  import pathlib
3
3
  import tempfile
4
4
  import warnings
5
- from typing import Any, Dict, List, Literal, Optional, Union, cast
5
+ from typing import Any, Dict, List, Literal, Optional, Union, cast, overload
6
6
 
7
7
  import yaml
8
8
 
@@ -12,6 +12,7 @@ from snowflake.ml.model._client.ops import metadata_ops
12
12
  from snowflake.ml.model._client.sql import (
13
13
  model as model_sql,
14
14
  model_version as model_version_sql,
15
+ service as service_sql,
15
16
  stage as stage_sql,
16
17
  tag as tag_sql,
17
18
  )
@@ -21,7 +22,7 @@ from snowflake.ml.model._model_composer.model_manifest import (
21
22
  model_manifest_schema,
22
23
  )
23
24
  from snowflake.ml.model._packager.model_env import model_env
24
- from snowflake.ml.model._packager.model_meta import model_meta
25
+ from snowflake.ml.model._packager.model_meta import model_meta, model_meta_schema
25
26
  from snowflake.ml.model._packager.model_runtime import model_runtime
26
27
  from snowflake.ml.model._signatures import snowpark_handler
27
28
  from snowflake.snowpark import dataframe, row, session
@@ -60,6 +61,11 @@ class ModelOperator:
60
61
  database_name=database_name,
61
62
  schema_name=schema_name,
62
63
  )
64
+ self._service_client = service_sql.ServiceSQLClient(
65
+ session,
66
+ database_name=database_name,
67
+ schema_name=schema_name,
68
+ )
63
69
  self._metadata_ops = metadata_ops.MetadataOperator(
64
70
  session,
65
71
  database_name=database_name,
@@ -548,15 +554,14 @@ class ModelOperator:
548
554
  res[function_name] = target_method
549
555
  return res
550
556
 
551
- def get_functions(
557
+ def _fetch_model_spec(
552
558
  self,
553
- *,
554
559
  database_name: Optional[sql_identifier.SqlIdentifier],
555
560
  schema_name: Optional[sql_identifier.SqlIdentifier],
556
561
  model_name: sql_identifier.SqlIdentifier,
557
562
  version_name: sql_identifier.SqlIdentifier,
558
563
  statement_params: Optional[Dict[str, Any]] = None,
559
- ) -> List[model_manifest_schema.ModelFunctionInfo]:
564
+ ) -> model_meta_schema.ModelMetadataDict:
560
565
  raw_model_spec_res = self._model_client.show_versions(
561
566
  database_name=database_name,
562
567
  schema_name=schema_name,
@@ -567,6 +572,43 @@ class ModelOperator:
567
572
  )[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
568
573
  model_spec_dict = yaml.safe_load(raw_model_spec_res)
569
574
  model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
575
+ return model_spec
576
+
577
+ def get_model_objective(
578
+ self,
579
+ *,
580
+ database_name: Optional[sql_identifier.SqlIdentifier],
581
+ schema_name: Optional[sql_identifier.SqlIdentifier],
582
+ model_name: sql_identifier.SqlIdentifier,
583
+ version_name: sql_identifier.SqlIdentifier,
584
+ statement_params: Optional[Dict[str, Any]] = None,
585
+ ) -> type_hints.ModelObjective:
586
+ model_spec = self._fetch_model_spec(
587
+ database_name=database_name,
588
+ schema_name=schema_name,
589
+ model_name=model_name,
590
+ version_name=version_name,
591
+ statement_params=statement_params,
592
+ )
593
+ model_objective_val = model_spec.get("model_objective", type_hints.ModelObjective.UNKNOWN.value)
594
+ return type_hints.ModelObjective(model_objective_val)
595
+
596
+ def get_functions(
597
+ self,
598
+ *,
599
+ database_name: Optional[sql_identifier.SqlIdentifier],
600
+ schema_name: Optional[sql_identifier.SqlIdentifier],
601
+ model_name: sql_identifier.SqlIdentifier,
602
+ version_name: sql_identifier.SqlIdentifier,
603
+ statement_params: Optional[Dict[str, Any]] = None,
604
+ ) -> List[model_manifest_schema.ModelFunctionInfo]:
605
+ model_spec = self._fetch_model_spec(
606
+ database_name=database_name,
607
+ schema_name=schema_name,
608
+ model_name=model_name,
609
+ version_name=version_name,
610
+ statement_params=statement_params,
611
+ )
570
612
  show_functions_res = self._model_version_client.show_functions(
571
613
  database_name=database_name,
572
614
  schema_name=schema_name,
@@ -597,16 +639,38 @@ class ModelOperator:
597
639
  function_names, list(signatures.keys())
598
640
  )
599
641
 
600
- return [
601
- model_manifest_schema.ModelFunctionInfo(
602
- name=function_name.identifier(),
603
- target_method=function_name_mapping[function_name],
604
- target_method_function_type=function_type,
605
- signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
642
+ model_func_info = []
643
+
644
+ for function_name, function_type in function_names_and_types:
645
+
646
+ target_method = function_name_mapping[function_name]
647
+
648
+ is_partitioned = False
649
+ if function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
650
+ # better to set default True here because worse case it will be slow but not error out
651
+ is_partitioned = (
652
+ (
653
+ model_spec["function_properties"]
654
+ .get(target_method, {})
655
+ .get(model_meta_schema.FunctionProperties.PARTITIONED.value, True)
656
+ )
657
+ if "function_properties" in model_spec
658
+ else True
659
+ )
660
+
661
+ model_func_info.append(
662
+ model_manifest_schema.ModelFunctionInfo(
663
+ name=function_name.identifier(),
664
+ target_method=target_method,
665
+ target_method_function_type=function_type,
666
+ signature=model_signature.ModelSignature.from_dict(signatures[target_method]),
667
+ is_partitioned=is_partitioned,
668
+ )
606
669
  )
607
- for function_name, function_type in function_names_and_types
608
- ]
609
670
 
671
+ return model_func_info
672
+
673
+ @overload
610
674
  def invoke_method(
611
675
  self,
612
676
  *,
@@ -621,6 +685,41 @@ class ModelOperator:
621
685
  strict_input_validation: bool = False,
622
686
  partition_column: Optional[sql_identifier.SqlIdentifier] = None,
623
687
  statement_params: Optional[Dict[str, str]] = None,
688
+ is_partitioned: Optional[bool] = None,
689
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
690
+ ...
691
+
692
+ @overload
693
+ def invoke_method(
694
+ self,
695
+ *,
696
+ method_name: sql_identifier.SqlIdentifier,
697
+ signature: model_signature.ModelSignature,
698
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
699
+ database_name: Optional[sql_identifier.SqlIdentifier],
700
+ schema_name: Optional[sql_identifier.SqlIdentifier],
701
+ service_name: sql_identifier.SqlIdentifier,
702
+ strict_input_validation: bool = False,
703
+ statement_params: Optional[Dict[str, str]] = None,
704
+ ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
705
+ ...
706
+
707
+ def invoke_method(
708
+ self,
709
+ *,
710
+ method_name: sql_identifier.SqlIdentifier,
711
+ method_function_type: Optional[str] = None,
712
+ signature: model_signature.ModelSignature,
713
+ X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
714
+ database_name: Optional[sql_identifier.SqlIdentifier],
715
+ schema_name: Optional[sql_identifier.SqlIdentifier],
716
+ model_name: Optional[sql_identifier.SqlIdentifier] = None,
717
+ version_name: Optional[sql_identifier.SqlIdentifier] = None,
718
+ service_name: Optional[sql_identifier.SqlIdentifier] = None,
719
+ strict_input_validation: bool = False,
720
+ partition_column: Optional[sql_identifier.SqlIdentifier] = None,
721
+ statement_params: Optional[Dict[str, str]] = None,
722
+ is_partitioned: Optional[bool] = None,
624
723
  ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
625
724
  identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
626
725
 
@@ -657,31 +756,46 @@ class ModelOperator:
657
756
  if output_name in original_cols:
658
757
  original_cols.remove(output_name)
659
758
 
660
- if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
661
- df_res = self._model_version_client.invoke_function_method(
662
- method_name=method_name,
663
- input_df=s_df,
664
- input_args=input_args,
665
- returns=returns,
666
- database_name=database_name,
667
- schema_name=schema_name,
668
- model_name=model_name,
669
- version_name=version_name,
670
- statement_params=statement_params,
671
- )
672
- elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
673
- df_res = self._model_version_client.invoke_table_function_method(
759
+ if service_name:
760
+ df_res = self._service_client.invoke_function_method(
674
761
  method_name=method_name,
675
762
  input_df=s_df,
676
763
  input_args=input_args,
677
- partition_column=partition_column,
678
764
  returns=returns,
679
765
  database_name=database_name,
680
766
  schema_name=schema_name,
681
- model_name=model_name,
682
- version_name=version_name,
767
+ service_name=service_name,
683
768
  statement_params=statement_params,
684
769
  )
770
+ else:
771
+ assert model_name is not None
772
+ assert version_name is not None
773
+ if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
774
+ df_res = self._model_version_client.invoke_function_method(
775
+ method_name=method_name,
776
+ input_df=s_df,
777
+ input_args=input_args,
778
+ returns=returns,
779
+ database_name=database_name,
780
+ schema_name=schema_name,
781
+ model_name=model_name,
782
+ version_name=version_name,
783
+ statement_params=statement_params,
784
+ )
785
+ elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
786
+ df_res = self._model_version_client.invoke_table_function_method(
787
+ method_name=method_name,
788
+ input_df=s_df,
789
+ input_args=input_args,
790
+ partition_column=partition_column,
791
+ returns=returns,
792
+ database_name=database_name,
793
+ schema_name=schema_name,
794
+ model_name=model_name,
795
+ version_name=version_name,
796
+ statement_params=statement_params,
797
+ is_partitioned=is_partitioned or False,
798
+ )
685
799
 
686
800
  if keep_order:
687
801
  # if it's a partitioned table function, _ID will be null and we won't be able to sort.