tencent-wedata-feature-engineering-dev 0.1.48__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {tencent_wedata_feature_engineering_dev-0.1.48.dist-info → tencent_wedata_feature_engineering_dev-0.2.5.dist-info}/METADATA +14 -3
  2. tencent_wedata_feature_engineering_dev-0.2.5.dist-info/RECORD +78 -0
  3. {tencent_wedata_feature_engineering_dev-0.1.48.dist-info → tencent_wedata_feature_engineering_dev-0.2.5.dist-info}/WHEEL +1 -1
  4. wedata/__init__.py +1 -1
  5. wedata/common/base_table_client/__init__.py +1 -0
  6. wedata/common/base_table_client/base.py +58 -0
  7. wedata/common/cloud_sdk_client/__init__.py +2 -0
  8. wedata/{feature_store → common}/cloud_sdk_client/client.py +33 -3
  9. wedata/{feature_store → common}/cloud_sdk_client/models.py +212 -37
  10. wedata/{feature_store → common}/cloud_sdk_client/utils.py +7 -0
  11. wedata/{feature_store → common}/constants/constants.py +3 -2
  12. wedata/common/constants/engine_types.py +34 -0
  13. wedata/{feature_store → common}/entities/column_info.py +6 -5
  14. wedata/{feature_store → common}/entities/feature_column_info.py +2 -1
  15. wedata/{feature_store → common}/entities/feature_lookup.py +1 -1
  16. wedata/{feature_store → common}/entities/feature_spec.py +9 -9
  17. wedata/{feature_store → common}/entities/feature_table_info.py +1 -1
  18. wedata/{feature_store → common}/entities/function_info.py +2 -1
  19. wedata/{feature_store → common}/entities/on_demand_column_info.py +2 -1
  20. wedata/{feature_store → common}/entities/source_data_column_info.py +3 -1
  21. wedata/{feature_store → common}/entities/training_set.py +6 -6
  22. wedata/common/feast_client/__init__.py +1 -0
  23. wedata/{feature_store → common}/feast_client/feast_client.py +1 -1
  24. wedata/common/log/__init__.py +1 -0
  25. wedata/{feature_store/common → common}/log/logger.py +9 -5
  26. wedata/common/spark_client/__init__.py +1 -0
  27. wedata/{feature_store → common}/spark_client/spark_client.py +6 -7
  28. wedata/{feature_store → common}/utils/common_utils.py +7 -9
  29. wedata/{feature_store → common}/utils/env_utils.py +12 -0
  30. wedata/{feature_store → common}/utils/feature_lookup_utils.py +6 -6
  31. wedata/{feature_store → common}/utils/feature_spec_utils.py +13 -8
  32. wedata/{feature_store → common}/utils/feature_utils.py +5 -5
  33. wedata/{feature_store → common}/utils/on_demand_utils.py +5 -4
  34. wedata/{feature_store → common}/utils/schema_utils.py +1 -1
  35. wedata/{feature_store → common}/utils/signature_utils.py +4 -4
  36. wedata/{feature_store → common}/utils/training_set_utils.py +13 -13
  37. wedata/{feature_store → common}/utils/uc_utils.py +1 -1
  38. wedata/feature_engineering/__init__.py +1 -0
  39. wedata/feature_engineering/client.py +417 -0
  40. wedata/feature_engineering/ml_training_client/ml_training_client.py +569 -0
  41. wedata/feature_engineering/mlflow_model.py +9 -0
  42. wedata/feature_engineering/table_client/table_client.py +548 -0
  43. wedata/feature_store/client.py +11 -15
  44. wedata/feature_store/constants/engine_types.py +8 -30
  45. wedata/feature_store/feature_table_client/feature_table_client.py +73 -105
  46. wedata/feature_store/training_set_client/training_set_client.py +12 -23
  47. wedata/tempo/interpol.py +2 -2
  48. tencent_wedata_feature_engineering_dev-0.1.48.dist-info/RECORD +0 -66
  49. {tencent_wedata_feature_engineering_dev-0.1.48.dist-info → tencent_wedata_feature_engineering_dev-0.2.5.dist-info}/top_level.txt +0 -0
  50. /wedata/{feature_store/cloud_sdk_client → common}/__init__.py +0 -0
  51. /wedata/{feature_store/common/log → common/constants}/__init__.py +0 -0
  52. /wedata/{feature_store/common/protos → common/entities}/__init__.py +0 -0
  53. /wedata/{feature_store → common}/entities/environment_variables.py +0 -0
  54. /wedata/{feature_store → common}/entities/feature.py +0 -0
  55. /wedata/{feature_store → common}/entities/feature_function.py +0 -0
  56. /wedata/{feature_store → common}/entities/feature_spec_constants.py +0 -0
  57. /wedata/{feature_store → common}/entities/feature_table.py +0 -0
  58. /wedata/{feature_store/entities → common/protos}/__init__.py +0 -0
  59. /wedata/{feature_store/common → common}/protos/feature_store_pb2.py +0 -0
  60. /wedata/{feature_store/feast_client → common/utils}/__init__.py +0 -0
  61. /wedata/{feature_store → common}/utils/topological_sort.py +0 -0
  62. /wedata/{feature_store → common}/utils/validation_utils.py +0 -0
  63. /wedata/{feature_store/spark_client → feature_engineering/ml_training_client}/__init__.py +0 -0
  64. /wedata/{feature_store/utils → feature_engineering/table_client}/__init__.py +0 -0
