wedata-feature-engineering 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. wedata/__init__.py +1 -1
  2. wedata/feature_store/client.py +113 -41
  3. wedata/feature_store/constants/constants.py +19 -0
  4. wedata/feature_store/entities/column_info.py +4 -4
  5. wedata/feature_store/entities/feature_lookup.py +5 -1
  6. wedata/feature_store/entities/feature_spec.py +46 -46
  7. wedata/feature_store/entities/feature_table.py +42 -99
  8. wedata/feature_store/entities/training_set.py +13 -12
  9. wedata/feature_store/feature_table_client/feature_table_client.py +85 -30
  10. wedata/feature_store/spark_client/spark_client.py +30 -56
  11. wedata/feature_store/training_set_client/training_set_client.py +209 -38
  12. wedata/feature_store/utils/common_utils.py +213 -3
  13. wedata/feature_store/utils/feature_lookup_utils.py +6 -6
  14. wedata/feature_store/utils/feature_spec_utils.py +6 -6
  15. wedata/feature_store/utils/feature_utils.py +5 -5
  16. wedata/feature_store/utils/on_demand_utils.py +107 -0
  17. wedata/feature_store/utils/schema_utils.py +1 -1
  18. wedata/feature_store/utils/signature_utils.py +205 -0
  19. wedata/feature_store/utils/training_set_utils.py +18 -19
  20. wedata/feature_store/utils/uc_utils.py +1 -1
  21. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/METADATA +1 -1
  22. wedata_feature_engineering-0.1.6.dist-info/RECORD +43 -0
  23. feature_store/__init__.py +0 -6
  24. feature_store/client.py +0 -169
  25. feature_store/constants/__init__.py +0 -0
  26. feature_store/constants/constants.py +0 -28
  27. feature_store/entities/__init__.py +0 -0
  28. feature_store/entities/column_info.py +0 -117
  29. feature_store/entities/data_type.py +0 -92
  30. feature_store/entities/environment_variables.py +0 -55
  31. feature_store/entities/feature.py +0 -53
  32. feature_store/entities/feature_column_info.py +0 -64
  33. feature_store/entities/feature_function.py +0 -55
  34. feature_store/entities/feature_lookup.py +0 -179
  35. feature_store/entities/feature_spec.py +0 -454
  36. feature_store/entities/feature_spec_constants.py +0 -25
  37. feature_store/entities/feature_table.py +0 -164
  38. feature_store/entities/feature_table_info.py +0 -40
  39. feature_store/entities/function_info.py +0 -184
  40. feature_store/entities/on_demand_column_info.py +0 -44
  41. feature_store/entities/source_data_column_info.py +0 -21
  42. feature_store/entities/training_set.py +0 -134
  43. feature_store/feature_table_client/__init__.py +0 -0
  44. feature_store/feature_table_client/feature_table_client.py +0 -313
  45. feature_store/spark_client/__init__.py +0 -0
  46. feature_store/spark_client/spark_client.py +0 -286
  47. feature_store/training_set_client/__init__.py +0 -0
  48. feature_store/training_set_client/training_set_client.py +0 -196
  49. feature_store/utils/__init__.py +0 -0
  50. feature_store/utils/common_utils.py +0 -96
  51. feature_store/utils/feature_lookup_utils.py +0 -570
  52. feature_store/utils/feature_spec_utils.py +0 -286
  53. feature_store/utils/feature_utils.py +0 -73
  54. feature_store/utils/schema_utils.py +0 -117
  55. feature_store/utils/topological_sort.py +0 -158
  56. feature_store/utils/training_set_utils.py +0 -580
  57. feature_store/utils/uc_utils.py +0 -281
  58. feature_store/utils/utils.py +0 -252
  59. feature_store/utils/validation_utils.py +0 -55
  60. wedata/feature_store/utils/utils.py +0 -252
  61. wedata_feature_engineering-0.1.5.dist-info/RECORD +0 -79
  62. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/WHEEL +0 -0
  63. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.6.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,18 @@ from pyspark.sql.streaming import StreamingQuery
8
8
  from pyspark.sql.types import StructType
9
9
  import os
10
10
 
