awx-zipline-ai 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of awx-zipline-ai might be problematic. Click here for more details.

Files changed (96) hide show
  1. agent/ttypes.py +6 -6
  2. ai/chronon/airflow_helpers.py +20 -23
  3. ai/chronon/cli/__init__.py +0 -0
  4. ai/chronon/cli/compile/__init__.py +0 -0
  5. ai/chronon/cli/compile/column_hashing.py +40 -17
  6. ai/chronon/cli/compile/compile_context.py +13 -17
  7. ai/chronon/cli/compile/compiler.py +59 -36
  8. ai/chronon/cli/compile/conf_validator.py +251 -99
  9. ai/chronon/cli/compile/display/__init__.py +0 -0
  10. ai/chronon/cli/compile/display/class_tracker.py +6 -16
  11. ai/chronon/cli/compile/display/compile_status.py +10 -10
  12. ai/chronon/cli/compile/display/diff_result.py +79 -14
  13. ai/chronon/cli/compile/fill_templates.py +3 -8
  14. ai/chronon/cli/compile/parse_configs.py +10 -17
  15. ai/chronon/cli/compile/parse_teams.py +38 -34
  16. ai/chronon/cli/compile/serializer.py +3 -9
  17. ai/chronon/cli/compile/version_utils.py +42 -0
  18. ai/chronon/cli/git_utils.py +2 -13
  19. ai/chronon/cli/logger.py +0 -2
  20. ai/chronon/constants.py +1 -1
  21. ai/chronon/group_by.py +47 -47
  22. ai/chronon/join.py +46 -32
  23. ai/chronon/logger.py +1 -2
  24. ai/chronon/model.py +9 -4
  25. ai/chronon/query.py +2 -2
  26. ai/chronon/repo/__init__.py +1 -2
  27. ai/chronon/repo/aws.py +17 -31
  28. ai/chronon/repo/cluster.py +121 -50
  29. ai/chronon/repo/compile.py +14 -8
  30. ai/chronon/repo/constants.py +1 -1
  31. ai/chronon/repo/default_runner.py +32 -54
  32. ai/chronon/repo/explore.py +70 -73
  33. ai/chronon/repo/extract_objects.py +6 -9
  34. ai/chronon/repo/gcp.py +89 -88
  35. ai/chronon/repo/gitpython_utils.py +3 -2
  36. ai/chronon/repo/hub_runner.py +145 -55
  37. ai/chronon/repo/hub_uploader.py +2 -1
  38. ai/chronon/repo/init.py +12 -5
  39. ai/chronon/repo/join_backfill.py +19 -5
  40. ai/chronon/repo/run.py +42 -39
  41. ai/chronon/repo/serializer.py +4 -12
  42. ai/chronon/repo/utils.py +72 -63
  43. ai/chronon/repo/zipline.py +3 -19
  44. ai/chronon/repo/zipline_hub.py +211 -39
  45. ai/chronon/resources/__init__.py +0 -0
  46. ai/chronon/resources/gcp/__init__.py +0 -0
  47. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  48. ai/chronon/resources/gcp/group_bys/test/data.py +13 -17
  49. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  50. ai/chronon/resources/gcp/joins/test/data.py +4 -8
  51. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  52. ai/chronon/resources/gcp/sources/test/data.py +9 -6
  53. ai/chronon/resources/gcp/teams.py +9 -21
  54. ai/chronon/source.py +2 -4
  55. ai/chronon/staging_query.py +60 -19
  56. ai/chronon/types.py +3 -2
  57. ai/chronon/utils.py +21 -68
  58. ai/chronon/windows.py +2 -4
  59. {awx_zipline_ai-0.2.0.dist-info → awx_zipline_ai-0.3.0.dist-info}/METADATA +47 -24
  60. awx_zipline_ai-0.3.0.dist-info/RECORD +96 -0
  61. awx_zipline_ai-0.3.0.dist-info/top_level.txt +4 -0
  62. gen_thrift/__init__.py +0 -0
  63. {ai/chronon → gen_thrift}/api/ttypes.py +327 -197
  64. {ai/chronon/api → gen_thrift}/common/ttypes.py +9 -39
  65. gen_thrift/eval/ttypes.py +660 -0
  66. {ai/chronon → gen_thrift}/hub/ttypes.py +12 -131
  67. {ai/chronon → gen_thrift}/observability/ttypes.py +343 -180
  68. {ai/chronon → gen_thrift}/planner/ttypes.py +326 -45
  69. ai/chronon/eval/__init__.py +0 -122
  70. ai/chronon/eval/query_parsing.py +0 -19
  71. ai/chronon/eval/sample_tables.py +0 -100
  72. ai/chronon/eval/table_scan.py +0 -186
  73. ai/chronon/orchestration/ttypes.py +0 -4406
  74. ai/chronon/resources/gcp/README.md +0 -174
  75. ai/chronon/resources/gcp/zipline-cli-install.sh +0 -54
  76. awx_zipline_ai-0.2.0.dist-info/RECORD +0 -93
  77. awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +0 -202
  78. awx_zipline_ai-0.2.0.dist-info/top_level.txt +0 -3
  79. /jars/__init__.py → /__init__.py +0 -0
  80. {awx_zipline_ai-0.2.0.dist-info → awx_zipline_ai-0.3.0.dist-info}/WHEEL +0 -0
  81. {awx_zipline_ai-0.2.0.dist-info → awx_zipline_ai-0.3.0.dist-info}/entry_points.txt +0 -0
  82. {ai/chronon → gen_thrift}/api/__init__.py +0 -0
  83. {ai/chronon/api/common → gen_thrift/api}/constants.py +0 -0
  84. {ai/chronon/api → gen_thrift}/common/__init__.py +0 -0
  85. {ai/chronon/api → gen_thrift/common}/constants.py +0 -0
  86. {ai/chronon/fetcher → gen_thrift/eval}/__init__.py +0 -0
  87. {ai/chronon/fetcher → gen_thrift/eval}/constants.py +0 -0
  88. {ai/chronon/hub → gen_thrift/fetcher}/__init__.py +0 -0
  89. {ai/chronon/hub → gen_thrift/fetcher}/constants.py +0 -0
  90. {ai/chronon → gen_thrift}/fetcher/ttypes.py +0 -0
  91. {ai/chronon/observability → gen_thrift/hub}/__init__.py +0 -0
  92. {ai/chronon/observability → gen_thrift/hub}/constants.py +0 -0
  93. {ai/chronon/orchestration → gen_thrift/observability}/__init__.py +0 -0
  94. {ai/chronon/orchestration → gen_thrift/observability}/constants.py +0 -0
  95. {ai/chronon → gen_thrift}/planner/__init__.py +0 -0
  96. {ai/chronon → gen_thrift}/planner/constants.py +0 -0
