snowflake-ml-python 1.11.0__py3-none-any.whl → 1.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +3 -2
- snowflake/ml/_internal/utils/service_logger.py +26 -1
- snowflake/ml/experiment/_client/artifact.py +76 -0
- snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +64 -1
- snowflake/ml/experiment/experiment_tracking.py +89 -4
- snowflake/ml/feature_store/feature_store.py +1150 -131
- snowflake/ml/feature_store/feature_view.py +122 -0
- snowflake/ml/jobs/_utils/constants.py +8 -16
- snowflake/ml/jobs/_utils/feature_flags.py +16 -0
- snowflake/ml/jobs/_utils/payload_utils.py +19 -5
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +18 -7
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +12 -4
- snowflake/ml/jobs/_utils/spec_utils.py +4 -6
- snowflake/ml/jobs/_utils/types.py +2 -1
- snowflake/ml/jobs/job.py +33 -17
- snowflake/ml/jobs/manager.py +107 -12
- snowflake/ml/model/__init__.py +6 -1
- snowflake/ml/model/_client/model/batch_inference_specs.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +61 -65
- snowflake/ml/model/_client/ops/service_ops.py +73 -154
- snowflake/ml/model/_client/service/model_deployment_spec.py +20 -37
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +14 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +207 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +3 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -3
- snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +4 -2
- snowflake/ml/model/openai_signatures.py +57 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +43 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +14 -3
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +1 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -1
- snowflake/ml/modeling/cluster/birch.py +1 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -1
- snowflake/ml/modeling/cluster/dbscan.py +1 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -1
- snowflake/ml/modeling/cluster/k_means.py +1 -1
- snowflake/ml/modeling/cluster/mean_shift.py +1 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -1
- snowflake/ml/modeling/cluster/optics.py +1 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +1 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -1
- snowflake/ml/modeling/compose/column_transformer.py +1 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +1 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +1 -1
- snowflake/ml/modeling/covariance/oas.py +1 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +1 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +1 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +1 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/pca.py +1 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +1 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +1 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +1 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +1 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +1 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +1 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +1 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +1 -1
- snowflake/ml/modeling/impute/knn_imputer.py +1 -1
- snowflake/ml/modeling/impute/missing_indicator.py +1 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +1 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/lars.py +1 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +1 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/perceptron.py +1 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/ridge.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -1
- snowflake/ml/modeling/manifold/isomap.py +1 -1
- snowflake/ml/modeling/manifold/mds.py +1 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +1 -1
- snowflake/ml/modeling/manifold/tsne.py +1 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +1 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -1
- snowflake/ml/modeling/svm/linear_svc.py +1 -1
- snowflake/ml/modeling/svm/linear_svr.py +1 -1
- snowflake/ml/modeling/svm/nu_svc.py +1 -1
- snowflake/ml/modeling/svm/nu_svr.py +1 -1
- snowflake/ml/modeling/svm/svc.py +1 -1
- snowflake/ml/modeling/svm/svr.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -1
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +91 -6
- snowflake/ml/monitoring/_manager/model_monitor_manager.py +3 -0
- snowflake/ml/monitoring/entities/model_monitor_config.py +3 -0
- snowflake/ml/monitoring/model_monitor.py +26 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/METADATA +66 -5
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/RECORD +192 -188
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.11.0.dist-info → snowflake_ml_python-1.12.0.dist-info}/top_level.txt +0 -0
snowflake/ml/jobs/manager.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
|
+
import json
|
|
1
2
|
import logging
|
|
2
3
|
import pathlib
|
|
3
4
|
import textwrap
|
|
5
|
+
from pathlib import PurePath
|
|
4
6
|
from typing import Any, Callable, Optional, TypeVar, Union, cast, overload
|
|
5
7
|
from uuid import uuid4
|
|
6
8
|
|
|
@@ -11,7 +13,13 @@ from snowflake import snowpark
|
|
|
11
13
|
from snowflake.ml._internal import telemetry
|
|
12
14
|
from snowflake.ml._internal.utils import identifier
|
|
13
15
|
from snowflake.ml.jobs import job as jb
|
|
14
|
-
from snowflake.ml.jobs._utils import
|
|
16
|
+
from snowflake.ml.jobs._utils import (
|
|
17
|
+
feature_flags,
|
|
18
|
+
payload_utils,
|
|
19
|
+
query_helper,
|
|
20
|
+
spec_utils,
|
|
21
|
+
types,
|
|
22
|
+
)
|
|
15
23
|
from snowflake.snowpark.context import get_active_session
|
|
16
24
|
from snowflake.snowpark.exceptions import SnowparkSQLException
|
|
17
25
|
from snowflake.snowpark.functions import coalesce, col, lit, when
|
|
@@ -445,7 +453,7 @@ def _submit_job(
|
|
|
445
453
|
env_vars = kwargs.pop("env_vars", None)
|
|
446
454
|
spec_overrides = kwargs.pop("spec_overrides", None)
|
|
447
455
|
enable_metrics = kwargs.pop("enable_metrics", True)
|
|
448
|
-
query_warehouse = kwargs.pop("query_warehouse",
|
|
456
|
+
query_warehouse = kwargs.pop("query_warehouse", session.get_current_warehouse())
|
|
449
457
|
additional_payloads = kwargs.pop("additional_payloads", None)
|
|
450
458
|
|
|
451
459
|
if additional_payloads:
|
|
@@ -483,6 +491,27 @@ def _submit_job(
|
|
|
483
491
|
source, entrypoint=entrypoint, pip_requirements=pip_requirements, additional_payloads=additional_payloads
|
|
484
492
|
).upload(session, stage_path)
|
|
485
493
|
|
|
494
|
+
if feature_flags.FeatureFlags.USE_SUBMIT_JOB_V2.is_enabled():
|
|
495
|
+
# Add default env vars (extracted from spec_utils.generate_service_spec)
|
|
496
|
+
combined_env_vars = {**uploaded_payload.env_vars, **(env_vars or {})}
|
|
497
|
+
|
|
498
|
+
return _do_submit_job_v2(
|
|
499
|
+
session=session,
|
|
500
|
+
payload=uploaded_payload,
|
|
501
|
+
args=args,
|
|
502
|
+
env_vars=combined_env_vars,
|
|
503
|
+
spec_overrides=spec_overrides,
|
|
504
|
+
compute_pool=compute_pool,
|
|
505
|
+
job_id=job_id,
|
|
506
|
+
external_access_integrations=external_access_integrations,
|
|
507
|
+
query_warehouse=query_warehouse,
|
|
508
|
+
target_instances=target_instances,
|
|
509
|
+
min_instances=min_instances,
|
|
510
|
+
enable_metrics=enable_metrics,
|
|
511
|
+
use_async=True,
|
|
512
|
+
)
|
|
513
|
+
|
|
514
|
+
# Fall back to v1
|
|
486
515
|
# Generate service spec
|
|
487
516
|
spec = spec_utils.generate_service_spec(
|
|
488
517
|
session,
|
|
@@ -493,6 +522,8 @@ def _submit_job(
|
|
|
493
522
|
min_instances=min_instances,
|
|
494
523
|
enable_metrics=enable_metrics,
|
|
495
524
|
)
|
|
525
|
+
|
|
526
|
+
# Generate spec overrides
|
|
496
527
|
spec_overrides = spec_utils.generate_spec_overrides(
|
|
497
528
|
environment_vars=env_vars,
|
|
498
529
|
custom_overrides=spec_overrides,
|
|
@@ -500,26 +531,25 @@ def _submit_job(
|
|
|
500
531
|
if spec_overrides:
|
|
501
532
|
spec = spec_utils.merge_patch(spec, spec_overrides, display_name="spec_overrides")
|
|
502
533
|
|
|
503
|
-
|
|
504
|
-
spec, external_access_integrations, query_warehouse, target_instances,
|
|
534
|
+
return _do_submit_job_v1(
|
|
535
|
+
session, spec, external_access_integrations, query_warehouse, target_instances, compute_pool, job_id
|
|
505
536
|
)
|
|
506
|
-
_ = query_helper.run_query(session, query_text, params=params)
|
|
507
|
-
return get_job(job_id, session=session)
|
|
508
537
|
|
|
509
538
|
|
|
510
|
-
def
|
|
539
|
+
def _do_submit_job_v1(
|
|
540
|
+
session: snowpark.Session,
|
|
511
541
|
spec: dict[str, Any],
|
|
512
542
|
external_access_integrations: list[str],
|
|
513
543
|
query_warehouse: Optional[str],
|
|
514
544
|
target_instances: int,
|
|
515
|
-
session: snowpark.Session,
|
|
516
545
|
compute_pool: str,
|
|
517
546
|
job_id: str,
|
|
518
|
-
) ->
|
|
547
|
+
) -> jb.MLJob[Any]:
|
|
519
548
|
"""
|
|
520
549
|
Generate the SQL query for job submission.
|
|
521
550
|
|
|
522
551
|
Args:
|
|
552
|
+
session: The Snowpark session to use.
|
|
523
553
|
spec: The service spec for the job.
|
|
524
554
|
external_access_integrations: The external access integrations for the job.
|
|
525
555
|
query_warehouse: The query warehouse for the job.
|
|
@@ -529,7 +559,7 @@ def _generate_submission_query(
|
|
|
529
559
|
job_id: The ID of the job.
|
|
530
560
|
|
|
531
561
|
Returns:
|
|
532
|
-
|
|
562
|
+
The job object.
|
|
533
563
|
"""
|
|
534
564
|
query_template = textwrap.dedent(
|
|
535
565
|
"""\
|
|
@@ -547,12 +577,77 @@ def _generate_submission_query(
|
|
|
547
577
|
if external_access_integrations:
|
|
548
578
|
external_access_integration_list = ",".join(f"{e}" for e in external_access_integrations)
|
|
549
579
|
query.append(f"EXTERNAL_ACCESS_INTEGRATIONS = ({external_access_integration_list})")
|
|
550
|
-
query_warehouse = query_warehouse or session.get_current_warehouse()
|
|
551
580
|
if query_warehouse:
|
|
552
581
|
query.append("QUERY_WAREHOUSE = IDENTIFIER(?)")
|
|
553
582
|
params.append(query_warehouse)
|
|
554
583
|
if target_instances > 1:
|
|
555
584
|
query.append("REPLICAS = ?")
|
|
556
585
|
params.append(target_instances)
|
|
586
|
+
|
|
557
587
|
query_text = "\n".join(line for line in query if line)
|
|
558
|
-
|
|
588
|
+
_ = query_helper.run_query(session, query_text, params=params)
|
|
589
|
+
|
|
590
|
+
return get_job(job_id, session=session)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def _do_submit_job_v2(
|
|
594
|
+
session: snowpark.Session,
|
|
595
|
+
payload: types.UploadedPayload,
|
|
596
|
+
args: Optional[list[str]],
|
|
597
|
+
env_vars: dict[str, str],
|
|
598
|
+
spec_overrides: dict[str, Any],
|
|
599
|
+
compute_pool: str,
|
|
600
|
+
job_id: Optional[str] = None,
|
|
601
|
+
external_access_integrations: Optional[list[str]] = None,
|
|
602
|
+
query_warehouse: Optional[str] = None,
|
|
603
|
+
target_instances: int = 1,
|
|
604
|
+
min_instances: int = 1,
|
|
605
|
+
enable_metrics: bool = True,
|
|
606
|
+
use_async: bool = True,
|
|
607
|
+
) -> jb.MLJob[Any]:
|
|
608
|
+
"""
|
|
609
|
+
Generate the SQL query for job submission.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
session: The Snowpark session to use.
|
|
613
|
+
payload: The uploaded job payload.
|
|
614
|
+
args: Arguments to pass to the entrypoint script.
|
|
615
|
+
env_vars: Environment variables to set in the job container.
|
|
616
|
+
spec_overrides: Custom service specification overrides.
|
|
617
|
+
compute_pool: The compute pool to use for job execution.
|
|
618
|
+
job_id: The ID of the job.
|
|
619
|
+
external_access_integrations: Optional list of external access integrations.
|
|
620
|
+
query_warehouse: Optional query warehouse to use.
|
|
621
|
+
target_instances: Number of instances for multi-node job.
|
|
622
|
+
min_instances: Minimum number of instances required to start the job.
|
|
623
|
+
enable_metrics: Whether to enable platform metrics for the job.
|
|
624
|
+
use_async: Whether to run the job asynchronously.
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
The job object.
|
|
628
|
+
"""
|
|
629
|
+
args = [
|
|
630
|
+
(payload.stage_path.joinpath(v).as_posix() if isinstance(v, PurePath) else v) for v in payload.entrypoint
|
|
631
|
+
] + (args or [])
|
|
632
|
+
spec_options = {
|
|
633
|
+
"STAGE_PATH": payload.stage_path.as_posix(),
|
|
634
|
+
"ENTRYPOINT": ["/usr/local/bin/_entrypoint.sh"],
|
|
635
|
+
"ARGS": args,
|
|
636
|
+
"ENV_VARS": env_vars,
|
|
637
|
+
"ENABLE_METRICS": enable_metrics,
|
|
638
|
+
"SPEC_OVERRIDES": spec_overrides,
|
|
639
|
+
}
|
|
640
|
+
job_options = {
|
|
641
|
+
"EXTERNAL_ACCESS_INTEGRATIONS": external_access_integrations,
|
|
642
|
+
"QUERY_WAREHOUSE": query_warehouse,
|
|
643
|
+
"TARGET_INSTANCES": target_instances,
|
|
644
|
+
"MIN_INSTANCES": min_instances,
|
|
645
|
+
"ASYNC": use_async,
|
|
646
|
+
}
|
|
647
|
+
job_options = {k: v for k, v in job_options.items() if v is not None}
|
|
648
|
+
|
|
649
|
+
query_template = "CALL SYSTEM$EXECUTE_ML_JOB(?, ?, ?, ?)"
|
|
650
|
+
params = [job_id, compute_pool, json.dumps(spec_options), json.dumps(job_options)]
|
|
651
|
+
actual_job_id = query_helper.run_query(session, query_template, params=params)[0][0]
|
|
652
|
+
|
|
653
|
+
return get_job(actual_job_id, session=session)
|
snowflake/ml/model/__init__.py
CHANGED
|
@@ -1,5 +1,10 @@
|
|
|
1
|
+
from snowflake.ml.model._client.model.batch_inference_specs import (
|
|
2
|
+
InputSpec,
|
|
3
|
+
JobSpec,
|
|
4
|
+
OutputSpec,
|
|
5
|
+
)
|
|
1
6
|
from snowflake.ml.model._client.model.model_impl import Model
|
|
2
7
|
from snowflake.ml.model._client.model.model_version_impl import ExportMode, ModelVersion
|
|
3
8
|
from snowflake.ml.model.models.huggingface_pipeline import HuggingFacePipelineModel
|
|
4
9
|
|
|
5
|
-
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel"]
|
|
10
|
+
__all__ = ["Model", "ModelVersion", "ExportMode", "HuggingFacePipelineModel", "InputSpec", "JobSpec", "OutputSpec"]
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class InputSpec(BaseModel):
|
|
7
|
+
input_stage_location: str
|
|
8
|
+
input_file_pattern: str = "*"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class OutputSpec(BaseModel):
|
|
12
|
+
output_stage_location: str
|
|
13
|
+
output_file_prefix: Optional[str] = None
|
|
14
|
+
completion_filename: str = "_SUCCESS"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class JobSpec(BaseModel):
|
|
18
|
+
image_repo: Optional[str] = None
|
|
19
|
+
job_name: Optional[str] = None
|
|
20
|
+
num_workers: Optional[int] = None
|
|
21
|
+
function_name: Optional[str] = None
|
|
22
|
+
gpu: Optional[Union[str, int]] = None
|
|
23
|
+
force_rebuild: bool = False
|
|
24
|
+
max_batch_rows: int = 1024
|
|
25
|
+
warehouse: Optional[str] = None
|
|
26
|
+
cpu_requests: Optional[str] = None
|
|
27
|
+
memory_requests: Optional[str] = None
|
|
@@ -1,16 +1,18 @@
|
|
|
1
1
|
import enum
|
|
2
2
|
import pathlib
|
|
3
3
|
import tempfile
|
|
4
|
+
import uuid
|
|
4
5
|
import warnings
|
|
5
6
|
from typing import Any, Callable, Optional, Union, overload
|
|
6
7
|
|
|
7
8
|
import pandas as pd
|
|
8
9
|
|
|
9
|
-
from snowflake import
|
|
10
|
+
from snowflake.ml import jobs
|
|
10
11
|
from snowflake.ml._internal import telemetry
|
|
11
12
|
from snowflake.ml._internal.utils import sql_identifier
|
|
12
13
|
from snowflake.ml.lineage import lineage_node
|
|
13
14
|
from snowflake.ml.model import task, type_hints
|
|
15
|
+
from snowflake.ml.model._client.model import batch_inference_specs
|
|
14
16
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
15
17
|
from snowflake.ml.model._model_composer import model_composer
|
|
16
18
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
@@ -19,6 +21,7 @@ from snowflake.snowpark import Session, async_job, dataframe
|
|
|
19
21
|
|
|
20
22
|
_TELEMETRY_PROJECT = "MLOps"
|
|
21
23
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
|
24
|
+
_BATCH_INFERENCE_JOB_ID_PREFIX = "BATCH_INFERENCE_"
|
|
22
25
|
|
|
23
26
|
|
|
24
27
|
class ExportMode(enum.Enum):
|
|
@@ -539,6 +542,63 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
539
542
|
is_partitioned=target_function_info["is_partitioned"],
|
|
540
543
|
)
|
|
541
544
|
|
|
545
|
+
@telemetry.send_api_usage_telemetry(
|
|
546
|
+
project=_TELEMETRY_PROJECT,
|
|
547
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
548
|
+
func_params_to_log=[
|
|
549
|
+
"compute_pool",
|
|
550
|
+
],
|
|
551
|
+
)
|
|
552
|
+
def _run_batch(
|
|
553
|
+
self,
|
|
554
|
+
*,
|
|
555
|
+
compute_pool: str,
|
|
556
|
+
input_spec: batch_inference_specs.InputSpec,
|
|
557
|
+
output_spec: batch_inference_specs.OutputSpec,
|
|
558
|
+
job_spec: Optional[batch_inference_specs.JobSpec] = None,
|
|
559
|
+
) -> jobs.MLJob[Any]:
|
|
560
|
+
statement_params = telemetry.get_statement_params(
|
|
561
|
+
project=_TELEMETRY_PROJECT,
|
|
562
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
if job_spec is None:
|
|
566
|
+
job_spec = batch_inference_specs.JobSpec()
|
|
567
|
+
|
|
568
|
+
warehouse = job_spec.warehouse or self._service_ops._session.get_current_warehouse()
|
|
569
|
+
if warehouse is None:
|
|
570
|
+
raise ValueError("Warehouse is not set. Please set the warehouse field in the JobSpec.")
|
|
571
|
+
|
|
572
|
+
if job_spec.job_name is None:
|
|
573
|
+
# Same as the MLJob ID generation logic with a different prefix
|
|
574
|
+
job_name = f"{_BATCH_INFERENCE_JOB_ID_PREFIX}{str(uuid.uuid4()).replace('-', '_').upper()}"
|
|
575
|
+
else:
|
|
576
|
+
job_name = job_spec.job_name
|
|
577
|
+
|
|
578
|
+
return self._service_ops.invoke_batch_job_method(
|
|
579
|
+
# model version info
|
|
580
|
+
model_name=self._model_name,
|
|
581
|
+
version_name=self._version_name,
|
|
582
|
+
# job spec
|
|
583
|
+
function_name=self._get_function_info(function_name=job_spec.function_name)["target_method"],
|
|
584
|
+
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
585
|
+
force_rebuild=job_spec.force_rebuild,
|
|
586
|
+
image_repo_name=job_spec.image_repo,
|
|
587
|
+
num_workers=job_spec.num_workers,
|
|
588
|
+
max_batch_rows=job_spec.max_batch_rows,
|
|
589
|
+
warehouse=sql_identifier.SqlIdentifier(warehouse),
|
|
590
|
+
cpu_requests=job_spec.cpu_requests,
|
|
591
|
+
memory_requests=job_spec.memory_requests,
|
|
592
|
+
job_name=job_name,
|
|
593
|
+
# input and output
|
|
594
|
+
input_stage_location=input_spec.input_stage_location,
|
|
595
|
+
input_file_pattern=input_spec.input_file_pattern,
|
|
596
|
+
output_stage_location=output_spec.output_stage_location,
|
|
597
|
+
completion_filename=output_spec.completion_filename,
|
|
598
|
+
# misc
|
|
599
|
+
statement_params=statement_params,
|
|
600
|
+
)
|
|
601
|
+
|
|
542
602
|
def _get_function_info(self, function_name: Optional[str]) -> model_manifest_schema.ModelFunctionInfo:
|
|
543
603
|
functions: list[model_manifest_schema.ModelFunctionInfo] = self._functions
|
|
544
604
|
|
|
@@ -1184,69 +1244,5 @@ class ModelVersion(lineage_node.LineageNode):
|
|
|
1184
1244
|
statement_params=statement_params,
|
|
1185
1245
|
)
|
|
1186
1246
|
|
|
1187
|
-
@snowpark._internal.utils.private_preview(version="1.8.3")
|
|
1188
|
-
@telemetry.send_api_usage_telemetry(
|
|
1189
|
-
project=_TELEMETRY_PROJECT,
|
|
1190
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
|
1191
|
-
)
|
|
1192
|
-
def _run_job(
|
|
1193
|
-
self,
|
|
1194
|
-
X: Union[pd.DataFrame, "dataframe.DataFrame"],
|
|
1195
|
-
*,
|
|
1196
|
-
job_name: str,
|
|
1197
|
-
compute_pool: str,
|
|
1198
|
-
image_repo: Optional[str] = None,
|
|
1199
|
-
output_table_name: str,
|
|
1200
|
-
function_name: Optional[str] = None,
|
|
1201
|
-
cpu_requests: Optional[str] = None,
|
|
1202
|
-
memory_requests: Optional[str] = None,
|
|
1203
|
-
gpu_requests: Optional[Union[str, int]] = None,
|
|
1204
|
-
num_workers: Optional[int] = None,
|
|
1205
|
-
max_batch_rows: Optional[int] = None,
|
|
1206
|
-
force_rebuild: bool = False,
|
|
1207
|
-
build_external_access_integrations: Optional[list[str]] = None,
|
|
1208
|
-
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
|
1209
|
-
statement_params = telemetry.get_statement_params(
|
|
1210
|
-
project=_TELEMETRY_PROJECT,
|
|
1211
|
-
subproject=_TELEMETRY_SUBPROJECT,
|
|
1212
|
-
)
|
|
1213
|
-
target_function_info = self._get_function_info(function_name=function_name)
|
|
1214
|
-
job_db_id, job_schema_id, job_id = sql_identifier.parse_fully_qualified_name(job_name)
|
|
1215
|
-
output_table_db_id, output_table_schema_id, output_table_id = sql_identifier.parse_fully_qualified_name(
|
|
1216
|
-
output_table_name
|
|
1217
|
-
)
|
|
1218
|
-
warehouse = self._service_ops._session.get_current_warehouse()
|
|
1219
|
-
assert warehouse, "No active warehouse selected in the current session."
|
|
1220
|
-
return self._service_ops.invoke_job_method(
|
|
1221
|
-
target_method=target_function_info["target_method"],
|
|
1222
|
-
signature=target_function_info["signature"],
|
|
1223
|
-
X=X,
|
|
1224
|
-
database_name=None,
|
|
1225
|
-
schema_name=None,
|
|
1226
|
-
model_name=self._model_name,
|
|
1227
|
-
version_name=self._version_name,
|
|
1228
|
-
job_database_name=job_db_id,
|
|
1229
|
-
job_schema_name=job_schema_id,
|
|
1230
|
-
job_name=job_id,
|
|
1231
|
-
compute_pool_name=sql_identifier.SqlIdentifier(compute_pool),
|
|
1232
|
-
warehouse_name=sql_identifier.SqlIdentifier(warehouse),
|
|
1233
|
-
image_repo_name=image_repo,
|
|
1234
|
-
output_table_database_name=output_table_db_id,
|
|
1235
|
-
output_table_schema_name=output_table_schema_id,
|
|
1236
|
-
output_table_name=output_table_id,
|
|
1237
|
-
cpu_requests=cpu_requests,
|
|
1238
|
-
memory_requests=memory_requests,
|
|
1239
|
-
gpu_requests=gpu_requests,
|
|
1240
|
-
num_workers=num_workers,
|
|
1241
|
-
max_batch_rows=max_batch_rows,
|
|
1242
|
-
force_rebuild=force_rebuild,
|
|
1243
|
-
build_external_access_integrations=(
|
|
1244
|
-
None
|
|
1245
|
-
if build_external_access_integrations is None
|
|
1246
|
-
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
1247
|
-
),
|
|
1248
|
-
statement_params=statement_params,
|
|
1249
|
-
)
|
|
1250
|
-
|
|
1251
1247
|
|
|
1252
1248
|
lineage_node.DOMAIN_LINEAGE_REGISTRY["model"] = ModelVersion
|
|
@@ -10,17 +10,13 @@ import time
|
|
|
10
10
|
from typing import Any, Optional, Union, cast
|
|
11
11
|
|
|
12
12
|
from snowflake import snowpark
|
|
13
|
+
from snowflake.ml import jobs
|
|
13
14
|
from snowflake.ml._internal import file_utils, platform_capabilities as pc
|
|
14
15
|
from snowflake.ml._internal.utils import identifier, service_logger, sql_identifier
|
|
15
|
-
from snowflake.ml.model import
|
|
16
|
-
inference_engine as inference_engine_module,
|
|
17
|
-
model_signature,
|
|
18
|
-
type_hints,
|
|
19
|
-
)
|
|
16
|
+
from snowflake.ml.model import inference_engine as inference_engine_module, type_hints
|
|
20
17
|
from snowflake.ml.model._client.service import model_deployment_spec
|
|
21
18
|
from snowflake.ml.model._client.sql import service as service_sql, stage as stage_sql
|
|
22
|
-
from snowflake.
|
|
23
|
-
from snowflake.snowpark import async_job, dataframe, exceptions, row, session
|
|
19
|
+
from snowflake.snowpark import async_job, exceptions, row, session
|
|
24
20
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
25
21
|
|
|
26
22
|
module_logger = service_logger.get_logger(__name__, service_logger.LogColor.GREY)
|
|
@@ -866,174 +862,97 @@ class ServiceOperator:
|
|
|
866
862
|
except exceptions.SnowparkSQLException:
|
|
867
863
|
return False
|
|
868
864
|
|
|
869
|
-
def
|
|
865
|
+
def invoke_batch_job_method(
|
|
870
866
|
self,
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
|
874
|
-
database_name: Optional[sql_identifier.SqlIdentifier],
|
|
875
|
-
schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
867
|
+
*,
|
|
868
|
+
function_name: str,
|
|
876
869
|
model_name: sql_identifier.SqlIdentifier,
|
|
877
870
|
version_name: sql_identifier.SqlIdentifier,
|
|
878
|
-
|
|
879
|
-
job_schema_name: Optional[sql_identifier.SqlIdentifier],
|
|
880
|
-
job_name: sql_identifier.SqlIdentifier,
|
|
871
|
+
job_name: str,
|
|
881
872
|
compute_pool_name: sql_identifier.SqlIdentifier,
|
|
882
|
-
|
|
873
|
+
warehouse: sql_identifier.SqlIdentifier,
|
|
883
874
|
image_repo_name: Optional[str],
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
gpu_requests: Optional[Union[int, str]],
|
|
875
|
+
input_stage_location: str,
|
|
876
|
+
input_file_pattern: str,
|
|
877
|
+
output_stage_location: str,
|
|
878
|
+
completion_filename: str,
|
|
879
|
+
force_rebuild: bool,
|
|
890
880
|
num_workers: Optional[int],
|
|
891
881
|
max_batch_rows: Optional[int],
|
|
892
|
-
|
|
893
|
-
|
|
882
|
+
cpu_requests: Optional[str],
|
|
883
|
+
memory_requests: Optional[str],
|
|
894
884
|
statement_params: Optional[dict[str, Any]] = None,
|
|
895
|
-
) ->
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
schema_name = schema_name or self._schema_name
|
|
899
|
-
|
|
900
|
-
# fall back to the model's database and schema if not provided then to the registry's database and schema
|
|
901
|
-
job_database_name = job_database_name or database_name or self._database_name
|
|
902
|
-
job_schema_name = job_schema_name or schema_name or self._schema_name
|
|
885
|
+
) -> jobs.MLJob[Any]:
|
|
886
|
+
database_name = self._database_name
|
|
887
|
+
schema_name = self._schema_name
|
|
903
888
|
|
|
904
|
-
|
|
889
|
+
job_database_name, job_schema_name, job_name = sql_identifier.parse_fully_qualified_name(job_name)
|
|
890
|
+
job_database_name = job_database_name or database_name
|
|
891
|
+
job_schema_name = job_schema_name or schema_name
|
|
905
892
|
|
|
906
|
-
|
|
907
|
-
input_table_schema_name = job_schema_name
|
|
908
|
-
output_table_database_name = output_table_database_name or database_name or self._database_name
|
|
909
|
-
output_table_schema_name = output_table_schema_name or schema_name or self._schema_name
|
|
910
|
-
|
|
911
|
-
if self._workspace:
|
|
912
|
-
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
913
|
-
else:
|
|
914
|
-
stage_path = None
|
|
893
|
+
self._model_deployment_spec.clear()
|
|
915
894
|
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
s_df = snowpark_handler.SnowparkDataFrameHandler.convert_from_df(
|
|
922
|
-
self._session, df, keep_order=keep_order, features=signature.inputs, statement_params=statement_params
|
|
923
|
-
)
|
|
924
|
-
else:
|
|
925
|
-
keep_order = False
|
|
926
|
-
output_with_input_features = True
|
|
927
|
-
s_df = X
|
|
928
|
-
|
|
929
|
-
# only write the index and feature input columns
|
|
930
|
-
cols = [snowpark_handler._KEEP_ORDER_COL_NAME] if snowpark_handler._KEEP_ORDER_COL_NAME in s_df.columns else []
|
|
931
|
-
cols += [
|
|
932
|
-
sql_identifier.SqlIdentifier(feature.name, case_sensitive=True).identifier() for feature in signature.inputs
|
|
933
|
-
]
|
|
934
|
-
s_df = s_df.select(cols)
|
|
935
|
-
original_cols = s_df.columns
|
|
936
|
-
|
|
937
|
-
# input/output tables
|
|
938
|
-
fq_output_table_name = identifier.get_schema_level_object_identifier(
|
|
939
|
-
output_table_database_name.identifier(),
|
|
940
|
-
output_table_schema_name.identifier(),
|
|
941
|
-
output_table_name.identifier(),
|
|
942
|
-
)
|
|
943
|
-
tmp_input_table_id = sql_identifier.SqlIdentifier(
|
|
944
|
-
snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
|
945
|
-
)
|
|
946
|
-
fq_tmp_input_table_name = identifier.get_schema_level_object_identifier(
|
|
947
|
-
job_database_name.identifier(),
|
|
948
|
-
job_schema_name.identifier(),
|
|
949
|
-
tmp_input_table_id.identifier(),
|
|
950
|
-
)
|
|
951
|
-
s_df.write.save_as_table(
|
|
952
|
-
table_name=fq_tmp_input_table_name,
|
|
953
|
-
mode="errorifexists",
|
|
954
|
-
statement_params=statement_params,
|
|
895
|
+
self._model_deployment_spec.add_model_spec(
|
|
896
|
+
database_name=database_name,
|
|
897
|
+
schema_name=schema_name,
|
|
898
|
+
model_name=model_name,
|
|
899
|
+
version_name=version_name,
|
|
955
900
|
)
|
|
956
901
|
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
gpu=gpu_requests,
|
|
974
|
-
num_workers=num_workers,
|
|
975
|
-
max_batch_rows=max_batch_rows,
|
|
976
|
-
warehouse=warehouse_name,
|
|
977
|
-
target_method=target_method,
|
|
978
|
-
input_table_database_name=input_table_database_name,
|
|
979
|
-
input_table_schema_name=input_table_schema_name,
|
|
980
|
-
input_table_name=tmp_input_table_id,
|
|
981
|
-
output_table_database_name=output_table_database_name,
|
|
982
|
-
output_table_schema_name=output_table_schema_name,
|
|
983
|
-
output_table_name=output_table_name,
|
|
984
|
-
)
|
|
902
|
+
self._model_deployment_spec.add_job_spec(
|
|
903
|
+
job_database_name=job_database_name,
|
|
904
|
+
job_schema_name=job_schema_name,
|
|
905
|
+
job_name=job_name,
|
|
906
|
+
inference_compute_pool_name=compute_pool_name,
|
|
907
|
+
num_workers=num_workers,
|
|
908
|
+
max_batch_rows=max_batch_rows,
|
|
909
|
+
input_stage_location=input_stage_location,
|
|
910
|
+
input_file_pattern=input_file_pattern,
|
|
911
|
+
output_stage_location=output_stage_location,
|
|
912
|
+
completion_filename=completion_filename,
|
|
913
|
+
function_name=function_name,
|
|
914
|
+
warehouse=warehouse,
|
|
915
|
+
cpu=cpu_requests,
|
|
916
|
+
memory=memory_requests,
|
|
917
|
+
)
|
|
985
918
|
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
)
|
|
919
|
+
self._model_deployment_spec.add_image_build_spec(
|
|
920
|
+
image_build_compute_pool_name=compute_pool_name,
|
|
921
|
+
fully_qualified_image_repo_name=self._get_image_repo_fqn(image_repo_name, database_name, schema_name),
|
|
922
|
+
force_rebuild=force_rebuild,
|
|
923
|
+
)
|
|
992
924
|
|
|
993
|
-
|
|
994
|
-
if self._workspace:
|
|
995
|
-
assert stage_path is not None
|
|
996
|
-
file_utils.upload_directory_to_stage(
|
|
997
|
-
self._session,
|
|
998
|
-
local_path=pathlib.Path(self._workspace.name),
|
|
999
|
-
stage_path=pathlib.PurePosixPath(stage_path),
|
|
1000
|
-
statement_params=statement_params,
|
|
1001
|
-
)
|
|
925
|
+
spec_yaml_str_or_path = self._model_deployment_spec.save()
|
|
1002
926
|
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
),
|
|
1009
|
-
|
|
927
|
+
if self._workspace:
|
|
928
|
+
module_logger.info("using workspace")
|
|
929
|
+
stage_path = self._create_temp_stage(database_name, schema_name, statement_params)
|
|
930
|
+
file_utils.upload_directory_to_stage(
|
|
931
|
+
self._session,
|
|
932
|
+
local_path=pathlib.Path(self._workspace.name),
|
|
933
|
+
stage_path=pathlib.PurePosixPath(stage_path),
|
|
1010
934
|
statement_params=statement_params,
|
|
1011
935
|
)
|
|
936
|
+
else:
|
|
937
|
+
module_logger.info("not using workspace")
|
|
938
|
+
stage_path = None
|
|
1012
939
|
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
df_res = df_res.sort(
|
|
1022
|
-
snowpark_handler._KEEP_ORDER_COL_NAME,
|
|
1023
|
-
ascending=True,
|
|
1024
|
-
)
|
|
1025
|
-
df_res = df_res.drop(snowpark_handler._KEEP_ORDER_COL_NAME)
|
|
940
|
+
_, async_job = self._service_client.deploy_model(
|
|
941
|
+
stage_path=stage_path if self._workspace else None,
|
|
942
|
+
model_deployment_spec_file_rel_path=(
|
|
943
|
+
model_deployment_spec.ModelDeploymentSpec.DEPLOY_SPEC_FILE_REL_PATH if self._workspace else None
|
|
944
|
+
),
|
|
945
|
+
model_deployment_spec_yaml_str=None if self._workspace else spec_yaml_str_or_path,
|
|
946
|
+
statement_params=statement_params,
|
|
947
|
+
)
|
|
1026
948
|
|
|
1027
|
-
|
|
1028
|
-
|
|
949
|
+
# Block until the async job is done
|
|
950
|
+
async_job.result()
|
|
1029
951
|
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
)
|
|
1035
|
-
else:
|
|
1036
|
-
return df_res
|
|
952
|
+
return jobs.MLJob(
|
|
953
|
+
id=sql_identifier.get_fully_qualified_name(job_database_name, job_schema_name, job_name),
|
|
954
|
+
session=self._session,
|
|
955
|
+
)
|
|
1037
956
|
|
|
1038
957
|
def _create_temp_stage(
|
|
1039
958
|
self,
|