wedata-feature-engineering 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. wedata/__init__.py +1 -1
  2. wedata/feature_store/client.py +113 -41
  3. wedata/feature_store/constants/constants.py +19 -0
  4. wedata/feature_store/entities/column_info.py +4 -4
  5. wedata/feature_store/entities/feature_lookup.py +5 -1
  6. wedata/feature_store/entities/feature_spec.py +46 -46
  7. wedata/feature_store/entities/feature_table.py +42 -99
  8. wedata/feature_store/entities/training_set.py +13 -12
  9. wedata/feature_store/feature_table_client/feature_table_client.py +86 -31
  10. wedata/feature_store/spark_client/spark_client.py +30 -56
  11. wedata/feature_store/training_set_client/training_set_client.py +209 -38
  12. wedata/feature_store/utils/common_utils.py +213 -3
  13. wedata/feature_store/utils/feature_lookup_utils.py +6 -6
  14. wedata/feature_store/utils/feature_spec_utils.py +6 -6
  15. wedata/feature_store/utils/feature_utils.py +5 -5
  16. wedata/feature_store/utils/on_demand_utils.py +107 -0
  17. wedata/feature_store/utils/schema_utils.py +1 -1
  18. wedata/feature_store/utils/signature_utils.py +205 -0
  19. wedata/feature_store/utils/training_set_utils.py +18 -19
  20. wedata/feature_store/utils/uc_utils.py +1 -1
  21. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.7.dist-info}/METADATA +1 -1
  22. wedata_feature_engineering-0.1.7.dist-info/RECORD +43 -0
  23. feature_store/__init__.py +0 -6
  24. feature_store/client.py +0 -169
  25. feature_store/constants/__init__.py +0 -0
  26. feature_store/constants/constants.py +0 -28
  27. feature_store/entities/__init__.py +0 -0
  28. feature_store/entities/column_info.py +0 -117
  29. feature_store/entities/data_type.py +0 -92
  30. feature_store/entities/environment_variables.py +0 -55
  31. feature_store/entities/feature.py +0 -53
  32. feature_store/entities/feature_column_info.py +0 -64
  33. feature_store/entities/feature_function.py +0 -55
  34. feature_store/entities/feature_lookup.py +0 -179
  35. feature_store/entities/feature_spec.py +0 -454
  36. feature_store/entities/feature_spec_constants.py +0 -25
  37. feature_store/entities/feature_table.py +0 -164
  38. feature_store/entities/feature_table_info.py +0 -40
  39. feature_store/entities/function_info.py +0 -184
  40. feature_store/entities/on_demand_column_info.py +0 -44
  41. feature_store/entities/source_data_column_info.py +0 -21
  42. feature_store/entities/training_set.py +0 -134
  43. feature_store/feature_table_client/__init__.py +0 -0
  44. feature_store/feature_table_client/feature_table_client.py +0 -313
  45. feature_store/spark_client/__init__.py +0 -0
  46. feature_store/spark_client/spark_client.py +0 -286
  47. feature_store/training_set_client/__init__.py +0 -0
  48. feature_store/training_set_client/training_set_client.py +0 -196
  49. feature_store/utils/__init__.py +0 -0
  50. feature_store/utils/common_utils.py +0 -96
  51. feature_store/utils/feature_lookup_utils.py +0 -570
  52. feature_store/utils/feature_spec_utils.py +0 -286
  53. feature_store/utils/feature_utils.py +0 -73
  54. feature_store/utils/schema_utils.py +0 -117
  55. feature_store/utils/topological_sort.py +0 -158
  56. feature_store/utils/training_set_utils.py +0 -580
  57. feature_store/utils/uc_utils.py +0 -281
  58. feature_store/utils/utils.py +0 -252
  59. feature_store/utils/validation_utils.py +0 -55
  60. wedata/feature_store/utils/utils.py +0 -252
  61. wedata_feature_engineering-0.1.5.dist-info/RECORD +0 -79
  62. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.7.dist-info}/WHEEL +0 -0
  63. {wedata_feature_engineering-0.1.5.dist-info → wedata_feature_engineering-0.1.7.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,217 @@
1
1
  """
2
2
  通用工具函数
3
3
  """
4
-
4
+ import os
5
5
  from collections import Counter
6
- from typing import Any, List
6
+ from datetime import datetime, timezone
7
+ from functools import wraps
8
+ from typing import Any, Dict, List, Optional
9
+ from urllib.parse import urlparse
7
10
 
11
+ import mlflow
12
+ from mlflow.exceptions import RestException
13
+ from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
8
14
  from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
9
15
  from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
16
+ from mlflow.utils import databricks_utils
17
+
18
+ from wedata.feature_store.constants.constants import MODEL_DATA_PATH_ROOT
19
+
20
+
21
+ def validate_table_name(name: str):
22
+ """
23
+ 验证特征表名规范,仅支持单表名,不能包含点(如<catalog>.<schema>.<table>)
24
+
25
+ 参数:
26
+ name: 要验证的表名
27
+
28
+ 异常:
29
+ ValueError: 如果表名包含点或不符合规范
30
+ """
31
+ if not name or not isinstance(name, str):
32
+ raise ValueError("Table name must be a non-empty string")
33
+ if name.count('.') > 0:
34
+ raise ValueError("Feature table name only supports single table name, cannot contain dots (e.g. <catalog>.<schema>.<table>)")
35
+ if not name[0].isalpha():
36
+ raise ValueError("Table name must start with a letter")
37
+ if not all(c.isalnum() or c == '_' for c in name):
38
+ raise ValueError("Table name can only contain letters, numbers and underscores")
39
+
40
+
41
+ def build_full_table_name(table_name: str) -> str:
42
+ """
43
+ 构建完整的表名,格式化为`<database>.<table>`形式。
44
+
45
+ Args:
46
+ table_name: 输入的表名(可以是简化的表名或完整表名)。
47
+
48
+ Returns:
49
+ 完整表名(`<database>.<table>`)。
50
+ """
51
+
52
+ # 从环境变量中获取当前主账户名称
53
+ owner_uin = os.environ.get("QCLOUD_UIN", "default")
54
+
55
+ # 如果owner_uin为空,则报错
56
+ if not owner_uin:
57
+ raise ValueError("WEDATA_OWNER_UIN environment variable is not set")
58
+
59
+ feature_store_database = f"{owner_uin}.{table_name}"
60
+
61
+ return feature_store_database
62
+
63
+
64
+ def enable_if(condition):
65
+ """
66
+ A decorator that conditionally enables a function based on a condition.
67
+ If the condition is not truthy, calling the function raises a NotImplementedError.
68
+
69
+ :param condition: A callable that returns a truthy or falsy value.
70
+ """
71
+
72
+ def decorator(func):
73
+ @wraps(func)
74
+ def wrapper(*args, **kwargs):
75
+ if not condition():
76
+ raise NotImplementedError
77
+ return func(*args, **kwargs)
78
+
79
+ return wrapper
80
+
81
+ return decorator
82
+
83
+
84
+ def is_empty(target: str):
85
+ return target is None or len(target.strip()) == 0
86
+
87
+
88
+ class _NoDbutilsError(Exception):
89
+ pass
90
+
91
+
92
+ def _get_dbutils():
93
+ try:
94
+ import IPython
95
+
96
+ ip_shell = IPython.get_ipython()
97
+ if ip_shell is None:
98
+ raise _NoDbutilsError
99
+ return ip_shell.ns_table["user_global"]["dbutils"]
100
+ except ImportError:
101
+ raise _NoDbutilsError
102
+ except KeyError:
103
+ raise _NoDbutilsError
104
+
105
+
106
+ def utc_timestamp_ms_from_iso_datetime_string(date_string: str) -> int:
107
+ dt = datetime.fromisoformat(date_string)
108
+ utc_dt = dt.replace(tzinfo=timezone.utc)
109
+ return 1000 * utc_dt.timestamp()
110
+
111
+
112
+ def pip_depependency_pinned_major_version(pip_package_name, major_version):
113
+ """
114
+ Generate a pip dependency string that is pinned to a major version, for example: "databricks-feature-lookup==0.*"
115
+ """
116
+ return f"{pip_package_name}=={major_version}.*"
117
+
118
+
119
+ def add_mlflow_pip_depependency(conda_env, pip_package_name):
120
+ """
121
+ Add a new pip dependency to the conda environment taken from the raw MLflow model.
122
+ """
123
+ if pip_package_name is None or len(pip_package_name) == 0:
124
+ raise ValueError(
125
+ "Unexpected input: missing or empty pip_package_name parameter"
126
+ )
127
+
128
+ found_pip_dependency = False
129
+ if conda_env is not None:
130
+ for dep in conda_env["dependencies"]:
131
+ if isinstance(dep, dict) and "pip" in dep:
132
+ found_pip_dependency = True
133
+ pip_deps = dep["pip"]
134
+ if pip_package_name not in pip_deps:
135
+ pip_deps.append(pip_package_name)
136
+ if "dependencies" in conda_env and not found_pip_dependency:
137
+ raise ValueError(
138
+ "Unexpected input: mlflow conda_env did not contain pip as a dependency"
139
+ )
140
+
141
+
142
+ def download_model_artifacts(model_uri, dir):
143
+ """
144
+ Downloads model artifacts from model_uri to dir.
145
+ """
146
+ if not is_artifact_uri(model_uri):
147
+ raise ValueError(
148
+ f"Invalid model URI '{model_uri}'."
149
+ f"Use ``models:/model_name>/<version_number>`` or "
150
+ f"``runs:/<mlflow_run_id>/run-relative/path/to/model``."
151
+ )
152
+
153
+ try:
154
+ repo = get_artifact_repository(model_uri)
155
+ except RestException as e:
156
+ raise ValueError(f"The model at '{model_uri}' does not exist.", e)
157
+
158
+ artifact_path = os.path.join(mlflow.pyfunc.DATA, MODEL_DATA_PATH_ROOT)
159
+ if len(repo.list_artifacts(artifact_path)) == 0:
160
+ raise ValueError(
161
+ f"No suitable model found at '{model_uri}'. Either no model exists in this "
162
+ f"artifact location or an existing model was not packaged with Feature Store metadata. "
163
+ f"Only models logged by FeatureStoreClient.log_model can be used in inference."
164
+ )
165
+
166
+ return repo.download_artifacts(artifact_path="", dst_path=dir)
167
+
168
+
169
+ def validate_params_non_empty(params: Dict[str, Any], expected_params: List[str]):
170
+ """
171
+ Validate that none of the expected parameters are empty.
172
+ """
173
+ for expected_param in expected_params:
174
+ if expected_param not in params:
175
+ raise ValueError(
176
+ f'Internal error: expected parameter "{expected_param}" not found in params dictionary'
177
+ )
178
+ param_value = params[expected_param]
179
+ if not param_value:
180
+ raise ValueError(f'Parameter "{expected_param}" cannot be empty')
181
+
182
+
183
+ def is_in_databricks_job():
184
+ """
185
+ Overrides the behavior of the mlflow databricks_utils.is_in_databricks_job().
186
+ """
187
+ try:
188
+ return databricks_utils.get_job_id() is not None
189
+ except Exception:
190
+ return False
191
+
192
+
193
+ def get_workspace_url() -> Optional[str]:
194
+ """
195
+ Overrides the behavior of the mlflow.utils.databricks_utils.get_workspace_url().
196
+ """
197
+ workspace_url = databricks_utils.get_workspace_url()
198
+ if workspace_url and not urlparse(workspace_url).scheme:
199
+ workspace_url = "https://" + workspace_url
200
+ return workspace_url
201
+
202
+
203
+ def is_in_databricks_env():
204
+ """
205
+ Determine if we are running in a Databricks environment.
206
+ """
207
+ try:
208
+ return (
209
+ is_in_databricks_job()
210
+ or databricks_utils.is_in_databricks_notebook()
211
+ or databricks_utils.is_in_databricks_runtime()
212
+ )
213
+ except Exception:
214
+ return False
10
215
 
11
216
 
12
217
  def is_artifact_uri(uri):
@@ -18,6 +223,7 @@ def is_artifact_uri(uri):
18
223
  uri
19
224
  ) or RunsArtifactRepository.is_runs_uri(uri)
