awx-zipline-ai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. agent/__init__.py +1 -0
  2. agent/constants.py +15 -0
  3. agent/ttypes.py +1684 -0
  4. ai/__init__.py +0 -0
  5. ai/chronon/__init__.py +0 -0
  6. ai/chronon/airflow_helpers.py +251 -0
  7. ai/chronon/api/__init__.py +1 -0
  8. ai/chronon/api/common/__init__.py +1 -0
  9. ai/chronon/api/common/constants.py +15 -0
  10. ai/chronon/api/common/ttypes.py +1844 -0
  11. ai/chronon/api/constants.py +15 -0
  12. ai/chronon/api/ttypes.py +3624 -0
  13. ai/chronon/cli/compile/column_hashing.py +313 -0
  14. ai/chronon/cli/compile/compile_context.py +177 -0
  15. ai/chronon/cli/compile/compiler.py +160 -0
  16. ai/chronon/cli/compile/conf_validator.py +590 -0
  17. ai/chronon/cli/compile/display/class_tracker.py +112 -0
  18. ai/chronon/cli/compile/display/compile_status.py +95 -0
  19. ai/chronon/cli/compile/display/compiled_obj.py +12 -0
  20. ai/chronon/cli/compile/display/console.py +3 -0
  21. ai/chronon/cli/compile/display/diff_result.py +46 -0
  22. ai/chronon/cli/compile/fill_templates.py +40 -0
  23. ai/chronon/cli/compile/parse_configs.py +141 -0
  24. ai/chronon/cli/compile/parse_teams.py +238 -0
  25. ai/chronon/cli/compile/serializer.py +115 -0
  26. ai/chronon/cli/git_utils.py +156 -0
  27. ai/chronon/cli/logger.py +61 -0
  28. ai/chronon/constants.py +3 -0
  29. ai/chronon/eval/__init__.py +122 -0
  30. ai/chronon/eval/query_parsing.py +19 -0
  31. ai/chronon/eval/sample_tables.py +100 -0
  32. ai/chronon/eval/table_scan.py +186 -0
  33. ai/chronon/fetcher/__init__.py +1 -0
  34. ai/chronon/fetcher/constants.py +15 -0
  35. ai/chronon/fetcher/ttypes.py +127 -0
  36. ai/chronon/group_by.py +692 -0
  37. ai/chronon/hub/__init__.py +1 -0
  38. ai/chronon/hub/constants.py +15 -0
  39. ai/chronon/hub/ttypes.py +1228 -0
  40. ai/chronon/join.py +566 -0
  41. ai/chronon/logger.py +24 -0
  42. ai/chronon/model.py +35 -0
  43. ai/chronon/observability/__init__.py +1 -0
  44. ai/chronon/observability/constants.py +15 -0
  45. ai/chronon/observability/ttypes.py +2192 -0
  46. ai/chronon/orchestration/__init__.py +1 -0
  47. ai/chronon/orchestration/constants.py +15 -0
  48. ai/chronon/orchestration/ttypes.py +4406 -0
  49. ai/chronon/planner/__init__.py +1 -0
  50. ai/chronon/planner/constants.py +15 -0
  51. ai/chronon/planner/ttypes.py +1686 -0
  52. ai/chronon/query.py +126 -0
  53. ai/chronon/repo/__init__.py +40 -0
  54. ai/chronon/repo/aws.py +298 -0
  55. ai/chronon/repo/cluster.py +65 -0
  56. ai/chronon/repo/compile.py +56 -0
  57. ai/chronon/repo/constants.py +164 -0
  58. ai/chronon/repo/default_runner.py +291 -0
  59. ai/chronon/repo/explore.py +421 -0
  60. ai/chronon/repo/extract_objects.py +137 -0
  61. ai/chronon/repo/gcp.py +585 -0
  62. ai/chronon/repo/gitpython_utils.py +14 -0
  63. ai/chronon/repo/hub_runner.py +171 -0
  64. ai/chronon/repo/hub_uploader.py +108 -0
  65. ai/chronon/repo/init.py +53 -0
  66. ai/chronon/repo/join_backfill.py +105 -0
  67. ai/chronon/repo/run.py +293 -0
  68. ai/chronon/repo/serializer.py +141 -0
  69. ai/chronon/repo/team_json_utils.py +46 -0
  70. ai/chronon/repo/utils.py +472 -0
  71. ai/chronon/repo/zipline.py +51 -0
  72. ai/chronon/repo/zipline_hub.py +105 -0
  73. ai/chronon/resources/gcp/README.md +174 -0
  74. ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
  75. ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
  76. ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
  77. ai/chronon/resources/gcp/joins/test/data.py +30 -0
  78. ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
  79. ai/chronon/resources/gcp/sources/test/data.py +23 -0
  80. ai/chronon/resources/gcp/teams.py +70 -0
  81. ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
  82. ai/chronon/source.py +88 -0
  83. ai/chronon/staging_query.py +185 -0
  84. ai/chronon/types.py +57 -0
  85. ai/chronon/utils.py +557 -0
  86. ai/chronon/windows.py +50 -0
  87. awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
  88. awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
  89. awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
  90. awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
  91. awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
  92. awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
  93. jars/__init__.py +0 -0
