snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +142 -20
- snowflake/ml/_internal/utils/identifier.py +48 -11
- snowflake/ml/_internal/utils/snowflake_env.py +23 -13
- snowflake/ml/_internal/utils/sql_identifier.py +1 -1
- snowflake/ml/_internal/utils/table_manager.py +19 -1
- snowflake/ml/_internal/utils/uri.py +2 -2
- snowflake/ml/data/data_connector.py +33 -7
- snowflake/ml/data/torch_utils.py +68 -0
- snowflake/ml/dataset/dataset.py +1 -3
- snowflake/ml/feature_store/feature_store.py +41 -17
- snowflake/ml/feature_store/feature_view.py +2 -2
- snowflake/ml/fileset/embedded_stage_fs.py +1 -1
- snowflake/ml/fileset/fileset.py +1 -1
- snowflake/ml/fileset/sfcfs.py +9 -3
- snowflake/ml/model/_client/model/model_version_impl.py +22 -7
- snowflake/ml/model/_client/ops/model_ops.py +39 -3
- snowflake/ml/model/_client/ops/service_ops.py +198 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
- snowflake/ml/model/_client/sql/service.py +85 -18
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
- snowflake/ml/model/_model_composer/model_composer.py +2 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
- snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
- snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
- snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
- snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
- snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
- snowflake/ml/model/_packager/model_packager.py +2 -0
- snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
- snowflake/ml/model/_signatures/utils.py +9 -0
- snowflake/ml/model/models/llm.py +3 -1
- snowflake/ml/model/type_hints.py +9 -1
- snowflake/ml/modeling/_internal/constants.py +1 -0
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
- snowflake/ml/modeling/_internal/model_specifications.py +2 -0
- snowflake/ml/modeling/_internal/model_trainer.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
- snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
- snowflake/ml/modeling/cluster/birch.py +60 -21
- snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
- snowflake/ml/modeling/cluster/dbscan.py +60 -21
- snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
- snowflake/ml/modeling/cluster/k_means.py +60 -21
- snowflake/ml/modeling/cluster/mean_shift.py +60 -21
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
- snowflake/ml/modeling/cluster/optics.py +60 -21
- snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
- snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
- snowflake/ml/modeling/compose/column_transformer.py +60 -21
- snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
- snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
- snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
- snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
- snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
- snowflake/ml/modeling/covariance/oas.py +60 -21
- snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
- snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
- snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
- snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
- snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/pca.py +60 -21
- snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
- snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
- snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
- snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
- snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
- snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
- snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
- snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
- snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
- snowflake/ml/modeling/impute/knn_imputer.py +60 -21
- snowflake/ml/modeling/impute/missing_indicator.py +60 -21
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/lars.py +60 -21
- snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
- snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/perceptron.py +60 -21
- snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/ridge.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
- snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
- snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
- snowflake/ml/modeling/manifold/isomap.py +60 -21
- snowflake/ml/modeling/manifold/mds.py +60 -21
- snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
- snowflake/ml/modeling/manifold/tsne.py +60 -21
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
- snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
- snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
- snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
- snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
- snowflake/ml/modeling/pipeline/pipeline.py +1 -12
- snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
- snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
- snowflake/ml/modeling/svm/linear_svc.py +60 -21
- snowflake/ml/modeling/svm/linear_svr.py +60 -21
- snowflake/ml/modeling/svm/nu_svc.py +60 -21
- snowflake/ml/modeling/svm/nu_svr.py +60 -21
- snowflake/ml/modeling/svm/svc.py +60 -21
- snowflake/ml/modeling/svm/svr.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
- snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
- snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
- snowflake/ml/registry/_manager/model_manager.py +4 -0
- snowflake/ml/registry/model_registry.py +1 -1
- snowflake/ml/registry/registry.py +1 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
- snowflake/ml/data/torch_dataset.py +0 -33
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,9 @@
|
|
1
|
+
import enum
|
2
|
+
import json
|
1
3
|
import textwrap
|
2
4
|
from typing import Any, Dict, List, Optional, Tuple
|
3
5
|
|
6
|
+
from snowflake import snowpark
|
4
7
|
from snowflake.ml._internal.utils import (
|
5
8
|
identifier,
|
6
9
|
query_result_checker,
|
@@ -11,6 +14,17 @@ from snowflake.snowpark import dataframe, functions as F, types as spt
|
|
11
14
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
12
15
|
|
13
16
|
|
17
|
+
class ServiceStatus(enum.Enum):
|
18
|
+
UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
|
19
|
+
PENDING = "PENDING" # resource set is being created, can't be used yet
|
20
|
+
READY = "READY" # resource set has been deployed.
|
21
|
+
DELETING = "DELETING" # resource set is being deleted
|
22
|
+
FAILED = "FAILED" # resource set has failed and cannot be used anymore
|
23
|
+
DONE = "DONE" # resource set has finished running
|
24
|
+
NOT_FOUND = "NOT_FOUND" # not found or deleted
|
25
|
+
INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
|
26
|
+
|
27
|
+
|
14
28
|
class ServiceSQLClient(_base._BaseSQLClient):
|
15
29
|
def build_model_container(
|
16
30
|
self,
|
@@ -30,20 +44,21 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
30
44
|
) -> None:
|
31
45
|
actual_image_repo_database = image_repo_database_name or self._database_name
|
32
46
|
actual_image_repo_schema = image_repo_schema_name or self._schema_name
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
47
|
+
actual_model_database = database_name or self._database_name
|
48
|
+
actual_model_schema = schema_name or self._schema_name
|
49
|
+
fq_model_name = self.fully_qualified_object_name(actual_model_database, actual_model_schema, model_name)
|
50
|
+
fq_image_repo_name = identifier.get_schema_level_object_identifier(
|
51
|
+
actual_image_repo_database.identifier(),
|
52
|
+
actual_image_repo_schema.identifier(),
|
53
|
+
image_repo_name.identifier(),
|
40
54
|
)
|
41
|
-
|
55
|
+
is_gpu_str = "TRUE" if gpu else "FALSE"
|
56
|
+
force_rebuild_str = "TRUE" if force_rebuild else "FALSE"
|
42
57
|
query_result_checker.SqlResultValidator(
|
43
58
|
self._session,
|
44
59
|
(
|
45
60
|
f"CALL SYSTEM$BUILD_MODEL_CONTAINER('{fq_model_name}', '{version_name}', '{compute_pool_name}',"
|
46
|
-
f" '{fq_image_repo_name}', '{
|
61
|
+
f" '{fq_image_repo_name}', '{is_gpu_str}', '{force_rebuild_str}', '', '{external_access_integration}')"
|
47
62
|
),
|
48
63
|
statement_params=statement_params,
|
49
64
|
).has_dimensions(expected_rows=1, expected_cols=1).validate()
|
@@ -54,12 +69,12 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
54
69
|
stage_path: str,
|
55
70
|
model_deployment_spec_file_rel_path: str,
|
56
71
|
statement_params: Optional[Dict[str, Any]] = None,
|
57
|
-
) ->
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
72
|
+
) -> Tuple[str, snowpark.AsyncJob]:
|
73
|
+
async_job = self._session.sql(
|
74
|
+
f"CALL SYSTEM$DEPLOY_MODEL('@{stage_path}/{model_deployment_spec_file_rel_path}')"
|
75
|
+
).collect(block=False, statement_params=statement_params)
|
76
|
+
assert isinstance(async_job, snowpark.AsyncJob)
|
77
|
+
return async_job.query_id, async_job
|
63
78
|
|
64
79
|
def invoke_function_method(
|
65
80
|
self,
|
@@ -74,12 +89,20 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
74
89
|
statement_params: Optional[Dict[str, Any]] = None,
|
75
90
|
) -> dataframe.DataFrame:
|
76
91
|
with_statements = []
|
92
|
+
actual_database_name = database_name or self._database_name
|
93
|
+
actual_schema_name = schema_name or self._schema_name
|
94
|
+
|
95
|
+
function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
|
96
|
+
fully_qualified_function_name = identifier.get_schema_level_object_identifier(
|
97
|
+
actual_database_name.identifier(),
|
98
|
+
actual_schema_name.identifier(),
|
99
|
+
function_name,
|
100
|
+
)
|
101
|
+
|
77
102
|
if len(input_df.queries["queries"]) == 1 and len(input_df.queries["post_actions"]) == 0:
|
78
103
|
INTERMEDIATE_TABLE_NAME = "SNOWPARK_ML_MODEL_INFERENCE_INPUT"
|
79
104
|
with_statements.append(f"{INTERMEDIATE_TABLE_NAME} AS ({input_df.queries['queries'][0]})")
|
80
105
|
else:
|
81
|
-
actual_database_name = database_name or self._database_name
|
82
|
-
actual_schema_name = schema_name or self._schema_name
|
83
106
|
tmp_table_name = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE)
|
84
107
|
INTERMEDIATE_TABLE_NAME = identifier.get_schema_level_object_identifier(
|
85
108
|
actual_database_name.identifier(),
|
@@ -104,7 +127,7 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
104
127
|
sql = textwrap.dedent(
|
105
128
|
f"""{with_sql}
|
106
129
|
SELECT *,
|
107
|
-
{
|
130
|
+
{fully_qualified_function_name}({args_sql}) AS {INTERMEDIATE_OBJ_NAME}
|
108
131
|
FROM {INTERMEDIATE_TABLE_NAME}"""
|
109
132
|
)
|
110
133
|
|
@@ -127,3 +150,47 @@ class ServiceSQLClient(_base._BaseSQLClient):
|
|
127
150
|
output_df._statement_params = statement_params # type: ignore[assignment]
|
128
151
|
|
129
152
|
return output_df
|
153
|
+
|
154
|
+
def get_service_logs(
|
155
|
+
self,
|
156
|
+
*,
|
157
|
+
service_name: str,
|
158
|
+
instance_id: str = "0",
|
159
|
+
container_name: str,
|
160
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
161
|
+
) -> str:
|
162
|
+
system_func = "SYSTEM$GET_SERVICE_LOGS"
|
163
|
+
rows = (
|
164
|
+
query_result_checker.SqlResultValidator(
|
165
|
+
self._session,
|
166
|
+
f"CALL {system_func}('{service_name}', '{instance_id}', '{container_name}')",
|
167
|
+
statement_params=statement_params,
|
168
|
+
)
|
169
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
170
|
+
.validate()
|
171
|
+
)
|
172
|
+
return str(rows[0][system_func])
|
173
|
+
|
174
|
+
def get_service_status(
|
175
|
+
self,
|
176
|
+
*,
|
177
|
+
service_name: str,
|
178
|
+
include_message: bool = False,
|
179
|
+
statement_params: Optional[Dict[str, Any]] = None,
|
180
|
+
) -> Tuple[ServiceStatus, Optional[str]]:
|
181
|
+
system_func = "SYSTEM$GET_SERVICE_STATUS"
|
182
|
+
rows = (
|
183
|
+
query_result_checker.SqlResultValidator(
|
184
|
+
self._session,
|
185
|
+
f"CALL {system_func}('{service_name}')",
|
186
|
+
statement_params=statement_params,
|
187
|
+
)
|
188
|
+
.has_dimensions(expected_rows=1, expected_cols=1)
|
189
|
+
.validate()
|
190
|
+
)
|
191
|
+
metadata = json.loads(rows[0][system_func])[0]
|
192
|
+
if metadata and metadata["status"]:
|
193
|
+
service_status = ServiceStatus(metadata["status"])
|
194
|
+
message = metadata["message"] if include_message else None
|
195
|
+
return service_status, message
|
196
|
+
return ServiceStatus.UNKNOWN, None
|
@@ -182,7 +182,7 @@ class ServerImageBuilder(base_image_builder.ImageBuilder):
|
|
182
182
|
with file_utils.open_file(spec_file_path, "w+") as spec_file:
|
183
183
|
assert self.artifact_stage_location.startswith("@")
|
184
184
|
normed_artifact_stage_path = posixpath.normpath(identifier.remove_prefix(self.artifact_stage_location, "@"))
|
185
|
-
(db, schema, stage, path) = identifier.
|
185
|
+
(db, schema, stage, path) = identifier.parse_snowflake_stage_path(normed_artifact_stage_path)
|
186
186
|
content = Template(spec_template).safe_substitute(
|
187
187
|
{
|
188
188
|
"base_image": base_image,
|
@@ -280,7 +280,7 @@ def _get_or_create_image_repo(session: Session, *, service_func_name: str, image
|
|
280
280
|
conn = session._conn._conn
|
281
281
|
# We try to use the same db and schema as the service function locates, as we could retrieve those information
|
282
282
|
# if that is a fully qualified one. If not we use the current session one.
|
283
|
-
(_db, _schema, _
|
283
|
+
(_db, _schema, _) = identifier.parse_schema_level_object_identifier(service_func_name)
|
284
284
|
db = _db if _db is not None else conn._database
|
285
285
|
schema = _schema if _schema is not None else conn._schema
|
286
286
|
assert isinstance(db, str) and isinstance(schema, str)
|
@@ -343,7 +343,7 @@ class SnowServiceDeployment:
|
|
343
343
|
self.model_zip_stage_path = model_zip_stage_path
|
344
344
|
self.options = options
|
345
345
|
self.target_method = target_method
|
346
|
-
(db, schema, _
|
346
|
+
(db, schema, _) = identifier.parse_schema_level_object_identifier(service_func_name)
|
347
347
|
|
348
348
|
self._service_name = identifier.get_schema_level_object_identifier(db, schema, f"service_{model_id}")
|
349
349
|
self._job_name = identifier.get_schema_level_object_identifier(db, schema, f"build_{model_id}")
|
@@ -503,7 +503,7 @@ class SnowServiceDeployment:
|
|
503
503
|
norm_stage_path = posixpath.normpath(identifier.remove_prefix(self.model_zip_stage_path, "@"))
|
504
504
|
# Ensure model stage path has root prefix as stage mount will it mount it to root.
|
505
505
|
absolute_model_stage_path = os.path.join("/", norm_stage_path)
|
506
|
-
(db, schema, stage, path) = identifier.
|
506
|
+
(db, schema, stage, path) = identifier.parse_snowflake_stage_path(norm_stage_path)
|
507
507
|
substitutes = {
|
508
508
|
"image": image,
|
509
509
|
"predict_endpoint_name": constants.PREDICT,
|
@@ -92,6 +92,7 @@ class ModelComposer:
|
|
92
92
|
python_version: Optional[str] = None,
|
93
93
|
ext_modules: Optional[List[ModuleType]] = None,
|
94
94
|
code_paths: Optional[List[str]] = None,
|
95
|
+
model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
|
95
96
|
options: Optional[model_types.ModelSaveOption] = None,
|
96
97
|
) -> model_meta.ModelMetadata:
|
97
98
|
if not options:
|
@@ -120,6 +121,7 @@ class ModelComposer:
|
|
120
121
|
python_version=python_version,
|
121
122
|
ext_modules=ext_modules,
|
122
123
|
code_paths=code_paths,
|
124
|
+
model_objective=model_objective,
|
123
125
|
options=options,
|
124
126
|
)
|
125
127
|
assert self.packager.meta is not None
|
@@ -1,7 +1,6 @@
|
|
1
1
|
import collections
|
2
2
|
import copy
|
3
3
|
import pathlib
|
4
|
-
import warnings
|
5
4
|
from typing import List, Optional, cast
|
6
5
|
|
7
6
|
import yaml
|
@@ -78,13 +77,9 @@ class ModelManifest:
|
|
78
77
|
)
|
79
78
|
|
80
79
|
dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
"be warehouse-compabible. The model may need to be run in SPCS.",
|
85
|
-
category=UserWarning,
|
86
|
-
stacklevel=1,
|
87
|
-
)
|
80
|
+
|
81
|
+
# We only want to include pip dependencies file if there are any pip requirements.
|
82
|
+
if len(model_meta.env.pip_requirements) > 0:
|
88
83
|
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
89
84
|
|
90
85
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
@@ -1,9 +1,11 @@
|
|
1
1
|
import json
|
2
|
+
import warnings
|
2
3
|
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
3
4
|
|
4
5
|
import numpy as np
|
5
6
|
import numpy.typing as npt
|
6
7
|
import pandas as pd
|
8
|
+
from absl import logging
|
7
9
|
|
8
10
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
9
11
|
from snowflake.ml.model._packager.model_meta import model_meta
|
@@ -11,6 +13,17 @@ from snowflake.ml.model._signatures import snowpark_handler
|
|
11
13
|
from snowflake.snowpark import DataFrame as SnowparkDataFrame
|
12
14
|
|
13
15
|
|
16
|
+
class NumpyEncoder(json.JSONEncoder):
|
17
|
+
def default(self, obj: Any) -> Any:
|
18
|
+
if isinstance(obj, np.integer):
|
19
|
+
return int(obj)
|
20
|
+
if isinstance(obj, np.floating):
|
21
|
+
return float(obj)
|
22
|
+
if isinstance(obj, np.ndarray):
|
23
|
+
return obj.tolist()
|
24
|
+
return super().default(obj)
|
25
|
+
|
26
|
+
|
14
27
|
def _is_callable(model: model_types.SupportedModelType, method_name: str) -> bool:
|
15
28
|
return callable(getattr(model, method_name, None))
|
16
29
|
|
@@ -93,23 +106,42 @@ def convert_explanations_to_2D_df(
|
|
93
106
|
return pd.DataFrame(explanations)
|
94
107
|
|
95
108
|
if hasattr(model, "classes_"):
|
96
|
-
classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
|
109
|
+
classes_list = [str(cl) for cl in model.classes_] # type:ignore[union-attr]
|
97
110
|
len_classes = len(classes_list)
|
98
111
|
if explanations.shape[2] != len_classes:
|
99
112
|
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
100
113
|
else:
|
101
|
-
classes_list = [i for i in range(explanations.shape[2])]
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
if isinstance(cl, (int, np.integer)):
|
110
|
-
cl = int(cl)
|
111
|
-
class_explanations[cl] = cl_exp
|
112
|
-
col_list.append(json.dumps(class_explanations))
|
113
|
-
exp_2d.append(col_list)
|
114
|
+
classes_list = [str(i) for i in range(explanations.shape[2])]
|
115
|
+
|
116
|
+
def row_to_dict(row: npt.NDArray[Any]) -> npt.NDArray[Any]:
|
117
|
+
"""Converts a single row to a dictionary."""
|
118
|
+
# convert to object or numpy creates strings of fixed length
|
119
|
+
return np.asarray(json.dumps(dict(zip(classes_list, row)), cls=NumpyEncoder), dtype=object)
|
120
|
+
|
121
|
+
exp_2d = np.apply_along_axis(row_to_dict, -1, explanations)
|
114
122
|
|
115
123
|
return pd.DataFrame(exp_2d)
|
124
|
+
|
125
|
+
|
126
|
+
def validate_model_objective(
|
127
|
+
passed_model_objective: model_types.ModelObjective, inferred_model_objective: model_types.ModelObjective
|
128
|
+
) -> model_types.ModelObjective:
|
129
|
+
if (
|
130
|
+
passed_model_objective != model_types.ModelObjective.UNKNOWN
|
131
|
+
and inferred_model_objective != model_types.ModelObjective.UNKNOWN
|
132
|
+
):
|
133
|
+
if passed_model_objective != inferred_model_objective:
|
134
|
+
warnings.warn(
|
135
|
+
f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
|
136
|
+
f"version and passed argument ModelObjective: {passed_model_objective.name} is ignored",
|
137
|
+
category=UserWarning,
|
138
|
+
stacklevel=1,
|
139
|
+
)
|
140
|
+
return inferred_model_objective
|
141
|
+
elif inferred_model_objective != model_types.ModelObjective.UNKNOWN:
|
142
|
+
logging.info(
|
143
|
+
f"Inferred ModelObjective: {inferred_model_objective.name} is used as model objective for this model "
|
144
|
+
f"version"
|
145
|
+
)
|
146
|
+
return inferred_model_objective
|
147
|
+
return passed_model_objective
|
@@ -34,20 +34,20 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
34
34
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
35
|
|
36
36
|
@classmethod
|
37
|
-
def
|
37
|
+
def get_model_objective_and_output_type(cls, model: "catboost.CatBoost") -> model_types.ModelObjective:
|
38
38
|
import catboost
|
39
39
|
|
40
40
|
if isinstance(model, catboost.CatBoostClassifier):
|
41
41
|
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
42
|
if num_classes == 2:
|
43
|
-
return
|
44
|
-
return
|
43
|
+
return model_types.ModelObjective.BINARY_CLASSIFICATION
|
44
|
+
return model_types.ModelObjective.MULTI_CLASSIFICATION
|
45
45
|
if isinstance(model, catboost.CatBoostRanker):
|
46
|
-
return
|
46
|
+
return model_types.ModelObjective.RANKING
|
47
47
|
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
-
return
|
48
|
+
return model_types.ModelObjective.REGRESSION
|
49
49
|
# TODO: Find out model type from the generic Catboost Model
|
50
|
-
return
|
50
|
+
return model_types.ModelObjective.UNKNOWN
|
51
51
|
|
52
52
|
@classmethod
|
53
53
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
@@ -77,6 +77,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
77
77
|
is_sub_model: Optional[bool] = False,
|
78
78
|
**kwargs: Unpack[model_types.CatBoostModelSaveOptions],
|
79
79
|
) -> None:
|
80
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
81
|
+
|
80
82
|
import catboost
|
81
83
|
|
82
84
|
assert isinstance(model, catboost.CatBoost)
|
@@ -105,11 +107,14 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
105
107
|
sample_input_data=sample_input_data,
|
106
108
|
get_prediction_fn=get_prediction,
|
107
109
|
)
|
108
|
-
|
109
|
-
model_meta.model_objective =
|
110
|
-
|
110
|
+
inferred_model_objective = cls.get_model_objective_and_output_type(model)
|
111
|
+
model_meta.model_objective = handlers_utils.validate_model_objective(
|
112
|
+
model_meta.model_objective, inferred_model_objective
|
113
|
+
)
|
114
|
+
model_objective = model_meta.model_objective
|
115
|
+
if enable_explainability:
|
111
116
|
output_type = model_signature.DataType.DOUBLE
|
112
|
-
if model_objective ==
|
117
|
+
if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
|
113
118
|
output_type = model_signature.DataType.STRING
|
114
119
|
model_meta = handlers_utils.add_explain_method_signature(
|
115
120
|
model_meta=model_meta,
|
@@ -143,11 +148,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
143
148
|
],
|
144
149
|
check_local_version=True,
|
145
150
|
)
|
146
|
-
if
|
147
|
-
model_meta.env.include_if_absent(
|
148
|
-
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
149
|
-
check_local_version=True,
|
150
|
-
)
|
151
|
+
if enable_explainability:
|
152
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
151
153
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
152
154
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
153
155
|
|
@@ -369,7 +369,9 @@ class HuggingFacePipelineHandler(
|
|
369
369
|
else:
|
370
370
|
# For others, we could offer the whole dataframe as a list.
|
371
371
|
# Some of them may need some conversion
|
372
|
-
if
|
372
|
+
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
373
|
+
raw_model, transformers.ConversationalPipeline
|
374
|
+
):
|
373
375
|
input_data = [
|
374
376
|
transformers.Conversation(
|
375
377
|
text=conv_data["user_inputs"][0],
|
@@ -391,27 +393,33 @@ class HuggingFacePipelineHandler(
|
|
391
393
|
# Making it not aligned with the auto-inferred signature.
|
392
394
|
# If the output is a dict, we could blindly create a list containing that.
|
393
395
|
# Otherwise, creating pandas DataFrame won't succeed.
|
394
|
-
if
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
396
|
+
if (
|
397
|
+
(hasattr(transformers, "Conversation") and isinstance(temp_res, transformers.Conversation))
|
398
|
+
or isinstance(temp_res, dict)
|
399
|
+
or (
|
400
|
+
# For some pipeline that is expected to generate a list of dict per input
|
401
|
+
# When it omit outer list, it becomes list of dict instead of list of list of dict.
|
402
|
+
# We need to distinguish them from those pipelines that designed to output a dict per input
|
403
|
+
# So we need to check the pipeline type.
|
404
|
+
isinstance(
|
405
|
+
raw_model,
|
406
|
+
(
|
407
|
+
transformers.FillMaskPipeline,
|
408
|
+
transformers.QuestionAnsweringPipeline,
|
409
|
+
),
|
410
|
+
)
|
411
|
+
and X.shape[0] == 1
|
412
|
+
and isinstance(temp_res[0], dict)
|
405
413
|
)
|
406
|
-
and X.shape[0] == 1
|
407
|
-
and isinstance(temp_res[0], dict)
|
408
414
|
):
|
409
415
|
temp_res = [temp_res]
|
410
416
|
|
411
417
|
if len(temp_res) == 0:
|
412
418
|
return pd.DataFrame()
|
413
419
|
|
414
|
-
if
|
420
|
+
if hasattr(transformers, "ConversationalPipeline") and isinstance(
|
421
|
+
raw_model, transformers.ConversationalPipeline
|
422
|
+
):
|
415
423
|
temp_res = [[conv.generated_responses] for conv in temp_res]
|
416
424
|
|
417
425
|
# To concat those who outputs a list with one input.
|
@@ -19,7 +19,11 @@ from typing_extensions import TypeGuard, Unpack
|
|
19
19
|
from snowflake.ml._internal import type_utils
|
20
20
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
21
21
|
from snowflake.ml.model._packager.model_env import model_env
|
22
|
-
from snowflake.ml.model._packager.model_handlers import
|
22
|
+
from snowflake.ml.model._packager.model_handlers import (
|
23
|
+
_base,
|
24
|
+
_utils as handlers_utils,
|
25
|
+
model_objective_utils,
|
26
|
+
)
|
23
27
|
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
24
28
|
from snowflake.ml.model._packager.model_meta import (
|
25
29
|
model_blob_meta,
|
@@ -43,47 +47,6 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
43
47
|
|
44
48
|
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
45
49
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
-
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
47
|
-
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
48
|
-
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
49
|
-
_REGRESSION_OBJECTIVES = [
|
50
|
-
"regression",
|
51
|
-
"regression_l1",
|
52
|
-
"huber",
|
53
|
-
"fair",
|
54
|
-
"poisson",
|
55
|
-
"quantile",
|
56
|
-
"tweedie",
|
57
|
-
"mape",
|
58
|
-
"gamma",
|
59
|
-
]
|
60
|
-
|
61
|
-
@classmethod
|
62
|
-
def get_model_objective(
|
63
|
-
cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
|
64
|
-
) -> model_meta_schema.ModelObjective:
|
65
|
-
import lightgbm
|
66
|
-
|
67
|
-
# does not account for cross-entropy and custom
|
68
|
-
if isinstance(model, lightgbm.LGBMClassifier):
|
69
|
-
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
70
|
-
if num_classes == 2:
|
71
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
72
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
73
|
-
if isinstance(model, lightgbm.LGBMRanker):
|
74
|
-
return model_meta_schema.ModelObjective.RANKING
|
75
|
-
if isinstance(model, lightgbm.LGBMRegressor):
|
76
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
77
|
-
model_objective = model.params["objective"]
|
78
|
-
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
79
|
-
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
80
|
-
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
81
|
-
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
82
|
-
if model_objective in cls._RANKING_OBJECTIVES:
|
83
|
-
return model_meta_schema.ModelObjective.RANKING
|
84
|
-
if model_objective in cls._REGRESSION_OBJECTIVES:
|
85
|
-
return model_meta_schema.ModelObjective.REGRESSION
|
86
|
-
return model_meta_schema.ModelObjective.UNKNOWN
|
87
50
|
|
88
51
|
@classmethod
|
89
52
|
def can_handle(
|
@@ -118,6 +81,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
118
81
|
is_sub_model: Optional[bool] = False,
|
119
82
|
**kwargs: Unpack[model_types.LGBMModelSaveOptions],
|
120
83
|
) -> None:
|
84
|
+
enable_explainability = kwargs.get("enable_explainability", True)
|
85
|
+
|
121
86
|
import lightgbm
|
122
87
|
|
123
88
|
assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
|
@@ -146,20 +111,16 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
146
111
|
sample_input_data=sample_input_data,
|
147
112
|
get_prediction_fn=get_prediction,
|
148
113
|
)
|
149
|
-
|
150
|
-
model_meta.model_objective =
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
|
155
|
-
model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
|
156
|
-
]:
|
157
|
-
output_type = model_signature.DataType.STRING
|
114
|
+
model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
|
115
|
+
model_meta.model_objective = handlers_utils.validate_model_objective(
|
116
|
+
model_meta.model_objective, model_objective_and_output.objective
|
117
|
+
)
|
118
|
+
if enable_explainability:
|
158
119
|
model_meta = handlers_utils.add_explain_method_signature(
|
159
120
|
model_meta=model_meta,
|
160
121
|
explain_method="explain",
|
161
122
|
target_method="predict",
|
162
|
-
output_return_type=output_type,
|
123
|
+
output_return_type=model_objective_and_output.output_type,
|
163
124
|
)
|
164
125
|
model_meta.function_properties = {
|
165
126
|
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
@@ -189,11 +150,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
189
150
|
],
|
190
151
|
check_local_version=True,
|
191
152
|
)
|
192
|
-
if
|
193
|
-
model_meta.env.include_if_absent(
|
194
|
-
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
195
|
-
check_local_version=True,
|
196
|
-
)
|
153
|
+
if enable_explainability:
|
154
|
+
model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
|
197
155
|
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
198
156
|
|
199
157
|
return None
|
@@ -205,7 +205,9 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
205
205
|
"token": raw_model.token,
|
206
206
|
}
|
207
207
|
model_dir_path = raw_model.model_id_or_path
|
208
|
-
peft_config = peft.PeftConfig.from_pretrained(
|
208
|
+
peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
|
209
|
+
model_dir_path
|
210
|
+
)
|
209
211
|
base_model_path = peft_config.base_model_name_or_path
|
210
212
|
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
211
213
|
base_model_path,
|
@@ -221,7 +223,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
221
223
|
model_dir_path,
|
222
224
|
device_map="auto",
|
223
225
|
torch_dtype="auto",
|
224
|
-
**hub_kwargs,
|
226
|
+
**hub_kwargs, # type: ignore[arg-type]
|
225
227
|
)
|
226
228
|
hf_model.eval()
|
227
229
|
hf_model = hf_model.merge_and_unload()
|