snowflake-ml-python 1.11.0__py3-none-any.whl → 1.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/telemetry.py +3 -1
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/experiment_tracking.py +113 -6
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/constants.py +8 -16
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +19 -5
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +23 -5
- snowflake/ml/jobs/_utils/spec_utils.py +4 -6
- snowflake/ml/jobs/_utils/types.py +2 -1
- snowflake/ml/jobs/job.py +38 -19
- snowflake/ml/jobs/manager.py +136 -19
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +25 -0
- snowflake/ml/model/_client/model/model_version_impl.py +62 -65
- snowflake/ml/model/_client/ops/model_ops.py +42 -9
- snowflake/ml/model/_client/ops/service_ops.py +75 -154
- snowflake/ml/model/_client/service/model_deployment_spec.py +23 -37
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +15 -4
- snowflake/ml/model/_client/sql/service.py +4 -0
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +309 -22
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_handlers/xgboost.py +5 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/models/huggingface_pipeline.py +23 -0
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/METADATA +82 -5
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/RECORD +198 -194
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.13.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import logging
|
|
2
3
|
import pathlib
|
|
3
4
|
import textwrap
|
|
5
|
+
from pathlib import PurePath
|
|
4
6
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
5
7
|
from uuid import uuid4
|
|
6
8
|
|
|
@@ -11,7 +13,13 @@ from snowflake import snowpark
|
|
|
11
13
|
from snowflake.ml._internal import telemetry
|
|
12
14
|
from snowflake.ml._internal.utils import identifier
|
|
13
15
|
from snowflake.ml.jobs import job as jb
|
|
14
|
-
from snowflake.ml.jobs._utils import
|
|
16
|
+
from snowflake.ml.jobs._utils import (
|
|
17
|
+
feature_flags,
|
|
18
|
+
payload_utils,
|
|
19
|
+
query_helper,
|
|
20
|
+
spec_utils,
|
|
21
|
+
types,
|
|
22
|
+
)
|
|
15
23
|
from snowflake.snowpark.context import get_active_session
|
|
16
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
17
25
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -50,7 +58,8 @@ def list_jobs(
|
|
|
50
58
|
>>> from snowflake.ml.jobs import list_jobs
|
|
51
59
|
>>> list_jobs(limit=5)
|
|
52
60
|
"""
|
|
53
|
-
|
|
61
|
+
|
|
62
|
+
session = _ensure_session(session)
|
|
54
63
|
try:
|
|
55
64
|
df = _get_job_history_spcs(
|
|
56
65
|
session,
|
|
@@ -154,7 +163,7 @@ def _get_job_history_spcs(
|
|
|
154
163
|
@telemetry.send_api_usage_telemetry(project=_PROJECT)
|
|
155
164
|
def get_job(job_id: str, session: Optional[snowpark.Session] = None) -> jb.MLJob[Any]:
|
|
156
165
|
"""Retrieve a job service from the backend."""
|
|
157
|
-
session = session
|
|
166
|
+
session = _ensure_session(session)
|
|
158
167
|
try:
|
|
159
168
|
database, schema, job_name = identifier.parse_schema_level_object_identifier(job_id)
|
|
160
169
|
database = identifier.resolve_identifier(cast(str, database or session.get_current_database()))
|
|
@@ -426,8 +435,10 @@ def _submit_job(
|
|
|
426
435
|
|
|
427
436
|
Raises:
|
|
428
437
|
ValueError: If database or schema value(s) are invalid
|
|
438
|
+
RuntimeError: If schema is not specified in session context or job submission
|
|
439
|
+
snowpark.exceptions.SnowparkSQLException: if failed to upload payload
|
|
429
440
|
"""
|
|
430
|
-
session = session
|
|
441
|
+
session = _ensure_session(session)
|
|
431
442
|
|
|
432
443
|
# Check for deprecated args
|
|
433
444
|
if "num_instances" in kwargs:
|
|
@@ -445,7 +456,7 @@ def _submit_job(
|
|
|
445
456
|
env_vars = kwargs.pop("env_vars", None)
|
|
446
457
|
spec_overrides = kwargs.pop("spec_overrides", None)
|
|
447
458
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
448
|
-
query_warehouse = kwargs.pop("query_warehouse",
|
|
459
|
+
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
449
460
|
additional_payloads = kwargs.pop("additional_payloads", None)
|
|
450
461
|
|
|
451
462
|
if additional_payloads:
|
|
@@ -478,11 +489,39 @@ def _submit_job(
|
|
|
478
489
|
stage_name = f"@{'.'.join(filter(None, stage_path_parts[:3]))}"
|
|
479
490
|
stage_path = pathlib.PurePosixPath(f"{stage_name}{stage_path_parts[-1].rstrip('/')}/{job_name}")
|
|
480
491
|
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
492
|
+
try:
|
|
493
|
+
# Upload payload
|
|
494
|
+
uploaded_payload = payload_utils.JobPayload(
|
|
495
|
+
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
|
|
496
|
+
).upload(session, stage_path)
|
|
497
|
+
except snowpark.exceptions.SnowparkSQLException as e:
|
|
498
|
+
if e.sql_error_code == 90106:
|
|
499
|
+
raise RuntimeError(
|
|
500
|
+
"Please specify a schema, either in the session context or as a parameter in the job submission"
|
|
501
|
+
)
|
|
502
|
+
raise
|
|
485
503
|
|
|
504
|
+
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
505
|
+
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
506
|
+
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
507
|
+
|
|
508
|
+
return _do_submit_job_v2(
|
|
509
|
+
session=session,
|
|
510
|
+
payload=uploaded_payload,
|
|
511
|
+
args=args,
|
|
512
|
+
env_vars=combined_env_vars,
|
|
513
|
+
spec_overrides=spec_overrides,
|
|
514
|
+
compute_pool=compute_pool,
|
|
515
|
+
job_id=job_id,
|
|
516
|
+
external_access_integrations=external_access_integrations,
|
|
517
|
+
query_warehouse=query_warehouse,
|
|
518
|
+
target_instances=target_instances,
|
|
519
|
+
min_instances=min_instances,
|
|
520
|
+
enable_metrics=enable_metrics,
|
|
521
|
+
use_async=True,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Fall back to v1
|
|
486
525
|
# Generate service spec
|
|
487
526
|
spec = spec_utils.generate_service_spec(
|
|
488
527
|
session,
|
|
@@ -493,6 +532,8 @@ def _submit_job(
|
|
|
493
532
|
min_instances=min_instances,
|
|
494
533
|
enable_metrics=enable_metrics,
|
|
495
534
|
)
|
|
535
|
+
|
|
536
|
+
# Generate spec overrides
|
|
496
537
|
spec_overrides = spec_utils.generate_spec_overrides(
|
|
497
538
|
environment_vars=env_vars,
|
|
498
539
|
custom_overrides=spec_overrides,
|
|
@@ -500,26 +541,25 @@ def _submit_job(
|
|
|
500
541
|
if spec_overrides:
|
|
501
542
|
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
|
502
543
|
|
|
503
|
-
|
|
504
|
-
spec, external_access_integrations, query_warehouse, target_instances,
|
|
544
|
+
return _do_submit_job_v1(
|
|
545
|
+
session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
|
|
505
546
|
)
|
|
506
|
-
_ = query_helper.run_query(session, query_text, params=params)
|
|
507
|
-
return get_job(job_id, session=session)
|
|
508
547
|
|
|
509
548
|
|
|
510
|
-
def
|
|
549
|
+
def _do_submit_job_v1(
|
|
550
|
+
session: snowpark.Session,
|
|
511
551
|
spec: dict[str, Any],
|
|
512
552
|
external_access_integrations: list[str],
|
|
513
553
|
query_warehouse: Optional[str],
|
|
514
554
|
target_instances: int,
|
|
515
|
-
session: snowpark.Session,
|
|
516
555
|
compute_pool: str,
|
|
517
556
|
job_id: str,
|
|
518
|
-
) ->
|
|
557
|
+
) -> jb.MLJob[Any]:
|
|
519
558
|
"""
|
|
520
559
|
Generate the SQL query for job submission.
|
|
521
560
|
|
|
522
561
|
Args:
|
|
562
|
+
session: The Snowpark session to use.
|
|
523
563
|
spec: The service spec for the job.
|
|
524
564
|
external_access_integrations: The external access integrations for the job.
|
|
525
565
|
query_warehouse: The query warehouse for the job.
|
|
@@ -529,7 +569,7 @@ def _generate_submission_query(
|
|
|
529
569
|
job_id: The ID of the job.
|
|
530
570
|
|
|
531
571
|
Returns:
|
|
532
|
-
|
|
572
|
+
The job object.
|
|
533
573
|
"""
|
|
534
574
|
query_template = textwrap.dedent(
|
|
535
575
|
"""\
|
|
@@ -547,12 +587,89 @@ def _generate_submission_query(
|
|
|
547
587
|
if external_access_integrations:
|
|
548
588
|
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
|
549
589
|
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
|
550
|
-
query_warehouse = query_warehouse or session.get_current_warehouse()
|
|
551
590
|
if query_warehouse:
|
|
552
591
|
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
|
553
592
|
params.append(query_warehouse)
|
|
554
593
|
if target_instances > 1:
|
|
555
594
|
query.append("REPLICAS = ?")
|
|
556
595
|
params.append(target_instances)
|
|
596
|
+
|
|
557
597
|
query_text = "\n".join(line for line in query if line)
|
|
558
|
-
|
|
598
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
|
599
|
+
|
|
600
|
+
return get_job(job_id, session=session)
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def _do_submit_job_v2(
|
|
604
|
+
session: snowpark.Session,
|
|
605
|
+
payload: types.UploadedPayload,
|
|
606
|
+
args: Optional[list[str]],
|
|
607
|
+
env_vars: dict[str, str],
|
|
608
|
+
spec_overrides: dict[str, Any],
|
|
609
|
+
compute_pool: str,
|
|
610
|
+
job_id: Optional[str] = None,
|
|
611
|
+
external_access_integrations: Optional[list[str]] = None,
|
|
612
|
+
query_warehouse: Optional[str] = None,
|
|
613
|
+
target_instances: int = 1,
|
|
614
|
+
min_instances: int = 1,
|
|
615
|
+
enable_metrics: bool = True,
|
|
616
|
+
use_async: bool = True,
|
|
617
|
+
) -> jb.MLJob[Any]:
|
|
618
|
+
"""
|
|
619
|
+
Generate the SQL query for job submission.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
session: The Snowpark session to use.
|
|
623
|
+
payload: The uploaded job payload.
|
|
624
|
+
args: Arguments to pass to the entrypoint script.
|
|
625
|
+
env_vars: Environment variables to set in the job container.
|
|
626
|
+
spec_overrides: Custom service specification overrides.
|
|
627
|
+
compute_pool: The compute pool to use for job execution.
|
|
628
|
+
job_id: The ID of the job.
|
|
629
|
+
external_access_integrations: Optional list of external access integrations.
|
|
630
|
+
query_warehouse: Optional query warehouse to use.
|
|
631
|
+
target_instances: Number of instances for multi-node job.
|
|
632
|
+
min_instances: Minimum number of instances required to start the job.
|
|
633
|
+
enable_metrics: Whether to enable platform metrics for the job.
|
|
634
|
+
use_async: Whether to run the job asynchronously.
|
|
635
|
+
|
|
636
|
+
Returns:
|
|
637
|
+
The job object.
|
|
638
|
+
"""
|
|
639
|
+
args = [
|
|
640
|
+
(payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
641
|
+
] + (args or [])
|
|
642
|
+
spec_options = {
|
|
643
|
+
"STAGE_PATH": payload.stage_path.as_posix(),
|
|
644
|
+
"ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
|
|
645
|
+
"ARGS": args,
|
|
646
|
+
"ENV_VARS": env_vars,
|
|
647
|
+
"ENABLE_METRICS": enable_metrics,
|
|
648
|
+
"SPEC_OVERRIDES": spec_overrides,
|
|
649
|
+
}
|
|
650
|
+
job_options = {
|
|
651
|
+
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
652
|
+
"QUERY_WAREHOUSE": query_warehouse,
|
|
653
|
+
"TARGET_INSTANCES": target_instances,
|
|
654
|
+
"MIN_INSTANCES": min_instances,
|
|
655
|
+
"ASYNC": use_async,
|
|
656
|
+
}
|
|
657
|
+
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
658
|
+
|
|
659
|
+
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
|
|
660
|
+
params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
|
|
661
|
+
actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
|
|
662
|
+
|
|
663
|
+
return get_job(actual_job_id, session=session)
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def _ensure_session(session: Optional[snowpark.Session]) -> snowpark.Session:
|
|
667
|
+
try:
|
|
668
|
+
session = session or get_active_session()
|
|
669
|
+
except snowpark.exceptions.SnowparkSessionException as e:
|
|
670
|
+
if "More than one active session" in e.message:
|
|
671
|
+
raise RuntimeError("Please specify the session as a parameter in API call")
|
|
672
|
+
if "No default Session is found" in e.message:
|
|
673
|
+
raise RuntimeError("Please create a session before API call")
|
|
674
|
+
raise
|
|
675
|
+
return session
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
|
+
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
|
+
InputSpec,
|
|
3
|
+
JobSpec,
|
|
4
|
+
OutputSpec,
|
|
5
|
+
)
|
|
1
6
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
2
7
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
3
8
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
4
9
|
|
|
5
|
-
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
|
|
10
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InputSpec(BaseModel):
|
|
7
|
+
stage_location: str
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class OutputSpec(BaseModel):
|
|
11
|
+
stage_location: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class JobSpec(BaseModel):
|
|
15
|
+
image_repo: Optional[str] = None
|
|
16
|
+
job_name: Optional[str] = None
|
|
17
|
+
num_workers: Optional[int] = None
|
|
18
|
+
function_name: Optional[str] = None
|
|
19
|
+
gpu: Optional[Union[str, int]] = None
|
|
20
|
+
force_rebuild: bool = False
|
|
21
|
+
max_batch_rows: int = 1024
|
|
22
|
+
warehouse: Optional[str] = None
|
|
23
|
+
cpu_requests: Optional[str] = None
|
|
24
|
+
memory_requests: Optional[str] = None
|
|
25
|
+
replicas: Optional[int] = None
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import pathlib
|
|
3
3
|
import tempfile
|
|
4
|
+
import uuid
|
|
4
5
|
import warnings
|
|
5
6
|
from typing import Any, Callable, Optional, Union, overload
|
|
6
7
|
|
|
7
8
|
import pandas as pd
|
|
8
9
|
|
|
9
|
-
from snowflake import
|
|
10
|
+
from snowflake.ml import jobs
|
|
10
11
|
from snowflake.ml._internal import telemetry
|
|
11
12
|
from snowflake.ml._internal.utils import sql_identifier
|
|
12
13
|
from snowflake.ml.lineage import lineage_node
|
|
13
14
|
from snowflake.ml.model import task, type_hints
|
|
15
|
+
from snowflake.ml.model._client.model import batch_inference_specs
|
|
14
16
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
15
17
|
from snowflake.ml.model._model_composer import model_composer
|
|
16
18
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -19,6 +21,7 @@ from snowflake.snowpark import Session, async_job, dataframe
|
|
|
19
21
|
|
|
20
22
|
_TELEMETRY_PROJECT = "MLOps"
|
|
21
23
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
24
|
+
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class ExportMode(enum.Enum):
|
|
@@ -539,6 +542,64 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
539
542
|
is_partitioned=target_function_info["is_partitioned"],
|
|
540
543
|
)
|
|
541
544
|
|
|
545
|
+
@telemetry.send_api_usage_telemetry(
|
|
546
|
+
project=_TELEMETRY_PROJECT,
|
|
547
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
548
|
+
func_params_to_log=[
|
|
549
|
+
"compute_pool",
|
|
550
|
+
],
|
|
551
|
+
)
|
|
552
|
+
def _run_batch(
|
|
553
|
+
self,
|
|
554
|
+
*,
|
|
555
|
+
compute_pool: str,
|
|
556
|
+
input_spec: batch_inference_specs.InputSpec,
|
|
557
|
+
output_spec: batch_inference_specs.OutputSpec,
|
|
558
|
+
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
559
|
+
) -> jobs.MLJob[Any]:
|
|
560
|
+
statement_params = telemetry.get_statement_params(
|
|
561
|
+
project=_TELEMETRY_PROJECT,
|
|
562
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if job_spec is None:
|
|
566
|
+
job_spec = batch_inference_specs.JobSpec()
|
|
567
|
+
|
|
568
|
+
warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
|
|
569
|
+
if warehouse is None:
|
|
570
|
+
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
571
|
+
|
|
572
|
+
if job_spec.job_name is None:
|
|
573
|
+
# Same as the MLJob ID generation logic with a different prefix
|
|
574
|
+
job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
|
|
575
|
+
else:
|
|
576
|
+
job_name = job_spec.job_name
|
|
577
|
+
|
|
578
|
+
return self._service_ops.invoke_batch_job_method(
|
|
579
|
+
# model version info
|
|
580
|
+
model_name=self._model_name,
|
|
581
|
+
version_name=self._version_name,
|
|
582
|
+
# job spec
|
|
583
|
+
function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
|
|
584
|
+
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
585
|
+
force_rebuild=job_spec.force_rebuild,
|
|
586
|
+
image_repo_name=job_spec.image_repo,
|
|
587
|
+
num_workers=job_spec.num_workers,
|
|
588
|
+
max_batch_rows=job_spec.max_batch_rows,
|
|
589
|
+
warehouse=sql_identifier.SqlIdentifier(warehouse),
|
|
590
|
+
cpu_requests=job_spec.cpu_requests,
|
|
591
|
+
memory_requests=job_spec.memory_requests,
|
|
592
|
+
job_name=job_name,
|
|
593
|
+
replicas=job_spec.replicas,
|
|
594
|
+
# input and output
|
|
595
|
+
input_stage_location=input_spec.stage_location,
|
|
596
|
+
input_file_pattern="*",
|
|
597
|
+
output_stage_location=output_spec.stage_location,
|
|
598
|
+
completion_filename="_SUCCESS",
|
|
599
|
+
# misc
|
|
600
|
+
statement_params=statement_params,
|
|
601
|
+
)
|
|
602
|
+
|
|
542
603
|
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
|
543
604
|
functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
|
|
544
605
|
|
|
@@ -1184,69 +1245,5 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1184
1245
|
statement_params=statement_params,
|
|
1185
1246
|
)
|
|
1186
1247
|
|
|
1187
|
-
@snowpark._internal.utils.private_preview(version="1.8.3")
|
|
1188
|
-
@telemetry.send_api_usage_telemetry(
|
|
1189
|
-
project=_TELEMETRY_PROJECT,
|
|
1190
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
|
1191
|
-
)
|
|
1192
|
-
def _run_job(
|
|
1193
|
-
self,
|
|
1194
|
-
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
|
1195
|
-
*,
|
|
1196
|
-
job_name: str,
|
|
1197
|
-
compute_pool: str,
|
|
1198
|
-
image_repo: Optional[str] = None,
|
|
1199
|
-
output_table_name: str,
|
|
1200
|
-
function_name: Optional[str] = None,
|
|
1201
|
-
cpu_requests: Optional[str] = None,
|
|
1202
|
-
memory_requests: Optional[str] = None,
|
|
1203
|
-
gpu_requests: Optional[Union[str, int]] = None,
|
|
1204
|
-
num_workers: Optional[int] = None,
|
|
1205
|
-
max_batch_rows: Optional[int] = None,
|
|
1206
|
-
force_rebuild: bool = False,
|
|
1207
|
-
build_external_access_integrations: Optional[list[str]] = None,
|
|
1208
|
-
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
1209
|
-
statement_params = telemetry.get_statement_params(
|
|
1210
|
-
project=_TELEMETRY_PROJECT,
|
|
1211
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
|
1212
|
-
)
|
|
1213
|
-
target_function_info = self._get_function_info(function_name=function_name)
|
|
1214
|
-
job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
|
|
1215
|
-
output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
|
|
1216
|
-
output_table_name
|
|
1217
|
-
)
|
|
1218
|
-
warehouse = self._service_ops._session.get_current_warehouse()
|
|
1219
|
-
assert warehouse, "No active warehouse selected in the current session."
|
|
1220
|
-
return self._service_ops.invoke_job_method(
|
|
1221
|
-
target_method=target_function_info["target_method"],
|
|
1222
|
-
signature=target_function_info["signature"],
|
|
1223
|
-
X=X,
|
|
1224
|
-
database_name=None,
|
|
1225
|
-
schema_name=None,
|
|
1226
|
-
model_name=self._model_name,
|
|
1227
|
-
version_name=self._version_name,
|
|
1228
|
-
job_database_name=job_db_id,
|
|
1229
|
-
job_schema_name=job_schema_id,
|
|
1230
|
-
job_name=job_id,
|
|
1231
|
-
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
1232
|
-
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
|
1233
|
-
image_repo_name=image_repo,
|
|
1234
|
-
output_table_database_name=output_table_db_id,
|
|
1235
|
-
output_table_schema_name=output_table_schema_id,
|
|
1236
|
-
output_table_name=output_table_id,
|
|
1237
|
-
cpu_requests=cpu_requests,
|
|
1238
|
-
memory_requests=memory_requests,
|
|
1239
|
-
gpu_requests=gpu_requests,
|
|
1240
|
-
num_workers=num_workers,
|
|
1241
|
-
max_batch_rows=max_batch_rows,
|
|
1242
|
-
force_rebuild=force_rebuild,
|
|
1243
|
-
build_external_access_integrations=(
|
|
1244
|
-
None
|
|
1245
|
-
if build_external_access_integrations is None
|
|
1246
|
-
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
1247
|
-
),
|
|
1248
|
-
statement_params=statement_params,
|
|
1249
|
-
)
|
|
1250
|
-
|
|
1251
1248
|
|
|
1252
1249
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
|
@@ -47,6 +47,7 @@ class ServiceInfo(TypedDict):
|
|
|
47
47
|
class ModelOperator:
|
|
48
48
|
INFERENCE_SERVICE_ENDPOINT_NAME = "inference"
|
|
49
49
|
INGRESS_ENDPOINT_URL_SUFFIX = "snowflakecomputing.app"
|
|
50
|
+
PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING = "privatelink.snowflakecomputing"
|
|
50
51
|
|
|
51
52
|
def __init__(
|
|
52
53
|
self,
|
|
@@ -612,6 +613,30 @@ class ModelOperator:
|
|
|
612
613
|
statement_params=statement_params,
|
|
613
614
|
)
|
|
614
615
|
|
|
616
|
+
def _is_privatelink_connection(self) -> bool:
|
|
617
|
+
"""Detect if the current session is using a privatelink connection."""
|
|
618
|
+
try:
|
|
619
|
+
host = self._session.connection.host
|
|
620
|
+
return ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in host
|
|
621
|
+
except AttributeError:
|
|
622
|
+
return False
|
|
623
|
+
|
|
624
|
+
def _extract_and_validate_ingress_url(self, res_row: "row.Row") -> Optional[str]:
|
|
625
|
+
"""Extract and validate ingress URL from endpoint row."""
|
|
626
|
+
url_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME]
|
|
627
|
+
if url_value is None:
|
|
628
|
+
return None
|
|
629
|
+
url_str = str(url_value)
|
|
630
|
+
return url_str if url_str.endswith(ModelOperator.INGRESS_ENDPOINT_URL_SUFFIX) else None
|
|
631
|
+
|
|
632
|
+
def _extract_and_validate_privatelink_url(self, res_row: "row.Row") -> Optional[str]:
|
|
633
|
+
"""Extract and validate privatelink ingress URL from endpoint row."""
|
|
634
|
+
url_value = res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_PRIVATELINK_INGRESS_URL_COL_NAME]
|
|
635
|
+
if url_value is None:
|
|
636
|
+
return None
|
|
637
|
+
url_str = str(url_value)
|
|
638
|
+
return url_str if ModelOperator.PRIVATELINK_INGRESS_ENDPOINT_URL_SUBSTRING in url_str else None
|
|
639
|
+
|
|
615
640
|
def show_services(
|
|
616
641
|
self,
|
|
617
642
|
*,
|
|
@@ -644,8 +669,10 @@ class ModelOperator:
|
|
|
644
669
|
fully_qualified_service_names = [str(service) for service in json_array if "MODEL_BUILD_" not in service]
|
|
645
670
|
|
|
646
671
|
result: list[ServiceInfo] = []
|
|
672
|
+
is_privatelink_connection = self._is_privatelink_connection()
|
|
673
|
+
|
|
647
674
|
for fully_qualified_service_name in fully_qualified_service_names:
|
|
648
|
-
|
|
675
|
+
inference_endpoint: Optional[str] = None
|
|
649
676
|
db, schema, service_name = sql_identifier.parse_fully_qualified_name(fully_qualified_service_name)
|
|
650
677
|
statuses = self._service_client.get_service_container_statuses(
|
|
651
678
|
database_name=db, schema_name=schema, service_name=service_name, statement_params=statement_params
|
|
@@ -659,17 +686,23 @@ class ModelOperator:
|
|
|
659
686
|
):
|
|
660
687
|
if (
|
|
661
688
|
res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME]
|
|
662
|
-
|
|
663
|
-
and res_row[self._service_client.MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME] is not None
|
|
689
|
+
!= self.INFERENCE_SERVICE_ENDPOINT_NAME
|
|
664
690
|
):
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
691
|
+
continue
|
|
692
|
+
|
|
693
|
+
ingress_url = self._extract_and_validate_ingress_url(res_row)
|
|
694
|
+
privatelink_ingress_url = self._extract_and_validate_privatelink_url(res_row)
|
|
695
|
+
|
|
696
|
+
if is_privatelink_connection and privatelink_ingress_url is not None:
|
|
697
|
+
inference_endpoint = privatelink_ingress_url
|
|
698
|
+
else:
|
|
699
|
+
inference_endpoint = ingress_url
|
|
700
|
+
|
|
670
701
|
result.append(
|
|
671
702
|
ServiceInfo(
|
|
672
|
-
name=fully_qualified_service_name,
|
|
703
|
+
name=fully_qualified_service_name,
|
|
704
|
+
status=service_status.value,
|
|
705
|
+
inference_endpoint=inference_endpoint,
|
|
673
706
|
)
|
|
674
707
|
)
|
|
675
708
|
|