snowflake-ml-python 1.10.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/callback/keras.py +63 -0
  6. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  7. snowflake/ml/experiment/callback/xgboost.py +5 -1
  8. snowflake/ml/experiment/experiment_tracking.py +89 -4
  9. snowflake/ml/feature_store/feature_store.py +1150 -131
  10. snowflake/ml/feature_store/feature_view.py +122 -0
  11. snowflake/ml/jobs/_utils/__init__.py +0 -0
  12. snowflake/ml/jobs/_utils/constants.py +9 -14
  13. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  14. snowflake/ml/jobs/_utils/payload_utils.py +61 -19
  15. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  16. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  17. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  18. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +15 -7
  19. snowflake/ml/jobs/_utils/spec_utils.py +44 -13
  20. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  21. snowflake/ml/jobs/_utils/types.py +7 -8
  22. snowflake/ml/jobs/job.py +34 -18
  23. snowflake/ml/jobs/manager.py +107 -24
  24. snowflake/ml/model/__init__.py +6 -1
  25. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  26. snowflake/ml/model/_client/model/model_version_impl.py +225 -73
  27. snowflake/ml/model/_client/ops/service_ops.py +128 -174
  28. snowflake/ml/model/_client/service/model_deployment_spec.py +123 -64
  29. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -9
  30. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  31. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  32. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  33. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  34. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  35. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  36. snowflake/ml/model/_signatures/utils.py +4 -2
  37. snowflake/ml/model/inference_engine.py +5 -0
  38. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  39. snowflake/ml/model/openai_signatures.py +57 -0
  40. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  41. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  42. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  43. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  44. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  45. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  46. snowflake/ml/modeling/cluster/birch.py +1 -1
  47. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  49. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  50. snowflake/ml/modeling/cluster/k_means.py +1 -1
  51. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  52. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  53. snowflake/ml/modeling/cluster/optics.py +1 -1
  54. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  55. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  56. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  57. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  58. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  59. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  60. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  61. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  62. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  63. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  64. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  65. snowflake/ml/modeling/covariance/oas.py +1 -1
  66. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  67. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  69. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  70. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  72. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  73. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  74. snowflake/ml/modeling/decomposition/pca.py +1 -1
  75. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  76. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  77. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  78. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  79. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  82. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  83. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  84. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  85. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  88. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  89. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  90. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  91. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  92. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  93. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  94. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  95. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  96. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  97. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  98. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  99. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  100. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  101. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  102. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  103. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  104. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  105. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  106. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  107. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  108. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  109. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  110. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  111. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  112. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  113. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  114. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  115. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  117. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  118. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  119. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  120. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  121. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  122. snowflake/ml/modeling/linear_model/lars.py +1 -1
  123. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  124. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  125. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  126. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  127. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  128. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  129. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  130. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  131. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  132. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  133. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  135. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  136. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  137. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  138. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  140. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  141. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  142. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  143. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  144. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  145. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  146. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  147. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  148. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  149. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  150. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  151. snowflake/ml/modeling/manifold/isomap.py +1 -1
  152. snowflake/ml/modeling/manifold/mds.py +1 -1
  153. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  154. snowflake/ml/modeling/manifold/tsne.py +1 -1
  155. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  156. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  157. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  158. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  159. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  160. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  161. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  162. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  163. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  164. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  165. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  166. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  167. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  168. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  169. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  170. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  171. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  172. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  173. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  174. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  175. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  176. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  177. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  178. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  179. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  180. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  181. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  182. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  183. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  184. snowflake/ml/modeling/svm/svc.py +1 -1
  185. snowflake/ml/modeling/svm/svr.py +1 -1
  186. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  187. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  188. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  189. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  190. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  191. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  192. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  193. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  194. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  195. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  196. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  197. snowflake/ml/monitoring/model_monitor.py +26 -0
  198. snowflake/ml/registry/_manager/model_manager.py +7 -35
  199. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  200. snowflake/ml/version.py +1 -1
  201. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +87 -7
  202. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +205 -197
  203. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  204. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  205. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,18 @@
