wedata-feature-engineering 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wedata/__init__.py +1 -1
- wedata/feature_store/client.py +113 -41
- wedata/feature_store/constants/constants.py +19 -0
- wedata/feature_store/entities/column_info.py +4 -4
- wedata/feature_store/entities/feature_lookup.py +5 -1
- wedata/feature_store/entities/feature_spec.py +46 -46
- wedata/feature_store/entities/feature_table.py +42 -99
- wedata/feature_store/entities/training_set.py +13 -12
- wedata/feature_store/feature_table_client/feature_table_client.py +86 -31
- wedata/feature_store/spark_client/spark_client.py +30 -56
- wedata/feature_store/training_set_client/training_set_client.py +209 -38
- wedata/feature_store/utils/common_utils.py +213 -3
- wedata/feature_store/utils/feature_lookup_utils.py +6 -6
- wedata/feature_store/utils/feature_spec_utils.py +6 -6
- wedata/feature_store/utils/feature_utils.py +5 -5
- wedata/feature_store/utils/on_demand_utils.py +107 -0
- wedata/feature_store/utils/schema_utils.py +1 -1
- wedata/feature_store/utils/signature_utils.py +205 -0
- wedata/feature_store/utils/training_set_utils.py +18 -19
- wedata/feature_store/utils/uc_utils.py +1 -1
- {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.7.dist-info}/METADATA +1 -1
- wedata_feature_engineering-0.1.7.dist-info/RECORD +43 -0
- feature_store/__init__.py +0 -6
- feature_store/client.py +0 -169
- feature_store/constants/__init__.py +0 -0
- feature_store/constants/constants.py +0 -28
- feature_store/entities/__init__.py +0 -0
- feature_store/entities/column_info.py +0 -117
- feature_store/entities/data_type.py +0 -92
- feature_store/entities/environment_variables.py +0 -55
- feature_store/entities/feature.py +0 -53
- feature_store/entities/feature_column_info.py +0 -64
- feature_store/entities/feature_function.py +0 -55
- feature_store/entities/feature_lookup.py +0 -179
- feature_store/entities/feature_spec.py +0 -454
- feature_store/entities/feature_spec_constants.py +0 -25
- feature_store/entities/feature_table.py +0 -164
- feature_store/entities/feature_table_info.py +0 -40
- feature_store/entities/function_info.py +0 -184
- feature_store/entities/on_demand_column_info.py +0 -44
- feature_store/entities/source_data_column_info.py +0 -21
- feature_store/entities/training_set.py +0 -134
- feature_store/feature_table_client/__init__.py +0 -0
- feature_store/feature_table_client/feature_table_client.py +0 -313
- feature_store/spark_client/__init__.py +0 -0
- feature_store/spark_client/spark_client.py +0 -286
- feature_store/training_set_client/__init__.py +0 -0
- feature_store/training_set_client/training_set_client.py +0 -196
- feature_store/utils/__init__.py +0 -0
- feature_store/utils/common_utils.py +0 -96
- feature_store/utils/feature_lookup_utils.py +0 -570
- feature_store/utils/feature_spec_utils.py +0 -286
- feature_store/utils/feature_utils.py +0 -73
- feature_store/utils/schema_utils.py +0 -117
- feature_store/utils/topological_sort.py +0 -158
- feature_store/utils/training_set_utils.py +0 -580
- feature_store/utils/uc_utils.py +0 -281
- feature_store/utils/utils.py +0 -252
- feature_store/utils/validation_utils.py +0 -55
- wedata/feature_store/utils/utils.py +0 -252
- wedata_feature_engineering-0.1.5.dist-info/RECORD +0 -79
- {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.7.dist-info}/WHEEL +0 -0
- {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.7.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,18 @@ from pyspark.sql.streaming import StreamingQuery
|
|
8
8
|
from pyspark.sql.types import StructType
|
9
9
|
import os
|
10
10
|
|
11
|
-
from feature_store.constants.constants import APPEND, DEFAULT_WRITE_STREAM_TRIGGER
|
11
|
+
from wedata.feature_store.constants.constants import APPEND, DEFAULT_WRITE_STREAM_TRIGGER
|
12
|
+
from wedata.feature_store.entities.feature_table import FeatureTable
|
13
|
+
from wedata.feature_store.spark_client.spark_client import SparkClient
|
14
|
+
from wedata.feature_store.utils import common_utils
|
12
15
|
|
13
16
|
|
14
17
|
class FeatureTableClient:
|
15
18
|
"""特征表操作类"""
|
16
19
|
|
17
20
|
def __init__(
|
18
|
-
|
19
|
-
|
21
|
+
self,
|
22
|
+
spark: SparkSession
|
20
23
|
):
|
21
24
|
self._spark = spark
|
22
25
|
|
@@ -46,12 +49,6 @@ class FeatureTableClient:
|
|
46
49
|
f"DataFrame与schema不匹配。差异字段: {diff_fields if diff_fields else '字段类型不一致'}"
|
47
50
|
)
|
48
51
|
|
49
|
-
@staticmethod
|
50
|
-
def _validate_table_name(name: str):
|
51
|
-
"""验证特征表命名规范"""
|
52
|
-
if name.count('.') < 2:
|
53
|
-
raise ValueError("特征表名称需符合<catalog>.<schema>.<table>格式")
|
54
|
-
|
55
52
|
@staticmethod
|
56
53
|
def _validate_key_conflicts(primary_keys: List[str], timestamp_keys: List[str]):
|
57
54
|
"""校验主键与时间戳键是否冲突"""
|
@@ -75,7 +72,8 @@ class FeatureTableClient:
|
|
75
72
|
schema: Optional[StructType] = None,
|
76
73
|
description: Optional[str] = None,
|
77
74
|
tags: Optional[Dict[str, str]] = None
|
78
|
-
):
|
75
|
+
) -> FeatureTable:
|
76
|
+
|
79
77
|
"""
|
80
78
|
创建特征表(支持批流数据写入)
|
81
79
|
|
@@ -85,6 +83,7 @@ class FeatureTableClient:
|
|
85
83
|
df: 初始数据(可选,用于推断schema)
|
86
84
|
timestamp_keys: 时间戳键(用于时态特征)
|
87
85
|
partition_columns: 分区列(优化存储查询)
|
86
|
+
schema: 表结构定义(可选,当不提供df时必需)
|
88
87
|
description: 业务描述
|
89
88
|
tags: 业务标签
|
90
89
|
|
@@ -94,6 +93,7 @@ class FeatureTableClient:
|
|
94
93
|
Raises:
|
95
94
|
ValueError: 当schema与数据不匹配时
|
96
95
|
"""
|
96
|
+
|
97
97
|
# 参数标准化
|
98
98
|
primary_keys = self._normalize_params(primary_keys)
|
99
99
|
timestamp_keys = self._normalize_params(timestamp_keys)
|
@@ -101,23 +101,25 @@ class FeatureTableClient:
|
|
101
101
|
|
102
102
|
# 元数据校验
|
103
103
|
self._validate_schema(df, schema)
|
104
|
-
#self._validate_table_name(name)
|
105
104
|
self._validate_key_conflicts(primary_keys, timestamp_keys)
|
106
105
|
|
107
|
-
#
|
108
|
-
|
106
|
+
# 表名校验
|
107
|
+
common_utils.validate_table_name(name)
|
108
|
+
|
109
|
+
# 构建完整表名
|
110
|
+
table_name = common_utils.build_full_table_name(name)
|
109
111
|
|
110
112
|
# 检查表是否存在
|
111
113
|
try:
|
112
114
|
if self._spark.catalog.tableExists(table_name):
|
113
115
|
raise ValueError(
|
114
|
-
f"
|
115
|
-
"
|
116
|
-
"1.
|
117
|
-
"2.
|
116
|
+
f"Table '{table_name}' already exists\n"
|
117
|
+
"Solutions:\n"
|
118
|
+
"1. Use a different table name\n"
|
119
|
+
"2. Drop the existing table: spark.sql(f'DROP TABLE {name}')\n"
|
118
120
|
)
|
119
121
|
except Exception as e:
|
120
|
-
raise ValueError(f"
|
122
|
+
raise ValueError(f"Error checking table existence: {str(e)}") from e
|
121
123
|
|
122
124
|
# 推断表schema
|
123
125
|
table_schema = schema or df.schema
|
@@ -126,14 +128,14 @@ class FeatureTableClient:
|
|
126
128
|
timestamp_keys_ddl = []
|
127
129
|
for timestamp_key in timestamp_keys:
|
128
130
|
if timestamp_key not in primary_keys:
|
129
|
-
raise ValueError(f"
|
131
|
+
raise ValueError(f"Timestamp key '{timestamp_key}' must be a primary key")
|
130
132
|
timestamp_keys_ddl.append(f"`{timestamp_key}` TIMESTAMP")
|
131
133
|
|
132
134
|
#从环境变量获取额外标签
|
133
135
|
env_tags = {
|
134
136
|
"project_id": os.getenv("WEDATA_PROJECT_ID", ""), # wedata项目ID
|
135
137
|
"engine_name": os.getenv("WEDATA_NOTEBOOK_ENGINE", ""), # wedata引擎名称
|
136
|
-
"user_uin": os.getenv("
|
138
|
+
"user_uin": os.getenv("QCLOUD_SUBUIN", "") # wedata用户UIN
|
137
139
|
}
|
138
140
|
|
139
141
|
# 构建表属性(通过TBLPROPERTIES)
|
@@ -185,7 +187,19 @@ class FeatureTableClient:
|
|
185
187
|
if df is not None:
|
186
188
|
df.write.insertInto(table_name)
|
187
189
|
except Exception as e:
|
188
|
-
raise ValueError(f"
|
190
|
+
raise ValueError(f"Failed to create table: {str(e)}") from e
|
191
|
+
|
192
|
+
# 构建并返回FeatureTable对象
|
193
|
+
return FeatureTable(
|
194
|
+
name=name,
|
195
|
+
table_id=table_name,
|
196
|
+
description=description or "",
|
197
|
+
primary_keys=primary_keys,
|
198
|
+
partition_columns=partition_columns or [],
|
199
|
+
features=[field.name for field in table_schema.fields],
|
200
|
+
timestamp_keys=timestamp_keys or [],
|
201
|
+
tags=dict(**tags or {}, **env_tags)
|
202
|
+
)
|
189
203
|
|
190
204
|
def write_table(
|
191
205
|
self,
|
@@ -195,6 +209,7 @@ class FeatureTableClient:
|
|
195
209
|
checkpoint_location: Optional[str] = None,
|
196
210
|
trigger: Optional[Dict[str, Any]] = DEFAULT_WRITE_STREAM_TRIGGER
|
197
211
|
) -> Optional[StreamingQuery]:
|
212
|
+
|
198
213
|
"""
|
199
214
|
写入特征表数据(支持批处理和流式写入)
|
200
215
|
|
@@ -215,10 +230,13 @@ class FeatureTableClient:
|
|
215
230
|
# 验证写入模式
|
216
231
|
valid_modes = ["append", "overwrite"]
|
217
232
|
if mode not in valid_modes:
|
218
|
-
raise ValueError(f"
|
233
|
+
raise ValueError(f"Invalid write mode '{mode}', valid options: {valid_modes}")
|
234
|
+
|
235
|
+
# 表名校验
|
236
|
+
common_utils.validate_table_name(name)
|
219
237
|
|
220
|
-
#
|
221
|
-
table_name =
|
238
|
+
# 构建完整表名
|
239
|
+
table_name = common_utils.build_full_table_name(name)
|
222
240
|
|
223
241
|
# 判断是否是流式DataFrame
|
224
242
|
is_streaming = df.isStreaming
|
@@ -227,7 +245,7 @@ class FeatureTableClient:
|
|
227
245
|
if is_streaming:
|
228
246
|
# 流式写入
|
229
247
|
if not checkpoint_location:
|
230
|
-
raise ValueError("
|
248
|
+
raise ValueError("Streaming write requires checkpoint_location parameter")
|
231
249
|
|
232
250
|
writer = df.writeStream \
|
233
251
|
.format("parquet") \
|
@@ -252,6 +270,7 @@ class FeatureTableClient:
|
|
252
270
|
self,
|
253
271
|
name: str
|
254
272
|
) -> DataFrame:
|
273
|
+
|
255
274
|
"""
|
256
275
|
从特征表中读取数据
|
257
276
|
|
@@ -264,8 +283,12 @@ class FeatureTableClient:
|
|
264
283
|
Raises:
|
265
284
|
ValueError: 当表不存在或读取失败时抛出
|
266
285
|
"""
|
286
|
+
|
287
|
+
# 表名校验
|
288
|
+
common_utils.validate_table_name(name)
|
289
|
+
|
267
290
|
# 构建完整表名
|
268
|
-
table_name =
|
291
|
+
table_name = common_utils.build_full_table_name(name)
|
269
292
|
|
270
293
|
try:
|
271
294
|
# 检查表是否存在
|
@@ -278,10 +301,8 @@ class FeatureTableClient:
|
|
278
301
|
except Exception as e:
|
279
302
|
raise ValueError(f"读取表 '{table_name}' 失败: {str(e)}") from e
|
280
303
|
|
281
|
-
def drop_table(
|
282
|
-
|
283
|
-
name: str
|
284
|
-
) -> None:
|
304
|
+
def drop_table(self, name: str):
|
305
|
+
|
285
306
|
"""
|
286
307
|
删除特征表(表不存在时抛出异常)
|
287
308
|
|
@@ -296,8 +317,12 @@ class FeatureTableClient:
|
|
296
317
|
# 基本删除
|
297
318
|
drop_table("user_features")
|
298
319
|
"""
|
320
|
+
|
321
|
+
# 表名校验
|
322
|
+
common_utils.validate_table_name(name)
|
323
|
+
|
299
324
|
# 构建完整表名
|
300
|
-
table_name =
|
325
|
+
table_name = common_utils.build_full_table_name(name)
|
301
326
|
|
302
327
|
try:
|
303
328
|
# 检查表是否存在
|
@@ -311,3 +336,33 @@ class FeatureTableClient:
|
|
311
336
|
raise # 直接抛出已知的ValueError
|
312
337
|
except Exception as e:
|
313
338
|
raise RuntimeError(f"删除表 '{table_name}' 失败: {str(e)}") from e
|
339
|
+
|
340
|
+
def get_table(
|
341
|
+
self,
|
342
|
+
name: str,
|
343
|
+
spark_client: SparkClient
|
344
|
+
) -> FeatureTable:
|
345
|
+
|
346
|
+
"""获取特征表元数据信息
|
347
|
+
|
348
|
+
参数:
|
349
|
+
name: 特征表名称
|
350
|
+
spark_client: Spark客户端
|
351
|
+
|
352
|
+
返回:
|
353
|
+
FeatureTable对象
|
354
|
+
|
355
|
+
异常:
|
356
|
+
ValueError: 当表不存在或获取失败时抛出
|
357
|
+
"""
|
358
|
+
|
359
|
+
# 表名校验
|
360
|
+
common_utils.validate_table_name(name)
|
361
|
+
|
362
|
+
# 构建完整表名
|
363
|
+
table_name = common_utils.build_full_table_name(name)
|
364
|
+
|
365
|
+
try:
|
366
|
+
return spark_client.get_feature_table(table_name)
|
367
|
+
except Exception as e:
|
368
|
+
raise ValueError(f"获取表'{name}'元数据失败: {str(e)}") from e
|
@@ -6,73 +6,52 @@ from pyspark.sql.catalog import Column
|
|
6
6
|
from pyspark.sql.functions import when, isnull
|
7
7
|
from pyspark.sql.types import StructType, StringType, StructField
|
8
8
|
|
9
|
-
from feature_store.entities.feature import Feature
|
10
|
-
from feature_store.entities.feature_table import FeatureTable
|
11
|
-
from feature_store.entities.function_info import FunctionParameterInfo, FunctionInfo
|
12
|
-
from feature_store.utils.common_utils import unsanitize_identifier
|
13
|
-
from feature_store.utils.utils import sanitize_multi_level_name
|
9
|
+
from wedata.feature_store.entities.feature import Feature
|
10
|
+
from wedata.feature_store.entities.feature_table import FeatureTable
|
11
|
+
from wedata.feature_store.entities.function_info import FunctionParameterInfo, FunctionInfo
|
12
|
+
from wedata.feature_store.utils.common_utils import unsanitize_identifier, sanitize_multi_level_name
|
14
13
|
|
15
14
|
|
16
15
|
class SparkClient:
|
17
16
|
def __init__(self, spark: SparkSession):
|
18
17
|
self._spark = spark
|
19
18
|
|
20
|
-
def createDataFrame(self, data, schema) -> DataFrame:
|
21
|
-
return self._spark.createDataFrame(data, schema)
|
22
|
-
|
23
|
-
def read_table(
|
24
|
-
self, qualified_table_name, as_of_delta_timestamp=None, streaming=False
|
25
|
-
):
|
26
|
-
"""
|
27
|
-
Reads a Delta table, optionally as of some timestamp.
|
28
|
-
"""
|
29
|
-
if streaming and as_of_delta_timestamp:
|
30
|
-
raise ValueError(
|
31
|
-
"Internal error: as_of_delta_timestamp cannot be specified when"
|
32
|
-
" streaming=True."
|
33
|
-
)
|
34
|
-
|
35
|
-
base_reader = (
|
36
|
-
# By default, Structured Streaming only handles append operations. Because
|
37
|
-
# we have a notion of primary keys, most offline feature store operations
|
38
|
-
# are not appends. For example, FeatureStoreClient.write_table(mode=MERGE)
|
39
|
-
# will issue a MERGE operation.
|
40
|
-
# In order to propagate the non-append operations to the
|
41
|
-
# readStream, we set ignoreChanges to "true".
|
42
|
-
# For more information,
|
43
|
-
# see https://docs.databricks.com/delta/delta-streaming.html#ignore-updates-and-deletes
|
44
|
-
self._spark.readStream.format("delta").option("ignoreChanges", "true")
|
45
|
-
if streaming
|
46
|
-
else self._spark.read.format("delta")
|
47
|
-
)
|
48
|
-
|
49
|
-
if as_of_delta_timestamp:
|
50
|
-
return base_reader.option("timestampAsOf", as_of_delta_timestamp).table(
|
51
|
-
sanitize_multi_level_name(qualified_table_name)
|
52
|
-
)
|
53
|
-
else:
|
54
|
-
return base_reader.table(sanitize_multi_level_name(qualified_table_name))
|
55
19
|
|
56
20
|
def get_current_catalog(self):
|
57
21
|
"""
|
58
|
-
|
22
|
+
获取当前Spark会话的catalog名称(使用spark.catalog.currentCatalog属性)
|
23
|
+
|
24
|
+
返回:
|
25
|
+
str: 当前catalog名称,如果未设置则返回None
|
59
26
|
"""
|
60
27
|
try:
|
61
|
-
|
62
|
-
|
63
|
-
except Exception as e:
|
28
|
+
return unsanitize_identifier(self._spark.catalog.currentCatalog())
|
29
|
+
except Exception:
|
64
30
|
return None
|
65
31
|
|
66
32
|
def get_current_database(self):
|
67
33
|
"""
|
68
|
-
|
34
|
+
获取Spark上下文中当前设置的数据库名称
|
35
|
+
|
36
|
+
返回:
|
37
|
+
str: 当前数据库名称,如果获取失败则返回None
|
69
38
|
"""
|
70
39
|
try:
|
71
|
-
|
72
|
-
|
73
|
-
|
40
|
+
# 使用Spark SQL查询当前数据库
|
41
|
+
df = self._spark.sql("SELECT CURRENT_DATABASE()")
|
42
|
+
# 获取第一行第一列的值并去除特殊字符
|
43
|
+
return unsanitize_identifier(df.first()[0])
|
44
|
+
except Exception:
|
45
|
+
# 捕获所有异常并返回None
|
74
46
|
return None
|
75
47
|
|
48
|
+
|
49
|
+
|
50
|
+
|
51
|
+
def createDataFrame(self, data, schema) -> DataFrame:
|
52
|
+
return self._spark.createDataFrame(data, schema)
|
53
|
+
|
54
|
+
|
76
55
|
def read_table(self, table_name):
|
77
56
|
"""读取Spark表数据
|
78
57
|
|
@@ -134,11 +113,6 @@ class SparkClient:
|
|
134
113
|
) for row in columns
|
135
114
|
]
|
136
115
|
|
137
|
-
def get_online_stores(self, table_name):
|
138
|
-
return None
|
139
|
-
|
140
|
-
|
141
|
-
|
142
116
|
def get_feature_table(self, table_name):
|
143
117
|
|
144
118
|
# 获取表元数据
|
@@ -170,19 +144,19 @@ class SparkClient:
|
|
170
144
|
return FeatureTable(
|
171
145
|
name=table_name,
|
172
146
|
table_id=table_properties.get("table_id", table_name),
|
173
|
-
description=table.description or table_properties.get("
|
147
|
+
description=table.description or table_properties.get("comment", table_name),
|
174
148
|
primary_keys=table_properties.get("primaryKeys", "").split(",") if table_properties.get("primaryKeys") else [],
|
175
149
|
partition_columns=table.partitionColumnNames if hasattr(table, 'partitionColumnNames') else [],
|
176
150
|
features=features,
|
177
151
|
creation_timestamp=None, # Spark表元数据不包含创建时间戳
|
178
|
-
online_stores=
|
152
|
+
online_stores=None,
|
179
153
|
notebook_producers=None,
|
180
154
|
job_producers=None,
|
181
155
|
table_data_sources=None,
|
182
156
|
path_data_sources=None,
|
183
157
|
custom_data_sources=None,
|
184
158
|
timestamp_keys=table_properties.get("timestamp_keys"),
|
185
|
-
tags=table_properties
|
159
|
+
tags=table_properties
|
186
160
|
)
|
187
161
|
|
188
162
|
def _get_routines_with_parameters(self, full_routine_names: List[str]) -> DataFrame:
|
@@ -1,35 +1,28 @@
|
|
1
|
-
import json
|
2
1
|
import logging
|
3
2
|
import os
|
4
|
-
from collections import defaultdict
|
5
3
|
from types import ModuleType
|
6
|
-
from typing import Any,
|
4
|
+
from typing import Any, List, Optional, Set, Union
|
7
5
|
|
8
6
|
import mlflow
|
9
|
-
import
|
10
|
-
from mlflow.
|
11
|
-
from mlflow.utils.file_utils import TempDir, YamlSafeDumper, read_yaml
|
7
|
+
from mlflow.models import Model
|
8
|
+
from mlflow.utils.file_utils import TempDir, read_yaml
|
12
9
|
from pyspark.sql import DataFrame
|
13
|
-
from pyspark.sql.functions import struct
|
14
10
|
|
15
|
-
from feature_store.
|
16
|
-
from feature_store.entities.feature_function import FeatureFunction
|
17
|
-
from feature_store.entities.feature_lookup import FeatureLookup
|
18
|
-
from feature_store.entities.feature_spec import FeatureSpec
|
19
|
-
from feature_store.entities.training_set import TrainingSet
|
20
|
-
from feature_store.spark_client.spark_client import SparkClient
|
11
|
+
from wedata.feature_store.constants import constants
|
12
|
+
from wedata.feature_store.entities.feature_function import FeatureFunction
|
13
|
+
from wedata.feature_store.entities.feature_lookup import FeatureLookup
|
14
|
+
from wedata.feature_store.entities.feature_spec import FeatureSpec
|
15
|
+
from wedata.feature_store.entities.training_set import TrainingSet
|
16
|
+
from wedata.feature_store.spark_client.spark_client import SparkClient
|
21
17
|
|
22
|
-
from feature_store.constants.constants import (
|
18
|
+
from wedata.feature_store.constants.constants import (
|
23
19
|
_NO_RESULT_TYPE_PASSED,
|
24
|
-
|
25
|
-
_USE_SPARK_NATIVE_JOIN,
|
26
|
-
_WARN,
|
27
|
-
MODEL_DATA_PATH_ROOT,
|
28
|
-
PREDICTION_COLUMN_NAME,
|
20
|
+
_USE_SPARK_NATIVE_JOIN
|
29
21
|
)
|
30
22
|
|
31
|
-
from feature_store.utils import common_utils, training_set_utils
|
32
|
-
from feature_store.utils.
|
23
|
+
from wedata.feature_store.utils import common_utils, training_set_utils, uc_utils
|
24
|
+
from wedata.feature_store.utils.signature_utils import get_mlflow_signature_from_feature_spec, \
|
25
|
+
drop_signature_inputs_and_invalid_params
|
33
26
|
|
34
27
|
_logger = logging.getLogger(__name__)
|
35
28
|
|
@@ -46,7 +39,6 @@ class TrainingSetClient:
|
|
46
39
|
def create_training_set(
|
47
40
|
self,
|
48
41
|
feature_spec: FeatureSpec,
|
49
|
-
feature_column_infos: List[FeatureColumnInfo],
|
50
42
|
label_names: List[str],
|
51
43
|
df: DataFrame,
|
52
44
|
ft_metadata: training_set_utils._FeatureTableMetadata,
|
@@ -57,18 +49,6 @@ class TrainingSetClient:
|
|
57
49
|
{odci.udf_name for odci in feature_spec.on_demand_column_infos},
|
58
50
|
)
|
59
51
|
|
60
|
-
# TODO(divyagupta-db): Move validation from _validate_join_feature_data in feature_lookup_utils.py
|
61
|
-
# to a helper function called here and in score_batch.
|
62
|
-
|
63
|
-
# Add consumer of each feature and instrument as final step
|
64
|
-
consumer_feature_table_map = defaultdict(list)
|
65
|
-
for feature in feature_column_infos:
|
66
|
-
consumer_feature_table_map[feature.table_name].append(feature.feature_name)
|
67
|
-
consumed_udf_names = [f.udf_name for f in feature_spec.function_infos]
|
68
|
-
|
69
|
-
# Spark query planning is known to cause spark driver to crash if there are many feature tables to PiT join.
|
70
|
-
# See https://docs.google.com/document/d/1EyA4vvlWikTJMeinsLkxmRAVNlXoF1eqoZElOdqlWyY/edit
|
71
|
-
# So we disable native join by default.
|
72
52
|
training_set_utils.warn_if_non_photon_for_native_spark(
|
73
53
|
kwargs.get(_USE_SPARK_NATIVE_JOIN, False), self._spark_client
|
74
54
|
)
|
@@ -96,6 +76,12 @@ class TrainingSetClient:
|
|
96
76
|
feature_lookups = [f for f in features if isinstance(f, FeatureLookup)]
|
97
77
|
feature_functions = [f for f in features if isinstance(f, FeatureFunction)]
|
98
78
|
|
79
|
+
# 最多支持100个FeatureFunctions
|
80
|
+
if len(feature_functions) > training_set_utils.MAX_FEATURE_FUNCTIONS:
|
81
|
+
raise ValueError(
|
82
|
+
f"A maximum of {training_set_utils.MAX_FEATURE_FUNCTIONS} FeatureFunctions are supported."
|
83
|
+
)
|
84
|
+
|
99
85
|
# 如果未提供标签,则用空列表初始化label_names
|
100
86
|
label_names = common_utils.as_list(label, [])
|
101
87
|
del label
|
@@ -137,7 +123,6 @@ class TrainingSetClient:
|
|
137
123
|
|
138
124
|
return self.create_training_set(
|
139
125
|
feature_spec,
|
140
|
-
column_infos.feature_column_infos,
|
141
126
|
label_names,
|
142
127
|
df,
|
143
128
|
ft_metadata,
|
@@ -145,9 +130,6 @@ class TrainingSetClient:
|
|
145
130
|
)
|
146
131
|
|
147
132
|
|
148
|
-
|
149
|
-
|
150
|
-
|
151
133
|
def create_feature_spec(
|
152
134
|
self,
|
153
135
|
name: str,
|
@@ -194,3 +176,192 @@ class TrainingSetClient:
|
|
194
176
|
)
|
195
177
|
|
196
178
|
return feature_spec
|
179
|
+
|
180
|
+
|
181
|
+
def log_model(
|
182
|
+
self,
|
183
|
+
model: Any,
|
184
|
+
artifact_path: str,
|
185
|
+
*,
|
186
|
+
flavor: ModuleType,
|
187
|
+
training_set: Optional[TrainingSet],
|
188
|
+
registered_model_name: Optional[str],
|
189
|
+
await_registration_for: int,
|
190
|
+
infer_input_example: bool,
|
191
|
+
**kwargs,
|
192
|
+
):
|
193
|
+
# Validate only one of the training_set, feature_spec_path arguments is provided.
|
194
|
+
# Retrieve the FeatureSpec, then remove training_set, feature_spec_path from local scope.
|
195
|
+
feature_spec_path = kwargs.pop("feature_spec_path", None)
|
196
|
+
if (training_set is None) == (feature_spec_path is None):
|
197
|
+
raise ValueError(
|
198
|
+
"Either 'training_set' or 'feature_spec_path' must be provided, but not both."
|
199
|
+
)
|
200
|
+
# Retrieve the FeatureSpec and then reformat tables in local metastore to 2L before serialization.
|
201
|
+
# This will make sure the format of the feature spec with local metastore tables is always consistent.
|
202
|
+
if training_set:
|
203
|
+
all_uc_tables = all(
|
204
|
+
[
|
205
|
+
uc_utils.is_uc_entity(table_info.table_name)
|
206
|
+
for table_info in training_set.feature_spec.table_infos
|
207
|
+
]
|
208
|
+
)
|
209
|
+
# training_set.feature_spec is guaranteed to be 3L from FeatureStoreClient.create_training_set.
|
210
|
+
feature_spec = uc_utils.get_feature_spec_with_reformat_full_table_names(
|
211
|
+
training_set.feature_spec
|
212
|
+
)
|
213
|
+
label_type_map = training_set._label_data_types
|
214
|
+
|
215
|
+
labels = training_set._labels
|
216
|
+
df_head = training_set._df.drop(*labels).head()
|
217
|
+
else:
|
218
|
+
# FeatureSpec.load expects the root directory of feature_spec.yaml
|
219
|
+
root_dir, file_name = os.path.split(feature_spec_path)
|
220
|
+
if file_name != FeatureSpec.FEATURE_ARTIFACT_FILE:
|
221
|
+
raise ValueError(
|
222
|
+
f"'feature_spec_path' must be a path to {FeatureSpec.FEATURE_ARTIFACT_FILE}."
|
223
|
+
)
|
224
|
+
feature_spec = FeatureSpec.load(root_dir)
|
225
|
+
|
226
|
+
# The loaded FeatureSpec is not guaranteed to be 3L.
|
227
|
+
# First call get_feature_spec_with_full_table_names to append the default metastore to 2L names,
|
228
|
+
# as get_feature_spec_with_reformat_full_table_names expects full 3L table names and throws otherwise.
|
229
|
+
# TODO (ML-26593): Consolidate this into a single function that allows either 2L/3L names.
|
230
|
+
feature_spec_with_full_table_names = (
|
231
|
+
uc_utils.get_feature_spec_with_full_table_names(feature_spec)
|
232
|
+
)
|
233
|
+
all_uc_tables = all(
|
234
|
+
[
|
235
|
+
uc_utils.is_uc_entity(table_info.table_name)
|
236
|
+
for table_info in feature_spec_with_full_table_names.table_infos
|
237
|
+
]
|
238
|
+
)
|
239
|
+
feature_spec = uc_utils.get_feature_spec_with_reformat_full_table_names(
|
240
|
+
feature_spec_with_full_table_names
|
241
|
+
)
|
242
|
+
label_type_map = None
|
243
|
+
df_head = None
|
244
|
+
del training_set, feature_spec_path
|
245
|
+
|
246
|
+
override_output_schema = kwargs.pop("output_schema", None)
|
247
|
+
params = kwargs.pop("params", {})
|
248
|
+
params["result_type"] = params.get("result_type", _NO_RESULT_TYPE_PASSED)
|
249
|
+
# Signatures will ony be supported for UC-table-only models to
|
250
|
+
# mitigate new online scoring behavior from being a breaking regression for older
|
251
|
+
# models.
|
252
|
+
# See https://docs.google.com/document/d/1L5tLY-kRreRefDfuAM3crXvYlirkcPuUUppU8uIMVM0/edit#
|
253
|
+
try:
|
254
|
+
if all_uc_tables:
|
255
|
+
signature = get_mlflow_signature_from_feature_spec(
|
256
|
+
feature_spec, label_type_map, override_output_schema, params
|
257
|
+
)
|
258
|
+
else:
|
259
|
+
_logger.warning(
|
260
|
+
"Model could not be logged with a signature because the training set uses feature tables in "
|
261
|
+
"Hive Metastore. Migrate the feature tables to Unity Catalog for model to be logged "
|
262
|
+
"with a signature. "
|
263
|
+
"See https://docs.databricks.com/en/machine-learning/feature-store/uc/upgrade-feature-table-to-uc.html for more details."
|
264
|
+
)
|
265
|
+
signature = None
|
266
|
+
except Exception as e:
|
267
|
+
_logger.warning(f"Model could not be logged with a signature: {e}")
|
268
|
+
signature = None
|
269
|
+
|
270
|
+
with TempDir() as tmp_location:
|
271
|
+
data_path = os.path.join(tmp_location.path(), "feature_store")
|
272
|
+
raw_mlflow_model = Model(
|
273
|
+
signature=drop_signature_inputs_and_invalid_params(signature)
|
274
|
+
)
|
275
|
+
raw_model_path = os.path.join(
|
276
|
+
data_path, constants.RAW_MODEL_FOLDER
|
277
|
+
)
|
278
|
+
if flavor.FLAVOR_NAME != mlflow.pyfunc.FLAVOR_NAME:
|
279
|
+
flavor.save_model(
|
280
|
+
model, raw_model_path, mlflow_model=raw_mlflow_model, **kwargs
|
281
|
+
)
|
282
|
+
else:
|
283
|
+
flavor.save_model(
|
284
|
+
raw_model_path,
|
285
|
+
mlflow_model=raw_mlflow_model,
|
286
|
+
python_model=model,
|
287
|
+
**kwargs,
|
288
|
+
)
|
289
|
+
if not "python_function" in raw_mlflow_model.flavors:
|
290
|
+
raise ValueError(
|
291
|
+
f"FeatureStoreClient.log_model does not support '{flavor.__name__}' "
|
292
|
+
f"since it does not have a python_function model flavor."
|
293
|
+
)
|
294
|
+
|
295
|
+
# Re-use the conda environment from the raw model for the packaged model. Later, we may
|
296
|
+
# add an additional requirement for the Feature Store library. At the moment, however,
|
297
|
+
# the databricks-feature-store package is not available via conda or pip.
|
298
|
+
model_env = raw_mlflow_model.flavors["python_function"][mlflow.pyfunc.ENV]
|
299
|
+
if isinstance(model_env, dict):
|
300
|
+
# mlflow 2.0 has multiple supported environments
|
301
|
+
conda_file = model_env[mlflow.pyfunc.EnvType.CONDA]
|
302
|
+
else:
|
303
|
+
conda_file = model_env
|
304
|
+
|
305
|
+
conda_env = read_yaml(raw_model_path, conda_file)
|
306
|
+
|
307
|
+
# Check if databricks-feature-lookup version is specified in conda_env
|
308
|
+
lookup_client_version_specified = False
|
309
|
+
for dependency in conda_env.get("dependencies", []):
|
310
|
+
if isinstance(dependency, dict):
|
311
|
+
for pip_dep in dependency.get("pip", []):
|
312
|
+
if pip_dep.startswith(
|
313
|
+
constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE
|
314
|
+
):
|
315
|
+
lookup_client_version_specified = True
|
316
|
+
break
|
317
|
+
|
318
|
+
# If databricks-feature-lookup version is not specified, add default version
|
319
|
+
if not lookup_client_version_specified:
|
320
|
+
# Get the pip package string for the databricks-feature-lookup client
|
321
|
+
default_databricks_feature_lookup_pip_package = common_utils.pip_depependency_pinned_major_version(
|
322
|
+
pip_package_name=constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE,
|
323
|
+
major_version=constants.FEATURE_LOOKUP_CLIENT_MAJOR_VERSION,
|
324
|
+
)
|
325
|
+
common_utils.add_mlflow_pip_depependency(
|
326
|
+
conda_env, default_databricks_feature_lookup_pip_package
|
327
|
+
)
|
328
|
+
|
329
|
+
try:
|
330
|
+
if df_head is not None and infer_input_example:
|
331
|
+
input_example = df_head.asDict()
|
332
|
+
else:
|
333
|
+
input_example = None
|
334
|
+
except Exception:
|
335
|
+
input_example = None
|
336
|
+
|
337
|
+
# todo:
|
338
|
+
#feature_spec.save(data_path)
|
339
|
+
|
340
|
+
# Log the packaged model. If no run is active, this call will create an active run.
|
341
|
+
mlflow.pyfunc.log_model(
|
342
|
+
artifact_path=artifact_path,
|
343
|
+
loader_module=constants.MLFLOW_MODEL_NAME,
|
344
|
+
data_path=data_path,
|
345
|
+
code_path=None,
|
346
|
+
conda_env=conda_env,
|
347
|
+
signature=signature,
|
348
|
+
input_example=input_example,
|
349
|
+
)
|
350
|
+
if registered_model_name is not None:
|
351
|
+
# The call to mlflow.pyfunc.log_model will create an active run, so it is safe to
|
352
|
+
# obtain the run_id for the active run.
|
353
|
+
run_id = mlflow.tracking.fluent.active_run().info.run_id
|
354
|
+
|
355
|
+
# If the user provided an explicit model_registry_uri when constructing the FeatureStoreClient,
|
356
|
+
# we respect this by setting the registry URI prior to reading the model from Model
|
357
|
+
# Registry.
|
358
|
+
# todo:
|
359
|
+
# if self._model_registry_uri:
|
360
|
+
# # This command will override any previously set registry_uri.
|
361
|
+
# mlflow.set_registry_uri(self._model_registry_uri)
|
362
|
+
|
363
|
+
mlflow.register_model(
|
364
|
+
"runs:/%s/%s" % (run_id, artifact_path),
|
365
|
+
registered_model_name,
|
366
|
+
await_registration_for=await_registration_for,
|
367
|
+
)
|