snowflake-ml-python 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/telemetry.py +3 -1
  3. snowflake/ml/_internal/utils/service_logger.py +26 -1
  4. snowflake/ml/experiment/_client/artifact.py +76 -0
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  6. snowflake/ml/experiment/experiment_tracking.py +113 -6
  7. snowflake/ml/feature_store/feature_store.py +1150 -131
  8. snowflake/ml/feature_store/feature_view.py +122 -0
  9. snowflake/ml/jobs/_utils/constants.py +8 -16
  10. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +19 -5
  12. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  13. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +23 -5
  14. snowflake/ml/jobs/_utils/spec_utils.py +4 -6
  15. snowflake/ml/jobs/_utils/types.py +2 -1
  16. snowflake/ml/jobs/job.py +38 -19
  17. snowflake/ml/jobs/manager.py +136 -19
  18. snowflake/ml/model/__init__.py +6 -1
  19. snowflake/ml/model/_client/model/batch_inference_specs.py +25 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +62 -65
  21. snowflake/ml/model/_client/ops/model_ops.py +42 -9
  22. snowflake/ml/model/_client/ops/service_ops.py +75 -154
  23. snowflake/ml/model/_client/service/model_deployment_spec.py +23 -37
  24. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +15 -4
  25. snowflake/ml/model/_client/sql/service.py +4 -0
  26. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +309 -22
  27. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
  29. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  30. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  31. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  32. snowflake/ml/model/_signatures/utils.py +4 -2
  33. snowflake/ml/model/models/huggingface_pipeline.py +23 -0
  34. snowflake/ml/model/openai_signatures.py +57 -0
  35. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  37. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  38. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  39. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  40. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  41. snowflake/ml/modeling/cluster/birch.py +1 -1
  42. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  43. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  44. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  45. snowflake/ml/modeling/cluster/k_means.py +1 -1
  46. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  47. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/optics.py +1 -1
  49. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  50. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  51. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  52. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  53. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  54. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  55. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  56. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  57. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  58. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  59. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  60. snowflake/ml/modeling/covariance/oas.py +1 -1
  61. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  62. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  63. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  64. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  65. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  66. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  67. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  69. snowflake/ml/modeling/decomposition/pca.py +1 -1
  70. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  72. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  73. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  74. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  79. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  80. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  81. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  84. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  85. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  88. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  89. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  90. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  91. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  92. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  93. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  94. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  95. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  96. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  97. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  98. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  99. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  100. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  101. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  102. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  103. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  104. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  105. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  106. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  107. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  108. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  109. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  110. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  112. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  113. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  114. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  117. snowflake/ml/modeling/linear_model/lars.py +1 -1
  118. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  119. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  120. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  122. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  124. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  125. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  126. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  127. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  128. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  129. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  130. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  131. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  132. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  134. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  135. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  136. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  137. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  138. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  139. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  140. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  141. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  142. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  143. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  144. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  145. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  146. snowflake/ml/modeling/manifold/isomap.py +1 -1
  147. snowflake/ml/modeling/manifold/mds.py +1 -1
  148. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  149. snowflake/ml/modeling/manifold/tsne.py +1 -1
  150. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  151. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  152. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  153. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  154. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  155. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  156. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  157. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  158. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  159. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  160. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  161. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  162. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  163. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  164. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  165. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  166. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  167. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  168. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  169. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  170. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  171. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  179. snowflake/ml/modeling/svm/svc.py +1 -1
  180. snowflake/ml/modeling/svm/svr.py +1 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  189. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  190. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  191. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  192. snowflake/ml/monitoring/model_monitor.py +26 -0
  193. snowflake/ml/version.py +1 -1
  194. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/METADATA +82 -5
  195. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/RECORD +198 -194
  196. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/licenses/LICENSE.txt +0 -0
  198. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/top_level.txt +0 -0
@@ -10,17 +10,13 @@ import time
10
10
  from typing import Any, Optional, Union, cast
11
11
 
12
12
  from snowflake import snowpark
13
+ from snowflake.ml import jobs
13
14
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
14
15
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
15
- from snowflake.ml.model import (
16
- inference_engine as inference_engine_module,
17
- model_signature,
18
- type_hints,
19
- )
16
+ from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
20
17
  from snowflake.ml.model._client.service import model_deployment_spec
