awx-zipline-ai 0.2.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of awx-zipline-ai might be problematic. Click here for more details.

Files changed (96) hide show
  1. agent/ttypes.py +6 -6
  2. ai/chronon/airflow_helpers.py +20 -23
  3. ai/chronon/cli/__init__.py +0 -0
  4. ai/chronon/cli/compile/__init__.py +0 -0
  5. ai/chronon/cli/compile/column_hashing.py +40 -17
  6. ai/chronon/cli/compile/compile_context.py +13 -17
  7. ai/chronon/cli/compile/compiler.py +59 -36
  8. ai/chronon/cli/compile/conf_validator.py +251 -99
  9. ai/chronon/cli/compile/display/__init__.py +0 -0
  10. ai/chronon/cli/compile/display/class_tracker.py +6 -16
  11. ai/chronon/cli/compile/display/compile_status.py +10 -10
  12. ai/chronon/cli/compile/display/diff_result.py +79 -14
  13. ai/chronon/cli/compile/fill_templates.py +3 -8
  14. ai/chronon/cli/compile/parse_configs.py +10 -17
  15. ai/chronon/cli/compile/parse_teams.py +38 -34
  16. ai/chronon/cli/compile/serializer.py +3 -9
  17. ai/chronon/cli/compile/version_utils.py +42 -0
  18. ai/chronon/cli/git_utils.py +2 -13
  19. ai/chronon/cli/logger.py +0 -2
  20. ai/chronon/constants.py +1 -1
  21. ai/chronon/group_by.py +47 -47
  22. ai/chronon/join.py +46 -32
  23. ai/chronon/logger.py +1 -2
  24. ai/chronon/model.py +9 -4
  25. ai/chronon/query.py +2 -2
  26. ai/chronon/repo/__init__.py +1 -2
  27. ai/chronon/repo/aws.py +17 -31
  28. ai/chronon/repo/cluster.py +121 -50
  29. ai/chronon/repo/compile.py +14 -8
  30. ai/chronon/repo/constants.py +1 -1
  31. ai/chronon/repo/default_runner.py +32 -54
  32. ai/chronon/repo/explore.py +70 -73
  33. ai/chronon/repo/extract_objects.py +6 -9
  34. ai/chronon/repo/gcp.py +89 -88
  35. ai/chronon/repo/gitpython_utils.py +3 -2
  36. ai/chronon/repo/hub_runner.py +145 -55
  37. ai/chronon/repo/hub_uploader.py +2 -1
  38. ai/chronon/repo/init.py +12 -5
  39. ai/chronon/repo/join_backfill.py +19 -5
  40. ai/chronon/repo/run.py +42 -39
  41. ai/chronon/repo/serializer.py +4 -12
  42. ai/chronon/repo/utils.py +72 -63
  43. ai/chronon/repo/zipline.py +3 -19
  44. ai/chronon/repo/zipline_hub.py +211 -39
  45. ai/chronon/resources/__init__.py +0 -0
  46. ai/chronon/resources/gcp/__init__.py +0 -0
  47. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  48. ai/chronon/resources/gcp/group_bys/test/data.py +13 -17
  49. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  50. ai/chronon/resources/gcp/joins/test/data.py +4 -8
  51. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  52. ai/chronon/resources/gcp/sources/test/data.py +9 -6
  53. ai/chronon/resources/gcp/teams.py +9 -21
  54. ai/chronon/source.py +2 -4
  55. ai/chronon/staging_query.py +60 -19
  56. ai/chronon/types.py +3 -2
  57. ai/chronon/utils.py +21 -68
  58. ai/chronon/windows.py +2 -4
  59. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.0.dist-info}/METADATA +47 -24
  60. awx_zipline_ai-0.3.0.dist-info/RECORD +96 -0
  61. awx_zipline_ai-0.3.0.dist-info/top_level.txt +4 -0
  62. gen_thrift/__init__.py +0 -0
  63. {ai/chronon → gen_thrift}/api/ttypes.py +327 -197
  64. {ai/chronon/api → gen_thrift}/common/ttypes.py +9 -39
  65. gen_thrift/eval/ttypes.py +660 -0
  66. {ai/chronon → gen_thrift}/hub/ttypes.py +12 -131
  67. {ai/chronon → gen_thrift}/observability/ttypes.py +343 -180
  68. {ai/chronon → gen_thrift}/planner/ttypes.py +326 -45
  69. ai/chronon/eval/__init__.py +0 -122
  70. ai/chronon/eval/query_parsing.py +0 -19
  71. ai/chronon/eval/sample_tables.py +0 -100
  72. ai/chronon/eval/table_scan.py +0 -186
  73. ai/chronon/orchestration/ttypes.py +0 -4406
  74. ai/chronon/resources/gcp/README.md +0 -174
  75. ai/chronon/resources/gcp/zipline-cli-install.sh +0 -54
  76. awx_zipline_ai-0.2.1.dist-info/RECORD +0 -93
  77. awx_zipline_ai-0.2.1.dist-info/licenses/LICENSE +0 -202
  78. awx_zipline_ai-0.2.1.dist-info/top_level.txt +0 -3
  79. /jars/__init__.py → /__init__.py +0 -0
  80. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.0.dist-info}/WHEEL +0 -0
  81. {awx_zipline_ai-0.2.1.dist-info → awx_zipline_ai-0.3.0.dist-info}/entry_points.txt +0 -0
  82. {ai/chronon → gen_thrift}/api/__init__.py +0 -0
  83. {ai/chronon/api/common → gen_thrift/api}/constants.py +0 -0
  84. {ai/chronon/api → gen_thrift}/common/__init__.py +0 -0
  85. {ai/chronon/api → gen_thrift/common}/constants.py +0 -0
  86. {ai/chronon/fetcher → gen_thrift/eval}/__init__.py +0 -0
  87. {ai/chronon/fetcher → gen_thrift/eval}/constants.py +0 -0
  88. {ai/chronon/hub → gen_thrift/fetcher}/__init__.py +0 -0
  89. {ai/chronon/hub → gen_thrift/fetcher}/constants.py +0 -0
  90. {ai/chronon → gen_thrift}/fetcher/ttypes.py +0 -0
  91. {ai/chronon/observability → gen_thrift/hub}/__init__.py +0 -0
  92. {ai/chronon/observability → gen_thrift/hub}/constants.py +0 -0
  93. {ai/chronon/orchestration → gen_thrift/observability}/__init__.py +0 -0
  94. {ai/chronon/orchestration → gen_thrift/observability}/constants.py +0 -0
  95. {ai/chronon → gen_thrift}/planner/__init__.py +0 -0
  96. {ai/chronon → gen_thrift}/planner/constants.py +0 -0