1
1
  import enum
2
2
  import pathlib
3
3
  import tempfile
4
+ import uuid
4
5
  import warnings
5
6
  from typing import Any, Callable, Optional, Union, overload
6
7
 
7
8
  import pandas as pd
8
9
 
9
- from snowflake import snowpark
10
+ from snowflake.ml import jobs
10
11
  from snowflake.ml._internal import telemetry
11
12
  from snowflake.ml._internal.utils import sql_identifier
12
13
  from snowflake.ml.lineage import lineage_node
13
14
  from snowflake.ml.model import task, type_hints
15
+ from snowflake.ml.model._client.model import batch_inference_specs
14
16
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
15
17
  from snowflake.ml.model._model_composer import model_composer
16
18
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -19,6 +21,7 @@ from snowflake.snowpark import Session, async_job, dataframe
19
21
 
20
22
  _TELEMETRY_PROJECT = "MLOps"
21
23
  _TELEMETRY_SUBPROJECT = "ModelManagement"
24
+ _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
22
25
 
23
26
 
24
27
  class ExportMode(enum.Enum):
@@ -539,6 +542,63 @@ class ModelVersion(lineage_node.LineageNode):
539
542
  is_partitioned=target_function_info["is_partitioned"],
540
543
  )
541
544
 
545
+ @telemetry.send_api_usage_telemetry(
546
+ project=_TELEMETRY_PROJECT,
547
+ subproject=_TELEMETRY_SUBPROJECT,
548
+ func_params_to_log=[
549
+ "compute_pool",
550
+ ],
551
+ )
552
+ def _run_batch(
553
+ self,
554
+ *,
555
+ compute_pool: str,
556
+ input_spec: batch_inference_specs.InputSpec,
557
+ output_spec: batch_inference_specs.OutputSpec,
558
+ job_spec: Optional[batch_inference_specs.JobSpec] = None,
559
+ ) -> jobs.MLJob[Any]:
560
+ statement_params = telemetry.get_statement_params(
561
+ project=_TELEMETRY_PROJECT,
562
+ subproject=_TELEMETRY_SUBPROJECT,
563
+ )
564
+
565
+ if job_spec is None:
566
+ job_spec = batch_inference_specs.JobSpec()
567
+
568
+ warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
569
+ if warehouse is None:
570
+ raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
571
+
572
+ if job_spec.job_name is None:
573
+ # Same as the MLJob ID generation logic with a different prefix
574
+ job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
575
+ else:
576
+ job_name = job_spec.job_name
577
+
578
+ return self._service_ops.invoke_batch_job_method(
579
+ # model version info
580
+ model_name=self._model_name,
581
+ version_name=self._version_name,
582
+ # job spec
583
+ function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
584
+ compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
585
+ force_rebuild=job_spec.force_rebuild,
586
+ image_repo_name=job_spec.image_repo,
587
+ num_workers=job_spec.num_workers,
588
+ max_batch_rows=job_spec.max_batch_rows,
589
+ warehouse=sql_identifier.SqlIdentifier(warehouse),
590
+ cpu_requests=job_spec.cpu_requests,
591
+ memory_requests=job_spec.memory_requests,
592
+ job_name=job_name,
593
+ # input and output
594
+ input_stage_location=input_spec.input_stage_location,
595
+ input_file_pattern=input_spec.input_file_pattern,
596
+ output_stage_location=output_spec.output_stage_location,
597
+ completion_filename=output_spec.completion_filename,
598
+ # misc
599
+ statement_params=statement_params,
600
+ )
601
+
542
602
  def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
543
603
  functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
544
604
 
@@ -707,6 +767,128 @@ class ModelVersion(lineage_node.LineageNode):
707
767
  version_name=sql_identifier.SqlIdentifier(version),
708
768
  )
709
769
 
