awx-zipline-ai 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent/__init__.py +1 -0
- agent/constants.py +15 -0
- agent/ttypes.py +1684 -0
- ai/__init__.py +0 -0
- ai/chronon/__init__.py +0 -0
- ai/chronon/airflow_helpers.py +251 -0
- ai/chronon/api/__init__.py +1 -0
- ai/chronon/api/common/__init__.py +1 -0
- ai/chronon/api/common/constants.py +15 -0
- ai/chronon/api/common/ttypes.py +1844 -0
- ai/chronon/api/constants.py +15 -0
- ai/chronon/api/ttypes.py +3624 -0
- ai/chronon/cli/compile/column_hashing.py +313 -0
- ai/chronon/cli/compile/compile_context.py +177 -0
- ai/chronon/cli/compile/compiler.py +160 -0
- ai/chronon/cli/compile/conf_validator.py +590 -0
- ai/chronon/cli/compile/display/class_tracker.py +112 -0
- ai/chronon/cli/compile/display/compile_status.py +95 -0
- ai/chronon/cli/compile/display/compiled_obj.py +12 -0
- ai/chronon/cli/compile/display/console.py +3 -0
- ai/chronon/cli/compile/display/diff_result.py +46 -0
- ai/chronon/cli/compile/fill_templates.py +40 -0
- ai/chronon/cli/compile/parse_configs.py +141 -0
- ai/chronon/cli/compile/parse_teams.py +238 -0
- ai/chronon/cli/compile/serializer.py +115 -0
- ai/chronon/cli/git_utils.py +156 -0
- ai/chronon/cli/logger.py +61 -0
- ai/chronon/constants.py +3 -0
- ai/chronon/eval/__init__.py +122 -0
- ai/chronon/eval/query_parsing.py +19 -0
- ai/chronon/eval/sample_tables.py +100 -0
- ai/chronon/eval/table_scan.py +186 -0
- ai/chronon/fetcher/__init__.py +1 -0
- ai/chronon/fetcher/constants.py +15 -0
- ai/chronon/fetcher/ttypes.py +127 -0
- ai/chronon/group_by.py +692 -0
- ai/chronon/hub/__init__.py +1 -0
- ai/chronon/hub/constants.py +15 -0
- ai/chronon/hub/ttypes.py +1228 -0
- ai/chronon/join.py +566 -0
- ai/chronon/logger.py +24 -0
- ai/chronon/model.py +35 -0
- ai/chronon/observability/__init__.py +1 -0
- ai/chronon/observability/constants.py +15 -0
- ai/chronon/observability/ttypes.py +2192 -0
- ai/chronon/orchestration/__init__.py +1 -0
- ai/chronon/orchestration/constants.py +15 -0
- ai/chronon/orchestration/ttypes.py +4406 -0
- ai/chronon/planner/__init__.py +1 -0
- ai/chronon/planner/constants.py +15 -0
- ai/chronon/planner/ttypes.py +1686 -0
- ai/chronon/query.py +126 -0
- ai/chronon/repo/__init__.py +40 -0
- ai/chronon/repo/aws.py +298 -0
- ai/chronon/repo/cluster.py +65 -0
- ai/chronon/repo/compile.py +56 -0
- ai/chronon/repo/constants.py +164 -0
- ai/chronon/repo/default_runner.py +291 -0
- ai/chronon/repo/explore.py +421 -0
- ai/chronon/repo/extract_objects.py +137 -0
- ai/chronon/repo/gcp.py +585 -0
- ai/chronon/repo/gitpython_utils.py +14 -0
- ai/chronon/repo/hub_runner.py +171 -0
- ai/chronon/repo/hub_uploader.py +108 -0
- ai/chronon/repo/init.py +53 -0
- ai/chronon/repo/join_backfill.py +105 -0
- ai/chronon/repo/run.py +293 -0
- ai/chronon/repo/serializer.py +141 -0
- ai/chronon/repo/team_json_utils.py +46 -0
- ai/chronon/repo/utils.py +472 -0
- ai/chronon/repo/zipline.py +51 -0
- ai/chronon/repo/zipline_hub.py +105 -0
- ai/chronon/resources/gcp/README.md +174 -0
- ai/chronon/resources/gcp/group_bys/test/__init__.py +0 -0
- ai/chronon/resources/gcp/group_bys/test/data.py +34 -0
- ai/chronon/resources/gcp/joins/test/__init__.py +0 -0
- ai/chronon/resources/gcp/joins/test/data.py +30 -0
- ai/chronon/resources/gcp/sources/test/__init__.py +0 -0
- ai/chronon/resources/gcp/sources/test/data.py +23 -0
- ai/chronon/resources/gcp/teams.py +70 -0
- ai/chronon/resources/gcp/zipline-cli-install.sh +54 -0
- ai/chronon/source.py +88 -0
- ai/chronon/staging_query.py +185 -0
- ai/chronon/types.py +57 -0
- ai/chronon/utils.py +557 -0
- ai/chronon/windows.py +50 -0
- awx_zipline_ai-0.2.0.dist-info/METADATA +173 -0
- awx_zipline_ai-0.2.0.dist-info/RECORD +93 -0
- awx_zipline_ai-0.2.0.dist-info/WHEEL +5 -0
- awx_zipline_ai-0.2.0.dist-info/entry_points.txt +2 -0
- awx_zipline_ai-0.2.0.dist-info/licenses/LICENSE +202 -0
- awx_zipline_ai-0.2.0.dist-info/top_level.txt +3 -0
- jars/__init__.py +0 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import subprocess
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import List, Optional
|
|
5
|
+
|
|
6
|
+
from ai.chronon.cli.logger import get_logger
|
|
7
|
+
|
|
8
|
+
logger = get_logger()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def get_current_branch() -> str:
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL)
|
|
15
|
+
|
|
16
|
+
return (
|
|
17
|
+
subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
|
18
|
+
.decode("utf-8")
|
|
19
|
+
.strip()
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
except subprocess.CalledProcessError as e:
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
head_file = Path(".git/HEAD").resolve()
|
|
26
|
+
|
|
27
|
+
if head_file.exists():
|
|
28
|
+
content = head_file.read_text().strip()
|
|
29
|
+
|
|
30
|
+
if content.startswith("ref: refs/heads/"):
|
|
31
|
+
return content.split("/")[-1]
|
|
32
|
+
|
|
33
|
+
except Exception:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
print(
|
|
37
|
+
f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else 'Not a git repository or no commits'}",
|
|
38
|
+
file=sys.stderr,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
raise
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def get_fork_point(base_branch: str = "main") -> str:
|
|
45
|
+
try:
|
|
46
|
+
|
|
47
|
+
return (
|
|
48
|
+
subprocess.check_output(["git", "merge-base", base_branch, "HEAD"])
|
|
49
|
+
.decode("utf-8")
|
|
50
|
+
.strip()
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
except subprocess.CalledProcessError as e:
|
|
54
|
+
print(
|
|
55
|
+
f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else f'Could not determine fork point from {base_branch}'}",
|
|
56
|
+
file=sys.stderr,
|
|
57
|
+
)
|
|
58
|
+
raise
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def get_file_content_at_commit(file_path: str, commit: str) -> Optional[str]:
|
|
62
|
+
try:
|
|
63
|
+
return subprocess.check_output(["git", "show", f"{commit}:{file_path}"]).decode(
|
|
64
|
+
"utf-8"
|
|
65
|
+
)
|
|
66
|
+
except subprocess.CalledProcessError:
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_current_file_content(file_path: str) -> Optional[str]:
|
|
71
|
+
try:
|
|
72
|
+
return Path(file_path).read_text()
|
|
73
|
+
except Exception:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def get_changes_since_commit(path: str, commit: Optional[str] = None) -> List[str]:
|
|
78
|
+
|
|
79
|
+
path = Path(path).resolve()
|
|
80
|
+
if not path.exists():
|
|
81
|
+
print(f"⛔ Error: Path does not exist: {path}", file=sys.stderr)
|
|
82
|
+
raise ValueError(f"Path does not exist: {path}")
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
subprocess.check_output(["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL)
|
|
86
|
+
commit_range = f"{commit}..HEAD" if commit else "HEAD"
|
|
87
|
+
|
|
88
|
+
changes = (
|
|
89
|
+
subprocess.check_output(
|
|
90
|
+
["git", "diff", "--name-only", commit_range, "--", str(path)]
|
|
91
|
+
)
|
|
92
|
+
.decode("utf-8")
|
|
93
|
+
.splitlines()
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
except subprocess.CalledProcessError:
|
|
97
|
+
|
|
98
|
+
changes = (
|
|
99
|
+
subprocess.check_output(["git", "diff", "--name-only", "--", str(path)])
|
|
100
|
+
.decode("utf-8")
|
|
101
|
+
.splitlines()
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
try:
|
|
105
|
+
|
|
106
|
+
untracked = (
|
|
107
|
+
subprocess.check_output(
|
|
108
|
+
["git", "ls-files", "--others", "--exclude-standard", str(path)]
|
|
109
|
+
)
|
|
110
|
+
.decode("utf-8")
|
|
111
|
+
.splitlines()
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
changes.extend(untracked)
|
|
115
|
+
|
|
116
|
+
except subprocess.CalledProcessError as e:
|
|
117
|
+
|
|
118
|
+
print(
|
|
119
|
+
f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else 'Failed to get untracked files'}",
|
|
120
|
+
file=sys.stderr,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
raise
|
|
124
|
+
|
|
125
|
+
logger.info(f"Changes since commit: {changes}")
|
|
126
|
+
|
|
127
|
+
return [change for change in changes if change.strip()]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def get_changes_since_fork(path: str, base_branch: str = "main") -> List[str]:
|
|
131
|
+
try:
|
|
132
|
+
fork_point = get_fork_point(base_branch)
|
|
133
|
+
path = Path(path).resolve()
|
|
134
|
+
|
|
135
|
+
# Get all potential changes
|
|
136
|
+
changed_files = set(get_changes_since_commit(str(path), fork_point))
|
|
137
|
+
|
|
138
|
+
# Filter out files that are identical to fork point
|
|
139
|
+
real_changes = []
|
|
140
|
+
for file in changed_files:
|
|
141
|
+
fork_content = get_file_content_at_commit(file, fork_point)
|
|
142
|
+
current_content = get_current_file_content(file)
|
|
143
|
+
|
|
144
|
+
if fork_content != current_content:
|
|
145
|
+
real_changes.append(file)
|
|
146
|
+
|
|
147
|
+
logger.info(f"Changes since fork: {real_changes}")
|
|
148
|
+
|
|
149
|
+
return real_changes
|
|
150
|
+
|
|
151
|
+
except subprocess.CalledProcessError as e:
|
|
152
|
+
print(
|
|
153
|
+
f"⛔ Error: {e.stderr.decode('utf-8') if e.stderr else f'Failed to get changes since fork from {base_branch}'}",
|
|
154
|
+
file=sys.stderr,
|
|
155
|
+
)
|
|
156
|
+
raise
|
ai/chronon/cli/logger.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import sys
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
|
|
5
|
+
TIME_COLOR = "\033[36m" # Cyan
|
|
6
|
+
LEVEL_COLORS = {
|
|
7
|
+
logging.DEBUG: "\033[36m", # Cyan
|
|
8
|
+
logging.INFO: "\033[32m", # Green
|
|
9
|
+
logging.WARNING: "\033[33m", # Yellow
|
|
10
|
+
logging.ERROR: "\033[31m", # Red
|
|
11
|
+
logging.CRITICAL: "\033[41m", # White on Red
|
|
12
|
+
}
|
|
13
|
+
FILE_COLOR = "\033[35m" # Purple
|
|
14
|
+
RESET = "\033[0m"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ColorFormatter(logging.Formatter):
|
|
18
|
+
|
|
19
|
+
def format(self, record):
|
|
20
|
+
|
|
21
|
+
time_str = datetime.fromtimestamp(record.created).strftime("%H:%M:%S")
|
|
22
|
+
level_color = LEVEL_COLORS.get(record.levelno)
|
|
23
|
+
|
|
24
|
+
return (
|
|
25
|
+
f"{TIME_COLOR}{time_str}{RESET} "
|
|
26
|
+
f"{level_color}{record.levelname}{RESET} "
|
|
27
|
+
f"{FILE_COLOR}{record.filename}:{record.lineno}{RESET} - "
|
|
28
|
+
f"{record.getMessage()}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_logger(log_level=logging.INFO):
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
# no need to reset if a handler already exists
|
|
36
|
+
if not logger.hasHandlers():
|
|
37
|
+
handler = logging.StreamHandler(sys.stdout)
|
|
38
|
+
handler.setFormatter(ColorFormatter())
|
|
39
|
+
|
|
40
|
+
logger.addHandler(handler)
|
|
41
|
+
logger.setLevel(log_level)
|
|
42
|
+
|
|
43
|
+
return logger
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def red(text):
|
|
47
|
+
return f"\033[1;91m{text}\033[0m"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def green(text):
|
|
51
|
+
return f"\033[1;92m{text}\033[0m"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def require(cond, message):
|
|
55
|
+
if not cond:
|
|
56
|
+
print(f"X: {message}")
|
|
57
|
+
sys.exit(1)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def done(cond, message):
|
|
61
|
+
print(f"DONE: {message}")
|
ai/chronon/constants.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
from typing import Any, List
|
|
2
|
+
|
|
3
|
+
from pyspark.sql import DataFrame, SparkSession
|
|
4
|
+
|
|
5
|
+
import ai.chronon.api.ttypes as chronon
|
|
6
|
+
from ai.chronon.eval.query_parsing import get_tables_from_query
|
|
7
|
+
from ai.chronon.eval.sample_tables import sample_tables, sample_with_query
|
|
8
|
+
from ai.chronon.eval.table_scan import (
|
|
9
|
+
TableScan,
|
|
10
|
+
clean_table_name,
|
|
11
|
+
table_scans_in_group_by,
|
|
12
|
+
table_scans_in_join,
|
|
13
|
+
table_scans_in_source,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def eval(obj: Any) -> List[DataFrame]:
|
|
18
|
+
|
|
19
|
+
if isinstance(obj, chronon.Source):
|
|
20
|
+
return _run_table_scans(table_scans_in_source(obj))
|
|
21
|
+
|
|
22
|
+
elif isinstance(obj, chronon.GroupBy):
|
|
23
|
+
return _run_table_scans(table_scans_in_group_by(obj))
|
|
24
|
+
|
|
25
|
+
elif isinstance(obj, chronon.Join):
|
|
26
|
+
return _run_table_scans(table_scans_in_join(obj))
|
|
27
|
+
|
|
28
|
+
elif isinstance(obj, chronon.StagingQuery):
|
|
29
|
+
return _sample_and_eval_query(_render_staging_query(obj))
|
|
30
|
+
|
|
31
|
+
elif isinstance(obj, str):
|
|
32
|
+
has_white_spaces = any(char.isspace() for char in obj)
|
|
33
|
+
if has_white_spaces:
|
|
34
|
+
return _sample_and_eval_query(obj)
|
|
35
|
+
else:
|
|
36
|
+
return _sample_and_eval_query(f"SELECT * FROM {obj} LIMIT 1000")
|
|
37
|
+
|
|
38
|
+
elif isinstance(obj, chronon.Model):
|
|
39
|
+
_run_table_scans(table_scans_in_source(obj.source))
|
|
40
|
+
|
|
41
|
+
else:
|
|
42
|
+
raise Exception(f"Unsupported object type for: {obj}")
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _sample_and_eval_query(query: str) -> DataFrame:
|
|
46
|
+
|
|
47
|
+
table_names = get_tables_from_query(query)
|
|
48
|
+
sample_tables(table_names)
|
|
49
|
+
|
|
50
|
+
clean_query = query
|
|
51
|
+
for table_name in table_names:
|
|
52
|
+
clean_name = clean_table_name(table_name)
|
|
53
|
+
clean_query = clean_query.replace(table_name, clean_name)
|
|
54
|
+
|
|
55
|
+
return _run_query(clean_query)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _run_query(query: str) -> DataFrame:
|
|
59
|
+
spark = _get_spark()
|
|
60
|
+
return spark.sql(query)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _sample_table_scan(table_scan: TableScan) -> str:
|
|
64
|
+
table = table_scan.table
|
|
65
|
+
output_path = table_scan.output_path()
|
|
66
|
+
query = table_scan.raw_scan_query(local_table_view=False)
|
|
67
|
+
return sample_with_query(table, query, output_path)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _run_table_scans(table_scans: List[TableScan]) -> List[DataFrame]:
|
|
71
|
+
spark = _get_spark()
|
|
72
|
+
df_list = []
|
|
73
|
+
|
|
74
|
+
for table_scan in table_scans:
|
|
75
|
+
output_path = table_scan.output_path()
|
|
76
|
+
|
|
77
|
+
status = " (exists)" if output_path.exists() else ""
|
|
78
|
+
print(
|
|
79
|
+
f"table: {table_scan.table}\n"
|
|
80
|
+
f"view: {table_scan.view_name()}\n"
|
|
81
|
+
f"local_file: {output_path}{status}\n"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
for table_scan in table_scans:
|
|
85
|
+
|
|
86
|
+
view_name = table_scan.view_name()
|
|
87
|
+
output_path = _sample_table_scan(table_scan)
|
|
88
|
+
|
|
89
|
+
print(f"Creating view {view_name} from parquet file {output_path}")
|
|
90
|
+
df = spark.read.parquet(str(output_path))
|
|
91
|
+
df.createOrReplaceTempView(view_name)
|
|
92
|
+
|
|
93
|
+
scan_query = table_scan.scan_query(local_table_view=True)
|
|
94
|
+
print(f"Scanning {table_scan.table} with query: \n{scan_query}\n")
|
|
95
|
+
df = spark.sql(scan_query)
|
|
96
|
+
df.show(5)
|
|
97
|
+
df_list.append(df)
|
|
98
|
+
|
|
99
|
+
return df_list
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
_spark: SparkSession = None
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_spark() -> SparkSession:
|
|
106
|
+
global _spark
|
|
107
|
+
if not _spark:
|
|
108
|
+
_spark = (
|
|
109
|
+
SparkSession.builder.appName("Chronon Evaluator")
|
|
110
|
+
.config("spark.driver.bindAddress", "127.0.0.1")
|
|
111
|
+
.config("spark.driver.host", "127.0.0.1")
|
|
112
|
+
.config("spark.sql.parquet.columnarReaderBatchSize", "16")
|
|
113
|
+
.config("spark.executor.memory", "4g")
|
|
114
|
+
.config("spark.driver.memory", "4g")
|
|
115
|
+
.config("spark.driver.maxResultSize", "2g")
|
|
116
|
+
.getOrCreate()
|
|
117
|
+
)
|
|
118
|
+
return _spark
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _render_staging_query(staging_query: chronon.StagingQuery) -> str:
|
|
122
|
+
raise NotImplementedError("Staging query evals are not yet implemented")
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_tables_from_query(sql_query) -> List[str]:
|
|
5
|
+
import sqlglot
|
|
6
|
+
|
|
7
|
+
# Parse the query
|
|
8
|
+
parsed = sqlglot.parse_one(sql_query, dialect="bigquery")
|
|
9
|
+
|
|
10
|
+
# Extract all table references
|
|
11
|
+
tables = parsed.find_all(sqlglot.exp.Table)
|
|
12
|
+
|
|
13
|
+
table_names = []
|
|
14
|
+
for table in tables:
|
|
15
|
+
name_parts = [part for part in [table.catalog, table.db, table.name] if part]
|
|
16
|
+
table_name = ".".join(name_parts)
|
|
17
|
+
table_names.append(table_name)
|
|
18
|
+
|
|
19
|
+
return table_names
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
from ai.chronon.eval.table_scan import local_warehouse
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def sample_with_query(table, query, output_path) -> str:
|
|
8
|
+
# if file exists, skip
|
|
9
|
+
if os.path.exists(output_path):
|
|
10
|
+
print(f"File {output_path} already exists. Skipping sampling.")
|
|
11
|
+
return output_path
|
|
12
|
+
|
|
13
|
+
raw_scan_query = query
|
|
14
|
+
print(f"Sampling {table} with query: {raw_scan_query}")
|
|
15
|
+
|
|
16
|
+
_sample_internal(raw_scan_query, output_path)
|
|
17
|
+
return output_path
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def sample_tables(table_names: List[str]) -> None:
|
|
21
|
+
|
|
22
|
+
for table in table_names:
|
|
23
|
+
query = f"SELECT * FROM {table} LIMIT 10000"
|
|
24
|
+
sample_with_query(table, query, local_warehouse / f"{table}.parquet")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_sampling_engine = os.getenv("CHRONON_SAMPLING_ENGINE", "bigquery")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _sample_internal(query, output_path) -> str:
|
|
31
|
+
if _sampling_engine == "bigquery":
|
|
32
|
+
_sample_bigquery(query, output_path)
|
|
33
|
+
elif _sampling_engine == "trino":
|
|
34
|
+
_sample_trino(query, output_path)
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError("Invalid sampling engine")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _sample_trino(query, output_path):
|
|
40
|
+
raise NotImplementedError("Trino sampling is not yet implemented")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _sample_bigquery(query, output_path):
|
|
44
|
+
|
|
45
|
+
from google.cloud import bigquery
|
|
46
|
+
|
|
47
|
+
project_id = os.getenv("GCP_PROJECT_ID")
|
|
48
|
+
assert project_id, "Please set the GCP_PROJECT_ID environment variable"
|
|
49
|
+
|
|
50
|
+
client = bigquery.Client(project=project_id)
|
|
51
|
+
|
|
52
|
+
results = client.query_and_wait(query)
|
|
53
|
+
|
|
54
|
+
df = results.to_dataframe()
|
|
55
|
+
df.to_parquet(output_path)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _sample_bigquery_fast(query, destination_path):
|
|
59
|
+
import os
|
|
60
|
+
|
|
61
|
+
import pyarrow.parquet as pq
|
|
62
|
+
from google.cloud import bigquery
|
|
63
|
+
from google.cloud.bigquery_storage import BigQueryReadClient
|
|
64
|
+
from google.cloud.bigquery_storage_v1.types import DataFormat, ReadSession
|
|
65
|
+
|
|
66
|
+
project_id = os.getenv("GCP_PROJECT_ID")
|
|
67
|
+
assert project_id, "Please set the GCP_PROJECT_ID environment variable"
|
|
68
|
+
|
|
69
|
+
client = bigquery.Client(project=project_id)
|
|
70
|
+
bqstorage_client = BigQueryReadClient()
|
|
71
|
+
|
|
72
|
+
# Create query job
|
|
73
|
+
query_job = client.query(query)
|
|
74
|
+
table_ref = query_job.destination
|
|
75
|
+
|
|
76
|
+
# Create read session
|
|
77
|
+
read_session = ReadSession()
|
|
78
|
+
read_session.table = table_ref.to_bqstorage()
|
|
79
|
+
read_session.data_format = DataFormat.ARROW
|
|
80
|
+
|
|
81
|
+
print("Fetching from BigQuery... (this might take a while)")
|
|
82
|
+
|
|
83
|
+
session = bqstorage_client.create_read_session(
|
|
84
|
+
parent=f"projects/{client.project}",
|
|
85
|
+
read_session=read_session,
|
|
86
|
+
max_stream_count=1,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
print("Writing to local parquet file...")
|
|
90
|
+
|
|
91
|
+
# Read using Arrow
|
|
92
|
+
stream = bqstorage_client.read_rows(session.streams[0].name)
|
|
93
|
+
table = stream.to_arrow(read_session=session)
|
|
94
|
+
|
|
95
|
+
# Write to Parquet directly
|
|
96
|
+
pq.write_table(table, destination_path)
|
|
97
|
+
|
|
98
|
+
print(f"Wrote results to {destination_path}")
|
|
99
|
+
|
|
100
|
+
return destination_path
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from datetime import datetime, timedelta
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import List, Tuple
|
|
8
|
+
|
|
9
|
+
import ai.chronon.api.ttypes as chronon
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clean_table_name(name: str) -> str:
|
|
13
|
+
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
local_warehouse = Path(os.getenv("CHRONON_ROOT", os.getcwd())) / "local_warehouse"
|
|
17
|
+
limit = int(os.getenv("SAMPLE_LIMIT", "100"))
|
|
18
|
+
# create local_warehouse if it doesn't exist
|
|
19
|
+
local_warehouse.mkdir(parents=True, exist_ok=True)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class TableScan:
|
|
24
|
+
table: str
|
|
25
|
+
partition_col: str
|
|
26
|
+
partition_date: str
|
|
27
|
+
query: chronon.Query
|
|
28
|
+
is_mutations: bool = False
|
|
29
|
+
|
|
30
|
+
def output_path(self) -> str:
|
|
31
|
+
return Path(local_warehouse) / f"{self.view_name()}.parquet"
|
|
32
|
+
|
|
33
|
+
def view_name(self) -> str:
|
|
34
|
+
return clean_table_name(self.table) + "_" + self.where_id()
|
|
35
|
+
|
|
36
|
+
def table_name(self, local_table_view) -> str:
|
|
37
|
+
return self.view_name() if local_table_view else self.table
|
|
38
|
+
|
|
39
|
+
def where_id(self) -> str:
|
|
40
|
+
return "_" + hashlib.md5(self.where_block().encode()).hexdigest()[:3]
|
|
41
|
+
|
|
42
|
+
def where_block(self) -> str:
|
|
43
|
+
wheres = []
|
|
44
|
+
partition_scan = f"{self.partition_col} = '{self.partition_date}'"
|
|
45
|
+
wheres.append(partition_scan)
|
|
46
|
+
|
|
47
|
+
if self.query.wheres:
|
|
48
|
+
wheres.extend(self.query.wheres)
|
|
49
|
+
|
|
50
|
+
return " AND\n ".join([f"({where})" for where in wheres])
|
|
51
|
+
|
|
52
|
+
def raw_scan_query(self, local_table_view: bool = True) -> str:
|
|
53
|
+
return f"""
|
|
54
|
+
SELECT * FROM {self.table_name(local_table_view)}
|
|
55
|
+
WHERE
|
|
56
|
+
{self.where_block()}
|
|
57
|
+
LIMIT {limit}
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def scan_query(self, local_table_view=True) -> str:
|
|
61
|
+
selects = []
|
|
62
|
+
base_selects = self.query.selects.copy()
|
|
63
|
+
|
|
64
|
+
if self.is_mutations:
|
|
65
|
+
base_selects["is_before"] = coalesce(self.query.reversalColumn, "is_before")
|
|
66
|
+
base_selects["mutation_ts"] = coalesce(
|
|
67
|
+
self.query.mutationTimeColumn, "mutation_ts"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
if self.query.timeColumn:
|
|
71
|
+
base_selects["ts"] = coalesce(self.query.timeColumn, "ts")
|
|
72
|
+
|
|
73
|
+
for k, v in base_selects.items():
|
|
74
|
+
selects.append(f"{v} as {k}")
|
|
75
|
+
select_clauses = ",\n ".join(selects)
|
|
76
|
+
|
|
77
|
+
return f"""
|
|
78
|
+
SELECT
|
|
79
|
+
{select_clauses}
|
|
80
|
+
FROM
|
|
81
|
+
{self.table_name(local_table_view)}
|
|
82
|
+
WHERE
|
|
83
|
+
{self.where_block()}
|
|
84
|
+
LIMIT
|
|
85
|
+
{limit}
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# TODO: use teams.py to get the default date column
|
|
90
|
+
DEFAULT_DATE_COLUMN = "_date"
|
|
91
|
+
DEFAULT_DATE_FORMAT = "%Y-%m-%d"
|
|
92
|
+
|
|
93
|
+
two_days_ago = (datetime.now() - timedelta(days=2)).strftime(DEFAULT_DATE_FORMAT)
|
|
94
|
+
|
|
95
|
+
_sample_date = os.getenv("SAMPLE_DATE", two_days_ago)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def get_date(query: chronon.Query) -> Tuple[str, str]:
|
|
99
|
+
assert query and query.selects, "please specify source.query.selects"
|
|
100
|
+
|
|
101
|
+
partition_col = query.selects.get("ds", DEFAULT_DATE_COLUMN)
|
|
102
|
+
partition_date = coalesce(query.endPartition, _sample_date)
|
|
103
|
+
|
|
104
|
+
return (partition_col, partition_date)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def coalesce(*args):
|
|
108
|
+
for arg in args:
|
|
109
|
+
if arg:
|
|
110
|
+
return arg
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def table_scans_in_source(source: chronon.Source) -> List[TableScan]:
|
|
114
|
+
result = []
|
|
115
|
+
|
|
116
|
+
if not source:
|
|
117
|
+
return result
|
|
118
|
+
|
|
119
|
+
if source.entities:
|
|
120
|
+
query: chronon.Query = source.entities.query
|
|
121
|
+
col, date = get_date(query)
|
|
122
|
+
|
|
123
|
+
snapshot = TableScan(source.entities.snapshotTable, col, date, query)
|
|
124
|
+
result.append(snapshot)
|
|
125
|
+
|
|
126
|
+
if source.entities.mutationTable:
|
|
127
|
+
mutations = TableScan(source.entities.mutationTable, col, date, query, True)
|
|
128
|
+
result.append(mutations)
|
|
129
|
+
|
|
130
|
+
if source.events:
|
|
131
|
+
query = source.events.query
|
|
132
|
+
col, date = get_date(query)
|
|
133
|
+
table = TableScan(source.events.table, col, date, query)
|
|
134
|
+
result.append(table)
|
|
135
|
+
|
|
136
|
+
if source.joinSource:
|
|
137
|
+
result.extend(table_scans_in_source(source.joinSource.join.left))
|
|
138
|
+
|
|
139
|
+
return result
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def table_scans_in_sources(sources: List[chronon.Source]) -> List[TableScan]:
|
|
143
|
+
result = []
|
|
144
|
+
|
|
145
|
+
for source in sources:
|
|
146
|
+
result.extend(table_scans_in_source(source))
|
|
147
|
+
|
|
148
|
+
return result
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def table_scans_in_group_by(gb: chronon.GroupBy) -> List[TableScan]:
|
|
152
|
+
if not gb:
|
|
153
|
+
return []
|
|
154
|
+
|
|
155
|
+
return table_scans_in_sources(gb.sources)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def table_scans_in_join(join: chronon.Join) -> List[TableScan]:
|
|
159
|
+
|
|
160
|
+
result = []
|
|
161
|
+
|
|
162
|
+
if not join:
|
|
163
|
+
return result
|
|
164
|
+
|
|
165
|
+
result.extend(table_scans_in_source(join.left))
|
|
166
|
+
|
|
167
|
+
parts: List[chronon.JoinPart] = join.joinParts
|
|
168
|
+
if parts:
|
|
169
|
+
for part in parts:
|
|
170
|
+
result.extend(table_scans_in_group_by(part.groupBy))
|
|
171
|
+
|
|
172
|
+
bootstraps: List[chronon.BootstrapPart] = join.bootstrapParts
|
|
173
|
+
if bootstraps:
|
|
174
|
+
for bootstrap in bootstraps:
|
|
175
|
+
query = bootstrap.query
|
|
176
|
+
col, date = get_date(query)
|
|
177
|
+
bootstrap = TableScan(bootstrap.table, col, date, query)
|
|
178
|
+
|
|
179
|
+
result.append(bootstrap)
|
|
180
|
+
|
|
181
|
+
if join.labelParts:
|
|
182
|
+
labelParts: List[chronon.JoinPart] = join.labelParts.labels
|
|
183
|
+
for part in labelParts:
|
|
184
|
+
result.extend(table_scans_in_sources(part.groupBy))
|
|
185
|
+
|
|
186
|
+
return result
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__all__ = ['ttypes', 'constants']
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Autogenerated by Thrift Compiler (0.22.0)
|
|
3
|
+
#
|
|
4
|
+
# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
|
|
5
|
+
#
|
|
6
|
+
# options string: py
|
|
7
|
+
#
|
|
8
|
+
|
|
9
|
+
from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, TApplicationException
|
|
10
|
+
from thrift.protocol.TProtocol import TProtocolException
|
|
11
|
+
from thrift.TRecursive import fix_spec
|
|
12
|
+
from uuid import UUID
|
|
13
|
+
|
|
14
|
+
import sys
|
|
15
|
+
from .ttypes import *
|