11
- from feature_store.constants.constants import APPEND, DEFAULT_WRITE_STREAM_TRIGGER
11
+ from wedata.feature_store.constants.constants import APPEND, DEFAULT_WRITE_STREAM_TRIGGER
12
+ from wedata.feature_store.entities.feature_table import FeatureTable
13
+ from wedata.feature_store.spark_client.spark_client import SparkClient
14
+ from wedata.feature_store.utils import common_utils
12
15
 
13
16
 
14
17
  class FeatureTableClient:
15
18
  """特征表操作类"""
16
19
 
17
20
  def __init__(
18
- self,
19
- spark: SparkSession
21
+ self,
22
+ spark: SparkSession
20
23
  ):
21
24
  self._spark = spark
22
25
 
@@ -46,12 +49,6 @@ class FeatureTableClient:
46
49
  f"DataFrame与schema不匹配。差异字段: {diff_fields if diff_fields else '字段类型不一致'}"
47
50
  )
48
51
 
49
- @staticmethod
50
- def _validate_table_name(name: str):
51
- """验证特征表命名规范"""
52
- if name.count('.') < 2:
53
- raise ValueError("特征表名称需符合<catalog>.<schema>.<table>格式")
54
-
55
52
  @staticmethod
56
53
  def _validate_key_conflicts(primary_keys: List[str], timestamp_keys: List[str]):
57
54
  """校验主键与时间戳键是否冲突"""
@@ -75,7 +72,8 @@ class FeatureTableClient:
75
72
  schema: Optional[StructType] = None,
76
73
  description: Optional[str] = None,
77
74
  tags: Optional[Dict[str, str]] = None
78
- ):
75
+ ) -> FeatureTable:
76
+
79
77
  """
80
78
  创建特征表(支持批流数据写入)
81
79
 
@@ -85,6 +83,7 @@ class FeatureTableClient:
85
83
  df: 初始数据(可选,用于推断schema)
86
84
  timestamp_keys: 时间戳键(用于时态特征)
87
85
  partition_columns: 分区列(优化存储查询)
86
+ schema: 表结构定义(可选,当不提供df时必需)
88
87
  description: 业务描述
89
88
  tags: 业务标签
90
89
 
@@ -94,6 +93,7 @@ class FeatureTableClient:
94
93
  Raises:
95
94
  ValueError: 当schema与数据不匹配时
96
95
  """
96
+
97
97
  # 参数标准化
98
98
  primary_keys = self._normalize_params(primary_keys)
99
99
  timestamp_keys = self._normalize_params(timestamp_keys)
@@ -101,23 +101,25 @@ class FeatureTableClient:
101
101
 
102
102
  # 元数据校验
103
103
  self._validate_schema(df, schema)
104
- #self._validate_table_name(name)
105
104
  self._validate_key_conflicts(primary_keys, timestamp_keys)
106
105
 
107
- # 表名 格式:<catalog>.<schema>.<table> catalog默认值:DataLakeCatalog,schema默认值:feature_store
108
- table_name = f'DataLakeCatalog.feature_store.{name}'
106
+ # 表名校验
107
+ common_utils.validate_table_name(name)
108
+
109
+ # 构建完整表名
110
+ table_name = common_utils.build_full_table_name(name)
109
111
 
110
112
  # 检查表是否存在
111
113
  try:
112
114
  if self._spark.catalog.tableExists(table_name):
113
115
  raise ValueError(
114
- f" '{table_name}' 已存在\n"
115
- "解决方案:\n"
116
- "1. 使用不同的表名\n"
117
- "2. 删除现有表: spark.sql(f'DROP TABLE {name}')\n"
116
+ f"Table '{table_name}' already exists\n"
117
+ "Solutions:\n"
118
+ "1. Use a different table name\n"
119
+ "2. Drop the existing table: spark.sql(f'DROP TABLE {name}')\n"
118
120
  )
119
121
  except Exception as e:
120
- raise ValueError(f"检查表存在性时出错: {str(e)}") from e
122
+ raise ValueError(f"Error checking table existence: {str(e)}") from e
121
123
 
122
124
  # 推断表schema
123
125
  table_schema = schema or df.schema
@@ -126,7 +128,7 @@ class FeatureTableClient:
126
128
  timestamp_keys_ddl = []
127
129
  for timestamp_key in timestamp_keys:
128
130
  if timestamp_key not in primary_keys:
