awx-zipline-ai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. agent/__init__.py +1 -0
  2. agent/constants.py +15 -0
  3. agent/ttypes.py +1684 -0
  4. ai/__init__.py +0 -0
  5. ai/chronon/__init__.py +0 -0
  6. ai/chronon/airflow_helpers.py +251 -0
  7. ai/chronon/api/__init__.py +1 -0
  8. ai/chronon/api/common/__init__.py +1 -0
  9. ai/chronon/api/common/constants.py +15 -0
  10. ai/chronon/api/common/ttypes.py +1844 -0
  11. ai/chronon/api/constants.py +15 -0
  12. ai/chronon/api/ttypes.py +3624 -0
  13. ai/chronon/cli/compile/column_hashing.py +313 -0
  14. ai/chronon/cli/compile/compile_context.py +177 -0
  15. ai/chronon/cli/compile/compiler.py +160 -0
  16. ai/chronon/cli/compile/conf_validator.py +590 -0
  17. ai/chronon/cli/compile/display/class_tracker.py +112 -0
  18. ai/chronon/cli/compile/display/compile_status.py +95 -0
  19. ai/chronon/cli/compile/display/compiled_obj.py +12 -0
  20. ai/chronon/cli/compile/display/console.py +3 -0
  21. ai/chronon/cli/compile/display/diff_result.py +46 -0
  22. ai/chronon/cli/compile/fill_templates.py +40 -0
  23. ai/chronon/cli/compile/parse_configs.py +141 -0
  24. ai/chronon/cli/compile/parse_teams.py +238 -0
  25. ai/chronon/cli/compile/serializer.py +115 -0
  26. ai/chronon/cli/git_utils.py +156 -0
  27. ai/chronon/cli/logger.py +61 -0
  28. ai/chronon/constants.py +3 -0
  29. ai/chronon/eval/__init__.py +122 -0
  30. ai/chronon/eval/query_parsing.py +19 -0
  31. ai/chronon/eval/sample_tables.py +100 -0
  32. ai/chronon/eval/table_scan.py +186 -0
  33. ai/chronon/fetcher/__init__.py +1 -0
  34. ai/chronon/fetcher/constants.py +15 -0
  35. ai/chronon/fetcher/ttypes.py +127 -0
  36. ai/chronon/group_by.py +692 -0
  37. ai/chronon/hub/__init__.py +1 -0
  38. ai/chronon/hub/constants.py +15 -0
  39. ai/chronon/hub/ttypes.py +1228 -0
  40. ai/chronon/join.py +566 -0
  41. ai/chronon/logger.py +24 -0
  42. ai/chronon/model.py +35 -0
  43. ai/chronon/observability/__init__.py +1 -0
  44. ai/chronon/observability/constants.py +15 -0
  45. ai/chronon/observability/ttypes.py +2192 -0
  46. ai/chronon/orchestration/__init__.py +1 -0
  47. ai/chronon/orchestration/constants.py +15 -0
  48. ai/chronon/orchestration/ttypes.py +4406 -0
  49. ai/chronon/planner/__init__.py +1 -0
  50. ai/chronon/planner/constants.py +15 -0
  51. ai/chronon/planner/ttypes.py +1686 -0
  52. ai/chronon/query.py +126 -0
  53. ai/chronon/repo/__init__.py +40 -0
  54. ai/chronon/repo/aws.py +298 -0
  55. ai/chronon/repo/cluster.py +65 -0
  56. ai/chronon/repo/compile.py +56 -0
  57. ai/chronon/repo/constants.py +164 -0
  58. ai/chronon/repo/default_runner.py +291 -0
  59. ai/chronon/repo/explore.py +421 -0
  60. ai/chronon/repo/extract_objects.py +137 -0
  61. ai/chronon/repo/gcp.py +585 -0
  62. ai/chronon/repo/gitpython_utils.py +14 -0
  63. ai/chronon/repo/hub_runner.py +171 -0
  64. ai/chronon/repo/hub_uploader.py +108 -0
  65. ai/chronon/repo/init.py +53 -0
  66. ai/chronon/repo/join_backfill.py +105 -0
  67. ai/chronon/repo/run.py +293 -0
  68. ai/chronon/repo/serializer.py +141 -0
  69. ai/chronon/repo/team_json_utils.py +46 -0
  70. ai/chronon/repo/utils.py +472 -0
  71. ai/chronon/repo/zipline.py +51 -0
  72. ai/chronon/repo/zipline_hub.py +105 -0
  73. ai/chronon/resources/gcp/README.md +174 -0
  74. ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
  75. ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
  76. ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
  77. ai/chronon/resources/gcp/joins/test/data.py +30 -0
  78. ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
  79. ai/chronon/resources/gcp/sources/test/data.py +23 -0
  80. ai/chronon/resources/gcp/teams.py +70 -0
  81. ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
  82. ai/chronon/source.py +88 -0
  83. ai/chronon/staging_query.py +185 -0
  84. ai/chronon/types.py +57 -0
  85. ai/chronon/utils.py +557 -0
  86. ai/chronon/windows.py +50 -0
  87. awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
  88. awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
  89. awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
  90. awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
  91. awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
  92. awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
  93. jars/__init__.py +0 -0
