awx-zipline-ai 0.0.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- __init__.py +0 -0
- agent/__init__.py +1 -0
- agent/constants.py +15 -0
- agent/ttypes.py +1684 -0
- ai/__init__.py +0 -0
- ai/chronon/__init__.py +0 -0
- ai/chronon/airflow_helpers.py +248 -0
- ai/chronon/cli/__init__.py +0 -0
- ai/chronon/cli/compile/__init__.py +0 -0
- ai/chronon/cli/compile/column_hashing.py +336 -0
- ai/chronon/cli/compile/compile_context.py +173 -0
- ai/chronon/cli/compile/compiler.py +183 -0
- ai/chronon/cli/compile/conf_validator.py +742 -0
- ai/chronon/cli/compile/display/__init__.py +0 -0
- ai/chronon/cli/compile/display/class_tracker.py +102 -0
- ai/chronon/cli/compile/display/compile_status.py +95 -0
- ai/chronon/cli/compile/display/compiled_obj.py +12 -0
- ai/chronon/cli/compile/display/console.py +3 -0
- ai/chronon/cli/compile/display/diff_result.py +111 -0
- ai/chronon/cli/compile/fill_templates.py +35 -0
- ai/chronon/cli/compile/parse_configs.py +134 -0
- ai/chronon/cli/compile/parse_teams.py +242 -0
- ai/chronon/cli/compile/serializer.py +109 -0
- ai/chronon/cli/compile/version_utils.py +42 -0
- ai/chronon/cli/git_utils.py +145 -0
- ai/chronon/cli/logger.py +59 -0
- ai/chronon/constants.py +3 -0
- ai/chronon/group_by.py +692 -0
- ai/chronon/join.py +580 -0
- ai/chronon/logger.py +23 -0
- ai/chronon/model.py +40 -0
- ai/chronon/query.py +126 -0
- ai/chronon/repo/__init__.py +39 -0
- ai/chronon/repo/aws.py +284 -0
- ai/chronon/repo/cluster.py +136 -0
- ai/chronon/repo/compile.py +62 -0
- ai/chronon/repo/constants.py +164 -0
- ai/chronon/repo/default_runner.py +269 -0
- ai/chronon/repo/explore.py +418 -0
- ai/chronon/repo/extract_objects.py +134 -0
- ai/chronon/repo/gcp.py +586 -0
- ai/chronon/repo/gitpython_utils.py +15 -0
- ai/chronon/repo/hub_runner.py +261 -0
- ai/chronon/repo/hub_uploader.py +109 -0
- ai/chronon/repo/init.py +60 -0
- ai/chronon/repo/join_backfill.py +119 -0
- ai/chronon/repo/run.py +296 -0
- ai/chronon/repo/serializer.py +133 -0
- ai/chronon/repo/team_json_utils.py +46 -0
- ai/chronon/repo/utils.py +481 -0
- ai/chronon/repo/zipline.py +35 -0
- ai/chronon/repo/zipline_hub.py +277 -0
- ai/chronon/resources/__init__.py +0 -0
- ai/chronon/resources/gcp/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/data.py +30 -0
- ai/chronon/resources/gcp/joins/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/data.py +26 -0
- ai/chronon/resources/gcp/sources/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/data.py +26 -0
- ai/chronon/resources/gcp/teams.py +58 -0
- ai/chronon/source.py +86 -0
- ai/chronon/staging_query.py +226 -0
- ai/chronon/types.py +58 -0
- ai/chronon/utils.py +510 -0
- ai/chronon/windows.py +48 -0
- awx_zipline_ai-0.0.32.dist-info/METADATA +197 -0
- awx_zipline_ai-0.0.32.dist-info/RECORD +96 -0
- awx_zipline_ai-0.0.32.dist-info/WHEEL +5 -0
- awx_zipline_ai-0.0.32.dist-info/entry_points.txt +2 -0
- awx_zipline_ai-0.0.32.dist-info/top_level.txt +4 -0
- gen_thrift/__init__.py +0 -0
- gen_thrift/api/__init__.py +1 -0
- gen_thrift/api/constants.py +15 -0
- gen_thrift/api/ttypes.py +3754 -0
- gen_thrift/common/__init__.py +1 -0
- gen_thrift/common/constants.py +15 -0
- gen_thrift/common/ttypes.py +1814 -0
- gen_thrift/eval/__init__.py +1 -0
- gen_thrift/eval/constants.py +15 -0
- gen_thrift/eval/ttypes.py +660 -0
- gen_thrift/fetcher/__init__.py +1 -0
- gen_thrift/fetcher/constants.py +15 -0
- gen_thrift/fetcher/ttypes.py +127 -0
- gen_thrift/hub/__init__.py +1 -0
- gen_thrift/hub/constants.py +15 -0
- gen_thrift/hub/ttypes.py +1109 -0
- gen_thrift/observability/__init__.py +1 -0
- gen_thrift/observability/constants.py +15 -0
- gen_thrift/observability/ttypes.py +2355 -0
- gen_thrift/planner/__init__.py +1 -0
- gen_thrift/planner/constants.py +15 -0
- gen_thrift/planner/ttypes.py +1967 -0
ai/__init__.py
ADDED
|
File without changes
|
ai/chronon/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import math
|
|
3
|
+
from typing import OrderedDict
|
|
4
|
+
|
|
5
|
+
from gen_thrift.api.ttypes import GroupBy, Join
|
|
6
|
+
from gen_thrift.common.ttypes import TimeUnit
|
|
7
|
+
|
|
8
|
+
import ai.chronon.utils as utils
|
|
9
|
+
from ai.chronon.constants import (
|
|
10
|
+
AIRFLOW_DEPENDENCIES_KEY,
|
|
11
|
+
AIRFLOW_LABEL_DEPENDENCIES_KEY,
|
|
12
|
+
PARTITION_COLUMN_KEY,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def create_airflow_dependency(table, partition_column, additional_partitions=None, offset=0):
|
|
17
|
+
"""
|
|
18
|
+
Create an Airflow dependency object for a table.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
table: The table name (with namespace)
|
|
22
|
+
partition_column: The partition column to use (defaults to 'ds')
|
|
23
|
+
additional_partitions: Additional partitions to include in the dependency
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
A dictionary with name and spec for the Airflow dependency
|
|
27
|
+
"""
|
|
28
|
+
assert (
|
|
29
|
+
partition_column is not None
|
|
30
|
+
), """Partition column must be provided via the spark.chronon.partition.column
|
|
31
|
+
config. This can be set as a default in teams.py, or at the individual config level. For example:
|
|
32
|
+
```
|
|
33
|
+
Team(
|
|
34
|
+
conf=ConfigProperties(
|
|
35
|
+
common={
|
|
36
|
+
"spark.chronon.partition.column": "_test_column",
|
|
37
|
+
}
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
```
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
additional_partitions_str = ""
|
|
44
|
+
if additional_partitions:
|
|
45
|
+
additional_partitions_str = "/" + "/".join(additional_partitions)
|
|
46
|
+
|
|
47
|
+
return {
|
|
48
|
+
"name": f"wf_{utils.sanitize(table)}_with_offset_{offset}",
|
|
49
|
+
"spec": f"{table}/{partition_column}={{{{ macros.ds_add(ds, {offset}) }}}}{additional_partitions_str}",
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _get_partition_col_from_query(query):
|
|
54
|
+
"""Gets partition column from query if available"""
|
|
55
|
+
if query:
|
|
56
|
+
return query.partitionColumn
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _get_additional_subPartitionsToWaitFor_from_query(query):
|
|
61
|
+
"""Gets additional subPartitionsToWaitFor from query if available"""
|
|
62
|
+
if query:
|
|
63
|
+
return query.subPartitionsToWaitFor
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _get_airflow_deps_from_source(source, partition_column=None):
|
|
68
|
+
"""
|
|
69
|
+
Given a source, return a list of Airflow dependencies.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
source: The source object (events, entities, or joinSource)
|
|
73
|
+
partition_column: The partition column to use
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
A list of Airflow dependency objects
|
|
77
|
+
"""
|
|
78
|
+
tables = []
|
|
79
|
+
additional_partitions = None
|
|
80
|
+
# Assumes source has already been normalized
|
|
81
|
+
if source.events:
|
|
82
|
+
tables = [source.events.table]
|
|
83
|
+
# Use partition column from query if available, otherwise use the provided one
|
|
84
|
+
source_partition_column, additional_partitions = (
|
|
85
|
+
_get_partition_col_from_query(source.events.query) or partition_column,
|
|
86
|
+
_get_additional_subPartitionsToWaitFor_from_query(source.events.query),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
elif source.entities:
|
|
90
|
+
# Given the setup of Query, we currently mandate the same partition column for snapshot and mutations tables
|
|
91
|
+
tables = [source.entities.snapshotTable]
|
|
92
|
+
if source.entities.mutationTable:
|
|
93
|
+
tables.append(source.entities.mutationTable)
|
|
94
|
+
source_partition_column, additional_partitions = (
|
|
95
|
+
_get_partition_col_from_query(source.entities.query) or partition_column,
|
|
96
|
+
_get_additional_subPartitionsToWaitFor_from_query(source.entities.query),
|
|
97
|
+
)
|
|
98
|
+
elif source.joinSource:
|
|
99
|
+
# TODO: Handle joinSource -- it doesn't work right now because the metadata isn't set on joinSource at this point
|
|
100
|
+
return []
|
|
101
|
+
else:
|
|
102
|
+
# Unknown source type
|
|
103
|
+
return []
|
|
104
|
+
|
|
105
|
+
return [
|
|
106
|
+
create_airflow_dependency(table, source_partition_column, additional_partitions)
|
|
107
|
+
for table in tables
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def extract_default_partition_column(obj):
|
|
112
|
+
try:
|
|
113
|
+
return obj.metaData.executionInfo.conf.common.get("spark.chronon.partition.column")
|
|
114
|
+
except Exception:
|
|
115
|
+
# Error handling occurs in `create_airflow_dependency`
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _get_distinct_day_windows(group_by):
|
|
120
|
+
windows = []
|
|
121
|
+
aggs = group_by.aggregations
|
|
122
|
+
if aggs:
|
|
123
|
+
for agg in aggs:
|
|
124
|
+
for window in agg.windows:
|
|
125
|
+
time_unit = window.timeUnit
|
|
126
|
+
length = window.length
|
|
127
|
+
if time_unit == TimeUnit.DAYS:
|
|
128
|
+
windows.append(length)
|
|
129
|
+
elif time_unit == TimeUnit.HOURS:
|
|
130
|
+
windows.append(math.ceil(length / 24))
|
|
131
|
+
elif time_unit == TimeUnit.MINUTES:
|
|
132
|
+
windows.append(math.ceil(length / (24 * 60)))
|
|
133
|
+
return set(windows)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _set_join_deps(join):
|
|
137
|
+
default_partition_col = extract_default_partition_column(join)
|
|
138
|
+
|
|
139
|
+
deps = []
|
|
140
|
+
|
|
141
|
+
# Handle left source
|
|
142
|
+
left_query = utils.get_query(join.left)
|
|
143
|
+
left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
|
|
144
|
+
deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))
|
|
145
|
+
|
|
146
|
+
# Handle right parts (join parts)
|
|
147
|
+
if join.joinParts:
|
|
148
|
+
for join_part in join.joinParts:
|
|
149
|
+
if join_part.groupBy and join_part.groupBy.sources:
|
|
150
|
+
for source in join_part.groupBy.sources:
|
|
151
|
+
source_query = utils.get_query(source)
|
|
152
|
+
source_partition_column = (
|
|
153
|
+
_get_partition_col_from_query(source_query) or default_partition_col
|
|
154
|
+
)
|
|
155
|
+
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
|
|
156
|
+
|
|
157
|
+
label_deps = []
|
|
158
|
+
# Handle label parts
|
|
159
|
+
if join.labelParts and join.labelParts.labels:
|
|
160
|
+
join_output_table = utils.output_table_name(join, full_name=True)
|
|
161
|
+
partition_column = join.metaData.executionInfo.conf.common[PARTITION_COLUMN_KEY]
|
|
162
|
+
|
|
163
|
+
# set the dependencies on the label sources
|
|
164
|
+
for label_part in join.labelParts.labels:
|
|
165
|
+
group_by = label_part.groupBy
|
|
166
|
+
|
|
167
|
+
# set the dependency on the join output -- one for each distinct window offset
|
|
168
|
+
windows = _get_distinct_day_windows(group_by)
|
|
169
|
+
for window in windows:
|
|
170
|
+
label_deps.append(
|
|
171
|
+
create_airflow_dependency(
|
|
172
|
+
join_output_table, partition_column, offset=-1 * window
|
|
173
|
+
)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
if group_by and group_by.sources:
|
|
177
|
+
for source in label_part.groupBy.sources:
|
|
178
|
+
source_query = utils.get_query(source)
|
|
179
|
+
source_partition_column = (
|
|
180
|
+
_get_partition_col_from_query(source_query) or default_partition_col
|
|
181
|
+
)
|
|
182
|
+
label_deps.extend(
|
|
183
|
+
_get_airflow_deps_from_source(source, source_partition_column)
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Update the metadata customJson with dependencies
|
|
187
|
+
_dedupe_and_set_airflow_deps_json(join, deps, AIRFLOW_DEPENDENCIES_KEY)
|
|
188
|
+
|
|
189
|
+
# Update the metadata customJson with label join deps
|
|
190
|
+
if label_deps:
|
|
191
|
+
_dedupe_and_set_airflow_deps_json(join, label_deps, AIRFLOW_LABEL_DEPENDENCIES_KEY)
|
|
192
|
+
|
|
193
|
+
# Set the t/f flag for label_join
|
|
194
|
+
_set_label_join_flag(join)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def _set_group_by_deps(group_by):
|
|
198
|
+
if not group_by.sources:
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
default_partition_col = extract_default_partition_column(group_by)
|
|
202
|
+
|
|
203
|
+
deps = []
|
|
204
|
+
|
|
205
|
+
# Process each source in the group_by
|
|
206
|
+
for source in group_by.sources:
|
|
207
|
+
source_query = utils.get_query(source)
|
|
208
|
+
source_partition_column = (
|
|
209
|
+
_get_partition_col_from_query(source_query) or default_partition_col
|
|
210
|
+
)
|
|
211
|
+
deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
|
|
212
|
+
|
|
213
|
+
# Update the metadata customJson with dependencies
|
|
214
|
+
_dedupe_and_set_airflow_deps_json(group_by, deps, AIRFLOW_DEPENDENCIES_KEY)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _set_label_join_flag(join):
|
|
218
|
+
existing_json = join.metaData.customJson or "{}"
|
|
219
|
+
json_map = json.loads(existing_json)
|
|
220
|
+
label_join_flag = False
|
|
221
|
+
if join.labelParts:
|
|
222
|
+
label_join_flag = True
|
|
223
|
+
json_map["label_join"] = label_join_flag
|
|
224
|
+
join.metaData.customJson = json.dumps(json_map)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _dedupe_and_set_airflow_deps_json(obj, deps, custom_json_key):
|
|
228
|
+
sorted_items = [tuple(sorted(d.items())) for d in deps]
|
|
229
|
+
# Use OrderedDict for re-producible ordering of dependencies
|
|
230
|
+
unique = [OrderedDict(t) for t in sorted_items]
|
|
231
|
+
existing_json = obj.metaData.customJson or "{}"
|
|
232
|
+
json_map = json.loads(existing_json)
|
|
233
|
+
json_map[custom_json_key] = unique
|
|
234
|
+
obj.metaData.customJson = json.dumps(json_map)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def set_airflow_deps(obj):
|
|
238
|
+
"""
|
|
239
|
+
Set Airflow dependencies for a Chronon object.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
obj: A Join, GroupBy
|
|
243
|
+
"""
|
|
244
|
+
# StagingQuery dependency setting is handled directly in object init
|
|
245
|
+
if isinstance(obj, Join):
|
|
246
|
+
_set_join_deps(obj)
|
|
247
|
+
elif isinstance(obj, GroupBy):
|
|
248
|
+
_set_group_by_deps(obj)
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,336 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import re
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
from gen_thrift.api.ttypes import Derivation, ExternalPart, GroupBy, Join, Source
|
|
7
|
+
|
|
8
|
+
from ai.chronon.group_by import get_output_col_names
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Returns a map of output column to semantic hash, including derivations
|
|
12
|
+
def compute_group_by_columns_hashes(
|
|
13
|
+
group_by: GroupBy, exclude_keys: bool = False
|
|
14
|
+
) -> Dict[str, str]:
|
|
15
|
+
"""
|
|
16
|
+
From the group_by object, get the final output columns after derivations.
|
|
17
|
+
"""
|
|
18
|
+
# Get the output columns and their input expressions
|
|
19
|
+
output_to_input = get_pre_derived_group_by_columns(group_by)
|
|
20
|
+
|
|
21
|
+
# Get the base semantic fields that apply to all columns
|
|
22
|
+
base_semantics = []
|
|
23
|
+
for source in group_by.sources:
|
|
24
|
+
base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
|
|
25
|
+
|
|
26
|
+
# Add the major version to semantic fields
|
|
27
|
+
group_by_minor_version_suffix = f"__{group_by.metaData.version}"
|
|
28
|
+
group_by_major_version = group_by.metaData.name
|
|
29
|
+
if group_by_major_version.endswith(group_by_minor_version_suffix):
|
|
30
|
+
group_by_major_version = group_by_major_version[: -len(group_by_minor_version_suffix)]
|
|
31
|
+
base_semantics.append(f"group_by_name:{group_by_major_version}")
|
|
32
|
+
|
|
33
|
+
# Compute the semantic hash for each output column
|
|
34
|
+
output_to_hash = {}
|
|
35
|
+
for output_col, input_expr in output_to_input.items():
|
|
36
|
+
semantic_components = base_semantics + [f"input_expr:{input_expr}"]
|
|
37
|
+
semantic_hash = _compute_semantic_hash(semantic_components)
|
|
38
|
+
output_to_hash[output_col] = semantic_hash
|
|
39
|
+
|
|
40
|
+
if exclude_keys:
|
|
41
|
+
output = {k: v for k, v in output_to_hash.items() if k not in group_by.keyColumns}
|
|
42
|
+
else:
|
|
43
|
+
output = output_to_hash
|
|
44
|
+
if group_by.derivations:
|
|
45
|
+
derived = build_derived_columns(output, group_by.derivations, base_semantics)
|
|
46
|
+
if not exclude_keys:
|
|
47
|
+
# We need to add keys back at this point
|
|
48
|
+
for key in group_by.keyColumns:
|
|
49
|
+
if key not in derived:
|
|
50
|
+
derived[key] = output_to_input.get(key)
|
|
51
|
+
return derived
|
|
52
|
+
else:
|
|
53
|
+
return output
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def compute_join_column_hashes(join: Join) -> Dict[str, str]:
|
|
57
|
+
"""
|
|
58
|
+
From the join object, get the final output columns -> semantic hash after derivations.
|
|
59
|
+
"""
|
|
60
|
+
# Get the base semantics from the left side (table and key expression)
|
|
61
|
+
base_semantic_fields = []
|
|
62
|
+
|
|
63
|
+
output_columns = get_pre_derived_join_features(join) | get_pre_derived_source_keys(join.left)
|
|
64
|
+
|
|
65
|
+
if join.derivations:
|
|
66
|
+
return build_derived_columns(output_columns, join.derivations, base_semantic_fields)
|
|
67
|
+
else:
|
|
68
|
+
return output_columns
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_pre_derived_join_features(join: Join) -> Dict[str, str]:
|
|
72
|
+
return get_pre_derived_join_internal_features(join) | get_pre_derived_external_features(join)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_pre_derived_external_features(join: Join) -> Dict[str, str]:
|
|
76
|
+
external_cols = []
|
|
77
|
+
if join.onlineExternalParts:
|
|
78
|
+
for external_part in join.onlineExternalParts:
|
|
79
|
+
original_external_columns = [
|
|
80
|
+
param.name for param in external_part.source.valueSchema.params
|
|
81
|
+
]
|
|
82
|
+
prefix = get_external_part_full_name(external_part) + "_"
|
|
83
|
+
for col in original_external_columns:
|
|
84
|
+
external_cols.append(prefix + col)
|
|
85
|
+
# No meaningful semantic information on external columns, so we just return the column names as a self map
|
|
86
|
+
return {x: x for x in external_cols}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_pre_derived_source_keys(source: Source) -> Dict[str, str]:
|
|
90
|
+
base_semantics = _extract_source_semantic_info(source)
|
|
91
|
+
source_keys_to_hashes = {}
|
|
92
|
+
for key, expression in extract_selects(source).items():
|
|
93
|
+
source_keys_to_hashes[key] = _compute_semantic_hash(
|
|
94
|
+
base_semantics + [f"select:{key}={expression}"]
|
|
95
|
+
)
|
|
96
|
+
return source_keys_to_hashes
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def extract_selects(source: Source) -> Dict[str, str]:
|
|
100
|
+
if source.events:
|
|
101
|
+
return source.events.query.selects
|
|
102
|
+
elif source.entities:
|
|
103
|
+
return source.entities.query.selects
|
|
104
|
+
elif source.joinSource:
|
|
105
|
+
return source.joinSource.query.selects
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
|
|
109
|
+
# Get the base semantic fields from join left side (without key columns)
|
|
110
|
+
join_base_semantic_fields = _extract_source_semantic_info(join.left)
|
|
111
|
+
|
|
112
|
+
internal_features = {}
|
|
113
|
+
for jp in join.joinParts:
|
|
114
|
+
# Build key mapping semantics - include left side key expressions
|
|
115
|
+
if jp.keyMapping:
|
|
116
|
+
key_mapping_semantics = [
|
|
117
|
+
"join_keys:" + ",".join(f"{k}:{v}" for k, v in sorted(jp.keyMapping.items()))
|
|
118
|
+
]
|
|
119
|
+
else:
|
|
120
|
+
key_mapping_semantics = []
|
|
121
|
+
|
|
122
|
+
# Include left side key expressions that this join part uses
|
|
123
|
+
left_key_expressions = []
|
|
124
|
+
left_selects = extract_selects(join.left)
|
|
125
|
+
for gb_key in jp.groupBy.keyColumns:
|
|
126
|
+
# Get the mapped key name or use the original key name
|
|
127
|
+
left_key_name = gb_key
|
|
128
|
+
if jp.keyMapping:
|
|
129
|
+
# Find the left side key that maps to this groupby key
|
|
130
|
+
for left_key, mapped_key in jp.keyMapping.items():
|
|
131
|
+
if mapped_key == gb_key:
|
|
132
|
+
left_key_name = left_key
|
|
133
|
+
break
|
|
134
|
+
|
|
135
|
+
# Add the left side expression for this key
|
|
136
|
+
left_key_expr = left_selects.get(left_key_name, left_key_name)
|
|
137
|
+
left_key_expressions.append(f"left_key:{left_key_name}={left_key_expr}")
|
|
138
|
+
|
|
139
|
+
# These semantics apply to all features in the joinPart
|
|
140
|
+
jp_base_semantics = key_mapping_semantics + left_key_expressions + join_base_semantic_fields
|
|
141
|
+
|
|
142
|
+
pre_derived_group_by_features = get_pre_derived_group_by_features(
|
|
143
|
+
jp.groupBy, jp_base_semantics
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if jp.groupBy.derivations:
|
|
147
|
+
derived_group_by_features = build_derived_columns(
|
|
148
|
+
pre_derived_group_by_features, jp.groupBy.derivations, jp_base_semantics
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
derived_group_by_features = pre_derived_group_by_features
|
|
152
|
+
|
|
153
|
+
for col, semantic_hash in derived_group_by_features.items():
|
|
154
|
+
prefix = jp.prefix + "_" if jp.prefix else ""
|
|
155
|
+
if join.useLongNames:
|
|
156
|
+
gb_prefix = prefix + jp.groupBy.metaData.name.replace(".", "_")
|
|
157
|
+
else:
|
|
158
|
+
key_str = "_".join(jp.groupBy.keyColumns)
|
|
159
|
+
gb_prefix = prefix + key_str
|
|
160
|
+
internal_features[gb_prefix + "_" + col] = semantic_hash
|
|
161
|
+
return internal_features
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def get_pre_derived_group_by_columns(group_by: GroupBy) -> Dict[str, str]:
|
|
165
|
+
output_columns_to_hashes = get_pre_derived_group_by_features(group_by)
|
|
166
|
+
for key_column in group_by.keyColumns:
|
|
167
|
+
key_expressions = []
|
|
168
|
+
for source in group_by.sources:
|
|
169
|
+
source_selects = extract_selects(source) # Map[str, str]
|
|
170
|
+
key_expressions.append(source_selects.get(key_column, key_column))
|
|
171
|
+
output_columns_to_hashes[key_column] = _compute_semantic_hash(sorted(key_expressions))
|
|
172
|
+
return output_columns_to_hashes
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def get_pre_derived_group_by_features(
|
|
176
|
+
group_by: GroupBy, additional_semantic_fields=None
|
|
177
|
+
) -> Dict[str, str]:
|
|
178
|
+
# Get the base semantic fields that apply to all aggs
|
|
179
|
+
if additional_semantic_fields is None:
|
|
180
|
+
additional_semantic_fields = []
|
|
181
|
+
base_semantics = _get_base_group_by_semantic_fields(group_by)
|
|
182
|
+
|
|
183
|
+
output_columns = {}
|
|
184
|
+
# For group_bys with aggregations, aggregated columns
|
|
185
|
+
if group_by.aggregations:
|
|
186
|
+
for agg in group_by.aggregations:
|
|
187
|
+
input_expression_str = ",".join(
|
|
188
|
+
get_input_expression_across_sources(group_by, agg.inputColumn)
|
|
189
|
+
)
|
|
190
|
+
for output_col_name in get_output_col_names(agg):
|
|
191
|
+
output_columns[output_col_name] = _compute_semantic_hash(
|
|
192
|
+
base_semantics + [input_expression_str] + additional_semantic_fields
|
|
193
|
+
)
|
|
194
|
+
# For group_bys without aggregations, selected fields from query
|
|
195
|
+
else:
|
|
196
|
+
combined_selects = defaultdict(set)
|
|
197
|
+
|
|
198
|
+
for source in group_by.sources:
|
|
199
|
+
source_selects = extract_selects(source) # Map[str, str]
|
|
200
|
+
for key, val in source_selects.items():
|
|
201
|
+
combined_selects[key].add(val)
|
|
202
|
+
|
|
203
|
+
# Build a unified map of key to select expression from all sources
|
|
204
|
+
unified_selects = {key: ",".join(sorted(vals)) for key, vals in combined_selects.items()}
|
|
205
|
+
|
|
206
|
+
# now compute the hashes on base semantics + expression
|
|
207
|
+
selected_hashes = {
|
|
208
|
+
key: _compute_semantic_hash(base_semantics + [val] + additional_semantic_fields)
|
|
209
|
+
for key, val in unified_selects.items()
|
|
210
|
+
}
|
|
211
|
+
output_columns.update(selected_hashes)
|
|
212
|
+
return output_columns
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _get_base_group_by_semantic_fields(group_by: GroupBy) -> List[str]:
|
|
216
|
+
"""
|
|
217
|
+
Extract base semantic fields from the group_by object.
|
|
218
|
+
This includes source table names, key columns, and any other relevant metadata.
|
|
219
|
+
"""
|
|
220
|
+
base_semantics = []
|
|
221
|
+
for source in group_by.sources:
|
|
222
|
+
base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
|
|
223
|
+
|
|
224
|
+
return sorted(base_semantics)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def _extract_source_semantic_info(source: Source, key_columns: List[str] = None) -> List[str]:
|
|
228
|
+
"""
|
|
229
|
+
Extract source information for semantic hashing.
|
|
230
|
+
Returns list of semantic components.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
source: The source to extract info from
|
|
234
|
+
key_columns: List of key columns to include expressions for. If None, includes all selects.
|
|
235
|
+
"""
|
|
236
|
+
components = []
|
|
237
|
+
|
|
238
|
+
if source.events:
|
|
239
|
+
table = source.events.table
|
|
240
|
+
mutationTable = ""
|
|
241
|
+
query = source.events.query
|
|
242
|
+
cumulative = str(source.events.isCumulative or "")
|
|
243
|
+
elif source.entities:
|
|
244
|
+
table = source.entities.snapshotTable
|
|
245
|
+
mutationTable = source.entities.mutationTable
|
|
246
|
+
query = source.entities.query
|
|
247
|
+
cumulative = ""
|
|
248
|
+
elif source.joinSource:
|
|
249
|
+
table = source.joinSource.join.metaData.name
|
|
250
|
+
mutationTable = ""
|
|
251
|
+
query = source.joinSource.query
|
|
252
|
+
cumulative = ""
|
|
253
|
+
|
|
254
|
+
components.append(f"table:{table}")
|
|
255
|
+
components.append(f"mutation_table:{mutationTable}")
|
|
256
|
+
components.append(f"cumulative:{cumulative}")
|
|
257
|
+
components.append(f"filters:{query.wheres or ''}")
|
|
258
|
+
|
|
259
|
+
selects = query.selects or {}
|
|
260
|
+
if key_columns:
|
|
261
|
+
# Only include expressions for the specified key columns
|
|
262
|
+
for key_col in sorted(key_columns):
|
|
263
|
+
expr = selects.get(key_col, key_col)
|
|
264
|
+
components.append(f"select:{key_col}={expr}")
|
|
265
|
+
|
|
266
|
+
# Add time column expression
|
|
267
|
+
if query.timeColumn:
|
|
268
|
+
time_expr = selects.get(query.timeColumn, query.timeColumn)
|
|
269
|
+
components.append(f"time_column:{query.timeColumn}={time_expr}")
|
|
270
|
+
|
|
271
|
+
return sorted(components)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _compute_semantic_hash(components: List[str]) -> str:
|
|
275
|
+
"""
|
|
276
|
+
Compute semantic hash from a list of components.
|
|
277
|
+
Components should be ordered consistently to ensure reproducible hashes.
|
|
278
|
+
"""
|
|
279
|
+
# Sort components to ensure consistent ordering
|
|
280
|
+
sorted_components = sorted(components)
|
|
281
|
+
hash_input = "|".join(sorted_components)
|
|
282
|
+
return hashlib.md5(hash_input.encode("utf-8")).hexdigest()
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def build_derived_columns(
|
|
286
|
+
base_columns_to_hashes: Dict[str, str],
|
|
287
|
+
derivations: List[Derivation],
|
|
288
|
+
additional_semantic_fields: List[str],
|
|
289
|
+
) -> Dict[str, str]:
|
|
290
|
+
"""
|
|
291
|
+
Build the derived columns from pre-derived columns and derivations.
|
|
292
|
+
"""
|
|
293
|
+
# if derivations contain star, then all columns are included except the columns which are renamed
|
|
294
|
+
output_columns = {}
|
|
295
|
+
if derivations:
|
|
296
|
+
found = any(derivation.expression == "*" for derivation in derivations)
|
|
297
|
+
if found:
|
|
298
|
+
output_columns.update(base_columns_to_hashes)
|
|
299
|
+
for derivation in derivations:
|
|
300
|
+
if base_columns_to_hashes.get(derivation.expression):
|
|
301
|
+
# don't change the semantics if you're just passing a base column through derivations
|
|
302
|
+
output_columns[derivation.name] = base_columns_to_hashes[derivation.expression]
|
|
303
|
+
if derivation.name != "*":
|
|
304
|
+
# Identify base fields present within the derivation to include in the semantic hash
|
|
305
|
+
# We go long to short to avoid taking both a windowed feature and the unwindowed feature
|
|
306
|
+
# i.e. f_7d and f
|
|
307
|
+
derivation_expression = derivation.expression
|
|
308
|
+
base_col_semantic_fields = []
|
|
309
|
+
tokens = re.findall(r"\b\w+\b", derivation_expression)
|
|
310
|
+
for token in tokens:
|
|
311
|
+
if token in base_columns_to_hashes:
|
|
312
|
+
base_col_semantic_fields.append(base_columns_to_hashes[token])
|
|
313
|
+
|
|
314
|
+
output_columns[derivation.name] = _compute_semantic_hash(
|
|
315
|
+
additional_semantic_fields
|
|
316
|
+
+ [f"derivation:{derivation.expression}"]
|
|
317
|
+
+ base_col_semantic_fields
|
|
318
|
+
)
|
|
319
|
+
return output_columns
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def get_external_part_full_name(external_part: ExternalPart) -> str:
|
|
323
|
+
# The logic should be consistent with the full name logic defined
|
|
324
|
+
# in https://github.com/airbnb/chronon/blob/main/api/src/main/scala/ai/chronon/api/Extensions.scala#L677.
|
|
325
|
+
prefix = external_part.prefix + "_" if external_part.prefix else ""
|
|
326
|
+
name = external_part.source.metadata.name
|
|
327
|
+
sanitized_name = re.sub("[^a-zA-Z0-9_]", "_", name)
|
|
328
|
+
return "ext_" + prefix + sanitized_name
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def get_input_expression_across_sources(group_by: GroupBy, input_col: str):
|
|
332
|
+
expressions = []
|
|
333
|
+
for source in group_by.sources:
|
|
334
|
+
selects = extract_selects(source)
|
|
335
|
+
expressions.extend(selects.get(input_col, input_col))
|
|
336
|
+
return sorted(expressions)
|