129
- raise ValueError(f"时间戳键 '{timestamp_key}' 必须是主键")
131
+ raise ValueError(f"Timestamp key '{timestamp_key}' must be a primary key")
130
132
  timestamp_keys_ddl.append(f"`{timestamp_key}` TIMESTAMP")
131
133
 
132
134
  #从环境变量获取额外标签
@@ -185,7 +187,19 @@ class FeatureTableClient:
185
187
  if df is not None:
186
188
  df.write.insertInto(table_name)
187
189
  except Exception as e:
188
- raise ValueError(f"建表失败: {str(e)}") from e
190
+ raise ValueError(f"Failed to create table: {str(e)}") from e
191
+
192
+ # 构建并返回FeatureTable对象
193
+ return FeatureTable(
194
+ name=name,
195
+ table_id=table_name,
196
+ description=description or "",
197
+ primary_keys=primary_keys,
198
+ partition_columns=partition_columns or [],
199
+ features=[field.name for field in table_schema.fields],
200
+ timestamp_keys=timestamp_keys or [],
201
+ tags=dict(**tags or {}, **env_tags)
202
+ )
189
203
 
190
204
  def write_table(
191
205
  self,
@@ -195,6 +209,7 @@ class FeatureTableClient:
195
209
  checkpoint_location: Optional[str] = None,
196
210
  trigger: Optional[Dict[str, Any]] = DEFAULT_WRITE_STREAM_TRIGGER
197
211
  ) -> Optional[StreamingQuery]:
212
+
198
213
  """
199
214
  写入特征表数据(支持批处理和流式写入)
200
215
 
@@ -215,10 +230,13 @@ class FeatureTableClient:
215
230
  # 验证写入模式
216
231
  valid_modes = ["append", "overwrite"]
217
232
  if mode not in valid_modes:
218
- raise ValueError(f"无效的写入模式 '{mode}',可选值: {valid_modes}")
233
+ raise ValueError(f"Invalid write mode '{mode}', valid options: {valid_modes}")
234
+
235
+ # 表名校验
236
+ common_utils.validate_table_name(name)
219
237
 
220
- # 完整表名格式:<catalog>.<schema>.<table>
221
- table_name = f'DataLakeCatalog.feature_store.{name}'
238
+ # 构建完整表名
239
+ table_name = common_utils.build_full_table_name(name)
222
240
 
223
241
  # 判断是否是流式DataFrame
224
242
  is_streaming = df.isStreaming
@@ -227,7 +245,7 @@ class FeatureTableClient:
227
245
  if is_streaming:
228
246
  # 流式写入
229
247
  if not checkpoint_location:
230
- raise ValueError("流式写入必须提供checkpoint_location参数")
248
+ raise ValueError("Streaming write requires checkpoint_location parameter")
231
249
 
232
250
  writer = df.writeStream \
233
251
  .format("parquet") \
@@ -252,6 +270,7 @@ class FeatureTableClient:
252
270
  self,
253
271
  name: str
254
272
  ) -> DataFrame:
273
+
255
274
  """
256
275
  从特征表中读取数据
257
276
 
@@ -264,8 +283,12 @@ class FeatureTableClient:
264
283
  Raises:
265
284
  ValueError: 当表不存在或读取失败时抛出
266
285
  """
286
+
287
+ # 表名校验
288
+ common_utils.validate_table_name(name)
289
+
267
290
  # 构建完整表名
268
- table_name = f'DataLakeCatalog.feature_store.{name}'
291
+ table_name = common_utils.build_full_table_name(name)
269
292
 
270
293
  try:
271
294
  # 检查表是否存在
@@ -278,10 +301,8 @@ class FeatureTableClient:
278
301
  except Exception as e:
279
302
  raise ValueError(f"读取表 '{table_name}' 失败: {str(e)}") from e
280
303
 
281
- def drop_table(
282
- self,
283
- name: str
284
- ) -> None:
304
+ def drop_table(self, name: str):
305
+
285
306
  """
286
307
  删除特征表(表不存在时抛出异常)
287
308
 
@@ -296,8 +317,12 @@ class FeatureTableClient:
296
317
  # 基本删除
297
318
  drop_table("user_features")
