snowflake-ml-python 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/telemetry.py +3 -1
  3. snowflake/ml/_internal/utils/service_logger.py +26 -1
  4. snowflake/ml/experiment/_client/artifact.py +76 -0
  5. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  6. snowflake/ml/experiment/experiment_tracking.py +113 -6
  7. snowflake/ml/feature_store/feature_store.py +1150 -131
  8. snowflake/ml/feature_store/feature_view.py +122 -0
  9. snowflake/ml/jobs/_utils/constants.py +8 -16
  10. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  11. snowflake/ml/jobs/_utils/payload_utils.py +19 -5
  12. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  13. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +23 -5
  14. snowflake/ml/jobs/_utils/spec_utils.py +4 -6
  15. snowflake/ml/jobs/_utils/types.py +2 -1
  16. snowflake/ml/jobs/job.py +38 -19
  17. snowflake/ml/jobs/manager.py +136 -19
  18. snowflake/ml/model/__init__.py +6 -1
  19. snowflake/ml/model/_client/model/batch_inference_specs.py +25 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +62 -65
  21. snowflake/ml/model/_client/ops/model_ops.py +42 -9
  22. snowflake/ml/model/_client/ops/service_ops.py +75 -154
  23. snowflake/ml/model/_client/service/model_deployment_spec.py +23 -37
  24. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +15 -4
  25. snowflake/ml/model/_client/sql/service.py +4 -0
  26. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +309 -22
  27. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  28. snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
  29. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  30. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  31. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  32. snowflake/ml/model/_signatures/utils.py +4 -2
  33. snowflake/ml/model/models/huggingface_pipeline.py +23 -0
  34. snowflake/ml/model/openai_signatures.py +57 -0
  35. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  36. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  37. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  38. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  39. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  40. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  41. snowflake/ml/modeling/cluster/birch.py +1 -1
  42. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  43. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  44. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  45. snowflake/ml/modeling/cluster/k_means.py +1 -1
  46. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  47. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  48. snowflake/ml/modeling/cluster/optics.py +1 -1
  49. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  50. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  51. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  52. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  53. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  54. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  55. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  56. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  57. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  58. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  59. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  60. snowflake/ml/modeling/covariance/oas.py +1 -1
  61. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  62. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  63. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  64. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  65. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  66. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  67. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  68. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  69. snowflake/ml/modeling/decomposition/pca.py +1 -1
  70. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  71. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  72. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  73. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  74. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  79. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  80. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  81. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  84. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  85. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  86. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  87. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  88. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  89. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  90. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  91. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  92. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  93. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  94. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  95. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  96. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  97. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  98. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  99. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  100. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  101. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  102. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  103. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  104. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  105. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  106. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  107. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  108. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  109. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  110. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  112. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  113. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  114. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  116. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  117. snowflake/ml/modeling/linear_model/lars.py +1 -1
  118. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  119. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  120. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  122. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  124. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  125. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  126. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  127. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  128. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  129. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  130. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  131. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  132. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  134. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  135. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  136. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  137. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  138. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  139. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  140. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  141. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  142. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  143. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  144. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  145. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  146. snowflake/ml/modeling/manifold/isomap.py +1 -1
  147. snowflake/ml/modeling/manifold/mds.py +1 -1
  148. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  149. snowflake/ml/modeling/manifold/tsne.py +1 -1
  150. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  151. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  152. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  153. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  154. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  155. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  156. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  157. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  158. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  159. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  160. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  161. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  162. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  163. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  164. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  165. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  166. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  167. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  168. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  169. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  170. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  171. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  172. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  173. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  174. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  175. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  176. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  177. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  178. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  179. snowflake/ml/modeling/svm/svc.py +1 -1
  180. snowflake/ml/modeling/svm/svr.py +1 -1
  181. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  182. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  183. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  184. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  185. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  186. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  187. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  188. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  189. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  190. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  191. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  192. snowflake/ml/monitoring/model_monitor.py +26 -0
  193. snowflake/ml/version.py +1 -1
  194. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/METADATA +82 -5
  195. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/RECORD +198 -194
  196. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/WHEEL +0 -0
  197. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/licenses/LICENSE.txt +0 -0
  198. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
+ import json
1
2
  import logging
2
3
  import pathlib
3
4
  import textwrap
5
+ from pathlib import PurePath
4
6
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
5
7
  from uuid import uuid4
6
8
 
@@ -11,7 +13,13 @@ from snowflake import snowpark
11
13
  from snowflake.ml._internal import telemetry
12
14
  from snowflake.ml._internal.utils import identifier
13
15
  from snowflake.ml.jobs import job as jb
14
- from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
16
+ from snowflake.ml.jobs._utils import (
17
+ feature_flags,
18
+ payload_utils,
19
+ query_helper,
20
+ spec_utils,
21
+ types,
22
+ )
15
23
  from snowflake.snowpark.context import get_active_session
16
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
17
25
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -50,7 +58,8 @@ def list_jobs(
50
58
  >>> from snowflake.ml.jobs import list_jobs
51
59
  >>> list_jobs(limit=5)
52
60
  """
53
- session = session or get_active_session()
61
+
62
+ session = _ensure_session(session)
54
63
  try:
55
64
  df = _get_job_history_spcs(
56
65
  session,
@@ -154,7 +163,7 @@ def _get_job_history_spcs(
154
163
  @telemetry.send_api_usage_telemetry(project=_PROJECT)
155
164
  def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
156
165
  """Retrieve a job service from the backend."""
157
- session = session or get_active_session()
166
+ session = _ensure_session(session)
158
167
  try:
159
168
  database, schema, job_name = identifier.parse_schema_level_object_identifier(job_id)
160
169
  database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
@@ -426,8 +435,10 @@ def _submit_job(
426
435
 
427
436
  Raises:
428
437
  ValueError: If database or schema value(s) are invalid
438
+ RuntimeError: If schema is not specified in session context or job submission
439
+ snowpark.exceptions.SnowparkSQLException: if failed to upload payload
429
440
  """
430
- session = session or get_active_session()
441
+ session = _ensure_session(session)
431
442
 
432
443
  # Check for deprecated args
433
444
  if "num_instances" in kwargs:
@@ -445,7 +456,7 @@ def _submit_job(
445
456
  env_vars = kwargs.pop("env_vars", None)
446
457
  spec_overrides = kwargs.pop("spec_overrides", None)
447
458
  enable_metrics = kwargs.pop("enable_metrics", True)
448
- query_warehouse = kwargs.pop("query_warehouse", None)
459
+ query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
449
460
  additional_payloads = kwargs.pop("additional_payloads", None)
450
461
 
451
462
  if additional_payloads:
@@ -478,11 +489,39 @@ def _submit_job(
478
489
  stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
479
490
  stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
480
491
 
481
- # Upload payload
482
- uploaded_payload = payload_utils.JobPayload(
483
- source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
484
- ).upload(session, stage_path)
492
+ try:
493
+ # Upload payload
494
+ uploaded_payload = payload_utils.JobPayload(
495
+ source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
496
+ ).upload(session, stage_path)
497
+ except snowpark.exceptions.SnowparkSQLException as e:
498
+ if e.sql_error_code == 90106:
499
+ raise RuntimeError(
500
+ "Please specify a schema, either in the session context or as a parameter in the job submission"
501
+ )
502
+ raise
485
503
 
504
+ if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
505
+ # Add default env vars (extracted from spec_utils.generate_service_spec)
506
+ combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
507
+
508
+ return _do_submit_job_v2(
509
+ session=session,
510
+ payload=uploaded_payload,
511
+ args=args,
512
+ env_vars=combined_env_vars,
513
+ spec_overrides=spec_overrides,
514
+ compute_pool=compute_pool,
515
+ job_id=job_id,
516
+ external_access_integrations=external_access_integrations,
517
+ query_warehouse=query_warehouse,
518
+ target_instances=target_instances,
519
+ min_instances=min_instances,
520
+ enable_metrics=enable_metrics,
521
+ use_async=True,
522
+ )
523
+
524
+ # Fall back to v1
486
525
  # Generate service spec
487
526
  spec = spec_utils.generate_service_spec(
488
527
  session,
@@ -493,6 +532,8 @@ def _submit_job(
493
532
  min_instances=min_instances,
494
533
  enable_metrics=enable_metrics,
495
534
  )
535
+
536
+ # Generate spec overrides
496
537
  spec_overrides = spec_utils.generate_spec_overrides(
497
538
  environment_vars=env_vars,
498
539
  custom_overrides=spec_overrides,
@@ -500,26 +541,25 @@ def _submit_job(
500
541
  if spec_overrides:
501
542
  spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
502
543
 
503
- query_text, params = _generate_submission_query(
504
- spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
544
+ return _do_submit_job_v1(
545
+ session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
505
546
  )
506
- _ = query_helper.run_query(session, query_text, params=params)
507
- return get_job(job_id, session=session)
508
547
 
509
548
 
510
- def _generate_submission_query(
549
+ def _do_submit_job_v1(
550
+ session: snowpark.Session,
511
551
  spec: dict[str, Any],
512
552
  external_access_integrations: list[str],
513
553
  query_warehouse: Optional[str],
514
554
  target_instances: int,
515
- session: snowpark.Session,
516
555
  compute_pool: str,
517
556
  job_id: str,
518
- ) -> tuple[str, list[Any]]:
557
+ ) -> jb.MLJob[Any]:
519
558
  """
520
559
  Generate the SQL query for job submission.
521
560
 
522
561
  Args:
562
+ session: The Snowpark session to use.
523
563
  spec: The service spec for the job.
524
564
  external_access_integrations: The external access integrations for the job.
525
565
  query_warehouse: The query warehouse for the job.
@@ -529,7 +569,7 @@ def _generate_submission_query(
529
569
  job_id: The ID of the job.
530
570
 
531
571
  Returns:
532
- A tuple containing the SQL query text and the parameters for the query.
572
+ The job object.
533
573
  """
534
574
  query_template = textwrap.dedent(
535
575
  """\
@@ -547,12 +587,89 @@ def _generate_submission_query(
547
587
  if external_access_integrations:
548
588
  external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
549
589
  query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
550
- query_warehouse = query_warehouse or session.get_current_warehouse()
551
590
  if query_warehouse:
552
591
  query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
553
592
  params.append(query_warehouse)
554
593
  if target_instances > 1:
555
594
  query.append("REPLICAS = ?")
556
595
  params.append(target_instances)
596
+
557
597
  query_text = "\n".join(line for line in query if line)
558
- return query_text, params
598
+ _ = query_helper.run_query(session, query_text, params=params)
599
+
600
+ return get_job(job_id, session=session)
601
+
602
+
603
+ def _do_submit_job_v2(
604
+ session: snowpark.Session,
605
+ payload: types.UploadedPayload,
606
+ args: Optional[list[str]],
607
+ env_vars: dict[str, str],
608
+ spec_overrides: dict[str, Any],
609
+ compute_pool: str,
610
+ job_id: Optional[str] = None,
611
+ external_access_integrations: Optional[list[str]] = None,
612
+ query_warehouse: Optional[str] = None,
613
+ target_instances: int = 1,
614
+ min_instances: int = 1,
615
+ enable_metrics: bool = True,
616
+ use_async: bool = True,
617
+ ) -> jb.MLJob[Any]:
618
+ """
619
+ Generate the SQL query for job submission.
620
+
621
+ Args:
622
+ session: The Snowpark session to use.
623
+ payload: The uploaded job payload.
624
+ args: Arguments to pass to the entrypoint script.
625
+ env_vars: Environment variables to set in the job container.
626
+ spec_overrides: Custom service specification overrides.
627
+ compute_pool: The compute pool to use for job execution.
628
+ job_id: The ID of the job.
629
+ external_access_integrations: Optional list of external access integrations.
630
+ query_warehouse: Optional query warehouse to use.
631
+ target_instances: Number of instances for multi-node job.
632
+ min_instances: Minimum number of instances required to start the job.
633
+ enable_metrics: Whether to enable platform metrics for the job.
634
+ use_async: Whether to run the job asynchronously.
635
+
636
+ Returns:
637
+ The job object.
638
+ """
639
+ args = [
640
+ (payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
641
+ ] + (args or [])
642
+ spec_options = {
643
+ "STAGE_PATH": payload.stage_path.as_posix(),
644
+ "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
645
+ "ARGS": args,
646
+ "ENV_VARS": env_vars,
647
+ "ENABLE_METRICS": enable_metrics,
648
+ "SPEC_OVERRIDES": spec_overrides,
649
+ }
650
+ job_options = {
651
+ "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
652
+ "QUERY_WAREHOUSE": query_warehouse,
653
+ "TARGET_INSTANCES": target_instances,
654
+ "MIN_INSTANCES": min_instances,
655
+ "ASYNC": use_async,
656
+ }
657
+ job_options = {k: v for k, v in job_options.items() if v is not None}
658
+
659
+ query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
660
+ params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
661
+ actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
662
+
663
+ return get_job(actual_job_id, session=session)
664
+
665
+
666
+ def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
667
+ try:
668
+ session = session or get_active_session()
669
+ except snowpark.exceptions.SnowparkSessionException as e:
670
+ if "More than one active session" in e.message:
671
+ raise RuntimeError("Please specify the session as a parameter in API call")
672
+ if "No default Session is found" in e.message:
673
+ raise RuntimeError("Please create a session before API call")
674
+ raise
675
+ return session
@@ -1,5 +1,10 @@
1
+ from snowflake.ml.model._client.model.batch_inference_specs import (
2
+ InputSpec,
3
+ JobSpec,
4
+ OutputSpec,
5
+ )
1
6
  from snowflake.ml.model._client.model.model_impl import Model
2
7
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
3
8
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
4
9
 
5
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
10
+ __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
@@ -0,0 +1,25 @@
1
+ from typing import Optional, Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class InputSpec(BaseModel):
7
+ stage_location: str
8
+
9
+
10
+ class OutputSpec(BaseModel):
11
+ stage_location: str
12
+
13
+
14
+ class JobSpec(BaseModel):
15
+ image_repo: Optional[str] = None
16
+ job_name: Optional[str] = None
17
+ num_workers: Optional[int] = None
18
+ function_name: Optional[str] = None
19
+ gpu: Optional[Union[str, int]] = None
20
+ force_rebuild: bool = False
21
+ max_batch_rows: int = 1024
22
+ warehouse: Optional[str] = None
23
+ cpu_requests: Optional[str] = None
24
+ memory_requests: Optional[str] = None
25
+ replicas: Optional[int] = None
@@ -1,16 +1,18 @@
1
1
  import enum
2
2
  import pathlib
3
3
  import tempfile
4
+ import uuid
4
5
  import warnings
5
6
  from typing import Any, Callable, Optional, Union, overload
6
7
 
7
8
  import pandas as pd
8
9
 
9
- from snowflake import snowpark
10
+ from snowflake.ml import jobs
10
11
  from snowflake.ml._internal import telemetry
11
12
  from snowflake.ml._internal.utils import sql_identifier
12
13
  from snowflake.ml.lineage import lineage_node
13
14
  from snowflake.ml.model import task, type_hints
15
+ from snowflake.ml.model._client.model import batch_inference_specs
14
16
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
15
17
  from snowflake.ml.model._model_composer import model_composer
16
18
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -19,6 +21,7 @@ from snowflake.snowpark import Session, async_job, dataframe
19
21
 
20
22
  _TELEMETRY_PROJECT = "MLOps"
21
23
  _TELEMETRY_SUBPROJECT = "ModelManagement"
24
+ _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
22
25
 
23
26
 
24
27
  class ExportMode(enum.Enum):
@@ -539,6 +542,64 @@ class ModelVersion(lineage_node.LineageNode):
539
542
  is_partitioned=target_function_info["is_partitioned"],
540
543
  )
541
544
 
545
+ @telemetry.send_api_usage_telemetry(
546
+ project=_TELEMETRY_PROJECT,
547
+ subproject=_TELEMETRY_SUBPROJECT,
548
+ func_params_to_log=[
549
+ "compute_pool",
550
+ ],
551
+ )
552
+ def _run_batch(
553
+ self,
554
+ *,
555
+ compute_pool: str,
556
+ input_spec: batch_inference_specs.InputSpec,
557
+ output_spec: batch_inference_specs.OutputSpec,
558
+ job_spec: Optional[batch_inference_specs.JobSpec] = None,
559
+ ) -> jobs.MLJob[Any]:
560
+ statement_params = telemetry.get_statement_params(
561
+ project=_TELEMETRY_PROJECT,
562
+ subproject=_TELEMETRY_SUBPROJECT,
563
+ )
564
+
565
+ if job_spec is None:
566
+ job_spec = batch_inference_specs.JobSpec()
567
+
568
+ warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
569
+ if warehouse is None:
570
+ raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
571
+
572
+ if job_spec.job_name is None:
573
+ # Same as the MLJob ID generation logic with a different prefix
574
+ job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
575
+ else:
576
+ job_name = job_spec.job_name
577
+
578
+ return self._service_ops.invoke_batch_job_method(
579
+ # model version info
580
+ model_name=self._model_name,
581
+ version_name=self._version_name,
582
+ # job spec
583
+ function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
584
+ compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
585
+ force_rebuild=job_spec.force_rebuild,
586
+ image_repo_name=job_spec.image_repo,
587
+ num_workers=job_spec.num_workers,
588
+ max_batch_rows=job_spec.max_batch_rows,
589
+ warehouse=sql_identifier.SqlIdentifier(warehouse),
590
+ cpu_requests=job_spec.cpu_requests,
591
+ memory_requests=job_spec.memory_requests,
592
+ job_name=job_name,
593
+ replicas=job_spec.replicas,
594
+ # input and output
595
+ input_stage_location=input_spec.stage_location,
596
+ input_file_pattern="*",
597
+ output_stage_location=output_spec.stage_location,
598
+ completion_filename="_SUCCESS",
599
+ # misc
600
+ statement_params=statement_params,
601
+ )
602
+
542
603
  def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
543
604
  functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
544
605
 
@@ -1184,69 +1245,5 @@ class ModelVersion(lineage_node.LineageNode):
1184
1245
  statement_params=statement_params,
1185
1246
  )
1186
1247
 
1187
- @snowpark._internal.utils.private_preview(version="1.8.3")
1188
- @telemetry.send_api_usage_telemetry(
1189
- project=_TELEMETRY_PROJECT,
1190
- subproject=_TELEMETRY_SUBPROJECT,
1191
- )
1192
- def _run_job(
1193
- self,
1194
- X: Union[pd.DataFrame, "dataframe.DataFrame"],
1195
- *,
1196
- job_name: str,
1197
- compute_pool: str,
1198
- image_repo: Optional[str] = None,
1199
- output_table_name: str,
1200
- function_name: Optional[str] = None,
1201
- cpu_requests: Optional[str] = None,
1202
- memory_requests: Optional[str] = None,
1203
- gpu_requests: Optional[Union[str, int]] = None,
1204
- num_workers: Optional[int] = None,
1205
- max_batch_rows: Optional[int] = None,
1206
- force_rebuild: bool = False,
1207
- build_external_access_integrations: Optional[list[str]] = None,
1208
- ) -> Union[pd.DataFrame, dataframe.DataFrame]:
1209
- statement_params = telemetry.get_statement_params(
1210
- project=_TELEMETRY_PROJECT,
1211
- subproject=_TELEMETRY_SUBPROJECT,
1212
- )
1213
- target_function_info = self._get_function_info(function_name=function_name)
1214
- job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
1215
- output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
1216
- output_table_name
1217
- )
1218
- warehouse = self._service_ops._session.get_current_warehouse()
1219
- assert warehouse, "No active warehouse selected in the current session."
1220
- return self._service_ops.invoke_job_method(
1221
- target_method=target_function_info["target_method"],
1222
- signature=target_function_info["signature"],
1223
- X=X,
1224
- database_name=None,
1225
- schema_name=None,
1226
- model_name=self._model_name,
1227
- version_name=self._version_name,
1228
- job_database_name=job_db_id,
1229
- job_schema_name=job_schema_id,
1230
- job_name=job_id,
1231
- compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
1232
- warehouse_name=sql_identifier.SqlIdentifier(warehouse),
1233
- image_repo_name=image_repo,
1234
- output_table_database_name=output_table_db_id,
1235
- output_table_schema_name=output_table_schema_id,
1236
- output_table_name=output_table_id,
1237
- cpu_requests=cpu_requests,
1238
- memory_requests=memory_requests,
1239
- gpu_requests=gpu_requests,
1240
- num_workers=num_workers,
1241
- max_batch_rows=max_batch_rows,
1242
- force_rebuild=force_rebuild,
1243
- build_external_access_integrations=(
1244
- None
1245
- if build_external_access_integrations is None
1246
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
1247
- ),
1248
- statement_params=statement_params,
1249
- )
1250
-
1251
1248
 
1252
1249
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -47,6 +47,7 @@ class ServiceInfo(TypedDict):
47
47
  class ModelOperator:
48
48
  INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
49
49
  INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
50
+ PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING = "privatelink.snowflakecomputing"
50
51
 
51
52
  def __init__(
52
53
  self,
@@ -612,6 +613,30 @@ class ModelOperator:
612
613
  statement_params=statement_params,
613
614
  )
614
615
 
616
+ def _is_privatelink_connection(self) -> bool:
617
+ """Detect if the current session is using a privatelink connection."""
618
+ try:
619
+ host = self._session.connection.host
620
+ return ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in host
621
+ except AttributeError:
622
+ return False
623
+
624
+ def _extract_and_validate_ingress_url(self, res_row: "row.Row") -> Optional[str]:
625
+ """Extract and validate ingress URL from endpoint row."""
626
+ url_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
627
+ if url_value is None:
628
+ return None
629
+ url_str = str(url_value)
630
+ return url_str if url_str.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX) else None
631
+
632
+ def _extract_and_validate_privatelink_url(self, res_row: "row.Row") -> Optional[str]:
633
+ """Extract and validate privatelink ingress URL from endpoint row."""
634
+ url_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME]
635
+ if url_value is None:
636
+ return None
637
+ url_str = str(url_value)
638
+ return url_str if ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in url_str else None
639
+
615
640
  def show_services(
616
641
  self,
617
642
  *,
@@ -644,8 +669,10 @@ class ModelOperator:
644
669
  fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
645
670
 
646
671
  result: list[ServiceInfo] = []
672
+ is_privatelink_connection = self._is_privatelink_connection()
673
+
647
674
  for fully_qualified_service_name in fully_qualified_service_names:
648
- ingress_url: Optional[str] = None
675
+ inference_endpoint: Optional[str] = None
649
676
  db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
650
677
  statuses = self._service_client.get_service_container_statuses(
651
678
  database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
@@ -659,17 +686,23 @@ class ModelOperator:
659
686
  ):
660
687
  if (
661
688
  res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
662
- == self.INFERENCE_SERVICE_ENDPOINT_NAME
663
- and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
689
+ != self.INFERENCE_SERVICE_ENDPOINT_NAME
664
690
  ):
665
- ingress_url = str(
666
- res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
667
- )
668
- if not ingress_url.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX):
669
- ingress_url = None
691
+ continue
692
+
693
+ ingress_url = self._extract_and_validate_ingress_url(res_row)
694
+ privatelink_ingress_url = self._extract_and_validate_privatelink_url(res_row)
695
+
696
+ if is_privatelink_connection and privatelink_ingress_url is not None:
697
+ inference_endpoint = privatelink_ingress_url
698
+ else:
699
+ inference_endpoint = ingress_url
700
+
670
701
  result.append(
671
702
  ServiceInfo(
672
- name=fully_qualified_service_name, status=service_status.value, inference_endpoint=ingress_url
703
+ name=fully_qualified_service_name,
704
+ status=service_status.value,
705
+ inference_endpoint=inference_endpoint,
673
706
  )
674
707
  )
675
708