770
+ def _get_inference_engine_args(
771
+ self, experimental_options: Optional[dict[str, Any]]
772
+ ) -> Optional[service_ops.InferenceEngineArgs]:
773
+
774
+ if not experimental_options:
775
+ return None
776
+
777
+ if "inference_engine" not in experimental_options:
778
+ raise ValueError("inference_engine is required in experimental_options")
779
+
780
+ return service_ops.InferenceEngineArgs(
781
+ inference_engine=experimental_options["inference_engine"],
782
+ inference_engine_args_override=experimental_options.get("inference_engine_args_override"),
783
+ )
784
+
785
+ def _enrich_inference_engine_args(
786
+ self,
787
+ inference_engine_args: service_ops.InferenceEngineArgs,
788
+ gpu_requests: Optional[Union[str, int]] = None,
789
+ ) -> Optional[service_ops.InferenceEngineArgs]:
790
+ """Enrich inference engine args with model path and tensor parallelism settings.
791
+
792
+ Args:
793
+ inference_engine_args: The original inference engine args
794
+ gpu_requests: The number of GPUs requested
795
+
796
+ Returns:
797
+ Enriched inference engine args
798
+
799
+ Raises:
800
+ ValueError: Invalid gpu_requests
801
+ """
802
+ if inference_engine_args.inference_engine_args_override is None:
803
+ inference_engine_args.inference_engine_args_override = []
804
+
805
+ # Get model stage path and strip off "snow://" prefix
806
+ model_stage_path = self._model_ops.get_model_version_stage_path(
807
+ database_name=None,
808
+ schema_name=None,
809
+ model_name=self._model_name,
810
+ version_name=self._version_name,
811
+ )
812
+
813
+ # Strip "snow://" prefix
814
+ if model_stage_path.startswith("snow://"):
815
+ model_stage_path = model_stage_path.replace("snow://", "", 1)
816
+
817
+ # Always overwrite the model key by appending
818
+ inference_engine_args.inference_engine_args_override.append(f"--model={model_stage_path}")
819
+
820
+ gpu_count = None
821
+
822
+ # Set tensor-parallelism if gpu_requests is specified
823
+ if gpu_requests is not None:
824
+ # assert gpu_requests is a string or an integer before casting to int
825
+ if isinstance(gpu_requests, str) or isinstance(gpu_requests, int):
826
+ try:
827
+ gpu_count = int(gpu_requests)
828
+ except ValueError:
829
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
830
+
831
+ if gpu_count is not None:
832
+ if gpu_count > 0:
833
+ inference_engine_args.inference_engine_args_override.append(f"--tensor-parallel-size={gpu_count}")
834
+ else:
835
+ raise ValueError(f"Invalid gpu_requests: {gpu_requests}")
836
+
837
+ return inference_engine_args
838
+
839
+ def _check_huggingface_text_generation_model(
840
+ self,
841
+ statement_params: Optional[dict[str, Any]] = None,
842
+ ) -> None:
843
+ """Check if the model is a HuggingFace pipeline with text-generation task.
844
+
845
+ Args:
846
+ statement_params: Optional dictionary of statement parameters to include
847
+ in the SQL command to fetch model spec.
848
+
849
+ Raises:
850
+ ValueError: If the model is not a HuggingFace text-generation model.
851
+ """
852
+ # Fetch model spec
853
+ model_spec = self._model_ops._fetch_model_spec(
854
+ database_name=None,
855
+ schema_name=None,
856
+ model_name=self._model_name,
857
+ version_name=self._version_name,
858
+ statement_params=statement_params,
859
+ )
860
+
861
+ # Check if model_type is huggingface_pipeline
862
+ model_type = model_spec.get("model_type")
863
+ if model_type != "huggingface_pipeline":
864
+ raise ValueError(
865
+ f"Inference engine is only supported for HuggingFace text-generation models. "
866
+ f"Found model_type: {model_type}"
867
+ )
868
+
869
+ # Check if model supports text-generation task
870
+ # There should only be one model in the list because we don't support multiple models in a single model spec
871
+ models = model_spec.get("models", {})
872
+ is_text_generation = False
873
+ found_tasks: list[str] = []
874
+
875
+ # As long as the model supports text-generation task, we can use it
876
+ for _, model_info in models.items():
877
+ options = model_info.get("options", {})
878
+ task = options.get("task")
879
+ if task:
880
+ found_tasks.append(str(task))
881
+ if task == "text-generation":
882
+ is_text_generation = True
883
+ break
884
+
885
+ if not is_text_generation:
886
+ tasks_str = ", ".join(found_tasks)
887
+ found_tasks_str = (
888
+ f"Found task(s): {tasks_str} in model spec." if found_tasks else "No task found in model spec."
889
+ )
890
+ raise ValueError(f"Inference engine is only supported for task 'text-generation'. {found_tasks_str}")
891
+
710
892
  @overload