298
319
  """
320
+
321
+ # 表名校验
322
+ common_utils.validate_table_name(name)
323
+
299
324
  # 构建完整表名
300
- table_name = f'DataLakeCatalog.feature_store.{name}'
325
+ table_name = common_utils.build_full_table_name(name)
301
326
 
302
327
  try:
303
328
  # 检查表是否存在
@@ -311,3 +336,33 @@ class FeatureTableClient:
311
336
  raise # 直接抛出已知的ValueError
312
337
  except Exception as e:
313
338
  raise RuntimeError(f"删除表 '{table_name}' 失败: {str(e)}") from e
339
+
340
+ def get_table(
341
+ self,
342
+ name: str,
343
+ spark_client: SparkClient
344
+ ) -> FeatureTable:
345
+
346
+ """获取特征表元数据信息
347
+
348
+ 参数:
349
+ name: 特征表名称
350
+ spark_client: Spark客户端
351
+
352
+ 返回:
353
+ FeatureTable对象
354
+
355
+ 异常:
356
+ ValueError: 当表不存在或获取失败时抛出
357
+ """
358
+
359
+ # 表名校验
360
+ common_utils.validate_table_name(name)
361
+
362
+ # 构建完整表名
363
+ table_name = common_utils.build_full_table_name(name)
364
+
365
+ try:
366
+ return spark_client.get_feature_table(table_name)
367
+ except Exception as e:
368
+ raise ValueError(f"获取表'{name}'元数据失败: {str(e)}") from e
@@ -6,73 +6,52 @@ from pyspark.sql.catalog import Column
6
6
  from pyspark.sql.functions import when, isnull
7
7
  from pyspark.sql.types import StructType, StringType, StructField
8
8
 
9
- from feature_store.entities.feature import Feature
10
- from feature_store.entities.feature_table import FeatureTable
11
- from feature_store.entities.function_info import FunctionParameterInfo, FunctionInfo
12
- from feature_store.utils.common_utils import unsanitize_identifier
13
- from feature_store.utils.utils import sanitize_multi_level_name
9
+ from wedata.feature_store.entities.feature import Feature
10
+ from wedata.feature_store.entities.feature_table import FeatureTable
11
+ from wedata.feature_store.entities.function_info import FunctionParameterInfo, FunctionInfo
12
+ from wedata.feature_store.utils.common_utils import unsanitize_identifier, sanitize_multi_level_name
14
13
 
15
14
 
16
15
  class SparkClient:
17
16
  def __init__(self, spark: SparkSession):
18
17
  self._spark = spark
19
18
 
20
- def createDataFrame(self, data, schema) -> DataFrame:
21
- return self._spark.createDataFrame(data, schema)
22
-
23
- def read_table(
24
- self, qualified_table_name, as_of_delta_timestamp=None, streaming=False
25
- ):
26
- """
27
- Reads a Delta table, optionally as of some timestamp.
28
- """
29
- if streaming and as_of_delta_timestamp:
30
- raise ValueError(
31
- "Internal error: as_of_delta_timestamp cannot be specified when"
32
- " streaming=True."
33
- )
34
-
35
- base_reader = (
36
- # By default, Structured Streaming only handles append operations. Because
37
- # we have a notion of primary keys, most offline feature store operations
38
- # are not appends. For example, FeatureStoreClient.write_table(mode=MERGE)
39
- # will issue a MERGE operation.
40
- # In order to propagate the non-append operations to the
41
- # readStream, we set ignoreChanges to "true".
42
- # For more information,
43
- # see https://docs.databricks.com/delta/delta-streaming.html#ignore-updates-and-deletes
44
- self._spark.readStream.format("delta").option("ignoreChanges", "true")
45
- if streaming
46
- else self._spark.read.format("delta")
47
- )
48
-
49
- if as_of_delta_timestamp:
50
- return base_reader.option("timestampAsOf", as_of_delta_timestamp).table(
51
- sanitize_multi_level_name(qualified_table_name)
52
- )
53
- else:
54
- return base_reader.table(sanitize_multi_level_name(qualified_table_name))
55
19
 
56
20
  def get_current_catalog(self):
57
21
  """
58
- Get current set catalog in the spark context.
22
+ 获取当前Spark会话的catalog名称(使用spark.catalog.currentCatalog属性)
23
+
24
+ 返回:
25
+ str: 当前catalog名称,如果未设置则返回None
59
26
  """
60
27
  try:
61
- df = self._spark.sql("SELECT CURRENT_CATALOG()").collect()
62
- return unsanitize_identifier(df[0][0])
63
- except Exception as e:
28
+ return unsanitize_identifier(self._spark.catalog.currentCatalog())
29
+ except Exception:
64
30
  return None
65
31
 
66
32
  def get_current_database(self):
67
33
  """
