awx-zipline-ai 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. agent/__init__.py +1 -0
  2. agent/constants.py +15 -0
  3. agent/ttypes.py +1684 -0
  4. ai/__init__.py +0 -0
  5. ai/chronon/__init__.py +0 -0
  6. ai/chronon/airflow_helpers.py +251 -0
  7. ai/chronon/api/__init__.py +1 -0
  8. ai/chronon/api/common/__init__.py +1 -0
  9. ai/chronon/api/common/constants.py +15 -0
  10. ai/chronon/api/common/ttypes.py +1844 -0
  11. ai/chronon/api/constants.py +15 -0
  12. ai/chronon/api/ttypes.py +3624 -0
  13. ai/chronon/cli/compile/column_hashing.py +313 -0
  14. ai/chronon/cli/compile/compile_context.py +177 -0
  15. ai/chronon/cli/compile/compiler.py +160 -0
  16. ai/chronon/cli/compile/conf_validator.py +590 -0
  17. ai/chronon/cli/compile/display/class_tracker.py +112 -0
  18. ai/chronon/cli/compile/display/compile_status.py +95 -0
  19. ai/chronon/cli/compile/display/compiled_obj.py +12 -0
  20. ai/chronon/cli/compile/display/console.py +3 -0
  21. ai/chronon/cli/compile/display/diff_result.py +46 -0
  22. ai/chronon/cli/compile/fill_templates.py +40 -0
  23. ai/chronon/cli/compile/parse_configs.py +141 -0
  24. ai/chronon/cli/compile/parse_teams.py +238 -0
  25. ai/chronon/cli/compile/serializer.py +115 -0
  26. ai/chronon/cli/git_utils.py +156 -0
  27. ai/chronon/cli/logger.py +61 -0
  28. ai/chronon/constants.py +3 -0
  29. ai/chronon/eval/__init__.py +122 -0
  30. ai/chronon/eval/query_parsing.py +19 -0
  31. ai/chronon/eval/sample_tables.py +100 -0
  32. ai/chronon/eval/table_scan.py +186 -0
  33. ai/chronon/fetcher/__init__.py +1 -0
  34. ai/chronon/fetcher/constants.py +15 -0
  35. ai/chronon/fetcher/ttypes.py +127 -0
  36. ai/chronon/group_by.py +692 -0
  37. ai/chronon/hub/__init__.py +1 -0
  38. ai/chronon/hub/constants.py +15 -0
  39. ai/chronon/hub/ttypes.py +1228 -0
  40. ai/chronon/join.py +566 -0
  41. ai/chronon/logger.py +24 -0
  42. ai/chronon/model.py +35 -0
  43. ai/chronon/observability/__init__.py +1 -0
  44. ai/chronon/observability/constants.py +15 -0
  45. ai/chronon/observability/ttypes.py +2192 -0
  46. ai/chronon/orchestration/__init__.py +1 -0
  47. ai/chronon/orchestration/constants.py +15 -0
  48. ai/chronon/orchestration/ttypes.py +4406 -0
  49. ai/chronon/planner/__init__.py +1 -0
  50. ai/chronon/planner/constants.py +15 -0
  51. ai/chronon/planner/ttypes.py +1686 -0
  52. ai/chronon/query.py +126 -0
  53. ai/chronon/repo/__init__.py +40 -0
  54. ai/chronon/repo/aws.py +298 -0
  55. ai/chronon/repo/cluster.py +65 -0
  56. ai/chronon/repo/compile.py +56 -0
  57. ai/chronon/repo/constants.py +164 -0
  58. ai/chronon/repo/default_runner.py +291 -0
  59. ai/chronon/repo/explore.py +421 -0
  60. ai/chronon/repo/extract_objects.py +137 -0
  61. ai/chronon/repo/gcp.py +585 -0
  62. ai/chronon/repo/gitpython_utils.py +14 -0
  63. ai/chronon/repo/hub_runner.py +171 -0
  64. ai/chronon/repo/hub_uploader.py +108 -0
  65. ai/chronon/repo/init.py +53 -0
  66. ai/chronon/repo/join_backfill.py +105 -0
  67. ai/chronon/repo/run.py +293 -0
  68. ai/chronon/repo/serializer.py +141 -0
  69. ai/chronon/repo/team_json_utils.py +46 -0
  70. ai/chronon/repo/utils.py +472 -0
  71. ai/chronon/repo/zipline.py +51 -0
  72. ai/chronon/repo/zipline_hub.py +105 -0
  73. ai/chronon/resources/gcp/README.md +174 -0
  74. ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
  75. ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
  76. ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
  77. ai/chronon/resources/gcp/joins/test/data.py +30 -0
  78. ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
  79. ai/chronon/resources/gcp/sources/test/data.py +23 -0
  80. ai/chronon/resources/gcp/teams.py +70 -0
  81. ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
  82. ai/chronon/source.py +88 -0
  83. ai/chronon/staging_query.py +185 -0
  84. ai/chronon/types.py +57 -0
  85. ai/chronon/utils.py +557 -0
  86. ai/chronon/windows.py +50 -0
  87. awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
  88. awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
  89. awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
  90. awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
  91. awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
  92. awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
  93. jars/__init__.py +0 -0
