awx-zipline-ai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +1 -0
- agent/constants.py +15 -0
- agent/ttypes.py +1684 -0
- ai/__init__.py +0 -0
- ai/chronon/__init__.py +0 -0
- ai/chronon/airflow_helpers.py +251 -0
- ai/chronon/api/__init__.py +1 -0
- ai/chronon/api/common/__init__.py +1 -0
- ai/chronon/api/common/constants.py +15 -0
- ai/chronon/api/common/ttypes.py +1844 -0
- ai/chronon/api/constants.py +15 -0
- ai/chronon/api/ttypes.py +3624 -0
- ai/chronon/cli/compile/column_hashing.py +313 -0
- ai/chronon/cli/compile/compile_context.py +177 -0
- ai/chronon/cli/compile/compiler.py +160 -0
- ai/chronon/cli/compile/conf_validator.py +590 -0
- ai/chronon/cli/compile/display/class_tracker.py +112 -0
- ai/chronon/cli/compile/display/compile_status.py +95 -0
- ai/chronon/cli/compile/display/compiled_obj.py +12 -0
- ai/chronon/cli/compile/display/console.py +3 -0
- ai/chronon/cli/compile/display/diff_result.py +46 -0
- ai/chronon/cli/compile/fill_templates.py +40 -0
- ai/chronon/cli/compile/parse_configs.py +141 -0
- ai/chronon/cli/compile/parse_teams.py +238 -0
- ai/chronon/cli/compile/serializer.py +115 -0
- ai/chronon/cli/git_utils.py +156 -0
- ai/chronon/cli/logger.py +61 -0
- ai/chronon/constants.py +3 -0
- ai/chronon/eval/__init__.py +122 -0
- ai/chronon/eval/query_parsing.py +19 -0
- ai/chronon/eval/sample_tables.py +100 -0
- ai/chronon/eval/table_scan.py +186 -0
- ai/chronon/fetcher/__init__.py +1 -0
- ai/chronon/fetcher/constants.py +15 -0
- ai/chronon/fetcher/ttypes.py +127 -0
- ai/chronon/group_by.py +692 -0
- ai/chronon/hub/__init__.py +1 -0
- ai/chronon/hub/constants.py +15 -0
- ai/chronon/hub/ttypes.py +1228 -0
- ai/chronon/join.py +566 -0
- ai/chronon/logger.py +24 -0
- ai/chronon/model.py +35 -0
- ai/chronon/observability/__init__.py +1 -0
- ai/chronon/observability/constants.py +15 -0
- ai/chronon/observability/ttypes.py +2192 -0
- ai/chronon/orchestration/__init__.py +1 -0
- ai/chronon/orchestration/constants.py +15 -0
- ai/chronon/orchestration/ttypes.py +4406 -0
- ai/chronon/planner/__init__.py +1 -0
- ai/chronon/planner/constants.py +15 -0
- ai/chronon/planner/ttypes.py +1686 -0
- ai/chronon/query.py +126 -0
- ai/chronon/repo/__init__.py +40 -0
- ai/chronon/repo/aws.py +298 -0
- ai/chronon/repo/cluster.py +65 -0
- ai/chronon/repo/compile.py +56 -0
- ai/chronon/repo/constants.py +164 -0
- ai/chronon/repo/default_runner.py +291 -0
- ai/chronon/repo/explore.py +421 -0
- ai/chronon/repo/extract_objects.py +137 -0
- ai/chronon/repo/gcp.py +585 -0
- ai/chronon/repo/gitpython_utils.py +14 -0
- ai/chronon/repo/hub_runner.py +171 -0
- ai/chronon/repo/hub_uploader.py +108 -0
- ai/chronon/repo/init.py +53 -0
- ai/chronon/repo/join_backfill.py +105 -0
- ai/chronon/repo/run.py +293 -0
- ai/chronon/repo/serializer.py +141 -0
- ai/chronon/repo/team_json_utils.py +46 -0
- ai/chronon/repo/utils.py +472 -0
- ai/chronon/repo/zipline.py +51 -0
- ai/chronon/repo/zipline_hub.py +105 -0
- ai/chronon/resources/gcp/README.md +174 -0
- ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
- ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/data.py +30 -0
- ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/data.py +23 -0
- ai/chronon/resources/gcp/teams.py +70 -0
- ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
- ai/chronon/source.py +88 -0
- ai/chronon/staging_query.py +185 -0
- ai/chronon/types.py +57 -0
- ai/chronon/utils.py +557 -0
- ai/chronon/windows.py +50 -0
- awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
- awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
- awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
- awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
- awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
- awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
- jars/__init__.py +0 -0
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import re
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from typing import Dict, List
|
|
5
|
+
|
|
6
|
+
from ai.chronon.api.ttypes import Derivation, ExternalPart, GroupBy, Join, Source
|
|
7
|
+
from ai.chronon.group_by import get_output_col_names
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
# Returns a map of output column to semantic hash, including derivations
|
|
11
|
+
def compute_group_by_columns_hashes(group_by: GroupBy, exclude_keys: bool = False) -> Dict[str, str]:
|
|
12
|
+
"""
|
|
13
|
+
From the group_by object, get the final output columns after derivations.
|
|
14
|
+
"""
|
|
15
|
+
# Get the output columns and their input expressions
|
|
16
|
+
output_to_input = get_pre_derived_group_by_columns(group_by)
|
|
17
|
+
|
|
18
|
+
# Get the base semantic fields that apply to all columns
|
|
19
|
+
base_semantics = []
|
|
20
|
+
for source in group_by.sources:
|
|
21
|
+
base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
|
|
22
|
+
|
|
23
|
+
# Add the major version to semantic fields
|
|
24
|
+
group_by_minor_version_suffix = f"__{group_by.metaData.version}"
|
|
25
|
+
group_by_major_version = group_by.metaData.name
|
|
26
|
+
if group_by_major_version.endswith(group_by_minor_version_suffix):
|
|
27
|
+
group_by_major_version = group_by_major_version[:-len(group_by_minor_version_suffix)]
|
|
28
|
+
base_semantics.append(f"group_by_name:{group_by_major_version}")
|
|
29
|
+
|
|
30
|
+
# Compute the semantic hash for each output column
|
|
31
|
+
output_to_hash = {}
|
|
32
|
+
for output_col, input_expr in output_to_input.items():
|
|
33
|
+
semantic_components = base_semantics + [f"input_expr:{input_expr}"]
|
|
34
|
+
semantic_hash = _compute_semantic_hash(semantic_components)
|
|
35
|
+
output_to_hash[output_col] = semantic_hash
|
|
36
|
+
|
|
37
|
+
if exclude_keys:
|
|
38
|
+
output = {k: v for k, v in output_to_hash.items() if k not in group_by.keyColumns}
|
|
39
|
+
else:
|
|
40
|
+
output = output_to_hash
|
|
41
|
+
if group_by.derivations:
|
|
42
|
+
derived = build_derived_columns(output, group_by.derivations, base_semantics)
|
|
43
|
+
if not exclude_keys:
|
|
44
|
+
# We need to add keys back at this point
|
|
45
|
+
for key in group_by.keyColumns:
|
|
46
|
+
if key not in derived:
|
|
47
|
+
derived[key] = output_to_input.get(key)
|
|
48
|
+
return derived
|
|
49
|
+
else:
|
|
50
|
+
return output
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def compute_join_column_hashes(join: Join) -> Dict[str, str]:
|
|
54
|
+
"""
|
|
55
|
+
From the join object, get the final output columns -> semantic hash after derivations.
|
|
56
|
+
"""
|
|
57
|
+
# Get the base semantics from the left side (table and key expression)
|
|
58
|
+
base_semantic_fields = []
|
|
59
|
+
|
|
60
|
+
output_columns = get_pre_derived_join_features(join) | get_pre_derived_source_keys(join.left)
|
|
61
|
+
|
|
62
|
+
if join.derivations:
|
|
63
|
+
return build_derived_columns(output_columns, join.derivations, base_semantic_fields)
|
|
64
|
+
else:
|
|
65
|
+
return output_columns
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def get_pre_derived_join_features(join: Join) -> Dict[str, str]:
|
|
69
|
+
return get_pre_derived_join_internal_features(join) | get_pre_derived_external_features(join)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def get_pre_derived_external_features(join: Join) -> Dict[str, str]:
|
|
73
|
+
external_cols = []
|
|
74
|
+
if join.onlineExternalParts:
|
|
75
|
+
for external_part in join.onlineExternalParts:
|
|
76
|
+
original_external_columns = [
|
|
77
|
+
param.name for param in external_part.source.valueSchema.params
|
|
78
|
+
]
|
|
79
|
+
prefix = get_external_part_full_name(external_part) + "_"
|
|
80
|
+
for col in original_external_columns:
|
|
81
|
+
external_cols.append(prefix + col)
|
|
82
|
+
# No meaningful semantic information on external columns, so we just return the column names as a self map
|
|
83
|
+
return {x: x for x in external_cols}
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def get_pre_derived_source_keys(source: Source) -> Dict[str, str]:
|
|
87
|
+
base_semantics = _extract_source_semantic_info(source)
|
|
88
|
+
source_keys_to_hashes = {}
|
|
89
|
+
for key, expression in extract_selects(source).items():
|
|
90
|
+
source_keys_to_hashes[key] = _compute_semantic_hash(base_semantics + [f"select:{key}={expression}"])
|
|
91
|
+
return source_keys_to_hashes
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def extract_selects(source: Source) -> Dict[str, str]:
|
|
95
|
+
if source.events:
|
|
96
|
+
return source.events.query.selects
|
|
97
|
+
elif source.entities:
|
|
98
|
+
return source.entities.query.selects
|
|
99
|
+
elif source.joinSource:
|
|
100
|
+
return source.joinSource.query.selects
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
|
|
104
|
+
|
|
105
|
+
# Get the base semantic fields from join left side (without key columns)
|
|
106
|
+
join_base_semantic_fields = _extract_source_semantic_info(join.left)
|
|
107
|
+
|
|
108
|
+
internal_features = {}
|
|
109
|
+
for jp in join.joinParts:
|
|
110
|
+
# Build key mapping semantics - include left side key expressions
|
|
111
|
+
if jp.keyMapping:
|
|
112
|
+
key_mapping_semantics = ["join_keys:" + ",".join(f"{k}:{v}" for k, v in sorted(jp.keyMapping.items()))]
|
|
113
|
+
else:
|
|
114
|
+
key_mapping_semantics = []
|
|
115
|
+
|
|
116
|
+
# Include left side key expressions that this join part uses
|
|
117
|
+
left_key_expressions = []
|
|
118
|
+
left_selects = extract_selects(join.left)
|
|
119
|
+
for gb_key in jp.groupBy.keyColumns:
|
|
120
|
+
# Get the mapped key name or use the original key name
|
|
121
|
+
left_key_name = gb_key
|
|
122
|
+
if jp.keyMapping:
|
|
123
|
+
# Find the left side key that maps to this groupby key
|
|
124
|
+
for left_key, mapped_key in jp.keyMapping.items():
|
|
125
|
+
if mapped_key == gb_key:
|
|
126
|
+
left_key_name = left_key
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
# Add the left side expression for this key
|
|
130
|
+
left_key_expr = left_selects.get(left_key_name, left_key_name)
|
|
131
|
+
left_key_expressions.append(f"left_key:{left_key_name}={left_key_expr}")
|
|
132
|
+
|
|
133
|
+
# These semantics apply to all features in the joinPart
|
|
134
|
+
jp_base_semantics = key_mapping_semantics + left_key_expressions + join_base_semantic_fields
|
|
135
|
+
|
|
136
|
+
pre_derived_group_by_features = get_pre_derived_group_by_features(jp.groupBy, jp_base_semantics)
|
|
137
|
+
|
|
138
|
+
if jp.groupBy.derivations:
|
|
139
|
+
derived_group_by_features = build_derived_columns(
|
|
140
|
+
pre_derived_group_by_features, jp.groupBy.derivations, jp_base_semantics
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
derived_group_by_features = pre_derived_group_by_features
|
|
144
|
+
|
|
145
|
+
for col, semantic_hash in derived_group_by_features.items():
|
|
146
|
+
prefix = jp.prefix + "_" if jp.prefix else ""
|
|
147
|
+
if join.useLongNames:
|
|
148
|
+
gb_prefix = prefix + jp.groupBy.metaData.name.replace(".", "_")
|
|
149
|
+
else:
|
|
150
|
+
key_str = "_".join(jp.groupBy.keyColumns)
|
|
151
|
+
gb_prefix = prefix + key_str
|
|
152
|
+
internal_features[gb_prefix + "_" + col] = semantic_hash
|
|
153
|
+
return internal_features
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_pre_derived_group_by_columns(group_by: GroupBy) -> Dict[str, str]:
|
|
157
|
+
output_columns_to_hashes = get_pre_derived_group_by_features(group_by)
|
|
158
|
+
for key_column in group_by.keyColumns:
|
|
159
|
+
key_expressions = []
|
|
160
|
+
for source in group_by.sources:
|
|
161
|
+
source_selects = extract_selects(source) # Map[str, str]
|
|
162
|
+
key_expressions.append(source_selects.get(key_column, key_column))
|
|
163
|
+
output_columns_to_hashes[key_column] = _compute_semantic_hash(sorted(key_expressions))
|
|
164
|
+
return output_columns_to_hashes
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def get_pre_derived_group_by_features(group_by: GroupBy, additional_semantic_fields=None) -> Dict[str, str]:
|
|
168
|
+
# Get the base semantic fields that apply to all aggs
|
|
169
|
+
if additional_semantic_fields is None:
|
|
170
|
+
additional_semantic_fields = []
|
|
171
|
+
base_semantics = _get_base_group_by_semantic_fields(group_by)
|
|
172
|
+
|
|
173
|
+
output_columns = {}
|
|
174
|
+
# For group_bys with aggregations, aggregated columns
|
|
175
|
+
if group_by.aggregations:
|
|
176
|
+
for agg in group_by.aggregations:
|
|
177
|
+
input_expression_str = ",".join(get_input_expression_across_sources(group_by, agg.inputColumn))
|
|
178
|
+
for output_col_name in get_output_col_names(agg):
|
|
179
|
+
output_columns[output_col_name] = _compute_semantic_hash(base_semantics + [input_expression_str] + additional_semantic_fields)
|
|
180
|
+
# For group_bys without aggregations, selected fields from query
|
|
181
|
+
else:
|
|
182
|
+
combined_selects = defaultdict(set)
|
|
183
|
+
|
|
184
|
+
for source in group_by.sources:
|
|
185
|
+
source_selects = extract_selects(source) # Map[str, str]
|
|
186
|
+
for key, val in source_selects.items():
|
|
187
|
+
combined_selects[key].add(val)
|
|
188
|
+
|
|
189
|
+
# Build a unified map of key to select expression from all sources
|
|
190
|
+
unified_selects = {key: ",".join(sorted(vals)) for key, vals in combined_selects.items()}
|
|
191
|
+
|
|
192
|
+
# now compute the hashes on base semantics + expression
|
|
193
|
+
selected_hashes = {key: _compute_semantic_hash(base_semantics + [val] + additional_semantic_fields) for key, val in unified_selects.items()}
|
|
194
|
+
output_columns.update(selected_hashes)
|
|
195
|
+
return output_columns
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _get_base_group_by_semantic_fields(group_by: GroupBy) -> List[str]:
|
|
199
|
+
"""
|
|
200
|
+
Extract base semantic fields from the group_by object.
|
|
201
|
+
This includes source table names, key columns, and any other relevant metadata.
|
|
202
|
+
"""
|
|
203
|
+
base_semantics = []
|
|
204
|
+
for source in group_by.sources:
|
|
205
|
+
base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
|
|
206
|
+
|
|
207
|
+
return sorted(base_semantics)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _extract_source_semantic_info(source: Source, key_columns: List[str] = None) -> List[str]:
|
|
211
|
+
"""
|
|
212
|
+
Extract source information for semantic hashing.
|
|
213
|
+
Returns list of semantic components.
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
source: The source to extract info from
|
|
217
|
+
key_columns: List of key columns to include expressions for. If None, includes all selects.
|
|
218
|
+
"""
|
|
219
|
+
components = []
|
|
220
|
+
|
|
221
|
+
if source.events:
|
|
222
|
+
table = source.events.table
|
|
223
|
+
mutationTable = ""
|
|
224
|
+
query = source.events.query
|
|
225
|
+
cumulative = str(source.events.isCumulative or "")
|
|
226
|
+
elif source.entities:
|
|
227
|
+
table = source.entities.snapshotTable
|
|
228
|
+
mutationTable = source.entities.mutationTable
|
|
229
|
+
query = source.entities.query
|
|
230
|
+
cumulative = ""
|
|
231
|
+
elif source.joinSource:
|
|
232
|
+
table = source.joinSource.join.metaData.name
|
|
233
|
+
mutationTable = ""
|
|
234
|
+
query = source.joinSource.query
|
|
235
|
+
cumulative = ""
|
|
236
|
+
|
|
237
|
+
components.append(f"table:{table}")
|
|
238
|
+
components.append(f"mutation_table:{mutationTable}")
|
|
239
|
+
components.append(f"cumulative:{cumulative}")
|
|
240
|
+
components.append(f"filters:{query.wheres or ''}")
|
|
241
|
+
|
|
242
|
+
selects = query.selects or {}
|
|
243
|
+
if key_columns:
|
|
244
|
+
# Only include expressions for the specified key columns
|
|
245
|
+
for key_col in sorted(key_columns):
|
|
246
|
+
expr = selects.get(key_col, key_col)
|
|
247
|
+
components.append(f"select:{key_col}={expr}")
|
|
248
|
+
|
|
249
|
+
# Add time column expression
|
|
250
|
+
if query.timeColumn:
|
|
251
|
+
time_expr = selects.get(query.timeColumn, query.timeColumn)
|
|
252
|
+
components.append(f"time_column:{query.timeColumn}={time_expr}")
|
|
253
|
+
|
|
254
|
+
return sorted(components)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def _compute_semantic_hash(components: List[str]) -> str:
|
|
258
|
+
"""
|
|
259
|
+
Compute semantic hash from a list of components.
|
|
260
|
+
Components should be ordered consistently to ensure reproducible hashes.
|
|
261
|
+
"""
|
|
262
|
+
# Sort components to ensure consistent ordering
|
|
263
|
+
sorted_components = sorted(components)
|
|
264
|
+
hash_input = "|".join(sorted_components)
|
|
265
|
+
return hashlib.md5(hash_input.encode('utf-8')).hexdigest()
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def build_derived_columns(
|
|
269
|
+
base_columns_to_hashes: Dict[str, str], derivations: List[Derivation], additional_semantic_fields: List[str]
|
|
270
|
+
) -> Dict[str, str]:
|
|
271
|
+
"""
|
|
272
|
+
Build the derived columns from pre-derived columns and derivations.
|
|
273
|
+
"""
|
|
274
|
+
# if derivations contain star, then all columns are included except the columns which are renamed
|
|
275
|
+
output_columns = {}
|
|
276
|
+
if derivations:
|
|
277
|
+
found = any(derivation.expression == "*" for derivation in derivations)
|
|
278
|
+
if found:
|
|
279
|
+
output_columns.update(base_columns_to_hashes)
|
|
280
|
+
for derivation in derivations:
|
|
281
|
+
if base_columns_to_hashes.get(derivation.expression):
|
|
282
|
+
# don't change the semantics if you're just passing a base column through derivations
|
|
283
|
+
output_columns[derivation.name] = base_columns_to_hashes[derivation.expression]
|
|
284
|
+
if derivation.name != "*":
|
|
285
|
+
# Identify base fields present within the derivation to include in the semantic hash
|
|
286
|
+
# We go long to short to avoid taking both a windowed feature and the unwindowed feature
|
|
287
|
+
# i.e. f_7d and f
|
|
288
|
+
derivation_expression = derivation.expression
|
|
289
|
+
base_col_semantic_fields = []
|
|
290
|
+
tokens = re.findall(r'\b\w+\b', derivation_expression)
|
|
291
|
+
for token in tokens:
|
|
292
|
+
if token in base_columns_to_hashes:
|
|
293
|
+
base_col_semantic_fields.append(base_columns_to_hashes[token])
|
|
294
|
+
|
|
295
|
+
output_columns[derivation.name] = _compute_semantic_hash(additional_semantic_fields + [f"derivation:{derivation.expression}"] + base_col_semantic_fields)
|
|
296
|
+
return output_columns
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def get_external_part_full_name(external_part: ExternalPart) -> str:
|
|
300
|
+
# The logic should be consistent with the full name logic defined
|
|
301
|
+
# in https://github.com/airbnb/chronon/blob/main/api/src/main/scala/ai/chronon/api/Extensions.scala#L677.
|
|
302
|
+
prefix = external_part.prefix + "_" if external_part.prefix else ""
|
|
303
|
+
name = external_part.source.metadata.name
|
|
304
|
+
sanitized_name = re.sub("[^a-zA-Z0-9_]", "_", name)
|
|
305
|
+
return "ext_" + prefix + sanitized_name
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def get_input_expression_across_sources(group_by: GroupBy, input_col: str):
|
|
309
|
+
expressions = []
|
|
310
|
+
for source in group_by.sources:
|
|
311
|
+
selects = extract_selects(source)
|
|
312
|
+
expressions.extend(selects.get(input_col, input_col))
|
|
313
|
+
return sorted(expressions)
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Dict, List, Optional, Type
|
|
4
|
+
|
|
5
|
+
import ai.chronon.cli.compile.parse_teams as teams
|
|
6
|
+
from ai.chronon.api.ttypes import GroupBy, Join, MetaData, Model, StagingQuery, Team
|
|
7
|
+
from ai.chronon.cli.compile.conf_validator import ConfValidator
|
|
8
|
+
from ai.chronon.cli.compile.display.compile_status import CompileStatus
|
|
9
|
+
from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
|
|
10
|
+
from ai.chronon.cli.compile.serializer import file2thrift
|
|
11
|
+
from ai.chronon.cli.logger import get_logger, require
|
|
12
|
+
from ai.chronon.orchestration.ttypes import ConfType
|
|
13
|
+
|
|
14
|
+
logger = get_logger()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class ConfigInfo:
|
|
19
|
+
folder_name: str
|
|
20
|
+
cls: Type
|
|
21
|
+
config_type: Optional[ConfType]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class CompileContext:
|
|
26
|
+
|
|
27
|
+
def __init__(self):
|
|
28
|
+
self.chronon_root: str = os.getenv("CHRONON_ROOT", os.getcwd())
|
|
29
|
+
self.teams_dict: Dict[str, Team] = teams.load_teams(self.chronon_root)
|
|
30
|
+
self.compile_dir: str = "compiled"
|
|
31
|
+
|
|
32
|
+
self.config_infos: List[ConfigInfo] = [
|
|
33
|
+
ConfigInfo(folder_name="joins", cls=Join, config_type=ConfType.JOIN),
|
|
34
|
+
ConfigInfo(
|
|
35
|
+
folder_name="group_bys",
|
|
36
|
+
cls=GroupBy,
|
|
37
|
+
config_type=ConfType.GROUP_BY,
|
|
38
|
+
),
|
|
39
|
+
ConfigInfo(
|
|
40
|
+
folder_name="staging_queries",
|
|
41
|
+
cls=StagingQuery,
|
|
42
|
+
config_type=ConfType.STAGING_QUERY,
|
|
43
|
+
),
|
|
44
|
+
ConfigInfo(folder_name="models", cls=Model, config_type=ConfType.MODEL),
|
|
45
|
+
ConfigInfo(folder_name="teams_metadata", cls=MetaData, config_type=None), # only for team metadata
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
self.compile_status = CompileStatus(use_live=False)
|
|
49
|
+
|
|
50
|
+
self.existing_confs: Dict[Type, Dict[str, Any]] = {}
|
|
51
|
+
for config_info in self.config_infos:
|
|
52
|
+
cls = config_info.cls
|
|
53
|
+
self.existing_confs[cls] = self._parse_existing_confs(cls)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
self.validator: ConfValidator = ConfValidator(
|
|
58
|
+
input_root=self.chronon_root,
|
|
59
|
+
output_root=self.compile_dir,
|
|
60
|
+
existing_gbs=self.existing_confs[GroupBy],
|
|
61
|
+
existing_joins=self.existing_confs[Join],
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
def input_dir(self, cls: type) -> str:
|
|
65
|
+
"""
|
|
66
|
+
- eg., input: group_by class
|
|
67
|
+
- eg., output: root/group_bys/
|
|
68
|
+
"""
|
|
69
|
+
config_info = self.config_info_for_class(cls)
|
|
70
|
+
return os.path.join(self.chronon_root, config_info.folder_name)
|
|
71
|
+
|
|
72
|
+
def staging_output_dir(self, cls: type = None) -> str:
|
|
73
|
+
"""
|
|
74
|
+
- eg., input: group_by class
|
|
75
|
+
- eg., output: root/compiled_staging/group_bys/
|
|
76
|
+
"""
|
|
77
|
+
if cls is None:
|
|
78
|
+
return os.path.join(self.chronon_root, self.compile_dir + "_staging")
|
|
79
|
+
else:
|
|
80
|
+
config_info = self.config_info_for_class(cls)
|
|
81
|
+
return os.path.join(
|
|
82
|
+
self.chronon_root,
|
|
83
|
+
self.compile_dir + "_staging",
|
|
84
|
+
config_info.folder_name,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
def output_dir(self, cls: type = None) -> str:
|
|
88
|
+
"""
|
|
89
|
+
- eg., input: group_by class
|
|
90
|
+
- eg., output: root/compiled/group_bys/
|
|
91
|
+
"""
|
|
92
|
+
if cls is None:
|
|
93
|
+
return os.path.join(self.chronon_root, self.compile_dir)
|
|
94
|
+
else:
|
|
95
|
+
config_info = self.config_info_for_class(cls)
|
|
96
|
+
return os.path.join(
|
|
97
|
+
self.chronon_root, self.compile_dir, config_info.folder_name
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def staging_output_path(self, compiled_obj: CompiledObj):
|
|
101
|
+
"""
|
|
102
|
+
- eg., input: group_by with name search.clicks.features.v1
|
|
103
|
+
- eg., output: root/compiled_staging/group_bys/search/clicks.features.v1
|
|
104
|
+
"""
|
|
105
|
+
|
|
106
|
+
output_dir = self.staging_output_dir(compiled_obj.obj.__class__) # compiled/joins
|
|
107
|
+
|
|
108
|
+
team, rest = compiled_obj.name.split(".", 1) # search, clicks.features.v1
|
|
109
|
+
|
|
110
|
+
return os.path.join(
|
|
111
|
+
output_dir,
|
|
112
|
+
team,
|
|
113
|
+
rest,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
def config_info_for_class(self, cls: type) -> ConfigInfo:
|
|
117
|
+
for info in self.config_infos:
|
|
118
|
+
if info.cls == cls:
|
|
119
|
+
return info
|
|
120
|
+
|
|
121
|
+
require(False, f"Class {cls} not found in CONFIG_INFOS")
|
|
122
|
+
|
|
123
|
+
def _parse_existing_confs(self, obj_class: type) -> Dict[str, object]:
|
|
124
|
+
|
|
125
|
+
result = {}
|
|
126
|
+
|
|
127
|
+
output_dir = self.output_dir(obj_class)
|
|
128
|
+
|
|
129
|
+
# Check if output_dir exists before walking
|
|
130
|
+
if not os.path.exists(output_dir):
|
|
131
|
+
return result
|
|
132
|
+
|
|
133
|
+
for sub_root, _sub_dirs, sub_files in os.walk(output_dir):
|
|
134
|
+
|
|
135
|
+
for f in sub_files:
|
|
136
|
+
|
|
137
|
+
if f.startswith("."): # ignore hidden files - such as .DS_Store
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
full_path = os.path.join(sub_root, f)
|
|
141
|
+
|
|
142
|
+
try:
|
|
143
|
+
obj = file2thrift(full_path, obj_class)
|
|
144
|
+
|
|
145
|
+
if obj:
|
|
146
|
+
if hasattr(obj, "metaData"):
|
|
147
|
+
result[obj.metaData.name] = obj
|
|
148
|
+
compiled_obj = CompiledObj(
|
|
149
|
+
name=obj.metaData.name,
|
|
150
|
+
obj=obj,
|
|
151
|
+
file=obj.metaData.sourceFile,
|
|
152
|
+
errors=None,
|
|
153
|
+
obj_type=obj_class.__name__,
|
|
154
|
+
tjson=open(full_path).read(),
|
|
155
|
+
)
|
|
156
|
+
self.compile_status.add_existing_object_update_display(compiled_obj)
|
|
157
|
+
elif isinstance(obj, MetaData):
|
|
158
|
+
team_metadata_name = '.'.join(full_path.split('/')[-2:]) # use the name of the file as team metadata won't have name
|
|
159
|
+
result[team_metadata_name] = obj
|
|
160
|
+
compiled_obj = CompiledObj(
|
|
161
|
+
name=team_metadata_name,
|
|
162
|
+
obj=obj,
|
|
163
|
+
file=obj.sourceFile,
|
|
164
|
+
errors=None,
|
|
165
|
+
obj_type=obj_class.__name__,
|
|
166
|
+
tjson=open(full_path).read(),
|
|
167
|
+
)
|
|
168
|
+
self.compile_status.add_existing_object_update_display(compiled_obj)
|
|
169
|
+
else:
|
|
170
|
+
logger.errors(
|
|
171
|
+
f"Parsed object from {full_path} has no metaData attribute"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
except Exception as e:
|
|
175
|
+
print(f"Failed to parse file {full_path}: {str(e)}", e)
|
|
176
|
+
|
|
177
|
+
return result
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import traceback
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
6
|
+
|
|
7
|
+
import ai.chronon.cli.compile.display.compiled_obj
|
|
8
|
+
import ai.chronon.cli.compile.parse_configs as parser
|
|
9
|
+
import ai.chronon.cli.logger as logger
|
|
10
|
+
from ai.chronon.cli.compile import serializer
|
|
11
|
+
from ai.chronon.cli.compile.compile_context import CompileContext, ConfigInfo
|
|
12
|
+
from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
|
|
13
|
+
from ai.chronon.cli.compile.display.console import console
|
|
14
|
+
from ai.chronon.cli.compile.parse_teams import merge_team_execution_info
|
|
15
|
+
from ai.chronon.orchestration.ttypes import ConfType
|
|
16
|
+
from ai.chronon.types import MetaData
|
|
17
|
+
|
|
18
|
+
logger = logger.get_logger()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class CompileResult:
|
|
23
|
+
config_info: ConfigInfo
|
|
24
|
+
obj_dict: Dict[str, Any]
|
|
25
|
+
error_dict: Dict[str, List[BaseException]]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Compiler:
|
|
29
|
+
|
|
30
|
+
def __init__(self, compile_context: CompileContext):
|
|
31
|
+
self.compile_context = compile_context
|
|
32
|
+
|
|
33
|
+
def compile(self) -> Dict[ConfType, CompileResult]:
|
|
34
|
+
|
|
35
|
+
config_infos = self.compile_context.config_infos
|
|
36
|
+
|
|
37
|
+
compile_results = {}
|
|
38
|
+
|
|
39
|
+
for config_info in config_infos:
|
|
40
|
+
configs = self._compile_class_configs(config_info)
|
|
41
|
+
|
|
42
|
+
compile_results[config_info.config_type] = configs
|
|
43
|
+
self._compile_team_metadata()
|
|
44
|
+
|
|
45
|
+
# check if staging_output_dir exists
|
|
46
|
+
staging_dir = self.compile_context.staging_output_dir()
|
|
47
|
+
if os.path.exists(staging_dir):
|
|
48
|
+
# replace staging_output_dir to output_dir
|
|
49
|
+
output_dir = self.compile_context.output_dir()
|
|
50
|
+
if os.path.exists(output_dir):
|
|
51
|
+
shutil.rmtree(output_dir)
|
|
52
|
+
shutil.move(staging_dir, output_dir)
|
|
53
|
+
else:
|
|
54
|
+
print(
|
|
55
|
+
f"Staging directory {staging_dir} does not exist. "
|
|
56
|
+
"Happens when every chronon config fails to compile or when no chronon configs exist."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# TODO: temporarily just print out the final results of the compile until live fix is implemented:
|
|
60
|
+
# https://github.com/Textualize/rich/pull/3637
|
|
61
|
+
console.print(self.compile_context.compile_status.render())
|
|
62
|
+
|
|
63
|
+
return compile_results
|
|
64
|
+
|
|
65
|
+
def _compile_team_metadata(self):
|
|
66
|
+
"""
|
|
67
|
+
Compile the team metadata and return the compiled object.
|
|
68
|
+
"""
|
|
69
|
+
teams_dict = self.compile_context.teams_dict
|
|
70
|
+
for team in teams_dict:
|
|
71
|
+
m = MetaData()
|
|
72
|
+
merge_team_execution_info(m, teams_dict, team)
|
|
73
|
+
|
|
74
|
+
tjson = serializer.thrift_simple_json(m)
|
|
75
|
+
name = f"{team}.{team}_team_metadata"
|
|
76
|
+
result = CompiledObj(
|
|
77
|
+
name=name,
|
|
78
|
+
obj=m,
|
|
79
|
+
file=name,
|
|
80
|
+
errors=None,
|
|
81
|
+
obj_type=MetaData.__name__,
|
|
82
|
+
tjson=tjson,
|
|
83
|
+
)
|
|
84
|
+
self._write_object(result)
|
|
85
|
+
self.compile_context.compile_status.add_object_update_display(result, MetaData.__name__)
|
|
86
|
+
|
|
87
|
+
# Done writing team metadata, close the class
|
|
88
|
+
self.compile_context.compile_status.close_cls(MetaData.__name__)
|
|
89
|
+
|
|
90
|
+
def _compile_class_configs(self, config_info: ConfigInfo) -> CompileResult:
|
|
91
|
+
|
|
92
|
+
compile_result = CompileResult(
|
|
93
|
+
config_info=config_info, obj_dict={}, error_dict={}
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
input_dir = self.compile_context.input_dir(config_info.cls)
|
|
97
|
+
|
|
98
|
+
compiled_objects = parser.from_folder(
|
|
99
|
+
config_info.cls, input_dir, self.compile_context
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
objects, errors = self._write_objects_in_folder(compiled_objects)
|
|
103
|
+
|
|
104
|
+
if objects:
|
|
105
|
+
compile_result.obj_dict.update(objects)
|
|
106
|
+
|
|
107
|
+
if errors:
|
|
108
|
+
compile_result.error_dict.update(errors)
|
|
109
|
+
|
|
110
|
+
self.compile_context.compile_status.close_cls(config_info.cls.__name__)
|
|
111
|
+
|
|
112
|
+
return compile_result
|
|
113
|
+
|
|
114
|
+
def _write_objects_in_folder(
|
|
115
|
+
self,
|
|
116
|
+
compiled_objects: List[ai.chronon.cli.compile.display.compiled_obj.CompiledObj],
|
|
117
|
+
) -> Tuple[Dict[str, Any], Dict[str, List[BaseException]]]:
|
|
118
|
+
|
|
119
|
+
error_dict = {}
|
|
120
|
+
object_dict = {}
|
|
121
|
+
|
|
122
|
+
for co in compiled_objects:
|
|
123
|
+
|
|
124
|
+
if co.obj:
|
|
125
|
+
|
|
126
|
+
if co.errors:
|
|
127
|
+
error_dict[co.name] = co.errors
|
|
128
|
+
|
|
129
|
+
for error in co.errors:
|
|
130
|
+
self.compile_context.compile_status.print_live_console(
|
|
131
|
+
f"Error processing conf {co.name}: {error}"
|
|
132
|
+
)
|
|
133
|
+
traceback.print_exception(
|
|
134
|
+
type(error), error, error.__traceback__
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
else:
|
|
138
|
+
self._write_object(co)
|
|
139
|
+
object_dict[co.name] = co.obj
|
|
140
|
+
else:
|
|
141
|
+
error_dict[co.file] = co.errors
|
|
142
|
+
|
|
143
|
+
self.compile_context.compile_status.print_live_console(
|
|
144
|
+
f"Error processing file {co.file}: {co.errors}"
|
|
145
|
+
)
|
|
146
|
+
for error in co.errors:
|
|
147
|
+
traceback.print_exception(type(error), error, error.__traceback__)
|
|
148
|
+
|
|
149
|
+
return object_dict, error_dict
|
|
150
|
+
|
|
151
|
+
def _write_object(self, compiled_obj: CompiledObj) -> Optional[List[BaseException]]:
|
|
152
|
+
output_path = self.compile_context.staging_output_path(compiled_obj)
|
|
153
|
+
|
|
154
|
+
folder = os.path.dirname(output_path)
|
|
155
|
+
|
|
156
|
+
if not os.path.exists(folder):
|
|
157
|
+
os.makedirs(folder)
|
|
158
|
+
|
|
159
|
+
with open(output_path, "w") as f:
|
|
160
|
+
f.write(compiled_obj.tjson)
|