68
- Get current set database in the spark context.
34
+ 获取Spark上下文中当前设置的数据库名称
35
+
36
+ 返回:
37
+ str: 当前数据库名称,如果获取失败则返回None
69
38
  """
70
39
  try:
71
- df = self._spark.sql("SELECT CURRENT_DATABASE()").collect()
72
- return unsanitize_identifier(df[0][0])
73
- except Exception as e:
40
+ # 使用Spark SQL查询当前数据库
41
+ df = self._spark.sql("SELECT CURRENT_DATABASE()")
42
+ # 获取第一行第一列的值并去除特殊字符
43
+ return unsanitize_identifier(df.first()[0])
44
+ except Exception:
45
+ # 捕获所有异常并返回None
74
46
  return None
75
47
 
48
+
49
+
50
+
51
+ def createDataFrame(self, data, schema) -> DataFrame:
52
+ return self._spark.createDataFrame(data, schema)
53
+
54
+
76
55
  def read_table(self, table_name):
77
56
  """读取Spark表数据
78
57
 
@@ -134,11 +113,6 @@ class SparkClient:
134
113
  ) for row in columns
135
114
  ]
136
115
 
137
- def get_online_stores(self, table_name):
138
- return None
139
-
140
-
141
-
142
116
  def get_feature_table(self, table_name):
143
117
 
144
118
  # 获取表元数据
@@ -170,19 +144,19 @@ class SparkClient:
170
144
  return FeatureTable(
171
145
  name=table_name,
172
146
  table_id=table_properties.get("table_id", table_name),
173
- description=table.description or table_properties.get("description", table_name),
147
+ description=table.description or table_properties.get("comment", table_name),
174
148
  primary_keys=table_properties.get("primaryKeys", "").split(",") if table_properties.get("primaryKeys") else [],
175
149
  partition_columns=table.partitionColumnNames if hasattr(table, 'partitionColumnNames') else [],
176
150
  features=features,
177
151
  creation_timestamp=None, # Spark表元数据不包含创建时间戳
178
- online_stores=self.get_online_stores(table_name),
152
+ online_stores=None,
179
153
  notebook_producers=None,
180
154
  job_producers=None,
181
155
  table_data_sources=None,
182
156
  path_data_sources=None,
183
157
  custom_data_sources=None,
184
158
  timestamp_keys=table_properties.get("timestamp_keys"),
185
- tags=table_properties.get("tags")
159
+ tags=table_properties
186
160
  )
187
161
 
188
162
  def _get_routines_with_parameters(self, full_routine_names: List[str]) -> DataFrame:
@@ -1,35 +1,28 @@
1
- import json
2
1
  import logging
3
2
  import os
4
- from collections import defaultdict
5
3
  from types import ModuleType
6
- from typing import Any, Dict, List, Optional, Set, Union
4
+ from typing import Any, List, Optional, Set, Union
7
5
 
8
6
  import mlflow
9
- import yaml
10
- from mlflow.models import Model, ModelSignature
11
- from mlflow.utils.file_utils import TempDir, YamlSafeDumper, read_yaml
7
+ from mlflow.models import Model
8
+ from mlflow.utils.file_utils import TempDir, read_yaml
12
9
  from pyspark.sql import DataFrame
13
- from pyspark.sql.functions import struct
14
10
 
15
- from feature_store.entities.feature_column_info import FeatureColumnInfo
16
- from feature_store.entities.feature_function import FeatureFunction
17
- from feature_store.entities.feature_lookup import FeatureLookup
18
- from feature_store.entities.feature_spec import FeatureSpec
19
- from feature_store.entities.training_set import TrainingSet
20
- from feature_store.spark_client.spark_client import SparkClient
11
+ from wedata.feature_store.constants import constants
12
+ from wedata.feature_store.entities.feature_function import FeatureFunction
13
+ from wedata.feature_store.entities.feature_lookup import FeatureLookup
14
+ from wedata.feature_store.entities.feature_spec import FeatureSpec
15
+ from wedata.feature_store.entities.training_set import TrainingSet
16
+ from wedata.feature_store.spark_client.spark_client import SparkClient
21
17
 