agent/ttypes.py CHANGED
@@ -12,7 +12,7 @@ from thrift.TRecursive import fix_spec
12
12
  from uuid import UUID
13
13
 
14
14
  import sys
15
- import ai.chronon.api.common.ttypes
15
+ import gen_thrift.common.ttypes
16
16
 
17
17
  from thrift.transport import TTransport
18
18
  all_structs = []
@@ -39,7 +39,7 @@ class JobStatusType(object):
39
39
  RUNNING = 2
40
40
  SUCCEEDED = 3
41
41
  FAILED = 4
42
- STOPPED = 5
42
+ CANCELLED = 5
43
43
 
44
44
  _VALUES_TO_NAMES = {
45
45
  0: "UNKNOWN",
@@ -47,7 +47,7 @@ class JobStatusType(object):
47
47
  2: "RUNNING",
48
48
  3: "SUCCEEDED",
49
49
  4: "FAILED",
50
- 5: "STOPPED",
50
+ 5: "CANCELLED",
51
51
  }
52
52
 
53
53
  _NAMES_TO_VALUES = {
@@ -56,7 +56,7 @@ class JobStatusType(object):
56
56
  "RUNNING": 2,
57
57
  "SUCCEEDED": 3,
58
58
  "FAILED": 4,
59
- "STOPPED": 5,
59
+ "CANCELLED": 5,
60
60
  }
61
61
 
62
62
 
@@ -1388,7 +1388,7 @@ class PartitionListingPutRequest(object):
1388
1388
  _val84 = []
1389
1389
  (_etype88, _size85) = iprot.readListBegin()
1390
1390
  for _i89 in range(_size85):
1391
- _elem90 = ai.chronon.api.common.ttypes.DateRange()
1391
+ _elem90 = gen_thrift.common.ttypes.DateRange()
1392
1392
  _elem90.read(iprot)
1393
1393
  _val84.append(_elem90)
1394
1394
  iprot.readListEnd()
