awx-zipline-ai 0.0.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. __init__.py +0 -0
  2. agent/__init__.py +1 -0
  3. agent/constants.py +15 -0
  4. agent/ttypes.py +1684 -0
  5. ai/__init__.py +0 -0
  6. ai/chronon/__init__.py +0 -0
  7. ai/chronon/airflow_helpers.py +248 -0
  8. ai/chronon/cli/__init__.py +0 -0
  9. ai/chronon/cli/compile/__init__.py +0 -0
  10. ai/chronon/cli/compile/column_hashing.py +336 -0
  11. ai/chronon/cli/compile/compile_context.py +173 -0
  12. ai/chronon/cli/compile/compiler.py +183 -0
  13. ai/chronon/cli/compile/conf_validator.py +742 -0
  14. ai/chronon/cli/compile/display/__init__.py +0 -0
  15. ai/chronon/cli/compile/display/class_tracker.py +102 -0
  16. ai/chronon/cli/compile/display/compile_status.py +95 -0
  17. ai/chronon/cli/compile/display/compiled_obj.py +12 -0
  18. ai/chronon/cli/compile/display/console.py +3 -0
  19. ai/chronon/cli/compile/display/diff_result.py +111 -0
  20. ai/chronon/cli/compile/fill_templates.py +35 -0
  21. ai/chronon/cli/compile/parse_configs.py +134 -0
  22. ai/chronon/cli/compile/parse_teams.py +242 -0
  23. ai/chronon/cli/compile/serializer.py +109 -0
  24. ai/chronon/cli/compile/version_utils.py +42 -0
  25. ai/chronon/cli/git_utils.py +145 -0
  26. ai/chronon/cli/logger.py +59 -0
  27. ai/chronon/constants.py +3 -0
  28. ai/chronon/group_by.py +692 -0
  29. ai/chronon/join.py +580 -0
  30. ai/chronon/logger.py +23 -0
  31. ai/chronon/model.py +40 -0
  32. ai/chronon/query.py +126 -0
  33. ai/chronon/repo/__init__.py +39 -0
  34. ai/chronon/repo/aws.py +284 -0
  35. ai/chronon/repo/cluster.py +136 -0
  36. ai/chronon/repo/compile.py +62 -0
  37. ai/chronon/repo/constants.py +164 -0
  38. ai/chronon/repo/default_runner.py +269 -0
  39. ai/chronon/repo/explore.py +418 -0
  40. ai/chronon/repo/extract_objects.py +134 -0
  41. ai/chronon/repo/gcp.py +586 -0
  42. ai/chronon/repo/gitpython_utils.py +15 -0
  43. ai/chronon/repo/hub_runner.py +261 -0
  44. ai/chronon/repo/hub_uploader.py +109 -0
  45. ai/chronon/repo/init.py +60 -0
  46. ai/chronon/repo/join_backfill.py +119 -0
  47. ai/chronon/repo/run.py +296 -0
  48. ai/chronon/repo/serializer.py +133 -0
  49. ai/chronon/repo/team_json_utils.py +46 -0
  50. ai/chronon/repo/utils.py +481 -0
  51. ai/chronon/repo/zipline.py +35 -0
  52. ai/chronon/repo/zipline_hub.py +277 -0
  53. ai/chronon/resources/__init__.py +0 -0
  54. ai/chronon/resources/gcp/__init__.py +0 -0
  55. ai/chronon/resources/gcp/group_bys/__init__.py +0 -0
  56. ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
  57. ai/chronon/resources/gcp/group_bys/test/data.py +30 -0
  58. ai/chronon/resources/gcp/joins/__init__.py +0 -0
  59. ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
  60. ai/chronon/resources/gcp/joins/test/data.py +26 -0
  61. ai/chronon/resources/gcp/sources/__init__.py +0 -0
  62. ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
  63. ai/chronon/resources/gcp/sources/test/data.py +26 -0
  64. ai/chronon/resources/gcp/teams.py +58 -0
  65. ai/chronon/source.py +86 -0
  66. ai/chronon/staging_query.py +226 -0
  67. ai/chronon/types.py +58 -0
  68. ai/chronon/utils.py +510 -0
  69. ai/chronon/windows.py +48 -0
  70. awx_zipline_ai-0.0.32.dist-info/METADATA +197 -0
  71. awx_zipline_ai-0.0.32.dist-info/RECORD +96 -0
  72. awx_zipline_ai-0.0.32.dist-info/WHEEL +5 -0
  73. awx_zipline_ai-0.0.32.dist-info/entry_points.txt +2 -0
  74. awx_zipline_ai-0.0.32.dist-info/top_level.txt +4 -0
  75. gen_thrift/__init__.py +0 -0
  76. gen_thrift/api/__init__.py +1 -0
  77. gen_thrift/api/constants.py +15 -0
  78. gen_thrift/api/ttypes.py +3754 -0
  79. gen_thrift/common/__init__.py +1 -0
  80. gen_thrift/common/constants.py +15 -0
  81. gen_thrift/common/ttypes.py +1814 -0
  82. gen_thrift/eval/__init__.py +1 -0
  83. gen_thrift/eval/constants.py +15 -0
  84. gen_thrift/eval/ttypes.py +660 -0
  85. gen_thrift/fetcher/__init__.py +1 -0
  86. gen_thrift/fetcher/constants.py +15 -0
  87. gen_thrift/fetcher/ttypes.py +127 -0
  88. gen_thrift/hub/__init__.py +1 -0
  89. gen_thrift/hub/constants.py +15 -0
  90. gen_thrift/hub/ttypes.py +1109 -0
  91. gen_thrift/observability/__init__.py +1 -0
  92. gen_thrift/observability/constants.py +15 -0
  93. gen_thrift/observability/ttypes.py +2355 -0
  94. gen_thrift/planner/__init__.py +1 -0
  95. gen_thrift/planner/constants.py +15 -0
  96. gen_thrift/planner/ttypes.py +1967 -0