22
- from feature_store.constants.constants import (
18
+ from wedata.feature_store.constants.constants import (
23
19
  _NO_RESULT_TYPE_PASSED,
24
- _PREBUILT_ENV_URI,
25
- _USE_SPARK_NATIVE_JOIN,
26
- _WARN,
27
- MODEL_DATA_PATH_ROOT,
28
- PREDICTION_COLUMN_NAME,
20
+ _USE_SPARK_NATIVE_JOIN
29
21
  )
30
22
 
31
- from feature_store.utils import common_utils, training_set_utils
32
- from feature_store.utils.feature_spec_utils import convert_to_yaml_string
23
+ from wedata.feature_store.utils import common_utils, training_set_utils, uc_utils
24
+ from wedata.feature_store.utils.signature_utils import get_mlflow_signature_from_feature_spec, \
25
+ drop_signature_inputs_and_invalid_params
33
26
 
34
27
  _logger = logging.getLogger(__name__)
35
28
 
@@ -46,7 +39,6 @@ class TrainingSetClient:
46
39
  def create_training_set(
47
40
  self,
48
41
  feature_spec: FeatureSpec,
49
- feature_column_infos: List[FeatureColumnInfo],
50
42
  label_names: List[str],
51
43
  df: DataFrame,
52
44
  ft_metadata: training_set_utils._FeatureTableMetadata,
@@ -57,18 +49,6 @@ class TrainingSetClient:
57
49
  {odci.udf_name for odci in feature_spec.on_demand_column_infos},
58
50
  )
59
51
 
60
- # TODO(divyagupta-db): Move validation from _validate_join_feature_data in feature_lookup_utils.py
61
- # to a helper function called here and in score_batch.
62
-
63
- # Add consumer of each feature and instrument as final step
64
- consumer_feature_table_map = defaultdict(list)
65
- for feature in feature_column_infos:
66
- consumer_feature_table_map[feature.table_name].append(feature.feature_name)
67
- consumed_udf_names = [f.udf_name for f in feature_spec.function_infos]
68
-
69
- # Spark query planning is known to cause spark driver to crash if there are many feature tables to PiT join.
70
- # See https://docs.google.com/document/d/1EyA4vvlWikTJMeinsLkxmRAVNlXoF1eqoZElOdqlWyY/edit
71
- # So we disable native join by default.
72
52
  training_set_utils.warn_if_non_photon_for_native_spark(
73
53
  kwargs.get(_USE_SPARK_NATIVE_JOIN, False), self._spark_client
74
54
  )
@@ -96,6 +76,12 @@ class TrainingSetClient:
96
76
  feature_lookups = [f for f in features if isinstance(f, FeatureLookup)]
97
77
  feature_functions = [f for f in features if isinstance(f, FeatureFunction)]
98
78
 
79
+ # 最多支持100个FeatureFunctions
80
+ if len(feature_functions) > training_set_utils.MAX_FEATURE_FUNCTIONS:
81
+ raise ValueError(
82
+ f"A maximum of {training_set_utils.MAX_FEATURE_FUNCTIONS} FeatureFunctions are supported."
83
+ )
84
+
99
85
  # 如果未提供标签,则用空列表初始化label_names
100
86
  label_names = common_utils.as_list(label, [])
101
87
  del label
@@ -137,7 +123,6 @@ class TrainingSetClient:
137
123
 
138
124
  return self.create_training_set(
139
125
  feature_spec,
140
- column_infos.feature_column_infos,
141
126
  label_names,
142
127
  df,
143
128
  ft_metadata,
@@ -145,9 +130,6 @@ class TrainingSetClient:
145
130
  )
146
131
 
147
132
 
148
-
149
-
150
-
151
133
  def create_feature_spec(
152
134
  self,
153
135
  name: str,
@@ -194,3 +176,192 @@ class TrainingSetClient:
194
176
  )
195
177
 
196
178
  return feature_spec
179
+
180
+
181
+ def log_model(
182
+ self,
183
+ model: Any,
184
+ artifact_path: str,
185
+ *,
186
+ flavor: ModuleType,
187
+ training_set: Optional[TrainingSet],
188
+ registered_model_name: Optional[str],
189
+ await_registration_for: int,
190
+ infer_input_example: bool,
191
+ **kwargs,
192
+ ):
193
+ # Validate only one of the training_set, feature_spec_path arguments is provided.
194
+ # Retrieve the FeatureSpec, then remove training_set, feature_spec_path from local scope.
195
+ feature_spec_path = kwargs.pop("feature_spec_path", None)
196
+ if (training_set is None) == (feature_spec_path is None):
197
+ raise ValueError(
198
+ "Either 'training_set' or 'feature_spec_path' must be provided, but not both."
199
+ )
200
+ # Retrieve the FeatureSpec and then reformat tables in local metastore to 2L before serialization.
201
+ # This will make sure the format of the feature spec with local metastore tables is always consistent.
202
+ if training_set:
203
+ all_uc_tables = all(
204
+ [
205
+ uc_utils.is_uc_entity(table_info.table_name)
206
+ for table_info in training_set.feature_spec.table_infos
207
+ ]
208
+ )
209
+ # training_set.feature_spec is guaranteed to be 3L from FeatureStoreClient.create_training_set.
210
+ feature_spec = uc_utils.get_feature_spec_with_reformat_full_table_names(
211
+ training_set.feature_spec
212
+ )
213
+ label_type_map = training_set._label_data_types
214
+
215
+ labels = training_set._labels
216
+ df_head = training_set._df.drop(*labels).head()
217
+ else:
218
+ # FeatureSpec.load expects the root directory of feature_spec.yaml
219
+ root_dir, file_name = os.path.split(feature_spec_path)
220
+ if file_name != FeatureSpec.FEATURE_ARTIFACT_FILE:
221
+ raise ValueError(
222
+ f"'feature_spec_path' must be a path to {FeatureSpec.FEATURE_ARTIFACT_FILE}."
223
+ )
224
+ feature_spec = FeatureSpec.load(root_dir)
225
+
226
+ # The loaded FeatureSpec is not guaranteed to be 3L.
227
+ # First call get_feature_spec_with_full_table_names to append the default metastore to 2L names,
228
+ # as get_feature_spec_with_reformat_full_table_names expects full 3L table names and throws otherwise.
229
+ # TODO (ML-26593): Consolidate this into a single function that allows either 2L/3L names.
230
+ feature_spec_with_full_table_names = (
231
+ uc_utils.get_feature_spec_with_full_table_names(feature_spec)
232
+ )
233
+ all_uc_tables = all(
234
+ [
235
+ uc_utils.is_uc_entity(table_info.table_name)
236
+ for table_info in feature_spec_with_full_table_names.table_infos
237
+ ]
238
+ )
239
+ feature_spec = uc_utils.get_feature_spec_with_reformat_full_table_names(
240
+ feature_spec_with_full_table_names
241
+ )
242
+ label_type_map = None
243
+ df_head = None
244
+ del training_set, feature_spec_path
245
+
246
+ override_output_schema = kwargs.pop("output_schema", None)
247
+ params = kwargs.pop("params", {})
248
+ params["result_type"] = params.get("result_type", _NO_RESULT_TYPE_PASSED)
249
+ # Signatures will ony be supported for UC-table-only models to
250
+ # mitigate new online scoring behavior from being a breaking regression for older
251
+ # models.
252
+ # See https://docs.google.com/document/d/1L5tLY-kRreRefDfuAM3crXvYlirkcPuUUppU8uIMVM0/edit#
253
+ try:
254
+ if all_uc_tables:
255
+ signature = get_mlflow_signature_from_feature_spec(
256
+ feature_spec, label_type_map, override_output_schema, params
257
+ )
258
+ else:
259
+ _logger.warning(
260
+ "Model could not be logged with a signature because the training set uses feature tables in "
261
+ "Hive Metastore. Migrate the feature tables to Unity Catalog for model to be logged "
262
+ "with a signature. "
263
+ "See https://docs.databricks.com/en/machine-learning/feature-store/uc/upgrade-feature-table-to-uc.html for more details."
264
+ )
265
+ signature = None
266
+ except Exception as e:
267
+ _logger.warning(f"Model could not be logged with a signature: {e}")
268
+ signature = None
269
+
270
+ with TempDir() as tmp_location:
271
+ data_path = os.path.join(tmp_location.path(), "feature_store")
272
+ raw_mlflow_model = Model(
273
+ signature=drop_signature_inputs_and_invalid_params(signature)
274
+ )
275
+ raw_model_path = os.path.join(
276
+ data_path, constants.RAW_MODEL_FOLDER
277
+ )
278
+ if flavor.FLAVOR_NAME != mlflow.pyfunc.FLAVOR_NAME:
279
+ flavor.save_model(
280
+ model, raw_model_path, mlflow_model=raw_mlflow_model, **kwargs
281
+ )
282
+ else:
283
+ flavor.save_model(
284
+ raw_model_path,
285
+ mlflow_model=raw_mlflow_model,
286
+ python_model=model,
287
+ **kwargs,
288
+ )
289
+ if not "python_function" in raw_mlflow_model.flavors:
290
+ raise ValueError(
291
+ f"FeatureStoreClient.log_model does not support '{flavor.__name__}' "
292
+ f"since it does not have a python_function model flavor."
293
+ )
294
+
295
+ # Re-use the conda environment from the raw model for the packaged model. Later, we may
296
+ # add an additional requirement for the Feature Store library. At the moment, however,
297
+ # the databricks-feature-store package is not available via conda or pip.
298
+ model_env = raw_mlflow_model.flavors["python_function"][mlflow.pyfunc.ENV]
299
+ if isinstance(model_env, dict):
300
+ # mlflow 2.0 has multiple supported environments
301
+ conda_file = model_env[mlflow.pyfunc.EnvType.CONDA]
302
+ else:
303
+ conda_file = model_env
304
+
305
+ conda_env = read_yaml(raw_model_path, conda_file)
306
+
307
+ # Check if databricks-feature-lookup version is specified in conda_env
308
+ lookup_client_version_specified = False
309
+ for dependency in conda_env.get("dependencies", []):
310
+ if isinstance(dependency, dict):
311
+ for pip_dep in dependency.get("pip", []):
312
+ if pip_dep.startswith(
313
+ constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE
314
+ ):
315
+ lookup_client_version_specified = True
316
+ break
317
+
318
+ # If databricks-feature-lookup version is not specified, add default version
319
+ if not lookup_client_version_specified:
320
+ # Get the pip package string for the databricks-feature-lookup client
321
+ default_databricks_feature_lookup_pip_package = common_utils.pip_depependency_pinned_major_version(
322
+ pip_package_name=constants.FEATURE_LOOKUP_CLIENT_PIP_PACKAGE,
323
+ major_version=constants.FEATURE_LOOKUP_CLIENT_MAJOR_VERSION,
324
+ )
325
+ common_utils.add_mlflow_pip_depependency(
326
+ conda_env, default_databricks_feature_lookup_pip_package
327
+ )
328
+
329
+ try:
330
+ if df_head is not None and infer_input_example:
331
+ input_example = df_head.asDict()
332
+ else:
333
+ input_example = None
334
+ except Exception:
335
+ input_example = None
336
+
337
+ # todo:
338
+ #feature_spec.save(data_path)
339
+
340
+ # Log the packaged model. If no run is active, this call will create an active run.
341
+ mlflow.pyfunc.log_model(
342
+ artifact_path=artifact_path,
343
+ loader_module=constants.MLFLOW_MODEL_NAME,
344
+ data_path=data_path,
345
+ code_path=None,
346
+ conda_env=conda_env,
347
+ signature=signature,
348
+ input_example=input_example,
349
+ )
350
+ if registered_model_name is not None:
351
+ # The call to mlflow.pyfunc.log_model will create an active run, so it is safe to
352
+ # obtain the run_id for the active run.
353
+ run_id = mlflow.tracking.fluent.active_run().info.run_id
354
+
355
+ # If the user provided an explicit model_registry_uri when constructing the FeatureStoreClient,
356
+ # we respect this by setting the registry URI prior to reading the model from Model
357
+ # Registry.
358
+ # todo:
359
+ # if self._model_registry_uri:
360
+ # # This command will override any previously set registry_uri.
361
+ # mlflow.set_registry_uri(self._model_registry_uri)
362
+
363
+ mlflow.register_model(
364
+ "runs:/%s/%s" % (run_id, artifact_path),
365
+ registered_model_name,
366
+ await_registration_for=await_registration_for,
367
+ )