711
893
  def create_service(
712
894
  self,
@@ -714,7 +896,7 @@ class ModelVersion(lineage_node.LineageNode):
714
896
  service_name: str,
715
897
  image_build_compute_pool: Optional[str] = None,
716
898
  service_compute_pool: str,
717
- image_repo: str,
899
+ image_repo: Optional[str] = None,
718
900
  ingress_enabled: bool = False,
719
901
  max_instances: int = 1,
720
902
  cpu_requests: Optional[str] = None,
@@ -725,6 +907,7 @@ class ModelVersion(lineage_node.LineageNode):
725
907
  force_rebuild: bool = False,
726
908
  build_external_access_integration: Optional[str] = None,
727
909
  block: bool = True,
910
+ experimental_options: Optional[dict[str, Any]] = None,
728
911
  ) -> Union[str, async_job.AsyncJob]:
729
912
  """Create an inference service with the given spec.
730
913
 
@@ -735,7 +918,8 @@ class ModelVersion(lineage_node.LineageNode):
735
918
  the service compute pool if None.
736
919
  service_compute_pool: The name of the compute pool used to run the inference service.
737
920
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
738
- or schema of the model will be used.
921
+ or schema of the model will be used. This can be None, in that case a default hidden image repository
922
+ will be used.
739
923
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
740
924
  BIND SERVICE ENDPOINT privilege on the account.
741
925
  max_instances: The maximum number of inference service instances to run. The same value it set to
@@ -756,6 +940,10 @@ class ModelVersion(lineage_node.LineageNode):
756
940
  block: A bool value indicating whether this function will wait until the service is available.
757
941
  When it is ``False``, this function executes the underlying service creation asynchronously
758
942
  and returns an :class:`AsyncJob`.
943
+ experimental_options: Experimental options for the service creation with custom inference engine.
944
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
945
+ `inference_engine` is the name of the inference engine to use.
946
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
759
947
  """
760
948
  ...
761
949
 
@@ -766,7 +954,7 @@ class ModelVersion(lineage_node.LineageNode):
766
954
  service_name: str,
767
955
  image_build_compute_pool: Optional[str] = None,
768
956
  service_compute_pool: str,
769
- image_repo: str,
957
+ image_repo: Optional[str] = None,
770
958
  ingress_enabled: bool = False,
771
959
  max_instances: int = 1,
772
960
  cpu_requests: Optional[str] = None,
@@ -777,6 +965,7 @@ class ModelVersion(lineage_node.LineageNode):
777
965
  force_rebuild: bool = False,
778
966
  build_external_access_integrations: Optional[list[str]] = None,
779
967
  block: bool = True,
968
+ experimental_options: Optional[dict[str, Any]] = None,
780
969
  ) -> Union[str, async_job.AsyncJob]:
781
970
  """Create an inference service with the given spec.
782
971
 
@@ -787,7 +976,8 @@ class ModelVersion(lineage_node.LineageNode):
787
976
  the service compute pool if None.
788
977
  service_compute_pool: The name of the compute pool used to run the inference service.
789
978
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
790
- or schema of the model will be used.
979
+ or schema of the model will be used. This can be None, in that case a default hidden image repository
980
+ will be used.
791
981
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
792
982
  BIND SERVICE ENDPOINT privilege on the account.
793
983
  max_instances: The maximum number of inference service instances to run. The same value it set to
@@ -808,6 +998,10 @@ class ModelVersion(lineage_node.LineageNode):
808
998
  block: A bool value indicating whether this function will wait until the service is available.
809
999
  When it is ``False``, this function executes the underlying service creation asynchronously
810
1000
  and returns an :class:`AsyncJob`.
1001
+ experimental_options: Experimental options for the service creation with custom inference engine.
1002
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1003
+ `inference_engine` is the name of the inference engine to use.
1004
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
811
1005
  """
