staran 0.6.1__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,603 +0,0 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- 特征生成器模块
6
- 基于表结构生成各种数据分析特征
7
- """
8
-
9
- from typing import List, Dict, Set, Optional
10
- from datetime import datetime
11
- from .schema import TableSchema, Field, FieldType
12
- from .manager import FeatureManager, FeatureTableManager
13
-
14
-
15
- class FeatureType:
16
- """特征类型常量"""
17
- RAW_COPY = "raw_copy"
18
- AGGREGATION = "aggregation"
19
- MOM = "mom" # Month over Month
20
- YOY = "yoy" # Year over Year
21
-
22
- @classmethod
23
- def get_all(cls) -> List[str]:
24
- """获取所有特征类型"""
25
- return [cls.RAW_COPY, cls.AGGREGATION, cls.MOM, cls.YOY]
26
-
27
-
28
- class AggregationType:
29
- """聚合类型常量"""
30
- SUM = "sum"
31
- AVG = "avg"
32
- MIN = "min"
33
- MAX = "max"
34
- COUNT = "count"
35
- VARIANCE = "variance"
36
- STDDEV = "stddev"
37
-
38
- @classmethod
39
- def get_all(cls) -> List[str]:
40
- """获取所有聚合类型"""
41
- return [cls.SUM, cls.AVG, cls.MIN, cls.MAX, cls.COUNT, cls.VARIANCE, cls.STDDEV]
42
-
43
- @classmethod
44
- def get_numeric_only(cls) -> List[str]:
45
- """获取仅适用于数值字段的聚合类型"""
46
- return [cls.SUM, cls.AVG, cls.MIN, cls.MAX, cls.VARIANCE, cls.STDDEV]
47
-
48
-
49
- class FeatureConfig:
50
- """特征生成配置"""
51
-
52
- def __init__(self):
53
- # 默认只生成基础特征
54
- self.enabled_features = {
55
- FeatureType.RAW_COPY: True,
56
- FeatureType.AGGREGATION: True,
57
- FeatureType.MOM: False, # 默认不生成
58
- FeatureType.YOY: False # 默认不生成
59
- }
60
-
61
- # 聚合类型配置(默认只使用常用的)
62
- self.aggregation_types = [AggregationType.SUM, AggregationType.AVG, AggregationType.COUNT]
63
-
64
- # 环比和同比配置(默认只推一个周期)
65
- self.mom_periods = [1] # 1个月环比
66
- self.yoy_periods = [1] # 1年同比(12个月)
67
-
68
- def enable_feature(self, feature_type: str) -> 'FeatureConfig':
69
- """启用特征类型"""
70
- if feature_type in self.enabled_features:
71
- self.enabled_features[feature_type] = True
72
- return self
73
-
74
- def disable_feature(self, feature_type: str) -> 'FeatureConfig':
75
- """禁用特征类型"""
76
- if feature_type in self.enabled_features:
77
- self.enabled_features[feature_type] = False
78
- return self
79
-
80
- def set_aggregation_types(self, types: List[str]) -> 'FeatureConfig':
81
- """设置聚合类型"""
82
- self.aggregation_types = types
83
- return self
84
-
85
- def set_mom_periods(self, periods: List[int]) -> 'FeatureConfig':
86
- """设置环比周期"""
87
- self.mom_periods = periods
88
- return self
89
-
90
- def set_yoy_periods(self, periods: List[int]) -> 'FeatureConfig':
91
- """设置同比周期"""
92
- self.yoy_periods = periods
93
- return self
94
-
95
- def is_feature_enabled(self, feature_type: str) -> bool:
96
- """检查特征类型是否启用"""
97
- return self.enabled_features.get(feature_type, False)
98
-
99
-
100
- class FeatureGenerator:
101
- """特征生成器"""
102
-
103
- def __init__(self,
104
- schema: TableSchema,
105
- feature_manager: Optional[FeatureManager] = None,
106
- config: FeatureConfig = None):
107
- """
108
- 初始化特征生成器
109
-
110
- Args:
111
- schema: 表结构定义
112
- feature_manager: SQL管理器(可选)
113
- config: 特征生成配置
114
- """
115
- self.schema = schema
116
- self.feature_manager = feature_manager
117
- self.config = config or FeatureConfig()
118
-
119
- # 验证表结构
120
- self.schema.validate()
121
-
122
- # 如果有SQL管理器,初始化特征表管理器
123
- self.table_manager = None
124
- if self.feature_manager:
125
- self.table_manager = FeatureTableManager(self.feature_manager)
126
-
127
- def generate_feature_by_type(self,
128
- feature_type: str,
129
- year: int,
130
- month: int,
131
- feature_num: int = 1) -> Dict[str, str]:
132
- """
133
- 按特征类型生成SQL
134
-
135
- Args:
136
- feature_type: 特征类型(raw_copy, aggregation, mom, yoy)
137
- year: 年份
138
- month: 月份
139
- feature_num: 特征编号
140
-
141
- Returns:
142
- Dict包含SQL和表名信息
143
- """
144
- if not self.config.is_feature_enabled(feature_type):
145
- raise ValueError(f"特征类型 {feature_type} 未启用")
146
-
147
- # 生成对应类型的SQL
148
- if feature_type == FeatureType.RAW_COPY:
149
- sql = self._generate_raw_copy_sql()
150
- elif feature_type == FeatureType.AGGREGATION:
151
- sql = self._generate_aggregation_sql()
152
- elif feature_type == FeatureType.MOM:
153
- sql = self._generate_mom_sql()
154
- elif feature_type == FeatureType.YOY:
155
- sql = self._generate_yoy_sql()
156
- else:
157
- raise ValueError(f"不支持的特征类型: {feature_type}")
158
-
159
- result = {
160
- 'feature_type': feature_type,
161
- 'sql': sql,
162
- 'year': year,
163
- 'month': month,
164
- 'feature_num': feature_num
165
- }
166
-
167
- # 如果有SQL管理器,生成表名
168
- if self.feature_manager:
169
- table_name = self.feature_manager.generate_feature_table_name(
170
- self.schema.table_name, year, month, feature_num
171
- )
172
- result['table_name'] = table_name
173
-
174
- return result
175
-
176
- def create_feature_table(self,
177
- feature_type: str,
178
- year: int,
179
- month: int,
180
- feature_num: int = 1,
181
- execute: bool = False) -> str:
182
- """
183
- 创建特征表
184
-
185
- Args:
186
- feature_type: 特征类型
187
- year: 年份
188
- month: 月份
189
- feature_num: 特征编号
190
- execute: 是否立即执行
191
-
192
- Returns:
193
- 特征表名
194
- """
195
- if not self.table_manager:
196
- raise ValueError("需要SQL管理器才能创建特征表")
197
-
198
- feature_info = self.generate_feature_by_type(feature_type, year, month, feature_num)
199
-
200
- return self.table_manager.create_feature_table(
201
- base_table=self.schema.table_name,
202
- year=year,
203
- month=month,
204
- feature_num=feature_num,
205
- sql=feature_info['sql'],
206
- execute=execute
207
- )
208
-
209
- def _generate_raw_copy_sql(self) -> str:
210
- """生成原始字段拷贝SQL"""
211
- pk_field = self.schema.primary_key
212
- date_field = self.schema.date_field
213
-
214
- select_parts = [
215
- f" {pk_field}",
216
- f" year({date_field}) as year",
217
- f" month({date_field}) as month",
218
- f" date_format({date_field}, 'yyyy-MM') as year_month"
219
- ]
220
-
221
- # 添加非聚合字段
222
- for field in self.schema.get_non_aggregatable_fields():
223
- select_parts.append(f" first({field.name}) as {field.name}")
224
-
225
- base_table = self.feature_manager.get_full_table_name(self.schema.table_name) if self.feature_manager else self.schema.table_name
226
-
227
- return f"""SELECT
228
- {',\\n'.join(select_parts)}
229
- FROM {base_table}
230
- GROUP BY {pk_field}, year({date_field}), month({date_field})
231
- ORDER BY {pk_field}, year, month"""
232
-
233
- def _generate_aggregation_sql(self) -> str:
234
- """生成聚合统计SQL"""
235
- pk_field = self.schema.primary_key
236
- date_field = self.schema.date_field
237
-
238
- select_parts = [
239
- f" {pk_field}",
240
- f" year({date_field}) as year",
241
- f" month({date_field}) as month",
242
- f" date_format({date_field}, 'yyyy-MM') as year_month"
243
- ]
244
-
245
- # 聚合统计特征
246
- for field in self.schema.get_aggregatable_fields():
247
- for agg_type in self.config.aggregation_types:
248
- if self._is_agg_applicable(field, agg_type):
249
- agg_expr = self._get_agg_expression(field.name, agg_type)
250
- select_parts.append(f" {agg_expr} as {field.name}_{agg_type}")
251
-
252
- base_table = self.feature_manager.get_full_table_name(self.schema.table_name) if self.feature_manager else self.schema.table_name
253
-
254
- return f"""SELECT
255
- {',\\n'.join(select_parts)}
256
- FROM {base_table}
257
- GROUP BY {pk_field}, year({date_field}), month({date_field})
258
- ORDER BY {pk_field}, year, month"""
259
-
260
- def _generate_mom_sql(self) -> str:
261
- """生成环比特征SQL"""
262
- if not self.schema.is_monthly_unique:
263
- raise ValueError("环比特征需要每人每月唯一数据")
264
-
265
- pk_field = self.schema.primary_key
266
- date_field = self.schema.date_field
267
-
268
- # 需要先有基础聚合数据
269
- base_sql = self._generate_aggregation_sql()
270
-
271
- select_parts = [
272
- f" {pk_field}",
273
- " year_month"
274
- ]
275
-
276
- # 环比特征
277
- for field in self.schema.get_aggregatable_fields():
278
- for agg_type in self.config.aggregation_types:
279
- if self._is_agg_applicable(field, agg_type):
280
- field_name = f"{field.name}_{agg_type}"
281
- for period in self.config.mom_periods:
282
- mom_expr = f"{field_name} - lag({field_name}, {period}) OVER (PARTITION BY {pk_field} ORDER BY year_month)"
283
- select_parts.append(f" {mom_expr} as {field_name}_mom_{period}m")
284
-
285
- return f"""WITH base_agg AS (
286
- {base_sql}
287
- )
288
- SELECT
289
- {',\\n'.join(select_parts)}
290
- FROM base_agg
291
- ORDER BY {pk_field}, year_month"""
292
-
293
- def _generate_yoy_sql(self) -> str:
294
- """生成同比特征SQL"""
295
- if not self.schema.is_monthly_unique:
296
- raise ValueError("同比特征需要每人每月唯一数据")
297
-
298
- pk_field = self.schema.primary_key
299
-
300
- # 需要先有基础聚合数据
301
- base_sql = self._generate_aggregation_sql()
302
-
303
- select_parts = [
304
- f" {pk_field}",
305
- " year_month"
306
- ]
307
-
308
- # 同比特征
309
- for field in self.schema.get_aggregatable_fields():
310
- for agg_type in self.config.aggregation_types:
311
- if self._is_agg_applicable(field, agg_type):
312
- field_name = f"{field.name}_{agg_type}"
313
- for period in self.config.yoy_periods:
314
- months = period * 12 # 年转换为月
315
- yoy_expr = f"{field_name} - lag({field_name}, {months}) OVER (PARTITION BY {pk_field} ORDER BY year_month)"
316
- select_parts.append(f" {yoy_expr} as {field_name}_yoy_{period}y")
317
-
318
- return f"""WITH base_agg AS (
319
- {base_sql}
320
- )
321
- SELECT
322
- {',\\n'.join(select_parts)}
323
- FROM base_agg
324
- ORDER BY {pk_field}, year_month"""
325
-
326
- def generate_feature_list(self) -> Dict[str, List[str]]:
327
- """
328
- 生成特征列表
329
-
330
- Returns:
331
- Dict[str, List[str]]: 按类型分组的特征列表
332
- """
333
- features = {
334
- 'raw_copy': [],
335
- 'aggregation': [],
336
- 'mom': [],
337
- 'yoy': []
338
- }
339
-
340
- # 1. 原始字段拷贝
341
- if self.config.is_feature_enabled(FeatureType.RAW_COPY):
342
- features['raw_copy'] = [
343
- field.name for field in self.schema.get_non_aggregatable_fields()
344
- ]
345
-
346
- # 2. 聚合统计特征
347
- if self.config.is_feature_enabled(FeatureType.AGGREGATION):
348
- for field in self.schema.get_aggregatable_fields():
349
- for agg_type in self.config.aggregation_types:
350
- if self._is_agg_applicable(field, agg_type):
351
- feature_name = f"{field.name}_{agg_type}"
352
- features['aggregation'].append(feature_name)
353
-
354
- # 3. 环比特征
355
- if self.config.is_feature_enabled(FeatureType.MOM) and self.schema.is_monthly_unique:
356
- for field in self.schema.get_aggregatable_fields():
357
- for agg_type in self.config.aggregation_types:
358
- if self._is_agg_applicable(field, agg_type):
359
- for period in self.config.mom_periods:
360
- feature_name = f"{field.name}_{agg_type}_mom_{period}m"
361
- features['mom'].append(feature_name)
362
-
363
- # 4. 同比特征
364
- if self.config.is_feature_enabled(FeatureType.YOY) and self.schema.is_monthly_unique:
365
- for field in self.schema.get_aggregatable_fields():
366
- for agg_type in self.config.aggregation_types:
367
- if self._is_agg_applicable(field, agg_type):
368
- for period in self.config.yoy_periods:
369
- feature_name = f"{field.name}_{agg_type}_yoy_{period}y"
370
- features['yoy'].append(feature_name)
371
-
372
- return features
373
-
374
- def _is_agg_applicable(self, field: Field, agg_type: str) -> bool:
375
- """检查聚合类型是否适用于字段"""
376
- # COUNT适用于所有字段
377
- if agg_type == AggregationType.COUNT:
378
- return True
379
-
380
- # 数值聚合仅适用于数值字段
381
- if agg_type in AggregationType.get_numeric_only():
382
- return field.field_type in [
383
- FieldType.INTEGER, FieldType.BIGINT, FieldType.DECIMAL,
384
- FieldType.DOUBLE, FieldType.FLOAT
385
- ]
386
-
387
- return True
388
-
389
- def _get_agg_expression(self, field_name: str, agg_type: str) -> str:
390
- """获取聚合表达式"""
391
- agg_map = {
392
- 'sum': f'sum({field_name})',
393
- 'avg': f'avg({field_name})',
394
- 'min': f'min({field_name})',
395
- 'max': f'max({field_name})',
396
- 'count': f'count({field_name})',
397
- 'variance': f'variance({field_name})',
398
- 'stddev': f'stddev({field_name})'
399
- }
400
- return agg_map.get(agg_type, f'{agg_type}({field_name})')
401
-
402
- def generate_spark_sql(self) -> str:
403
- """生成Spark SQL(兼容旧接口)"""
404
- if not self.schema.is_monthly_unique:
405
- return self._generate_aggregation_sql()
406
- else:
407
- # 生成完整的特征SQL(包含所有启用的特征)
408
- return self._generate_complete_feature_sql()
409
-
410
- def _generate_complete_feature_sql(self) -> str:
411
- """生成完整的特征SQL"""
412
- base_table = self.feature_manager.get_full_table_name(self.schema.table_name) if self.feature_manager else self.schema.table_name
413
- pk_field = self.schema.primary_key
414
- date_field = self.schema.date_field
415
-
416
- sql_parts = []
417
-
418
- # 基础数据CTE
419
- sql_parts.append(self._build_base_data_cte())
420
-
421
- # 聚合特征CTE
422
- if self.config.is_feature_enabled(FeatureType.AGGREGATION):
423
- sql_parts.append(self._build_aggregation_cte())
424
-
425
- # 环比特征CTE
426
- if self.config.is_feature_enabled(FeatureType.MOM):
427
- sql_parts.append(self._build_mom_cte())
428
-
429
- # 同比特征CTE
430
- if self.config.is_feature_enabled(FeatureType.YOY):
431
- sql_parts.append(self._build_yoy_cte())
432
-
433
- # 最终结果
434
- sql_parts.append(self._build_final_select())
435
-
436
- header = f"""-- Spark SQL: 特征工程
437
- -- 表: {self.schema.table_name}
438
- -- 生成时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
439
- """
440
-
441
- return header + ",\n\n".join(sql_parts)
442
-
443
- def _build_base_data_cte(self) -> str:
444
- """构建基础数据CTE"""
445
- pk_field = self.schema.primary_key
446
- date_field = self.schema.date_field
447
- base_table = self.feature_manager.get_full_table_name(self.schema.table_name) if self.feature_manager else self.schema.table_name
448
-
449
- select_parts = [
450
- f" {pk_field}",
451
- f" {date_field}",
452
- f" year({date_field}) as year",
453
- f" month({date_field}) as month",
454
- f" date_format({date_field}, 'yyyy-MM') as year_month"
455
- ]
456
-
457
- # 添加所有其他字段
458
- for field in self.schema.fields.values():
459
- if not field.is_primary_key and not field.is_date_field:
460
- select_parts.append(f" {field.name}")
461
-
462
- return f"""base_data AS (
463
- SELECT
464
- {',\\n'.join(select_parts)}
465
- FROM {base_table}
466
- )"""
467
-
468
- def _build_aggregation_cte(self) -> str:
469
- """构建聚合特征CTE"""
470
- pk_field = self.schema.primary_key
471
-
472
- select_parts = [
473
- f" {pk_field}",
474
- " year_month"
475
- ]
476
-
477
- # 原始字段拷贝
478
- if self.config.is_feature_enabled(FeatureType.RAW_COPY):
479
- for field in self.schema.get_non_aggregatable_fields():
480
- select_parts.append(f" first({field.name}) as {field.name}")
481
-
482
- # 聚合统计特征
483
- for field in self.schema.get_aggregatable_fields():
484
- for agg_type in self.config.aggregation_types:
485
- if self._is_agg_applicable(field, agg_type):
486
- agg_expr = self._get_agg_expression(field.name, agg_type)
487
- select_parts.append(f" {agg_expr} as {field.name}_{agg_type}")
488
-
489
- return f"""agg_features AS (
490
- SELECT
491
- {',\\n'.join(select_parts)}
492
- FROM base_data
493
- GROUP BY {pk_field}, year_month
494
- )"""
495
-
496
- def _build_mom_cte(self) -> str:
497
- """构建环比特征CTE"""
498
- pk_field = self.schema.primary_key
499
-
500
- select_parts = [f" a.{pk_field}", " a.year_month"]
501
-
502
- for field in self.schema.get_aggregatable_fields():
503
- for agg_type in self.config.aggregation_types:
504
- if self._is_agg_applicable(field, agg_type):
505
- field_name = f"{field.name}_{agg_type}"
506
- for period in self.config.mom_periods:
507
- mom_expr = f"a.{field_name} - lag(a.{field_name}, {period}) OVER (PARTITION BY a.{pk_field} ORDER BY a.year_month)"
508
- select_parts.append(f" {mom_expr} as {field_name}_mom_{period}m")
509
-
510
- return f"""mom_features AS (
511
- SELECT
512
- {',\\n'.join(select_parts)}
513
- FROM agg_features a
514
- )"""
515
-
516
- def _build_yoy_cte(self) -> str:
517
- """构建同比特征CTE"""
518
- pk_field = self.schema.primary_key
519
-
520
- select_parts = [f" a.{pk_field}", " a.year_month"]
521
-
522
- for field in self.schema.get_aggregatable_fields():
523
- for agg_type in self.config.aggregation_types:
524
- if self._is_agg_applicable(field, agg_type):
525
- field_name = f"{field.name}_{agg_type}"
526
- for period in self.config.yoy_periods:
527
- months = period * 12 # 年转换为月
528
- yoy_expr = f"a.{field_name} - lag(a.{field_name}, {months}) OVER (PARTITION BY a.{pk_field} ORDER BY a.year_month)"
529
- select_parts.append(f" {yoy_expr} as {field_name}_yoy_{period}y")
530
-
531
- return f"""yoy_features AS (
532
- SELECT
533
- {',\\n'.join(select_parts)}
534
- FROM agg_features a
535
- )"""
536
-
537
- def _build_final_select(self) -> str:
538
- """构建最终SELECT"""
539
- pk_field = self.schema.primary_key
540
-
541
- # 构建JOIN逻辑
542
- joins = []
543
- if self.config.is_feature_enabled(FeatureType.MOM):
544
- joins.append(f"LEFT JOIN mom_features m ON a.{pk_field} = m.{pk_field} AND a.year_month = m.year_month")
545
- if self.config.is_feature_enabled(FeatureType.YOY):
546
- joins.append(f"LEFT JOIN yoy_features y ON a.{pk_field} = y.{pk_field} AND a.year_month = y.year_month")
547
-
548
- select_fields = ["a.*"]
549
- if self.config.is_feature_enabled(FeatureType.MOM):
550
- for field in self.schema.get_aggregatable_fields():
551
- for agg_type in self.config.aggregation_types:
552
- if self._is_agg_applicable(field, agg_type):
553
- for period in self.config.mom_periods:
554
- select_fields.append(f"m.{field.name}_{agg_type}_mom_{period}m")
555
-
556
- if self.config.is_feature_enabled(FeatureType.YOY):
557
- for field in self.schema.get_aggregatable_fields():
558
- for agg_type in self.config.aggregation_types:
559
- if self._is_agg_applicable(field, agg_type):
560
- for period in self.config.yoy_periods:
561
- select_fields.append(f"y.{field.name}_{agg_type}_yoy_{period}y")
562
-
563
- join_clause = "\n".join(joins) if joins else ""
564
-
565
- return f"""SELECT
566
- {',\\n '.join(select_fields)}
567
- FROM agg_features a
568
- {join_clause}
569
- ORDER BY a.{pk_field}, a.year_month"""
570
-
571
- def get_feature_summary(self) -> Dict[str, int]:
572
- """获取特征统计摘要"""
573
- features = self.generate_feature_list()
574
- return {
575
- 'total': sum(len(feature_list) for feature_list in features.values()),
576
- 'raw_copy': len(features['raw_copy']),
577
- 'aggregation': len(features['aggregation']),
578
- 'mom': len(features['mom']),
579
- 'yoy': len(features['yoy'])
580
- }
581
-
582
- def print_feature_summary(self):
583
- """打印特征摘要"""
584
- features = self.generate_feature_list()
585
- summary = self.get_feature_summary()
586
-
587
- print(f"特征生成摘要 - 表: {self.schema.table_name}")
588
- print("=" * 50)
589
- print(f"总特征数: {summary['total']}")
590
- print(f"原始拷贝: {summary['raw_copy']} (启用: {self.config.is_feature_enabled(FeatureType.RAW_COPY)})")
591
- print(f"聚合统计: {summary['aggregation']} (启用: {self.config.is_feature_enabled(FeatureType.AGGREGATION)})")
592
- print(f"环比特征: {summary['mom']} (启用: {self.config.is_feature_enabled(FeatureType.MOM)})")
593
- print(f"同比特征: {summary['yoy']} (启用: {self.config.is_feature_enabled(FeatureType.YOY)})")
594
- print()
595
-
596
- for category, feature_list in features.items():
597
- if feature_list:
598
- print(f"{category.upper()} ({len(feature_list)}):")
599
- for feature in feature_list[:5]: # 只显示前5个
600
- print(f" - {feature}")
601
- if len(feature_list) > 5:
602
- print(f" ... 还有 {len(feature_list) - 5} 个特征")
603
- print()