snowflake-ml-python 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +66 -31
- snowflake/ml/_internal/exceptions/dataset_error_messages.py +5 -0
- snowflake/ml/_internal/exceptions/dataset_errors.py +24 -0
- snowflake/ml/_internal/exceptions/error_codes.py +3 -0
- snowflake/ml/_internal/lineage/data_source.py +10 -0
- snowflake/ml/_internal/lineage/dataset_dataframe.py +44 -0
- snowflake/ml/dataset/__init__.py +10 -0
- snowflake/ml/dataset/dataset.py +454 -129
- snowflake/ml/dataset/dataset_factory.py +53 -0
- snowflake/ml/dataset/dataset_metadata.py +103 -0
- snowflake/ml/dataset/dataset_reader.py +202 -0
- snowflake/ml/feature_store/feature_store.py +408 -282
- snowflake/ml/feature_store/feature_view.py +37 -8
- snowflake/ml/fileset/embedded_stage_fs.py +146 -0
- snowflake/ml/fileset/sfcfs.py +0 -4
- snowflake/ml/fileset/snowfs.py +159 -0
- snowflake/ml/fileset/stage_fs.py +1 -4
- snowflake/ml/model/__init__.py +2 -2
- snowflake/ml/model/_api.py +16 -1
- snowflake/ml/model/_client/model/model_impl.py +27 -0
- snowflake/ml/model/_client/model/model_version_impl.py +135 -0
- snowflake/ml/model/_client/ops/model_ops.py +137 -67
- snowflake/ml/model/_client/sql/model.py +16 -14
- snowflake/ml/model/_client/sql/model_version.py +109 -1
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +5 -1
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +1 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +2 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +0 -5
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +21 -50
- snowflake/ml/model/_model_composer/model_composer.py +22 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +22 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +11 -0
- snowflake/ml/model/_packager/model_env/model_env.py +41 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +1 -5
- snowflake/ml/model/_packager/model_packager.py +0 -3
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +55 -3
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +34 -18
- snowflake/ml/modeling/_internal/model_trainer.py +7 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +42 -9
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +24 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +261 -16
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -52
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -52
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -52
- snowflake/ml/modeling/cluster/birch.py +53 -52
- snowflake/ml/modeling/cluster/bisecting_k_means.py +53 -52
- snowflake/ml/modeling/cluster/dbscan.py +51 -52
- snowflake/ml/modeling/cluster/feature_agglomeration.py +53 -52
- snowflake/ml/modeling/cluster/k_means.py +53 -52
- snowflake/ml/modeling/cluster/mean_shift.py +51 -52
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +53 -52
- snowflake/ml/modeling/cluster/optics.py +51 -52
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -52
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -52
- snowflake/ml/modeling/compose/column_transformer.py +53 -52
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -52
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -52
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -52
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -52
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -52
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -52
- snowflake/ml/modeling/covariance/oas.py +51 -52
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -52
- snowflake/ml/modeling/decomposition/dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/factor_analysis.py +53 -52
- snowflake/ml/modeling/decomposition/fast_ica.py +53 -52
- snowflake/ml/modeling/decomposition/incremental_pca.py +53 -52
- snowflake/ml/modeling/decomposition/kernel_pca.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +53 -52
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/pca.py +53 -52
- snowflake/ml/modeling/decomposition/sparse_pca.py +53 -52
- snowflake/ml/modeling/decomposition/truncated_svd.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +53 -52
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -52
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -52
- snowflake/ml/modeling/ensemble/stacking_regressor.py +53 -52
- snowflake/ml/modeling/ensemble/voting_classifier.py +53 -52
- snowflake/ml/modeling/ensemble/voting_regressor.py +53 -52
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fdr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fpr.py +53 -52
- snowflake/ml/modeling/feature_selection/select_fwe.py +53 -52
- snowflake/ml/modeling/feature_selection/select_k_best.py +53 -52
- snowflake/ml/modeling/feature_selection/select_percentile.py +53 -52
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +53 -52
- snowflake/ml/modeling/feature_selection/variance_threshold.py +53 -52
- snowflake/ml/modeling/framework/base.py +63 -36
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -52
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -52
- snowflake/ml/modeling/impute/iterative_imputer.py +53 -52
- snowflake/ml/modeling/impute/knn_imputer.py +53 -52
- snowflake/ml/modeling/impute/missing_indicator.py +53 -52
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/nystroem.py +53 -52
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +53 -52
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +53 -52
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +53 -52
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -52
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -52
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/lars.py +51 -52
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -52
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -52
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -52
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -52
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -52
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/perceptron.py +51 -52
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/ridge.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -52
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -52
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -52
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -52
- snowflake/ml/modeling/manifold/isomap.py +53 -52
- snowflake/ml/modeling/manifold/mds.py +53 -52
- snowflake/ml/modeling/manifold/spectral_embedding.py +53 -52
- snowflake/ml/modeling/manifold/tsne.py +53 -52
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -52
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -52
- snowflake/ml/modeling/model_selection/grid_search_cv.py +21 -23
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +38 -20
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -52
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -52
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -52
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -52
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -52
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -52
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +53 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -52
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -52
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +53 -52
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -52
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -52
- snowflake/ml/modeling/pipeline/pipeline.py +514 -32
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +12 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +53 -52
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -52
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -52
- snowflake/ml/modeling/svm/linear_svc.py +51 -52
- snowflake/ml/modeling/svm/linear_svr.py +51 -52
- snowflake/ml/modeling/svm/nu_svc.py +51 -52
- snowflake/ml/modeling/svm/nu_svr.py +51 -52
- snowflake/ml/modeling/svm/svc.py +51 -52
- snowflake/ml/modeling/svm/svr.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -52
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -52
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -52
- snowflake/ml/registry/model_registry.py +3 -149
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/METADATA +63 -2
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/RECORD +204 -196
- snowflake/ml/registry/_artifact_manager.py +0 -156
- snowflake/ml/registry/artifact.py +0 -46
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.4.1.dist-info → snowflake_ml_python-1.5.0.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,29 @@
|
|
1
|
+
import enum
|
2
|
+
import pathlib
|
3
|
+
import tempfile
|
4
|
+
import warnings
|
1
5
|
from typing import Any, Callable, Dict, List, Optional, Union
|
2
6
|
|
3
7
|
import pandas as pd
|
4
8
|
|
5
9
|
from snowflake.ml._internal import telemetry
|
6
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
+
from snowflake.ml.model import type_hints as model_types
|
7
12
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
13
|
+
from snowflake.ml.model._model_composer import model_composer
|
8
14
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
15
|
+
from snowflake.ml.model._packager.model_handlers import snowmlmodel
|
9
16
|
from snowflake.snowpark import dataframe
|
10
17
|
|
11
18
|
_TELEMETRY_PROJECT = "MLOps"
|
12
19
|
_TELEMETRY_SUBPROJECT = "ModelManagement"
|
13
20
|
|
14
21
|
|
22
|
+
class ExportMode(enum.Enum):
|
23
|
+
MODEL = "model"
|
24
|
+
FULL = "full"
|
25
|
+
|
26
|
+
|
15
27
|
class ModelVersion:
|
16
28
|
"""Model Version Object representing a specific version of the model that could be run."""
|
17
29
|
|
@@ -240,6 +252,7 @@ class ModelVersion:
|
|
240
252
|
X: Union[pd.DataFrame, dataframe.DataFrame],
|
241
253
|
*,
|
242
254
|
function_name: Optional[str] = None,
|
255
|
+
partition_column: Optional[str] = None,
|
243
256
|
strict_input_validation: bool = False,
|
244
257
|
) -> Union[pd.DataFrame, dataframe.DataFrame]:
|
245
258
|
"""Invoke a method in a model version object.
|
@@ -248,12 +261,14 @@ class ModelVersion:
|
|
248
261
|
X: The input data, which could be a pandas DataFrame or Snowpark DataFrame.
|
249
262
|
function_name: The function name to run. It is the name used to call a function in SQL.
|
250
263
|
Defaults to None. It can only be None if there is only 1 method.
|
264
|
+
partition_column: The partition column name to partition by.
|
251
265
|
strict_input_validation: Enable stricter validation for the input data. This will result value range based
|
252
266
|
type validation to make sure your input data won't overflow when providing to the model.
|
253
267
|
|
254
268
|
Raises:
|
255
269
|
ValueError: When no method with the corresponding name is available.
|
256
270
|
ValueError: When there are more than 1 target methods available in the model but no function name specified.
|
271
|
+
ValueError: When the partition column is not a valid Snowflake identifier.
|
257
272
|
|
258
273
|
Returns:
|
259
274
|
The prediction data. It would be the same type dataframe as your input.
|
@@ -263,6 +278,10 @@ class ModelVersion:
|
|
263
278
|
subproject=_TELEMETRY_SUBPROJECT,
|
264
279
|
)
|
265
280
|
|
281
|
+
if partition_column is not None:
|
282
|
+
# Partition column must be a valid identifier
|
283
|
+
partition_column = sql_identifier.SqlIdentifier(partition_column)
|
284
|
+
|
266
285
|
functions: List[model_manifest_schema.ModelFunctionInfo] = self._functions
|
267
286
|
if function_name:
|
268
287
|
req_method_name = sql_identifier.SqlIdentifier(function_name).identifier()
|
@@ -287,10 +306,126 @@ class ModelVersion:
|
|
287
306
|
target_function_info = functions[0]
|
288
307
|
return self._model_ops.invoke_method(
|
289
308
|
method_name=sql_identifier.SqlIdentifier(target_function_info["name"]),
|
309
|
+
method_function_type=target_function_info["target_method_function_type"],
|
290
310
|
signature=target_function_info["signature"],
|
291
311
|
X=X,
|
292
312
|
model_name=self._model_name,
|
293
313
|
version_name=self._version_name,
|
294
314
|
strict_input_validation=strict_input_validation,
|
315
|
+
partition_column=partition_column,
|
316
|
+
statement_params=statement_params,
|
317
|
+
)
|
318
|
+
|
319
|
+
@telemetry.send_api_usage_telemetry(
|
320
|
+
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["export_mode"]
|
321
|
+
)
|
322
|
+
def export(self, target_path: str, *, export_mode: ExportMode = ExportMode.MODEL) -> None:
|
323
|
+
"""Export model files to a local directory.
|
324
|
+
|
325
|
+
Args:
|
326
|
+
target_path: Path to a local directory to export files to. A directory will be created if does not exist.
|
327
|
+
export_mode: The mode to export the model. Defaults to ExportMode.MODEL.
|
328
|
+
ExportMode.MODEL: All model files including environment to load the model and model weights.
|
329
|
+
ExportMode.FULL: Additional files to run the model in Warehouse, besides all files in MODEL mode,
|
330
|
+
|
331
|
+
Raises:
|
332
|
+
ValueError: Raised when the target path is a file or an non-empty folder.
|
333
|
+
"""
|
334
|
+
target_local_path = pathlib.Path(target_path)
|
335
|
+
if target_local_path.is_file() or any(target_local_path.iterdir()):
|
336
|
+
raise ValueError(f"Target path {target_local_path} is a file or an non-empty folder.")
|
337
|
+
|
338
|
+
target_local_path.mkdir(parents=False, exist_ok=True)
|
339
|
+
statement_params = telemetry.get_statement_params(
|
340
|
+
project=_TELEMETRY_PROJECT,
|
341
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
342
|
+
)
|
343
|
+
self._model_ops.download_files(
|
344
|
+
model_name=self._model_name,
|
345
|
+
version_name=self._version_name,
|
346
|
+
target_path=target_local_path,
|
347
|
+
mode=export_mode.value,
|
348
|
+
statement_params=statement_params,
|
349
|
+
)
|
350
|
+
|
351
|
+
@telemetry.send_api_usage_telemetry(
|
352
|
+
project=_TELEMETRY_PROJECT, subproject=_TELEMETRY_SUBPROJECT, func_params_to_log=["force", "options"]
|
353
|
+
)
|
354
|
+
def load(
|
355
|
+
self,
|
356
|
+
*,
|
357
|
+
force: bool = False,
|
358
|
+
options: Optional[model_types.ModelLoadOption] = None,
|
359
|
+
) -> model_types.SupportedModelType:
|
360
|
+
"""Load the underlying original Python object back from a model.
|
361
|
+
This operation requires to have the exact the same environment as the one when logging the model, otherwise,
|
362
|
+
the model might be not functional or some other problems might occur.
|
363
|
+
|
364
|
+
Args:
|
365
|
+
force: Bypass the best-effort environment validation. Defaults to False.
|
366
|
+
options: Options to specify when loading the model, check `snowflake.ml.model.type_hints` for available
|
367
|
+
options. Defaults to None.
|
368
|
+
|
369
|
+
Raises:
|
370
|
+
ValueError: Raised when the best-effort environment validation fails.
|
371
|
+
|
372
|
+
Returns:
|
373
|
+
The original Python object loaded from the model object.
|
374
|
+
"""
|
375
|
+
statement_params = telemetry.get_statement_params(
|
376
|
+
project=_TELEMETRY_PROJECT,
|
377
|
+
subproject=_TELEMETRY_SUBPROJECT,
|
378
|
+
)
|
379
|
+
if not force:
|
380
|
+
with tempfile.TemporaryDirectory() as tmp_workspace_for_validation:
|
381
|
+
ws_path_for_validation = pathlib.Path(tmp_workspace_for_validation)
|
382
|
+
self._model_ops.download_files(
|
383
|
+
model_name=self._model_name,
|
384
|
+
version_name=self._version_name,
|
385
|
+
target_path=ws_path_for_validation,
|
386
|
+
mode="minimal",
|
387
|
+
statement_params=statement_params,
|
388
|
+
)
|
389
|
+
pk_for_validation = model_composer.ModelComposer.load(
|
390
|
+
ws_path_for_validation, meta_only=True, options=options
|
391
|
+
)
|
392
|
+
assert pk_for_validation.meta, (
|
393
|
+
"Unable to load model metadata for validation. "
|
394
|
+
f"model_name={self._model_name}, version_name={self._version_name}"
|
395
|
+
)
|
396
|
+
|
397
|
+
validation_errors = pk_for_validation.meta.env.validate_with_local_env(
|
398
|
+
check_snowpark_ml_version=(
|
399
|
+
pk_for_validation.meta.model_type == snowmlmodel.SnowMLModelHandler.HANDLER_TYPE
|
400
|
+
)
|
401
|
+
)
|
402
|
+
if validation_errors:
|
403
|
+
raise ValueError(
|
404
|
+
f"Unable to load this model due to following validation errors: {validation_errors}. "
|
405
|
+
"Make sure your local environment is the same as that when you logged the model, "
|
406
|
+
"or if you believe it should work, specify `force=True` to bypass this check."
|
407
|
+
)
|
408
|
+
|
409
|
+
warnings.warn(
|
410
|
+
"Loading model requires to have the exact the same environment as the one when "
|
411
|
+
"logging the model, otherwise, the model might be not functional or "
|
412
|
+
"some other problems might occur.",
|
413
|
+
category=RuntimeWarning,
|
414
|
+
stacklevel=2,
|
415
|
+
)
|
416
|
+
|
417
|
+
# We need the folder to be existed.
|
418
|
+
workspace = pathlib.Path(tempfile.mkdtemp())
|
419
|
+
self._model_ops.download_files(
|
420
|
+
model_name=self._model_name,
|
421
|
+
version_name=self._version_name,
|
422
|
+
target_path=workspace,
|
423
|
+
mode="model",
|
295
424
|
statement_params=statement_params,
|
296
425
|
)
|
426
|
+
pk = model_composer.ModelComposer.load(workspace, meta_only=False, options=options)
|
427
|
+
assert pk.model, (
|
428
|
+
"Unable to load model. "
|
429
|
+
f"model_name={self._model_name}, version_name={self._version_name}, metadata={pk.meta}"
|
430
|
+
)
|
431
|
+
return pk.model
|
@@ -1,7 +1,7 @@
|
|
1
|
+
import os
|
1
2
|
import pathlib
|
2
3
|
import tempfile
|
3
|
-
from
|
4
|
-
from typing import Any, Dict, Generator, List, Optional, Union, cast
|
4
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
5
5
|
|
6
6
|
import yaml
|
7
7
|
|
@@ -19,7 +19,9 @@ from snowflake.ml.model._model_composer.model_manifest import (
|
|
19
19
|
model_manifest,
|
20
20
|
model_manifest_schema,
|
21
21
|
)
|
22
|
+
from snowflake.ml.model._packager.model_env import model_env
|
22
23
|
from snowflake.ml.model._packager.model_meta import model_meta
|
24
|
+
from snowflake.ml.model._packager.model_runtime import model_runtime
|
23
25
|
from snowflake.ml.model._signatures import snowpark_handler
|
24
26
|
from snowflake.snowpark import dataframe, row, session
|
25
27
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
@@ -337,16 +339,6 @@ class ModelOperator:
|
|
337
339
|
mm = model_manifest.ModelManifest(pathlib.Path(tmpdir))
|
338
340
|
return mm.load()
|
339
341
|
|
340
|
-
@contextmanager
|
341
|
-
def _enable_model_details(
|
342
|
-
self,
|
343
|
-
*,
|
344
|
-
statement_params: Optional[Dict[str, Any]] = None,
|
345
|
-
) -> Generator[None, None, None]:
|
346
|
-
self._model_client.config_model_details(enable=True, statement_params=statement_params)
|
347
|
-
yield
|
348
|
-
self._model_client.config_model_details(enable=False, statement_params=statement_params)
|
349
|
-
|
350
342
|
@staticmethod
|
351
343
|
def _match_model_spec_with_sql_functions(
|
352
344
|
sql_functions_names: List[sql_identifier.SqlIdentifier], target_methods: List[str]
|
@@ -374,64 +366,63 @@ class ModelOperator:
|
|
374
366
|
version_name: sql_identifier.SqlIdentifier,
|
375
367
|
statement_params: Optional[Dict[str, Any]] = None,
|
376
368
|
) -> List[model_manifest_schema.ModelFunctionInfo]:
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
369
|
+
raw_model_spec_res = self._model_client.show_versions(
|
370
|
+
model_name=model_name,
|
371
|
+
version_name=version_name,
|
372
|
+
check_model_details=True,
|
373
|
+
statement_params={**(statement_params or {}), "SHOW_MODEL_DETAILS_IN_SHOW_VERSIONS_IN_MODEL": True},
|
374
|
+
)[0][self._model_client.MODEL_VERSION_MODEL_SPEC_COL_NAME]
|
375
|
+
model_spec_dict = yaml.safe_load(raw_model_spec_res)
|
376
|
+
model_spec = model_meta.ModelMetadata._validate_model_metadata(model_spec_dict)
|
377
|
+
show_functions_res = self._model_version_client.show_functions(
|
378
|
+
model_name=model_name,
|
379
|
+
version_name=version_name,
|
380
|
+
statement_params=statement_params,
|
381
|
+
)
|
382
|
+
function_names_and_types = []
|
383
|
+
for r in show_functions_res:
|
384
|
+
function_name = sql_identifier.SqlIdentifier(
|
385
|
+
r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
|
390
386
|
)
|
391
|
-
function_names_and_types = []
|
392
|
-
for r in show_functions_res:
|
393
|
-
function_name = sql_identifier.SqlIdentifier(
|
394
|
-
r[self._model_version_client.FUNCTION_NAME_COL_NAME], case_sensitive=True
|
395
|
-
)
|
396
387
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
function_names_and_types.append((function_name, function_type))
|
407
|
-
|
408
|
-
signatures = model_spec["signatures"]
|
409
|
-
function_names = [name for name, _ in function_names_and_types]
|
410
|
-
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
411
|
-
function_names, list(signatures.keys())
|
412
|
-
)
|
388
|
+
function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
389
|
+
try:
|
390
|
+
return_type = r[self._model_version_client.FUNCTION_RETURN_TYPE_COL_NAME]
|
391
|
+
except KeyError:
|
392
|
+
pass
|
393
|
+
else:
|
394
|
+
if "TABLE" in return_type:
|
395
|
+
function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
413
396
|
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
397
|
+
function_names_and_types.append((function_name, function_type))
|
398
|
+
|
399
|
+
signatures = model_spec["signatures"]
|
400
|
+
function_names = [name for name, _ in function_names_and_types]
|
401
|
+
function_name_mapping = ModelOperator._match_model_spec_with_sql_functions(
|
402
|
+
function_names, list(signatures.keys())
|
403
|
+
)
|
404
|
+
|
405
|
+
return [
|
406
|
+
model_manifest_schema.ModelFunctionInfo(
|
407
|
+
name=function_name.identifier(),
|
408
|
+
target_method=function_name_mapping[function_name],
|
409
|
+
target_method_function_type=function_type,
|
410
|
+
signature=model_signature.ModelSignature.from_dict(signatures[function_name_mapping[function_name]]),
|
411
|
+
)
|
412
|
+
for function_name, function_type in function_names_and_types
|
413
|
+
]
|
425
414
|
|
426
415
|
def invoke_method(
|
427
416
|
self,
|
428
417
|
*,
|
429
418
|
method_name: sql_identifier.SqlIdentifier,
|
419
|
+
method_function_type: str,
|
430
420
|
signature: model_signature.ModelSignature,
|
431
421
|
X: Union[type_hints.SupportedDataType, dataframe.DataFrame],
|
432
422
|
model_name: sql_identifier.SqlIdentifier,
|
433
423
|
version_name: sql_identifier.SqlIdentifier,
|
434
424
|
strict_input_validation: bool = False,
|
425
|
+
partition_column: Optional[sql_identifier.SqlIdentifier] = None,
|
435
426
|
statement_params: Optional[Dict[str, str]] = None,
|
436
427
|
) -> Union[type_hints.SupportedDataType, dataframe.DataFrame]:
|
437
428
|
identifier_rule = model_signature.SnowparkIdentifierRule.INFERRED
|
@@ -469,15 +460,27 @@ class ModelOperator:
|
|
469
460
|
if output_name in original_cols:
|
470
461
|
original_cols.remove(output_name)
|
471
462
|
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
|
463
|
+
if method_function_type == model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value:
|
464
|
+
df_res = self._model_version_client.invoke_function_method(
|
465
|
+
method_name=method_name,
|
466
|
+
input_df=s_df,
|
467
|
+
input_args=input_args,
|
468
|
+
returns=returns,
|
469
|
+
model_name=model_name,
|
470
|
+
version_name=version_name,
|
471
|
+
statement_params=statement_params,
|
472
|
+
)
|
473
|
+
elif method_function_type == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
474
|
+
df_res = self._model_version_client.invoke_table_function_method(
|
475
|
+
method_name=method_name,
|
476
|
+
input_df=s_df,
|
477
|
+
input_args=input_args,
|
478
|
+
partition_column=partition_column,
|
479
|
+
returns=returns,
|
480
|
+
model_name=model_name,
|
481
|
+
version_name=version_name,
|
482
|
+
statement_params=statement_params,
|
483
|
+
)
|
481
484
|
|
482
485
|
if keep_order:
|
483
486
|
df_res = df_res.sort(
|
@@ -486,7 +489,11 @@ class ModelOperator:
|
|
486
489
|
)
|
487
490
|
|
488
491
|
if not output_with_input_features:
|
489
|
-
|
492
|
+
cols_to_drop = original_cols
|
493
|
+
if partition_column is not None:
|
494
|
+
# don't drop partition column
|
495
|
+
cols_to_drop.remove(partition_column.identifier())
|
496
|
+
df_res = df_res.drop(*cols_to_drop)
|
490
497
|
|
491
498
|
# Get final result
|
492
499
|
if not isinstance(X, dataframe.DataFrame):
|
@@ -512,3 +519,66 @@ class ModelOperator:
|
|
512
519
|
model_name=model_name,
|
513
520
|
statement_params=statement_params,
|
514
521
|
)
|
522
|
+
|
523
|
+
def rename(
|
524
|
+
self,
|
525
|
+
*,
|
526
|
+
model_name: sql_identifier.SqlIdentifier,
|
527
|
+
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
528
|
+
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
529
|
+
new_model_name: sql_identifier.SqlIdentifier,
|
530
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
531
|
+
) -> None:
|
532
|
+
self._model_client.rename(
|
533
|
+
model_name=model_name,
|
534
|
+
new_model_db=new_model_db,
|
535
|
+
new_model_schema=new_model_schema,
|
536
|
+
new_model_name=new_model_name,
|
537
|
+
statement_params=statement_params,
|
538
|
+
)
|
539
|
+
|
540
|
+
# Map indicating in different modes, the path to list and download.
|
541
|
+
# The boolean value indicates if it is a directory,
|
542
|
+
MODEL_FILE_DOWNLOAD_PATTERN = {
|
543
|
+
"minimal": {
|
544
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
|
545
|
+
/ model_meta.MODEL_METADATA_FILE: False,
|
546
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH) / model_env._DEFAULT_ENV_DIR: True,
|
547
|
+
pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH)
|
548
|
+
/ model_runtime.ModelRuntime.RUNTIME_DIR_REL_PATH: True,
|
549
|
+
},
|
550
|
+
"model": {pathlib.PurePosixPath(model_composer.ModelComposer.MODEL_DIR_REL_PATH): True},
|
551
|
+
"full": {pathlib.PurePosixPath(os.curdir): True},
|
552
|
+
}
|
553
|
+
|
554
|
+
def download_files(
|
555
|
+
self,
|
556
|
+
*,
|
557
|
+
model_name: sql_identifier.SqlIdentifier,
|
558
|
+
version_name: sql_identifier.SqlIdentifier,
|
559
|
+
target_path: pathlib.Path,
|
560
|
+
mode: Literal["full", "model", "minimal"] = "model",
|
561
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
562
|
+
) -> None:
|
563
|
+
for remote_rel_path, is_dir in self.MODEL_FILE_DOWNLOAD_PATTERN[mode].items():
|
564
|
+
list_file_res = self._model_version_client.list_file(
|
565
|
+
model_name=model_name,
|
566
|
+
version_name=version_name,
|
567
|
+
file_path=remote_rel_path,
|
568
|
+
is_dir=is_dir,
|
569
|
+
statement_params=statement_params,
|
570
|
+
)
|
571
|
+
file_list = [
|
572
|
+
pathlib.PurePosixPath(*pathlib.PurePosixPath(row.name).parts[2:]) # versions/<version_name>/...
|
573
|
+
for row in list_file_res
|
574
|
+
]
|
575
|
+
for stage_file_path in file_list:
|
576
|
+
local_file_dir = target_path / stage_file_path.parent
|
577
|
+
local_file_dir.mkdir(parents=True, exist_ok=True)
|
578
|
+
self._model_version_client.get_file(
|
579
|
+
model_name=model_name,
|
580
|
+
version_name=version_name,
|
581
|
+
file_path=stage_file_path,
|
582
|
+
target_path=local_file_dir,
|
583
|
+
statement_params=statement_params,
|
584
|
+
)
|
@@ -121,21 +121,23 @@ class ModelSQLClient:
|
|
121
121
|
statement_params=statement_params,
|
122
122
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
123
123
|
|
124
|
-
def
|
124
|
+
def rename(
|
125
125
|
self,
|
126
126
|
*,
|
127
|
-
|
127
|
+
model_name: sql_identifier.SqlIdentifier,
|
128
|
+
new_model_db: Optional[sql_identifier.SqlIdentifier],
|
129
|
+
new_model_schema: Optional[sql_identifier.SqlIdentifier],
|
130
|
+
new_model_name: sql_identifier.SqlIdentifier,
|
128
131
|
statement_params: Optional[Dict[str, Any]] = None,
|
129
132
|
) -> None:
|
130
|
-
if
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
133
|
+
# Use registry's database and schema if a non fully qualified new model name is provided.
|
134
|
+
new_fully_qualified_name = identifier.get_schema_level_object_identifier(
|
135
|
+
new_model_db.identifier() if new_model_db else self._database_name.identifier(),
|
136
|
+
new_model_schema.identifier() if new_model_schema else self._schema_name.identifier(),
|
137
|
+
new_model_name.identifier(),
|
138
|
+
)
|
139
|
+
query_result_checker.SqlResultValidator(
|
140
|
+
self._session,
|
141
|
+
f"ALTER MODEL {self.fully_qualified_model_name(model_name)} RENAME TO {new_fully_qualified_name}",
|
142
|
+
statement_params=statement_params,
|
143
|
+
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -96,6 +96,38 @@ class ModelVersionSQLClient:
|
|
96
96
|
statement_params=statement_params,
|
97
97
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
98
98
|
|
99
|
+
def list_file(
|
100
|
+
self,
|
101
|
+
*,
|
102
|
+
model_name: sql_identifier.SqlIdentifier,
|
103
|
+
version_name: sql_identifier.SqlIdentifier,
|
104
|
+
file_path: pathlib.PurePosixPath,
|
105
|
+
is_dir: bool = False,
|
106
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
107
|
+
) -> List[row.Row]:
|
108
|
+
# Workaround for snowURL bug.
|
109
|
+
trailing_slash = "/" if is_dir else ""
|
110
|
+
|
111
|
+
stage_location = (
|
112
|
+
pathlib.PurePosixPath(
|
113
|
+
self.fully_qualified_model_name(model_name), "versions", version_name.resolved(), file_path
|
114
|
+
).as_posix()
|
115
|
+
+ trailing_slash
|
116
|
+
)
|
117
|
+
stage_location_url = ParseResult(
|
118
|
+
scheme="snow", netloc="model", path=stage_location, params="", query="", fragment=""
|
119
|
+
).geturl()
|
120
|
+
|
121
|
+
return (
|
122
|
+
query_result_checker.SqlResultValidator(
|
123
|
+
self._session,
|
124
|
+
f"List {_normalize_url_for_sql(stage_location_url)}",
|
125
|
+
statement_params=statement_params,
|
126
|
+
)
|
127
|
+
.has_column("name")
|
128
|
+
.validate()
|
129
|
+
)
|
130
|
+
|
99
131
|
def get_file(
|
100
132
|
self,
|
101
133
|
*,
|
@@ -162,7 +194,7 @@ class ModelVersionSQLClient:
|
|
162
194
|
statement_params=statement_params,
|
163
195
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
164
196
|
|
165
|
-
def
|
197
|
+
def invoke_function_method(
|
166
198
|
self,
|
167
199
|
*,
|
168
200
|
model_name: sql_identifier.SqlIdentifier,
|
@@ -232,6 +264,82 @@ class ModelVersionSQLClient:
|
|
232
264
|
|
233
265
|
return output_df
|
234
266
|
|
267
|
+
def invoke_table_function_method(
|
268
|
+
self,
|
269
|
+
*,
|
270
|
+
model_name: sql_identifier.SqlIdentifier,
|
271
|
+
version_name: sql_identifier.SqlIdentifier,
|
272
|
+
method_name: sql_identifier.SqlIdentifier,
|
273
|
+
input_df: dataframe.DataFrame,
|
274
|
+
input_args: List[sql_identifier.SqlIdentifier],
|
275
|
+
returns: List[Tuple[str, spt.DataType, sql_identifier.SqlIdentifier]],
|
276
|
+
partition_column: Optional[sql_identifier.SqlIdentifier],
|
277
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
278
|
+
) -> dataframe.DataFrame:
|
279
|
+
with_statements = []
|
280
|
+
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
281
|
+
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
282
|
+
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
283
|
+
else:
|
284
|
+
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
285
|
+
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
286
|
+
self._database_name.identifier(),
|
287
|
+
self._schema_name.identifier(),
|
288
|
+
tmp_table_name,
|
289
|
+
)
|
290
|
+
input_df.write.save_as_table( # type: ignore[call-overload]
|
291
|
+
table_name=INTERMEDIATE_TABLE_NAME,
|
292
|
+
mode="errorifexists",
|
293
|
+
table_type="temporary",
|
294
|
+
statement_params=statement_params,
|
295
|
+
)
|
296
|
+
|
297
|
+
module_version_alias = "MODEL_VERSION_ALIAS"
|
298
|
+
with_statements.append(
|
299
|
+
f"{module_version_alias} AS "
|
300
|
+
f"MODEL {self.fully_qualified_model_name(model_name)} VERSION {version_name.identifier()}"
|
301
|
+
)
|
302
|
+
|
303
|
+
partition_by = partition_column.identifier() if partition_column is not None else "1"
|
304
|
+
|
305
|
+
args_sql_list = []
|
306
|
+
for input_arg_value in input_args:
|
307
|
+
args_sql_list.append(input_arg_value)
|
308
|
+
|
309
|
+
args_sql = ", ".join(args_sql_list)
|
310
|
+
|
311
|
+
sql = textwrap.dedent(
|
312
|
+
f"""WITH {','.join(with_statements)}
|
313
|
+
SELECT *,
|
314
|
+
FROM {INTERMEDIATE_TABLE_NAME},
|
315
|
+
TABLE({module_version_alias}!{method_name.identifier()}({args_sql})
|
316
|
+
OVER (PARTITION BY {partition_by}))"""
|
317
|
+
)
|
318
|
+
|
319
|
+
output_df = self._session.sql(sql)
|
320
|
+
|
321
|
+
# Prepare the output
|
322
|
+
output_cols = []
|
323
|
+
output_names = []
|
324
|
+
|
325
|
+
for output_name, output_type, output_col_name in returns:
|
326
|
+
output_cols.append(F.col(output_name).astype(output_type))
|
327
|
+
output_names.append(output_col_name)
|
328
|
+
|
329
|
+
if partition_column is not None:
|
330
|
+
output_cols.append(F.col(partition_column.identifier()))
|
331
|
+
output_names.append(partition_column)
|
332
|
+
|
333
|
+
output_df = output_df.with_columns(
|
334
|
+
col_names=output_names,
|
335
|
+
values=output_cols,
|
336
|
+
)
|
337
|
+
|
338
|
+
if statement_params:
|
339
|
+
output_df._statement_params = statement_params # type: ignore[assignment]
|
340
|
+
|
341
|
+
return output_df
|
342
|
+
|
235
343
|
def set_metadata(
|
236
344
|
self,
|
237
345
|
metadata_dict: Dict[str, Any],
|
@@ -37,6 +37,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
37
37
|
session: snowpark.Session,
|
38
38
|
artifact_stage_location: str,
|
39
39
|
compute_pool: str,
|
40
|
+
job_name: str,
|
40
41
|
external_access_integrations: List[str],
|
41
42
|
) -> None:
|
42
43
|
"""Initialization
|
@@ -49,6 +50,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
49
50
|
artifact_stage_location: Spec file and future deployment related artifacts will be stored under
|
50
51
|
{stage}/models/{model_id}
|
51
52
|
compute_pool: The compute pool used to run docker image build workload.
|
53
|
+
job_name: job_name to use.
|
52
54
|
external_access_integrations: EAIs for network connection.
|
53
55
|
"""
|
54
56
|
self.context_dir = context_dir
|
@@ -58,6 +60,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
58
60
|
self.artifact_stage_location = artifact_stage_location
|
59
61
|
self.compute_pool = compute_pool
|
60
62
|
self.external_access_integrations = external_access_integrations
|
63
|
+
self.job_name = job_name
|
61
64
|
self.client = snowservice_client.SnowServiceClient(session)
|
62
65
|
|
63
66
|
assert artifact_stage_location.startswith(
|
@@ -203,8 +206,9 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
203
206
|
)
|
204
207
|
|
205
208
|
def _launch_kaniko_job(self, spec_stage_location: str) -> None:
|
206
|
-
logger.debug("Submitting job for building docker image with kaniko")
|
209
|
+
logger.debug(f"Submitting job {self.job_name} for building docker image with kaniko")
|
207
210
|
self.client.create_job(
|
211
|
+
job_name=self.job_name,
|
208
212
|
compute_pool=self.compute_pool,
|
209
213
|
spec_stage_location=spec_stage_location,
|
210
214
|
external_access_integrations=self.external_access_integrations,
|
@@ -30,6 +30,7 @@ USER mambauser
|
|
30
30
|
|
31
31
|
# Set MAMBA_DOCKERFILE_ACTIVATE=1 to activate the conda environment during build time.
|
32
32
|
ARG MAMBA_DOCKERFILE_ACTIVATE=1
|
33
|
+
ARG MAMBA_NO_LOW_SPEED_LIMIT=1
|
33
34
|
|
34
35
|
# Bitsandbytes uses this ENVVAR to determine CUDA library location
|
35
36
|
ENV CONDA_PREFIX=/opt/conda
|