20
225
 
226
+
21
227
  def as_list(obj, default=None):
22
228
  if not obj:
23
229
  return default
@@ -26,6 +232,7 @@ def as_list(obj, default=None):
26
232
  else:
27
233
  return [obj]
28
234
 
235
+
29
236
  def get_duplicates(elements: List[Any]) -> List[Any]:
30
237
  """
31
238
  Returns duplicate elements in the order they first appear.
@@ -37,6 +244,7 @@ def get_duplicates(elements: List[Any]) -> List[Any]:
37
244
  duplicates.append(e)
38
245
  return duplicates
39
246
 
247
+
40
248
  def validate_strings_unique(strings: List[str], error_template: str):
41
249
  """
42
250
  Validates all strings are unique, otherwise raise ValueError with the error template and duplicates.
@@ -47,6 +255,7 @@ def validate_strings_unique(strings: List[str], error_template: str):
47
255
  duplicates_formatted = ", ".join([f"'{s}'" for s in duplicate_strings])
48
256
  raise ValueError(error_template.format(duplicates_formatted))
49
257
 
258
+
50
259
  def sanitize_identifier(identifier: str):
51
260
  """
52
261
  Sanitize and wrap an identifier with backquotes. For example, "a`b" becomes "`a``b`".
@@ -89,8 +298,9 @@ def unsanitize_identifier(identifier: str):
89
298
  def escape_sql_string(input_str: str) -> str:
90
299
  return input_str.replace("\\", "\\\\").replace("'", "\\'")
91
300
 
301
+
92
302
  def get_unique_list_order(elements: List[Any]) -> List[Any]:
93
303
  """