812
1006
  ...
813
1007
 
@@ -832,7 +1026,7 @@ class ModelVersion(lineage_node.LineageNode):
832
1026
  service_name: str,
833
1027
  image_build_compute_pool: Optional[str] = None,
834
1028
  service_compute_pool: str,
835
- image_repo: str,
1029
+ image_repo: Optional[str] = None,
836
1030
  ingress_enabled: bool = False,
837
1031
  max_instances: int = 1,
838
1032
  cpu_requests: Optional[str] = None,
@@ -844,6 +1038,7 @@ class ModelVersion(lineage_node.LineageNode):
844
1038
  build_external_access_integration: Optional[str] = None,
845
1039
  build_external_access_integrations: Optional[list[str]] = None,
846
1040
  block: bool = True,
1041
+ experimental_options: Optional[dict[str, Any]] = None,
847
1042
  ) -> Union[str, async_job.AsyncJob]:
848
1043
  """Create an inference service with the given spec.
849
1044
 
@@ -854,7 +1049,8 @@ class ModelVersion(lineage_node.LineageNode):
854
1049
  the service compute pool if None.
855
1050
  service_compute_pool: The name of the compute pool used to run the inference service.
856
1051
  image_repo: The name of the image repository, can be fully qualified. If not fully qualified, the database
857
- or schema of the model will be used.
1052
+ or schema of the model will be used. This can be None, in that case a default hidden image repository
1053
+ will be used.
858
1054
  ingress_enabled: If true, creates an service endpoint associated with the service. User must have
859
1055
  BIND SERVICE ENDPOINT privilege on the account.
860
1056
  max_instances: The maximum number of inference service instances to run. The same value it set to
@@ -877,6 +1073,11 @@ class ModelVersion(lineage_node.LineageNode):
877
1073
  block: A bool value indicating whether this function will wait until the service is available.
878
1074
  When it is False, this function executes the underlying service creation asynchronously
879
1075
  and returns an AsyncJob.
1076
+ experimental_options: Experimental options for the service creation with custom inference engine.
1077
+ Currently, only `inference_engine` and `inference_engine_args_override` are supported.
1078
+ `inference_engine` is the name of the inference engine to use.
1079
+ `inference_engine_args_override` is a list of string arguments to pass to the inference engine.
1080
+
880
1081
 
881
1082
  Raises:
882
1083
  ValueError: Illegal external access integration arguments.
@@ -885,6 +1086,9 @@ class ModelVersion(lineage_node.LineageNode):
885
1086
  Returns:
886
1087
  If `block=True`, return result information about service creation from server.
887
1088
  Otherwise, return the service creation AsyncJob.
1089
+
1090
+ Raises:
1091
+ ValueError: Illegal external access integration arguments.
888
1092
  """
889
1093
  statement_params = telemetry.get_statement_params(
890
1094
  project=_TELEMETRY_PROJECT,
@@ -906,7 +1110,18 @@ class ModelVersion(lineage_node.LineageNode):
906
1110
  build_external_access_integrations = [build_external_access_integration]
907
1111
 
908
1112
  service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
909
- image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
1113
+
1114
+ # Check if model is HuggingFace text-generation before doing inference engine checks
1115
+ if experimental_options:
1116
+ self._check_huggingface_text_generation_model(statement_params)
1117
+
1118
+ inference_engine_args: Optional[service_ops.InferenceEngineArgs] = self._get_inference_engine_args(
1119
+ experimental_options
1120
+ )
1121
+
1122
+ # Enrich inference engine args if inference engine is specified
1123
+ if inference_engine_args is not None:
1124
+ inference_engine_args = self._enrich_inference_engine_args(inference_engine_args, gpu_requests)
910
1125
 
911
1126
  from snowflake.ml.model import event_handler
912
1127
  from snowflake.snowpark import exceptions
@@ -929,7 +1144,7 @@ class ModelVersion(lineage_node.LineageNode):
929
1144
  else sql_identifier.SqlIdentifier(service_compute_pool)
930
1145
  ),