21
18
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
22
- from snowflake.ml.model._signatures import snowpark_handler
23
- from snowflake.snowpark import async_job, dataframe, exceptions, row, session
19
+ from snowflake.snowpark import async_job, exceptions, row, session
24
20
  from snowflake.snowpark._internal import utils as snowpark_utils
25
21
 
26
22
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -866,174 +862,99 @@ class ServiceOperator:
866
862
  except exceptions.SnowparkSQLException:
867
863
  return False
868
864
 
869
- def invoke_job_method(
865
+ def invoke_batch_job_method(
870
866
  self,
871
- target_method: str,
872
- signature: model_signature.ModelSignature,
873
- X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
874
- database_name: Optional[sql_identifier.SqlIdentifier],
875
- schema_name: Optional[sql_identifier.SqlIdentifier],
867
+ *,
868
+ function_name: str,
876
869
  model_name: sql_identifier.SqlIdentifier,
877
870
  version_name: sql_identifier.SqlIdentifier,
878
- job_database_name: Optional[sql_identifier.SqlIdentifier],
879
- job_schema_name: Optional[sql_identifier.SqlIdentifier],
880
- job_name: sql_identifier.SqlIdentifier,
871
+ job_name: str,
881
872
  compute_pool_name: sql_identifier.SqlIdentifier,
882
- warehouse_name: sql_identifier.SqlIdentifier,
873
+ warehouse: sql_identifier.SqlIdentifier,
883
874
  image_repo_name: Optional[str],
884
- output_table_database_name: Optional[sql_identifier.SqlIdentifier],
885
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
886
- output_table_name: sql_identifier.SqlIdentifier,
887
- cpu_requests: Optional[str],
888
- memory_requests: Optional[str],
889
- gpu_requests: Optional[Union[int, str]],
875
+ input_stage_location: str,
876
+ input_file_pattern: str,
877
+ output_stage_location: str,
878
+ completion_filename: str,
879
+ force_rebuild: bool,
890
880
  num_workers: Optional[int],
891
881
  max_batch_rows: Optional[int],
892
- force_rebuild: bool,
893
- build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
882
+ cpu_requests: Optional[str],
883
+ memory_requests: Optional[str],
884
+ replicas: Optional[int],
894
885
  statement_params: Optional[dict[str, Any]] = None,
895
- ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
896
- # fall back to the registry's database and schema if not provided
897
- database_name = database_name or self._database_name
898
- schema_name = schema_name or self._schema_name
899
-
900
- # fall back to the model's database and schema if not provided then to the registry's database and schema
901
- job_database_name = job_database_name or database_name or self._database_name
902
- job_schema_name = job_schema_name or schema_name or self._schema_name
886
+ ) -> jobs.MLJob[Any]:
887
+ database_name = self._database_name
888
+ schema_name = self._schema_name
903
889
 
904
- image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name)
890
+ job_database_name, job_schema_name, job_name = sql_identifier.parse_fully_qualified_name(job_name)
891
+ job_database_name = job_database_name or database_name
892
+ job_schema_name = job_schema_name or schema_name
905
893
 
906
- input_table_database_name = job_database_name
907
- input_table_schema_name = job_schema_name
908
- output_table_database_name = output_table_database_name or database_name or self._database_name
909
- output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
910
-
911
- if self._workspace:
912
- stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
913
- else:
914
- stage_path = None
894
+ self._model_deployment_spec.clear()
915
895
 