ai/__init__.py ADDED
File without changes
ai/chronon/__init__.py ADDED
File without changes
@@ -0,0 +1,248 @@
1
+ import json
2
+ import math
3
+ from typing import OrderedDict
4
+
5
+ from gen_thrift.api.ttypes import GroupBy, Join
6
+ from gen_thrift.common.ttypes import TimeUnit
7
+
8
+ import ai.chronon.utils as utils
9
+ from ai.chronon.constants import (
10
+ AIRFLOW_DEPENDENCIES_KEY,
11
+ AIRFLOW_LABEL_DEPENDENCIES_KEY,
12
+ PARTITION_COLUMN_KEY,
13
+ )
14
+
15
+
16
+ def create_airflow_dependency(table, partition_column, additional_partitions=None, offset=0):
17
+ """
18
+ Create an Airflow dependency object for a table.
19
+
20
+ Args:
21
+ table: The table name (with namespace)
22
+ partition_column: The partition column to use (defaults to 'ds')
23
+ additional_partitions: Additional partitions to include in the dependency
24
+
25
+ Returns:
26
+ A dictionary with name and spec for the Airflow dependency
27
+ """
28
+ assert (
29
+ partition_column is not None
30
+ ), """Partition column must be provided via the spark.chronon.partition.column
31
+ config. This can be set as a default in teams.py, or at the individual config level. For example:
32
+ ```
33
+ Team(
34
+ conf=ConfigProperties(
35
+ common={
36
+ "spark.chronon.partition.column": "_test_column",
37
+ }
38
+ )
39
+ )
40
+ ```
41
+ """
42
+
43
+ additional_partitions_str = ""
44
+ if additional_partitions:
45
+ additional_partitions_str = "/" + "/".join(additional_partitions)
46
+
47
+ return {
48
+ "name": f"wf_{utils.sanitize(table)}_with_offset_{offset}",
49
+ "spec": f"{table}/{partition_column}={{{{ macros.ds_add(ds, {offset}) }}}}{additional_partitions_str}",
50
+ }
51
+
52
+
53
+ def _get_partition_col_from_query(query):
54
+ """Gets partition column from query if available"""
55
+ if query:
56
+ return query.partitionColumn
57
+ return None
58
+
59
+
60
+ def _get_additional_subPartitionsToWaitFor_from_query(query):
61
+ """Gets additional subPartitionsToWaitFor from query if available"""
62
+ if query:
63
+ return query.subPartitionsToWaitFor
64
+ return None
65
+
66
+
67
+ def _get_airflow_deps_from_source(source, partition_column=None):
68
+ """
69
+ Given a source, return a list of Airflow dependencies.
70
+
71
+ Args:
72
+ source: The source object (events, entities, or joinSource)
73
+ partition_column: The partition column to use
74
+
75
+ Returns:
76
+ A list of Airflow dependency objects
77
+ """
78
+ tables = []
79
+ additional_partitions = None
80
+ # Assumes source has already been normalized
81
+ if source.events:
82
+ tables = [source.events.table]
83
+ # Use partition column from query if available, otherwise use the provided one
84
+ source_partition_column, additional_partitions = (
85
+ _get_partition_col_from_query(source.events.query) or partition_column,
86
+ _get_additional_subPartitionsToWaitFor_from_query(source.events.query),
87
+ )
88
+
89
+ elif source.entities:
90
+ # Given the setup of Query, we currently mandate the same partition column for snapshot and mutations tables
91
+ tables = [source.entities.snapshotTable]
92
+ if source.entities.mutationTable:
93
+ tables.append(source.entities.mutationTable)
94
+ source_partition_column, additional_partitions = (
95
+ _get_partition_col_from_query(source.entities.query) or partition_column,
96
+ _get_additional_subPartitionsToWaitFor_from_query(source.entities.query),
97
+ )
98
+ elif source.joinSource:
99
+ # TODO: Handle joinSource -- it doesn't work right now because the metadata isn't set on joinSource at this point
100
+ return []
101
+ else:
102
+ # Unknown source type
103
+ return []
104
+
105
+ return [
106
+ create_airflow_dependency(table, source_partition_column, additional_partitions)
107
+ for table in tables
108
+ ]
109
+
110
+
111
+ def extract_default_partition_column(obj):
112
+ try:
113
+ return obj.metaData.executionInfo.conf.common.get("spark.chronon.partition.column")
114
+ except Exception:
115
+ # Error handling occurs in `create_airflow_dependency`
116
+ return None
117
+
118
+
119
+ def _get_distinct_day_windows(group_by):
120
+ windows = []
121
+ aggs = group_by.aggregations
122
+ if aggs:
123
+ for agg in aggs:
124
+ for window in agg.windows:
125
+ time_unit = window.timeUnit
126
+ length = window.length
127
+ if time_unit == TimeUnit.DAYS:
128
+ windows.append(length)
129
+ elif time_unit == TimeUnit.HOURS:
130
+ windows.append(math.ceil(length / 24))
131
+ elif time_unit == TimeUnit.MINUTES:
132
+ windows.append(math.ceil(length / (24 * 60)))
133
+ return set(windows)
134
+
135
+
136
+ def _set_join_deps(join):
137
+ default_partition_col = extract_default_partition_column(join)
138
+
139
+ deps = []
140
+
141
+ # Handle left source
142
+ left_query = utils.get_query(join.left)
143
+ left_partition_column = _get_partition_col_from_query(left_query) or default_partition_col
144
+ deps.extend(_get_airflow_deps_from_source(join.left, left_partition_column))
145
+
146
+ # Handle right parts (join parts)
147
+ if join.joinParts:
148
+ for join_part in join.joinParts:
149
+ if join_part.groupBy and join_part.groupBy.sources:
150
+ for source in join_part.groupBy.sources:
151
+ source_query = utils.get_query(source)
152
+ source_partition_column = (
153
+ _get_partition_col_from_query(source_query) or default_partition_col
154
+ )
155
+ deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
156
+
157
+ label_deps = []
158
+ # Handle label parts
159
+ if join.labelParts and join.labelParts.labels:
160
+ join_output_table = utils.output_table_name(join, full_name=True)
161
+ partition_column = join.metaData.executionInfo.conf.common[PARTITION_COLUMN_KEY]
162
+
163
+ # set the dependencies on the label sources
164
+ for label_part in join.labelParts.labels:
165
+ group_by = label_part.groupBy
166
+
167
+ # set the dependency on the join output -- one for each distinct window offset
168
+ windows = _get_distinct_day_windows(group_by)
169
+ for window in windows:
170
+ label_deps.append(
171
+ create_airflow_dependency(
172
+ join_output_table, partition_column, offset=-1 * window
173
+ )
174
+ )
175
+
176
+ if group_by and group_by.sources:
177
+ for source in label_part.groupBy.sources:
178
+ source_query = utils.get_query(source)
179
+ source_partition_column = (
180
+ _get_partition_col_from_query(source_query) or default_partition_col
181
+ )
182
+ label_deps.extend(
183
+ _get_airflow_deps_from_source(source, source_partition_column)
184
+ )
185
+
186
+ # Update the metadata customJson with dependencies
187
+ _dedupe_and_set_airflow_deps_json(join, deps, AIRFLOW_DEPENDENCIES_KEY)
188
+
189
+ # Update the metadata customJson with label join deps
190
+ if label_deps:
191
+ _dedupe_and_set_airflow_deps_json(join, label_deps, AIRFLOW_LABEL_DEPENDENCIES_KEY)
192
+
193
+ # Set the t/f flag for label_join
194
+ _set_label_join_flag(join)
195
+
196
+
197
+ def _set_group_by_deps(group_by):
198
+ if not group_by.sources:
199
+ return
200
+
201
+ default_partition_col = extract_default_partition_column(group_by)
202
+
203
+ deps = []
204
+
205
+ # Process each source in the group_by
206
+ for source in group_by.sources:
207
+ source_query = utils.get_query(source)
208
+ source_partition_column = (
209
+ _get_partition_col_from_query(source_query) or default_partition_col
210
+ )
211
+ deps.extend(_get_airflow_deps_from_source(source, source_partition_column))
212
+
213
+ # Update the metadata customJson with dependencies
214
+ _dedupe_and_set_airflow_deps_json(group_by, deps, AIRFLOW_DEPENDENCIES_KEY)
215
+
216
+
217
+ def _set_label_join_flag(join):
218
+ existing_json = join.metaData.customJson or "{}"
219
+ json_map = json.loads(existing_json)
220
+ label_join_flag = False
221
+ if join.labelParts:
222
+ label_join_flag = True
223
+ json_map["label_join"] = label_join_flag
224
+ join.metaData.customJson = json.dumps(json_map)
225
+
226
+
227
+ def _dedupe_and_set_airflow_deps_json(obj, deps, custom_json_key):
228
+ sorted_items = [tuple(sorted(d.items())) for d in deps]
229
+ # Use OrderedDict for re-producible ordering of dependencies
230
+ unique = [OrderedDict(t) for t in sorted_items]
231
+ existing_json = obj.metaData.customJson or "{}"
232
+ json_map = json.loads(existing_json)
233
+ json_map[custom_json_key] = unique
234
+ obj.metaData.customJson = json.dumps(json_map)
235
+
236
+
237
+ def set_airflow_deps(obj):
238
+ """
239
+ Set Airflow dependencies for a Chronon object.
240
+
241
+ Args:
242
+ obj: A Join, GroupBy
243
+ """
244
+ # StagingQuery dependency setting is handled directly in object init
245
+ if isinstance(obj, Join):
246
+ _set_join_deps(obj)
247
+ elif isinstance(obj, GroupBy):
248
+ _set_group_by_deps(obj)
File without changes
File without changes
@@ -0,0 +1,336 @@
1
+ import hashlib
2
+ import re
3
+ from collections import defaultdict
4
+ from typing import Dict, List
5
+
6
+ from gen_thrift.api.ttypes import Derivation, ExternalPart, GroupBy, Join, Source
7
+
8
+ from ai.chronon.group_by import get_output_col_names
9
+
10
+
11
+ # Returns a map of output column to semantic hash, including derivations
12
+ def compute_group_by_columns_hashes(
13
+ group_by: GroupBy, exclude_keys: bool = False
14
+ ) -> Dict[str, str]:
15
+ """
16
+ From the group_by object, get the final output columns after derivations.
17
+ """
18
+ # Get the output columns and their input expressions
19
+ output_to_input = get_pre_derived_group_by_columns(group_by)
20
+
21
+ # Get the base semantic fields that apply to all columns
22
+ base_semantics = []
23
+ for source in group_by.sources:
24
+ base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
25
+
26
+ # Add the major version to semantic fields
27
+ group_by_minor_version_suffix = f"__{group_by.metaData.version}"
28
+ group_by_major_version = group_by.metaData.name
29
+ if group_by_major_version.endswith(group_by_minor_version_suffix):
30
+ group_by_major_version = group_by_major_version[: -len(group_by_minor_version_suffix)]
31
+ base_semantics.append(f"group_by_name:{group_by_major_version}")
32
+
33
+ # Compute the semantic hash for each output column
34
+ output_to_hash = {}
35
+ for output_col, input_expr in output_to_input.items():
36
+ semantic_components = base_semantics + [f"input_expr:{input_expr}"]
37
+ semantic_hash = _compute_semantic_hash(semantic_components)
38
+ output_to_hash[output_col] = semantic_hash
39
+
40
+ if exclude_keys:
41
+ output = {k: v for k, v in output_to_hash.items() if k not in group_by.keyColumns}
42
+ else:
43
+ output = output_to_hash
44
+ if group_by.derivations:
45
+ derived = build_derived_columns(output, group_by.derivations, base_semantics)
46
+ if not exclude_keys:
47
+ # We need to add keys back at this point
48
+ for key in group_by.keyColumns:
49
+ if key not in derived:
50
+ derived[key] = output_to_input.get(key)
51
+ return derived
52
+ else:
53
+ return output
54
+
55
+
56
+ def compute_join_column_hashes(join: Join) -> Dict[str, str]:
57
+ """
58
+ From the join object, get the final output columns -> semantic hash after derivations.
59
+ """
60
+ # Get the base semantics from the left side (table and key expression)
61
+ base_semantic_fields = []
62
+
63
+ output_columns = get_pre_derived_join_features(join) | get_pre_derived_source_keys(join.left)
64
+
65
+ if join.derivations:
66
+ return build_derived_columns(output_columns, join.derivations, base_semantic_fields)
67
+ else:
68
+ return output_columns
69
+
70
+
71
+ def get_pre_derived_join_features(join: Join) -> Dict[str, str]:
72
+ return get_pre_derived_join_internal_features(join) | get_pre_derived_external_features(join)
73
+
74
+
75
+ def get_pre_derived_external_features(join: Join) -> Dict[str, str]:
76
+ external_cols = []
77
+ if join.onlineExternalParts:
78
+ for external_part in join.onlineExternalParts:
79
+ original_external_columns = [
80
+ param.name for param in external_part.source.valueSchema.params
81
+ ]
82
+ prefix = get_external_part_full_name(external_part) + "_"
83
+ for col in original_external_columns:
84
+ external_cols.append(prefix + col)
85
+ # No meaningful semantic information on external columns, so we just return the column names as a self map
86
+ return {x: x for x in external_cols}
87
+
88
+
89
+ def get_pre_derived_source_keys(source: Source) -> Dict[str, str]:
90
+ base_semantics = _extract_source_semantic_info(source)
91
+ source_keys_to_hashes = {}
92
+ for key, expression in extract_selects(source).items():
93
+ source_keys_to_hashes[key] = _compute_semantic_hash(
94
+ base_semantics + [f"select:{key}={expression}"]
95
+ )
96
+ return source_keys_to_hashes
97
+
98
+
99
+ def extract_selects(source: Source) -> Dict[str, str]:
100
+ if source.events:
101
+ return source.events.query.selects
102
+ elif source.entities:
103
+ return source.entities.query.selects
104
+ elif source.joinSource:
105
+ return source.joinSource.query.selects
106
+
107
+
108
+ def get_pre_derived_join_internal_features(join: Join) -> Dict[str, str]:
109
+ # Get the base semantic fields from join left side (without key columns)
110
+ join_base_semantic_fields = _extract_source_semantic_info(join.left)
111
+
112
+ internal_features = {}
113
+ for jp in join.joinParts:
114
+ # Build key mapping semantics - include left side key expressions
115
+ if jp.keyMapping:
116
+ key_mapping_semantics = [
117
+ "join_keys:" + ",".join(f"{k}:{v}" for k, v in sorted(jp.keyMapping.items()))
118
+ ]
119
+ else:
120
+ key_mapping_semantics = []
121
+
122
+ # Include left side key expressions that this join part uses
123
+ left_key_expressions = []
124
+ left_selects = extract_selects(join.left)
125
+ for gb_key in jp.groupBy.keyColumns:
126
+ # Get the mapped key name or use the original key name
127
+ left_key_name = gb_key
128
+ if jp.keyMapping:
129
+ # Find the left side key that maps to this groupby key
130
+ for left_key, mapped_key in jp.keyMapping.items():
131
+ if mapped_key == gb_key:
132
+ left_key_name = left_key
133
+ break
134
+
135
+ # Add the left side expression for this key
136
+ left_key_expr = left_selects.get(left_key_name, left_key_name)
137
+ left_key_expressions.append(f"left_key:{left_key_name}={left_key_expr}")
138
+
139
+ # These semantics apply to all features in the joinPart
140
+ jp_base_semantics = key_mapping_semantics + left_key_expressions + join_base_semantic_fields
141
+
142
+ pre_derived_group_by_features = get_pre_derived_group_by_features(
143
+ jp.groupBy, jp_base_semantics
144
+ )
145
+
146
+ if jp.groupBy.derivations:
147
+ derived_group_by_features = build_derived_columns(
148
+ pre_derived_group_by_features, jp.groupBy.derivations, jp_base_semantics
149
+ )
150
+ else:
151
+ derived_group_by_features = pre_derived_group_by_features
152
+
153
+ for col, semantic_hash in derived_group_by_features.items():
154
+ prefix = jp.prefix + "_" if jp.prefix else ""
155
+ if join.useLongNames:
156
+ gb_prefix = prefix + jp.groupBy.metaData.name.replace(".", "_")
157
+ else:
158
+ key_str = "_".join(jp.groupBy.keyColumns)
159
+ gb_prefix = prefix + key_str
160
+ internal_features[gb_prefix + "_" + col] = semantic_hash
161
+ return internal_features
162
+
163
+
164
+ def get_pre_derived_group_by_columns(group_by: GroupBy) -> Dict[str, str]:
165
+ output_columns_to_hashes = get_pre_derived_group_by_features(group_by)
166
+ for key_column in group_by.keyColumns:
167
+ key_expressions = []
168
+ for source in group_by.sources:
169
+ source_selects = extract_selects(source) # Map[str, str]
170
+ key_expressions.append(source_selects.get(key_column, key_column))
171
+ output_columns_to_hashes[key_column] = _compute_semantic_hash(sorted(key_expressions))
172
+ return output_columns_to_hashes
173
+
174
+
175
+ def get_pre_derived_group_by_features(
176
+ group_by: GroupBy, additional_semantic_fields=None
177
+ ) -> Dict[str, str]:
178
+ # Get the base semantic fields that apply to all aggs
179
+ if additional_semantic_fields is None:
180
+ additional_semantic_fields = []
181
+ base_semantics = _get_base_group_by_semantic_fields(group_by)
182
+
183
+ output_columns = {}
184
+ # For group_bys with aggregations, aggregated columns
185
+ if group_by.aggregations:
186
+ for agg in group_by.aggregations:
187
+ input_expression_str = ",".join(
188
+ get_input_expression_across_sources(group_by, agg.inputColumn)
189
+ )
190
+ for output_col_name in get_output_col_names(agg):
191
+ output_columns[output_col_name] = _compute_semantic_hash(
192
+ base_semantics + [input_expression_str] + additional_semantic_fields
193
+ )
194
+ # For group_bys without aggregations, selected fields from query
195
+ else:
196
+ combined_selects = defaultdict(set)
197
+
198
+ for source in group_by.sources:
199
+ source_selects = extract_selects(source) # Map[str, str]
200
+ for key, val in source_selects.items():
201
+ combined_selects[key].add(val)
202
+
203
+ # Build a unified map of key to select expression from all sources
204
+ unified_selects = {key: ",".join(sorted(vals)) for key, vals in combined_selects.items()}
205
+
206
+ # now compute the hashes on base semantics + expression
207
+ selected_hashes = {
208
+ key: _compute_semantic_hash(base_semantics + [val] + additional_semantic_fields)
209
+ for key, val in unified_selects.items()
210
+ }
211
+ output_columns.update(selected_hashes)
212
+ return output_columns
213
+
214
+
215
+ def _get_base_group_by_semantic_fields(group_by: GroupBy) -> List[str]:
216
+ """
217
+ Extract base semantic fields from the group_by object.
218
+ This includes source table names, key columns, and any other relevant metadata.
219
+ """
220
+ base_semantics = []
221
+ for source in group_by.sources:
222
+ base_semantics.extend(_extract_source_semantic_info(source, group_by.keyColumns))
223
+
224
+ return sorted(base_semantics)
225
+
226
+
227
+ def _extract_source_semantic_info(source: Source, key_columns: List[str] = None) -> List[str]:
228
+ """
229
+ Extract source information for semantic hashing.
230
+ Returns list of semantic components.
231
+
232
+ Args:
233
+ source: The source to extract info from
234
+ key_columns: List of key columns to include expressions for. If None, includes all selects.
235
+ """
236
+ components = []
237
+
238
+ if source.events:
239
+ table = source.events.table
240
+ mutationTable = ""
241
+ query = source.events.query
242
+ cumulative = str(source.events.isCumulative or "")
243
+ elif source.entities:
244
+ table = source.entities.snapshotTable
245
+ mutationTable = source.entities.mutationTable
246
+ query = source.entities.query
247
+ cumulative = ""
248
+ elif source.joinSource:
249
+ table = source.joinSource.join.metaData.name
250
+ mutationTable = ""
251
+ query = source.joinSource.query
252
+ cumulative = ""
253
+
254
+ components.append(f"table:{table}")
255
+ components.append(f"mutation_table:{mutationTable}")
256
+ components.append(f"cumulative:{cumulative}")
257
+ components.append(f"filters:{query.wheres or ''}")
258
+
259
+ selects = query.selects or {}
260
+ if key_columns:
261
+ # Only include expressions for the specified key columns
262
+ for key_col in sorted(key_columns):
263
+ expr = selects.get(key_col, key_col)
264
+ components.append(f"select:{key_col}={expr}")
265
+
266
+ # Add time column expression
267
+ if query.timeColumn:
268
+ time_expr = selects.get(query.timeColumn, query.timeColumn)
269
+ components.append(f"time_column:{query.timeColumn}={time_expr}")
270
+
271
+ return sorted(components)
272
+
273
+
274
+ def _compute_semantic_hash(components: List[str]) -> str:
275
+ """
276
+ Compute semantic hash from a list of components.
277
+ Components should be ordered consistently to ensure reproducible hashes.
278
+ """
279
+ # Sort components to ensure consistent ordering
280
+ sorted_components = sorted(components)
281
+ hash_input = "|".join(sorted_components)
282
+ return hashlib.md5(hash_input.encode("utf-8")).hexdigest()
283
+
284
+
285
+ def build_derived_columns(
286
+ base_columns_to_hashes: Dict[str, str],
287
+ derivations: List[Derivation],
288
+ additional_semantic_fields: List[str],
289
+ ) -> Dict[str, str]:
290
+ """
291
+ Build the derived columns from pre-derived columns and derivations.
292
+ """
293
+ # if derivations contain star, then all columns are included except the columns which are renamed
294
+ output_columns = {}
295
+ if derivations:
296
+ found = any(derivation.expression == "*" for derivation in derivations)
297
+ if found:
298
+ output_columns.update(base_columns_to_hashes)
299
+ for derivation in derivations:
300
+ if base_columns_to_hashes.get(derivation.expression):
301
+ # don't change the semantics if you're just passing a base column through derivations
302
+ output_columns[derivation.name] = base_columns_to_hashes[derivation.expression]
303
+ if derivation.name != "*":
304
+ # Identify base fields present within the derivation to include in the semantic hash
305
+ # We go long to short to avoid taking both a windowed feature and the unwindowed feature
306
+ # i.e. f_7d and f
307
+ derivation_expression = derivation.expression
308
+ base_col_semantic_fields = []
309
+ tokens = re.findall(r"\b\w+\b", derivation_expression)
310
+ for token in tokens:
311
+ if token in base_columns_to_hashes:
312
+ base_col_semantic_fields.append(base_columns_to_hashes[token])
313
+
314
+ output_columns[derivation.name] = _compute_semantic_hash(
315
+ additional_semantic_fields
316
+ + [f"derivation:{derivation.expression}"]
317
+ + base_col_semantic_fields
318
+ )
319
+ return output_columns
320
+
321
+
322
+ def get_external_part_full_name(external_part: ExternalPart) -> str:
323
+ # The logic should be consistent with the full name logic defined
324
+ # in https://github.com/airbnb/chronon/blob/main/api/src/main/scala/ai/chronon/api/Extensions.scala#L677.
325
+ prefix = external_part.prefix + "_" if external_part.prefix else ""
326
+ name = external_part.source.metadata.name
327
+ sanitized_name = re.sub("[^a-zA-Z0-9_]", "_", name)
328
+ return "ext_" + prefix + sanitized_name
329
+
330
+
331
+ def get_input_expression_across_sources(group_by: GroupBy, input_col: str):
332
+ expressions = []
333
+ for source in group_by.sources:
334
+ selects = extract_selects(source)
335
+ expressions.extend(selects.get(input_col, input_col))
336
+ return sorted(expressions)