@@ -0,0 +1,313 @@
1
+ import hashlib
2
+ import re
3
+ from collections import defaultdict
4
+ from typing import Dict, List
5
+
6
+ from ai.chronon.api.ttypes import Derivation, ExternalPart, GroupBy, Join, Source
7
+ from ai.chronon.group_by import get_output_col_names
8
+
9
+
10
+ # Returns a map of output column to semantic hash, including derivations
11
+ def compute_group_by_columns_hashes(group_by: GroupBy, exclude_keys: bool = False) -> Dict[str, str]:
12
+ """
13
+ From the group_by object, get the final output columns after derivations.
14
+ """
15
+ # Get the output columns and their input expressions
16
+ output_to_input = get_pre_derived_group_by_columns(group_by)
17
+
18
+ # Get the base semantic fields that apply to all columns
19
+ base_semantics = []
20
+ for source in group_by.sources:
21
+ base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
22
+
23
+ # Add the major version to semantic fields
24
+ group_by_minor_version_suffix = f"__{group_by.metaData.version}"
25
+ group_by_major_version = group_by.metaData.name
26
+ if group_by_major_version.endswith(group_by_minor_version_suffix):
27
+ group_by_major_version = group_by_major_version[:-len(group_by_minor_version_suffix)]
28
+ base_semantics.append(f"group_by_name:{group_by_major_version}")
29
+
30
+ # Compute the semantic hash for each output column
31
+ output_to_hash = {}
32
+ for output_col, input_expr in output_to_input.items():
33
+ semantic_components = base_semantics + [f"input_expr:{input_expr}"]
34
+ semantic_hash = _compute_semantic_hash(semantic_components)
35
+ output_to_hash[output_col] = semantic_hash
36
+
37
+ if exclude_keys:
38
+ output = {k: v for k, v in output_to_hash.items() if k not in group_by.keyColumns}
39
+ else:
40
+ output = output_to_hash
41
+ if group_by.derivations:
42
+ derived = build_derived_columns(output, group_by.derivations, base_semantics)
43
+ if not exclude_keys:
44
+ # We need to add keys back at this point
45
+ for key in group_by.keyColumns:
46
+ if key not in derived:
47
+ derived[key] = output_to_input.get(key)
48
+ return derived
49
+ else:
50
+ return output
51
+
52
+
53
+ def compute_join_column_hashes(join: Join) -> Dict[str, str]:
54
+ """
55
+ From the join object, get the final output columns -> semantic hash after derivations.
56
+ """
57
+ # Get the base semantics from the left side (table and key expression)
58
+ base_semantic_fields = []
59
+
60
+ output_columns = get_pre_derived_join_features(join) | get_pre_derived_source_keys(join.left)
61
+
62
+ if join.derivations:
63
+ return build_derived_columns(output_columns, join.derivations, base_semantic_fields)
64
+ else:
65
+ return output_columns
66
+
67
+
68
+ def get_pre_derived_join_features(join: Join) -> Dict[str, str]:
69
+ return get_pre_derived_join_internal_features(join) | get_pre_derived_external_features(join)
70
+
71
+
72
+ def get_pre_derived_external_features(join: Join) -> Dict[str, str]:
73
+ external_cols = []
74
+ if join.onlineExternalParts:
75
+ for external_part in join.onlineExternalParts:
76
+ original_external_columns = [
77
+ param.name for param in external_part.source.valueSchema.params
78
+ ]
79
+ prefix = get_external_part_full_name(external_part) + "_"
80
+ for col in original_external_columns:
81
+ external_cols.append(prefix + col)
82
+ # No meaningful semantic information on external columns, so we just return the column names as a self map
83
+ return {x: x for x in external_cols}
84
+
85
+
86
+ def get_pre_derived_source_keys(source: Source) -> Dict[str, str]:
87
+ base_semantics = _extract_source_semantic_info(source)
88
+ source_keys_to_hashes = {}
89
+ for key, expression in extract_selects(source).items():
90
+ source_keys_to_hashes[key] = _compute_semantic_hash(base_semantics + [f"select:{key}={expression}"])
91
+ return source_keys_to_hashes
92
+
93
+
94
+ def extract_selects(source: Source) -> Dict[str, str]:
95
+ if source.events:
96
+ return source.events.query.selects
97
+ elif source.entities:
98
+ return source.entities.query.selects
99
+ elif source.joinSource:
100
+ return source.joinSource.query.selects
101
+
102
+
103
+ def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
104
+
105
+ # Get the base semantic fields from join left side (without key columns)
106
+ join_base_semantic_fields = _extract_source_semantic_info(join.left)
107
+
108
+ internal_features = {}
109
+ for jp in join.joinParts:
110
+ # Build key mapping semantics - include left side key expressions
111
+ if jp.keyMapping:
112
+ key_mapping_semantics = ["join_keys:" + ",".join(f"{k}:{v}" for k, v in sorted(jp.keyMapping.items()))]
113
+ else:
114
+ key_mapping_semantics = []
115
+
116
+ # Include left side key expressions that this join part uses
117
+ left_key_expressions = []
118
+ left_selects = extract_selects(join.left)
119
+ for gb_key in jp.groupBy.keyColumns:
120
+ # Get the mapped key name or use the original key name
121
+ left_key_name = gb_key
122
+ if jp.keyMapping:
123
+ # Find the left side key that maps to this groupby key
124
+ for left_key, mapped_key in jp.keyMapping.items():
125
+ if mapped_key == gb_key:
126
+ left_key_name = left_key
127
+ break
128
+
129
+ # Add the left side expression for this key
130
+ left_key_expr = left_selects.get(left_key_name, left_key_name)
131
+ left_key_expressions.append(f"left_key:{left_key_name}={left_key_expr}")
132
+
133
+ # These semantics apply to all features in the joinPart
134
+ jp_base_semantics = key_mapping_semantics + left_key_expressions + join_base_semantic_fields
135
+
136
+ pre_derived_group_by_features = get_pre_derived_group_by_features(jp.groupBy, jp_base_semantics)
137
+
138
+ if jp.groupBy.derivations:
139
+ derived_group_by_features = build_derived_columns(
140
+ pre_derived_group_by_features, jp.groupBy.derivations, jp_base_semantics
141
+ )
142
+ else:
143
+ derived_group_by_features = pre_derived_group_by_features
144
+
145
+ for col, semantic_hash in derived_group_by_features.items():
146
+ prefix = jp.prefix + "_" if jp.prefix else ""
147
+ if join.useLongNames:
148
+ gb_prefix = prefix + jp.groupBy.metaData.name.replace(".", "_")
149
+ else:
150
+ key_str = "_".join(jp.groupBy.keyColumns)
151
+ gb_prefix = prefix + key_str
152
+ internal_features[gb_prefix + "_" + col] = semantic_hash
153
+ return internal_features
154
+
155
+
156
+ def get_pre_derived_group_by_columns(group_by: GroupBy) -> Dict[str, str]:
157
+ output_columns_to_hashes = get_pre_derived_group_by_features(group_by)
158
+ for key_column in group_by.keyColumns:
159
+ key_expressions = []
160
+ for source in group_by.sources:
161
+ source_selects = extract_selects(source) # Map[str, str]
162
+ key_expressions.append(source_selects.get(key_column, key_column))
163
+ output_columns_to_hashes[key_column] = _compute_semantic_hash(sorted(key_expressions))
164
+ return output_columns_to_hashes
165
+
166
+
167
+ def get_pre_derived_group_by_features(group_by: GroupBy, additional_semantic_fields=None) -> Dict[str, str]:
168
+ # Get the base semantic fields that apply to all aggs
169
+ if additional_semantic_fields is None:
170
+ additional_semantic_fields = []
171
+ base_semantics = _get_base_group_by_semantic_fields(group_by)
172
+
173
+ output_columns = {}
174
+ # For group_bys with aggregations, aggregated columns
175
+ if group_by.aggregations:
176
+ for agg in group_by.aggregations:
177
+ input_expression_str = ",".join(get_input_expression_across_sources(group_by, agg.inputColumn))
178
+ for output_col_name in get_output_col_names(agg):
179
+ output_columns[output_col_name] = _compute_semantic_hash(base_semantics + [input_expression_str] + additional_semantic_fields)
180
+ # For group_bys without aggregations, selected fields from query
181
+ else:
182
+ combined_selects = defaultdict(set)
183
+
184
+ for source in group_by.sources:
185
+ source_selects = extract_selects(source) # Map[str, str]
186
+ for key, val in source_selects.items():
187
+ combined_selects[key].add(val)
188
+
189
+ # Build a unified map of key to select expression from all sources
190
+ unified_selects = {key: ",".join(sorted(vals)) for key, vals in combined_selects.items()}
191
+
192
+ # now compute the hashes on base semantics + expression
193
+ selected_hashes = {key: _compute_semantic_hash(base_semantics + [val] + additional_semantic_fields) for key, val in unified_selects.items()}
194
+ output_columns.update(selected_hashes)
195
+ return output_columns
196
+
197
+
198
+ def _get_base_group_by_semantic_fields(group_by: GroupBy) -> List[str]:
199
+ """
200
+ Extract base semantic fields from the group_by object.
201
+ This includes source table names, key columns, and any other relevant metadata.
202
+ """
203
+ base_semantics = []
204
+ for source in group_by.sources:
205
+ base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
206
+
207
+ return sorted(base_semantics)
208
+
209
+
210
+ def _extract_source_semantic_info(source: Source, key_columns: List[str] = None) -> List[str]:
211
+ """
212
+ Extract source information for semantic hashing.
213
+ Returns list of semantic components.
214
+
215
+ Args:
216
+ source: The source to extract info from
217
+ key_columns: List of key columns to include expressions for. If None, includes all selects.
218
+ """
219
+ components = []
220
+
221
+ if source.events:
222
+ table = source.events.table
223
+ mutationTable = ""
224
+ query = source.events.query
225
+ cumulative = str(source.events.isCumulative or "")
226
+ elif source.entities:
227
+ table = source.entities.snapshotTable
228
+ mutationTable = source.entities.mutationTable
229
+ query = source.entities.query
230
+ cumulative = ""
231
+ elif source.joinSource:
232
+ table = source.joinSource.join.metaData.name
233
+ mutationTable = ""
234
+ query = source.joinSource.query
235
+ cumulative = ""
236
+
237
+ components.append(f"table:{table}")
238
+ components.append(f"mutation_table:{mutationTable}")
239
+ components.append(f"cumulative:{cumulative}")
240
+ components.append(f"filters:{query.wheres or ''}")
241
+
242
+ selects = query.selects or {}
243
+ if key_columns:
244
+ # Only include expressions for the specified key columns
245
+ for key_col in sorted(key_columns):
246
+ expr = selects.get(key_col, key_col)
247
+ components.append(f"select:{key_col}={expr}")
248
+
249
+ # Add time column expression
250
+ if query.timeColumn:
251
+ time_expr = selects.get(query.timeColumn, query.timeColumn)
252
+ components.append(f"time_column:{query.timeColumn}={time_expr}")
253
+
254
+ return sorted(components)
255
+
256
+
257
+ def _compute_semantic_hash(components: List[str]) -> str:
258
+ """
259
+ Compute semantic hash from a list of components.
260
+ Components should be ordered consistently to ensure reproducible hashes.
261
+ """
262
+ # Sort components to ensure consistent ordering
263
+ sorted_components = sorted(components)
264
+ hash_input = "|".join(sorted_components)
265
+ return hashlib.md5(hash_input.encode('utf-8')).hexdigest()
266
+
267
+
268
+ def build_derived_columns(
269
+ base_columns_to_hashes: Dict[str, str], derivations: List[Derivation], additional_semantic_fields: List[str]
270
+ ) -> Dict[str, str]:
271
+ """
272
+ Build the derived columns from pre-derived columns and derivations.
273
+ """
274
+ # if derivations contain star, then all columns are included except the columns which are renamed
275
+ output_columns = {}
276
+ if derivations:
277
+ found = any(derivation.expression == "*" for derivation in derivations)
278
+ if found:
279
+ output_columns.update(base_columns_to_hashes)
280
+ for derivation in derivations:
281
+ if base_columns_to_hashes.get(derivation.expression):
282
+ # don't change the semantics if you're just passing a base column through derivations
283
+ output_columns[derivation.name] = base_columns_to_hashes[derivation.expression]
284
+ if derivation.name != "*":
285
+ # Identify base fields present within the derivation to include in the semantic hash
286
+ # We go long to short to avoid taking both a windowed feature and the unwindowed feature
287
+ # i.e. f_7d and f
288
+ derivation_expression = derivation.expression
289
+ base_col_semantic_fields = []
290
+ tokens = re.findall(r'\b\w+\b', derivation_expression)
291
+ for token in tokens:
292
+ if token in base_columns_to_hashes:
293
+ base_col_semantic_fields.append(base_columns_to_hashes[token])
294
+
295
+ output_columns[derivation.name] = _compute_semantic_hash(additional_semantic_fields + [f"derivation:{derivation.expression}"] + base_col_semantic_fields)
296
+ return output_columns
297
+
298
+
299
+ def get_external_part_full_name(external_part: ExternalPart) -> str:
300
+ # The logic should be consistent with the full name logic defined
301
+ # in https://github.com/airbnb/chronon/blob/main/api/src/main/scala/ai/chronon/api/Extensions.scala#L677.
302
+ prefix = external_part.prefix + "_" if external_part.prefix else ""
303
+ name = external_part.source.metadata.name
304
+ sanitized_name = re.sub("[^a-zA-Z0-9_]", "_", name)
305
+ return "ext_" + prefix + sanitized_name
306
+
307
+
308
+ def get_input_expression_across_sources(group_by: GroupBy, input_col: str):
309
+ expressions = []
310
+ for source in group_by.sources:
311
+ selects = extract_selects(source)
312
+ expressions.extend(selects.get(input_col, input_col))
313
+ return sorted(expressions)
@@ -0,0 +1,177 @@
1
+ import os
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Type
4
+
5
+ import ai.chronon.cli.compile.parse_teams as teams
6
+ from ai.chronon.api.ttypes import GroupBy, Join, MetaData, Model, StagingQuery, Team
7
+ from ai.chronon.cli.compile.conf_validator import ConfValidator
8
+ from ai.chronon.cli.compile.display.compile_status import CompileStatus
9
+ from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
10
+ from ai.chronon.cli.compile.serializer import file2thrift
11
+ from ai.chronon.cli.logger import get_logger, require
12
+ from ai.chronon.orchestration.ttypes import ConfType
13
+
14
+ logger = get_logger()
15
+
16
+
17
+ @dataclass
18
+ class ConfigInfo:
19
+ folder_name: str
20
+ cls: Type
21
+ config_type: Optional[ConfType]
22
+
23
+
24
+ @dataclass
25
+ class CompileContext:
26
+
27
+ def __init__(self):
28
+ self.chronon_root: str = os.getenv("CHRONON_ROOT", os.getcwd())
29
+ self.teams_dict: Dict[str, Team] = teams.load_teams(self.chronon_root)
30
+ self.compile_dir: str = "compiled"
31
+
32
+ self.config_infos: List[ConfigInfo] = [
33
+ ConfigInfo(folder_name="joins", cls=Join, config_type=ConfType.JOIN),
34
+ ConfigInfo(
35
+ folder_name="group_bys",
36
+ cls=GroupBy,
37
+ config_type=ConfType.GROUP_BY,
38
+ ),
39
+ ConfigInfo(
40
+ folder_name="staging_queries",
41
+ cls=StagingQuery,
42
+ config_type=ConfType.STAGING_QUERY,
43
+ ),
44
+ ConfigInfo(folder_name="models", cls=Model, config_type=ConfType.MODEL),
45
+ ConfigInfo(folder_name="teams_metadata", cls=MetaData, config_type=None), # only for team metadata
46
+ ]
47
+
48
+ self.compile_status = CompileStatus(use_live=False)
49
+
50
+ self.existing_confs: Dict[Type, Dict[str, Any]] = {}
51
+ for config_info in self.config_infos:
52
+ cls = config_info.cls
53
+ self.existing_confs[cls] = self._parse_existing_confs(cls)
54
+
55
+
56
+
57
+ self.validator: ConfValidator = ConfValidator(
58
+ input_root=self.chronon_root,
59
+ output_root=self.compile_dir,
60
+ existing_gbs=self.existing_confs[GroupBy],
61
+ existing_joins=self.existing_confs[Join],
62
+ )
63
+
64
+ def input_dir(self, cls: type) -> str:
65
+ """
66
+ - eg., input: group_by class
67
+ - eg., output: root/group_bys/
68
+ """
69
+ config_info = self.config_info_for_class(cls)
70
+ return os.path.join(self.chronon_root, config_info.folder_name)
71
+
72
+ def staging_output_dir(self, cls: type = None) -> str:
73
+ """
74
+ - eg., input: group_by class
75
+ - eg., output: root/compiled_staging/group_bys/
76
+ """
77
+ if cls is None:
78
+ return os.path.join(self.chronon_root, self.compile_dir + "_staging")
79
+ else:
80
+ config_info = self.config_info_for_class(cls)
81
+ return os.path.join(
82
+ self.chronon_root,
83
+ self.compile_dir + "_staging",
84
+ config_info.folder_name,
85
+ )
86
+
87
+ def output_dir(self, cls: type = None) -> str:
88
+ """
89
+ - eg., input: group_by class
90
+ - eg., output: root/compiled/group_bys/
91
+ """
92
+ if cls is None:
93
+ return os.path.join(self.chronon_root, self.compile_dir)
94
+ else:
95
+ config_info = self.config_info_for_class(cls)
96
+ return os.path.join(
97
+ self.chronon_root, self.compile_dir, config_info.folder_name
98
+ )
99
+
100
+ def staging_output_path(self, compiled_obj: CompiledObj):
101
+ """
102
+ - eg., input: group_by with name search.clicks.features.v1
103
+ - eg., output: root/compiled_staging/group_bys/search/clicks.features.v1
104
+ """
105
+
106
+ output_dir = self.staging_output_dir(compiled_obj.obj.__class__) # compiled/joins
107
+
108
+ team, rest = compiled_obj.name.split(".", 1) # search, clicks.features.v1
109
+
110
+ return os.path.join(
111
+ output_dir,
112
+ team,
113
+ rest,
114
+ )
115
+
116
+ def config_info_for_class(self, cls: type) -> ConfigInfo:
117
+ for info in self.config_infos:
118
+ if info.cls == cls:
119
+ return info
120
+
121
+ require(False, f"Class {cls} not found in CONFIG_INFOS")
122
+
123
+ def _parse_existing_confs(self, obj_class: type) -> Dict[str, object]:
124
+
125
+ result = {}
126
+
127
+ output_dir = self.output_dir(obj_class)
128
+
129
+ # Check if output_dir exists before walking
130
+ if not os.path.exists(output_dir):
131
+ return result
132
+
133
+ for sub_root, _sub_dirs, sub_files in os.walk(output_dir):
134
+
135
+ for f in sub_files:
136
+
137
+ if f.startswith("."): # ignore hidden files - such as .DS_Store
138
+ continue
139
+
140
+ full_path = os.path.join(sub_root, f)
141
+
142
+ try:
143
+ obj = file2thrift(full_path, obj_class)
144
+
145
+ if obj:
146
+ if hasattr(obj, "metaData"):
147
+ result[obj.metaData.name] = obj
148
+ compiled_obj = CompiledObj(
149
+ name=obj.metaData.name,
150
+ obj=obj,
151
+ file=obj.metaData.sourceFile,
152
+ errors=None,
153
+ obj_type=obj_class.__name__,
154
+ tjson=open(full_path).read(),
155
+ )
156
+ self.compile_status.add_existing_object_update_display(compiled_obj)
157
+ elif isinstance(obj, MetaData):
158
+ team_metadata_name = '.'.join(full_path.split('/')[-2:]) # use the name of the file as team metadata won't have name
159
+ result[team_metadata_name] = obj
160
+ compiled_obj = CompiledObj(
161
+ name=team_metadata_name,
162
+ obj=obj,
163
+ file=obj.sourceFile,
164
+ errors=None,
165
+ obj_type=obj_class.__name__,
166
+ tjson=open(full_path).read(),
167
+ )
168
+ self.compile_status.add_existing_object_update_display(compiled_obj)
169
+ else:
170
+ logger.errors(
171
+ f"Parsed object from {full_path} has no metaData attribute"
172
+ )
173
+
174
+ except Exception as e:
175
+ print(f"Failed to parse file {full_path}: {str(e)}", e)
176
+
177
+ return result
@@ -0,0 +1,160 @@
1
+ import os
2
+ import shutil
3
+ import traceback
4
+ from dataclasses import dataclass
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import ai.chronon.cli.compile.display.compiled_obj
8
+ import ai.chronon.cli.compile.parse_configs as parser
9
+ import ai.chronon.cli.logger as logger
10
+ from ai.chronon.cli.compile import serializer
11
+ from ai.chronon.cli.compile.compile_context import CompileContext, ConfigInfo
12
+ from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
13
+ from ai.chronon.cli.compile.display.console import console
14
+ from ai.chronon.cli.compile.parse_teams import merge_team_execution_info
15
+ from ai.chronon.orchestration.ttypes import ConfType
16
+ from ai.chronon.types import MetaData
17
+
18
+ logger = logger.get_logger()
19
+
20
+
21
+ @dataclass
22
+ class CompileResult:
23
+ config_info: ConfigInfo
24
+ obj_dict: Dict[str, Any]
25
+ error_dict: Dict[str, List[BaseException]]
26
+
27
+
28
+ class Compiler:
29
+
30
+ def __init__(self, compile_context: CompileContext):
31
+ self.compile_context = compile_context
32
+
33
+ def compile(self) -> Dict[ConfType, CompileResult]:
34
+
35
+ config_infos = self.compile_context.config_infos
36
+
37
+ compile_results = {}
38
+
39
+ for config_info in config_infos:
40
+ configs = self._compile_class_configs(config_info)
41
+
42
+ compile_results[config_info.config_type] = configs
43
+ self._compile_team_metadata()
44
+
45
+ # check if staging_output_dir exists
46
+ staging_dir = self.compile_context.staging_output_dir()
47
+ if os.path.exists(staging_dir):
48
+ # replace staging_output_dir to output_dir
49
+ output_dir = self.compile_context.output_dir()
50
+ if os.path.exists(output_dir):
51
+ shutil.rmtree(output_dir)
52
+ shutil.move(staging_dir, output_dir)
53
+ else:
54
+ print(
55
+ f"Staging directory {staging_dir} does not exist. "
56
+ "Happens when every chronon config fails to compile or when no chronon configs exist."
57
+ )
58
+
59
+ # TODO: temporarily just print out the final results of the compile until live fix is implemented:
60
+ # https://github.com/Textualize/rich/pull/3637
61
+ console.print(self.compile_context.compile_status.render())
62
+
63
+ return compile_results
64
+
65
+ def _compile_team_metadata(self):
66
+ """
67
+ Compile the team metadata and return the compiled object.
68
+ """
69
+ teams_dict = self.compile_context.teams_dict
70
+ for team in teams_dict:
71
+ m = MetaData()
72
+ merge_team_execution_info(m, teams_dict, team)
73
+
74
+ tjson = serializer.thrift_simple_json(m)
75
+ name = f"{team}.{team}_team_metadata"
76
+ result = CompiledObj(
77
+ name=name,
78
+ obj=m,
79
+ file=name,
80
+ errors=None,
81
+ obj_type=MetaData.__name__,
82
+ tjson=tjson,
83
+ )
84
+ self._write_object(result)
85
+ self.compile_context.compile_status.add_object_update_display(result, MetaData.__name__)
86
+
87
+ # Done writing team metadata, close the class
88
+ self.compile_context.compile_status.close_cls(MetaData.__name__)
89
+
90
+ def _compile_class_configs(self, config_info: ConfigInfo) -> CompileResult:
91
+
92
+ compile_result = CompileResult(
93
+ config_info=config_info, obj_dict={}, error_dict={}
94
+ )
95
+
96
+ input_dir = self.compile_context.input_dir(config_info.cls)
97
+
98
+ compiled_objects = parser.from_folder(
99
+ config_info.cls, input_dir, self.compile_context
100
+ )
101
+
102
+ objects, errors = self._write_objects_in_folder(compiled_objects)
103
+
104
+ if objects:
105
+ compile_result.obj_dict.update(objects)
106
+
107
+ if errors:
108
+ compile_result.error_dict.update(errors)
109
+
110
+ self.compile_context.compile_status.close_cls(config_info.cls.__name__)
111
+
112
+ return compile_result
113
+
114
+ def _write_objects_in_folder(
115
+ self,
116
+ compiled_objects: List[ai.chronon.cli.compile.display.compiled_obj.CompiledObj],
117
+ ) -> Tuple[Dict[str, Any], Dict[str, List[BaseException]]]:
118
+
119
+ error_dict = {}
120
+ object_dict = {}
121
+
122
+ for co in compiled_objects:
123
+
124
+ if co.obj:
125
+
126
+ if co.errors:
127
+ error_dict[co.name] = co.errors
128
+
129
+ for error in co.errors:
130
+ self.compile_context.compile_status.print_live_console(
131
+ f"Error processing conf {co.name}: {error}"
132
+ )
133
+ traceback.print_exception(
134
+ type(error), error, error.__traceback__
135
+ )
136
+
137
+ else:
138
+ self._write_object(co)
139
+ object_dict[co.name] = co.obj
140
+ else:
141
+ error_dict[co.file] = co.errors
142
+
143
+ self.compile_context.compile_status.print_live_console(
144
+ f"Error processing file {co.file}: {co.errors}"
145
+ )
146
+ for error in co.errors:
147
+ traceback.print_exception(type(error), error, error.__traceback__)
148
+
149
+ return object_dict, error_dict
150
+
151
+ def _write_object(self, compiled_obj: CompiledObj) -> Optional[List[BaseException]]:
152
+ output_path = self.compile_context.staging_output_path(compiled_obj)
153
+
154
+ folder = os.path.dirname(output_path)
155
+
156
+ if not os.path.exists(folder):
157
+ os.makedirs(folder)
158
+
159
+ with open(output_path, "w") as f:
160
+ f.write(compiled_obj.tjson)