916
- # validate and prepare input
917
- if not isinstance(X, dataframe.DataFrame):
918
- keep_order = True
919
- output_with_input_features = False
920
- df = model_signature._convert_and_validate_local_data(X, signature.inputs)
921
- s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
922
- self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
923
- )
924
- else:
925
- keep_order = False
926
- output_with_input_features = True
927
- s_df = X
928
-
929
- # only write the index and feature input columns
930
- cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
931
- cols += [
932
- sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
933
- ]
934
- s_df = s_df.select(cols)
935
- original_cols = s_df.columns
936
-
937
- # input/output tables
938
- fq_output_table_name = identifier.get_schema_level_object_identifier(
939
- output_table_database_name.identifier(),
940
- output_table_schema_name.identifier(),
941
- output_table_name.identifier(),
942
- )
943
- tmp_input_table_id = sql_identifier.SqlIdentifier(
944
- snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
945
- )
946
- fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
947
- job_database_name.identifier(),
948
- job_schema_name.identifier(),
949
- tmp_input_table_id.identifier(),
950
- )
951
- s_df.write.save_as_table(
952
- table_name=fq_tmp_input_table_name,
953
- mode="errorifexists",
954
- statement_params=statement_params,
896
+ self._model_deployment_spec.add_model_spec(
897
+ database_name=database_name,
898
+ schema_name=schema_name,
899
+ model_name=model_name,
900
+ version_name=version_name,
955
901
  )
956
902
 
957
- try:
958
- self._model_deployment_spec.clear()
959
- # save the spec
960
- self._model_deployment_spec.add_model_spec(
961
- database_name=database_name,
962
- schema_name=schema_name,
963
- model_name=model_name,
964
- version_name=version_name,
965
- )
966
- self._model_deployment_spec.add_job_spec(
967
- job_database_name=job_database_name,
968
- job_schema_name=job_schema_name,
969
- job_name=job_name,
970
- inference_compute_pool_name=compute_pool_name,
971
- cpu=cpu_requests,
972
- memory=memory_requests,
973
- gpu=gpu_requests,
974
- num_workers=num_workers,
975
- max_batch_rows=max_batch_rows,
976
- warehouse=warehouse_name,
977
- target_method=target_method,
978
- input_table_database_name=input_table_database_name,
979
- input_table_schema_name=input_table_schema_name,
980
- input_table_name=tmp_input_table_id,
981
- output_table_database_name=output_table_database_name,
982
- output_table_schema_name=output_table_schema_name,
983
- output_table_name=output_table_name,
984
- )
903
+ self._model_deployment_spec.add_job_spec(
904
+ job_database_name=job_database_name,
905
+ job_schema_name=job_schema_name,
906
+ job_name=job_name,
907
+ inference_compute_pool_name=compute_pool_name,
908
+ num_workers=num_workers,
909
+ max_batch_rows=max_batch_rows,
910
+ input_stage_location=input_stage_location,
911
+ input_file_pattern=input_file_pattern,
912
+ output_stage_location=output_stage_location,
913
+ completion_filename=completion_filename,
914
+ function_name=function_name,
915
+ warehouse=warehouse,
916
+ cpu=cpu_requests,
917
+ memory=memory_requests,
918
+ replicas=replicas,
919
+ )
985
920
 
986
- self._model_deployment_spec.add_image_build_spec(
987
- image_build_compute_pool_name=compute_pool_name,
988
- fully_qualified_image_repo_name=image_repo_fqn,
989
- force_rebuild=force_rebuild,
990
- external_access_integrations=build_external_access_integrations,
991
- )
921
+ self._model_deployment_spec.add_image_build_spec(
922
+ image_build_compute_pool_name=compute_pool_name,
923
+ fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
924
+ force_rebuild=force_rebuild,
925
+ )
992
926
 
993
- spec_yaml_str_or_path = self._model_deployment_spec.save()
994
- if self._workspace:
995
- assert stage_path is not None
996
- file_utils.upload_directory_to_stage(
997
- self._session,
998
- local_path=pathlib.Path(self._workspace.name),
999
- stage_path=pathlib.PurePosixPath(stage_path),
1000
- statement_params=statement_params,
1001
- )
927
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
1002
928
 
1003
- # deploy the job
1004
- query_id, async_job = self._service_client.deploy_model(
1005
- stage_path=stage_path if self._workspace else None,
1006
- model_deployment_spec_file_rel_path=(
1007
- model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
1008
- ),
1009
- model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
929
+ if self._workspace:
930
+ module_logger.info("using workspace")
931
+ stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
932
+ file_utils.upload_directory_to_stage(
933
+ self._session,
934
+ local_path=pathlib.Path(self._workspace.name),
935
+ stage_path=pathlib.PurePosixPath(stage_path),
1010
936
  statement_params=statement_params,
1011
937
  )
938
+ else:
939
+ module_logger.info("not using workspace")
940
+ stage_path = None
1012
941
 
1013
- while not async_job.is_done():
1014
- time.sleep(5)
1015
- finally:
1016
- self._session.table(fq_tmp_input_table_name).drop_table()
1017
-
1018
- # handle the output
1019
- df_res = self._session.table(fq_output_table_name)
1020
- if keep_order:
1021
- df_res = df_res.sort(
1022
- snowpark_handler._KEEP_ORDER_COL_NAME,
1023
- ascending=True,
1024
- )
1025
- df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
942
+ _, async_job = self._service_client.deploy_model(
943
+ stage_path=stage_path if self._workspace else None,
944
+ model_deployment_spec_file_rel_path=(
945
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
946
+ ),
947
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
948
+ statement_params=statement_params,
949
+ )
1026
950
 
1027
- if not output_with_input_features:
1028
- df_res = df_res.drop(*original_cols)
951
+ # Block until the async job is done
952
+ async_job.result()
1029
953
 
1030
- # get final result
1031
- if not isinstance(X, dataframe.DataFrame):
1032
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
1033
- df_res, features=signature.outputs, statement_params=statement_params
1034
- )
1035
- else:
1036
- return df_res
954
+ return jobs.MLJob(
955
+ id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
956
+ session=self._session,
957
+ )
1037
958
 
1038
959
  def _create_temp_stage(
1039
960
  self,
@@ -194,42 +194,40 @@ class ModelDeploymentSpec:
194
194
  self,
195
195
  job_name: sql_identifier.SqlIdentifier,
196
196
  inference_compute_pool_name: sql_identifier.SqlIdentifier,
197
+ function_name: str,
198
+ input_stage_location: str,
199
+ output_stage_location: str,
200
+ completion_filename: str,
201
+ input_file_pattern: str,
197
202
  warehouse: sql_identifier.SqlIdentifier,
198
- target_method: str,
199
- input_table_name: sql_identifier.SqlIdentifier,
200
- output_table_name: sql_identifier.SqlIdentifier,
201
203
  job_database_name: Optional[sql_identifier.SqlIdentifier] = None,
202
204
  job_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
203
- input_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
204
- input_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
205
- output_table_database_name: Optional[sql_identifier.SqlIdentifier] = None,
206
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier] = None,
207
205
  cpu: Optional[str] = None,
208
206
  memory: Optional[str] = None,
209
207
  gpu: Optional[Union[str, int]] = None,
210
208
  num_workers: Optional[int] = None,
211
209
  max_batch_rows: Optional[int] = None,
210
+ replicas: Optional[int] = None,
212
211
  ) -> "ModelDeploymentSpec":
