tencent-wedata-feature-engineering-dev 0.1.42__py3-none-any.whl → 0.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {tencent_wedata_feature_engineering_dev-0.1.42.dist-info → tencent_wedata_feature_engineering_dev-0.2.5.dist-info}/METADATA +14 -3
- tencent_wedata_feature_engineering_dev-0.2.5.dist-info/RECORD +78 -0
- {tencent_wedata_feature_engineering_dev-0.1.42.dist-info → tencent_wedata_feature_engineering_dev-0.2.5.dist-info}/WHEEL +1 -1
- wedata/__init__.py +1 -1
- wedata/common/base_table_client/__init__.py +1 -0
- wedata/common/base_table_client/base.py +58 -0
- wedata/common/cloud_sdk_client/__init__.py +2 -0
- wedata/{feature_store → common}/cloud_sdk_client/client.py +56 -12
- wedata/{feature_store → common}/cloud_sdk_client/models.py +212 -37
- wedata/{feature_store → common}/cloud_sdk_client/utils.py +14 -0
- wedata/{feature_store → common}/constants/constants.py +3 -2
- wedata/common/constants/engine_types.py +34 -0
- wedata/{feature_store → common}/entities/column_info.py +6 -5
- wedata/{feature_store → common}/entities/feature_column_info.py +2 -1
- wedata/{feature_store → common}/entities/feature_lookup.py +1 -1
- wedata/{feature_store → common}/entities/feature_spec.py +9 -9
- wedata/{feature_store → common}/entities/feature_table_info.py +1 -1
- wedata/{feature_store → common}/entities/function_info.py +2 -1
- wedata/{feature_store → common}/entities/on_demand_column_info.py +2 -1
- wedata/{feature_store → common}/entities/source_data_column_info.py +3 -1
- wedata/{feature_store → common}/entities/training_set.py +6 -6
- wedata/common/feast_client/__init__.py +1 -0
- wedata/{feature_store → common}/feast_client/feast_client.py +3 -4
- wedata/common/log/__init__.py +1 -0
- wedata/common/log/logger.py +44 -0
- wedata/common/spark_client/__init__.py +1 -0
- wedata/{feature_store → common}/spark_client/spark_client.py +6 -9
- wedata/{feature_store → common}/utils/common_utils.py +7 -9
- wedata/{feature_store → common}/utils/env_utils.py +31 -10
- wedata/{feature_store → common}/utils/feature_lookup_utils.py +6 -6
- wedata/{feature_store → common}/utils/feature_spec_utils.py +13 -8
- wedata/{feature_store → common}/utils/feature_utils.py +5 -5
- wedata/{feature_store → common}/utils/on_demand_utils.py +5 -4
- wedata/{feature_store → common}/utils/schema_utils.py +1 -1
- wedata/{feature_store → common}/utils/signature_utils.py +4 -4
- wedata/{feature_store → common}/utils/training_set_utils.py +13 -13
- wedata/{feature_store → common}/utils/uc_utils.py +1 -1
- wedata/feature_engineering/__init__.py +1 -0
- wedata/feature_engineering/client.py +417 -0
- wedata/feature_engineering/ml_training_client/ml_training_client.py +569 -0
- wedata/feature_engineering/mlflow_model.py +9 -0
- wedata/feature_engineering/table_client/__init__.py +0 -0
- wedata/feature_engineering/table_client/table_client.py +548 -0
- wedata/feature_store/client.py +13 -16
- wedata/feature_store/constants/engine_types.py +8 -30
- wedata/feature_store/feature_table_client/feature_table_client.py +98 -108
- wedata/feature_store/training_set_client/training_set_client.py +14 -17
- wedata/tempo/interpol.py +2 -2
- tencent_wedata_feature_engineering_dev-0.1.42.dist-info/RECORD +0 -64
- {tencent_wedata_feature_engineering_dev-0.1.42.dist-info → tencent_wedata_feature_engineering_dev-0.2.5.dist-info}/top_level.txt +0 -0
- /wedata/{feature_store/cloud_sdk_client → common}/__init__.py +0 -0
- /wedata/{feature_store/common/protos → common/constants}/__init__.py +0 -0
- /wedata/{feature_store → common}/entities/__init__.py +0 -0
- /wedata/{feature_store → common}/entities/environment_variables.py +0 -0
- /wedata/{feature_store → common}/entities/feature.py +0 -0
- /wedata/{feature_store → common}/entities/feature_function.py +0 -0
- /wedata/{feature_store → common}/entities/feature_spec_constants.py +0 -0
- /wedata/{feature_store → common}/entities/feature_table.py +0 -0
- /wedata/{feature_store/feast_client → common/protos}/__init__.py +0 -0
- /wedata/{feature_store/common → common}/protos/feature_store_pb2.py +0 -0
- /wedata/{feature_store/spark_client → common/utils}/__init__.py +0 -0
- /wedata/{feature_store → common}/utils/topological_sort.py +0 -0
- /wedata/{feature_store → common}/utils/validation_utils.py +0 -0
- /wedata/{feature_store/utils → feature_engineering/ml_training_client}/__init__.py +0 -0
|
@@ -0,0 +1,569 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from types import ModuleType
|
|
4
|
+
from typing import Any, List, Optional, Union, Dict
|
|
5
|
+
|
|
6
|
+
import mlflow
|
|
7
|
+
from mlflow.models import Model
|
|
8
|
+
from mlflow.utils.file_utils import TempDir, read_yaml
|
|
9
|
+
from pyspark.sql import DataFrame
|
|
10
|
+
from pyspark.sql.functions import struct
|
|
11
|
+
|
|
12
|
+
from wedata.common.constants import constants
|
|
13
|
+
from wedata.common.entities.feature_function import FeatureFunction
|
|
14
|
+
from wedata.common.entities.feature_lookup import FeatureLookup
|
|
15
|
+
from wedata.common.entities.feature_spec import FeatureSpec
|
|
16
|
+
from wedata.common.entities.training_set import TrainingSet
|
|
17
|
+
from wedata.feature_engineering.mlflow_model import _FeatureEngineeringModelWrapper
|
|
18
|
+
from wedata.common.spark_client import SparkClient
|
|
19
|
+
from wedata.common.utils import validation_utils, common_utils, training_set_utils
|
|
20
|
+
from wedata.common.entities.feature_table import FeatureTable
|
|
21
|
+
|
|
22
|
+
from wedata.common.constants.constants import (
|
|
23
|
+
_NO_RESULT_TYPE_PASSED,
|
|
24
|
+
_USE_SPARK_NATIVE_JOIN,
|
|
25
|
+
MODEL_DATA_PATH_ROOT,
|
|
26
|
+
PREDICTION_COLUMN_NAME,
|
|
27
|
+
_PREBUILT_ENV_URI
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
from wedata.common.utils import uc_utils
|
|
31
|
+
from wedata.common.utils.signature_utils import get_mlflow_signature_from_feature_spec, \
|
|
32
|
+
drop_signature_inputs_and_invalid_params
|
|
33
|
+
|
|
34
|
+
_logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
FEATURE_SPEC_GRAPH_MAX_COLUMN_INFO = 1000
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class MLTrainingClient:
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
spark_client: SparkClient
|
|
43
|
+
):
|
|
44
|
+
self._spark_client = spark_client
|
|
45
|
+
|
|
46
|
+
def create_training_set(
|
|
47
|
+
self,
|
|
48
|
+
feature_spec: FeatureSpec,
|
|
49
|
+
label_names: List[str],
|
|
50
|
+
df: DataFrame,
|
|
51
|
+
ft_metadata: training_set_utils._FeatureTableMetadata,
|
|
52
|
+
kwargs,
|
|
53
|
+
):
|
|
54
|
+
uc_function_infos = training_set_utils.get_uc_function_infos(
|
|
55
|
+
self._spark_client,
|
|
56
|
+
{odci.udf_name for odci in feature_spec.on_demand_column_infos},
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
training_set_utils.warn_if_non_photon_for_native_spark(
|
|
60
|
+
kwargs.get(_USE_SPARK_NATIVE_JOIN, False), self._spark_client
|
|
61
|
+
)
|
|
62
|
+
return TrainingSet(
|
|
63
|
+
feature_spec,
|
|
64
|
+
df,
|
|
65
|
+
label_names,
|
|
66
|
+
ft_metadata.feature_table_metadata_map,
|
|
67
|
+
ft_metadata.feature_table_data_map,
|
|
68
|
+
uc_function_infos,
|
|
69
|
+
kwargs.get(_USE_SPARK_NATIVE_JOIN, False),
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def create_training_set_from_feature_lookups(
|
|
73
|
+
self,
|
|
74
|
+
df: DataFrame,
|
|
75
|
+
feature_lookups: List[Union[FeatureLookup, FeatureFunction]],
|
|
76
|
+
label: Union[str, List[str], None],
|
|
77
|
+
exclude_columns: List[str],
|
|
78
|
+
**kwargs,
|
|
79
|
+
) -> TrainingSet:
|
|
80
|
+
|
|
81
|
+
# 获取特征查找列表和特征函数列表
|
|
82
|
+
features = feature_lookups
|
|
83
|
+
feature_lookups = [f for f in features if isinstance(f, FeatureLookup)]
|
|
84
|
+
feature_functions = [f for f in features if isinstance(f, FeatureFunction)]
|
|
85
|
+
|
|
86
|
+
# 最多支持100个FeatureFunctions
|
|
87
|
+
if len(feature_functions) > training_set_utils.MAX_FEATURE_FUNCTIONS:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"A maximum of {training_set_utils.MAX_FEATURE_FUNCTIONS} FeatureFunctions are supported."
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# 如果未提供标签,则用空列表初始化label_names
|
|
93
|
+
label_names = common_utils.as_list(label, [])
|
|
94
|
+
del label
|
|
95
|
+
|
|
96
|
+
# 校验数据集和标签
|
|
97
|
+
training_set_utils.verify_df_and_labels(df, label_names, exclude_columns)
|
|
98
|
+
|
|
99
|
+
# 获取特征表元数据
|
|
100
|
+
ft_metadata = training_set_utils.get_table_metadata(
|
|
101
|
+
self._spark_client,
|
|
102
|
+
{fl.table_name for fl in feature_lookups}
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
column_infos = training_set_utils.get_column_infos(
|
|
106
|
+
feature_lookups,
|
|
107
|
+
feature_functions,
|
|
108
|
+
ft_metadata,
|
|
109
|
+
df_columns=df.columns,
|
|
110
|
+
label_names=label_names,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
training_set_utils.validate_column_infos(
|
|
114
|
+
self._spark_client,
|
|
115
|
+
ft_metadata,
|
|
116
|
+
column_infos.source_data_column_infos,
|
|
117
|
+
column_infos.feature_column_infos,
|
|
118
|
+
column_infos.on_demand_column_infos,
|
|
119
|
+
label_names,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
# Build feature_spec locally for comparison with the feature spec yaml generated by the
|
|
123
|
+
# FeatureStore backend. This will be removed once the migration is validated.
|
|
124
|
+
feature_spec = training_set_utils.build_feature_spec(
|
|
125
|
+
feature_lookups,
|
|
126
|
+
ft_metadata,
|
|
127
|
+
column_infos,
|
|
128
|
+
exclude_columns
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
return self.create_training_set(
|
|
132
|
+
feature_spec,
|
|
133
|
+
label_names,
|
|
134
|
+
df,
|
|
135
|
+
ft_metadata,
|
|
136
|
+
kwargs=kwargs,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def create_feature_spec(
|
|
141
|
+
self,
|
|
142
|
+
name: str,
|
|
143
|
+
features: List[Union[FeatureLookup, FeatureFunction]],
|
|
144
|
+
sparkClient: SparkClient,
|
|
145
|
+
exclude_columns: List[str] = [],
|
|
146
|
+
) -> FeatureSpec:
|
|
147
|
+
|
|
148
|
+
feature_lookups = [f for f in features if isinstance(f, FeatureLookup)]
|
|
149
|
+
feature_functions = [f for f in features if isinstance(f, FeatureFunction)]
|
|
150
|
+
|
|
151
|
+
# Maximum of 100 FeatureFunctions is supported
|
|
152
|
+
if len(feature_functions) > training_set_utils.MAX_FEATURE_FUNCTIONS:
|
|
153
|
+
raise ValueError(
|
|
154
|
+
f"A maximum of {training_set_utils.MAX_FEATURE_FUNCTIONS} FeatureFunctions are supported."
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Get feature table metadata and column infos
|
|
158
|
+
ft_metadata = training_set_utils.get_table_metadata(
|
|
159
|
+
self._spark_client,
|
|
160
|
+
{fl.table_name for fl in feature_lookups}
|
|
161
|
+
)
|
|
162
|
+
column_infos = training_set_utils.get_column_infos(
|
|
163
|
+
feature_lookups,
|
|
164
|
+
feature_functions,
|
|
165
|
+
ft_metadata,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
column_infos = training_set_utils.add_inferred_source_columns(column_infos)
|
|
169
|
+
|
|
170
|
+
training_set_utils.validate_column_infos(
|
|
171
|
+
self._spark_client,
|
|
172
|
+
ft_metadata,
|
|
173
|
+
column_infos.source_data_column_infos,
|
|
174
|
+
column_infos.feature_column_infos,
|
|
175
|
+
column_infos.on_demand_column_infos,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
feature_spec = training_set_utils.build_feature_spec(
|
|
179
|
+
feature_lookups,
|
|
180
|
+
ft_metadata,
|
|
181
|
+
column_infos,
|
|
182
|
+
exclude_columns
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
return feature_spec
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def log_model(
|
|
189
|
+
self,
|
|
190
|
+
model: Any,
|
|
191
|
+
artifact_path: str,
|
|
192
|
+
*,
|
|
193
|
+
flavor: ModuleType,
|
|
194
|
+
training_set: Optional[TrainingSet],
|
|
195
|
+
registered_model_name: Optional[str],
|
|
196
|
+
model_registry_uri: Optional[str],
|
|
197
|
+
await_registration_for: int,
|
|
198
|
+
infer_input_example: bool,
|
|
199
|
+
**kwargs,
|
|
200
|
+
):
|
|
201
|
+
# 验证training_set参数是否提供
|
|
202
|
+
if (training_set is None):
|
|
203
|
+
raise ValueError(
|
|
204
|
+
"'training_set' must be provided, but not ."
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# 获取特征规格并重新格式化表名
|
|
208
|
+
# training_set.feature_spec保证来自FeatureStoreClient.create_training_set的3L格式
|
|
209
|
+
feature_spec = uc_utils.get_feature_spec_with_reformat_full_table_names(
|
|
210
|
+
training_set.feature_spec
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
# 获取标签类型映射和标签
|
|
214
|
+
label_type_map = training_set._label_data_types
|
|
215
|
+
|
|
216
|
+
# 收集所有特征列名
|
|
217
|
+
feature_columns = [
|
|
218
|
+
feature_column.output_name
|
|
219
|
+
for feature_column in feature_spec.feature_column_infos
|
|
220
|
+
]
|
|
221
|
+
df_head = training_set.load_df().select(*feature_columns).head()
|
|
222
|
+
|
|
223
|
+
# 处理输出模式和参数
|
|
224
|
+
override_output_schema = kwargs.pop("output_schema", None)
|
|
225
|
+
params = kwargs.pop("params", {})
|
|
226
|
+
params["result_type"] = params.get("result_type", _NO_RESULT_TYPE_PASSED)
|
|
227
|
+
|
|
228
|
+
# 尝试获取MLflow签名
|
|
229
|
+
try:
|
|
230
|
+
signature = get_mlflow_signature_from_feature_spec(
|
|
231
|
+
feature_spec, label_type_map, override_output_schema, params
|
|
232
|
+
)
|
|
233
|
+
except Exception as e:
|
|
234
|
+
_logger.warning(f"Model could not be logged with a signature: {e}")
|
|
235
|
+
signature = None
|
|
236
|
+
|
|
237
|
+
with TempDir() as tmp_location:
|
|
238
|
+
# wedata data_path路径,改为记录表路径,遍历表名,生成数组
|
|
239
|
+
data_path = os.path.join(tmp_location.path(), "feature_store")
|
|
240
|
+
os.makedirs(data_path, exist_ok=True)
|
|
241
|
+
|
|
242
|
+
# 创建原始MLflow模型
|
|
243
|
+
raw_mlflow_model = Model(
|
|
244
|
+
signature=drop_signature_inputs_and_invalid_params(signature)
|
|
245
|
+
)
|
|
246
|
+
raw_model_path = os.path.join(data_path, constants.RAW_MODEL_FOLDER)
|
|
247
|
+
|
|
248
|
+
# 根据flavor类型保存模型
|
|
249
|
+
if flavor.FLAVOR_NAME != mlflow.pyfunc.FLAVOR_NAME:
|
|
250
|
+
flavor.save_model(
|
|
251
|
+
model, raw_model_path, mlflow_model=raw_mlflow_model, **kwargs
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
flavor.save_model(
|
|
255
|
+
raw_model_path,
|
|
256
|
+
mlflow_model=raw_mlflow_model,
|
|
257
|
+
python_model=model,
|
|
258
|
+
**kwargs,
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# 验证模型是否支持python_function flavor
|
|
262
|
+
if not "python_function" in raw_mlflow_model.flavors:
|
|
263
|
+
raise ValueError(
|
|
264
|
+
f"FeatureStoreClient.log_model does not support '{flavor.__name__}' "
|
|
265
|
+
f"since it does not have a python_function model flavor."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# 获取并处理conda环境配置
|
|
269
|
+
model_env = raw_mlflow_model.flavors["python_function"][mlflow.pyfunc.ENV]
|
|
270
|
+
if isinstance(model_env, dict):
|
|
271
|
+
# mlflow 2.0 has multiple supported environments
|
|
272
|
+
conda_file = model_env[mlflow.pyfunc.EnvType.CONDA]
|
|
273
|
+
else:
|
|
274
|
+
conda_file = model_env
|
|
275
|
+
|
|
276
|
+
conda_env = read_yaml(raw_model_path, conda_file)
|
|
277
|
+
#TODO 暂时不需要databricks-feature-lookup这个包,会导致 python 环境创建失败
|
|
278
|
+
# Check if databricks-feature-lookup version is specified in conda_env
|
|
279
|
+
lookup_client_version_specified = False
|
|
280
|
+
for dependency in conda_env.get("dependencies", []):
|
|
281
|
+
if isinstance(dependency, dict):
|
|
282
|
+
for pip_dep in dependency.get("pip", []):
|
|
283
|
+
if pip_dep.startswith(
|
|
284
|
+
constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE
|
|
285
|
+
):
|
|
286
|
+
lookup_client_version_specified = True
|
|
287
|
+
break
|
|
288
|
+
#TODO 暂时不需要databricks-feature-lookup这个包,会导致 python 环境创建失败
|
|
289
|
+
# If databricks-feature-lookup version is not specified, add default version
|
|
290
|
+
if not lookup_client_version_specified:
|
|
291
|
+
# Get the pip package string for the databricks-feature-lookup client
|
|
292
|
+
default_wedata_feature_lookup_pip_package = common_utils.pip_depependency_pinned_version(
|
|
293
|
+
pip_package_name=constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE,
|
|
294
|
+
version=constants.FEATURE_LOOKUP_CLIENT_MAJOR_VERSION,
|
|
295
|
+
)
|
|
296
|
+
common_utils.add_mlflow_pip_depependency(
|
|
297
|
+
conda_env, default_wedata_feature_lookup_pip_package
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
# 尝试创建输入示例
|
|
301
|
+
input_example = None
|
|
302
|
+
try:
|
|
303
|
+
if df_head is not None and infer_input_example:
|
|
304
|
+
input_example = df_head.asDict()
|
|
305
|
+
except Exception:
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
feature_spec.save(data_path)
|
|
309
|
+
|
|
310
|
+
print(f'artifact_path:{artifact_path},data_path:{data_path},conda_env:{conda_env},'
|
|
311
|
+
f'signature:{signature},input_example:{input_example}');
|
|
312
|
+
|
|
313
|
+
mlflow.pyfunc.log_model(
|
|
314
|
+
artifact_path=artifact_path,
|
|
315
|
+
python_model=_FeatureEngineeringModelWrapper(model),
|
|
316
|
+
# data_path=data_path,
|
|
317
|
+
artifacts={"feature_store": data_path},
|
|
318
|
+
code_path=None,
|
|
319
|
+
conda_env=conda_env,
|
|
320
|
+
signature=signature,
|
|
321
|
+
input_example=input_example,
|
|
322
|
+
registered_model_name=registered_model_name
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# mlflow.pyfunc.log_model(
|
|
326
|
+
# artifact_path=artifact_path,
|
|
327
|
+
# loader_module=constants.MLFLOW_MODEL_NAME,
|
|
328
|
+
# data_path=data_path,
|
|
329
|
+
# conda_env=conda_env,
|
|
330
|
+
# signature=signature,
|
|
331
|
+
# input_example=input_example,
|
|
332
|
+
# )
|
|
333
|
+
|
|
334
|
+
# 注册模型
|
|
335
|
+
# if registered_model_name is not None:
|
|
336
|
+
# run_id = mlflow.tracking.fluent.active_run().info.run_id
|
|
337
|
+
# if model_registry_uri is not None:
|
|
338
|
+
# mlflow.set_registry_uri(model_registry_uri)
|
|
339
|
+
#
|
|
340
|
+
# mlflow.register_model(
|
|
341
|
+
# f"runs:/{run_id}/{artifact_path}",
|
|
342
|
+
# registered_model_name,
|
|
343
|
+
# await_registration_for=await_registration_for,
|
|
344
|
+
# )
|
|
345
|
+
#
|
|
346
|
+
# print(f"Model registered successfully: {registered_model_name}")
|
|
347
|
+
|
|
348
|
+
# # 验证模型是否已注册
|
|
349
|
+
# from mlflow.tracking import MlflowClient
|
|
350
|
+
# client = MlflowClient()
|
|
351
|
+
# model_version = client.get_latest_versions(registered_model_name, stages=["None"])[0]
|
|
352
|
+
# print(f"Registered model version: {model_version.version}")
|
|
353
|
+
|
|
354
|
+
def score_batch(
|
|
355
|
+
self,
|
|
356
|
+
model_uri: Optional[str],
|
|
357
|
+
df: DataFrame,
|
|
358
|
+
result_type: str,
|
|
359
|
+
env_manager: Optional[str] = None,
|
|
360
|
+
local_uri: Optional[str] = None,
|
|
361
|
+
params: Optional[dict[str, Any]] = None,
|
|
362
|
+
timestamp_key: str = None,
|
|
363
|
+
**kwargs,
|
|
364
|
+
) -> DataFrame:
|
|
365
|
+
# TODO:ML 待确定是否需要
|
|
366
|
+
# req_context = RequestContext(request_context.SCORE_BATCH, client_name)
|
|
367
|
+
|
|
368
|
+
# 校验是否
|
|
369
|
+
validation_utils.check_dataframe_type(df)
|
|
370
|
+
if (model_uri is None) == (local_uri is None):
|
|
371
|
+
raise ValueError(
|
|
372
|
+
"Either 'model_uri' or 'local_uri' must be provided, but not both."
|
|
373
|
+
)
|
|
374
|
+
if df.isStreaming:
|
|
375
|
+
raise ValueError("Streaming DataFrames are not supported.")
|
|
376
|
+
|
|
377
|
+
# 返回结果中会包含列名为 prediction,为预测结果,输入数据中不用此名字
|
|
378
|
+
if PREDICTION_COLUMN_NAME in df.columns:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
"FeatureStoreClient.score_batch returns a DataFrame with a new column "
|
|
381
|
+
f'"{PREDICTION_COLUMN_NAME}". df already has a column with name '
|
|
382
|
+
f'"{PREDICTION_COLUMN_NAME}".'
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# 校验列中是否有重复列名
|
|
386
|
+
validation_utils.validate_strings_unique(
|
|
387
|
+
df.columns,
|
|
388
|
+
"The provided DataFrame for scoring must have unique column names. Found duplicates {}.",
|
|
389
|
+
)
|
|
390
|
+
artifact_path = os.path.join("artifacts", MODEL_DATA_PATH_ROOT)
|
|
391
|
+
with (TempDir() as tmp_location):
|
|
392
|
+
local_path = (
|
|
393
|
+
local_uri
|
|
394
|
+
if local_uri
|
|
395
|
+
else common_utils.download_model_artifacts(model_uri, tmp_location.path())
|
|
396
|
+
)
|
|
397
|
+
model_data_path = os.path.join(local_path, artifact_path)
|
|
398
|
+
|
|
399
|
+
# Augment local workspace metastore tables from 2L to 3L,
|
|
400
|
+
# this will prevent us from erroneously reading data from other catalogs
|
|
401
|
+
feature_spec = uc_utils.get_feature_spec_with_full_table_names(
|
|
402
|
+
FeatureSpec.load(model_data_path)
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
raw_model_path = os.path.join(
|
|
406
|
+
model_data_path, constants.RAW_MODEL_FOLDER
|
|
407
|
+
)
|
|
408
|
+
print(f"raw_model_path: {raw_model_path}")
|
|
409
|
+
# 构建 udf 函数
|
|
410
|
+
predict_udf =self._spark_client.get_predict_udf(
|
|
411
|
+
raw_model_path,
|
|
412
|
+
result_type=result_type,
|
|
413
|
+
env_manager=env_manager,
|
|
414
|
+
params=params,
|
|
415
|
+
prebuilt_env_uri=kwargs.get(_PREBUILT_ENV_URI, None))
|
|
416
|
+
# TODO (ML-17260) Consider reading the timestamp from the backend instead of feature store artifacts
|
|
417
|
+
ml_model = Model.load(
|
|
418
|
+
os.path.join(local_path, constants.ML_MODEL)
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Validate that columns needed for joining feature tables exist and are not duplicates.
|
|
422
|
+
feature_input_keys = []
|
|
423
|
+
for fci in feature_spec.feature_column_infos:
|
|
424
|
+
feature_input_keys.extend([k for k in fci.lookup_key])
|
|
425
|
+
|
|
426
|
+
on_demand_input_names = uc_utils.get_unique_list_order(
|
|
427
|
+
[
|
|
428
|
+
input_name
|
|
429
|
+
for odci in feature_spec.on_demand_column_infos
|
|
430
|
+
for input_name in odci.input_bindings.values()
|
|
431
|
+
]
|
|
432
|
+
)
|
|
433
|
+
intermediate_inputs = set(feature_input_keys + on_demand_input_names)
|
|
434
|
+
source_data_names = [
|
|
435
|
+
sdci.name for sdci in feature_spec.source_data_column_infos
|
|
436
|
+
]
|
|
437
|
+
|
|
438
|
+
feature_output_names = [
|
|
439
|
+
fci.output_name for fci in feature_spec.feature_column_infos
|
|
440
|
+
]
|
|
441
|
+
on_demand_output_names = [
|
|
442
|
+
odci.output_name for odci in feature_spec.on_demand_column_infos
|
|
443
|
+
]
|
|
444
|
+
all_output_names = set(
|
|
445
|
+
source_data_names + feature_output_names + on_demand_output_names
|
|
446
|
+
)
|
|
447
|
+
required_cols = intermediate_inputs.difference(all_output_names)
|
|
448
|
+
required_cols.update(source_data_names)
|
|
449
|
+
|
|
450
|
+
missing_required_columns = [
|
|
451
|
+
col for col in required_cols if col not in df.columns
|
|
452
|
+
]
|
|
453
|
+
if missing_required_columns:
|
|
454
|
+
missing_columns_formatted = ", ".join(
|
|
455
|
+
[f"'{s}'" for s in missing_required_columns]
|
|
456
|
+
)
|
|
457
|
+
raise ValueError(
|
|
458
|
+
f"DataFrame is missing required columns {missing_columns_formatted}."
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
table_names = {fci.table_name for fci in feature_spec.feature_column_infos}
|
|
462
|
+
feature_table_features_map = training_set_utils.get_features_for_tables(
|
|
463
|
+
self._spark_client, table_names=table_names
|
|
464
|
+
)
|
|
465
|
+
feature_table_metadata_map = (
|
|
466
|
+
training_set_utils.get_feature_table_metadata_for_tables(
|
|
467
|
+
self._spark_client,
|
|
468
|
+
table_names=table_names,
|
|
469
|
+
)
|
|
470
|
+
)
|
|
471
|
+
feature_table_data_map = training_set_utils.load_feature_data_for_tables(
|
|
472
|
+
self._spark_client, table_names=table_names
|
|
473
|
+
)
|
|
474
|
+
training_set_utils.validate_feature_column_infos_data(
|
|
475
|
+
self._spark_client,
|
|
476
|
+
feature_spec.feature_column_infos,
|
|
477
|
+
feature_table_features_map,
|
|
478
|
+
feature_table_data_map,
|
|
479
|
+
)
|
|
480
|
+
|
|
481
|
+
uc_function_infos = training_set_utils.get_uc_function_infos(
|
|
482
|
+
self._spark_client,
|
|
483
|
+
{odci.udf_name for odci in feature_spec.on_demand_column_infos},
|
|
484
|
+
)
|
|
485
|
+
|
|
486
|
+
# Required source data and feature lookup keys have been validated to exist in `df`.
|
|
487
|
+
# No additional validation is required before resolving FeatureLookups and applying FeatureFunctions.
|
|
488
|
+
training_set_utils.warn_if_non_photon_for_native_spark(
|
|
489
|
+
kwargs.get(_USE_SPARK_NATIVE_JOIN, False), self._spark_client
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
augmented_df = TrainingSet(
|
|
493
|
+
feature_spec=feature_spec,
|
|
494
|
+
df=df,
|
|
495
|
+
labels=[],
|
|
496
|
+
feature_table_metadata_map=feature_table_metadata_map,
|
|
497
|
+
feature_table_data_map=feature_table_data_map,
|
|
498
|
+
uc_function_infos=uc_function_infos,
|
|
499
|
+
use_spark_native_join=kwargs.get(_USE_SPARK_NATIVE_JOIN, False),
|
|
500
|
+
)._augment_df()
|
|
501
|
+
# Only included FeatureSpec columns should be part of UDF inputs for scoring.
|
|
502
|
+
# Note: extra `df` columns not in FeatureSpec should be preserved.
|
|
503
|
+
|
|
504
|
+
udf_input_columns = [
|
|
505
|
+
ci.output_name for ci in feature_spec.column_infos if ci.include
|
|
506
|
+
]
|
|
507
|
+
print(f"udf_input_columns:{udf_input_columns}")
|
|
508
|
+
# Apply predictions.
|
|
509
|
+
df_with_predictions = augmented_df.withColumn(
|
|
510
|
+
PREDICTION_COLUMN_NAME, predict_udf(struct(*udf_input_columns))
|
|
511
|
+
)
|
|
512
|
+
# Reorder `df_with_predictions` to include:
|
|
513
|
+
# 1. Preserved `df` columns, in `df` column order.
|
|
514
|
+
# 2. Computed model input columns, in `FeatureSpec` column order.
|
|
515
|
+
# 3. Prediction column.
|
|
516
|
+
output_column_order = (
|
|
517
|
+
df.columns
|
|
518
|
+
+ [col for col in udf_input_columns if col not in df.columns]
|
|
519
|
+
+ [PREDICTION_COLUMN_NAME]
|
|
520
|
+
)
|
|
521
|
+
return_df = df_with_predictions.select(output_column_order)
|
|
522
|
+
return return_df
|
|
523
|
+
|
|
524
|
+
def _warn_if_tables_mismatched_for_model(
|
|
525
|
+
self,
|
|
526
|
+
feature_spec: FeatureSpec,
|
|
527
|
+
feature_table_metadata_map: Dict[str, FeatureTable],
|
|
528
|
+
model_creation_timestamp_ms: float,
|
|
529
|
+
):
|
|
530
|
+
"""
|
|
531
|
+
Helper method to warn if feature tables were deleted and recreated after a model was logged.
|
|
532
|
+
For newer FeatureSpec versions >=3, we can compare the FeatureSpec and current table ids.
|
|
533
|
+
Otherwise, we compare the model and table creation timestamps.
|
|
534
|
+
"""
|
|
535
|
+
# 1. Compare feature table ids
|
|
536
|
+
# Check for feature_spec logged with client versions that supports table_infos
|
|
537
|
+
if len(feature_spec.table_infos) > 0:
|
|
538
|
+
# When feature_spec.yaml is parsed, FeatureSpec.load will assert
|
|
539
|
+
# that the listed table names in input_tables match table names in input_columns.
|
|
540
|
+
# The following code assumes this as invariant and only checks for the table IDs.
|
|
541
|
+
mismatched_tables = []
|
|
542
|
+
for table_info in feature_spec.table_infos:
|
|
543
|
+
feature_table = feature_table_metadata_map[table_info.table_name]
|
|
544
|
+
if feature_table :
|
|
545
|
+
mismatched_tables.append(table_info.table_name)
|
|
546
|
+
if len(mismatched_tables) > 0:
|
|
547
|
+
plural = len(mismatched_tables) > 1
|
|
548
|
+
_logger.warning(
|
|
549
|
+
f"Feature table{'s' if plural else ''} {', '.join(mismatched_tables)} "
|
|
550
|
+
f"{'were' if plural else 'was'} deleted and recreated after "
|
|
551
|
+
f"the model was trained. Model performance may be affected if the features "
|
|
552
|
+
f"used in scoring have drifted from the features used in training."
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# 2. 无法获取创建时间,不做校验
|
|
556
|
+
# feature_tables_created_after_model = []
|
|
557
|
+
# for name, metadata in feature_table_metadata_map.items():
|
|
558
|
+
# if model_creation_timestamp_ms < metadata.creation_timestamp:
|
|
559
|
+
# feature_tables_created_after_model.append(name)
|
|
560
|
+
#
|
|
561
|
+
# if len(feature_tables_created_after_model) > 0:
|
|
562
|
+
# plural = len(feature_tables_created_after_model) > 1
|
|
563
|
+
# message = (
|
|
564
|
+
# f"Feature table{'s' if plural else ''} {', '.join(feature_tables_created_after_model)} "
|
|
565
|
+
# f"{'were' if plural else 'was'} created after the model was logged. "
|
|
566
|
+
# f"Model performance may be affected if the features used in scoring have drifted "
|
|
567
|
+
# f"from the features used in training."
|
|
568
|
+
# )
|
|
569
|
+
# _logger.warning(message)
|
|
File without changes
|