@@ -1672,7 +1672,7 @@ JobInfo.thrift_spec = (
1672
1672
  all_structs.append(PartitionListingPutRequest)
1673
1673
  PartitionListingPutRequest.thrift_spec = (
1674
1674
  None, # 0
1675
- (1, TType.MAP, 'partitions', (TType.STRUCT, [PartitionListingJob, None], TType.LIST, (TType.STRUCT, [ai.chronon.api.common.ttypes.DateRange, None], False), False), None, ), # 1
1675
+ (1, TType.MAP, 'partitions', (TType.STRUCT, [PartitionListingJob, None], TType.LIST, (TType.STRUCT, [gen_thrift.common.ttypes.DateRange, None], False), False), None, ), # 1
1676
1676
  (2, TType.MAP, 'errors', (TType.STRUCT, [PartitionListingJob, None], TType.STRING, 'UTF8', False), None, ), # 2
1677
1677
  )
1678
1678
  all_structs.append(JobInfoPutRequest)
@@ -2,9 +2,10 @@ import json
2
2
  import math
3
3
  from typing import OrderedDict
4
4
 
5
+ from gen_thrift.api.ttypes import GroupBy, Join
6
+ from gen_thrift.common.ttypes import TimeUnit
7
+
5
8
  import ai.chronon.utils as utils
6
- from ai.chronon.api.common.ttypes import TimeUnit
7
- from ai.chronon.api.ttypes import GroupBy, Join
8
9
  from ai.chronon.constants import (
9
10
  AIRFLOW_DEPENDENCIES_KEY,
10
11
  AIRFLOW_LABEL_DEPENDENCIES_KEY,
@@ -55,6 +56,7 @@ def _get_partition_col_from_query(query):
55
56
  return query.partitionColumn
56
57
  return None
57
58
 
59
+
58
60
  def _get_additional_subPartitionsToWaitFor_from_query(query):
59
61
  """Gets additional subPartitionsToWaitFor from query if available"""
60
62
  if query:
@@ -80,7 +82,8 @@ def _get_airflow_deps_from_source(source, partition_column=None):
80
82
  tables = [source.events.table]
81
83
  # Use partition column from query if available, otherwise use the provided one
82
84
  source_partition_column, additional_partitions = (
83
- _get_partition_col_from_query(source.events.query) or partition_column, _get_additional_subPartitionsToWaitFor_from_query(source.events.query)
85
+ _get_partition_col_from_query(source.events.query) or partition_column,
86
+ _get_additional_subPartitionsToWaitFor_from_query(source.events.query),
84
87
  )
85
88
 
86
89
  elif source.entities:
@@ -89,7 +92,8 @@ def _get_airflow_deps_from_source(source, partition_column=None):
89
92
  if source.entities.mutationTable:
90
93
  tables.append(source.entities.mutationTable)
91
94
  source_partition_column, additional_partitions = (
92
- _get_partition_col_from_query(source.entities.query) or partition_column, _get_additional_subPartitionsToWaitFor_from_query(source.entities.query)
95
+ _get_partition_col_from_query(source.entities.query) or partition_column,
96
+ _get_additional_subPartitionsToWaitFor_from_query(source.entities.query),
93
97
  )
94
98
  elif source.joinSource:
95
99
  # TODO: Handle joinSource -- it doesn't work right now because the metadata isn't set on joinSource at this point
@@ -99,15 +103,14 @@ def _get_airflow_deps_from_source(source, partition_column=None):
99
103
  return []
100
104
 
101
105
  return [
102
- create_airflow_dependency(table, source_partition_column, additional_partitions) for table in tables
106
+ create_airflow_dependency(table, source_partition_column, additional_partitions)
107
+ for table in tables
103
108
  ]
104
109
 
105
110
 
106
111
  def extract_default_partition_column(obj):
107
112
  try:
108
- return obj.metaData.executionInfo.conf.common.get(
109
- "spark.chronon.partition.column"
110
- )
113
+ return obj.metaData.executionInfo.conf.common.get("spark.chronon.partition.column")
111
114
  except Exception:
112
115
  # Error handling occurs in `create_airflow_dependency`
113
116
  return None
@@ -124,9 +127,9 @@ def _get_distinct_day_windows(group_by):
124
127
  if time_unit == TimeUnit.DAYS:
125
128
  windows.append(length)
126
129
  elif time_unit == TimeUnit.HOURS:
127
- windows.append(math.ceil(length/24))
130
+ windows.append(math.ceil(length / 24))
128
131
  elif time_unit == TimeUnit.MINUTES:
129
- windows.append(math.ceil(length/(24*60)))
132
+ windows.append(math.ceil(length / (24 * 60)))
130
133
  return set(windows)
131
134
 
132
135
 
@@ -137,9 +140,7 @@ def _set_join_deps(join):
137
140
 
138
141
  # Handle left source
139
142
  left_query = utils.get_query(join.left)
140
- left_partition_column = (
141
- _get_partition_col_from_query(left_query) or default_partition_col
142
- )
143
+ left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
143
144
  deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))
144
145
 
145
146
  # Handle right parts (join parts)