213
212
  """Add job specification to the deployment spec.
214
213
 
215
214
  Args:
216
215
  job_name: Name of the job.
217
216
  inference_compute_pool_name: Compute pool for inference.
217
+ warehouse: Warehouse for the job.
218
+ function_name: Function name.
219
+ input_stage_location: Stage location for input data.
220
+ output_stage_location: Stage location for output data.
218
221
  job_database_name: Database name for the job.
219
222
  job_schema_name: Schema name for the job.
220
- warehouse: Warehouse for the job.
221
- target_method: Target method for inference.
222
- input_table_name: Input table name.
223
- output_table_name: Output table name.
224
- input_table_database_name: Database for input table.
225
- input_table_schema_name: Schema for input table.
226
- output_table_database_name: Database for output table.
227
- output_table_schema_name: Schema for output table.
223
+ input_file_pattern: Pattern for input files (optional).
224
+ completion_filename: Name of completion file (default: "completion.txt").
228
225
  cpu: CPU requirement.
229
226
  memory: Memory requirement.
230
227
  gpu: GPU requirement.
231
228
  num_workers: Number of workers.
232
229
  max_batch_rows: Maximum batch rows for inference.
230
+ replicas: Number of replicas.
233
231
 
234
232
  Raises:
235
233
  ValueError: If a service spec already exists.
@@ -242,41 +240,29 @@ class ModelDeploymentSpec:
242
240
 
243
241
  saved_job_database = job_database_name or self.database
244
242
  saved_job_schema = job_schema_name or self.schema
245
- input_table_database_name = input_table_database_name or self.database
246
- input_table_schema_name = input_table_schema_name or self.schema
247
- output_table_database_name = output_table_database_name or self.database
248
- output_table_schema_name = output_table_schema_name or self.schema
249
243
 
250
244
  assert saved_job_database is not None
251
245
  assert saved_job_schema is not None
252
- assert input_table_database_name is not None
253
- assert input_table_schema_name is not None
254
- assert output_table_database_name is not None
255
- assert output_table_schema_name is not None
256
246
 
257
247
  fq_job_name = identifier.get_schema_level_object_identifier(
258
248
  saved_job_database.identifier(), saved_job_schema.identifier(), job_name.identifier()
259
249
  )
260
- fq_input_table_name = identifier.get_schema_level_object_identifier(
261
- input_table_database_name.identifier(),
262
- input_table_schema_name.identifier(),
263
- input_table_name.identifier(),
264
- )
265
- fq_output_table_name = identifier.get_schema_level_object_identifier(
266
- output_table_database_name.identifier(),
267
- output_table_schema_name.identifier(),
268
- output_table_name.identifier(),
269
- )
270
250
 
271
251
  self._add_inference_spec(cpu, memory, gpu, num_workers, max_batch_rows)
272
252
 
273
253
  self._job = model_deployment_spec_schema.Job(
274
254
  name=fq_job_name,
275
255
  compute_pool=inference_compute_pool_name.identifier(),
276
- warehouse=warehouse.identifier(),
277
- target_method=target_method,
278
- input_table_name=fq_input_table_name,
279
- output_table_name=fq_output_table_name,
256
+ warehouse=warehouse.identifier() if warehouse else None,
257
+ function_name=function_name,
258
+ input=model_deployment_spec_schema.Input(
259
+ input_stage_location=input_stage_location, input_file_pattern=input_file_pattern
260
+ ),
261
+ output=model_deployment_spec_schema.Output(
262
+ output_stage_location=output_stage_location,
263
+ completion_filename=completion_filename,
264
+ ),
265
+ replicas=replicas,
280
266
  **self._inference_spec,
281
267
  )
282
268
  return self
@@ -35,6 +35,16 @@ class Service(BaseModel):
35
35
  inference_engine_spec: Optional[InferenceEngineSpec] = None
36
36
 
37
37
 
38
+ class Input(BaseModel):
39
+ input_stage_location: str
40
+ input_file_pattern: str
41
+
42
+
43
+ class Output(BaseModel):
44
+ output_stage_location: str
45
+ completion_filename: str
46
+
47
+
38
48
  class Job(BaseModel):
39
49
  name: str
40
50
  compute_pool: str
@@ -43,10 +53,11 @@ class Job(BaseModel):
43
53
  gpu: Optional[str] = None
44
54
  num_workers: Optional[int] = None
45
55
  max_batch_rows: Optional[int] = None
46
- warehouse: str
47
- target_method: str
48
- input_table_name: str
49
- output_table_name: str
56
+ warehouse: Optional[str] = None
57
+ function_name: str
58
+ input: Input
59
+ output: Output
60
+ replicas: Optional[int] = None
50
61
 
51
62
 
52
63
  class LogModelArgs(BaseModel):
@@ -63,6 +63,7 @@ class ServiceStatusInfo:
63
63
  class ServiceSQLClient(_base._BaseSQLClient):
64
64
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
65
65
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
66
+ MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME = "privatelink_ingress_url"
66
67
  SERVICE_STATUS = "service_status"
67
68
  INSTANCE_ID = "instance_id"
68
69
  INSTANCE_STATUS = "instance_status"
@@ -255,6 +256,9 @@ class ServiceSQLClient(_base._BaseSQLClient):
255
256
  )
256
257
  .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME, allow_empty=True)
257
258
  .has_column(ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME, allow_empty=True)
259
+ .has_column(
260
+ ServiceSQLClient.MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME, allow_empty=True
261
+ )
258
262
  )
259
263
 
260
264
  return res.validate()