94
304
  Returns unique elements in the order they first appear.
95
305
  """
96
- return list(dict.fromkeys(elements))
306
+ return list(dict.fromkeys(elements))
@@ -10,13 +10,13 @@ from pyspark.sql import DataFrame, Window
10
10
  from pyspark.sql import functions as F
11
11
  from pyspark.sql.functions import sum, unix_timestamp
12
12
 
13
- from feature_store.entities.environment_variables import BROADCAST_JOIN_THRESHOLD
14
- from feature_store.entities.feature_column_info import FeatureColumnInfo
15
- from feature_store.entities.feature_lookup import FeatureLookup
16
- from feature_store.entities.feature_spec import FeatureSpec
17
- from feature_store.entities.feature_table import FeatureTable
13
+ from wedata.feature_store.entities.environment_variables import BROADCAST_JOIN_THRESHOLD
14
+ from wedata.feature_store.entities.feature_column_info import FeatureColumnInfo
15
+ from wedata.feature_store.entities.feature_lookup import FeatureLookup
16
+ from wedata.feature_store.entities.feature_spec import FeatureSpec
17
+ from wedata.feature_store.entities.feature_table import FeatureTable
18
18
 
19
- from feature_store.utils import common_utils, validation_utils, uc_utils
19
+ from wedata.feature_store.utils import common_utils, validation_utils, uc_utils
20
20
 
21
21
  _logger = logging.getLogger(__name__)
22
22
 
@@ -6,12 +6,12 @@ from typing import Dict, List, Set, Tuple, Type, Union
6
6
  import yaml
7
7
  from mlflow.utils.file_utils import YamlSafeDumper
8
8
 
9
- from feature_store.entities.column_info import ColumnInfo
10
- from feature_store.entities.feature_column_info import FeatureColumnInfo
11
- from feature_store.entities.feature_spec import FeatureSpec
12
- from feature_store.entities.on_demand_column_info import OnDemandColumnInfo
13
- from feature_store.entities.source_data_column_info import SourceDataColumnInfo
14
- from feature_store.utils.topological_sort import topological_sort
9
+ from wedata.feature_store.entities.column_info import ColumnInfo
10
+ from wedata.feature_store.entities.feature_column_info import FeatureColumnInfo
11
+ from wedata.feature_store.entities.feature_spec import FeatureSpec
12
+ from wedata.feature_store.entities.on_demand_column_info import OnDemandColumnInfo
13
+ from wedata.feature_store.entities.source_data_column_info import SourceDataColumnInfo
14
+ from wedata.feature_store.utils.topological_sort import topological_sort
15
15
 
16
16
  DEFAULT_GRAPH_DEPTH_LIMIT = 5
17
17
 
@@ -1,11 +1,11 @@
1
1
  import copy
2
2
  from typing import List, Union
3
3
 
4
- from feature_store.entities.feature_function import FeatureFunction
5
- from feature_store.entities.feature_lookup import FeatureLookup
6
- from feature_store.spark_client.spark_client import SparkClient
7
- from feature_store.utils import uc_utils
8
- from feature_store.utils.feature_lookup_utils import get_feature_lookups_with_full_table_names
4
+ from wedata.feature_store.entities.feature_function import FeatureFunction
5
+ from wedata.feature_store.entities.feature_lookup import FeatureLookup
6
+ from wedata.feature_store.spark_client.spark_client import SparkClient
7
+ from wedata.feature_store.utils import uc_utils
8
+ from wedata.feature_store.utils.feature_lookup_utils import get_feature_lookups_with_full_table_names
9
9
 
10
10
 
11
11
  def format_feature_lookups_and_functions(
@@ -0,0 +1,107 @@
1
+ import copy
2
+ from typing import Dict, List
3
+
4
+ from pyspark.sql import DataFrame
5
+ from pyspark.sql.functions import expr
6
+
7
+ from wedata.feature_store.entities.feature_function import FeatureFunction
8
+ from wedata.feature_store.entities.function_info import FunctionInfo
9
+ from wedata.feature_store.entities.on_demand_column_info import OnDemandColumnInfo
10
+ from wedata.feature_store.utils import common_utils, uc_utils
11
+
12
+
13
+ def _udf_expr(udf_name: str, arguments: List[str]) -> expr:
14
+ """
15
+ Generate a Spark SQL expression, e.g. expr("udf_name(col1, col2)")
16
+ """
17
+ arguments_str = ", ".join(common_utils.sanitize_identifiers(arguments))
18
+ return expr(f"{udf_name}({arguments_str})")
19
+
20
+
21
+ def _validate_apply_functions_df(
22
+ df: DataFrame,
23
+ functions_to_apply: List[OnDemandColumnInfo],
24
+ uc_function_infos: Dict[str, FunctionInfo],
25
+ ):
26
+ """
27
+ Validate the following:
28
+ 1. On-demand input columns specified by functions_to_apply exist in the DataFrame.
29
+ 2. On-demand input columns have data types that match those of UDF parameters.
30
+ """
31
+ for odci in functions_to_apply:
32
+ function_info = uc_function_infos[odci.udf_name]
33
+ types_dict = dict(df.dtypes)
34
+
35
+ for p in function_info.input_params:
36
+ arg_column = odci.input_bindings[p.name]
37
+ if arg_column not in df.columns:
38
+ raise ValueError(
39
+ f"FeatureFunction argument column '{arg_column}' for UDF '{odci.udf_name}' parameter '{p.name}' "
40
+ f"does not exist in provided DataFrame with schema '{df.schema}'."
41
+ )
42
+ if types_dict[arg_column] != p.type_text:
43
+ raise ValueError(
44
+ f"FeatureFunction argument column '{arg_column}' for UDF '{odci.udf_name}' parameter '{p.name}' "
45
+ f"does not have the expected type. Argument column '{arg_column}' has type "
46
+ f"'{types_dict[arg_column]}' and parameter '{p.name}' has type '{p.type_text}'."
47
+ )
48
+
49
+
50
+ def apply_functions_if_not_overridden(
51
+ df: DataFrame,
52
+ functions_to_apply: List[OnDemandColumnInfo],
53
+ uc_function_infos: Dict[str, FunctionInfo],
54
+ ) -> DataFrame:
55
+ """
56
+ For all on-demand features, in the order defined by the FeatureSpec:
57
+ If the feature does not already exist, append the evaluated UDF expression.
58
+ Existing column values or column positions are not modified.
59
+
60
+ `_validate_apply_functions_df` validates UDFs can be applied on `df` schema.
61
+
62
+ The caller should validate:
63
+ 1. FeatureFunction bound argument columns for UDF parameters exist in FeatureSpec defined features.
64
+ 2. FeatureFunction output feature names are unique.
65
+ """
66
+ _validate_apply_functions_df(
67
+ df=df,
68
+ functions_to_apply=functions_to_apply,
69
+ uc_function_infos=uc_function_infos,
70
+ )
71
+
72
+ columns = {}
73
+ for odci in functions_to_apply:
74
+ if odci.output_name not in df.columns:
75
+ function_info = uc_function_infos[odci.udf_name]
76
+ # Resolve the bound arguments in the UDF parameter order
77
+ udf_arguments = [
78
+ odci.input_bindings[p.name] for p in function_info.input_params
79
+ ]
80
+ columns[odci.output_name] = _udf_expr(odci.udf_name, udf_arguments)
81
+ return df.withColumns(columns)
82
+
83
+
84
+ def get_feature_functions_with_full_udf_names(
85
+ feature_functions: List[FeatureFunction], current_catalog: str, current_schema: str
86
+ ):
87
+ """
88
+ Takes in a list of FeatureFunctions, and returns copies with:
89
+ 1. Fully qualified UDF names.
90
+ 2. If output_name is empty, fully qualified UDF names as output_name.
91
+ """
92
+ udf_names = {ff.udf_name for ff in feature_functions}
93
+ uc_utils._check_qualified_udf_names(udf_names)
94
+ uc_utils._verify_all_udfs_in_uc(udf_names, current_catalog, current_schema)
95
+
96
+ standardized_feature_functions = []
97
+ for ff in feature_functions:
98
+ ff_copy = copy.deepcopy(ff)
99
+ del ff
100
+
101
+ ff_copy._udf_name = uc_utils.get_full_udf_name(
102
+ ff_copy.udf_name, current_catalog, current_schema
103
+ )
104
+ if not ff_copy.output_name:
105
+ ff_copy._output_name = ff_copy.udf_name
106
+ standardized_feature_functions.append(ff_copy)
107
+ return standardized_feature_functions
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
 
3
- from feature_store.constants.constants import _ERROR, _WARN
3
+ from wedata.feature_store.constants.constants import _ERROR, _WARN
4
4
 
5
5
  _logger = logging.getLogger(__name__)
6
6
 
@@ -0,0 +1,205 @@
1
+ import logging
2
+ from typing import Any, Dict, Optional
3
+
4
+ import mlflow
5
+ from mlflow.models import ModelSignature
6
+ from mlflow.types import ColSpec
7
+ from mlflow.types import DataType as MlflowDataType
8
+ from mlflow.types import ParamSchema, Schema
9
+
10
+ from wedata.feature_store.entities.feature_column_info import FeatureColumnInfo
11
+ from wedata.feature_store.entities.feature_spec import FeatureSpec
12
+ from wedata.feature_store.entities.on_demand_column_info import OnDemandColumnInfo
13
+ from wedata.feature_store.entities.source_data_column_info import SourceDataColumnInfo
14
+
15
+ _logger = logging.getLogger(__name__)
16
+
17
+ # Some types (array, map, decimal, timestamp_ntz) are unsupported due to MLflow signatures
18
+ # lacking any equivalent types. We thus cannot construct a ColSpec for any column
19
+ # that uses these types.
20
+ SUPPORTED_TYPE_MAP = {
21
+ "smallint": MlflowDataType.integer, # Upcast to integer
22
+ "int": MlflowDataType.integer,
23
+ "bigint": MlflowDataType.long,
24
+ "float": MlflowDataType.float,
25
+ "double": MlflowDataType.double,
26
+ "boolean": MlflowDataType.boolean,
27
+ "date": MlflowDataType.datetime,
28
+ "timestamp": MlflowDataType.datetime,
29
+ "string": MlflowDataType.string,
30
+ "binary": MlflowDataType.binary,
31
+ }
32
+
33
+
34
+ def is_unsupported_type(type_str: str):
35
+ return type_str not in SUPPORTED_TYPE_MAP
36
+
37
+
38
+ def convert_spark_data_type_to_mlflow_signature_type(spark_type):
39
+ """
40
+ Maps Databricks SQL types to MLflow signature types.
41
+ docs.databricks.com/sql/language-manual/sql-ref-datatypes.html#language-mappings
42
+ """
43
+ return SUPPORTED_TYPE_MAP.get(spark_type)
44
+
45
+
46
+ def get_input_schema_from_feature_spec(feature_spec: FeatureSpec) -> Schema:
47
+ """
48
+ Produces an MLflow signature schema from a feature spec.
49
+ Source data columns are marked as required inputs and feature columns
50
+ (both lookups and on-demand features) are marked as optional inputs.
51
+
52
+ :param feature_spec: FeatureSpec object with datatypes for each column.
53
+ """
54
+ # If we're missing any data types for any column, we are likely dealing with a
55
+ # malformed feature spec and should halt signature construction.
56
+ if any([ci.data_type is None for ci in feature_spec.column_infos]):
57
+ raise Exception("Training set does not contain column data types.")
58
+
59
+ source_data_cols = [
60
+ ci
61
+ for ci in feature_spec.column_infos
62
+ if isinstance(ci.info, SourceDataColumnInfo)
63
+ ]
64
+ # Don't create signature if any source data columns (required) are of complex types.
65
+ if any(
66
+ [
67
+ ci.data_type is None or is_unsupported_type(ci.data_type)
68
+ for ci in source_data_cols
69
+ ]
70
+ ):
71
+ raise Exception(
72
+ "Input DataFrame contains column data types not supported by "
73
+ "MLflow model signatures."
74
+ )
75
+ required_input_colspecs = [
76
+ ColSpec(
77
+ convert_spark_data_type_to_mlflow_signature_type(ci.data_type),
78
+ ci.info.output_name,
79
+ required=True,
80
+ )
81
+ for ci in source_data_cols
82
+ ]
83
+ feature_cols = [
84
+ ci
85
+ for ci in feature_spec.column_infos
86
+ if isinstance(ci.info, (FeatureColumnInfo, OnDemandColumnInfo))
87
+ ]
88
+ unsupported_feature_cols = [
89
+ ci for ci in feature_cols if is_unsupported_type(ci.data_type)
90
+ ]
91
+ optional_input_colspecs = [
92
+ ColSpec(
93
+ convert_spark_data_type_to_mlflow_signature_type(ci.data_type),
94
+ ci.output_name,
95
+ required=False,
96
+ )
97
+ for ci in feature_cols
98
+ if not is_unsupported_type(ci.data_type)
99
+ ]
100
+ if unsupported_feature_cols:
101
+ feat_string = ", ".join(
102
+ [f"{ci.output_name} ({ci.data_type})" for ci in unsupported_feature_cols]
103
+ )
104
+ _logger.warning(
105
+ f"The following features will not be included in the input schema because their"
106
+ f" data types are not supported by MLflow model signatures: {feat_string}. "
107
+ f"These features cannot be overridden during model serving."
108
+ )
109
+ return Schema(required_input_colspecs + optional_input_colspecs)
110
+
111
+
112
+ def get_output_schema_from_labels(label_type_map: Optional[Dict[str, str]]) -> Schema:
113
+ """
114
+ Produces an MLflow signature schema from the provided label type map.
115
+ :param label_type_map: Map label column name -> data type
116
+ """
117
+ if not label_type_map:
118
+ raise Exception("Training set does not contain a label.")
119
+ if any([is_unsupported_type(dtype) for dtype in label_type_map.values()]):
120
+ raise Exception(
121
+ "Labels are of data types not supported by MLflow model signatures."
122
+ )
123
+ else:
124
+ output_colspecs = [
125
+ ColSpec(
126
+ convert_spark_data_type_to_mlflow_signature_type(spark_type),
127
+ col_name,
128
+ required=True,
129
+ )
130
+ for col_name, spark_type in label_type_map.items()
131
+ ]
132
+ return Schema(output_colspecs)
133
+
134
+
135
+ def get_mlflow_signature_from_feature_spec(
136
+ feature_spec: FeatureSpec,
137
+ label_type_map: Optional[Dict[str, str]],
138
+ override_output_schema: Optional[Schema],
139
+ params: Optional[Dict[str, Any]] = None,
140
+ ) -> Optional[ModelSignature]:
141
+ """
142
+ Produce an MLflow signature from a feature spec and label type map.
143
+ Source data columns are marked as required inputs and feature columns
144
+ (both lookups and on-demand features) are marked as optional inputs.
145
+
146
+ Reads output types from the cached label -> datatype map in the training set.
147
+ If override_output_schema is provided, it will always be used as the output schema.
148
+
149
+ :param feature_spec: FeatureSpec object with datatypes for each column.
150
+ :param label_type_map: Map of label column name -> datatype
151
+ :param override_output_schema: User-provided output schema to use if provided.
152
+ """
153
+ kwargs = {}
154
+ kwargs["inputs"] = get_input_schema_from_feature_spec(feature_spec)
155
+ try:
156
+ output_schema = override_output_schema or get_output_schema_from_labels(
157
+ label_type_map
158
+ )
159
+ kwargs["outputs"] = output_schema
160
+ except Exception as e:
161
+ _logger.warning(f"Could not infer an output schema: {e}")
162
+
163
+ if params:
164
+ try:
165
+ from mlflow.types.utils import _infer_param_schema
166
+
167
+ kwargs["params"] = _infer_param_schema(params)
168
+ except Exception as e:
169
+ _logger.warning(f"Could not infer params schema: {e}")
170
+
171
+ return mlflow.models.ModelSignature(**kwargs)
172
+
173
+
174
+ def drop_signature_inputs_and_invalid_params(signature):
175
+ """
176
+ Drop ModelSignature inputs field and invalid params from params field.
177
+ This is useful for feature store model's raw_model.
178
+ Feature store model's input schema does not apply to raw_model's input,
179
+ so we drop the inputs field of raw_model's signature.
180
+ Feature store model's result_type param enables setting and overriding
181
+ a default result_type for predictions, but this interferes with params
182
+ passed to MLflow's predict function, so we drop result_type from
183
+ the params field of raw_model's signature.
184
+
185
+ :param signature: ModelSignature object.
186
+ """
187
+ if signature:
188
+ outputs_schema = signature.outputs
189
+ params_schema = signature.params if hasattr(signature, "params") else None
190
+ try:
191
+ # Only for mlflow>=2.6.0 ModelSignature contains params attribute
192
+ if params_schema:
193
+ updated_params_schema = ParamSchema(
194
+ [param for param in params_schema if param.name != "result_type"]
195
+ )
196
+ return ModelSignature(
197
+ outputs=outputs_schema, params=updated_params_schema
198
+ )
199
+ if outputs_schema:
200
+ return ModelSignature(outputs=outputs_schema)
201
+ except TypeError:
202
+ _logger.warning(
203
+ "ModelSignature without inputs is not supported, please upgrade "
204
+ "mlflow >= 2.7.0 to use the feature."
205
+ )