@@ -0,0 +1,95 @@
1
+ from collections import OrderedDict
2
+ from typing import Dict
3
+
4
+ from rich.live import Live
5
+ from rich.text import Text
6
+
7
+ from ai.chronon.cli.compile.display.class_tracker import ClassTracker
8
+ from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
9
+
10
+
11
+ class CompileStatus:
12
+ """
13
+ Uses rich ui - to consolidate and sink the overview of the compile process to the bottom.
14
+ """
15
+
16
+ def __init__(self, use_live: bool = False):
17
+ self.cls_to_tracker: Dict[str, ClassTracker] = OrderedDict()
18
+ self.use_live = use_live
19
+ # we need vertical_overflow to be visible as the output gets cufoff when our output goes past the termianal window
20
+ # but then we start seeing duplicates: https://github.com/Textualize/rich/issues/3263
21
+ if self.use_live:
22
+ self.live = Live(refresh_per_second=50, vertical_overflow="visible")
23
+ self.live.start()
24
+
25
+ def print_live_console(self, msg: str):
26
+ if self.use_live:
27
+ self.live.console.print(msg)
28
+
29
+ def add_object_update_display(
30
+ self, compiled: CompiledObj, obj_type: str = None
31
+ ) -> None:
32
+
33
+ if compiled.obj_type is not None and obj_type is not None:
34
+ assert (
35
+ compiled.obj_type == obj_type
36
+ ), f"obj_type mismatch: {compiled.obj_type} != {obj_type}"
37
+
38
+ if obj_type not in self.cls_to_tracker:
39
+ self.cls_to_tracker[obj_type] = ClassTracker()
40
+
41
+ self.cls_to_tracker[obj_type].add(compiled)
42
+
43
+ self._update_display()
44
+
45
+ def add_existing_object_update_display(self, existing_obj: CompiledObj) -> None:
46
+
47
+ obj_type = existing_obj.obj_type
48
+
49
+ if obj_type not in self.cls_to_tracker:
50
+ self.cls_to_tracker[obj_type] = ClassTracker()
51
+
52
+ self.cls_to_tracker[obj_type].add_existing(existing_obj)
53
+
54
+ self._update_display()
55
+
56
+ def close_cls(self, obj_type: str) -> None:
57
+ if obj_type in self.cls_to_tracker:
58
+ self.cls_to_tracker[obj_type].close()
59
+ self._update_display()
60
+
61
+ def close(self) -> None:
62
+ self._update_display()
63
+ if self.use_live:
64
+ self.live.stop()
65
+
66
+ def render(self) -> Text:
67
+ text = Text(overflow="fold", no_wrap=False)
68
+
69
+ for obj_type, tracker in self.cls_to_tracker.items():
70
+ text.append(f"\n{obj_type}-s:\n", style="cyan")
71
+
72
+ status = tracker.to_status()
73
+ if status:
74
+ text.append(status)
75
+
76
+ errors = tracker.to_errors()
77
+ if errors:
78
+ text.append(errors)
79
+
80
+ diff = tracker.diff()
81
+ if diff:
82
+ text.append(diff)
83
+
84
+ text.append("\n")
85
+ return text
86
+
87
+ def _update_display(self):
88
+ # self.live.clear()
89
+
90
+ # TODO: add this after live_crop is implemented
91
+ # text = self.display_text()
92
+ # if self.use_live:
93
+ # self.live.update(text, refresh=True)
94
+ # return text
95
+ pass
@@ -0,0 +1,12 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, List, Optional
3
+
4
+
5
+ @dataclass
6
+ class CompiledObj:
7
+ name: str
8
+ obj: Any
9
+ file: str
10
+ errors: Optional[List[Exception]]
11
+ obj_type: str
12
+ tjson: str
@@ -0,0 +1,3 @@
1
+ from rich.console import Console
2
+
3
+ console = Console()
@@ -0,0 +1,46 @@
1
+ from typing import List
2
+
3
+ from rich.text import Text
4
+
5
+
6
+ class DiffResult:
7
+
8
+ def __init__(self):
9
+ self.added: List[str] = []
10
+ self.updated: List[str] = []
11
+
12
+ def render(self, deleted_names: List[str], indent=" ") -> Text:
13
+
14
+ def added_signage():
15
+ return Text("Added", style="dim green")
16
+
17
+ def updated_signage():
18
+ return Text("Updated", style="dim yellow")
19
+
20
+ def deleted_signage():
21
+ return Text("Deleted", style="red")
22
+
23
+ added = [(added_signage(), name) for name in self.added]
24
+
25
+ updated = [(updated_signage(), name) for name in self.updated]
26
+
27
+ result_order = added + updated
28
+
29
+ if deleted_names:
30
+ deleted = [(deleted_signage(), name) for name in deleted_names]
31
+ result_order += deleted
32
+
33
+ result_order = sorted(result_order, key=lambda t: t[1])
34
+
35
+ text = Text(overflow="fold", no_wrap=False)
36
+ for signage, name in result_order:
37
+ text.append(indent)
38
+ text.append(signage)
39
+ text.append(" ")
40
+ text.append(name)
41
+ text.append("\n")
42
+
43
+ if not text:
44
+ return Text(indent + "No new changes detected\n", style="dim")
45
+
46
+ return text
@@ -0,0 +1,40 @@
1
+ from ai.chronon import utils
2
+ from ai.chronon.api.ttypes import Join, Team
3
+ from ai.chronon.cli.compile.compile_context import CompileContext
4
+
5
+
6
+ def _fill_template(table, obj, namespace):
7
+
8
+ if table:
9
+ table = table.replace(
10
+ "{{ logged_table }}", utils.log_table_name(obj, full_name=True)
11
+ )
12
+ table = table.replace("{{ db }}", namespace)
13
+
14
+ return table
15
+
16
+
17
+ def set_templated_values(obj, cls, compile_context: CompileContext):
18
+
19
+ team_obj: Team = compile_context.teams_dict[obj.team]
20
+ namespace = team_obj.outputNamespace
21
+
22
+ if cls == Join and obj.bootstrapParts:
23
+
24
+ for bootstrap in obj.bootstrapParts:
25
+ bootstrap.table = _fill_template(bootstrap.table, obj, namespace)
26
+
27
+ if obj.metaData.dependencies:
28
+ obj.metaData.dependencies = [
29
+ _fill_template(dep, obj, namespace) for dep in obj.metaData.dependencies
30
+ ]
31
+
32
+ if cls == Join and obj.labelParts:
33
+
34
+ obj.labelParts.metaData.dependencies = [
35
+ label_dep.replace(
36
+ "{{ join_backfill_table }}",
37
+ utils.output_table_name(obj, full_name=True),
38
+ )
39
+ for label_dep in obj.labelParts.metaData.dependencies
40
+ ]
@@ -0,0 +1,141 @@
1
+ import copy
2
+ import glob
3
+ import importlib
4
+ import os
5
+ from typing import Any, List
6
+
7
+ from ai.chronon import airflow_helpers
8
+ from ai.chronon.api.ttypes import GroupBy, Join
9
+ from ai.chronon.cli.compile import parse_teams, serializer
10
+ from ai.chronon.cli.compile.compile_context import CompileContext
11
+ from ai.chronon.cli.compile.display.compiled_obj import CompiledObj
12
+ from ai.chronon.cli.logger import get_logger
13
+
14
+ logger = get_logger()
15
+
16
+ def from_folder(
17
+ cls: type, input_dir: str, compile_context: CompileContext
18
+ ) -> List[CompiledObj]:
19
+ """
20
+ Recursively consumes a folder, and constructs a map of
21
+ object qualifier to StagingQuery, GroupBy, or Join
22
+ """
23
+
24
+ python_files = glob.glob(os.path.join(input_dir, "**/*.py"), recursive=True)
25
+
26
+ results = []
27
+
28
+ for f in python_files:
29
+
30
+ try:
31
+ results_dict = from_file(f, cls, input_dir)
32
+
33
+ for name, obj in results_dict.items():
34
+ parse_teams.update_metadata(obj, compile_context.teams_dict)
35
+ # Populate columnHashes field with semantic hashes
36
+ populate_column_hashes(obj)
37
+
38
+ # Airflow deps must be set AFTER updating metadata
39
+ airflow_helpers.set_airflow_deps(obj)
40
+
41
+ obj.metaData.sourceFile = os.path.relpath(f, compile_context.chronon_root)
42
+
43
+ tjson = serializer.thrift_simple_json(obj)
44
+
45
+ # Perform validation
46
+ errors = compile_context.validator.validate_obj(obj)
47
+
48
+ result = CompiledObj(
49
+ name=name,
50
+ obj=obj,
51
+ file=f,
52
+ errors=errors if len(errors) > 0 else None,
53
+ obj_type=cls.__name__,
54
+ tjson=tjson,
55
+ )
56
+ results.append(result)
57
+
58
+ compile_context.compile_status.add_object_update_display(
59
+ result, cls.__name__
60
+ )
61
+
62
+ except Exception as e:
63
+ result = CompiledObj(
64
+ name=None,
65
+ obj=None,
66
+ file=f,
67
+ errors=[e],
68
+ obj_type=cls.__name__,
69
+ tjson=None,
70
+ )
71
+
72
+ results.append(result)
73
+
74
+ compile_context.compile_status.add_object_update_display(
75
+ result, cls.__name__
76
+ )
77
+
78
+ return results
79
+
80
+
81
+ def from_file(file_path: str, cls: type, input_dir: str):
82
+
83
+ # this is where the python path should have been set to
84
+ chronon_root = os.path.dirname(input_dir)
85
+ rel_path = os.path.relpath(file_path, chronon_root)
86
+
87
+ rel_path_without_extension = os.path.splitext(rel_path)[0]
88
+
89
+ module_name = rel_path_without_extension.replace("/", ".")
90
+
91
+ conf_type, team_name_with_path = module_name.split(".", 1)
92
+ mod_path = team_name_with_path.replace("/", ".")
93
+
94
+ module = importlib.import_module(module_name)
95
+
96
+ result = {}
97
+
98
+ for var_name, obj in list(module.__dict__.items()):
99
+
100
+ if isinstance(obj, cls):
101
+
102
+ copied_obj = copy.deepcopy(obj)
103
+
104
+ name = f"{mod_path}.{var_name}"
105
+
106
+ # Add version suffix if version is set
107
+ name = name + "__" + str(copied_obj.metaData.version)
108
+
109
+ copied_obj.metaData.name = name
110
+ copied_obj.metaData.team = mod_path.split(".")[0]
111
+
112
+ result[name] = copied_obj
113
+
114
+ return result
115
+
116
+ def populate_column_hashes(obj: Any):
117
+ """
118
+ Populate the columnHashes field in the object's metadata with semantic hashes
119
+ for each output column.
120
+ """
121
+ # Import here to avoid circular imports
122
+ from ai.chronon.cli.compile.column_hashing import (
123
+ compute_group_by_columns_hashes,
124
+ compute_join_column_hashes,
125
+ )
126
+
127
+ if isinstance(obj, GroupBy):
128
+ # For GroupBy objects, get column hashes
129
+ column_hashes = compute_group_by_columns_hashes(obj, exclude_keys=False)
130
+ obj.metaData.columnHashes = column_hashes
131
+
132
+ elif isinstance(obj, Join):
133
+ # For Join objects, get column hashes
134
+ column_hashes = compute_join_column_hashes(obj)
135
+ obj.metaData.columnHashes = column_hashes
136
+
137
+ if obj.joinParts:
138
+ for jp in (obj.joinParts or []):
139
+ group_by = jp.groupBy
140
+ group_by_hashes = compute_group_by_columns_hashes(group_by)
141
+ group_by.metaData.columnHashes = group_by_hashes
@@ -0,0 +1,238 @@
1
+ import importlib
2
+ import importlib.util
3
+ import os
4
+ import sys
5
+ from copy import deepcopy
6
+ from enum import Enum
7
+ from typing import Any, Dict, Optional, Union
8
+
9
+ from ai.chronon.api.common.ttypes import (
10
+ ClusterConfigProperties,
11
+ ConfigProperties,
12
+ EnvironmentVariables,
13
+ ExecutionInfo,
14
+ )
15
+ from ai.chronon.api.ttypes import Join, MetaData, Team
16
+ from ai.chronon.cli.compile.display.console import console
17
+ from ai.chronon.cli.logger import get_logger
18
+
19
+ logger = get_logger()
20
+
21
+ _DEFAULT_CONF_TEAM = "default"
22
+
23
+
24
+ def import_module_from_file(file_path):
25
+ # Get the module name from the file path (without .py extension)
26
+ module_name = file_path.split("/")[-1].replace(".py", "")
27
+
28
+ # Create the module spec
29
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
30
+
31
+ # Create the module based on the spec
32
+ module = importlib.util.module_from_spec(spec)
33
+
34
+ # Add the module to sys.modules
35
+ sys.modules[module_name] = module
36
+
37
+ # Execute the module
38
+ spec.loader.exec_module(module)
39
+
40
+ return module
41
+
42
+
43
+ def load_teams(conf_root: str, print: bool = True) -> Dict[str, Team]:
44
+ teams_file = os.path.join(conf_root, "teams.py")
45
+
46
+ assert os.path.exists(
47
+ teams_file
48
+ ), f"Team config file: {teams_file} not found. You might be running this from the wrong directory."
49
+
50
+ team_module = import_module_from_file(teams_file)
51
+
52
+ assert team_module is not None, (
53
+ f"Team config file {teams_file} is not on the PYTHONPATH. You might need to add the your config "
54
+ f"directory to the PYTHONPATH."
55
+ )
56
+
57
+ team_dict = {}
58
+
59
+ if print:
60
+ console.print(
61
+ f"Pulling configuration from [cyan italic]{teams_file}[/cyan italic]"
62
+ )
63
+
64
+ for name, obj in team_module.__dict__.items():
65
+ if isinstance(obj, Team):
66
+ obj.name = name
67
+ team_dict[name] = obj
68
+
69
+ return team_dict
70
+
71
+
72
+ def update_metadata(obj: Any, team_dict: Dict[str, Team]):
73
+ assert obj is not None, "Cannot update metadata None object"
74
+
75
+ metadata = obj.metaData
76
+
77
+ assert obj.metaData is not None, "Cannot update empty metadata"
78
+
79
+ name = obj.metaData.name
80
+ team = obj.metaData.team
81
+
82
+ assert (
83
+ team is not None
84
+ ), f"Team name is required in metadata for {name}. This usually set by compiler. Internal error."
85
+
86
+ assert (
87
+ team in team_dict
88
+ ), f"Team '{team}' not found in teams.py. Please add an entry 🙏"
89
+
90
+ assert (
91
+ _DEFAULT_CONF_TEAM in team_dict
92
+ ), f"'{_DEFAULT_CONF_TEAM}' team not found in teams.py, please add an entry 🙏."
93
+
94
+ # Only set the outputNamespace if it hasn't been set already
95
+ if not metadata.outputNamespace:
96
+ metadata.outputNamespace = team_dict[team].outputNamespace
97
+
98
+ if isinstance(obj, Join):
99
+ join_namespace = obj.metaData.outputNamespace
100
+
101
+ # set the metadata for each join part and labelParts
102
+ def set_group_by_metadata(join_part_gb, output_namespace):
103
+ if join_part_gb is not None:
104
+ if join_part_gb.metaData:
105
+ # Only set the outputNamespace if it hasn't been set already
106
+ if not join_part_gb.metaData.outputNamespace:
107
+ join_part_gb.metaData.outputNamespace = output_namespace
108
+ else:
109
+ # If there's no metaData at all, create it and set outputNamespace
110
+ join_part_gb.metaData = MetaData()
111
+ join_part_gb.metaData.outputNamespace = output_namespace
112
+
113
+ if obj.joinParts:
114
+ for jp in (obj.joinParts or []):
115
+ jp.useLongNames = obj.useLongNames
116
+ set_group_by_metadata(jp.groupBy, join_namespace)
117
+
118
+ if obj.labelParts:
119
+ for lb in (obj.labelParts.labels or []):
120
+ lb.useLongNames = obj.useLongNames
121
+ set_group_by_metadata(lb.groupBy, join_namespace)
122
+
123
+ if metadata.executionInfo is None:
124
+ metadata.executionInfo = ExecutionInfo()
125
+
126
+ merge_team_execution_info(metadata, team_dict, team)
127
+
128
+
129
+ def merge_team_execution_info(metadata: MetaData, team_dict: Dict[str, Team], team_name: str):
130
+ default_team = team_dict.get(_DEFAULT_CONF_TEAM)
131
+ if not metadata.executionInfo:
132
+ metadata.executionInfo = ExecutionInfo()
133
+
134
+ metadata.executionInfo.env = _merge_mode_maps(
135
+ default_team.env if default_team else {},
136
+ team_dict[team_name].env,
137
+ metadata.executionInfo.env,
138
+ env_or_config_attribute=EnvOrConfigAttribute.ENV,
139
+ )
140
+
141
+ metadata.executionInfo.conf = _merge_mode_maps(
142
+ default_team.conf if default_team else {},
143
+ team_dict[team_name].conf,
144
+ metadata.executionInfo.conf,
145
+ env_or_config_attribute=EnvOrConfigAttribute.CONFIG,
146
+ )
147
+
148
+ metadata.executionInfo.clusterConf = _merge_mode_maps(
149
+ default_team.clusterConf if default_team else {},
150
+ team_dict[team_name].clusterConf,
151
+ metadata.executionInfo.clusterConf,
152
+ env_or_config_attribute=EnvOrConfigAttribute.CLUSTER_CONFIG,
153
+ )
154
+
155
+
156
+ def _merge_maps(*maps: Optional[Dict[str, str]]):
157
+ """
158
+ Merges multiple maps into one - with the later maps overriding the earlier ones.
159
+ """
160
+
161
+ result = {}
162
+
163
+ for m in maps:
164
+
165
+ if m is None:
166
+ continue
167
+
168
+ for key, value in m.items():
169
+ result[key] = value
170
+
171
+ return result
172
+
173
+
174
+ class EnvOrConfigAttribute(str, Enum):
175
+ ENV = "modeEnvironments"
176
+ CONFIG = "modeConfigs"
177
+ CLUSTER_CONFIG = "modeClusterConfigs"
178
+
179
+
180
+ def _merge_mode_maps(
181
+ *mode_maps: Optional[Union[EnvironmentVariables, ConfigProperties, ClusterConfigProperties]],
182
+ env_or_config_attribute: EnvOrConfigAttribute,
183
+ ):
184
+ """
185
+ Merges multiple environment variables into one - with the later maps overriding the earlier ones.
186
+ """
187
+
188
+ # Merge `common` to each individual mode map. Creates a new map
189
+ def push_common_to_modes(mode_map: Union[EnvironmentVariables, ConfigProperties], mode_key: EnvOrConfigAttribute):
190
+ final_mode_map = deepcopy(mode_map)
191
+ common = final_mode_map.common
192
+ modes = getattr(final_mode_map, mode_key)
193
+ for _ in modes:
194
+ modes[_] = _merge_maps(
195
+ common, modes[_]
196
+ )
197
+ return final_mode_map
198
+
199
+ filtered_mode_maps = [m for m in mode_maps if m]
200
+
201
+ # Initialize the result with the first mode map
202
+ result = None
203
+
204
+ if len(filtered_mode_maps) >= 1:
205
+ result = push_common_to_modes(filtered_mode_maps[0], env_or_config_attribute)
206
+
207
+ # Merge each new mode map into the result
208
+ for m in filtered_mode_maps[1:]:
209
+ # We want to prepare the individual modes with `common` in incoming_mode_map
210
+ incoming_mode_map = push_common_to_modes(m, env_or_config_attribute)
211
+
212
+ # create new common
213
+ incoming_common = incoming_mode_map.common
214
+ new_common = _merge_maps(result.common, incoming_common)
215
+ result.common = new_common
216
+
217
+ current_modes = getattr(result, env_or_config_attribute)
218
+ incoming_modes = getattr(incoming_mode_map, env_or_config_attribute)
219
+
220
+ current_modes_keys = list(current_modes.keys())
221
+ incoming_modes_keys = list(incoming_modes.keys())
222
+
223
+ all_modes_keys = list(set(current_modes_keys + incoming_modes_keys))
224
+ for mode in all_modes_keys:
225
+ current_mode = current_modes.get(mode, {})
226
+
227
+ # if the incoming_mode is not found, we NEED to default to incoming_common
228
+ incoming_mode = incoming_modes.get(mode, incoming_common)
229
+
230
+ # first to last with later ones overriding the earlier ones
231
+ # common -> current mode level -> incoming mode level
232
+
233
+ new_mode = _merge_maps(
234
+ new_common, current_mode, incoming_mode
235
+ )
236
+ current_modes[mode] = new_mode
237
+
238
+ return result
@@ -0,0 +1,115 @@
1
+ # Copyright (C) 2023 The Chronon Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+
17
+ from thrift import TSerialization
18
+ from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
19
+ from thrift.protocol.TJSONProtocol import TSimpleJSONProtocolFactory
20
+ from thrift.Thrift import TType
21
+ from thrift.transport.TTransport import TMemoryBuffer
22
+
23
+
24
+ class ThriftJSONDecoder(json.JSONDecoder):
25
+ def __init__(self, *args, **kwargs):
26
+ self._thrift_class = kwargs.pop("thrift_class")
27
+ super(ThriftJSONDecoder, self).__init__(*args, **kwargs)
28
+
29
+ def decode(self, json_str):
30
+ if isinstance(json_str, dict):
31
+ dct = json_str
32
+ else:
33
+ dct = super(ThriftJSONDecoder, self).decode(json_str)
34
+ return self._convert(
35
+ dct, TType.STRUCT, (self._thrift_class, self._thrift_class.thrift_spec)
36
+ )
37
+
38
+ def _convert(self, val, ttype, ttype_info):
39
+ if ttype == TType.STRUCT:
40
+ (thrift_class, thrift_spec) = ttype_info
41
+ ret = thrift_class()
42
+ for field in thrift_spec:
43
+ if field is None:
44
+ continue
45
+ (_, field_ttype, field_name, field_ttype_info, dummy) = field
46
+ if field_name not in val:
47
+ continue
48
+ converted_val = self._convert(
49
+ val[field_name], field_ttype, field_ttype_info
50
+ )
51
+ setattr(ret, field_name, converted_val)
52
+ elif ttype == TType.LIST:
53
+ (element_ttype, element_ttype_info, _) = ttype_info
54
+ ret = [self._convert(x, element_ttype, element_ttype_info) for x in val]
55
+ elif ttype == TType.SET:
56
+ (element_ttype, element_ttype_info) = ttype_info
57
+ ret = set(
58
+ [self._convert(x, element_ttype, element_ttype_info) for x in val]
59
+ )
60
+ elif ttype == TType.MAP:
61
+ (key_ttype, key_ttype_info, val_ttype, val_ttype_info, _) = ttype_info
62
+ ret = dict(
63
+ [
64
+ (
65
+ self._convert(k, key_ttype, key_ttype_info),
66
+ self._convert(v, val_ttype, val_ttype_info),
67
+ )
68
+ for (k, v) in val.items()
69
+ ]
70
+ )
71
+ elif ttype == TType.STRING:
72
+ ret = str(val)
73
+ elif ttype == TType.DOUBLE:
74
+ ret = float(val)
75
+ elif ttype == TType.I64:
76
+ ret = int(val)
77
+ elif ttype == TType.I32 or ttype == TType.I16 or ttype == TType.BYTE:
78
+ ret = int(val)
79
+ elif ttype == TType.BOOL:
80
+ ret = bool(val)
81
+ else:
82
+ raise TypeError("Unrecognized thrift field type: %d" % ttype)
83
+ return ret
84
+
85
+
86
+ def json2thrift(json_str, thrift_class):
87
+ return json.loads(json_str, cls=ThriftJSONDecoder, thrift_class=thrift_class)
88
+
89
+
90
+ def json2binary(json_str, thrift_class):
91
+ thrift = json2thrift(json_str, thrift_class)
92
+ transport = TMemoryBuffer()
93
+ protocol = TBinaryProtocolAccelerated(transport)
94
+ thrift.write(protocol)
95
+ # Get the raw bytes representing the object in Thrift binary format
96
+ return transport.getvalue()
97
+
98
+
99
+ def file2thrift(path, thrift_class):
100
+ try:
101
+ with open(path, "r") as file:
102
+ return json2thrift(file.read(), thrift_class)
103
+ except json.decoder.JSONDecodeError as e:
104
+ raise Exception(
105
+ f"Error decoding file into a {thrift_class.__name__}: {path}. "
106
+ + f"Please double check that {path} represents a valid {thrift_class.__name__}."
107
+ ) from e
108
+
109
+
110
+ def thrift_simple_json(obj):
111
+ simple = TSerialization.serialize(
112
+ obj, protocol_factory=TSimpleJSONProtocolFactory()
113
+ )
114
+ parsed = json.loads(simple)
115
+ return json.dumps(parsed, indent=2, sort_keys=True)