@@ -149,12 +150,9 @@ def _set_join_deps(join):
149
150
  for source in join_part.groupBy.sources:
150
151
  source_query = utils.get_query(source)
151
152
  source_partition_column = (
152
- _get_partition_col_from_query(source_query)
153
- or default_partition_col
154
- )
155
- deps.extend(
156
- _get_airflow_deps_from_source(source, source_partition_column)
153
+ _get_partition_col_from_query(source_query) or default_partition_col
157
154
  )
155
+ deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
158
156
 
159
157
  label_deps = []
160
158
  # Handle label parts
@@ -162,7 +160,6 @@ def _set_join_deps(join):
162
160
  join_output_table = utils.output_table_name(join, full_name=True)
163
161
  partition_column = join.metaData.executionInfo.conf.common[PARTITION_COLUMN_KEY]
164
162
 
165
-
166
163
  # set the dependencies on the label sources
167
164
  for label_part in join.labelParts.labels:
168
165
  group_by = label_part.groupBy
@@ -171,21 +168,21 @@ def _set_join_deps(join):
171
168
  windows = _get_distinct_day_windows(group_by)
172
169
  for window in windows:
173
170
  label_deps.append(
174
- create_airflow_dependency(join_output_table, partition_column, offset=-1 * window)
171
+ create_airflow_dependency(
172
+ join_output_table, partition_column, offset=-1 * window
173
+ )
175
174
  )
176
175
 
177
176
  if group_by and group_by.sources:
178
177
  for source in label_part.groupBy.sources:
179
178
  source_query = utils.get_query(source)
180
179
  source_partition_column = (
181
- _get_partition_col_from_query(source_query)
182
- or default_partition_col
180
+ _get_partition_col_from_query(source_query) or default_partition_col
183
181
  )
184
182
  label_deps.extend(
185
183
  _get_airflow_deps_from_source(source, source_partition_column)
186
184
  )
187
185
 
188
-
189
186
  # Update the metadata customJson with dependencies
190
187
  _dedupe_and_set_airflow_deps_json(join, deps, AIRFLOW_DEPENDENCIES_KEY)
191
188
 
File without changes
File without changes
@@ -3,12 +3,15 @@ import re
3
3
  from collections import defaultdict
4
4
  from typing import Dict, List
5
5
 
6
- from ai.chronon.api.ttypes import Derivation, ExternalPart, GroupBy, Join, Source
6
+ from gen_thrift.api.ttypes import Derivation, ExternalPart, GroupBy, Join, Source
7
+
7
8
  from ai.chronon.group_by import get_output_col_names
8
9
 
9
10
 
10
11
  # Returns a map of output column to semantic hash, including derivations
11
- def compute_group_by_columns_hashes(group_by: GroupBy, exclude_keys: bool = False) -> Dict[str, str]:
12
+ def compute_group_by_columns_hashes(
13
+ group_by: GroupBy, exclude_keys: bool = False
14
+ ) -> Dict[str, str]:
12
15
  """
13
16
  From the group_by object, get the final output columns after derivations.
14
17
  """
