wedata-feature-engineering 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. feature_store/constants/__init__.py +0 -0
  2. feature_store/constants/constants.py +28 -0
  3. feature_store/entities/__init__.py +0 -0
  4. feature_store/entities/column_info.py +117 -0
  5. feature_store/entities/data_type.py +92 -0
  6. feature_store/entities/environment_variables.py +55 -0
  7. feature_store/entities/feature.py +53 -0
  8. feature_store/entities/feature_column_info.py +64 -0
  9. feature_store/entities/feature_function.py +55 -0
  10. feature_store/entities/feature_lookup.py +179 -0
  11. feature_store/entities/feature_spec.py +454 -0
  12. feature_store/entities/feature_spec_constants.py +25 -0
  13. feature_store/entities/feature_table.py +164 -0
  14. feature_store/entities/feature_table_info.py +40 -0
  15. feature_store/entities/function_info.py +184 -0
  16. feature_store/entities/on_demand_column_info.py +44 -0
  17. feature_store/entities/source_data_column_info.py +21 -0
  18. feature_store/entities/training_set.py +134 -0
  19. feature_store/feature_table_client/__init__.py +0 -0
  20. feature_store/feature_table_client/feature_table_client.py +313 -0
  21. feature_store/spark_client/__init__.py +0 -0
  22. feature_store/spark_client/spark_client.py +286 -0
  23. feature_store/training_set_client/__init__.py +0 -0
  24. feature_store/training_set_client/training_set_client.py +196 -0
  25. {wedata_feature_engineering-0.1.0.dist-info → wedata_feature_engineering-0.1.2.dist-info}/METADATA +1 -1
  26. wedata_feature_engineering-0.1.2.dist-info/RECORD +30 -0
  27. wedata_feature_engineering-0.1.0.dist-info/RECORD +0 -6
  28. {wedata_feature_engineering-0.1.0.dist-info → wedata_feature_engineering-0.1.2.dist-info}/WHEEL +0 -0
  29. {wedata_feature_engineering-0.1.0.dist-info → wedata_feature_engineering-0.1.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,184 @@
1
+ from collections import defaultdict
2
+ from typing import List, Optional
3
+
4
+ from pyspark.sql import Column, DataFrame
5
+ from pyspark.sql.functions import isnull, when
6
+ from pyspark.sql.types import StringType, StructField, StructType
7
+
8
+ class FunctionParameterInfo():
9
+ def __init__(self, name: str, type_text: str):
10
+ self._name = name
11
+ self._type_text = type_text
12
+
13
+ @property
14
+ def name(self) -> str:
15
+ return self._name
16
+
17
+ @property
18
+ def type_text(self) -> str:
19
+ return self._type_text
20
+
21
+ @classmethod
22
+ def from_dict(cls, function_parameter_info_json):
23
+ return FunctionParameterInfo(
24
+ function_parameter_info_json["name"],
25
+ function_parameter_info_json["type_text"],
26
+ )
27
+
28
+
29
+ class FunctionInfo():
30
+ """
31
+ Helper entity class that exposes properties in GetFunction's response JSON as attributes.
32
+ https://docs.databricks.com/api-explorer/workspace/functions/get
33
+
34
+ Note: empty fields (e.g. when 0 input parameters) are not included in the response JSON.
35
+ """
36
+
37
+ # Python UDFs have external_language = "Python"
38
+ PYTHON = "Python"
39
+
40
+ def __init__(
41
+ self,
42
+ full_name: str,
43
+ input_params: List[FunctionParameterInfo],
44
+ routine_definition: Optional[str],
45
+ external_language: Optional[str],
46
+ ):
47
+ self._full_name = full_name
48
+ self._input_params = input_params
49
+ self._routine_definition = routine_definition
50
+ self._external_language = external_language
51
+
52
+ @property
53
+ def full_name(self) -> str:
54
+ return self._full_name
55
+
56
+ @property
57
+ def input_params(self) -> List[FunctionParameterInfo]:
58
+ return self._input_params
59
+
60
+ @property
61
+ def routine_definition(self) -> Optional[str]:
62
+ return self._routine_definition
63
+
64
+ @property
65
+ def external_language(self) -> Optional[str]:
66
+ """
67
+ Field is None if language is SQL (not an external language).
68
+ """
69
+ return self._external_language
70
+
71
+ @classmethod
72
+ def from_dict(cls, function_info_json):
73
+ input_params = function_info_json.get("input_params", {}).get("parameters", [])
74
+ return FunctionInfo(
75
+ full_name=function_info_json["full_name"],
76
+ input_params=[FunctionParameterInfo.from_dict(p) for p in input_params],
77
+ routine_definition=function_info_json.get("routine_definition", None),
78
+ external_language=function_info_json.get("external_language", None),
79
+ )
80
+
81
+
82
+ class InformationSchemaSparkClient:
83
+ """
84
+ Internal client to retrieve Unity Catalog metadata from system.information_schema.
85
+ https://docs.databricks.com/sql/language-manual/sql-ref-information-schema.html
86
+ """
87
+
88
+ def _get_routines_with_parameters(self, full_routine_names: List[str]) -> DataFrame:
89
+ """
90
+ Retrieve the routines with their parameters from information_schema.routines, information_schema.parameters.
91
+ Return DataFrame only contains routines that 1. exist and 2. the caller has GetFunction permission on.
92
+
93
+ Note: The returned DataFrame contains the cartesian product of routines and parameters.
94
+ For efficiency, routines table columns are only present in the first row for each routine.
95
+ """
96
+ routine_name_schema = StructType(
97
+ [
98
+ StructField("specific_catalog", StringType(), False),
99
+ StructField("specific_schema", StringType(), False),
100
+ StructField("specific_name", StringType(), False),
101
+ ]
102
+ )
103
+ routine_names_df = self._spark_client.createDataFrame(
104
+ [full_routine_name.split(".") for full_routine_name in full_routine_names],
105
+ routine_name_schema,
106
+ )
107
+ routines_table = self._spark_client.read_table(
108
+ "system.information_schema.routines"
109
+ )
110
+ parameters_table = self._spark_client.read_table(
111
+ "system.information_schema.parameters"
112
+ )
113
+
114
+ # Inner join routines table to filter out non-existent routines.
115
+ # Left join parameters as routines may have no parameters.
116
+ full_routines_with_parameters_df = routine_names_df.join(
117
+ routines_table, on=routine_names_df.columns, how="inner"
118
+ ).join(parameters_table, on=routine_names_df.columns, how="left")
119
+
120
+ # Return only relevant metadata from information_schema, sorted by routine name + parameter order.
121
+ # For efficiency, only preserve routine column values in the first of each routine's result rows.
122
+ # The first row will have parameter.ordinal_value is None (no parameters) or equals 0 (first parameter).
123
+ def select_if_first_row(col: Column) -> Column:
124
+ return when(
125
+ isnull(parameters_table.ordinal_position)
126
+ | (parameters_table.ordinal_position == 0),
127
+ col,
128
+ ).otherwise(None)
129
+
130
+ return full_routines_with_parameters_df.select(
131
+ routine_names_df.columns
132
+ + [
133
+ select_if_first_row(routines_table.routine_definition).alias(
134
+ "routine_definition"
135
+ ),
136
+ select_if_first_row(routines_table.external_language).alias(
137
+ "external_language"
138
+ ),
139
+ parameters_table.ordinal_position,
140
+ parameters_table.parameter_name,
141
+ parameters_table.full_data_type,
142
+ ]
143
+ ).sort(routine_names_df.columns + [parameters_table.ordinal_position])
144
+
145
+ def get_functions(self, full_function_names: List[str]) -> List[FunctionInfo]:
146
+ """
147
+ Retrieves and maps Unity Catalog functions' metadata as FunctionInfos.
148
+ """
149
+ # Avoid unnecessary Spark calls and return if empty.
150
+ if not full_function_names:
151
+ return []
152
+
153
+ # Collect dict of routine name -> DataFrame rows describing the routine.
154
+ routines_with_parameters_df = self._get_routines_with_parameters(
155
+ full_routine_names=full_function_names
156
+ )
157
+ routine_infos = defaultdict(list)
158
+ for r in routines_with_parameters_df.collect():
159
+ routine_name = f"{r.specific_catalog}.{r.specific_schema}.{r.specific_name}"
160
+ routine_infos[routine_name].append(r)
161
+
162
+ # Mock GetFunction DNE error, since information_schema does not throw.
163
+ for function_name in full_function_names:
164
+ if not function_name in routine_infos:
165
+ raise ValueError(f"Function '{function_name}' does not exist.")
166
+
167
+ # Map routine_infos into FunctionInfos.
168
+ function_infos = []
169
+ for function_name in full_function_names:
170
+ routine_info = routine_infos[function_name][0]
171
+ input_params = [
172
+ FunctionParameterInfo(name=p.parameter_name, type_text=p.full_data_type)
173
+ for p in routine_infos[function_name]
174
+ if p.ordinal_position is not None
175
+ ]
176
+ function_infos.append(
177
+ FunctionInfo(
178
+ full_name=function_name,
179
+ input_params=input_params,
180
+ routine_definition=routine_info.routine_definition,
181
+ external_language=routine_info.external_language,
182
+ )
183
+ )
184
+ return function_infos
@@ -0,0 +1,44 @@
1
+ from typing import Dict
2
+
3
+ class OnDemandColumnInfo:
4
+ def __init__(
5
+ self,
6
+ udf_name: str,
7
+ input_bindings: Dict[str, str],
8
+ output_name: str,
9
+ ):
10
+ if not udf_name:
11
+ raise ValueError("udf_name must be non-empty.")
12
+ if not output_name:
13
+ raise ValueError("output_name must be non-empty.")
14
+
15
+ self._udf_name = udf_name
16
+ self._input_bindings = input_bindings
17
+ self._output_name = output_name
18
+
19
+ @property
20
+ def udf_name(self) -> str:
21
+ return self._udf_name
22
+
23
+ @property
24
+ def input_bindings(self) -> Dict[str, str]:
25
+ """
26
+ input_bindings is serialized as the InputBindings proto message.
27
+ """
28
+ return self._input_bindings
29
+
30
+ @property
31
+ def output_name(self) -> str:
32
+ return self._output_name
33
+
34
+ @classmethod
35
+ def from_proto(cls, on_demand_column_info_proto):
36
+ input_bindings_dict = {
37
+ input_binding.parameter: input_binding.bound_to
38
+ for input_binding in on_demand_column_info_proto.input_bindings
39
+ }
40
+ return OnDemandColumnInfo(
41
+ udf_name=on_demand_column_info_proto.udf_name,
42
+ input_bindings=input_bindings_dict,
43
+ output_name=on_demand_column_info_proto.output_name,
44
+ )
@@ -0,0 +1,21 @@
1
+
2
+ class SourceDataColumnInfo:
3
+ def __init__(self, name: str):
4
+ if not name:
5
+ raise ValueError("name must be non-empty.")
6
+ self._name = name
7
+
8
+ @property
9
+ def name(self):
10
+ return self._name
11
+
12
+ @property
13
+ def output_name(self) -> str:
14
+ """
15
+ This field does not exist in the proto, and is provided for convenience.
16
+ """
17
+ return self._name
18
+
19
+ @classmethod
20
+ def from_proto(cls, source_data_column_info_proto):
21
+ return cls(name=source_data_column_info_proto.name)
@@ -0,0 +1,134 @@
1
+ from typing import Dict, List, Optional
2
+
3
+ from pyspark.sql import DataFrame
4
+
5
+ from feature_store.entities.feature_table import FeatureTable
6
+ from feature_store.entities.function_info import FunctionInfo
7
+ from feature_store.utils.feature_lookup_utils import (
8
+ join_feature_data_if_not_overridden,
9
+ )
10
+
11
+ from feature_store.entities.feature_spec import FeatureSpec
12
+ from feature_store.utils.feature_spec_utils import (
13
+ COLUMN_INFO_TYPE_FEATURE,
14
+ COLUMN_INFO_TYPE_ON_DEMAND,
15
+ COLUMN_INFO_TYPE_SOURCE,
16
+ get_feature_execution_groups,
17
+ )
18
+
19
+
20
+ class TrainingSet:
21
+ """
22
+ .. note::
23
+
24
+ Aliases: `!databricks.feature_engineering.training_set.TrainingSet`, `!databricks.feature_store.training_set.TrainingSet`
25
+
26
+ Class that defines :obj:`TrainingSet` objects.
27
+
28
+ .. note::
29
+
30
+ The :class:`TrainingSet` constructor should not be called directly. Instead,
31
+ call :meth:`create_training_set() <databricks.feature_engineering.client.FeatureEngineeringClient.create_training_set>`.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ feature_spec: FeatureSpec,
37
+ df: DataFrame,
38
+ labels: List[str],
39
+ feature_table_metadata_map: Dict[str, FeatureTable],
40
+ feature_table_data_map: Dict[str, DataFrame],
41
+ uc_function_infos: Dict[str, FunctionInfo],
42
+ use_spark_native_join: Optional[bool] = False,
43
+ ):
44
+ """Initialize a :obj:`TrainingSet` object."""
45
+ assert isinstance(
46
+ labels, list
47
+ ), f"Expected type `list` for argument `labels`. Got '{labels}' with type '{type(labels)}'."
48
+
49
+ self._feature_spec = feature_spec
50
+ self._df = df
51
+ self._labels = labels
52
+ self._feature_table_metadata_map = feature_table_metadata_map
53
+ self._feature_table_data_map = feature_table_data_map
54
+ self._uc_function_infos = uc_function_infos
55
+ self._use_spark_native_join = use_spark_native_join
56
+ # Perform basic validations and resolve FeatureSpec and label column data types.
57
+ self._validate_and_inject_dtypes()
58
+ self._label_data_types = {
59
+ name: data_type for name, data_type in df.dtypes if name in labels
60
+ }
61
+
62
+ @property
63
+ def feature_spec(self) -> FeatureSpec:
64
+ """Define a feature spec."""
65
+ return self._feature_spec
66
+
67
+ def _augment_df(self) -> DataFrame:
68
+ """
69
+ Internal helper to augment DataFrame with feature lookups and on-demand features specified in the FeatureSpec.
70
+ Does not drop excluded columns, and does not overwrite columns that already exist.
71
+ Return column order is df.columns + feature lookups + on-demand features.
72
+ """
73
+ execution_groups = get_feature_execution_groups(
74
+ self.feature_spec, self._df.columns
75
+ )
76
+
77
+ result_df = self._df
78
+ # Iterate over all levels and type of DAG nodes in FeatureSpec and execute them.
79
+ for execution_group in execution_groups:
80
+ if execution_group.type == COLUMN_INFO_TYPE_SOURCE:
81
+ continue
82
+ if execution_group.type == COLUMN_INFO_TYPE_FEATURE:
83
+ # Apply FeatureLookups
84
+ result_df = join_feature_data_if_not_overridden(
85
+ feature_spec=self.feature_spec,
86
+ df=result_df,
87
+ features_to_join=execution_group.features,
88
+ feature_table_metadata_map=self._feature_table_metadata_map,
89
+ feature_table_data_map=self._feature_table_data_map,
90
+ use_spark_native_join=self._use_spark_native_join,
91
+ )
92
+ # elif execution_group.type == COLUMN_INFO_TYPE_ON_DEMAND:
93
+ # # Apply all on-demand UDFs
94
+ # result_df = apply_functions_if_not_overridden(
95
+ # df=result_df,
96
+ # functions_to_apply=execution_group.features,
97
+ # uc_function_infos=self._uc_function_infos,
98
+ # )
99
+ else:
100
+ # This should never be reached.
101
+ raise Exception("Unknown feature execution type:", execution_group.type)
102
+ return result_df
103
+
104
+ def _validate_and_inject_dtypes(self):
105
+ """
106
+ Performs validations through _augment_df (e.g. Delta table exists, Delta and feature table dtypes match),
107
+ then inject the result DataFrame dtypes into the FeatureSpec.
108
+ """
109
+ augmented_df = self._augment_df()
110
+ augmented_df_dtypes = {column: dtype for column, dtype in augmented_df.dtypes}
111
+
112
+ # Inject the result DataFrame column types into the respective ColumnInfo
113
+ for ci in self.feature_spec.column_infos:
114
+ ci._data_type = augmented_df_dtypes[ci.output_name]
115
+
116
+ def load_df(self) -> DataFrame:
117
+ """
118
+ Load a :class:`DataFrame <pyspark.sql.DataFrame>`.
119
+
120
+ Return a :class:`DataFrame <pyspark.sql.DataFrame>` for training.
121
+
122
+ The returned :class:`DataFrame <pyspark.sql.DataFrame>` has columns specified
123
+ in the ``feature_spec`` and ``labels`` parameters provided
124
+ in :meth:`create_training_set() <databricks.feature_engineering.client.FeatureEngineeringClient.create_training_set>`.
125
+
126
+ :return:
127
+ A :class:`DataFrame <pyspark.sql.DataFrame>` for training
128
+ """
129
+ augmented_df = self._augment_df()
130
+ # Return only included columns in order defined by FeatureSpec + labels
131
+ included_columns = [
132
+ ci.output_name for ci in self.feature_spec.column_infos if ci.include
133
+ ] + self._labels
134
+ return augmented_df.select(included_columns)
File without changes
@@ -0,0 +1,313 @@
1
+ """
2
+ 特征表操作相关工具方法
3
+ """
4
+
5
+ from typing import Union, List, Dict, Optional, Sequence, Any
6
+ from pyspark.sql import DataFrame, SparkSession
7
+ from pyspark.sql.streaming import StreamingQuery
8
+ from pyspark.sql.types import StructType
9
+ import os
10
+
11
+ from feature_store.constants.constants import APPEND, DEFAULT_WRITE_STREAM_TRIGGER
12
+
13
+
14
+ class FeatureTableClient:
15
+ """特征表操作类"""
16
+
17
+ def __init__(
18
+ self,
19
+ spark: SparkSession
20
+ ):
21
+ self._spark = spark
22
+
23
+ @staticmethod
24
+ def _normalize_params(
25
+ param: Optional[Union[str, Sequence[str]]],
26
+ default_type: type = list
27
+ ) -> list:
28
+ """统一处理参数标准化"""
29
+ if param is None:
30
+ return default_type()
31
+ return list(param) if isinstance(param, Sequence) else [param]
32
+
33
+ @staticmethod
34
+ def _validate_schema(df: DataFrame, schema: StructType):
35
+ """校验DataFrame和schema的有效性和一致性"""
36
+ # 检查是否同时为空
37
+ if df is None and schema is None:
38
+ raise ValueError("必须提供DataFrame或schema其中之一")
39
+
40
+ # 检查schema匹配
41
+ if df is not None and schema is not None:
42
+ df_schema = df.schema
43
+ if df_schema != schema:
44
+ diff_fields = set(df_schema.fieldNames()).symmetric_difference(set(schema.fieldNames()))
45
+ raise ValueError(
46
+ f"DataFrame与schema不匹配。差异字段: {diff_fields if diff_fields else '字段类型不一致'}"
47
+ )
48
+
49
+ @staticmethod
50
+ def _validate_table_name(name: str):
51
+ """验证特征表命名规范"""
52
+ if name.count('.') < 2:
53
+ raise ValueError("特征表名称需符合<catalog>.<schema>.<table>格式")
54
+
55
+ @staticmethod
56
+ def _validate_key_conflicts(primary_keys: List[str], timestamp_keys: List[str]):
57
+ """校验主键与时间戳键是否冲突"""
58
+ conflict_keys = set(timestamp_keys) & set(primary_keys)
59
+ if conflict_keys:
60
+ raise ValueError(f"时间戳键与主键冲突: {conflict_keys}")
61
+
62
+ @staticmethod
63
+ def _escape_sql_value(value: str) -> str:
64
+ """转义SQL值中的特殊字符"""
65
+ return value.replace("'", "''")
66
+
67
+ def create_table(
68
+ self,
69
+ name: str,
70
+ primary_keys: Union[str, List[str]],
71
+ df: Optional[DataFrame] = None,
72
+ *,
73
+ timestamp_keys: Union[str, List[str], None] = None,
74
+ partition_columns: Union[str, List[str], None] = None,
75
+ schema: Optional[StructType] = None,
76
+ description: Optional[str] = None,
77
+ tags: Optional[Dict[str, str]] = None
78
+ ):
79
+ """
80
+ 创建特征表(支持批流数据写入)
81
+
82
+ Args:
83
+ name: 特征表全称(格式:<table>)
84
+ primary_keys: 主键列名(支持复合主键)
85
+ df: 初始数据(可选,用于推断schema)
86
+ timestamp_keys: 时间戳键(用于时态特征)
87
+ partition_columns: 分区列(优化存储查询)
88
+ description: 业务描述
89
+ tags: 业务标签
90
+
91
+ Returns:
92
+ FeatureTable实例
93
+
94
+ Raises:
95
+ ValueError: 当schema与数据不匹配时
96
+ """
97
+ # 参数标准化
98
+ primary_keys = self._normalize_params(primary_keys)
99
+ timestamp_keys = self._normalize_params(timestamp_keys)
100
+ partition_columns = self._normalize_params(partition_columns)
101
+
102
+ # 元数据校验
103
+ self._validate_schema(df, schema)
104
+ #self._validate_table_name(name)
105
+ self._validate_key_conflicts(primary_keys, timestamp_keys)
106
+
107
+ # 表名 格式:<catalog>.<schema>.<table> catalog默认值:DataLakeCatalog,schema默认值:feature_store
108
+ table_name = f'DataLakeCatalog.feature_store.{name}'
109
+
110
+ # 检查表是否存在
111
+ try:
112
+ if self._spark.catalog.tableExists(table_name):
113
+ raise ValueError(
114
+ f"表 '{table_name}' 已存在\n"
115
+ "解决方案:\n"
116
+ "1. 使用不同的表名\n"
117
+ "2. 删除现有表: spark.sql(f'DROP TABLE {name}')\n"
118
+ )
119
+ except Exception as e:
120
+ raise ValueError(f"检查表存在性时出错: {str(e)}") from e
121
+
122
+ # 推断表schema
123
+ table_schema = schema or df.schema
124
+
125
+ # 构建时间戳键属性
126
+ timestamp_keys_ddl = []
127
+ for timestamp_key in timestamp_keys:
128
+ if timestamp_key not in primary_keys:
129
+ raise ValueError(f"时间戳键 '{timestamp_key}' 必须是主键")
130
+ timestamp_keys_ddl.append(f"`{timestamp_key}` TIMESTAMP")
131
+
132
+ #从环境变量获取额外标签
133
+ env_tags = {
134
+ "project_id": os.getenv("WEDATA_PROJECT_ID", ""), # wedata项目ID
135
+ "engine_name": os.getenv("WEDATA_NOTEBOOK_ENGINE", ""), # wedata引擎名称
136
+ "user_uin": os.getenv("WEDATA_USER_UIN", "") # wedata用户UIN
137
+ }
138
+
139
+ # 构建表属性(通过TBLPROPERTIES)
140
+ tbl_properties = {
141
+ "feature_table": "TRUE",
142
+ "primaryKeys": ",".join(primary_keys),
143
+ "comment": description or "",
144
+ **{f"{k}": v for k, v in (tags or {}).items()},
145
+ **{f"feature_{k}": v for k, v in (env_tags or {}).items()}
146
+ }
147
+
148
+ # 构建列定义
149
+ columns_ddl = []
150
+ for field in table_schema.fields:
151
+ data_type = field.dataType.simpleString().upper()
152
+ col_def = f"`{field.name}` {data_type}"
153
+ if not field.nullable:
154
+ col_def += " NOT NULL"
155
+ # 添加字段注释(如果metadata中有comment)
156
+ if field.metadata and "comment" in field.metadata:
157
+ comment = self._escape_sql_value(field.metadata["comment"])
158
+ col_def += f" COMMENT '{comment}'"
159
+ columns_ddl.append(col_def)
160
+
161
+ # 构建分区表达式
162
+ partition_expr = (
163
+ f"PARTITIONED BY ({', '.join([f'`{c}`' for c in partition_columns])})"
164
+ if partition_columns else ""
165
+ )
166
+
167
+ # 核心建表语句
168
+ ddl = f"""
169
+ CREATE TABLE {table_name} (
170
+ {', '.join(columns_ddl)}
171
+ )
172
+ USING PARQUET
173
+ {partition_expr}
174
+ TBLPROPERTIES (
175
+ {', '.join(f"'{k}'='{self._escape_sql_value(v)}'" for k, v in tbl_properties.items())}
176
+ )
177
+ """
178
+
179
+ # 打印sql
180
+ print(f"create table ddl: {ddl}")
181
+
182
+ # 执行DDL
183
+ try:
184
+ self._spark.sql(ddl)
185
+ if df is not None:
186
+ df.write.insertInto(table_name)
187
+ except Exception as e:
188
+ raise ValueError(f"建表失败: {str(e)}") from e
189
+
190
+ def write_table(
191
+ self,
192
+ name: str,
193
+ df: DataFrame,
194
+ mode: str = APPEND,
195
+ checkpoint_location: Optional[str] = None,
196
+ trigger: Optional[Dict[str, Any]] = DEFAULT_WRITE_STREAM_TRIGGER
197
+ ) -> Optional[StreamingQuery]:
198
+ """
199
+ 写入特征表数据(支持批处理和流式写入)
200
+
201
+ Args:
202
+ name: 特征表名称(格式:<table>)
203
+ df: 要写入的数据(DataFrame)
204
+ mode: 写入模式(append/overwrite)
205
+ checkpoint_location: 流式写入的检查点位置(仅流式写入需要)
206
+ trigger: 流式写入触发条件(仅流式写入需要)
207
+
208
+ Returns:
209
+ 如果是流式写入返回StreamingQuery对象,否则返回None
210
+
211
+ Raises:
212
+ ValueError: 当参数不合法时抛出
213
+ """
214
+
215
+ # 验证写入模式
216
+ valid_modes = ["append", "overwrite"]
217
+ if mode not in valid_modes:
218
+ raise ValueError(f"无效的写入模式 '{mode}',可选值: {valid_modes}")
219
+
220
+ # 完整表名格式:<catalog>.<schema>.<table>
221
+ table_name = f'DataLakeCatalog.feature_store.{name}'
222
+
223
+ # 判断是否是流式DataFrame
224
+ is_streaming = df.isStreaming
225
+
226
+ try:
227
+ if is_streaming:
228
+ # 流式写入
229
+ if not checkpoint_location:
230
+ raise ValueError("流式写入必须提供checkpoint_location参数")
231
+
232
+ writer = df.writeStream \
233
+ .format("parquet") \
234
+ .outputMode(mode) \
235
+ .option("checkpointLocation", checkpoint_location)
236
+
237
+ if trigger:
238
+ writer = writer.trigger(**trigger)
239
+
240
+ return writer.toTable(table_name)
241
+ else:
242
+ # 批处理写入
243
+ df.write \
244
+ .mode(mode) \
245
+ .insertInto(table_name)
246
+ return None
247
+
248
+ except Exception as e:
249
+ raise ValueError(f"写入表'{table_name}'失败: {str(e)}") from e
250
+
251
+ def read_table(
252
+ self,
253
+ name: str
254
+ ) -> DataFrame:
255
+ """
256
+ 从特征表中读取数据
257
+
258
+ Args:
259
+ name: 特征表名称(格式:<table>)
260
+
261
+ Returns:
262
+ 包含表数据的DataFrame
263
+
264
+ Raises:
265
+ ValueError: 当表不存在或读取失败时抛出
266
+ """
267
+ # 构建完整表名
268
+ table_name = f'DataLakeCatalog.feature_store.{name}'
269
+
270
+ try:
271
+ # 检查表是否存在
272
+ if not self._spark.catalog.tableExists(table_name):
273
+ raise ValueError(f"表 '{table_name}' 不存在")
274
+
275
+ # 读取表数据
276
+ return self._spark.read.table(table_name)
277
+
278
+ except Exception as e:
279
+ raise ValueError(f"读取表 '{table_name}' 失败: {str(e)}") from e
280
+
281
+ def drop_table(
282
+ self,
283
+ name: str
284
+ ) -> None:
285
+ """
286
+ 删除特征表(表不存在时抛出异常)
287
+
288
+ Args:
289
+ name: 特征表名称(格式:<table>)
290
+
291
+ Raises:
292
+ ValueError: 当表不存在时抛出
293
+ RuntimeError: 当删除操作失败时抛出
294
+
295
+ 示例:
296
+ # 基本删除
297
+ drop_table("user_features")
298
+ """
299
+ # 构建完整表名
300
+ table_name = f'DataLakeCatalog.feature_store.{name}'
301
+
302
+ try:
303
+ # 检查表是否存在
304
+ if not self._spark.catalog.tableExists(table_name):
305
+ raise ValueError(f"表 '{table_name}' 不存在")
306
+
307
+ # 执行删除
308
+ self._spark.sql(f"DROP TABLE {table_name}")
309
+
310
+ except ValueError as e:
311
+ raise # 直接抛出已知的ValueError
312
+ except Exception as e:
313
+ raise RuntimeError(f"删除表 '{table_name}' 失败: {str(e)}") from e