wedata-feature-engineering 0.1.8__tar.gz → 0.1.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/PKG-INFO +1 -1
- wedata-feature-engineering-0.1.13/tests/test_feature_store.py +298 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/__init__.py +1 -1
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/client.py +4 -1
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/feature_table_client/feature_table_client.py +10 -12
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/spark_client/spark_client.py +43 -38
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/training_set_client/training_set_client.py +6 -7
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata_feature_engineering.egg-info/PKG-INFO +1 -1
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata_feature_engineering.egg-info/SOURCES.txt +1 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/README.md +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/setup.cfg +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/setup.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/constants/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/constants/constants.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/column_info.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/data_type.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/environment_variables.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_column_info.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_function.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_lookup.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_spec.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_spec_constants.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_table.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/feature_table_info.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/function_info.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/on_demand_column_info.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/source_data_column_info.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/entities/training_set.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/feature_table_client/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/spark_client/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/training_set_client/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/__init__.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/common_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/feature_lookup_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/feature_spec_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/feature_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/on_demand_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/schema_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/signature_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/topological_sort.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/training_set_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/uc_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata/feature_store/utils/validation_utils.py +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata_feature_engineering.egg-info/dependency_links.txt +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata_feature_engineering.egg-info/requires.txt +0 -0
- {wedata-feature-engineering-0.1.8 → wedata-feature-engineering-0.1.13}/wedata_feature_engineering.egg-info/top_level.txt +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
# This is a test script for FeatureStoreClient
|
2
|
+
from datetime import date
|
3
|
+
|
4
|
+
import pandas as pd
|
5
|
+
from pyspark.sql import SparkSession, DataFrame
|
6
|
+
from sklearn.ensemble import RandomForestClassifier
|
7
|
+
|
8
|
+
import mlflow.sklearn
|
9
|
+
|
10
|
+
from wedata.feature_store.client import FeatureStoreClient
|
11
|
+
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, IntegerType, DoubleType, DateType
|
12
|
+
|
13
|
+
from wedata.feature_store.entities.feature_lookup import FeatureLookup
|
14
|
+
from wedata.feature_store.entities.training_set import TrainingSet
|
15
|
+
|
16
|
+
|
17
|
+
# 创建FeatureStoreClient实例
|
18
|
+
def create_client() -> FeatureStoreClient:
|
19
|
+
spark = SparkSession.builder \
|
20
|
+
.appName("FeatureStoreDemo") \
|
21
|
+
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
|
22
|
+
.config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
|
23
|
+
.config("spark.jars.packages", "io.delta:delta-core_2.12:2.4.0") \
|
24
|
+
.enableHiveSupport() \
|
25
|
+
.getOrCreate()
|
26
|
+
|
27
|
+
# 创建FeatureStoreClient实例
|
28
|
+
client = FeatureStoreClient(spark)
|
29
|
+
return client
|
30
|
+
|
31
|
+
# 创建特征表
|
32
|
+
def create_table(client: FeatureStoreClient):
|
33
|
+
user_data = [
|
34
|
+
(1001, 25, "F", 120.5, date(2020, 5, 15)), # user_id, age, gender, avg_purchase, member_since
|
35
|
+
(1002, 30, "M", 200.0, date(2019, 3, 10)),
|
36
|
+
(1003, 35, "F", 180.3, date(2021, 1, 20))
|
37
|
+
]
|
38
|
+
|
39
|
+
# 定义schema
|
40
|
+
user_schema = StructType([
|
41
|
+
StructField("user_id", IntegerType(), False, metadata={"comment": "用户唯一标识ID"}),
|
42
|
+
StructField("age", IntegerType(), True, metadata={"comment": "用户年龄"}),
|
43
|
+
StructField("gender", StringType(), True, metadata={"comment": "用户性别(F-女性,M-男性)"}),
|
44
|
+
StructField("avg_purchase", DoubleType(), True, metadata={"comment": "用户平均消费金额"}),
|
45
|
+
StructField("member_since", DateType(), True, metadata={"comment": "用户注册日期"})
|
46
|
+
])
|
47
|
+
|
48
|
+
# 创建DataFrame
|
49
|
+
user_df = client.spark.createDataFrame(user_data, user_schema)
|
50
|
+
|
51
|
+
display(user_df)
|
52
|
+
|
53
|
+
client.create_table(
|
54
|
+
name="user_features", # 表名
|
55
|
+
primary_keys=["user_id"], # 主键
|
56
|
+
df=user_df, # 数据
|
57
|
+
partition_columns=["member_since"], # 按注册日期分区
|
58
|
+
description="用户基本特征和消费行为特征", # 描述
|
59
|
+
tags={ # 业务标签
|
60
|
+
"create_by": "tencent",
|
61
|
+
"sensitivity": "internal"
|
62
|
+
}
|
63
|
+
)
|
64
|
+
|
65
|
+
# 商品数据样例
|
66
|
+
product_data = [
|
67
|
+
(5001, "电子", 599.0, 0.85, date(2024, 1, 1)),
|
68
|
+
(5002, "服装", 199.0, 0.92, date(2023, 11, 15)),
|
69
|
+
(5003, "家居", 299.0, 0.78, date(2024, 2, 20))
|
70
|
+
]
|
71
|
+
|
72
|
+
# 定义schema
|
73
|
+
product_schema = StructType([
|
74
|
+
StructField("product_id", IntegerType(), False),
|
75
|
+
StructField("category", StringType(), True),
|
76
|
+
StructField("price", DoubleType(), True),
|
77
|
+
StructField("popularity", DoubleType(), True),
|
78
|
+
StructField("release_date", DateType(), True)
|
79
|
+
])
|
80
|
+
|
81
|
+
# 创建DataFrame
|
82
|
+
product_df = client.spark.createDataFrame(product_data, product_schema)
|
83
|
+
|
84
|
+
display(product_df)
|
85
|
+
|
86
|
+
# 创建商品特征表
|
87
|
+
client.create_table(
|
88
|
+
name="product_features",
|
89
|
+
primary_keys=["product_id"],
|
90
|
+
df=product_df,
|
91
|
+
description="商品基本属性和受欢迎程度",
|
92
|
+
tags={ # 业务标签
|
93
|
+
"feature_table": "true",
|
94
|
+
"sensitivity": "internal"
|
95
|
+
}
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
# 追加写入数据
|
100
|
+
def append_data(client: FeatureStoreClient):
|
101
|
+
user_data = [
|
102
|
+
(1004, 45, "F", 120.5, date(2020, 5, 15)),
|
103
|
+
(1005, 55, "M", 200.0, date(2019, 3, 10)),
|
104
|
+
(1006, 65, "F", 180.3, date(2021, 1, 20))
|
105
|
+
]
|
106
|
+
|
107
|
+
user_schema = StructType([
|
108
|
+
StructField("user_id", IntegerType(), False, metadata={"comment": "用户唯一标识ID"}),
|
109
|
+
StructField("age", IntegerType(), True, metadata={"comment": "用户年龄"}),
|
110
|
+
StructField("gender", StringType(), True, metadata={"comment": "用户性别(F-女性,M-男性)"}),
|
111
|
+
StructField("avg_purchase", DoubleType(), True, metadata={"comment": "用户平均消费金额"}),
|
112
|
+
StructField("member_since", DateType(), True, metadata={"comment": "用户注册日期"})
|
113
|
+
])
|
114
|
+
|
115
|
+
user_df = client.spark.createDataFrame(user_data, user_schema)
|
116
|
+
|
117
|
+
display(user_df)
|
118
|
+
|
119
|
+
client.write_table(
|
120
|
+
name="user_features",
|
121
|
+
df=user_df,
|
122
|
+
mode="append"
|
123
|
+
)
|
124
|
+
|
125
|
+
product_data = [
|
126
|
+
(5007, "食品", 599.0, 0.85, date(2024, 1, 1)),
|
127
|
+
(5008, "玩具", 199.0, 0.92, date(2023, 11, 15)),
|
128
|
+
(5009, "电脑", 299.0, 0.78, date(2024, 2, 20))
|
129
|
+
]
|
130
|
+
|
131
|
+
product_schema = StructType([
|
132
|
+
StructField("product_id", IntegerType(), False, metadata={"comment": "商品唯一标识ID"}),
|
133
|
+
StructField("category", StringType(), True, metadata={"comment": "商品类别"}),
|
134
|
+
StructField("price", DoubleType(), True, metadata={"comment": "商品价格(元)"}),
|
135
|
+
StructField("popularity", DoubleType(), True, metadata={"comment": "商品受欢迎程度(0-1)"}),
|
136
|
+
StructField("release_date", DateType(), True, metadata={"comment": "商品发布日期"})
|
137
|
+
])
|
138
|
+
|
139
|
+
product_df = client.spark.createDataFrame(product_data, product_schema)
|
140
|
+
|
141
|
+
display(product_df)
|
142
|
+
|
143
|
+
client.write_table(
|
144
|
+
name="product_features",
|
145
|
+
df=product_df,
|
146
|
+
mode="append"
|
147
|
+
)
|
148
|
+
|
149
|
+
# 读取特征表数据
|
150
|
+
def read_table(client: FeatureStoreClient):
|
151
|
+
|
152
|
+
# 读取用户特征表
|
153
|
+
user_df = client.read_table("user_features")
|
154
|
+
display(user_df)
|
155
|
+
|
156
|
+
# 读取商品特征表
|
157
|
+
product_df = client.read_table("product_features")
|
158
|
+
display(product_df)
|
159
|
+
|
160
|
+
# 获取特征表元数据
|
161
|
+
def get_table(client: FeatureStoreClient):
|
162
|
+
feature_table_user = client.get_table(name="user_features")
|
163
|
+
print(feature_table_user)
|
164
|
+
|
165
|
+
|
166
|
+
# 创建训练集
|
167
|
+
def create_training_set(client: FeatureStoreClient) -> TrainingSet:
|
168
|
+
|
169
|
+
# 订单数据样例
|
170
|
+
order_data = [
|
171
|
+
(9001, 1001, 5001, date(2025, 3, 1), 1, 0),
|
172
|
+
(9002, 1002, 5002, date(2025, 3, 2), 2, 1),
|
173
|
+
(9003, 1003, 5003, date(2025, 3, 3), 1, 0)
|
174
|
+
]
|
175
|
+
|
176
|
+
# 定义schema
|
177
|
+
order_schema = StructType([
|
178
|
+
StructField("order_id", IntegerType(), False, metadata={"comment": "订单唯一标识ID"}),
|
179
|
+
StructField("user_id", IntegerType(), True, metadata={"comment": "用户ID"}),
|
180
|
+
StructField("product_id", IntegerType(), True, metadata={"comment": "商品ID"}),
|
181
|
+
StructField("order_date", DateType(), True, metadata={"comment": "订单日期"}),
|
182
|
+
StructField("quantity", IntegerType(), True, metadata={"comment": "购买数量"}),
|
183
|
+
StructField("is_returned", IntegerType(), True, metadata={"comment": "是否退货(0-未退货,1-已退货)"})
|
184
|
+
])
|
185
|
+
|
186
|
+
# 创建DataFrame
|
187
|
+
order_df = client.spark.createDataFrame(order_data, order_schema)
|
188
|
+
|
189
|
+
# 查看订单数据
|
190
|
+
display(order_df)
|
191
|
+
|
192
|
+
# 定义用户特征查找
|
193
|
+
user_feature_lookup = FeatureLookup(
|
194
|
+
table_name="user_features",
|
195
|
+
feature_names=["age", "gender", "avg_purchase"], # 选择需要的特征列
|
196
|
+
lookup_key="user_id" # 关联键
|
197
|
+
)
|
198
|
+
|
199
|
+
# 定义商品特征查找
|
200
|
+
product_feature_lookup = FeatureLookup(
|
201
|
+
table_name="product_features",
|
202
|
+
feature_names=["category", "price", "popularity"], # 选择需要的特征列
|
203
|
+
lookup_key="product_id" # 关联键
|
204
|
+
)
|
205
|
+
|
206
|
+
# 创建训练集
|
207
|
+
training_set = client.create_training_set(
|
208
|
+
df=order_df, # 基础数据
|
209
|
+
feature_lookups=[user_feature_lookup, product_feature_lookup], # 特征查找配置
|
210
|
+
label="is_returned", # 标签列
|
211
|
+
exclude_columns=["order_id"] # 排除不需要的列
|
212
|
+
)
|
213
|
+
|
214
|
+
# 获取最终的训练DataFrame
|
215
|
+
training_df = training_set.load_df()
|
216
|
+
|
217
|
+
# 查看训练数据
|
218
|
+
display(training_df)
|
219
|
+
|
220
|
+
return training_set
|
221
|
+
|
222
|
+
|
223
|
+
# 查看df中数据
|
224
|
+
def display(df):
|
225
|
+
|
226
|
+
"""
|
227
|
+
打印DataFrame的结构和数据
|
228
|
+
|
229
|
+
参数:
|
230
|
+
df (DataFrame): 要打印的Spark DataFrame
|
231
|
+
num_rows (int): 要显示的行数,默认为20
|
232
|
+
truncate (bool): 是否截断过长的列,默认为True
|
233
|
+
"""
|
234
|
+
# 打印表结构
|
235
|
+
print("=== 表结构 ===")
|
236
|
+
df.printSchema()
|
237
|
+
|
238
|
+
# 打印数据
|
239
|
+
print("\n=== 数据示例 ===")
|
240
|
+
df.show(20, True)
|
241
|
+
|
242
|
+
# 打印行数统计
|
243
|
+
print(f"\n总行数: {df.count()}")
|
244
|
+
|
245
|
+
|
246
|
+
def log_model(client: FeatureStoreClient,
|
247
|
+
training_set: TrainingSet
|
248
|
+
):
|
249
|
+
|
250
|
+
# 初始化模型
|
251
|
+
model = RandomForestClassifier(
|
252
|
+
n_estimators=100, # 增加树的数量提高模型稳定性
|
253
|
+
random_state=42 # 固定随机种子保证可复现性
|
254
|
+
)
|
255
|
+
|
256
|
+
# 获取训练数据并转换为Pandas格式
|
257
|
+
train_pd = training_set.load_df().toPandas()
|
258
|
+
|
259
|
+
# 特征工程处理
|
260
|
+
# 1. 处理分类特征
|
261
|
+
train_pd['gender'] = train_pd['gender'].map({'F': 0, 'M': 1})
|
262
|
+
train_pd = pd.get_dummies(train_pd, columns=['category'])
|
263
|
+
|
264
|
+
# 2. 处理日期特征(转换为距今天数)
|
265
|
+
current_date = pd.to_datetime('2025-04-19') # 使用参考信息中的当前时间
|
266
|
+
train_pd['order_days'] = (current_date - pd.to_datetime(train_pd['order_date'])).dt.days
|
267
|
+
train_pd = train_pd.drop('order_date', axis=1)
|
268
|
+
|
269
|
+
# 3. 创建交互特征(价格*数量)
|
270
|
+
train_pd['total_amount'] = train_pd['price'] * train_pd['quantity']
|
271
|
+
|
272
|
+
# 分离特征和标签
|
273
|
+
X = train_pd.drop("is_returned", axis=1)
|
274
|
+
y = train_pd["is_returned"]
|
275
|
+
|
276
|
+
# 训练模型
|
277
|
+
model.fit(X, y)
|
278
|
+
# 记录模型到MLflow
|
279
|
+
with mlflow.start_run():
|
280
|
+
client.log_model(
|
281
|
+
model=model,
|
282
|
+
artifact_path="return_prediction_model", # 更符合业务场景的路径名
|
283
|
+
flavor=mlflow.sklearn,
|
284
|
+
training_set=training_set,
|
285
|
+
registered_model_name="product_return_prediction_model" # 更准确的模型名称
|
286
|
+
)
|
287
|
+
|
288
|
+
|
289
|
+
|
290
|
+
# Press the green button in the gutter to run the script.
|
291
|
+
if __name__ == '__main__':
|
292
|
+
client = create_client()
|
293
|
+
#create_table(client)
|
294
|
+
#append_data(client)
|
295
|
+
#read_table(client)
|
296
|
+
#get_table(client)
|
297
|
+
training_set = create_training_set(client)
|
298
|
+
#log_model(client, training_set)
|
@@ -178,7 +178,7 @@ class FeatureStoreClient:
|
|
178
178
|
raise ValueError("FeatureLookup must specify a table_name")
|
179
179
|
# 先校验表名格式是否合法
|
180
180
|
common_utils.validate_table_name(feature.table_name)
|
181
|
-
#
|
181
|
+
# 再构建完整表名,并赋值给FeatureLookup对象
|
182
182
|
feature.table_name = common_utils.build_full_table_name(feature.table_name)
|
183
183
|
|
184
184
|
features = feature_lookups
|
@@ -202,6 +202,7 @@ class FeatureStoreClient:
|
|
202
202
|
flavor: ModuleType,
|
203
203
|
training_set: Optional[TrainingSet] = None,
|
204
204
|
registered_model_name: Optional[str] = None,
|
205
|
+
model_registry_uri: Optional[str] = None,
|
205
206
|
await_registration_for: int = mlflow.tracking._model_registry.DEFAULT_AWAIT_MAX_SLEEP_SECONDS,
|
206
207
|
infer_input_example: bool = False,
|
207
208
|
**kwargs,
|
@@ -218,6 +219,7 @@ class FeatureStoreClient:
|
|
218
219
|
flavor: MLflow模型类型模块(如mlflow.sklearn)
|
219
220
|
training_set: 训练模型使用的TrainingSet对象(可选)
|
220
221
|
registered_model_name: 要注册的模型名称(可选)
|
222
|
+
model_registry_uri: 模型注册中心地址(可选)
|
221
223
|
await_registration_for: 等待模型注册完成的秒数(默认300秒)
|
222
224
|
infer_input_example: 是否自动记录输入示例(默认False)
|
223
225
|
|
@@ -231,6 +233,7 @@ class FeatureStoreClient:
|
|
231
233
|
flavor=flavor,
|
232
234
|
training_set=training_set,
|
233
235
|
registered_model_name=registered_model_name,
|
236
|
+
model_registry_uri=model_registry_uri,
|
234
237
|
await_registration_for=await_registration_for,
|
235
238
|
infer_input_example=infer_input_example,
|
236
239
|
**kwargs
|
@@ -113,7 +113,7 @@ class FeatureTableClient:
|
|
113
113
|
try:
|
114
114
|
if self._spark.catalog.tableExists(table_name):
|
115
115
|
raise ValueError(
|
116
|
-
f"Table '{
|
116
|
+
f"Table '{name}' already exists\n"
|
117
117
|
"Solutions:\n"
|
118
118
|
"1. Use a different table name\n"
|
119
119
|
"2. Drop the existing table: spark.sql(f'DROP TABLE {name}')\n"
|
@@ -125,11 +125,6 @@ class FeatureTableClient:
|
|
125
125
|
table_schema = schema or df.schema
|
126
126
|
|
127
127
|
# 构建时间戳键属性
|
128
|
-
timestamp_keys_ddl = []
|
129
|
-
for timestamp_key in timestamp_keys:
|
130
|
-
if timestamp_key not in primary_keys:
|
131
|
-
raise ValueError(f"Timestamp key '{timestamp_key}' must be a primary key")
|
132
|
-
timestamp_keys_ddl.append(f"`{timestamp_key}` TIMESTAMP")
|
133
128
|
|
134
129
|
#从环境变量获取额外标签
|
135
130
|
env_tags = {
|
@@ -142,6 +137,7 @@ class FeatureTableClient:
|
|
142
137
|
tbl_properties = {
|
143
138
|
"feature_table": "TRUE",
|
144
139
|
"primaryKeys": ",".join(primary_keys),
|
140
|
+
"timestampKeys": ",".join(timestamp_keys) if timestamp_keys else "",
|
145
141
|
"comment": description or "",
|
146
142
|
**{f"{k}": v for k, v in (tags or {}).items()},
|
147
143
|
**{f"feature_{k}": v for k, v in (env_tags or {}).items()}
|
@@ -171,7 +167,7 @@ class FeatureTableClient:
|
|
171
167
|
CREATE TABLE {table_name} (
|
172
168
|
{', '.join(columns_ddl)}
|
173
169
|
)
|
174
|
-
USING
|
170
|
+
USING iceberg
|
175
171
|
{partition_expr}
|
176
172
|
TBLPROPERTIES (
|
177
173
|
{', '.join(f"'{k}'='{self._escape_sql_value(v)}'" for k, v in tbl_properties.items())}
|
@@ -293,13 +289,13 @@ class FeatureTableClient:
|
|
293
289
|
try:
|
294
290
|
# 检查表是否存在
|
295
291
|
if not self._spark.catalog.tableExists(table_name):
|
296
|
-
raise ValueError(f"
|
292
|
+
raise ValueError(f"Table '{name}' does not exist")
|
297
293
|
|
298
294
|
# 读取表数据
|
299
295
|
return self._spark.read.table(table_name)
|
300
296
|
|
301
297
|
except Exception as e:
|
302
|
-
raise ValueError(f"
|
298
|
+
raise ValueError(f"Failed to read table '{name}': {str(e)}") from e
|
303
299
|
|
304
300
|
def drop_table(self, name: str):
|
305
301
|
|
@@ -327,15 +323,17 @@ class FeatureTableClient:
|
|
327
323
|
try:
|
328
324
|
# 检查表是否存在
|
329
325
|
if not self._spark.catalog.tableExists(table_name):
|
330
|
-
|
326
|
+
print(f"Table '{name}' does not exist")
|
327
|
+
return
|
331
328
|
|
332
329
|
# 执行删除
|
333
330
|
self._spark.sql(f"DROP TABLE {table_name}")
|
331
|
+
print(f"Table '{name}' dropped")
|
334
332
|
|
335
333
|
except ValueError as e:
|
336
334
|
raise # 直接抛出已知的ValueError
|
337
335
|
except Exception as e:
|
338
|
-
raise RuntimeError(f"
|
336
|
+
raise RuntimeError(f"Failed to delete table '{name}': {str(e)}") from e
|
339
337
|
|
340
338
|
def get_table(
|
341
339
|
self,
|
@@ -365,4 +363,4 @@ class FeatureTableClient:
|
|
365
363
|
try:
|
366
364
|
return spark_client.get_feature_table(table_name)
|
367
365
|
except Exception as e:
|
368
|
-
raise ValueError(f"
|
366
|
+
raise ValueError(f"Failed to get metadata for table '{name}': {str(e)}") from e
|
@@ -9,13 +9,39 @@ from pyspark.sql.types import StructType, StringType, StructField
|
|
9
9
|
from wedata.feature_store.entities.feature import Feature
|
10
10
|
from wedata.feature_store.entities.feature_table import FeatureTable
|
11
11
|
from wedata.feature_store.entities.function_info import FunctionParameterInfo, FunctionInfo
|
12
|
-
from wedata.feature_store.utils.common_utils import unsanitize_identifier
|
12
|
+
from wedata.feature_store.utils.common_utils import unsanitize_identifier
|
13
13
|
|
14
14
|
|
15
15
|
class SparkClient:
|
16
16
|
def __init__(self, spark: SparkSession):
|
17
17
|
self._spark = spark
|
18
18
|
|
19
|
+
def _parse_table_name(self, table_name):
|
20
|
+
"""解析表名并返回表名部分
|
21
|
+
|
22
|
+
参数:
|
23
|
+
table_name: 完整表名,支持格式: catalog.schema.table、schema.table 或 table
|
24
|
+
|
25
|
+
返回:
|
26
|
+
str: 解析后的表名部分
|
27
|
+
"""
|
28
|
+
if not isinstance(table_name, str):
|
29
|
+
raise ValueError("Table name must be string type")
|
30
|
+
|
31
|
+
table_name = table_name.strip()
|
32
|
+
if not table_name:
|
33
|
+
raise ValueError("Table name cannot be empty")
|
34
|
+
|
35
|
+
parts = table_name.split('.')
|
36
|
+
if len(parts) == 3:
|
37
|
+
# 对于三部分名称(catalog.schema.table),只使用表名部分
|
38
|
+
return parts[2]
|
39
|
+
elif len(parts) == 2:
|
40
|
+
# 对于两部分名称(schema.table),只使用表名部分
|
41
|
+
return parts[1]
|
42
|
+
else:
|
43
|
+
# 单表名,直接使用
|
44
|
+
return table_name
|
19
45
|
|
20
46
|
def get_current_catalog(self):
|
21
47
|
"""
|
@@ -66,19 +92,13 @@ class SparkClient:
|
|
66
92
|
"""
|
67
93
|
try:
|
68
94
|
# 解析表名
|
69
|
-
|
70
|
-
if len(parts) == 3:
|
71
|
-
catalog, schema, table = parts
|
72
|
-
elif len(parts) == 2:
|
73
|
-
schema, table = parts
|
74
|
-
else:
|
75
|
-
table = table_name
|
95
|
+
schema_table_name = self._parse_table_name(table_name)
|
76
96
|
|
77
97
|
# 验证表是否存在
|
78
|
-
if not self._spark.catalog.tableExists(
|
98
|
+
if not self._spark.catalog.tableExists(schema_table_name):
|
79
99
|
raise ValueError(f"表不存在: {table_name}")
|
80
100
|
|
81
|
-
return self._spark.table(
|
101
|
+
return self._spark.table(schema_table_name)
|
82
102
|
|
83
103
|
except Exception as e:
|
84
104
|
raise ValueError(f"读取表 {table_name} 失败: {str(e)}")
|
@@ -86,23 +106,10 @@ class SparkClient:
|
|
86
106
|
|
87
107
|
def get_features(self, table_name):
|
88
108
|
# 解析表名
|
89
|
-
|
90
|
-
if len(parts) == 3:
|
91
|
-
# 对于三部分名称(catalog.schema.table),使用schema.table格式
|
92
|
-
_, schema, table = parts
|
93
|
-
full_table_name = f"{schema}.{table}"
|
94
|
-
elif len(parts) == 2:
|
95
|
-
# 对于两部分名称(schema.table),直接使用
|
96
|
-
full_table_name = table_name
|
97
|
-
else:
|
98
|
-
# 单表名,使用当前数据库
|
99
|
-
current_db = self.get_current_database()
|
100
|
-
if not current_db:
|
101
|
-
raise ValueError("无法确定当前数据库")
|
102
|
-
full_table_name = f"{current_db}.{table_name}"
|
109
|
+
schema_table_name = self._parse_table_name(table_name)
|
103
110
|
|
104
111
|
# 使用dbName.tableName格式查询列信息
|
105
|
-
columns = self._spark.catalog.listColumns(tableName=
|
112
|
+
columns = self._spark.catalog.listColumns(tableName=schema_table_name)
|
106
113
|
return [
|
107
114
|
Feature(
|
108
115
|
feature_table=table_name,
|
@@ -114,28 +121,26 @@ class SparkClient:
|
|
114
121
|
]
|
115
122
|
|
116
123
|
def get_feature_table(self, table_name):
|
124
|
+
# 解析表名
|
125
|
+
schema_table_name = self._parse_table_name(table_name)
|
117
126
|
|
118
127
|
# 获取表元数据
|
119
|
-
table = self._spark.catalog.getTable(
|
128
|
+
table = self._spark.catalog.getTable(schema_table_name)
|
120
129
|
|
121
|
-
parts = table_name.split('.')
|
122
|
-
if len(parts) == 3:
|
123
|
-
# 对于三部分名称(catalog.schema.table),只使用表名部分
|
124
|
-
table_to_describe = parts[2]
|
125
|
-
elif len(parts) == 2:
|
126
|
-
# 对于两部分名称(schema.table),只使用表名部分
|
127
|
-
table_to_describe = parts[1]
|
128
|
-
else:
|
129
|
-
# 单表名,直接使用
|
130
|
-
table_to_describe = table_name
|
131
130
|
# 获取表详细信息
|
132
|
-
table_details = self._spark.sql(f"DESCRIBE TABLE EXTENDED {
|
131
|
+
table_details = self._spark.sql(f"DESCRIBE TABLE EXTENDED {schema_table_name}").collect()
|
133
132
|
|
134
133
|
table_properties = {}
|
135
134
|
for row in table_details:
|
136
135
|
if row.col_name == "Table Properties":
|
137
136
|
props = row.data_type[1:-1].split(", ")
|
138
|
-
table_properties =
|
137
|
+
table_properties = {}
|
138
|
+
for p in props:
|
139
|
+
if "=" in p:
|
140
|
+
parts = p.split("=", 1)
|
141
|
+
key = parts[0].strip()
|
142
|
+
value = parts[1].strip() if len(parts) > 1 else ""
|
143
|
+
table_properties[key] = value
|
139
144
|
|
140
145
|
# 获取特征列信息
|
141
146
|
features = self.get_features(table_name)
|
@@ -186,6 +186,7 @@ class TrainingSetClient:
|
|
186
186
|
flavor: ModuleType,
|
187
187
|
training_set: Optional[TrainingSet],
|
188
188
|
registered_model_name: Optional[str],
|
189
|
+
model_registry_uri: Optional[str],
|
189
190
|
await_registration_for: int,
|
190
191
|
infer_input_example: bool,
|
191
192
|
**kwargs,
|
@@ -334,8 +335,7 @@ class TrainingSetClient:
|
|
334
335
|
except Exception:
|
335
336
|
input_example = None
|
336
337
|
|
337
|
-
|
338
|
-
#feature_spec.save(data_path)
|
338
|
+
feature_spec.save(data_path)
|
339
339
|
|
340
340
|
# Log the packaged model. If no run is active, this call will create an active run.
|
341
341
|
mlflow.pyfunc.log_model(
|
@@ -355,13 +355,12 @@ class TrainingSetClient:
|
|
355
355
|
# If the user provided an explicit model_registry_uri when constructing the FeatureStoreClient,
|
356
356
|
# we respect this by setting the registry URI prior to reading the model from Model
|
357
357
|
# Registry.
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
# mlflow.set_registry_uri(self._model_registry_uri)
|
358
|
+
if model_registry_uri is not None:
|
359
|
+
# This command will override any previously set registry_uri.
|
360
|
+
mlflow.set_registry_uri(model_registry_uri)
|
362
361
|
|
363
362
|
mlflow.register_model(
|
364
363
|
"runs:/%s/%s" % (run_id, artifact_path),
|
365
364
|
registered_model_name,
|
366
365
|
await_registration_for=await_registration_for,
|
367
|
-
|
366
|
+
)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|