snowflake-ml-python 1.0.2__py3-none-any.whl → 1.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/env_utils.py +2 -1
- snowflake/ml/_internal/file_utils.py +29 -7
- snowflake/ml/_internal/telemetry.py +5 -8
- snowflake/ml/_internal/utils/uri.py +7 -2
- snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +15 -0
- snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +259 -0
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +89 -0
- snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +24 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +118 -0
- snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +40 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +199 -0
- snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +88 -0
- snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +24 -0
- snowflake/ml/model/_deploy_client/utils/constants.py +47 -0
- snowflake/ml/model/_deploy_client/utils/snowservice_client.py +178 -0
- snowflake/ml/model/_deploy_client/warehouse/deploy.py +24 -6
- snowflake/ml/model/_deploy_client/warehouse/infer_template.py +5 -2
- snowflake/ml/model/_deployer.py +14 -27
- snowflake/ml/model/_env.py +4 -4
- snowflake/ml/model/_handlers/custom.py +14 -2
- snowflake/ml/model/_handlers/pytorch.py +186 -0
- snowflake/ml/model/_handlers/sklearn.py +14 -9
- snowflake/ml/model/_handlers/snowmlmodel.py +14 -9
- snowflake/ml/model/_handlers/torchscript.py +180 -0
- snowflake/ml/model/_handlers/xgboost.py +19 -9
- snowflake/ml/model/_model.py +3 -2
- snowflake/ml/model/_model_meta.py +12 -7
- snowflake/ml/model/model_signature.py +446 -66
- snowflake/ml/model/type_hints.py +23 -4
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -26
- snowflake/ml/modeling/cluster/affinity_propagation.py +51 -26
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -26
- snowflake/ml/modeling/cluster/birch.py +51 -26
- snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -26
- snowflake/ml/modeling/cluster/dbscan.py +51 -26
- snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -26
- snowflake/ml/modeling/cluster/k_means.py +51 -26
- snowflake/ml/modeling/cluster/mean_shift.py +51 -26
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -26
- snowflake/ml/modeling/cluster/optics.py +51 -26
- snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -26
- snowflake/ml/modeling/cluster/spectral_clustering.py +51 -26
- snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -26
- snowflake/ml/modeling/compose/column_transformer.py +51 -26
- snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -26
- snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -26
- snowflake/ml/modeling/covariance/empirical_covariance.py +51 -26
- snowflake/ml/modeling/covariance/graphical_lasso.py +51 -26
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -26
- snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -26
- snowflake/ml/modeling/covariance/min_cov_det.py +51 -26
- snowflake/ml/modeling/covariance/oas.py +51 -26
- snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -26
- snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -26
- snowflake/ml/modeling/decomposition/factor_analysis.py +51 -26
- snowflake/ml/modeling/decomposition/fast_ica.py +51 -26
- snowflake/ml/modeling/decomposition/incremental_pca.py +51 -26
- snowflake/ml/modeling/decomposition/kernel_pca.py +51 -26
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -26
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -26
- snowflake/ml/modeling/decomposition/pca.py +51 -26
- snowflake/ml/modeling/decomposition/sparse_pca.py +51 -26
- snowflake/ml/modeling/decomposition/truncated_svd.py +51 -26
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -26
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -26
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/isolation_forest.py +51 -26
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -26
- snowflake/ml/modeling/ensemble/voting_classifier.py +51 -26
- snowflake/ml/modeling/ensemble/voting_regressor.py +51 -26
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fdr.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fpr.py +51 -26
- snowflake/ml/modeling/feature_selection/select_fwe.py +51 -26
- snowflake/ml/modeling/feature_selection/select_k_best.py +51 -26
- snowflake/ml/modeling/feature_selection/select_percentile.py +51 -26
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -26
- snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -26
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -26
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -26
- snowflake/ml/modeling/impute/iterative_imputer.py +51 -26
- snowflake/ml/modeling/impute/knn_imputer.py +51 -26
- snowflake/ml/modeling/impute/missing_indicator.py +51 -26
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -26
- snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -26
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -26
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -26
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -26
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -26
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -26
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ard_regression.py +51 -26
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -26
- snowflake/ml/modeling/linear_model/elastic_net.py +51 -26
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -26
- snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/huber_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/lars.py +51 -26
- snowflake/ml/modeling/linear_model/lars_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -26
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -26
- snowflake/ml/modeling/linear_model/linear_regression.py +51 -26
- snowflake/ml/modeling/linear_model/logistic_regression.py +51 -26
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -26
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -26
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -26
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/perceptron.py +51 -26
- snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/ridge.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -26
- snowflake/ml/modeling/linear_model/ridge_cv.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -26
- snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -26
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -26
- snowflake/ml/modeling/manifold/isomap.py +51 -26
- snowflake/ml/modeling/manifold/mds.py +51 -26
- snowflake/ml/modeling/manifold/spectral_embedding.py +51 -26
- snowflake/ml/modeling/manifold/tsne.py +51 -26
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -26
- snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -26
- snowflake/ml/modeling/model_selection/grid_search_cv.py +51 -26
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +51 -26
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -26
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -26
- snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -26
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -26
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -26
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -26
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -26
- snowflake/ml/modeling/neighbors/kernel_density.py +51 -26
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -26
- snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -26
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -26
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -26
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -26
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -26
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -26
- snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -26
- snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -26
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +2 -0
- snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -26
- snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -26
- snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -26
- snowflake/ml/modeling/svm/linear_svc.py +51 -26
- snowflake/ml/modeling/svm/linear_svr.py +51 -26
- snowflake/ml/modeling/svm/nu_svc.py +51 -26
- snowflake/ml/modeling/svm/nu_svr.py +51 -26
- snowflake/ml/modeling/svm/svc.py +51 -26
- snowflake/ml/modeling/svm/svr.py +51 -26
- snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -26
- snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -26
- snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -26
- snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -26
- snowflake/ml/modeling/xgboost/xgb_classifier.py +51 -26
- snowflake/ml/modeling/xgboost/xgb_regressor.py +51 -26
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +51 -26
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +51 -26
- snowflake/ml/registry/model_registry.py +74 -56
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/METADATA +27 -8
- snowflake_ml_python-1.0.3.dist-info/RECORD +259 -0
- snowflake_ml_python-1.0.2.dist-info/RECORD +0 -246
- {snowflake_ml_python-1.0.2.dist-info → snowflake_ml_python-1.0.3.dist-info}/WHEEL +0 -0
@@ -310,7 +310,8 @@ def validate_requirements_in_snowflake_conda_channel(
|
|
310
310
|
FROM information_schema.packages
|
311
311
|
WHERE ({pkg_names_str})
|
312
312
|
AND language = 'python'
|
313
|
-
AND runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}'
|
313
|
+
AND (runtime_version = '{parsed_python_version.major}.{parsed_python_version.minor}'
|
314
|
+
OR runtime_version is null);
|
314
315
|
"""
|
315
316
|
)
|
316
317
|
else:
|
@@ -59,10 +59,12 @@ def zip_file_or_directory_to_stream(
|
|
59
59
|
Raises:
|
60
60
|
FileNotFoundError: Raised when the given path does not exist.
|
61
61
|
ValueError: Raised when the leading path is not a actual leading path of path
|
62
|
+
ValueError: Raised when the arcname cannot be encoded using ASCII.
|
62
63
|
|
63
64
|
Yields:
|
64
65
|
A bytes IO stream containing the zip file.
|
65
66
|
"""
|
67
|
+
# TODO(SNOW-862576): Should remove check on ASCII encoding after SNOW-862576 fixed.
|
66
68
|
if not os.path.exists(path):
|
67
69
|
raise FileNotFoundError(f"{path} is not found")
|
68
70
|
if leading_path and not path.startswith(leading_path):
|
@@ -76,23 +78,35 @@ def zip_file_or_directory_to_stream(
|
|
76
78
|
if os.path.realpath(path) != os.path.realpath(start_path):
|
77
79
|
cur_path = os.path.dirname(path)
|
78
80
|
while os.path.realpath(cur_path) != os.path.realpath(start_path):
|
79
|
-
|
81
|
+
arcname = os.path.relpath(cur_path, start_path)
|
82
|
+
if not _able_ascii_encode(arcname):
|
83
|
+
raise ValueError(f"File name {arcname} cannot be encoded using ASCII. Please rename.")
|
84
|
+
zf.write(cur_path, arcname)
|
80
85
|
cur_path = os.path.dirname(cur_path)
|
81
86
|
|
82
87
|
if os.path.isdir(path):
|
83
|
-
for
|
88
|
+
for dirpath, _, files in os.walk(path):
|
84
89
|
# ignore __pycache__
|
85
|
-
if ignore_generated_py_file and "__pycache__" in
|
90
|
+
if ignore_generated_py_file and "__pycache__" in dirpath:
|
86
91
|
continue
|
87
|
-
|
92
|
+
arcname = os.path.relpath(dirpath, start_path)
|
93
|
+
if not _able_ascii_encode(arcname):
|
94
|
+
raise ValueError(f"File name {arcname} cannot be encoded using ASCII. Please rename.")
|
95
|
+
zf.write(dirpath, arcname)
|
88
96
|
for file in files:
|
89
97
|
# ignore generated python files
|
90
98
|
if ignore_generated_py_file and file.endswith(GENERATED_PY_FILE_EXT):
|
91
99
|
continue
|
92
|
-
|
93
|
-
|
100
|
+
file_path = os.path.join(dirpath, file)
|
101
|
+
arcname = os.path.relpath(file_path, start_path)
|
102
|
+
if not _able_ascii_encode(arcname):
|
103
|
+
raise ValueError(f"File name {arcname} cannot be encoded using ASCII. Please rename.")
|
104
|
+
zf.write(file_path, arcname)
|
94
105
|
else:
|
95
|
-
|
106
|
+
arcname = os.path.relpath(path, start_path)
|
107
|
+
if not _able_ascii_encode(arcname):
|
108
|
+
raise ValueError(f"File name {arcname} cannot be encoded using ASCII. Please rename.")
|
109
|
+
zf.write(path, arcname)
|
96
110
|
|
97
111
|
yield input_stream
|
98
112
|
|
@@ -145,3 +159,11 @@ def get_all_modules(dirname: str, prefix: str = "") -> List[pkgutil.ModuleInfo]:
|
|
145
159
|
for dirname in subdirs:
|
146
160
|
modules.extend(get_all_modules(dirname, prefix=f"{prefix}.{dirname}" if prefix else dirname))
|
147
161
|
return modules
|
162
|
+
|
163
|
+
|
164
|
+
def _able_ascii_encode(s: str) -> bool:
|
165
|
+
try:
|
166
|
+
s.encode("ascii", errors="strict")
|
167
|
+
return True
|
168
|
+
except UnicodeEncodeError:
|
169
|
+
return False
|
@@ -6,7 +6,6 @@ import enum
|
|
6
6
|
import functools
|
7
7
|
import inspect
|
8
8
|
import operator
|
9
|
-
import threading
|
10
9
|
import types
|
11
10
|
from typing import (
|
12
11
|
Any,
|
@@ -29,7 +28,6 @@ from snowflake.ml._internal import env
|
|
29
28
|
from snowflake.snowpark import dataframe, exceptions, session
|
30
29
|
from snowflake.snowpark._internal import utils
|
31
30
|
|
32
|
-
_rlock = threading.RLock()
|
33
31
|
_log_counter = 0
|
34
32
|
_FLUSH_SIZE = 10
|
35
33
|
|
@@ -308,12 +306,11 @@ def send_api_usage_telemetry(
|
|
308
306
|
return res
|
309
307
|
finally:
|
310
308
|
telemetry.send_function_usage_telemetry(**telemetry_args)
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
_log_counter = 0
|
309
|
+
global _log_counter
|
310
|
+
_log_counter += 1
|
311
|
+
if _log_counter >= _FLUSH_SIZE or "error" in telemetry_args:
|
312
|
+
telemetry.send_batch()
|
313
|
+
_log_counter = 0
|
317
314
|
|
318
315
|
return cast(Callable[_Args, _ReturnValue], wrap)
|
319
316
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
import
|
1
|
+
import posixpath
|
2
2
|
from typing import Optional
|
3
3
|
from urllib.parse import ParseResult, urlparse, urlunparse
|
4
4
|
|
@@ -35,7 +35,12 @@ def get_snowflake_stage_path_from_uri(uri: str) -> Optional[str]:
|
|
35
35
|
if not is_snowflake_stage_uri(uri):
|
36
36
|
return None
|
37
37
|
uri_components = urlparse(uri)
|
38
|
-
|
38
|
+
# posixpath.join will drop other components if any of arguments is absolute path.
|
39
|
+
# The path we get is actually absolute (starting with '/'), however, since we concat them to stage location,
|
40
|
+
# it should not.
|
41
|
+
return posixpath.normpath(
|
42
|
+
posixpath.join(posixpath.normpath(uri_components.netloc), posixpath.normpath(uri_components.path.lstrip("/")))
|
43
|
+
)
|
39
44
|
|
40
45
|
|
41
46
|
def get_uri_scheme(uri: str) -> str:
|
@@ -0,0 +1,15 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
|
3
|
+
|
4
|
+
class ImageBuilder(ABC):
|
5
|
+
"""
|
6
|
+
Abstract class encapsulating image building and upload to model registry.
|
7
|
+
"""
|
8
|
+
|
9
|
+
@abstractmethod
|
10
|
+
def build_and_upload_image(self) -> str:
|
11
|
+
"""Builds and uploads an image to the model registry.
|
12
|
+
|
13
|
+
Returns: Full image path.
|
14
|
+
"""
|
15
|
+
pass
|
@@ -0,0 +1,259 @@
|
|
1
|
+
import base64
|
2
|
+
import json
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import posixpath
|
6
|
+
import subprocess
|
7
|
+
import tempfile
|
8
|
+
import zipfile
|
9
|
+
from enum import Enum
|
10
|
+
from typing import List
|
11
|
+
|
12
|
+
import yaml
|
13
|
+
|
14
|
+
from snowflake import snowpark
|
15
|
+
from snowflake.ml._internal.utils import query_result_checker
|
16
|
+
from snowflake.ml.model._deploy_client.image_builds import (
|
17
|
+
base_image_builder,
|
18
|
+
docker_context,
|
19
|
+
)
|
20
|
+
from snowflake.ml.model._deploy_client.utils import constants
|
21
|
+
|
22
|
+
|
23
|
+
class Platform(Enum):
|
24
|
+
LINUX_AMD64 = "linux/amd64"
|
25
|
+
|
26
|
+
|
27
|
+
class ClientImageBuilder(base_image_builder.ImageBuilder):
|
28
|
+
"""
|
29
|
+
Client-side image building and upload to model registry.
|
30
|
+
|
31
|
+
Usage requirements:
|
32
|
+
Requires prior installation and running of Docker with BuildKit. See installation instructions in
|
33
|
+
https://docs.docker.com/engine/install/
|
34
|
+
|
35
|
+
|
36
|
+
"""
|
37
|
+
|
38
|
+
def __init__(
|
39
|
+
self, *, id: str, image_repo: str, model_zip_stage_path: str, session: snowpark.Session, use_gpu: bool = False
|
40
|
+
) -> None:
|
41
|
+
"""Initialization
|
42
|
+
|
43
|
+
Args:
|
44
|
+
id: A hexadecimal string used for naming the image tag.
|
45
|
+
image_repo: Path to image repository.
|
46
|
+
model_zip_stage_path: Path to model zip file in stage.
|
47
|
+
use_gpu: Boolean flag for generating the CPU or GPU base image.
|
48
|
+
session: Snowpark session
|
49
|
+
"""
|
50
|
+
self.image_tag = "/".join([image_repo.rstrip("/"), id]) + ":latest"
|
51
|
+
self.image_repo = image_repo
|
52
|
+
self.model_zip_stage_path = model_zip_stage_path
|
53
|
+
self.use_gpu = use_gpu
|
54
|
+
self.session = session
|
55
|
+
|
56
|
+
def build_and_upload_image(self) -> str:
|
57
|
+
"""
|
58
|
+
Builds and uploads an image to the model registry.
|
59
|
+
"""
|
60
|
+
|
61
|
+
def _setup_docker_config(docker_config_dir: str) -> None:
|
62
|
+
"""Set up a temporary docker config, which is used for running all docker commands.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
docker_config_dir: Path to docker configuration directory, which stores the temporary session token.
|
66
|
+
"""
|
67
|
+
ctx = self.session._conn._conn
|
68
|
+
assert ctx._rest, "SnowflakeRestful is not set in session"
|
69
|
+
token_data = ctx._rest._token_request("ISSUE")
|
70
|
+
snowpark_session_token = token_data["data"]["sessionToken"]
|
71
|
+
token_obj = {"token": snowpark_session_token}
|
72
|
+
credentials = f"0sessiontoken:{json.dumps(token_obj)}"
|
73
|
+
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
|
74
|
+
content = {"auths": {self.image_tag: {"auth": encoded_credentials}}}
|
75
|
+
config_path = os.path.join(docker_config_dir, "config.json")
|
76
|
+
with open(config_path, "w", encoding="utf-8") as file:
|
77
|
+
json.dump(content, file)
|
78
|
+
|
79
|
+
self.validate_docker_client_env()
|
80
|
+
|
81
|
+
query_result = (
|
82
|
+
query_result_checker.SqlResultValidator(
|
83
|
+
self.session,
|
84
|
+
query="SHOW PARAMETERS LIKE 'PYTHON_CONNECTOR_QUERY_RESULT_FORMAT' IN SESSION",
|
85
|
+
)
|
86
|
+
.has_dimensions(expected_rows=1)
|
87
|
+
.validate()
|
88
|
+
)
|
89
|
+
prev_format = query_result[0].value
|
90
|
+
|
91
|
+
with tempfile.TemporaryDirectory() as config_dir:
|
92
|
+
try:
|
93
|
+
# Workaround for SNOW-841699: Fail to authenticate to image registry with session token generated from
|
94
|
+
# Snowpark. Need to temporarily set the json query format in order to process GS token response.
|
95
|
+
self.session.sql("ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = 'json'").collect()
|
96
|
+
_setup_docker_config(config_dir)
|
97
|
+
self._build(config_dir)
|
98
|
+
self._upload(config_dir)
|
99
|
+
finally:
|
100
|
+
self.session.sql(f"ALTER SESSION SET PYTHON_CONNECTOR_QUERY_RESULT_FORMAT = '{prev_format}'").collect()
|
101
|
+
commands = ["docker", "--config", config_dir, "rmi", self.image_tag]
|
102
|
+
logging.info(f"Removing local image: {self.image_tag}")
|
103
|
+
self._run_docker_commands(commands)
|
104
|
+
return self.image_tag
|
105
|
+
|
106
|
+
def validate_docker_client_env(self) -> None:
|
107
|
+
"""Ensure docker client is running and BuildKit is enabled. Note that Buildx always uses BuildKit.
|
108
|
+
- Ensure docker daemon is running through the "docker info" command on shell. When docker daemon is running,
|
109
|
+
return code will be 0, else return code will be 1.
|
110
|
+
- Ensure BuildKit is enabled by checking "docker buildx version".
|
111
|
+
|
112
|
+
Raises:
|
113
|
+
ConnectionError: Occurs when Docker is not installed or is not running.
|
114
|
+
|
115
|
+
"""
|
116
|
+
info_command = "docker info"
|
117
|
+
buildx_command = "docker buildx version"
|
118
|
+
|
119
|
+
try:
|
120
|
+
subprocess.check_call(info_command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
|
121
|
+
except subprocess.CalledProcessError:
|
122
|
+
raise ConnectionError("Failed to initialize Docker client. Please ensure Docker is installed and running.")
|
123
|
+
|
124
|
+
try:
|
125
|
+
subprocess.check_call(buildx_command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
|
126
|
+
except subprocess.CalledProcessError:
|
127
|
+
raise ConnectionError(
|
128
|
+
"Please ensured Docker is installed with BuildKit by following "
|
129
|
+
"https://docs.docker.com/build/buildkit/#getting-started"
|
130
|
+
)
|
131
|
+
|
132
|
+
def _extract_model_zip(self, context_dir: str) -> str:
|
133
|
+
"""Extract a zip file into the specified directory.
|
134
|
+
|
135
|
+
Args:
|
136
|
+
context_dir: Directory to extract the zip to.
|
137
|
+
|
138
|
+
Returns:
|
139
|
+
The extracted model directory.
|
140
|
+
"""
|
141
|
+
|
142
|
+
local_model_zip_path = os.path.join(context_dir, posixpath.basename(self.model_zip_stage_path))
|
143
|
+
if zipfile.is_zipfile(local_model_zip_path):
|
144
|
+
extracted_model_dir = os.path.join(context_dir, constants.MODEL_DIR)
|
145
|
+
with zipfile.ZipFile(local_model_zip_path, "r") as model_zip:
|
146
|
+
if len(model_zip.namelist()) > 1:
|
147
|
+
model_zip.extractall(extracted_model_dir)
|
148
|
+
conda_path = os.path.join(extracted_model_dir, "env", "conda.yaml")
|
149
|
+
|
150
|
+
def remove_snowml_from_conda() -> None:
|
151
|
+
with open(conda_path, encoding="utf-8") as file:
|
152
|
+
conda_yaml = yaml.safe_load(file)
|
153
|
+
|
154
|
+
dependencies = conda_yaml["dependencies"]
|
155
|
+
dependencies = [dep for dep in dependencies if not dep.startswith("snowflake-ml-python")]
|
156
|
+
|
157
|
+
conda_yaml["dependencies"] = dependencies
|
158
|
+
|
159
|
+
with open(conda_path, "w", encoding="utf-8") as file:
|
160
|
+
yaml.dump(conda_yaml, file)
|
161
|
+
|
162
|
+
# TODO(shchen): Remove once SNOW-840411 is landed.
|
163
|
+
remove_snowml_from_conda()
|
164
|
+
return extracted_model_dir
|
165
|
+
|
166
|
+
def _build(self, docker_config_dir: str) -> None:
|
167
|
+
"""Constructs the Docker context directory and then builds a Docker image based on that context.
|
168
|
+
|
169
|
+
Args:
|
170
|
+
docker_config_dir: Path to docker configuration directory, which stores the temporary session token.
|
171
|
+
"""
|
172
|
+
|
173
|
+
with tempfile.TemporaryDirectory() as context_dir:
|
174
|
+
# Download the model zip file that is already uploaded to stage during model registry log_model step.
|
175
|
+
# This is needed in order to obtain the conda and requirement file inside the model zip.
|
176
|
+
self.session.file.get(self.model_zip_stage_path, context_dir)
|
177
|
+
|
178
|
+
extracted_model_dir = self._extract_model_zip(context_dir)
|
179
|
+
|
180
|
+
dc = docker_context.DockerContext(
|
181
|
+
context_dir=context_dir, model_dir=extracted_model_dir, use_gpu=self.use_gpu
|
182
|
+
)
|
183
|
+
dc.build()
|
184
|
+
self._build_image_from_context(context_dir=context_dir, docker_config_dir=docker_config_dir)
|
185
|
+
|
186
|
+
def _run_docker_commands(self, commands: List[str]) -> None:
|
187
|
+
"""Run docker commands in a new child process.
|
188
|
+
|
189
|
+
Args:
|
190
|
+
commands: List of commands to run.
|
191
|
+
|
192
|
+
Raises:
|
193
|
+
RuntimeError: Occurs when docker commands failed to execute.
|
194
|
+
"""
|
195
|
+
proc = subprocess.Popen(commands, cwd=os.getcwd(), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
|
196
|
+
output_lines = []
|
197
|
+
|
198
|
+
if proc.stdout:
|
199
|
+
for line in iter(proc.stdout.readline, ""):
|
200
|
+
output_lines.append(line)
|
201
|
+
logging.info(line)
|
202
|
+
|
203
|
+
if proc.wait():
|
204
|
+
raise RuntimeError(f"Docker build failed: {''.join(output_lines)}")
|
205
|
+
|
206
|
+
def _build_image_from_context(
|
207
|
+
self, context_dir: str, docker_config_dir: str, *, platform: Platform = Platform.LINUX_AMD64
|
208
|
+
) -> None:
|
209
|
+
"""Builds a Docker image based on provided context.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
context_dir: Path to context directory.
|
213
|
+
docker_config_dir: Path to docker configuration directory, which stores the temporary session token.
|
214
|
+
platform: Target platform for the build output, in the format "os[/arch[/variant]]".
|
215
|
+
"""
|
216
|
+
|
217
|
+
commands = [
|
218
|
+
"docker",
|
219
|
+
"--config",
|
220
|
+
docker_config_dir,
|
221
|
+
"buildx",
|
222
|
+
"build",
|
223
|
+
"--platform",
|
224
|
+
platform.value,
|
225
|
+
"--tag",
|
226
|
+
f"{self.image_tag}",
|
227
|
+
context_dir,
|
228
|
+
]
|
229
|
+
|
230
|
+
self._run_docker_commands(commands)
|
231
|
+
|
232
|
+
def _upload(self, docker_config_dir: str) -> None:
|
233
|
+
"""
|
234
|
+
Uploads image to the image registry. This process requires a "docker login" followed by a "docker push". Remove
|
235
|
+
local image at the end of the upload operation to save up local space. Image cache is kept for more performant
|
236
|
+
built experience at the cost of small storage footprint.
|
237
|
+
|
238
|
+
For image registry authentication, we will use a session token obtained from the Snowpark session object.
|
239
|
+
The token authentication mechanism is automatically used when the username is set to "0sessiontoken" according
|
240
|
+
to the registry implementation detailed in the following link:
|
241
|
+
https://github.com/snowflakedb/snowflake-image-registry/blob/277435c6fd79db2df9f863aa9d04dc875e034d85
|
242
|
+
/AuthAdapter/src/main/java/com/snowflake/registry/service/AuthHeader.java#L122
|
243
|
+
|
244
|
+
By default, Docker overwrites the local Docker config file "/.docker/config.json" whenever a docker login
|
245
|
+
occurs. However, to ensure better isolation between Snowflake-managed Docker credentials and the user's own
|
246
|
+
Docker credentials, we will not use the default Docker config. Instead, we will write the username and session
|
247
|
+
token to a temporary file and use "docker --config" so that it only applies to the specific Docker command being
|
248
|
+
executed, without affecting the user's local Docker setup. The credential file will be automatically removed
|
249
|
+
at the end of upload operation.
|
250
|
+
|
251
|
+
Args:
|
252
|
+
docker_config_dir: Path to docker configuration directory, which stores the temporary session token.
|
253
|
+
"""
|
254
|
+
commands = ["docker", "--config", docker_config_dir, "login", self.image_tag]
|
255
|
+
self._run_docker_commands(commands)
|
256
|
+
|
257
|
+
logging.info(f"Pushing image to image repo {self.image_tag}")
|
258
|
+
commands = ["docker", "--config", docker_config_dir, "push", self.image_tag]
|
259
|
+
self._run_docker_commands(commands)
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import importlib
|
2
|
+
import os
|
3
|
+
import shutil
|
4
|
+
import string
|
5
|
+
from abc import ABC
|
6
|
+
|
7
|
+
from snowflake.ml.model._deploy_client.utils import constants
|
8
|
+
|
9
|
+
|
10
|
+
class DockerContext(ABC):
|
11
|
+
"""
|
12
|
+
Constructs the Docker context directory required for image building.
|
13
|
+
"""
|
14
|
+
|
15
|
+
def __init__(self, context_dir: str, model_dir: str, *, use_gpu: bool = False) -> None:
|
16
|
+
"""Initialization
|
17
|
+
|
18
|
+
Args:
|
19
|
+
context_dir: Path to context directory.
|
20
|
+
model_dir: Path to local model directory.
|
21
|
+
use_gpu: Boolean flag for generating the CPU or GPU base image.
|
22
|
+
"""
|
23
|
+
self.context_dir = context_dir
|
24
|
+
self.model_dir = model_dir
|
25
|
+
# TODO(shchen): SNOW-825995, Define dockerfile template used for model deployment. use_gpu will be used.
|
26
|
+
self.use_gpu = use_gpu
|
27
|
+
|
28
|
+
def build(self) -> None:
|
29
|
+
"""
|
30
|
+
Generates and/or moves resources into the Docker context directory.Rename the random model directory name to
|
31
|
+
constant "model_dir" instead for better readability.
|
32
|
+
"""
|
33
|
+
self._generate_inference_code()
|
34
|
+
self._copy_entrypoint_script_to_docker_context()
|
35
|
+
self._copy_snowml_source_code_to_docker_context()
|
36
|
+
self._generate_docker_file()
|
37
|
+
|
38
|
+
def _copy_snowml_source_code_to_docker_context(self) -> None:
|
39
|
+
"""Copy the entire snowflake/ml source code to docker context. This will be particularly useful for CI tests
|
40
|
+
against latest changes.
|
41
|
+
|
42
|
+
Note that we exclude the experimental directory mainly for development scenario; as experimental directory won't
|
43
|
+
be included in the release.
|
44
|
+
"""
|
45
|
+
snow_ml_source_dir = list(importlib.import_module("snowflake.ml").__path__)[0]
|
46
|
+
shutil.copytree(
|
47
|
+
snow_ml_source_dir,
|
48
|
+
os.path.join(self.context_dir, "snowflake", "ml"),
|
49
|
+
ignore=shutil.ignore_patterns("*.pyc", "experimental"),
|
50
|
+
)
|
51
|
+
|
52
|
+
def _copy_entrypoint_script_to_docker_context(self) -> None:
|
53
|
+
"""Copy gunicorn_run.sh entrypoint to docker context directory."""
|
54
|
+
path = os.path.join(os.path.dirname(__file__), constants.ENTRYPOINT_SCRIPT)
|
55
|
+
assert os.path.exists(path), f"Run script file missing at path: {path}"
|
56
|
+
shutil.copy(path, os.path.join(self.context_dir, constants.ENTRYPOINT_SCRIPT))
|
57
|
+
|
58
|
+
def _generate_docker_file(self) -> None:
|
59
|
+
"""
|
60
|
+
Generates dockerfile based on dockerfile template.
|
61
|
+
"""
|
62
|
+
docker_file_path = os.path.join(self.context_dir, "Dockerfile")
|
63
|
+
docker_file_template = os.path.join(os.path.dirname(__file__), "templates/dockerfile_template")
|
64
|
+
|
65
|
+
with open(docker_file_path, "w", encoding="utf-8") as dockerfile, open(
|
66
|
+
docker_file_template, encoding="utf-8"
|
67
|
+
) as template:
|
68
|
+
dockerfile_content = string.Template(template.read()).safe_substitute(
|
69
|
+
{
|
70
|
+
# TODO(shchen): SNOW-835411, Support overwriting base image
|
71
|
+
"base_image": "mambaorg/micromamba:focal-cuda-11.7.1"
|
72
|
+
if self.use_gpu
|
73
|
+
else "mambaorg/micromamba:1.4.3",
|
74
|
+
"model_dir": constants.MODEL_DIR,
|
75
|
+
"inference_server_dir": constants.INFERENCE_SERVER_DIR,
|
76
|
+
"entrypoint_script": constants.ENTRYPOINT_SCRIPT,
|
77
|
+
}
|
78
|
+
)
|
79
|
+
dockerfile.write(dockerfile_content)
|
80
|
+
|
81
|
+
def _generate_inference_code(self) -> None:
|
82
|
+
"""
|
83
|
+
Generates inference code based on the app template and creates a folder named 'server' to house the inference
|
84
|
+
server code.
|
85
|
+
"""
|
86
|
+
inference_server_folder_path = os.path.join(os.path.dirname(__file__), constants.INFERENCE_SERVER_DIR)
|
87
|
+
destination_folder_path = os.path.join(self.context_dir, constants.INFERENCE_SERVER_DIR)
|
88
|
+
ignore_patterns = shutil.ignore_patterns("BUILD.bazel", "*test.py", "*.\\.*", "__pycache__")
|
89
|
+
shutil.copytree(inference_server_folder_path, destination_folder_path, ignore=ignore_patterns)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
#!/bin/sh
|
2
|
+
set -eu
|
3
|
+
|
4
|
+
OS=$(uname)
|
5
|
+
|
6
|
+
if [ "${OS}" = "Linux" ]; then
|
7
|
+
NUM_CORES=$(nproc)
|
8
|
+
elif [ "${OS}" = "Darwin" ]; then
|
9
|
+
# macOS
|
10
|
+
NUM_CORES=$(sysctl -n hw.ncpu)
|
11
|
+
elif [ "${OS}" = "Windows" ]; then
|
12
|
+
NUM_CORES=$(wmic cpu get NumberOfCores | grep -Eo '[0-9]+')
|
13
|
+
else
|
14
|
+
echo "Unsupported operating system: ${OS}"
|
15
|
+
exit 1
|
16
|
+
fi
|
17
|
+
|
18
|
+
# Based on the Gunicorn documentation, set the number of workers to number_of_cores * 2 + 1. This assumption is
|
19
|
+
# based on an ideal scenario where one core is handling two processes simultaneously, while one process is dedicated to
|
20
|
+
# IO operations and the other process is performing compute tasks.
|
21
|
+
NUM_WORKERS=$((NUM_CORES * 2 + 1))
|
22
|
+
echo "Number of CPU cores: $NUM_CORES"
|
23
|
+
echo "Setting number of workers to $NUM_WORKERS"
|
24
|
+
exec /opt/conda/bin/gunicorn --preload -w "$NUM_WORKERS" -k uvicorn.workers.UvicornWorker -b 0.0.0.0:5000 inference_server.main:app
|
@@ -0,0 +1,118 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
import zipfile
|
5
|
+
|
6
|
+
import pandas as pd
|
7
|
+
from starlette import applications, requests, responses, routing
|
8
|
+
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
loaded_model = None
|
11
|
+
|
12
|
+
|
13
|
+
def _run_setup() -> None:
|
14
|
+
"""Set up logging and load model into memory."""
|
15
|
+
# Align the application logger's handler with Gunicorn's to capture logs from all processes.
|
16
|
+
gunicorn_logger = logging.getLogger("gunicorn.error")
|
17
|
+
logger.handlers = gunicorn_logger.handlers
|
18
|
+
logger.setLevel(gunicorn_logger.level)
|
19
|
+
|
20
|
+
from snowflake.ml.model import _model as model_api
|
21
|
+
|
22
|
+
global loaded_model
|
23
|
+
|
24
|
+
MODEL_ZIP_STAGE_PATH = os.getenv("MODEL_ZIP_STAGE_PATH")
|
25
|
+
assert MODEL_ZIP_STAGE_PATH, "Missing environment variable MODEL_ZIP_STAGE_PATH"
|
26
|
+
root_path = os.path.abspath(os.sep)
|
27
|
+
model_zip_stage_path = os.path.join(root_path, MODEL_ZIP_STAGE_PATH)
|
28
|
+
|
29
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
30
|
+
if zipfile.is_zipfile(model_zip_stage_path):
|
31
|
+
extracted_dir = os.path.join(tmp_dir, "extracted_model_dir")
|
32
|
+
logger.info(f"Extracting model zip from {model_zip_stage_path} to {extracted_dir}")
|
33
|
+
with zipfile.ZipFile(model_zip_stage_path, "r") as model_zip:
|
34
|
+
if len(model_zip.namelist()) > 1:
|
35
|
+
model_zip.extractall(extracted_dir)
|
36
|
+
else:
|
37
|
+
raise RuntimeError(f"No model zip found at stage path: {model_zip_stage_path}")
|
38
|
+
logger.info(f"Loading model from {extracted_dir} into memory")
|
39
|
+
loaded_model, _ = model_api._load_model_for_deploy(model_dir_path=extracted_dir)
|
40
|
+
logger.info("Successfully loaded model into memory")
|
41
|
+
|
42
|
+
|
43
|
+
async def ready(request: requests.Request) -> responses.JSONResponse:
|
44
|
+
"""Endpoint to check if the application is ready."""
|
45
|
+
return responses.JSONResponse({"status": "ready"})
|
46
|
+
|
47
|
+
|
48
|
+
async def predict(request: requests.Request) -> responses.JSONResponse:
|
49
|
+
"""Endpoint to make predictions based on input data.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
request: The input data is expected to be in the following JSON format:
|
53
|
+
{
|
54
|
+
"data": [
|
55
|
+
[0, 5.1, 3.5, 4.2, 1.3],
|
56
|
+
[1, 4.7, 3.2, 4.1, 4.2]
|
57
|
+
}
|
58
|
+
Each row is represented as a list, where the first element denotes the index of the row.
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Two possible responses:
|
62
|
+
For success, return a JSON response {"data": [[0, 1], [1, 2]]}, where the first element of each resulting list
|
63
|
+
denotes the index of the row, and the rest of the elements represent the prediction results for that row.
|
64
|
+
For an error, return {"error": error_message, "status_code": http_response_status_code}.
|
65
|
+
"""
|
66
|
+
try:
|
67
|
+
input = await request.json()
|
68
|
+
assert "data" in input, "missing data field in the request input"
|
69
|
+
# The expression x[1:] is used to exclude the index of the data row.
|
70
|
+
input_data = [x[1:] for x in input.get("data")]
|
71
|
+
x = pd.DataFrame(input_data)
|
72
|
+
assert len(input_data) != 0 and not all(not row for row in input_data), "empty data"
|
73
|
+
except Exception as e:
|
74
|
+
error_message = f"Input data malformed: {str(e)}"
|
75
|
+
return responses.JSONResponse({"error": error_message}, status_code=400)
|
76
|
+
|
77
|
+
assert loaded_model
|
78
|
+
|
79
|
+
try:
|
80
|
+
# TODO(shchen): SNOW-835369, Support target method in inference server (Multi-task model).
|
81
|
+
# Mypy ignore will be fixed along with the above ticket.
|
82
|
+
predictions = loaded_model.predict(x) # type: ignore[attr-defined]
|
83
|
+
result = predictions.to_records(index=True).tolist()
|
84
|
+
response = {"data": result}
|
85
|
+
return responses.JSONResponse(response)
|
86
|
+
except Exception as e:
|
87
|
+
error_message = f"Prediction failed: {str(e)}"
|
88
|
+
return responses.JSONResponse({"error": error_message}, status_code=400)
|
89
|
+
|
90
|
+
|
91
|
+
def _in_test_mode() -> bool:
|
92
|
+
"""Check if the code is running in test mode.
|
93
|
+
|
94
|
+
Specifically, it checks for the presence of
|
95
|
+
- "PYTEST_CURRENT_TEST" environment variable, which is automatically set by Pytest when running tests, and
|
96
|
+
- "TEST_WORKSPACE" environment variable, which is set by Bazel test, and
|
97
|
+
- "TEST_SRCDIR" environment variable, which is set by the Absl test.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
True if in test mode; otherwise, returns False
|
101
|
+
"""
|
102
|
+
is_running_under_py_test = "PYTEST_CURRENT_TEST" in os.environ
|
103
|
+
is_running_under_bazel_test = "TEST_WORKSPACE" in os.environ
|
104
|
+
is_running_under_absl_test = "TEST_SRCDIR" in os.environ
|
105
|
+
return is_running_under_py_test or is_running_under_bazel_test or is_running_under_absl_test
|
106
|
+
|
107
|
+
|
108
|
+
def run_app() -> applications.Starlette:
|
109
|
+
if not _in_test_mode():
|
110
|
+
_run_setup()
|
111
|
+
routes = [
|
112
|
+
routing.Route("/health", endpoint=ready, methods=["GET"]),
|
113
|
+
routing.Route("/predict", endpoint=predict, methods=["POST"]),
|
114
|
+
]
|
115
|
+
return applications.Starlette(routes=routes)
|
116
|
+
|
117
|
+
|
118
|
+
app = run_app()
|