931
1146
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
932
- image_repo=image_repo,
1147
+ image_repo_name=image_repo,
933
1148
  ingress_enabled=ingress_enabled,
934
1149
  max_instances=max_instances,
935
1150
  cpu_requests=cpu_requests,
@@ -946,6 +1161,7 @@ class ModelVersion(lineage_node.LineageNode):
946
1161
  block=block,
947
1162
  statement_params=statement_params,
948
1163
  progress_status=status,
1164
+ inference_engine_args=inference_engine_args,
949
1165
  )
950
1166
  status.update(label="Model service created successfully", state="complete", expanded=False)
951
1167
  return result
@@ -1028,69 +1244,5 @@ class ModelVersion(lineage_node.LineageNode):
1028
1244
  statement_params=statement_params,
1029
1245
  )
1030
1246
 
1031
- @snowpark._internal.utils.private_preview(version="1.8.3")
1032
- @telemetry.send_api_usage_telemetry(
1033
- project=_TELEMETRY_PROJECT,
1034
- subproject=_TELEMETRY_SUBPROJECT,
1035
- )
1036
- def _run_job(
1037
- self,
1038
- X: Union[pd.DataFrame, "dataframe.DataFrame"],
1039
- *,
1040
- job_name: str,
1041
- compute_pool: str,
1042
- image_repo: str,
1043
- output_table_name: str,
1044
- function_name: Optional[str] = None,
1045
- cpu_requests: Optional[str] = None,
1046
- memory_requests: Optional[str] = None,
1047
- gpu_requests: Optional[Union[str, int]] = None,
1048
- num_workers: Optional[int] = None,
1049
- max_batch_rows: Optional[int] = None,
1050
- force_rebuild: bool = False,
1051
- build_external_access_integrations: Optional[list[str]] = None,
1052
- ) -> Union[pd.DataFrame, dataframe.DataFrame]:
1053
- statement_params = telemetry.get_statement_params(
1054
- project=_TELEMETRY_PROJECT,
1055
- subproject=_TELEMETRY_SUBPROJECT,
1056
- )
1057
- target_function_info = self._get_function_info(function_name=function_name)
1058
- job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
1059
- output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
1060
- output_table_name
1061
- )
1062
- warehouse = self._service_ops._session.get_current_warehouse()
1063
- assert warehouse, "No active warehouse selected in the current session."
1064
- return self._service_ops.invoke_job_method(
1065
- target_method=target_function_info["target_method"],
1066
- signature=target_function_info["signature"],
1067
- X=X,
1068
- database_name=None,
1069
- schema_name=None,
1070
- model_name=self._model_name,
1071
- version_name=self._version_name,
1072
- job_database_name=job_db_id,
1073
- job_schema_name=job_schema_id,
1074
- job_name=job_id,
1075
- compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
1076
- warehouse_name=sql_identifier.SqlIdentifier(warehouse),
1077
- image_repo=image_repo,
1078
- output_table_database_name=output_table_db_id,
1079
- output_table_schema_name=output_table_schema_id,
1080
- output_table_name=output_table_id,
1081
- cpu_requests=cpu_requests,
1082
- memory_requests=memory_requests,
1083
- gpu_requests=gpu_requests,
1084
- num_workers=num_workers,
1085
- max_batch_rows=max_batch_rows,
1086
- force_rebuild=force_rebuild,
1087
- build_external_access_integrations=(
1088
- None
1089
- if build_external_access_integrations is None
1090
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
1091
- ),
1092
- statement_params=statement_params,
1093
- )
1094
-
1095
1247
 
1096
1248
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion