snowflake-ml-python 1.11.0__py3-none-any.whl → 1.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (192) hide show
  1. snowflake/cortex/_complete.py +3 -2
  2. snowflake/ml/_internal/utils/service_logger.py +26 -1
  3. snowflake/ml/experiment/_client/artifact.py +76 -0
  4. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
  5. snowflake/ml/experiment/experiment_tracking.py +89 -4
  6. snowflake/ml/feature_store/feature_store.py +1150 -131
  7. snowflake/ml/feature_store/feature_view.py +122 -0
  8. snowflake/ml/jobs/_utils/constants.py +8 -16
  9. snowflake/ml/jobs/_utils/feature_flags.py +16 -0
  10. snowflake/ml/jobs/_utils/payload_utils.py +19 -5
  11. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +12 -4
  13. snowflake/ml/jobs/_utils/spec_utils.py +4 -6
  14. snowflake/ml/jobs/_utils/types.py +2 -1
  15. snowflake/ml/jobs/job.py +33 -17
  16. snowflake/ml/jobs/manager.py +107 -12
  17. snowflake/ml/model/__init__.py +6 -1
  18. snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
  19. snowflake/ml/model/_client/model/model_version_impl.py +61 -65
  20. snowflake/ml/model/_client/ops/service_ops.py +73 -154
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +20 -37
  22. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +14 -4
  23. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
  24. snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
  25. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
  26. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  27. snowflake/ml/model/_signatures/utils.py +4 -2
  28. snowflake/ml/model/openai_signatures.py +57 -0
  29. snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
  30. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
  31. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  32. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
  33. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
  34. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
  35. snowflake/ml/modeling/cluster/birch.py +1 -1
  36. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
  37. snowflake/ml/modeling/cluster/dbscan.py +1 -1
  38. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
  39. snowflake/ml/modeling/cluster/k_means.py +1 -1
  40. snowflake/ml/modeling/cluster/mean_shift.py +1 -1
  41. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
  42. snowflake/ml/modeling/cluster/optics.py +1 -1
  43. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
  44. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
  45. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
  46. snowflake/ml/modeling/compose/column_transformer.py +1 -1
  47. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
  48. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
  49. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
  51. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
  52. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
  53. snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
  54. snowflake/ml/modeling/covariance/oas.py +1 -1
  55. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
  56. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
  57. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
  58. snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
  59. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
  60. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
  62. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
  63. snowflake/ml/modeling/decomposition/pca.py +1 -1
  64. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
  65. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
  66. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
  67. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
  69. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
  70. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
  71. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
  73. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
  75. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
  77. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
  78. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
  79. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
  80. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
  81. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
  82. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
  83. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
  84. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
  85. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
  86. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
  87. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
  88. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
  89. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
  90. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
  91. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
  93. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
  94. snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
  95. snowflake/ml/modeling/impute/knn_imputer.py +1 -1
  96. snowflake/ml/modeling/impute/missing_indicator.py +1 -1
  97. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
  98. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
  99. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
  100. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
  101. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
  102. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
  104. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
  105. snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
  106. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
  107. snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
  108. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
  109. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
  110. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
  111. snowflake/ml/modeling/linear_model/lars.py +1 -1
  112. snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
  113. snowflake/ml/modeling/linear_model/lasso.py +1 -1
  114. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
  117. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
  118. snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
  120. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
  122. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
  124. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
  125. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
  127. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
  128. snowflake/ml/modeling/linear_model/perceptron.py +1 -1
  129. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
  130. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
  131. snowflake/ml/modeling/linear_model/ridge.py +1 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
  133. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
  134. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
  135. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
  136. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
  137. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
  138. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
  139. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
  140. snowflake/ml/modeling/manifold/isomap.py +1 -1
  141. snowflake/ml/modeling/manifold/mds.py +1 -1
  142. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
  143. snowflake/ml/modeling/manifold/tsne.py +1 -1
  144. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
  145. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
  146. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
  147. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
  148. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
  149. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
  150. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
  151. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
  152. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
  153. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
  155. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
  156. snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
  157. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
  158. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
  159. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
  160. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
  162. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
  163. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
  164. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
  165. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
  166. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
  167. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
  168. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
  169. snowflake/ml/modeling/svm/linear_svc.py +1 -1
  170. snowflake/ml/modeling/svm/linear_svr.py +1 -1
  171. snowflake/ml/modeling/svm/nu_svc.py +1 -1
  172. snowflake/ml/modeling/svm/nu_svr.py +1 -1
  173. snowflake/ml/modeling/svm/svc.py +1 -1
  174. snowflake/ml/modeling/svm/svr.py +1 -1
  175. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
  176. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
  177. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
  178. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
  179. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
  180. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
  182. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
  183. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
  184. snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
  185. snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
  186. snowflake/ml/monitoring/model_monitor.py +26 -0
  187. snowflake/ml/version.py +1 -1
  188. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +66 -5
  189. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +192 -188
  190. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
  191. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
  192. {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
+ import json
1
2
  import logging
2
3
  import pathlib
3
4
  import textwrap
5
+ from pathlib import PurePath
4
6
  from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
5
7
  from uuid import uuid4
6
8
 
@@ -11,7 +13,13 @@ from snowflake import snowpark
11
13
  from snowflake.ml._internal import telemetry
12
14
  from snowflake.ml._internal.utils import identifier
13
15
  from snowflake.ml.jobs import job as jb
14
- from snowflake.ml.jobs._utils import payload_utils, query_helper, spec_utils
16
+ from snowflake.ml.jobs._utils import (
17
+ feature_flags,
18
+ payload_utils,
19
+ query_helper,
20
+ spec_utils,
21
+ types,
22
+ )
15
23
  from snowflake.snowpark.context import get_active_session
16
24
  from snowflake.snowpark.exceptions import SnowparkSQLException
17
25
  from snowflake.snowpark.functions import coalesce, col, lit, when
@@ -445,7 +453,7 @@ def _submit_job(
445
453
  env_vars = kwargs.pop("env_vars", None)
446
454
  spec_overrides = kwargs.pop("spec_overrides", None)
447
455
  enable_metrics = kwargs.pop("enable_metrics", True)
448
- query_warehouse = kwargs.pop("query_warehouse", None)
456
+ query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
449
457
  additional_payloads = kwargs.pop("additional_payloads", None)
450
458
 
451
459
  if additional_payloads:
@@ -483,6 +491,27 @@ def _submit_job(
483
491
  source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
484
492
  ).upload(session, stage_path)
485
493
 
494
+ if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
495
+ # Add default env vars (extracted from spec_utils.generate_service_spec)
496
+ combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
497
+
498
+ return _do_submit_job_v2(
499
+ session=session,
500
+ payload=uploaded_payload,
501
+ args=args,
502
+ env_vars=combined_env_vars,
503
+ spec_overrides=spec_overrides,
504
+ compute_pool=compute_pool,
505
+ job_id=job_id,
506
+ external_access_integrations=external_access_integrations,
507
+ query_warehouse=query_warehouse,
508
+ target_instances=target_instances,
509
+ min_instances=min_instances,
510
+ enable_metrics=enable_metrics,
511
+ use_async=True,
512
+ )
513
+
514
+ # Fall back to v1
486
515
  # Generate service spec
487
516
  spec = spec_utils.generate_service_spec(
488
517
  session,
@@ -493,6 +522,8 @@ def _submit_job(
493
522
  min_instances=min_instances,
494
523
  enable_metrics=enable_metrics,
495
524
  )
525
+
526
+ # Generate spec overrides
496
527
  spec_overrides = spec_utils.generate_spec_overrides(
497
528
  environment_vars=env_vars,
498
529
  custom_overrides=spec_overrides,
@@ -500,26 +531,25 @@ def _submit_job(
500
531
  if spec_overrides:
501
532
  spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
502
533
 
503
- query_text, params = _generate_submission_query(
504
- spec, external_access_integrations, query_warehouse, target_instances, session, compute_pool, job_id
534
+ return _do_submit_job_v1(
535
+ session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
505
536
  )
506
- _ = query_helper.run_query(session, query_text, params=params)
507
- return get_job(job_id, session=session)
508
537
 
509
538
 
510
- def _generate_submission_query(
539
+ def _do_submit_job_v1(
540
+ session: snowpark.Session,
511
541
  spec: dict[str, Any],
512
542
  external_access_integrations: list[str],
513
543
  query_warehouse: Optional[str],
514
544
  target_instances: int,
515
- session: snowpark.Session,
516
545
  compute_pool: str,
517
546
  job_id: str,
518
- ) -> tuple[str, list[Any]]:
547
+ ) -> jb.MLJob[Any]:
519
548
  """
520
549
  Generate the SQL query for job submission.
521
550
 
522
551
  Args:
552
+ session: The Snowpark session to use.
523
553
  spec: The service spec for the job.
524
554
  external_access_integrations: The external access integrations for the job.
525
555
  query_warehouse: The query warehouse for the job.
@@ -529,7 +559,7 @@ def _generate_submission_query(
529
559
  job_id: The ID of the job.
530
560
 
531
561
  Returns:
532
- A tuple containing the SQL query text and the parameters for the query.
562
+ The job object.
533
563
  """
534
564
  query_template = textwrap.dedent(
535
565
  """\
@@ -547,12 +577,77 @@ def _generate_submission_query(
547
577
  if external_access_integrations:
548
578
  external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
549
579
  query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
550
- query_warehouse = query_warehouse or session.get_current_warehouse()
551
580
  if query_warehouse:
552
581
  query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
553
582
  params.append(query_warehouse)
554
583
  if target_instances > 1:
555
584
  query.append("REPLICAS = ?")
556
585
  params.append(target_instances)
586
+
557
587
  query_text = "\n".join(line for line in query if line)
558
- return query_text, params
588
+ _ = query_helper.run_query(session, query_text, params=params)
589
+
590
+ return get_job(job_id, session=session)
591
+
592
+
593
+ def _do_submit_job_v2(
594
+ session: snowpark.Session,
595
+ payload: types.UploadedPayload,
596
+ args: Optional[list[str]],
597
+ env_vars: dict[str, str],
598
+ spec_overrides: dict[str, Any],
599
+ compute_pool: str,
600
+ job_id: Optional[str] = None,
601
+ external_access_integrations: Optional[list[str]] = None,
602
+ query_warehouse: Optional[str] = None,
603
+ target_instances: int = 1,
604
+ min_instances: int = 1,
605
+ enable_metrics: bool = True,
606
+ use_async: bool = True,
607
+ ) -> jb.MLJob[Any]:
608
+ """
609
+ Generate the SQL query for job submission.
610
+
611
+ Args:
612
+ session: The Snowpark session to use.
613
+ payload: The uploaded job payload.
614
+ args: Arguments to pass to the entrypoint script.
615
+ env_vars: Environment variables to set in the job container.
616
+ spec_overrides: Custom service specification overrides.
617
+ compute_pool: The compute pool to use for job execution.
618
+ job_id: The ID of the job.
619
+ external_access_integrations: Optional list of external access integrations.
620
+ query_warehouse: Optional query warehouse to use.
621
+ target_instances: Number of instances for multi-node job.
622
+ min_instances: Minimum number of instances required to start the job.
623
+ enable_metrics: Whether to enable platform metrics for the job.
624
+ use_async: Whether to run the job asynchronously.
625
+
626
+ Returns:
627
+ The job object.
628
+ """
629
+ args = [
630
+ (payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
631
+ ] + (args or [])
632
+ spec_options = {
633
+ "STAGE_PATH": payload.stage_path.as_posix(),
634
+ "ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
635
+ "ARGS": args,
636
+ "ENV_VARS": env_vars,
637
+ "ENABLE_METRICS": enable_metrics,
638
+ "SPEC_OVERRIDES": spec_overrides,
639
+ }
640
+ job_options = {
641
+ "EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
642
+ "QUERY_WAREHOUSE": query_warehouse,
643
+ "TARGET_INSTANCES": target_instances,
644
+ "MIN_INSTANCES": min_instances,
645
+ "ASYNC": use_async,
646
+ }
647
+ job_options = {k: v for k, v in job_options.items() if v is not None}
648
+
649
+ query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
650
+ params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
651
+ actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
652
+
653
+ return get_job(actual_job_id, session=session)
@@ -1,5 +1,10 @@
1
+ from snowflake.ml.model._client.model.batch_inference_specs import (
2
+ InputSpec,
3
+ JobSpec,
4
+ OutputSpec,
5
+ )
1
6
  from snowflake.ml.model._client.model.model_impl import Model
2
7
  from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
3
8
  from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
4
9
 
5
- __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
10
+ __all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
@@ -0,0 +1,27 @@
1
+ from typing import Optional, Union
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class InputSpec(BaseModel):
7
+ input_stage_location: str
8
+ input_file_pattern: str = "*"
9
+
10
+
11
+ class OutputSpec(BaseModel):
12
+ output_stage_location: str
13
+ output_file_prefix: Optional[str] = None
14
+ completion_filename: str = "_SUCCESS"
15
+
16
+
17
+ class JobSpec(BaseModel):
18
+ image_repo: Optional[str] = None
19
+ job_name: Optional[str] = None
20
+ num_workers: Optional[int] = None
21
+ function_name: Optional[str] = None
22
+ gpu: Optional[Union[str, int]] = None
23
+ force_rebuild: bool = False
24
+ max_batch_rows: int = 1024
25
+ warehouse: Optional[str] = None
26
+ cpu_requests: Optional[str] = None
27
+ memory_requests: Optional[str] = None
@@ -1,16 +1,18 @@
1
1
  import enum
2
2
  import pathlib
3
3
  import tempfile
4
+ import uuid
4
5
  import warnings
5
6
  from typing import Any, Callable, Optional, Union, overload
6
7
 
7
8
  import pandas as pd
8
9
 
9
- from snowflake import snowpark
10
+ from snowflake.ml import jobs
10
11
  from snowflake.ml._internal import telemetry
11
12
  from snowflake.ml._internal.utils import sql_identifier
12
13
  from snowflake.ml.lineage import lineage_node
13
14
  from snowflake.ml.model import task, type_hints
15
+ from snowflake.ml.model._client.model import batch_inference_specs
14
16
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
15
17
  from snowflake.ml.model._model_composer import model_composer
16
18
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -19,6 +21,7 @@ from snowflake.snowpark import Session, async_job, dataframe
19
21
 
20
22
  _TELEMETRY_PROJECT = "MLOps"
21
23
  _TELEMETRY_SUBPROJECT = "ModelManagement"
24
+ _BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
22
25
 
23
26
 
24
27
  class ExportMode(enum.Enum):
@@ -539,6 +542,63 @@ class ModelVersion(lineage_node.LineageNode):
539
542
  is_partitioned=target_function_info["is_partitioned"],
540
543
  )
541
544
 
545
+ @telemetry.send_api_usage_telemetry(
546
+ project=_TELEMETRY_PROJECT,
547
+ subproject=_TELEMETRY_SUBPROJECT,
548
+ func_params_to_log=[
549
+ "compute_pool",
550
+ ],
551
+ )
552
+ def _run_batch(
553
+ self,
554
+ *,
555
+ compute_pool: str,
556
+ input_spec: batch_inference_specs.InputSpec,
557
+ output_spec: batch_inference_specs.OutputSpec,
558
+ job_spec: Optional[batch_inference_specs.JobSpec] = None,
559
+ ) -> jobs.MLJob[Any]:
560
+ statement_params = telemetry.get_statement_params(
561
+ project=_TELEMETRY_PROJECT,
562
+ subproject=_TELEMETRY_SUBPROJECT,
563
+ )
564
+
565
+ if job_spec is None:
566
+ job_spec = batch_inference_specs.JobSpec()
567
+
568
+ warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
569
+ if warehouse is None:
570
+ raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
571
+
572
+ if job_spec.job_name is None:
573
+ # Same as the MLJob ID generation logic with a different prefix
574
+ job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
575
+ else:
576
+ job_name = job_spec.job_name
577
+
578
+ return self._service_ops.invoke_batch_job_method(
579
+ # model version info
580
+ model_name=self._model_name,
581
+ version_name=self._version_name,
582
+ # job spec
583
+ function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
584
+ compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
585
+ force_rebuild=job_spec.force_rebuild,
586
+ image_repo_name=job_spec.image_repo,
587
+ num_workers=job_spec.num_workers,
588
+ max_batch_rows=job_spec.max_batch_rows,
589
+ warehouse=sql_identifier.SqlIdentifier(warehouse),
590
+ cpu_requests=job_spec.cpu_requests,
591
+ memory_requests=job_spec.memory_requests,
592
+ job_name=job_name,
593
+ # input and output
594
+ input_stage_location=input_spec.input_stage_location,
595
+ input_file_pattern=input_spec.input_file_pattern,
596
+ output_stage_location=output_spec.output_stage_location,
597
+ completion_filename=output_spec.completion_filename,
598
+ # misc
599
+ statement_params=statement_params,
600
+ )
601
+
542
602
  def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
543
603
  functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
544
604
 
@@ -1184,69 +1244,5 @@ class ModelVersion(lineage_node.LineageNode):
1184
1244
  statement_params=statement_params,
1185
1245
  )
1186
1246
 
1187
- @snowpark._internal.utils.private_preview(version="1.8.3")
1188
- @telemetry.send_api_usage_telemetry(
1189
- project=_TELEMETRY_PROJECT,
1190
- subproject=_TELEMETRY_SUBPROJECT,
1191
- )
1192
- def _run_job(
1193
- self,
1194
- X: Union[pd.DataFrame, "dataframe.DataFrame"],
1195
- *,
1196
- job_name: str,
1197
- compute_pool: str,
1198
- image_repo: Optional[str] = None,
1199
- output_table_name: str,
1200
- function_name: Optional[str] = None,
1201
- cpu_requests: Optional[str] = None,
1202
- memory_requests: Optional[str] = None,
1203
- gpu_requests: Optional[Union[str, int]] = None,
1204
- num_workers: Optional[int] = None,
1205
- max_batch_rows: Optional[int] = None,
1206
- force_rebuild: bool = False,
1207
- build_external_access_integrations: Optional[list[str]] = None,
1208
- ) -> Union[pd.DataFrame, dataframe.DataFrame]:
1209
- statement_params = telemetry.get_statement_params(
1210
- project=_TELEMETRY_PROJECT,
1211
- subproject=_TELEMETRY_SUBPROJECT,
1212
- )
1213
- target_function_info = self._get_function_info(function_name=function_name)
1214
- job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
1215
- output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
1216
- output_table_name
1217
- )
1218
- warehouse = self._service_ops._session.get_current_warehouse()
1219
- assert warehouse, "No active warehouse selected in the current session."
1220
- return self._service_ops.invoke_job_method(
1221
- target_method=target_function_info["target_method"],
1222
- signature=target_function_info["signature"],
1223
- X=X,
1224
- database_name=None,
1225
- schema_name=None,
1226
- model_name=self._model_name,
1227
- version_name=self._version_name,
1228
- job_database_name=job_db_id,
1229
- job_schema_name=job_schema_id,
1230
- job_name=job_id,
1231
- compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
1232
- warehouse_name=sql_identifier.SqlIdentifier(warehouse),
1233
- image_repo_name=image_repo,
1234
- output_table_database_name=output_table_db_id,
1235
- output_table_schema_name=output_table_schema_id,
1236
- output_table_name=output_table_id,
1237
- cpu_requests=cpu_requests,
1238
- memory_requests=memory_requests,
1239
- gpu_requests=gpu_requests,
1240
- num_workers=num_workers,
1241
- max_batch_rows=max_batch_rows,
1242
- force_rebuild=force_rebuild,
1243
- build_external_access_integrations=(
1244
- None
1245
- if build_external_access_integrations is None
1246
- else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
1247
- ),
1248
- statement_params=statement_params,
1249
- )
1250
-
1251
1247
 
1252
1248
  lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
@@ -10,17 +10,13 @@ import time
10
10
  from typing import Any, Optional, Union, cast
11
11
 
12
12
  from snowflake import snowpark
13
+ from snowflake.ml import jobs
13
14
  from snowflake.ml._internal import file_utils, platform_capabilities as pc
14
15
  from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
15
- from snowflake.ml.model import (
16
- inference_engine as inference_engine_module,
17
- model_signature,
18
- type_hints,
19
- )
16
+ from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
20
17
  from snowflake.ml.model._client.service import model_deployment_spec
21
18
  from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
22
- from snowflake.ml.model._signatures import snowpark_handler
23
- from snowflake.snowpark import async_job, dataframe, exceptions, row, session
19
+ from snowflake.snowpark import async_job, exceptions, row, session
24
20
  from snowflake.snowpark._internal import utils as snowpark_utils
25
21
 
26
22
  module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
@@ -866,174 +862,97 @@ class ServiceOperator:
866
862
  except exceptions.SnowparkSQLException:
867
863
  return False
868
864
 
869
- def invoke_job_method(
865
+ def invoke_batch_job_method(
870
866
  self,
871
- target_method: str,
872
- signature: model_signature.ModelSignature,
873
- X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
874
- database_name: Optional[sql_identifier.SqlIdentifier],
875
- schema_name: Optional[sql_identifier.SqlIdentifier],
867
+ *,
868
+ function_name: str,
876
869
  model_name: sql_identifier.SqlIdentifier,
877
870
  version_name: sql_identifier.SqlIdentifier,
878
- job_database_name: Optional[sql_identifier.SqlIdentifier],
879
- job_schema_name: Optional[sql_identifier.SqlIdentifier],
880
- job_name: sql_identifier.SqlIdentifier,
871
+ job_name: str,
881
872
  compute_pool_name: sql_identifier.SqlIdentifier,
882
- warehouse_name: sql_identifier.SqlIdentifier,
873
+ warehouse: sql_identifier.SqlIdentifier,
883
874
  image_repo_name: Optional[str],
884
- output_table_database_name: Optional[sql_identifier.SqlIdentifier],
885
- output_table_schema_name: Optional[sql_identifier.SqlIdentifier],
886
- output_table_name: sql_identifier.SqlIdentifier,
887
- cpu_requests: Optional[str],
888
- memory_requests: Optional[str],
889
- gpu_requests: Optional[Union[int, str]],
875
+ input_stage_location: str,
876
+ input_file_pattern: str,
877
+ output_stage_location: str,
878
+ completion_filename: str,
879
+ force_rebuild: bool,
890
880
  num_workers: Optional[int],
891
881
  max_batch_rows: Optional[int],
892
- force_rebuild: bool,
893
- build_external_access_integrations: Optional[list[sql_identifier.SqlIdentifier]],
882
+ cpu_requests: Optional[str],
883
+ memory_requests: Optional[str],
894
884
  statement_params: Optional[dict[str, Any]] = None,
895
- ) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
896
- # fall back to the registry's database and schema if not provided
897
- database_name = database_name or self._database_name
898
- schema_name = schema_name or self._schema_name
899
-
900
- # fall back to the model's database and schema if not provided then to the registry's database and schema
901
- job_database_name = job_database_name or database_name or self._database_name
902
- job_schema_name = job_schema_name or schema_name or self._schema_name
885
+ ) -> jobs.MLJob[Any]:
886
+ database_name = self._database_name
887
+ schema_name = self._schema_name
903
888
 
904
- image_repo_fqn = self._get_image_repo_fqn(image_repo_name, database_name, schema_name)
889
+ job_database_name, job_schema_name, job_name = sql_identifier.parse_fully_qualified_name(job_name)
890
+ job_database_name = job_database_name or database_name
891
+ job_schema_name = job_schema_name or schema_name
905
892
 
906
- input_table_database_name = job_database_name
907
- input_table_schema_name = job_schema_name
908
- output_table_database_name = output_table_database_name or database_name or self._database_name
909
- output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
910
-
911
- if self._workspace:
912
- stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
913
- else:
914
- stage_path = None
893
+ self._model_deployment_spec.clear()
915
894
 
916
- # validate and prepare input
917
- if not isinstance(X, dataframe.DataFrame):
918
- keep_order = True
919
- output_with_input_features = False
920
- df = model_signature._convert_and_validate_local_data(X, signature.inputs)
921
- s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
922
- self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
923
- )
924
- else:
925
- keep_order = False
926
- output_with_input_features = True
927
- s_df = X
928
-
929
- # only write the index and feature input columns
930
- cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
931
- cols += [
932
- sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
933
- ]
934
- s_df = s_df.select(cols)
935
- original_cols = s_df.columns
936
-
937
- # input/output tables
938
- fq_output_table_name = identifier.get_schema_level_object_identifier(
939
- output_table_database_name.identifier(),
940
- output_table_schema_name.identifier(),
941
- output_table_name.identifier(),
942
- )
943
- tmp_input_table_id = sql_identifier.SqlIdentifier(
944
- snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
945
- )
946
- fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
947
- job_database_name.identifier(),
948
- job_schema_name.identifier(),
949
- tmp_input_table_id.identifier(),
950
- )
951
- s_df.write.save_as_table(
952
- table_name=fq_tmp_input_table_name,
953
- mode="errorifexists",
954
- statement_params=statement_params,
895
+ self._model_deployment_spec.add_model_spec(
896
+ database_name=database_name,
897
+ schema_name=schema_name,
898
+ model_name=model_name,
899
+ version_name=version_name,
955
900
  )
956
901
 
957
- try:
958
- self._model_deployment_spec.clear()
959
- # save the spec
960
- self._model_deployment_spec.add_model_spec(
961
- database_name=database_name,
962
- schema_name=schema_name,
963
- model_name=model_name,
964
- version_name=version_name,
965
- )
966
- self._model_deployment_spec.add_job_spec(
967
- job_database_name=job_database_name,
968
- job_schema_name=job_schema_name,
969
- job_name=job_name,
970
- inference_compute_pool_name=compute_pool_name,
971
- cpu=cpu_requests,
972
- memory=memory_requests,
973
- gpu=gpu_requests,
974
- num_workers=num_workers,
975
- max_batch_rows=max_batch_rows,
976
- warehouse=warehouse_name,
977
- target_method=target_method,
978
- input_table_database_name=input_table_database_name,
979
- input_table_schema_name=input_table_schema_name,
980
- input_table_name=tmp_input_table_id,
981
- output_table_database_name=output_table_database_name,
982
- output_table_schema_name=output_table_schema_name,
983
- output_table_name=output_table_name,
984
- )
902
+ self._model_deployment_spec.add_job_spec(
903
+ job_database_name=job_database_name,
904
+ job_schema_name=job_schema_name,
905
+ job_name=job_name,
906
+ inference_compute_pool_name=compute_pool_name,
907
+ num_workers=num_workers,
908
+ max_batch_rows=max_batch_rows,
909
+ input_stage_location=input_stage_location,
910
+ input_file_pattern=input_file_pattern,
911
+ output_stage_location=output_stage_location,
912
+ completion_filename=completion_filename,
913
+ function_name=function_name,
914
+ warehouse=warehouse,
915
+ cpu=cpu_requests,
916
+ memory=memory_requests,
917
+ )
985
918
 
986
- self._model_deployment_spec.add_image_build_spec(
987
- image_build_compute_pool_name=compute_pool_name,
988
- fully_qualified_image_repo_name=image_repo_fqn,
989
- force_rebuild=force_rebuild,
990
- external_access_integrations=build_external_access_integrations,
991
- )
919
+ self._model_deployment_spec.add_image_build_spec(
920
+ image_build_compute_pool_name=compute_pool_name,
921
+ fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
922
+ force_rebuild=force_rebuild,
923
+ )
992
924
 
993
- spec_yaml_str_or_path = self._model_deployment_spec.save()
994
- if self._workspace:
995
- assert stage_path is not None
996
- file_utils.upload_directory_to_stage(
997
- self._session,
998
- local_path=pathlib.Path(self._workspace.name),
999
- stage_path=pathlib.PurePosixPath(stage_path),
1000
- statement_params=statement_params,
1001
- )
925
+ spec_yaml_str_or_path = self._model_deployment_spec.save()
1002
926
 
1003
- # deploy the job
1004
- query_id, async_job = self._service_client.deploy_model(
1005
- stage_path=stage_path if self._workspace else None,
1006
- model_deployment_spec_file_rel_path=(
1007
- model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
1008
- ),
1009
- model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
927
+ if self._workspace:
928
+ module_logger.info("using workspace")
929
+ stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
930
+ file_utils.upload_directory_to_stage(
931
+ self._session,
932
+ local_path=pathlib.Path(self._workspace.name),
933
+ stage_path=pathlib.PurePosixPath(stage_path),
1010
934
  statement_params=statement_params,
1011
935
  )
936
+ else:
937
+ module_logger.info("not using workspace")
938
+ stage_path = None
1012
939
 
1013
- while not async_job.is_done():
1014
- time.sleep(5)
1015
- finally:
1016
- self._session.table(fq_tmp_input_table_name).drop_table()
1017
-
1018
- # handle the output
1019
- df_res = self._session.table(fq_output_table_name)
1020
- if keep_order:
1021
- df_res = df_res.sort(
1022
- snowpark_handler._KEEP_ORDER_COL_NAME,
1023
- ascending=True,
1024
- )
1025
- df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
940
+ _, async_job = self._service_client.deploy_model(
941
+ stage_path=stage_path if self._workspace else None,
942
+ model_deployment_spec_file_rel_path=(
943
+ model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
944
+ ),
945
+ model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
946
+ statement_params=statement_params,
947
+ )
1026
948
 
1027
- if not output_with_input_features:
1028
- df_res = df_res.drop(*original_cols)
949
+ # Block until the async job is done
950
+ async_job.result()
1029
951
 
1030
- # get final result
1031
- if not isinstance(X, dataframe.DataFrame):
1032
- return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(
1033
- df_res, features=signature.outputs, statement_params=statement_params
1034
- )
1035
- else:
1036
- return df_res
952
+ return jobs.MLJob(
953
+ id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
954
+ session=self._session,
955
+ )
1037
956
 
1038
957
  def _create_temp_stage(
1039
958
  self,