@@ -1,5 +1,4 @@
1
- """Object for checking whether a Chronon API thrift object is consistent with other
2
- """
1
+ """Object for checking whether a Chronon API thrift object is consistent with other"""
3
2
 
4
3
  # Copyright (C) 2023 The Chronon Authors.
5
4
  #
@@ -18,12 +17,14 @@
18
17
  import json
19
18
  import logging
20
19
  import re
20
+ import sys
21
21
  import textwrap
22
22
  from collections import defaultdict
23
+ from dataclasses import dataclass
23
24
  from typing import Dict, List, Tuple
24
25
 
25
- import ai.chronon.api.common.ttypes as common
26
- from ai.chronon.api.ttypes import (
26
+ import gen_thrift.common.ttypes as common
27
+ from gen_thrift.api.ttypes import (
27
28
  Accuracy,
28
29
  Aggregation,
29
30
  Derivation,
@@ -33,6 +34,7 @@ from ai.chronon.api.ttypes import (
33
34
  JoinPart,
34
35
  Source,
35
36
  )
37
+
36
38
  from ai.chronon.cli.compile.column_hashing import (
37
39
  compute_group_by_columns_hashes,
38
40
  get_pre_derived_group_by_columns,
@@ -40,6 +42,7 @@ from ai.chronon.cli.compile.column_hashing import (
40
42
  get_pre_derived_join_features,
41
43
  get_pre_derived_source_keys,
42
44
  )
45
+ from ai.chronon.cli.compile.version_utils import is_version_change
43
46
  from ai.chronon.logger import get_logger
44
47
  from ai.chronon.repo.serializer import thrift_simple_json
45
48
  from ai.chronon.utils import get_query, get_root_source
@@ -49,6 +52,20 @@ SKIPPED_FIELDS = frozenset(["metaData"])
49
52
  EXTERNAL_KEY = "onlineExternalParts"
50
53
 
51
54
 
55
+ @dataclass
56
+ class ConfigChange:
57
+ """Represents a change to a compiled config object."""
58
+
59
+ name: str
60
+ obj_type: str
61
+ online: bool = False
62
+ production: bool = False
63
+ base_name: str = None
64
+ old_version: int = None
65
+ new_version: int = None
66
+ is_version_change: bool = False
67
+
68
+
52
69
  def _filter_skipped_fields_from_join(json_obj: Dict, skipped_fields):
53
70
  for join_part in json_obj["joinParts"]:
54
71
  group_by = join_part["groupBy"]
@@ -91,7 +108,6 @@ def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
91
108
  return False
92
109
 
93
110
  for agg in aggs:
94
-
95
111
  if not agg.windows:
96
112
  return False
97
113
 
@@ -103,9 +119,7 @@ def _group_by_has_hourly_windows(groupBy: GroupBy) -> bool:
103
119
 
104
120
 
105
121
  def detect_feature_name_collisions(
106
- group_bys: List[Tuple[GroupBy, str]],
107
- entity_set_type: str,
108
- name: str
122
+ group_bys: List[Tuple[GroupBy, str]], entity_set_type: str, name: str
109
123
  ) -> BaseException | None:
110
124
  # Build a map of output_column -> set of group_by name
111
125
  output_col_to_gbs = {}
@@ -115,7 +129,7 @@ def detect_feature_name_collisions(
115
129
  prefix_str = jp_prefix_str + key_str + "_"
116
130
  cols = compute_group_by_columns_hashes(gb, exclude_keys=True)
117
131
  if not cols:
118
- print('HERE')
132
+ print("HERE")
119
133
  cols = {
120
134
  f"{prefix_str}{base_col}"
121
135
  for base_col in list(compute_group_by_columns_hashes(gb, exclude_keys=True).keys())
@@ -127,19 +141,13 @@ def detect_feature_name_collisions(
127
141
  output_col_to_gbs[col].add(gb_name)
128
142
 
129
143
  # Find output columns that are produced by more than one group_by
130
- collisions = {
131
- col: gb_names
132
- for col, gb_names in output_col_to_gbs.items()
133
- if len(gb_names) > 1
134
- }
144
+ collisions = {col: gb_names for col, gb_names in output_col_to_gbs.items() if len(gb_names) > 1}
135
145
 
136
146
  if not collisions:
137
147
  return None # no collisions
138
148
 
139
149
  # Assemble an error message listing the conflicting GroupBys by name
140
- lines = [
141
- f"{entity_set_type} for Join: {name} has the following output name collisions:\n"
142
- ]
150
+ lines = [f"{entity_set_type} for Join: {name} has the following output name collisions:\n"]
143
151
  for col, gb_names in collisions.items():
144
152
  names_str = ", ".join(sorted(gb_names))
145
153
  lines.append(f" - [{col}] has collisions from: [{names_str}]")
@@ -161,9 +169,9 @@ class ConfValidator(object):
161
169
  output_root,
162
170
  existing_gbs,
163
171
  existing_joins,
172
+ existing_staging_queries,
164
173
  log_level=logging.INFO,
165
174
  ):
166
-
167
175
  self.chronon_root_path = input_root
168
176
  self.output_root = output_root
169
177
 
@@ -176,8 +184,10 @@ class ConfValidator(object):
176
184
  self.old_objs = defaultdict(dict)
177
185
  self.old_group_bys = existing_gbs
178
186
  self.old_joins = existing_joins
187
+ self.old_staging_queries = existing_staging_queries
179
188
  self.old_objs["GroupBy"] = self.old_group_bys
180
189
  self.old_objs["Join"] = self.old_joins
190
+ self.old_objs["StagingQuery"] = self.old_staging_queries
181
191
 
182
192
  def _get_old_obj(self, obj_class: type, obj_name: str) -> object:
183
193
  """
@@ -223,9 +233,7 @@ class ConfValidator(object):
223
233
  elif not any(
224
234
  join.metaData.online for join in self._get_old_joins_with_group_by(obj)
225
235
  ) and not _is_batch_upload_needed(obj):
226
- reasons.append(
227
- "is not marked online/production nor is included in any online join"
228
- )
236
+ reasons.append("is not marked online/production nor is included in any online join")
229
237
  return reasons
230
238
 
231
239
  def validate_obj(self, obj: object) -> List[BaseException]:
@@ -241,13 +249,9 @@ class ConfValidator(object):
241
249
  return self._validate_join(obj)
242
250
  return []
243
251
 
244
- def _has_diff(
245
- self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS
246
- ) -> bool:
252
+ def _has_diff(self, obj: object, old_obj: object, skipped_fields=SKIPPED_FIELDS) -> bool:
247
253
  new_json = {
248
- k: v
249
- for k, v in json.loads(thrift_simple_json(obj)).items()
250
- if k not in skipped_fields
254
+ k: v for k, v in json.loads(thrift_simple_json(obj)).items() if k not in skipped_fields
251
255
  }
252
256
  old_json = {
253
257
  k: v
@@ -257,6 +261,7 @@ class ConfValidator(object):
257
261
  if isinstance(obj, Join):
258
262
  _filter_skipped_fields_from_join(new_json, skipped_fields)
259
263
  _filter_skipped_fields_from_join(old_json, skipped_fields)
264
+
260
265
  return new_json != old_json
261
266
 
262
267
  def safe_to_overwrite(self, obj: object) -> bool:
@@ -264,11 +269,7 @@ class ConfValidator(object):
264
269
  to materialize and overwrite the old conf.
265
270
  """
266
271
  old_obj = self._get_old_obj(type(obj), obj.metaData.name)
267
- return (
268
- not old_obj
269
- or not self._has_diff(obj, old_obj)
270
- or not old_obj.metaData.online
271
- )
272
+ return not old_obj or not self._has_diff(obj, old_obj) or not old_obj.metaData.online
272
273
 
273
274
  def _validate_derivations(
274
275
  self, pre_derived_cols: List[str], derivations: List[Derivation]
@@ -295,23 +296,26 @@ class ConfValidator(object):
295
296
  if wild_card_derivation_included:
296
297
  if derivation.expression in derived_columns:
297
298
  derived_columns.remove(derivation.expression)
298
- if (
299
- derivation.expression not in pre_derived_cols
300
- and derivation.expression not in ("ds", "ts")
299
+ if derivation.expression not in pre_derived_cols and derivation.expression not in (
300
+ "ds",
301
+ "ts",
301
302
  ):
302
303
  errors.append(
303
- ValueError("Incorrect derivation expression {}, expression not found in pre-derived columns {}"
304
- .format(
305
- derivation.expression, pre_derived_cols
306
- ))
304
+ ValueError(
305
+ "Incorrect derivation expression {}, expression not found in pre-derived columns {}".format(
306
+ derivation.expression, pre_derived_cols
307
+ )
308
+ )
307
309
  )
308
310
  if derivation.name != "*":
309
311
  if derivation.name in derived_columns:
310
312
  errors.append(
311
- ValueError("Incorrect derivation name {} due to output column name conflict".format(
312
- derivation.name
313
+ ValueError(
314
+ "Incorrect derivation name {} due to output column name conflict".format(
315
+ derivation.name
316
+ )
313
317
  )
314
- ))
318
+ )
315
319
  else:
316
320
  derived_columns.add(derivation.name)
317
321
  return errors
@@ -344,17 +348,20 @@ class ConfValidator(object):
344
348
  err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
345
349
 
346
350
  # Right side of key mapping should only include keys in GroupBy
347
- keys_missing_from_key_map_values = [v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns]
351
+ keys_missing_from_key_map_values = [
352
+ v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns
353
+ ]
348
354
  if keys_missing_from_key_map_values:
349
355
  err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
350
356
 
351
357
  if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
352
- err_string += "\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
358
+ err_string += (
359
+ "\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
360
+ )
353
361
 
354
362
  if err_string:
355
363
  return ValueError(err_string)
356
364
 
357
-
358
365
  def _validate_keys(self, join: Join) -> List[BaseException]:
359
366
  left = join.left
360
367
 
@@ -375,7 +382,7 @@ class ConfValidator(object):
375
382
  errors = []
376
383
 
377
384
  if left_cols:
378
- join_parts = join.joinParts
385
+ join_parts = list(join.joinParts) # Create a copy to avoid modifying the original
379
386
 
380
387
  # Add label_parts to join_parts to validate if set
381
388
  label_parts = join.labelParts
@@ -391,7 +398,6 @@ class ConfValidator(object):
391
398
 
392
399
  return errors
393
400
 
394
-
395
401
  def _validate_join(self, join: Join) -> List[BaseException]:
396
402
  """
397
403
  Validate join's status with materialized versions of group_bys
@@ -402,7 +408,9 @@ class ConfValidator(object):
402
408
  """
403
409
  included_group_bys_and_prefixes = [(rp.groupBy, rp.prefix) for rp in join.joinParts]
404
410
  # TODO: Remove label parts check in future PR that deprecates label_parts
405
- included_label_parts_and_prefixes = [(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
411
+ included_label_parts_and_prefixes = (
412
+ [(lp.groupBy, lp.prefix) for lp in join.labelParts.labels] if join.labelParts else []
413
+ )
406
414
  included_group_bys = [tup[0] for tup in included_group_bys_and_prefixes]
407
415
 
408
416
  offline_included_group_bys = [
@@ -422,9 +430,7 @@ class ConfValidator(object):
422
430
  if group_by.metaData.production is False
423
431
  ]
424
432
  # Check if the underlying groupBy is valid
425
- group_by_errors = [
426
- self._validate_group_by(group_by) for group_by in included_group_bys
427
- ]
433
+ group_by_errors = [self._validate_group_by(group_by) for group_by in included_group_bys]
428
434
  errors += [
429
435
  ValueError(f"join {join.metaData.name}'s underlying {error}")
430
436
  for errors in group_by_errors
@@ -433,18 +439,22 @@ class ConfValidator(object):
433
439
  # Check if the production join is using non production groupBy
434
440
  if join.metaData.production and non_prod_old_group_bys:
435
441
  errors.append(
436
- ValueError("join {} is production but includes the following non production group_bys: {}".format(
437
- join.metaData.name, ", ".join(non_prod_old_group_bys)
442
+ ValueError(
443
+ "join {} is production but includes the following non production group_bys: {}".format(
444
+ join.metaData.name, ", ".join(non_prod_old_group_bys)
445
+ )
438
446
  )
439
- ))
447
+ )
440
448
  # Check if the online join is using the offline groupBy
441
449
  if join.metaData.online:
442
450
  if offline_included_group_bys:
443
451
  errors.append(
444
- ValueError("join {} is online but includes the following offline group_bys: {}".format(
445
- join.metaData.name, ", ".join(offline_included_group_bys)
452
+ ValueError(
453
+ "join {} is online but includes the following offline group_bys: {}".format(
454
+ join.metaData.name, ", ".join(offline_included_group_bys)
455
+ )
446
456
  )
447
- ))
457
+ )
448
458
  # Only validate the join derivation when the underlying groupBy is valid
449
459
  group_by_correct = all(not errors for errors in group_by_errors)
450
460
  if join.derivations and group_by_correct:
@@ -457,22 +467,23 @@ class ConfValidator(object):
457
467
  columns = features + keys
458
468
  errors.extend(self._validate_derivations(columns, join.derivations))
459
469
 
460
-
461
470
  errors.extend(self._validate_keys(join))
462
471
 
463
472
  # If the join is using "short" names, ensure that there are no collisions
464
473
  if join.useLongNames is False:
465
- right_part_collisions = detect_feature_name_collisions(included_group_bys_and_prefixes, "right parts", join.metaData.name)
474
+ right_part_collisions = detect_feature_name_collisions(
475
+ included_group_bys_and_prefixes, "right parts", join.metaData.name
476
+ )
466
477
  if right_part_collisions:
467
478
  errors.append(right_part_collisions)
468
479
 
469
- label_part_collisions = detect_feature_name_collisions(included_label_parts_and_prefixes, "label parts", join.metaData.name)
480
+ label_part_collisions = detect_feature_name_collisions(
481
+ included_label_parts_and_prefixes, "label parts", join.metaData.name
482
+ )
470
483
  if label_part_collisions:
471
484
  errors.append(label_part_collisions)
472
485
 
473
486
  return errors
474
-
475
-
476
487
 
477
488
  def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:
478
489
  """
@@ -483,17 +494,11 @@ class ConfValidator(object):
483
494
  List of validation errors.
484
495
  """
485
496
  joins = self._get_old_joins_with_group_by(group_by)
486
- online_joins = [
487
- join.metaData.name for join in joins if join.metaData.online is True
488
- ]
489
- prod_joins = [
490
- join.metaData.name for join in joins if join.metaData.production is True
491
- ]
497
+ online_joins = [join.metaData.name for join in joins if join.metaData.online is True]
498
+ prod_joins = [join.metaData.name for join in joins if join.metaData.production is True]
492
499
  errors = []
493
500
 
494
- non_temporal = (
495
- group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
496
- )
501
+ non_temporal = group_by.accuracy is None or group_by.accuracy == Accuracy.SNAPSHOT
497
502
 
498
503
  no_topic = not _group_by_has_topic(group_by)
499
504
  has_hourly_windows = _group_by_has_hourly_windows(group_by)
@@ -501,48 +506,55 @@ class ConfValidator(object):
501
506
  # batch features cannot contain hourly windows
502
507
  if (no_topic and non_temporal) and has_hourly_windows:
503
508
  errors.append(
504
- ValueError(f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
505
- f"hourly windows. "
506
- )
509
+ ValueError(
510
+ f"group_by {group_by.metaData.name} is defined to be daily refreshed but contains "
511
+ f"hourly windows. "
512
+ )
507
513
  )
508
514
 
509
515
  def _validate_bounded_event_source():
510
516
  if group_by.aggregations is None:
511
517
  return
512
-
518
+
513
519
  unbounded_event_sources = [
514
520
  str(src)
515
- for src in group_by.sources
516
- if isinstance(get_root_source(src), EventSource) and get_query(src).startPartition is None
521
+ for src in group_by.sources
522
+ if isinstance(get_root_source(src), EventSource)
523
+ and get_query(src).startPartition is None
517
524
  ]
518
525
 
519
526
  if not unbounded_event_sources:
520
527
  return
521
-
522
- unwindowed_aggregations = [str(agg) for agg in group_by.aggregations if agg.windows is None]
528
+
529
+ unwindowed_aggregations = [
530
+ str(agg) for agg in group_by.aggregations if agg.windows is None
531
+ ]
523
532
 
524
533
  if not unwindowed_aggregations:
525
534
  return
526
-
535
+
527
536
  nln = "\n"
528
537
 
529
538
  errors.append(
530
539
  ValueError(
531
540
  f"""group_by {group_by.metaData.name} uses unwindowed aggregations [{nln}{f",{nln}".join(unwindowed_aggregations)}{nln}]
532
541
  on unbounded event sources: [{nln}{f",{nln}".join(unbounded_event_sources)}{nln}].
533
- Please set a start_partition on the source, or a window on the aggregation.""")
542
+ Please set a start_partition on the source, or a window on the aggregation."""
534
543
  )
535
-
544
+ )
545
+
536
546
  _validate_bounded_event_source()
537
547
 
538
548
  # group by that are marked explicitly offline should not be present in
539
549
  # materialized online joins.
540
550
  if group_by.metaData.online is False and online_joins:
541
551
  errors.append(
542
- ValueError("group_by {} is explicitly marked offline but included in "
543
- "the following online joins: {}".format(
544
- group_by.metaData.name, ", ".join(online_joins)
545
- ))
552
+ ValueError(
553
+ "group_by {} is explicitly marked offline but included in "
554
+ "the following online joins: {}".format(
555
+ group_by.metaData.name, ", ".join(online_joins)
556
+ )
557
+ )
546
558
  )
547
559
  # group by that are marked explicitly non-production should not be
548
560
  # present in materialized production joins.
@@ -551,8 +563,9 @@ class ConfValidator(object):
551
563
  errors.append(
552
564
  ValueError(
553
565
  "group_by {} is explicitly marked as non-production but included in the following production "
554
- "joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
555
- ))
566
+ "joins: {}".format(group_by.metaData.name, ", ".join(prod_joins))
567
+ )
568
+ )
556
569
  # if the group by is included in any of materialized production join,
557
570
  # set it to production in the materialized output.
558
571
  else:
@@ -569,22 +582,161 @@ class ConfValidator(object):
569
582
 
570
583
  for source in group_by.sources:
571
584
  src: Source = source
572
- if (
573
- src.events
574
- and src.events.isCumulative
575
- and (src.events.query.timeColumn is None)
576
- ):
585
+ if src.events and src.events.isCumulative and (src.events.query.timeColumn is None):
577
586
  errors.append(
578
- ValueError("Please set query.timeColumn for Cumulative Events Table: {}".format(
579
- src.events.table
580
- ))
587
+ ValueError(
588
+ "Please set query.timeColumn for Cumulative Events Table: {}".format(
589
+ src.events.table
590
+ )
591
+ )
581
592
  )
582
- elif (
583
- src.joinSource
584
- ):
593
+ elif src.joinSource:
585
594
  join_obj = src.joinSource.join
586
595
  if join_obj.metaData.name is None or join_obj.metaData.team is None:
587
596
  errors.append(
588
597
  ValueError(f"Join must be defined with team and name: {join_obj}")
589
598
  )
590
599
  return errors
600
+
601
+ def validate_changes(self, results):
602
+ """
603
+ Validate changes in compiled objects against existing materialized objects.
604
+
605
+ Args:
606
+ results: List of CompiledObj from parse_configs
607
+
608
+ Returns:
609
+ None (exits on user cancellation)
610
+ """
611
+
612
+ # Filter out results with errors and only process GroupBy/Join
613
+ valid_results = []
614
+ for result in results:
615
+ if result.errors:
616
+ continue # Skip results with errors
617
+ if result.obj_type not in ["GroupBy", "Join", "StagingQuery"]:
618
+ continue # Skip non-GroupBy/Join objects
619
+ valid_results.append(result)
620
+
621
+ # Categorize changes
622
+ changed_objects = {"changed": [], "deleted": [], "added": []}
623
+
624
+ # Process each valid result
625
+ for result in valid_results:
626
+ obj_name = result.obj.metaData.name
627
+ obj_type = result.obj_type
628
+
629
+ # Check if object exists in old objects
630
+ old_obj = self._get_old_obj(type(result.obj), obj_name)
631
+
632
+ if old_obj is None:
633
+ # New object
634
+ change = self._create_config_change(result.obj, obj_type)
635
+ changed_objects["added"].append(change)
636
+ elif self._has_diff(result.obj, old_obj):
637
+ # Modified object
638
+ change = self._create_config_change(result.obj, obj_type)
639
+ changed_objects["changed"].append(change)
640
+
641
+ # Check for deleted objects
642
+ # Build set of current objects from valid_results for lookup
643
+ current_objects = {(result.obj_type, result.obj.metaData.name) for result in valid_results}
644
+
645
+ for obj_type in ["GroupBy", "Join", "StagingQuery"]:
646
+ old_objs = self.old_objs.get(obj_type, {})
647
+ for obj_name, old_obj in old_objs.items():
648
+ if (obj_type, obj_name) not in current_objects:
649
+ # Object was deleted
650
+ change = self._create_config_change(old_obj, obj_type)
651
+ changed_objects["deleted"].append(change)
652
+
653
+ # Store changes for later confirmation check
654
+ self._pending_changes = {
655
+ "changed": changed_objects["changed"],
656
+ "deleted": changed_objects["deleted"],
657
+ "added": changed_objects["added"],
658
+ }
659
+
660
+ def _filter_non_version_changes(self, existing_changes, added_changes):
661
+ """Filter out version changes from existing changes.
662
+
663
+ Returns list of changes that are NOT version bumps and require confirmation.
664
+ """
665
+ added_names = {change.name for change in added_changes}
666
+ non_version_changes = []
667
+
668
+ for change in existing_changes:
669
+ # Check if this deleted config has a corresponding added config that represents a version bump
670
+ is_version_bump = any(
671
+ is_version_change(change.name, added_name) for added_name in added_names
672
+ )
673
+
674
+ if not is_version_bump:
675
+ non_version_changes.append(change)
676
+
677
+ return non_version_changes
678
+
679
+ def check_pending_changes_confirmation(self, compile_status):
680
+ """Check if user confirmation is needed for pending changes after display."""
681
+ from ai.chronon.cli.compile.display.console import console
682
+
683
+ # Skip confirmation if there are compilation errors
684
+ if self._has_compilation_errors(compile_status):
685
+ return # Don't prompt when there are errors
686
+
687
+ if not hasattr(self, "_pending_changes"):
688
+ return # No pending changes
689
+
690
+ # Check if we need user confirmation (only for non-version-bump changes)
691
+ non_version_changes = self._filter_non_version_changes(
692
+ self._pending_changes["changed"] + self._pending_changes["deleted"],
693
+ self._pending_changes["added"],
694
+ )
695
+
696
+ if non_version_changes:
697
+ if not self._prompt_user_confirmation():
698
+ console.print("❌ Compilation cancelled by user.")
699
+ sys.exit(1)
700
+
701
+ def _has_compilation_errors(self, compile_status):
702
+ """Check if there are any compilation errors across all class trackers."""
703
+ for tracker in compile_status.cls_to_tracker.values():
704
+ if tracker.files_to_errors:
705
+ return True
706
+ return False
707
+
708
+ def _create_config_change(self, obj, obj_type):
709
+ """Create a ConfigChange object from a thrift object."""
710
+ return ConfigChange(
711
+ name=obj.metaData.name,
712
+ obj_type=obj_type,
713
+ online=obj.metaData.online if obj.metaData.online else False,
714
+ production=obj.metaData.production if obj.metaData.production else False,
715
+ )
716
+
717
+ def _prompt_user_confirmation(self) -> bool:
718
+ """
719
+ Prompt user for Y/N confirmation to proceed with overwriting existing configs.
720
+ Returns True if user confirms, False otherwise.
721
+ """
722
+ from ai.chronon.cli.compile.display.console import console
723
+
724
+ console.print(
725
+ "\n❗ [bold red3]Some configs are changing in-place (changing semantics without changing the version).[/bold red3]"
726
+ )
727
+ console.print(
728
+ "[dim]Note that changes can be caused by directly modifying a config, or by changing the version of an upstream object which changes the input source.[/dim]"
729
+ )
730
+ console.print("")
731
+ console.print(
732
+ "[dim]This can be a safe operation if you are sure that the configs that are unaffected by the change (i.e. upstream config is versioning up in a way that does not effect the downstream), or if the objects in question are not yet used by critical workloads.[/dim]"
733
+ )
734
+ console.print("")
735
+ console.print("❓ [bold]Do you want to proceed? (y/N):[/bold]", end=" ")
736
+
737
+ try:
738
+ response = input().strip().lower()
739
+ return response in ["y", "yes"]
740
+ except (EOFError, KeyboardInterrupt):
741
+ console.print("\n❌ Compilation cancelled.")
742
+ return False
File without changes
@@ -24,9 +24,7 @@ class ClassTracker:
24
24
  self.existing_objs[obj.name] = obj
25
25
 
26
26
  def add(self, compiled: CompiledObj) -> None:
27
-
28
27
  if compiled.errors:
29
-
30
28
  if compiled.file not in self.files_to_errors:
31
29
  self.files_to_errors[compiled.file] = []
32
30
 
@@ -43,12 +41,10 @@ class ClassTracker:
43
41
 
44
42
  def _update_diff(self, compiled: CompiledObj) -> None:
45
43
  if compiled.name in self.existing_objs:
46
-
47
44
  existing_json = self.existing_objs[compiled.name].tjson
48
45
  new_json = compiled.tjson
49
46
 
50
47
  if existing_json != new_json:
51
-
52
48
  diff = difflib.unified_diff(
53
49
  existing_json.splitlines(keepends=True),
54
50
  new_json.splitlines(keepends=True),
@@ -74,16 +70,7 @@ class ClassTracker:
74
70
  text = Text(overflow="fold", no_wrap=False)
75
71
 
76
72
  if self.existing_objs:
77
- text.append(
78
- f" Parsed {len(self.existing_objs)} previously compiled objects.\n"
79
- )
80
-
81
- if self.files_to_obj:
82
- text.append(" Compiled ")
83
- text.append(f"{len(self.new_objs)} ", style="bold green")
84
- text.append("objects from ")
85
- text.append(f"{len(self.files_to_obj)} ", style="bold green")
86
- text.append("files.\n")
73
+ text.append(f" Parsed {len(self.existing_objs)} previously compiled objects.\n")
87
74
 
88
75
  if self.files_to_errors:
89
76
  text.append(" Failed to compile ")
@@ -99,7 +86,7 @@ class ClassTracker:
99
86
  for file, errors in self.files_to_errors.items():
100
87
  text.append(" ERROR ", style="bold red")
101
88
  text.append(f"- {file}:\n")
102
-
89
+
103
90
  for error in errors:
104
91
  # Format each error properly, handling newlines
105
92
  error_msg = str(error)
@@ -108,5 +95,8 @@ class ClassTracker:
108
95
  return text
109
96
 
110
97
  # doesn't make sense to show deletes until the very end of compilation
111
- def diff(self) -> Text:
98
+ def diff(self, ignore_python_errors: bool = False) -> Text:
99
+ # Don't show diff if there are compile errors - it's confusing
100
+ if self.files_to_errors and not ignore_python_errors:
101
+ return Text("\n❗Please fix python errors then retry compilation.\n", style="dim cyan")
112
102
  return self.diff_result.render(deleted_names=self.deleted_names)