@@ -0,0 +1,569 @@
1
+ import logging
2
+ import os
3
+ from types import ModuleType
4
+ from typing import Any, List, Optional, Union, Dict
5
+
6
+ import mlflow
7
+ from mlflow.models import Model
8
+ from mlflow.utils.file_utils import TempDir, read_yaml
9
+ from pyspark.sql import DataFrame
10
+ from pyspark.sql.functions import struct
11
+
12
+ from wedata.common.constants import constants
13
+ from wedata.common.entities.feature_function import FeatureFunction
14
+ from wedata.common.entities.feature_lookup import FeatureLookup
15
+ from wedata.common.entities.feature_spec import FeatureSpec
16
+ from wedata.common.entities.training_set import TrainingSet
17
+ from wedata.feature_engineering.mlflow_model import _FeatureEngineeringModelWrapper
18
+ from wedata.common.spark_client import SparkClient
19
+ from wedata.common.utils import validation_utils, common_utils, training_set_utils
20
+ from wedata.common.entities.feature_table import FeatureTable
21
+
22
+ from wedata.common.constants.constants import (
23
+ _NO_RESULT_TYPE_PASSED,
24
+ _USE_SPARK_NATIVE_JOIN,
25
+ MODEL_DATA_PATH_ROOT,
26
+ PREDICTION_COLUMN_NAME,
27
+ _PREBUILT_ENV_URI
28
+ )
29
+
30
+ from wedata.common.utils import uc_utils
31
+ from wedata.common.utils.signature_utils import get_mlflow_signature_from_feature_spec, \
32
+ drop_signature_inputs_and_invalid_params
33
+
34
+ _logger = logging.getLogger(__name__)
35
+
36
+ FEATURE_SPEC_GRAPH_MAX_COLUMN_INFO = 1000
37
+
38
+
39
+ class MLTrainingClient:
40
+ def __init__(
41
+ self,
42
+ spark_client: SparkClient
43
+ ):
44
+ self._spark_client = spark_client
45
+
46
+ def create_training_set(
47
+ self,
48
+ feature_spec: FeatureSpec,
49
+ label_names: List[str],
50
+ df: DataFrame,
51
+ ft_metadata: training_set_utils._FeatureTableMetadata,
52
+ kwargs,
53
+ ):
54
+ uc_function_infos = training_set_utils.get_uc_function_infos(
55
+ self._spark_client,
56
+ {odci.udf_name for odci in feature_spec.on_demand_column_infos},
57
+ )
58
+
59
+ training_set_utils.warn_if_non_photon_for_native_spark(
60
+ kwargs.get(_USE_SPARK_NATIVE_JOIN, False), self._spark_client
61
+ )
62
+ return TrainingSet(
63
+ feature_spec,
64
+ df,
65
+ label_names,
66
+ ft_metadata.feature_table_metadata_map,
67
+ ft_metadata.feature_table_data_map,
68
+ uc_function_infos,
69
+ kwargs.get(_USE_SPARK_NATIVE_JOIN, False),
70
+ )
71
+
72
+ def create_training_set_from_feature_lookups(
73
+ self,
74
+ df: DataFrame,
75
+ feature_lookups: List[Union[FeatureLookup, FeatureFunction]],
76
+ label: Union[str, List[str], None],
77
+ exclude_columns: List[str],
78
+ **kwargs,
79
+ ) -> TrainingSet:
80
+
81
+ # 获取特征查找列表和特征函数列表
82
+ features = feature_lookups
83
+ feature_lookups = [f for f in features if isinstance(f, FeatureLookup)]
84
+ feature_functions = [f for f in features if isinstance(f, FeatureFunction)]
85
+
86
+ # 最多支持100个FeatureFunctions
87
+ if len(feature_functions) > training_set_utils.MAX_FEATURE_FUNCTIONS:
88
+ raise ValueError(
89
+ f"A maximum of {training_set_utils.MAX_FEATURE_FUNCTIONS} FeatureFunctions are supported."
90
+ )
91
+
92
+ # 如果未提供标签,则用空列表初始化label_names
93
+ label_names = common_utils.as_list(label, [])
94
+ del label
95
+
96
+ # 校验数据集和标签
97
+ training_set_utils.verify_df_and_labels(df, label_names, exclude_columns)
98
+
99
+ # 获取特征表元数据
100
+ ft_metadata = training_set_utils.get_table_metadata(
101
+ self._spark_client,
102
+ {fl.table_name for fl in feature_lookups}
103
+ )
104
+
105
+ column_infos = training_set_utils.get_column_infos(
106
+ feature_lookups,
107
+ feature_functions,
108
+ ft_metadata,
109
+ df_columns=df.columns,
110
+ label_names=label_names,
111
+ )
112
+
113
+ training_set_utils.validate_column_infos(
114
+ self._spark_client,
115
+ ft_metadata,
116
+ column_infos.source_data_column_infos,
117
+ column_infos.feature_column_infos,
118
+ column_infos.on_demand_column_infos,
119
+ label_names,
120
+ )
121
+
122
+ # Build feature_spec locally for comparison with the feature spec yaml generated by the
123
+ # FeatureStore backend. This will be removed once the migration is validated.
124
+ feature_spec = training_set_utils.build_feature_spec(
125
+ feature_lookups,
126
+ ft_metadata,
127
+ column_infos,
128
+ exclude_columns
129
+ )
130
+
131
+ return self.create_training_set(
132
+ feature_spec,
133
+ label_names,
134
+ df,
135
+ ft_metadata,
136
+ kwargs=kwargs,
137
+ )
138
+
139
+
140
+ def create_feature_spec(
141
+ self,
142
+ name: str,
143
+ features: List[Union[FeatureLookup, FeatureFunction]],
144
+ sparkClient: SparkClient,
145
+ exclude_columns: List[str] = [],
146
+ ) -> FeatureSpec:
147
+
148
+ feature_lookups = [f for f in features if isinstance(f, FeatureLookup)]
149
+ feature_functions = [f for f in features if isinstance(f, FeatureFunction)]
150
+
151
+ # Maximum of 100 FeatureFunctions is supported
152
+ if len(feature_functions) > training_set_utils.MAX_FEATURE_FUNCTIONS:
153
+ raise ValueError(
154
+ f"A maximum of {training_set_utils.MAX_FEATURE_FUNCTIONS} FeatureFunctions are supported."
155
+ )
156
+
157
+ # Get feature table metadata and column infos
158
+ ft_metadata = training_set_utils.get_table_metadata(
159
+ self._spark_client,
160
+ {fl.table_name for fl in feature_lookups}
161
+ )
162
+ column_infos = training_set_utils.get_column_infos(
163
+ feature_lookups,
164
+ feature_functions,
165
+ ft_metadata,
166
+ )
167
+
168
+ column_infos = training_set_utils.add_inferred_source_columns(column_infos)
169
+
170
+ training_set_utils.validate_column_infos(
171
+ self._spark_client,
172
+ ft_metadata,
173
+ column_infos.source_data_column_infos,
174
+ column_infos.feature_column_infos,
175
+ column_infos.on_demand_column_infos,
176
+ )
177
+
178
+ feature_spec = training_set_utils.build_feature_spec(
179
+ feature_lookups,
180
+ ft_metadata,
181
+ column_infos,
182
+ exclude_columns
183
+ )
184
+
185
+ return feature_spec
186
+
187
+
188
+ def log_model(
189
+ self,
190
+ model: Any,
191
+ artifact_path: str,
192
+ *,
193
+ flavor: ModuleType,
194
+ training_set: Optional[TrainingSet],
195
+ registered_model_name: Optional[str],
196
+ model_registry_uri: Optional[str],
197
+ await_registration_for: int,
198
+ infer_input_example: bool,
199
+ **kwargs,
200
+ ):
201
+ # 验证training_set参数是否提供
202
+ if (training_set is None):
203
+ raise ValueError(
204
+ "'training_set' must be provided, but not ."
205
+ )
206
+
207
+ # 获取特征规格并重新格式化表名
208
+ # training_set.feature_spec保证来自FeatureStoreClient.create_training_set的3L格式
209
+ feature_spec = uc_utils.get_feature_spec_with_reformat_full_table_names(
210
+ training_set.feature_spec
211
+ )
212
+
213
+ # 获取标签类型映射和标签
214
+ label_type_map = training_set._label_data_types
215
+
216
+ # 收集所有特征列名
217
+ feature_columns = [
218
+ feature_column.output_name
219
+ for feature_column in feature_spec.feature_column_infos
220
+ ]
221
+ df_head = training_set.load_df().select(*feature_columns).head()
222
+
223
+ # 处理输出模式和参数
224
+ override_output_schema = kwargs.pop("output_schema", None)
225
+ params = kwargs.pop("params", {})
226
+ params["result_type"] = params.get("result_type", _NO_RESULT_TYPE_PASSED)
227
+
228
+ # 尝试获取MLflow签名
229
+ try:
230
+ signature = get_mlflow_signature_from_feature_spec(
231
+ feature_spec, label_type_map, override_output_schema, params
232
+ )
233
+ except Exception as e:
234
+ _logger.warning(f"Model could not be logged with a signature: {e}")
235
+ signature = None
236
+
237
+ with TempDir() as tmp_location:
238
+ # wedata data_path路径,改为记录表路径,遍历表名,生成数组
239
+ data_path = os.path.join(tmp_location.path(), "feature_store")
240
+ os.makedirs(data_path, exist_ok=True)
241
+
242
+ # 创建原始MLflow模型
243
+ raw_mlflow_model = Model(
244
+ signature=drop_signature_inputs_and_invalid_params(signature)
245
+ )
246
+ raw_model_path = os.path.join(data_path, constants.RAW_MODEL_FOLDER)
247
+
248
+ # 根据flavor类型保存模型
249
+ if flavor.FLAVOR_NAME != mlflow.pyfunc.FLAVOR_NAME:
250
+ flavor.save_model(
251
+ model, raw_model_path, mlflow_model=raw_mlflow_model, **kwargs
252
+ )
253
+ else:
254
+ flavor.save_model(
255
+ raw_model_path,
256
+ mlflow_model=raw_mlflow_model,
257
+ python_model=model,
258
+ **kwargs,
259
+ )
260
+
261
+ # 验证模型是否支持python_function flavor
262
+ if not "python_function" in raw_mlflow_model.flavors:
263
+ raise ValueError(
264
+ f"FeatureStoreClient.log_model does not support '{flavor.__name__}' "
265
+ f"since it does not have a python_function model flavor."
266
+ )
267
+
268
+ # 获取并处理conda环境配置
269
+ model_env = raw_mlflow_model.flavors["python_function"][mlflow.pyfunc.ENV]
270
+ if isinstance(model_env, dict):
271
+ # mlflow 2.0 has multiple supported environments
272
+ conda_file = model_env[mlflow.pyfunc.EnvType.CONDA]
273
+ else:
274
+ conda_file = model_env
275
+
276
+ conda_env = read_yaml(raw_model_path, conda_file)
277
+ #TODO 暂时不需要databricks-feature-lookup这个包,会导致 python 环境创建失败
278
+ # Check if databricks-feature-lookup version is specified in conda_env
279
+ lookup_client_version_specified = False
280
+ for dependency in conda_env.get("dependencies", []):
281
+ if isinstance(dependency, dict):
282
+ for pip_dep in dependency.get("pip", []):
283
+ if pip_dep.startswith(
284
+ constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE
285
+ ):
286
+ lookup_client_version_specified = True
287
+ break
288
+ #TODO 暂时不需要databricks-feature-lookup这个包,会导致 python 环境创建失败
289
+ # If databricks-feature-lookup version is not specified, add default version
290
+ if not lookup_client_version_specified:
291
+ # Get the pip package string for the databricks-feature-lookup client
292
+ default_wedata_feature_lookup_pip_package = common_utils.pip_depependency_pinned_version(
293
+ pip_package_name=constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE,
294
+ version=constants.FEATURE_LOOKUP_CLIENT_MAJOR_VERSION,
295
+ )
296
+ common_utils.add_mlflow_pip_depependency(
297
+ conda_env, default_wedata_feature_lookup_pip_package
298
+ )
299
+
300
+ # 尝试创建输入示例
301
+ input_example = None
302
+ try:
303
+ if df_head is not None and infer_input_example:
304
+ input_example = df_head.asDict()
305
+ except Exception:
306
+ pass
307
+
308
+ feature_spec.save(data_path)
309
+
310
+ print(f'artifact_path:{artifact_path},data_path:{data_path},conda_env:{conda_env},'
311
+ f'signature:{signature},input_example:{input_example}');
312
+
313
+ mlflow.pyfunc.log_model(
314
+ artifact_path=artifact_path,
315
+ python_model=_FeatureEngineeringModelWrapper(model),
316
+ # data_path=data_path,
317
+ artifacts={"feature_store": data_path},
318
+ code_path=None,
319
+ conda_env=conda_env,
320
+ signature=signature,
321
+ input_example=input_example,
322
+ registered_model_name=registered_model_name
323
+ )
324
+
325
+ # mlflow.pyfunc.log_model(
326
+ # artifact_path=artifact_path,
327
+ # loader_module=constants.MLFLOW_MODEL_NAME,
328
+ # data_path=data_path,
329
+ # conda_env=conda_env,
330
+ # signature=signature,
331
+ # input_example=input_example,
332
+ # )
333
+
334
+ # 注册模型
335
+ # if registered_model_name is not None:
336
+ # run_id = mlflow.tracking.fluent.active_run().info.run_id
337
+ # if model_registry_uri is not None:
338
+ # mlflow.set_registry_uri(model_registry_uri)
339
+ #
340
+ # mlflow.register_model(
341
+ # f"runs:/{run_id}/{artifact_path}",
342
+ # registered_model_name,
343
+ # await_registration_for=await_registration_for,
344
+ # )
345
+ #
346
+ # print(f"Model registered successfully: {registered_model_name}")
347
+
348
+ # # 验证模型是否已注册
349
+ # from mlflow.tracking import MlflowClient
350
+ # client = MlflowClient()
351
+ # model_version = client.get_latest_versions(registered_model_name, stages=["None"])[0]
352
+ # print(f"Registered model version: {model_version.version}")
353
+
354
+ def score_batch(
355
+ self,
356
+ model_uri: Optional[str],
357
+ df: DataFrame,
358
+ result_type: str,
359
+ env_manager: Optional[str] = None,
360
+ local_uri: Optional[str] = None,
361
+ params: Optional[dict[str, Any]] = None,
362
+ timestamp_key: str = None,
363
+ **kwargs,
364
+ ) -> DataFrame:
365
+ # TODO:ML 待确定是否需要
366
+ # req_context = RequestContext(request_context.SCORE_BATCH, client_name)
367
+
368
+ # 校验是否
369
+ validation_utils.check_dataframe_type(df)
370
+ if (model_uri is None) == (local_uri is None):
371
+ raise ValueError(
372
+ "Either 'model_uri' or 'local_uri' must be provided, but not both."
373
+ )
374
+ if df.isStreaming:
375
+ raise ValueError("Streaming DataFrames are not supported.")
376
+
377
+ # 返回结果中会包含列名为 prediction,为预测结果,输入数据中不用此名字
378
+ if PREDICTION_COLUMN_NAME in df.columns:
379
+ raise ValueError(
380
+ "FeatureStoreClient.score_batch returns a DataFrame with a new column "
381
+ f'"{PREDICTION_COLUMN_NAME}". df already has a column with name '
382
+ f'"{PREDICTION_COLUMN_NAME}".'
383
+ )
384
+
385
+ # 校验列中是否有重复列名
386
+ validation_utils.validate_strings_unique(
387
+ df.columns,
388
+ "The provided DataFrame for scoring must have unique column names. Found duplicates {}.",
389
+ )
390
+ artifact_path = os.path.join("artifacts", MODEL_DATA_PATH_ROOT)
391
+ with (TempDir() as tmp_location):
392
+ local_path = (
393
+ local_uri
394
+ if local_uri
395
+ else common_utils.download_model_artifacts(model_uri, tmp_location.path())
396
+ )
397
+ model_data_path = os.path.join(local_path, artifact_path)
398
+
399
+ # Augment local workspace metastore tables from 2L to 3L,
400
+ # this will prevent us from erroneously reading data from other catalogs
401
+ feature_spec = uc_utils.get_feature_spec_with_full_table_names(
402
+ FeatureSpec.load(model_data_path)
403
+ )
404
+
405
+ raw_model_path = os.path.join(
406
+ model_data_path, constants.RAW_MODEL_FOLDER
407
+ )
408
+ print(f"raw_model_path: {raw_model_path}")
409
+ # 构建 udf 函数
410
+ predict_udf =self._spark_client.get_predict_udf(
411
+ raw_model_path,
412
+ result_type=result_type,
413
+ env_manager=env_manager,
414
+ params=params,
415
+ prebuilt_env_uri=kwargs.get(_PREBUILT_ENV_URI, None))
416
+ # TODO (ML-17260) Consider reading the timestamp from the backend instead of feature store artifacts
417
+ ml_model = Model.load(
418
+ os.path.join(local_path, constants.ML_MODEL)
419
+ )
420
+
421
+ # Validate that columns needed for joining feature tables exist and are not duplicates.
422
+ feature_input_keys = []
423
+ for fci in feature_spec.feature_column_infos:
424
+ feature_input_keys.extend([k for k in fci.lookup_key])
425
+
426
+ on_demand_input_names = uc_utils.get_unique_list_order(
427
+ [
428
+ input_name
429
+ for odci in feature_spec.on_demand_column_infos
430
+ for input_name in odci.input_bindings.values()
431
+ ]
432
+ )
433
+ intermediate_inputs = set(feature_input_keys + on_demand_input_names)
434
+ source_data_names = [
435
+ sdci.name for sdci in feature_spec.source_data_column_infos
436
+ ]
437
+
438
+ feature_output_names = [
439
+ fci.output_name for fci in feature_spec.feature_column_infos
440
+ ]
441
+ on_demand_output_names = [
442
+ odci.output_name for odci in feature_spec.on_demand_column_infos
443
+ ]
444
+ all_output_names = set(
445
+ source_data_names + feature_output_names + on_demand_output_names
446
+ )
447
+ required_cols = intermediate_inputs.difference(all_output_names)
448
+ required_cols.update(source_data_names)
449
+
450
+ missing_required_columns = [
451
+ col for col in required_cols if col not in df.columns
452
+ ]
453
+ if missing_required_columns:
454
+ missing_columns_formatted = ", ".join(
455
+ [f"'{s}'" for s in missing_required_columns]
456
+ )
457
+ raise ValueError(
458
+ f"DataFrame is missing required columns {missing_columns_formatted}."
459
+ )
460
+
461
+ table_names = {fci.table_name for fci in feature_spec.feature_column_infos}
462
+ feature_table_features_map = training_set_utils.get_features_for_tables(
463
+ self._spark_client, table_names=table_names
464
+ )
465
+ feature_table_metadata_map = (
466
+ training_set_utils.get_feature_table_metadata_for_tables(
467
+ self._spark_client,
468
+ table_names=table_names,
469
+ )
470
+ )
471
+ feature_table_data_map = training_set_utils.load_feature_data_for_tables(
472
+ self._spark_client, table_names=table_names
473
+ )
474
+ training_set_utils.validate_feature_column_infos_data(
475
+ self._spark_client,
476
+ feature_spec.feature_column_infos,
477
+ feature_table_features_map,
478
+ feature_table_data_map,
479
+ )
480
+
481
+ uc_function_infos = training_set_utils.get_uc_function_infos(
482
+ self._spark_client,
483
+ {odci.udf_name for odci in feature_spec.on_demand_column_infos},
484
+ )
485
+
486
+ # Required source data and feature lookup keys have been validated to exist in `df`.
487
+ # No additional validation is required before resolving FeatureLookups and applying FeatureFunctions.
488
+ training_set_utils.warn_if_non_photon_for_native_spark(
489
+ kwargs.get(_USE_SPARK_NATIVE_JOIN, False), self._spark_client
490
+ )
491
+
492
+ augmented_df = TrainingSet(
493
+ feature_spec=feature_spec,
494
+ df=df,
495
+ labels=[],
496
+ feature_table_metadata_map=feature_table_metadata_map,
497
+ feature_table_data_map=feature_table_data_map,
498
+ uc_function_infos=uc_function_infos,
499
+ use_spark_native_join=kwargs.get(_USE_SPARK_NATIVE_JOIN, False),
500
+ )._augment_df()
501
+ # Only included FeatureSpec columns should be part of UDF inputs for scoring.
502
+ # Note: extra `df` columns not in FeatureSpec should be preserved.
503
+
504
+ udf_input_columns = [
505
+ ci.output_name for ci in feature_spec.column_infos if ci.include
506
+ ]
507
+ print(f"udf_input_columns:{udf_input_columns}")
508
+ # Apply predictions.
509
+ df_with_predictions = augmented_df.withColumn(
510
+ PREDICTION_COLUMN_NAME, predict_udf(struct(*udf_input_columns))
511
+ )
512
+ # Reorder `df_with_predictions` to include:
513
+ # 1. Preserved `df` columns, in `df` column order.
514
+ # 2. Computed model input columns, in `FeatureSpec` column order.
515
+ # 3. Prediction column.
516
+ output_column_order = (
517
+ df.columns
518
+ + [col for col in udf_input_columns if col not in df.columns]
519
+ + [PREDICTION_COLUMN_NAME]
520
+ )
521
+ return_df = df_with_predictions.select(output_column_order)
522
+ return return_df
523
+
524
+ def _warn_if_tables_mismatched_for_model(
525
+ self,
526
+ feature_spec: FeatureSpec,
527
+ feature_table_metadata_map: Dict[str, FeatureTable],
528
+ model_creation_timestamp_ms: float,
529
+ ):
530
+ """
531
+ Helper method to warn if feature tables were deleted and recreated after a model was logged.
532
+ For newer FeatureSpec versions >=3, we can compare the FeatureSpec and current table ids.
533
+ Otherwise, we compare the model and table creation timestamps.
534
+ """
535
+ # 1. Compare feature table ids
536
+ # Check for feature_spec logged with client versions that supports table_infos
537
+ if len(feature_spec.table_infos) > 0:
538
+ # When feature_spec.yaml is parsed, FeatureSpec.load will assert
539
+ # that the listed table names in input_tables match table names in input_columns.
540
+ # The following code assumes this as invariant and only checks for the table IDs.
541
+ mismatched_tables = []
542
+ for table_info in feature_spec.table_infos:
543
+ feature_table = feature_table_metadata_map[table_info.table_name]
544
+ if feature_table :
545
+ mismatched_tables.append(table_info.table_name)
546
+ if len(mismatched_tables) > 0:
547
+ plural = len(mismatched_tables) > 1
548
+ _logger.warning(
549
+ f"Feature table{'s' if plural else ''} {', '.join(mismatched_tables)} "
550
+ f"{'were' if plural else 'was'} deleted and recreated after "
551
+ f"the model was trained. Model performance may be affected if the features "
552
+ f"used in scoring have drifted from the features used in training."
553
+ )
554
+
555
+ # 2. 无法获取创建时间,不做校验
556
+ # feature_tables_created_after_model = []
557
+ # for name, metadata in feature_table_metadata_map.items():
558
+ # if model_creation_timestamp_ms < metadata.creation_timestamp:
559
+ # feature_tables_created_after_model.append(name)
560
+ #
561
+ # if len(feature_tables_created_after_model) > 0:
562
+ # plural = len(feature_tables_created_after_model) > 1
563
+ # message = (
564
+ # f"Feature table{'s' if plural else ''} {', '.join(feature_tables_created_after_model)} "
565
+ # f"{'were' if plural else 'was'} created after the model was logged. "
566
+ # f"Model performance may be affected if the features used in scoring have drifted "
567
+ # f"from the features used in training."
568
+ # )
569
+ # _logger.warning(message)
@@ -0,0 +1,9 @@
1
+ import mlflow
2
+
3
+ class _FeatureEngineeringModelWrapper(mlflow.pyfunc.PythonModel):
4
+ def __init__(self, model):
5
+ self.model = model
6
+
7
+ def predict(self, context, model_input):
8
+ return self.model.predict(model_input)
9
+