@@ -24,7 +27,7 @@ def compute_group_by_columns_hashes(group_by: GroupBy, exclude_keys: bool = Fals
24
27
  group_by_minor_version_suffix = f"__{group_by.metaData.version}"
25
28
  group_by_major_version = group_by.metaData.name
26
29
  if group_by_major_version.endswith(group_by_minor_version_suffix):
27
- group_by_major_version = group_by_major_version[:-len(group_by_minor_version_suffix)]
30
+ group_by_major_version = group_by_major_version[: -len(group_by_minor_version_suffix)]
28
31
  base_semantics.append(f"group_by_name:{group_by_major_version}")
29
32
 
30
33
  # Compute the semantic hash for each output column
@@ -87,7 +90,9 @@ def get_pre_derived_source_keys(source: Source) -> Dict[str, str]:
87
90
  base_semantics = _extract_source_semantic_info(source)
88
91
  source_keys_to_hashes = {}
89
92
  for key, expression in extract_selects(source).items():
90
- source_keys_to_hashes[key] = _compute_semantic_hash(base_semantics + [f"select:{key}={expression}"])
93
+ source_keys_to_hashes[key] = _compute_semantic_hash(
94
+ base_semantics + [f"select:{key}={expression}"]
95
+ )
91
96
  return source_keys_to_hashes
92
97
 
93
98
 
@@ -101,7 +106,6 @@ def extract_selects(source: Source) -> Dict[str, str]:
101
106
 
102
107
 
103
108
  def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
104
-
105
109
  # Get the base semantic fields from join left side (without key columns)
106
110
  join_base_semantic_fields = _extract_source_semantic_info(join.left)
107
111
 
@@ -109,7 +113,9 @@ def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
109
113
  for jp in join.joinParts:
110
114
  # Build key mapping semantics - include left side key expressions
111
115
  if jp.keyMapping:
112
- key_mapping_semantics = ["join_keys:" + ",".join(f"{k}:{v}" for k, v in sorted(jp.keyMapping.items()))]
116
+ key_mapping_semantics = [
117
+ "join_keys:" + ",".join(f"{k}:{v}" for k, v in sorted(jp.keyMapping.items()))
118
+ ]
113
119
  else:
114
120
  key_mapping_semantics = []
115
121
 
@@ -133,7 +139,9 @@ def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
133
139
  # These semantics apply to all features in the joinPart
134
140
  jp_base_semantics = key_mapping_semantics + left_key_expressions + join_base_semantic_fields
135
141
 
136
- pre_derived_group_by_features = get_pre_derived_group_by_features(jp.groupBy, jp_base_semantics)
142
+ pre_derived_group_by_features = get_pre_derived_group_by_features(
143
+ jp.groupBy, jp_base_semantics
144
+ )
137
145
 
138
146
  if jp.groupBy.derivations:
139
147
  derived_group_by_features = build_derived_columns(
@@ -164,7 +172,9 @@ def get_pre_derived_group_by_columns(group_by: GroupBy) -> Dict[str, str]:
164
172
  return output_columns_to_hashes
165
173
 
166
174
 
167
- def get_pre_derived_group_by_features(group_by: GroupBy, additional_semantic_fields=None) -> Dict[str, str]:
175
+ def get_pre_derived_group_by_features(
176
+ group_by: GroupBy, additional_semantic_fields=None
177
+ ) -> Dict[str, str]:
168
178
  # Get the base semantic fields that apply to all aggs
169
179
  if additional_semantic_fields is None:
170
180
  additional_semantic_fields = []
@@ -174,9 +184,13 @@ def get_pre_derived_group_by_features(group_by: GroupBy, additional_semantic_fie
174
184
  # For group_bys with aggregations, aggregated columns
175
185
  if group_by.aggregations:
176
186
  for agg in group_by.aggregations:
177
- input_expression_str = ",".join(get_input_expression_across_sources(group_by, agg.inputColumn))
187
+ input_expression_str = ",".join(
188
+ get_input_expression_across_sources(group_by, agg.inputColumn)
189
+ )
178
190
  for output_col_name in get_output_col_names(agg):
179
- output_columns[output_col_name] = _compute_semantic_hash(base_semantics + [input_expression_str] + additional_semantic_fields)
191
+ output_columns[output_col_name] = _compute_semantic_hash(
192
+ base_semantics + [input_expression_str] + additional_semantic_fields
193
+ )
180
194
  # For group_bys without aggregations, selected fields from query
181
195
  else:
182
196
  combined_selects = defaultdict(set)
@@ -190,7 +204,10 @@ def get_pre_derived_group_by_features(group_by: GroupBy, additional_semantic_fie
190
204
  unified_selects = {key: ",".join(sorted(vals)) for key, vals in combined_selects.items()}
191
205
 
192
206
  # now compute the hashes on base semantics + expression
193
- selected_hashes = {key: _compute_semantic_hash(base_semantics + [val] + additional_semantic_fields) for key, val in unified_selects.items()}
207
+ selected_hashes = {
208
+ key: _compute_semantic_hash(base_semantics + [val] + additional_semantic_fields)
209
+ for key, val in unified_selects.items()
210
+ }
194
211
  output_columns.update(selected_hashes)
195
212
  return output_columns
196
213
 
@@ -262,11 +279,13 @@ def _compute_semantic_hash(components: List[str]) -> str:
262
279
  # Sort components to ensure consistent ordering
263
280
  sorted_components = sorted(components)
264
281
  hash_input = "|".join(sorted_components)
265
- return hashlib.md5(hash_input.encode('utf-8')).hexdigest()
282
+ return hashlib.md5(hash_input.encode("utf-8")).hexdigest()
266
283
 
267
284
 
268
285
  def build_derived_columns(
269
- base_columns_to_hashes: Dict[str, str], derivations: List[Derivation], additional_semantic_fields: List[str]
286
+ base_columns_to_hashes: Dict[str, str],
287
+ derivations: List[Derivation],
288
+ additional_semantic_fields: List[str],
270
289
  ) -> Dict[str, str]:
271
290
  """
272
291
  Build the derived columns from pre-derived columns and derivations.
@@ -279,20 +298,24 @@ def build_derived_columns(
279
298
  output_columns.update(base_columns_to_hashes)
280
299
  for derivation in derivations:
281
300
  if base_columns_to_hashes.get(derivation.expression):
282
- # don't change the semantics if you're just passing a base column through derivations
283
- output_columns[derivation.name] = base_columns_to_hashes[derivation.expression]
301
+ # don't change the semantics if you're just passing a base column through derivations
302
+ output_columns[derivation.name] = base_columns_to_hashes[derivation.expression]
284
303
  if derivation.name != "*":
285
304
  # Identify base fields present within the derivation to include in the semantic hash
286
305
  # We go long to short to avoid taking both a windowed feature and the unwindowed feature
287
306
  # i.e. f_7d and f
288
307
  derivation_expression = derivation.expression
289
308
  base_col_semantic_fields = []
290
- tokens = re.findall(r'\b\w+\b', derivation_expression)
309
+ tokens = re.findall(r"\b\w+\b", derivation_expression)
291
310
  for token in tokens:
292
311
  if token in base_columns_to_hashes:
293
312
  base_col_semantic_fields.append(base_columns_to_hashes[token])
294
313
 
295
- output_columns[derivation.name] = _compute_semantic_hash(additional_semantic_fields + [f"derivation:{derivation.expression}"] + base_col_semantic_fields)
314
+ output_columns[derivation.name] = _compute_semantic_hash(
315
+ additional_semantic_fields
316
+ + [f"derivation:{derivation.expression}"]
317
+ + base_col_semantic_fields
318
+ )
296
319
  return output_columns
297
320
 
298
321
 
@@ -2,14 +2,14 @@ import os
2
2
  from dataclasses import dataclass
3
3
  from typing import Any, Dict, List, Optional, Type
4
4
 
5
+ from gen_thrift.api.ttypes import ConfType, GroupBy, Join, MetaData, Model, StagingQuery, Team
6
+
5
7
  import ai.chronon.cli.compile.parse_teams as teams
6
- from ai.chronon.api.ttypes import GroupBy, Join, MetaData, Model, StagingQuery, Team
7
8
  from ai.chronon.cli.compile.conf_validator import ConfValidator
8
9
  from ai.chronon.cli.compile.display.compile_status import CompileStatus
9
10
  from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
10
11
  from ai.chronon.cli.compile.serializer import file2thrift
11
12
  from ai.chronon.cli.logger import get_logger, require
12
- from ai.chronon.orchestration.ttypes import ConfType
13
13
 
14
14
  logger = get_logger()
15
15
 
@@ -23,11 +23,11 @@ class ConfigInfo:
23
23
 
24
24
  @dataclass
25
25
  class CompileContext:
26
-
27
- def __init__(self):
26
+ def __init__(self, ignore_python_errors: bool = False):
28
27
  self.chronon_root: str = os.getenv("CHRONON_ROOT", os.getcwd())
29
28
  self.teams_dict: Dict[str, Team] = teams.load_teams(self.chronon_root)
30
29
  self.compile_dir: str = "compiled"
30
+ self.ignore_python_errors: bool = ignore_python_errors
31
31
 
32
32
  self.config_infos: List[ConfigInfo] = [
33
33
  ConfigInfo(folder_name="joins", cls=Join, config_type=ConfType.JOIN),
@@ -42,7 +42,9 @@ class CompileContext:
42
42
  config_type=ConfType.STAGING_QUERY,
43
43
  ),
44
44
  ConfigInfo(folder_name="models", cls=Model, config_type=ConfType.MODEL),
45
- ConfigInfo(folder_name="teams_metadata", cls=MetaData, config_type=None), # only for team metadata
45
+ ConfigInfo(
46
+ folder_name="teams_metadata", cls=MetaData, config_type=None
47
+ ), # only for team metadata
46
48
  ]
47
49
 
48
50
  self.compile_status = CompileStatus(use_live=False)
@@ -52,13 +54,12 @@ class CompileContext:
52
54
  cls = config_info.cls
53
55
  self.existing_confs[cls] = self._parse_existing_confs(cls)
54
56
 
55
-
56
-
57
57
  self.validator: ConfValidator = ConfValidator(
58
58
  input_root=self.chronon_root,
59
59
  output_root=self.compile_dir,
60
60
  existing_gbs=self.existing_confs[GroupBy],
61
61
  existing_joins=self.existing_confs[Join],
62
+ existing_staging_queries=self.existing_confs[StagingQuery],
62
63
  )
63
64
 
64
65
  def input_dir(self, cls: type) -> str:
@@ -93,9 +94,7 @@ class CompileContext:
93
94
  return os.path.join(self.chronon_root, self.compile_dir)
94
95
  else:
95
96
  config_info = self.config_info_for_class(cls)
96
- return os.path.join(
97
- self.chronon_root, self.compile_dir, config_info.folder_name
98
- )
97
+ return os.path.join(self.chronon_root, self.compile_dir, config_info.folder_name)
99
98
 
100
99
  def staging_output_path(self, compiled_obj: CompiledObj):
101
100
  """
@@ -121,7 +120,6 @@ class CompileContext:
121
120
  require(False, f"Class {cls} not found in CONFIG_INFOS")
122
121
 
123
122
  def _parse_existing_confs(self, obj_class: type) -> Dict[str, object]:
124
-
125
123
  result = {}
126
124
 
127
125
  output_dir = self.output_dir(obj_class)
@@ -131,9 +129,7 @@ class CompileContext:
131
129
  return result
132
130
 
133
131
  for sub_root, _sub_dirs, sub_files in os.walk(output_dir):
134
-
135
132
  for f in sub_files:
136
-
137
133
  if f.startswith("."): # ignore hidden files - such as .DS_Store
138
134
  continue
139
135
 
@@ -155,7 +151,9 @@ class CompileContext:
155
151
  )
156
152
  self.compile_status.add_existing_object_update_display(compiled_obj)
157
153
  elif isinstance(obj, MetaData):
158
- team_metadata_name = '.'.join(full_path.split('/')[-2:]) # use the name of the file as team metadata won't have name
154
+ team_metadata_name = ".".join(
155
+ full_path.split("/")[-2:]
156
+ ) # use the name of the file as team metadata won't have name
159
157
  result[team_metadata_name] = obj
160
158
  compiled_obj = CompiledObj(
161
159
  name=team_metadata_name,
@@ -167,9 +165,7 @@ class CompileContext:
167
165
  )
168
166
  self.compile_status.add_existing_object_update_display(compiled_obj)
169
167
  else:
170
- logger.errors(
171
- f"Parsed object from {full_path} has no metaData attribute"
172
- )
168
+ logger.errors(f"Parsed object from {full_path} has no metaData attribute")
173
169
 
174
170
  except Exception as e:
175
171
  print(f"Failed to parse file {full_path}: {str(e)}", e)
@@ -4,6 +4,8 @@ import traceback
4
4
  from dataclasses import dataclass
5
5
  from typing import Any, Dict, List, Optional, Tuple
6
6
 
7
+ from gen_thrift.api.ttypes import ConfType
8
+
7
9
  import ai.chronon.cli.compile.display.compiled_obj
8
10
  import ai.chronon.cli.compile.parse_configs as parser
9
11
  import ai.chronon.cli.logger as logger
@@ -12,7 +14,6 @@ from ai.chronon.cli.compile.compile_context import CompileContext, ConfigInfo
12
14
  from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
13
15
  from ai.chronon.cli.compile.display.console import console
14
16
  from ai.chronon.cli.compile.parse_teams import merge_team_execution_info
15
- from ai.chronon.orchestration.ttypes import ConfType
16
17
  from ai.chronon.types import MetaData
17
18
 
18
19
  logger = logger.get_logger()
@@ -26,42 +27,72 @@ class CompileResult:
26
27
 
27
28
 
28
29
  class Compiler:
29
-
30
30
  def __init__(self, compile_context: CompileContext):
31
31
  self.compile_context = compile_context
32
32
 
33
33
  def compile(self) -> Dict[ConfType, CompileResult]:
34
+ # Clean staging directory at the start to ensure fresh compilation
35
+ staging_dir = self.compile_context.staging_output_dir()
36
+ if os.path.exists(staging_dir):
37
+ shutil.rmtree(staging_dir)
34
38
 
35
39
  config_infos = self.compile_context.config_infos
36
40
 
37
41
  compile_results = {}
42
+ all_compiled_objects = [] # Collect all compiled objects for change validation
38
43
 
39
44
  for config_info in config_infos:
40
- configs = self._compile_class_configs(config_info)
41
-
45
+ configs, compiled_objects = self._compile_class_configs(config_info)
42
46
  compile_results[config_info.config_type] = configs
43
- self._compile_team_metadata()
44
47
 
45
- # check if staging_output_dir exists
46
- staging_dir = self.compile_context.staging_output_dir()
47
- if os.path.exists(staging_dir):
48
- # replace staging_output_dir to output_dir
49
- output_dir = self.compile_context.output_dir()
50
- if os.path.exists(output_dir):
51
- shutil.rmtree(output_dir)
52
- shutil.move(staging_dir, output_dir)
53
- else:
54
- print(
55
- f"Staging directory {staging_dir} does not exist. "
56
- "Happens when every chronon config fails to compile or when no chronon configs exist."
57
- )
48
+ # Collect compiled objects for change validation
49
+ all_compiled_objects.extend(compiled_objects)
50
+
51
+ # Validate changes once after all classes have been processed
52
+ self.compile_context.validator.validate_changes(all_compiled_objects)
58
53
 
59
- # TODO: temporarily just print out the final results of the compile until live fix is implemented:
60
- # https://github.com/Textualize/rich/pull/3637
61
- console.print(self.compile_context.compile_status.render())
54
+ # Show the nice display first
55
+ console.print(
56
+ self.compile_context.compile_status.render(self.compile_context.ignore_python_errors)
57
+ )
58
+
59
+ # Check for confirmation before finalizing files
60
+ self.compile_context.validator.check_pending_changes_confirmation(
61
+ self.compile_context.compile_status
62
+ )
63
+
64
+ # Only proceed with file operations if there are no compilation errors
65
+ if not self._has_compilation_errors() or self.compile_context.ignore_python_errors:
66
+ self._compile_team_metadata()
67
+
68
+ # check if staging_output_dir exists
69
+ staging_dir = self.compile_context.staging_output_dir()
70
+ if os.path.exists(staging_dir):
71
+ # replace staging_output_dir to output_dir
72
+ output_dir = self.compile_context.output_dir()
73
+ if os.path.exists(output_dir):
74
+ shutil.rmtree(output_dir)
75
+ shutil.move(staging_dir, output_dir)
76
+ else:
77
+ print(
78
+ f"Staging directory {staging_dir} does not exist. "
79
+ "Happens when every chronon config fails to compile or when no chronon configs exist."
80
+ )
81
+ else:
82
+ # Clean up staging directory when there are errors (don't move to output)
83
+ staging_dir = self.compile_context.staging_output_dir()
84
+ if os.path.exists(staging_dir):
85
+ shutil.rmtree(staging_dir)
62
86
 
63
87
  return compile_results
64
88
 
89
+ def _has_compilation_errors(self):
90
+ """Check if there are any compilation errors across all class trackers."""
91
+ for tracker in self.compile_context.compile_status.cls_to_tracker.values():
92
+ if tracker.files_to_errors:
93
+ return True
94
+ return False
95
+
65
96
  def _compile_team_metadata(self):
66
97
  """
67
98
  Compile the team metadata and return the compiled object.
@@ -87,17 +118,14 @@ class Compiler:
87
118
  # Done writing team metadata, close the class
88
119
  self.compile_context.compile_status.close_cls(MetaData.__name__)
89
120
 
90
- def _compile_class_configs(self, config_info: ConfigInfo) -> CompileResult:
91
-
92
- compile_result = CompileResult(
93
- config_info=config_info, obj_dict={}, error_dict={}
94
- )
121
+ def _compile_class_configs(
122
+ self, config_info: ConfigInfo
123
+ ) -> Tuple[CompileResult, List[CompiledObj]]:
124
+ compile_result = CompileResult(config_info=config_info, obj_dict={}, error_dict={})
95
125
 
96
126
  input_dir = self.compile_context.input_dir(config_info.cls)
97
127
 
98
- compiled_objects = parser.from_folder(
99
- config_info.cls, input_dir, self.compile_context
100
- )
128
+ compiled_objects = parser.from_folder(config_info.cls, input_dir, self.compile_context)
101
129
 
102
130
  objects, errors = self._write_objects_in_folder(compiled_objects)
103
131
 
@@ -109,20 +137,17 @@ class Compiler:
109
137
 
110
138
  self.compile_context.compile_status.close_cls(config_info.cls.__name__)
111
139
 
112
- return compile_result
140
+ return compile_result, compiled_objects
113
141
 
114
142
  def _write_objects_in_folder(
115
143
  self,
116
144
  compiled_objects: List[ai.chronon.cli.compile.display.compiled_obj.CompiledObj],
117
145
  ) -> Tuple[Dict[str, Any], Dict[str, List[BaseException]]]:
118
-
119
146
  error_dict = {}
120
147
  object_dict = {}
121
148
 
122
149
  for co in compiled_objects:
123
-
124
150
  if co.obj:
125
-
126
151
  if co.errors:
127
152
  error_dict[co.name] = co.errors
128
153
 
@@ -130,9 +155,7 @@ class Compiler:
130
155
  self.compile_context.compile_status.print_live_console(
131
156
  f"Error processing conf {co.name}: {error}"
132
157
  )
133
- traceback.print_exception(
134
- type(error), error, error.__traceback__
135
- )
158
+ traceback.print_exception(type(error), error, error.__traceback__)
136
159
 
137
160
  else:
138
161
  self._write_object(co)