@@ -0,0 +1,156 @@
1
+ import subprocess
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ from ai.chronon.cli.logger import get_logger
7
+
8
+ logger = get_logger()
9
+
10
+
11
+ def get_current_branch() -> str:
12
+
13
+ try:
14
+ subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL)
15
+
16
+ return (
17
+ subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
18
+ .decode("utf-8")
19
+ .strip()
20
+ )
21
+
22
+ except subprocess.CalledProcessError as e:
23
+
24
+ try:
25
+ head_file = Path(".git/HEAD").resolve()
26
+
27
+ if head_file.exists():
28
+ content = head_file.read_text().strip()
29
+
30
+ if content.startswith("ref: refs/heads/"):
31
+ return content.split("/")[-1]
32
+
33
+ except Exception:
34
+ pass
35
+
36
+ print(
37
+ f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else 'Not a git repository or no commits'}",
38
+ file=sys.stderr,
39
+ )
40
+
41
+ raise
42
+
43
+
44
+ def get_fork_point(base_branch: str = "main") -> str:
45
+ try:
46
+
47
+ return (
48
+ subprocess.check_output(["git", "merge-base", base_branch, "HEAD"])
49
+ .decode("utf-8")
50
+ .strip()
51
+ )
52
+
53
+ except subprocess.CalledProcessError as e:
54
+ print(
55
+ f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else f'Could not determine fork point from {base_branch}'}",
56
+ file=sys.stderr,
57
+ )
58
+ raise
59
+
60
+
61
+ def get_file_content_at_commit(file_path: str, commit: str) -> Optional[str]:
62
+ try:
63
+ return subprocess.check_output(["git", "show", f"{commit}:{file_path}"]).decode(
64
+ "utf-8"
65
+ )
66
+ except subprocess.CalledProcessError:
67
+ return None
68
+
69
+
70
+ def get_current_file_content(file_path: str) -> Optional[str]:
71
+ try:
72
+ return Path(file_path).read_text()
73
+ except Exception:
74
+ return None
75
+
76
+
77
+ def get_changes_since_commit(path: str, commit: Optional[str] = None) -> List[str]:
78
+
79
+ path = Path(path).resolve()
80
+ if not path.exists():
81
+ print(f"⛔ Error: Path does not exist: {path}", file=sys.stderr)
82
+ raise ValueError(f"Path does not exist: {path}")
83
+
84
+ try:
85
+ subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL)
86
+ commit_range = f"{commit}..HEAD" if commit else "HEAD"
87
+
88
+ changes = (
89
+ subprocess.check_output(
90
+ ["git", "diff", "--name-only", commit_range, "--", str(path)]
91
+ )
92
+ .decode("utf-8")
93
+ .splitlines()
94
+ )
95
+
96
+ except subprocess.CalledProcessError:
97
+
98
+ changes = (
99
+ subprocess.check_output(["git", "diff", "--name-only", "--", str(path)])
100
+ .decode("utf-8")
101
+ .splitlines()
102
+ )
103
+
104
+ try:
105
+
106
+ untracked = (
107
+ subprocess.check_output(
108
+ ["git", "ls-files", "--others", "--exclude-standard", str(path)]
109
+ )
110
+ .decode("utf-8")
111
+ .splitlines()
112
+ )
113
+
114
+ changes.extend(untracked)
115
+
116
+ except subprocess.CalledProcessError as e:
117
+
118
+ print(
119
+ f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else 'Failed to get untracked files'}",
120
+ file=sys.stderr,
121
+ )
122
+
123
+ raise
124
+
125
+ logger.info(f"Changes since commit: {changes}")
126
+
127
+ return [change for change in changes if change.strip()]
128
+
129
+
130
+ def get_changes_since_fork(path: str, base_branch: str = "main") -> List[str]:
131
+ try:
132
+ fork_point = get_fork_point(base_branch)
133
+ path = Path(path).resolve()
134
+
135
+ # Get all potential changes
136
+ changed_files = set(get_changes_since_commit(str(path), fork_point))
137
+
138
+ # Filter out files that are identical to fork point
139
+ real_changes = []
140
+ for file in changed_files:
141
+ fork_content = get_file_content_at_commit(file, fork_point)
142
+ current_content = get_current_file_content(file)
143
+
144
+ if fork_content != current_content:
145
+ real_changes.append(file)
146
+
147
+ logger.info(f"Changes since fork: {real_changes}")
148
+
149
+ return real_changes
150
+
151
+ except subprocess.CalledProcessError as e:
152
+ print(
153
+ f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else f'Failed to get changes since fork from {base_branch}'}",
154
+ file=sys.stderr,
155
+ )
156
+ raise
@@ -0,0 +1,61 @@
1
+ import logging
2
+ import sys
3
+ from datetime import datetime
4
+
5
+ TIME_COLOR = "\033[36m" # Cyan
6
+ LEVEL_COLORS = {
7
+ logging.DEBUG: "\033[36m", # Cyan
8
+ logging.INFO: "\033[32m", # Green
9
+ logging.WARNING: "\033[33m", # Yellow
10
+ logging.ERROR: "\033[31m", # Red
11
+ logging.CRITICAL: "\033[41m", # White on Red
12
+ }
13
+ FILE_COLOR = "\033[35m" # Purple
14
+ RESET = "\033[0m"
15
+
16
+
17
+ class ColorFormatter(logging.Formatter):
18
+
19
+ def format(self, record):
20
+
21
+ time_str = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
22
+ level_color = LEVEL_COLORS.get(record.levelno)
23
+
24
+ return (
25
+ f"{TIME_COLOR}{time_str}{RESET} "
26
+ f"{level_color}{record.levelname}{RESET} "
27
+ f"{FILE_COLOR}{record.filename}:{record.lineno}{RESET} - "
28
+ f"{record.getMessage()}"
29
+ )
30
+
31
+
32
+ def get_logger(log_level=logging.INFO):
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # no need to reset if a handler already exists
36
+ if not logger.hasHandlers():
37
+ handler = logging.StreamHandler(sys.stdout)
38
+ handler.setFormatter(ColorFormatter())
39
+
40
+ logger.addHandler(handler)
41
+ logger.setLevel(log_level)
42
+
43
+ return logger
44
+
45
+
46
+ def red(text):
47
+ return f"\033[1;91m{text}\033[0m"
48
+
49
+
50
+ def green(text):
51
+ return f"\033[1;92m{text}\033[0m"
52
+
53
+
54
+ def require(cond, message):
55
+ if not cond:
56
+ print(f"X: {message}")
57
+ sys.exit(1)
58
+
59
+
60
+ def done(cond, message):
61
+ print(f"DONE: {message}")
@@ -0,0 +1,3 @@
1
+ AIRFLOW_DEPENDENCIES_KEY = "airflowDependencies"
2
+ AIRFLOW_LABEL_DEPENDENCIES_KEY = "airflowLabelDependencies"
3
+ PARTITION_COLUMN_KEY = "spark.chronon.partition.column"
@@ -0,0 +1,122 @@
1
+ from typing import Any, List
2
+
3
+ from pyspark.sql import DataFrame, SparkSession
4
+
5
+ import ai.chronon.api.ttypes as chronon
6
+ from ai.chronon.eval.query_parsing import get_tables_from_query
7
+ from ai.chronon.eval.sample_tables import sample_tables, sample_with_query
8
+ from ai.chronon.eval.table_scan import (
9
+ TableScan,
10
+ clean_table_name,
11
+ table_scans_in_group_by,
12
+ table_scans_in_join,
13
+ table_scans_in_source,
14
+ )
15
+
16
+
17
+ def eval(obj: Any) -> List[DataFrame]:
18
+
19
+ if isinstance(obj, chronon.Source):
20
+ return _run_table_scans(table_scans_in_source(obj))
21
+
22
+ elif isinstance(obj, chronon.GroupBy):
23
+ return _run_table_scans(table_scans_in_group_by(obj))
24
+
25
+ elif isinstance(obj, chronon.Join):
26
+ return _run_table_scans(table_scans_in_join(obj))
27
+
28
+ elif isinstance(obj, chronon.StagingQuery):
29
+ return _sample_and_eval_query(_render_staging_query(obj))
30
+
31
+ elif isinstance(obj, str):
32
+ has_white_spaces = any(char.isspace() for char in obj)
33
+ if has_white_spaces:
34
+ return _sample_and_eval_query(obj)
35
+ else:
36
+ return _sample_and_eval_query(f"SELECT * FROM {obj} LIMIT 1000")
37
+
38
+ elif isinstance(obj, chronon.Model):
39
+ _run_table_scans(table_scans_in_source(obj.source))
40
+
41
+ else:
42
+ raise Exception(f"Unsupported object type for: {obj}")
43
+
44
+
45
+ def _sample_and_eval_query(query: str) -> DataFrame:
46
+
47
+ table_names = get_tables_from_query(query)
48
+ sample_tables(table_names)
49
+
50
+ clean_query = query
51
+ for table_name in table_names:
52
+ clean_name = clean_table_name(table_name)
53
+ clean_query = clean_query.replace(table_name, clean_name)
54
+
55
+ return _run_query(clean_query)
56
+
57
+
58
+ def _run_query(query: str) -> DataFrame:
59
+ spark = _get_spark()
60
+ return spark.sql(query)
61
+
62
+
63
+ def _sample_table_scan(table_scan: TableScan) -> str:
64
+ table = table_scan.table
65
+ output_path = table_scan.output_path()
66
+ query = table_scan.raw_scan_query(local_table_view=False)
67
+ return sample_with_query(table, query, output_path)
68
+
69
+
70
+ def _run_table_scans(table_scans: List[TableScan]) -> List[DataFrame]:
71
+ spark = _get_spark()
72
+ df_list = []
73
+
74
+ for table_scan in table_scans:
75
+ output_path = table_scan.output_path()
76
+
77
+ status = " (exists)" if output_path.exists() else ""
78
+ print(
79
+ f"table: {table_scan.table}\n"
80
+ f"view: {table_scan.view_name()}\n"
81
+ f"local_file: {output_path}{status}\n"
82
+ )
83
+
84
+ for table_scan in table_scans:
85
+
86
+ view_name = table_scan.view_name()
87
+ output_path = _sample_table_scan(table_scan)
88
+
89
+ print(f"Creating view {view_name} from parquet file {output_path}")
90
+ df = spark.read.parquet(str(output_path))
91
+ df.createOrReplaceTempView(view_name)
92
+
93
+ scan_query = table_scan.scan_query(local_table_view=True)
94
+ print(f"Scanning {table_scan.table} with query: \n{scan_query}\n")
95
+ df = spark.sql(scan_query)
96
+ df.show(5)
97
+ df_list.append(df)
98
+
99
+ return df_list
100
+
101
+
102
+ _spark: SparkSession = None
103
+
104
+
105
+ def _get_spark() -> SparkSession:
106
+ global _spark
107
+ if not _spark:
108
+ _spark = (
109
+ SparkSession.builder.appName("Chronon Evaluator")
110
+ .config("spark.driver.bindAddress", "127.0.0.1")
111
+ .config("spark.driver.host", "127.0.0.1")
112
+ .config("spark.sql.parquet.columnarReaderBatchSize", "16")
113
+ .config("spark.executor.memory", "4g")
114
+ .config("spark.driver.memory", "4g")
115
+ .config("spark.driver.maxResultSize", "2g")
116
+ .getOrCreate()
117
+ )
118
+ return _spark
119
+
120
+
121
+ def _render_staging_query(staging_query: chronon.StagingQuery) -> str:
122
+ raise NotImplementedError("Staging query evals are not yet implemented")
@@ -0,0 +1,19 @@
1
+ from typing import List
2
+
3
+
4
+ def get_tables_from_query(sql_query) -> List[str]:
5
+ import sqlglot
6
+
7
+ # Parse the query
8
+ parsed = sqlglot.parse_one(sql_query, dialect="bigquery")
9
+
10
+ # Extract all table references
11
+ tables = parsed.find_all(sqlglot.exp.Table)
12
+
13
+ table_names = []
14
+ for table in tables:
15
+ name_parts = [part for part in [table.catalog, table.db, table.name] if part]
16
+ table_name = ".".join(name_parts)
17
+ table_names.append(table_name)
18
+
19
+ return table_names
@@ -0,0 +1,100 @@
1
+ import os
2
+ from typing import List
3
+
4
+ from ai.chronon.eval.table_scan import local_warehouse
5
+
6
+
7
+ def sample_with_query(table, query, output_path) -> str:
8
+ # if file exists, skip
9
+ if os.path.exists(output_path):
10
+ print(f"File {output_path} already exists. Skipping sampling.")
11
+ return output_path
12
+
13
+ raw_scan_query = query
14
+ print(f"Sampling {table} with query: {raw_scan_query}")
15
+
16
+ _sample_internal(raw_scan_query, output_path)
17
+ return output_path
18
+
19
+
20
+ def sample_tables(table_names: List[str]) -> None:
21
+
22
+ for table in table_names:
23
+ query = f"SELECT * FROM {table} LIMIT 10000"
24
+ sample_with_query(table, query, local_warehouse / f"{table}.parquet")
25
+
26
+
27
+ _sampling_engine = os.getenv("CHRONON_SAMPLING_ENGINE", "bigquery")
28
+
29
+
30
+ def _sample_internal(query, output_path) -> str:
31
+ if _sampling_engine == "bigquery":
32
+ _sample_bigquery(query, output_path)
33
+ elif _sampling_engine == "trino":
34
+ _sample_trino(query, output_path)
35
+ else:
36
+ raise ValueError("Invalid sampling engine")
37
+
38
+
39
+ def _sample_trino(query, output_path):
40
+ raise NotImplementedError("Trino sampling is not yet implemented")
41
+
42
+
43
+ def _sample_bigquery(query, output_path):
44
+
45
+ from google.cloud import bigquery
46
+
47
+ project_id = os.getenv("GCP_PROJECT_ID")
48
+ assert project_id, "Please set the GCP_PROJECT_ID environment variable"
49
+
50
+ client = bigquery.Client(project=project_id)
51
+
52
+ results = client.query_and_wait(query)
53
+
54
+ df = results.to_dataframe()
55
+ df.to_parquet(output_path)
56
+
57
+
58
+ def _sample_bigquery_fast(query, destination_path):
59
+ import os
60
+
61
+ import pyarrow.parquet as pq
62
+ from google.cloud import bigquery
63
+ from google.cloud.bigquery_storage import BigQueryReadClient
64
+ from google.cloud.bigquery_storage_v1.types import DataFormat, ReadSession
65
+
66
+ project_id = os.getenv("GCP_PROJECT_ID")
67
+ assert project_id, "Please set the GCP_PROJECT_ID environment variable"
68
+
69
+ client = bigquery.Client(project=project_id)
70
+ bqstorage_client = BigQueryReadClient()
71
+
72
+ # Create query job
73
+ query_job = client.query(query)
74
+ table_ref = query_job.destination
75
+
76
+ # Create read session
77
+ read_session = ReadSession()
78
+ read_session.table = table_ref.to_bqstorage()
79
+ read_session.data_format = DataFormat.ARROW
80
+
81
+ print("Fetching from BigQuery... (this might take a while)")
82
+
83
+ session = bqstorage_client.create_read_session(
84
+ parent=f"projects/{client.project}",
85
+ read_session=read_session,
86
+ max_stream_count=1,
87
+ )
88
+
89
+ print("Writing to local parquet file...")
90
+
91
+ # Read using Arrow
92
+ stream = bqstorage_client.read_rows(session.streams[0].name)
93
+ table = stream.to_arrow(read_session=session)
94
+
95
+ # Write to Parquet directly
96
+ pq.write_table(table, destination_path)
97
+
98
+ print(f"Wrote results to {destination_path}")
99
+
100
+ return destination_path
@@ -0,0 +1,186 @@
1
+ import hashlib
2
+ import os
3
+ import re
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, timedelta
6
+ from pathlib import Path
7
+ from typing import List, Tuple
8
+
9
+ import ai.chronon.api.ttypes as chronon
10
+
11
+
12
+ def clean_table_name(name: str) -> str:
13
+ return re.sub(r"[^a-zA-Z0-9_]", "_", name)
14
+
15
+
16
+ local_warehouse = Path(os.getenv("CHRONON_ROOT", os.getcwd())) / "local_warehouse"
17
+ limit = int(os.getenv("SAMPLE_LIMIT", "100"))
18
+ # create local_warehouse if it doesn't exist
19
+ local_warehouse.mkdir(parents=True, exist_ok=True)
20
+
21
+
22
+ @dataclass
23
+ class TableScan:
24
+ table: str
25
+ partition_col: str
26
+ partition_date: str
27
+ query: chronon.Query
28
+ is_mutations: bool = False
29
+
30
+ def output_path(self) -> str:
31
+ return Path(local_warehouse) / f"{self.view_name()}.parquet"
32
+
33
+ def view_name(self) -> str:
34
+ return clean_table_name(self.table) + "_" + self.where_id()
35
+
36
+ def table_name(self, local_table_view) -> str:
37
+ return self.view_name() if local_table_view else self.table
38
+
39
+ def where_id(self) -> str:
40
+ return "_" + hashlib.md5(self.where_block().encode()).hexdigest()[:3]
41
+
42
+ def where_block(self) -> str:
43
+ wheres = []
44
+ partition_scan = f"{self.partition_col} = '{self.partition_date}'"
45
+ wheres.append(partition_scan)
46
+
47
+ if self.query.wheres:
48
+ wheres.extend(self.query.wheres)
49
+
50
+ return " AND\n ".join([f"({where})" for where in wheres])
51
+
52
+ def raw_scan_query(self, local_table_view: bool = True) -> str:
53
+ return f"""
54
+ SELECT * FROM {self.table_name(local_table_view)}
55
+ WHERE
56
+ {self.where_block()}
57
+ LIMIT {limit}
58
+ """
59
+
60
+ def scan_query(self, local_table_view=True) -> str:
61
+ selects = []
62
+ base_selects = self.query.selects.copy()
63
+
64
+ if self.is_mutations:
65
+ base_selects["is_before"] = coalesce(self.query.reversalColumn, "is_before")
66
+ base_selects["mutation_ts"] = coalesce(
67
+ self.query.mutationTimeColumn, "mutation_ts"
68
+ )
69
+
70
+ if self.query.timeColumn:
71
+ base_selects["ts"] = coalesce(self.query.timeColumn, "ts")
72
+
73
+ for k, v in base_selects.items():
74
+ selects.append(f"{v} as {k}")
75
+ select_clauses = ",\n ".join(selects)
76
+
77
+ return f"""
78
+ SELECT
79
+ {select_clauses}
80
+ FROM
81
+ {self.table_name(local_table_view)}
82
+ WHERE
83
+ {self.where_block()}
84
+ LIMIT
85
+ {limit}
86
+ """
87
+
88
+
89
+ # TODO: use teams.py to get the default date column
90
+ DEFAULT_DATE_COLUMN = "_date"
91
+ DEFAULT_DATE_FORMAT = "%Y-%m-%d"
92
+
93
+ two_days_ago = (datetime.now() - timedelta(days=2)).strftime(DEFAULT_DATE_FORMAT)
94
+
95
+ _sample_date = os.getenv("SAMPLE_DATE", two_days_ago)
96
+
97
+
98
+ def get_date(query: chronon.Query) -> Tuple[str, str]:
99
+ assert query and query.selects, "please specify source.query.selects"
100
+
101
+ partition_col = query.selects.get("ds", DEFAULT_DATE_COLUMN)
102
+ partition_date = coalesce(query.endPartition, _sample_date)
103
+
104
+ return (partition_col, partition_date)
105
+
106
+
107
+ def coalesce(*args):
108
+ for arg in args:
109
+ if arg:
110
+ return arg
111
+
112
+
113
+ def table_scans_in_source(source: chronon.Source) -> List[TableScan]:
114
+ result = []
115
+
116
+ if not source:
117
+ return result
118
+
119
+ if source.entities:
120
+ query: chronon.Query = source.entities.query
121
+ col, date = get_date(query)
122
+
123
+ snapshot = TableScan(source.entities.snapshotTable, col, date, query)
124
+ result.append(snapshot)
125
+
126
+ if source.entities.mutationTable:
127
+ mutations = TableScan(source.entities.mutationTable, col, date, query, True)
128
+ result.append(mutations)
129
+
130
+ if source.events:
131
+ query = source.events.query
132
+ col, date = get_date(query)
133
+ table = TableScan(source.events.table, col, date, query)
134
+ result.append(table)
135
+
136
+ if source.joinSource:
137
+ result.extend(table_scans_in_source(source.joinSource.join.left))
138
+
139
+ return result
140
+
141
+
142
+ def table_scans_in_sources(sources: List[chronon.Source]) -> List[TableScan]:
143
+ result = []
144
+
145
+ for source in sources:
146
+ result.extend(table_scans_in_source(source))
147
+
148
+ return result
149
+
150
+
151
+ def table_scans_in_group_by(gb: chronon.GroupBy) -> List[TableScan]:
152
+ if not gb:
153
+ return []
154
+
155
+ return table_scans_in_sources(gb.sources)
156
+
157
+
158
+ def table_scans_in_join(join: chronon.Join) -> List[TableScan]:
159
+
160
+ result = []
161
+
162
+ if not join:
163
+ return result
164
+
165
+ result.extend(table_scans_in_source(join.left))
166
+
167
+ parts: List[chronon.JoinPart] = join.joinParts
168
+ if parts:
169
+ for part in parts:
170
+ result.extend(table_scans_in_group_by(part.groupBy))
171
+
172
+ bootstraps: List[chronon.BootstrapPart] = join.bootstrapParts
173
+ if bootstraps:
174
+ for bootstrap in bootstraps:
175
+ query = bootstrap.query
176
+ col, date = get_date(query)
177
+ bootstrap = TableScan(bootstrap.table, col, date, query)
178
+
179
+ result.append(bootstrap)
180
+
181
+ if join.labelParts:
182
+ labelParts: List[chronon.JoinPart] = join.labelParts.labels
183
+ for part in labelParts:
184
+ result.extend(table_scans_in_sources(part.groupBy))
185
+
186
+ return result
@@ -0,0 +1 @@
1
+ __all__ = ['ttypes', 'constants']
@@ -0,0 +1,15 @@
1
+ #
2
+ # Autogenerated by Thrift Compiler (0.22.0)
3
+ #
4
+ # DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
5
+ #
6
+ # options string: py
7
+ #
8
+
9
+ from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException
10
+ from thrift.protocol.TProtocol import TProtocolException
11
+ from thrift.TRecursive import fix_spec
12
+ from uuid import UUID
13
+
14
+ import sys
15
+ from .ttypes import *