squirrels 0.4.1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +10 -6
- squirrels/_api_response_models.py +93 -44
- squirrels/_api_server.py +571 -219
- squirrels/_auth.py +451 -0
- squirrels/_command_line.py +61 -20
- squirrels/_connection_set.py +38 -25
- squirrels/_constants.py +44 -34
- squirrels/_dashboards_io.py +34 -16
- squirrels/_exceptions.py +57 -0
- squirrels/_initializer.py +117 -44
- squirrels/_manifest.py +124 -62
- squirrels/_model_builder.py +111 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +860 -354
- squirrels/_package_loader.py +8 -4
- squirrels/_parameter_configs.py +45 -65
- squirrels/_parameter_sets.py +15 -13
- squirrels/_project.py +561 -0
- squirrels/_py_module.py +4 -3
- squirrels/_seeds.py +35 -16
- squirrels/_sources.py +106 -0
- squirrels/_utils.py +166 -63
- squirrels/_version.py +1 -1
- squirrels/arguments/init_time_args.py +78 -15
- squirrels/arguments/run_time_args.py +62 -101
- squirrels/dashboards.py +4 -4
- squirrels/data_sources.py +94 -162
- squirrels/dataset_result.py +86 -0
- squirrels/dateutils.py +4 -4
- squirrels/package_data/base_project/.env +30 -0
- squirrels/package_data/base_project/.env.example +30 -0
- squirrels/package_data/base_project/.gitignore +3 -2
- squirrels/package_data/base_project/assets/expenses.db +0 -0
- squirrels/package_data/base_project/connections.yml +11 -3
- squirrels/package_data/base_project/dashboards/dashboard_example.py +15 -13
- squirrels/package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/package_data/base_project/docker/.dockerignore +5 -2
- squirrels/package_data/base_project/docker/Dockerfile +3 -3
- squirrels/package_data/base_project/docker/compose.yml +1 -1
- squirrels/package_data/base_project/duckdb_init.sql +9 -0
- squirrels/package_data/base_project/macros/macros_example.sql +15 -0
- squirrels/package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/package_data/base_project/models/builds/build_example.yml +55 -0
- squirrels/package_data/base_project/models/dbviews/dbview_example.sql +12 -22
- squirrels/package_data/base_project/models/dbviews/dbview_example.yml +26 -0
- squirrels/package_data/base_project/models/federates/federate_example.py +38 -15
- squirrels/package_data/base_project/models/federates/federate_example.sql +16 -2
- squirrels/package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/package_data/base_project/models/sources.yml +39 -0
- squirrels/package_data/base_project/parameters.yml +36 -21
- squirrels/package_data/base_project/pyconfigs/connections.py +6 -11
- squirrels/package_data/base_project/pyconfigs/context.py +20 -33
- squirrels/package_data/base_project/pyconfigs/parameters.py +19 -21
- squirrels/package_data/base_project/pyconfigs/user.py +23 -0
- squirrels/package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +15 -15
- squirrels/package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/package_data/base_project/squirrels.yml.j2 +17 -40
- squirrels/parameters.py +20 -20
- {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info}/METADATA +31 -32
- squirrels-0.5.0rc0.dist-info/RECORD +70 -0
- {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info}/WHEEL +1 -1
- squirrels-0.5.0rc0.dist-info/entry_points.txt +3 -0
- {squirrels-0.4.1.dist-info → squirrels-0.5.0rc0.dist-info/licenses}/LICENSE +1 -1
- squirrels/_authenticator.py +0 -85
- squirrels/_environcfg.py +0 -84
- squirrels/package_data/assets/favicon.ico +0 -0
- squirrels/package_data/assets/index.css +0 -1
- squirrels/package_data/assets/index.js +0 -58
- squirrels/package_data/base_project/dashboards.yml +0 -10
- squirrels/package_data/base_project/env.yml +0 -29
- squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
- squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
- squirrels/package_data/templates/index.html +0 -18
- squirrels/project.py +0 -378
- squirrels/user_base.py +0 -55
- squirrels-0.4.1.dist-info/RECORD +0 -60
- squirrels-0.4.1.dist-info/entry_points.txt +0 -4
squirrels/_sources.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
from pydantic import BaseModel, Field, model_validator
|
|
3
|
+
import time, sqlglot
|
|
4
|
+
|
|
5
|
+
from . import _utils as u, _constants as c, _model_configs as mc
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UpdateHints(BaseModel):
|
|
9
|
+
increasing_column: str | None = Field(default=None)
|
|
10
|
+
strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if value is set")
|
|
11
|
+
selective_overwrite_value: Any = Field(default=None)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Source(mc.ConnectionInterface, mc.ModelConfig):
|
|
15
|
+
table: str | None = Field(default=None)
|
|
16
|
+
load_to_duckdb: bool = Field(default=False, description="Whether to load the data to DuckDB")
|
|
17
|
+
primary_key: list[str] = Field(default_factory=list)
|
|
18
|
+
update_hints: UpdateHints = Field(default_factory=UpdateHints)
|
|
19
|
+
|
|
20
|
+
def finalize_table(self, source_name: str):
|
|
21
|
+
if self.table is None:
|
|
22
|
+
self.table = source_name
|
|
23
|
+
return self
|
|
24
|
+
|
|
25
|
+
def get_table(self) -> str:
|
|
26
|
+
assert self.table is not None, "Table must be set"
|
|
27
|
+
return self.table
|
|
28
|
+
|
|
29
|
+
def get_cols_for_create_table_stmt(self) -> str:
|
|
30
|
+
cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
|
|
31
|
+
primary_key_clause = f", PRIMARY KEY ({', '.join(self.primary_key)})" if self.primary_key else ""
|
|
32
|
+
return f"{cols_clause}{primary_key_clause}"
|
|
33
|
+
|
|
34
|
+
def get_cols_for_insert_stmt(self) -> str:
|
|
35
|
+
return ", ".join([col.name for col in self.columns])
|
|
36
|
+
|
|
37
|
+
def get_max_incr_col_query(self, source_name: str) -> str:
|
|
38
|
+
return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
|
|
39
|
+
|
|
40
|
+
def get_query_for_insert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
|
|
41
|
+
select_cols = self.get_cols_for_insert_stmt()
|
|
42
|
+
if full_refresh or max_value_of_increasing_col is None:
|
|
43
|
+
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
|
|
44
|
+
|
|
45
|
+
increasing_col = self.update_hints.increasing_column
|
|
46
|
+
increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
|
|
47
|
+
where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
|
|
48
|
+
pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
|
|
49
|
+
|
|
50
|
+
if dialect in ['postgres', 'mysql']:
|
|
51
|
+
transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
|
|
52
|
+
return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
|
|
53
|
+
|
|
54
|
+
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
|
|
55
|
+
|
|
56
|
+
def get_insert_replace_clause(self) -> str:
|
|
57
|
+
return "" if len(self.primary_key) == 0 else "OR REPLACE"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Sources(BaseModel):
|
|
61
|
+
sources: dict[str, Source] = Field(default_factory=dict)
|
|
62
|
+
|
|
63
|
+
@model_validator(mode="before")
|
|
64
|
+
@classmethod
|
|
65
|
+
def convert_sources_list_to_dict(cls, data: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
if "sources" in data and isinstance(data["sources"], list):
|
|
67
|
+
# Convert list of sources to dictionary
|
|
68
|
+
sources_dict = {}
|
|
69
|
+
for source in data["sources"]:
|
|
70
|
+
if isinstance(source, dict) and "name" in source:
|
|
71
|
+
name = source.pop("name") # Remove name from source config
|
|
72
|
+
if name in sources_dict:
|
|
73
|
+
raise u.ConfigurationError(f"Duplicate source name found: {name}")
|
|
74
|
+
sources_dict[name] = source
|
|
75
|
+
else:
|
|
76
|
+
raise u.ConfigurationError(f"All sources must have a name field in sources file")
|
|
77
|
+
data["sources"] = sources_dict
|
|
78
|
+
return data
|
|
79
|
+
|
|
80
|
+
@model_validator(mode="after")
|
|
81
|
+
def validate_column_types(self):
|
|
82
|
+
for source_name, source in self.sources.items():
|
|
83
|
+
for col in source.columns:
|
|
84
|
+
if not col.type:
|
|
85
|
+
raise u.ConfigurationError(f"Column '{col.name}' in source '{source_name}' must have a type specified")
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
def finalize_null_fields(self, env_vars: dict[str, str]):
|
|
89
|
+
for source_name, source in self.sources.items():
|
|
90
|
+
source.finalize_connection(env_vars)
|
|
91
|
+
source.finalize_table(source_name)
|
|
92
|
+
return self
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SourcesIO:
|
|
96
|
+
@classmethod
|
|
97
|
+
def load_file(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Sources:
|
|
98
|
+
start = time.time()
|
|
99
|
+
|
|
100
|
+
sources_path = u.Path(base_path, c.MODELS_FOLDER, c.SOURCES_FILE)
|
|
101
|
+
sources_data = u.load_yaml_config(sources_path) if sources_path.exists() else {}
|
|
102
|
+
|
|
103
|
+
sources = Sources(**sources_data).finalize_null_fields(env_vars)
|
|
104
|
+
|
|
105
|
+
logger.log_activity_time("loading sources", start)
|
|
106
|
+
return sources
|
squirrels/_utils.py
CHANGED
|
@@ -1,35 +1,35 @@
|
|
|
1
|
-
from typing import Sequence, Optional, Union, TypeVar, Callable
|
|
2
|
-
from pathlib import Path
|
|
3
|
-
from pandas.api import types as pd_types
|
|
1
|
+
from typing import Sequence, Optional, Union, TypeVar, Callable, Any, Iterable
|
|
4
2
|
from datetime import datetime
|
|
5
|
-
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from functools import lru_cache
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
import os, time, logging, json, duckdb, polars as pl, yaml
|
|
6
7
|
import jinja2 as j2, jinja2.nodes as j2_nodes
|
|
8
|
+
import sqlglot, sqlglot.expressions, asyncio
|
|
7
9
|
|
|
8
10
|
from . import _constants as c
|
|
11
|
+
from ._exceptions import ConfigurationError
|
|
9
12
|
|
|
10
13
|
FilePath = Union[str, Path]
|
|
11
14
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
"""
|
|
17
|
-
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
""
|
|
23
|
-
|
|
24
|
-
"""
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
new_message = f"\n" + message + f"\n{t}Produced error message:\n{t}{t}{error} (see above for more details on handled exception)"
|
|
31
|
-
super().__init__(new_message, *args)
|
|
32
|
-
self.error = error
|
|
15
|
+
# Polars
|
|
16
|
+
polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
|
|
17
|
+
pl.String: ["string", "varchar", "char", "text"],
|
|
18
|
+
pl.Int8: ["tinyint", "int1"],
|
|
19
|
+
pl.Int16: ["smallint", "short", "int2"],
|
|
20
|
+
pl.Int32: ["integer", "int", "int4"],
|
|
21
|
+
pl.Int64: ["bigint", "long", "int8"],
|
|
22
|
+
pl.Float32: ["float", "float4", "real"],
|
|
23
|
+
pl.Float64: ["double", "float8"],
|
|
24
|
+
pl.Boolean: ["boolean", "bool", "logical"],
|
|
25
|
+
pl.Date: ["date"],
|
|
26
|
+
pl.Time: ["time"],
|
|
27
|
+
pl.Datetime: ["timestamp", "datetime"],
|
|
28
|
+
pl.Duration: ["interval"],
|
|
29
|
+
pl.Binary: ["blob", "binary", "varbinary"]
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v}
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
## Other utility classes
|
|
@@ -40,7 +40,7 @@ class Logger(logging.Logger):
|
|
|
40
40
|
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
41
41
|
data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
|
|
42
42
|
info = { "request_id": request_id } if request_id else {}
|
|
43
|
-
self.
|
|
43
|
+
self.info(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
class EnvironmentWithMacros(j2.Environment):
|
|
@@ -115,7 +115,6 @@ def read_file(filepath: FilePath) -> str:
|
|
|
115
115
|
|
|
116
116
|
Arguments:
|
|
117
117
|
filepath (str | pathlib.Path): The path to the file to read
|
|
118
|
-
is_required: If true, throw error if file doesn't exist
|
|
119
118
|
|
|
120
119
|
Returns:
|
|
121
120
|
Content of the file, or None if doesn't exist and not required
|
|
@@ -180,7 +179,7 @@ def load_json_or_comma_delimited_str_as_list(input_str: Union[str, Sequence]) ->
|
|
|
180
179
|
return [x.strip() for x in input_str.split(",")]
|
|
181
180
|
|
|
182
181
|
|
|
183
|
-
X
|
|
182
|
+
X = TypeVar('X'); Y = TypeVar('Y')
|
|
184
183
|
def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) -> Optional[Y]:
|
|
185
184
|
"""
|
|
186
185
|
Given a input value and a function that processes the value, return the output of the function unless input is None
|
|
@@ -197,60 +196,164 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
|
|
|
197
196
|
return processor(input_val)
|
|
198
197
|
|
|
199
198
|
|
|
200
|
-
|
|
199
|
+
@lru_cache(maxsize=1)
|
|
200
|
+
def _read_duckdb_init_sql() -> tuple[str, Path | None]:
|
|
201
|
+
"""
|
|
202
|
+
Reads and caches the duckdb init file content.
|
|
203
|
+
Returns None if file doesn't exist or is empty.
|
|
204
|
+
"""
|
|
205
|
+
try:
|
|
206
|
+
init_contents = []
|
|
207
|
+
global_init_path = Path(os.path.expanduser('~'), c.GLOBAL_ENV_FOLDER, c.DUCKDB_INIT_FILE)
|
|
208
|
+
if global_init_path.exists():
|
|
209
|
+
with open(global_init_path, 'r') as f:
|
|
210
|
+
init_contents.append(f.read())
|
|
211
|
+
|
|
212
|
+
if Path(c.DUCKDB_INIT_FILE).exists():
|
|
213
|
+
with open(c.DUCKDB_INIT_FILE, 'r') as f:
|
|
214
|
+
init_contents.append(f.read())
|
|
215
|
+
|
|
216
|
+
init_sql = "\n".join(init_contents).strip()
|
|
217
|
+
target_init_path = None
|
|
218
|
+
if init_sql:
|
|
219
|
+
target_init_path = Path(c.TARGET_FOLDER, c.DUCKDB_INIT_FILE)
|
|
220
|
+
target_init_path.parent.mkdir(parents=True, exist_ok=True)
|
|
221
|
+
target_init_path.write_text(init_sql)
|
|
222
|
+
|
|
223
|
+
return init_sql, target_init_path
|
|
224
|
+
except Exception as e:
|
|
225
|
+
raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
226
|
+
|
|
227
|
+
def create_duckdb_connection(filepath: str | Path = ":memory:", *, read_only: bool = False) -> duckdb.DuckDBPyConnection:
|
|
228
|
+
"""
|
|
229
|
+
Creates a DuckDB connection and initializes it with statements from duckdb init file
|
|
230
|
+
|
|
231
|
+
Arguments:
|
|
232
|
+
filepath: Path to the DuckDB database file. Defaults to in-memory database.
|
|
233
|
+
read_only: Whether to open the database in read-only mode. Defaults to False.
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
A DuckDB connection (which must be closed after use)
|
|
237
|
+
"""
|
|
238
|
+
conn = duckdb.connect(filepath, read_only=read_only)
|
|
239
|
+
|
|
240
|
+
try:
|
|
241
|
+
init_sql, _ = _read_duckdb_init_sql()
|
|
242
|
+
if init_sql:
|
|
243
|
+
conn.execute(init_sql)
|
|
244
|
+
except Exception as e:
|
|
245
|
+
conn.close()
|
|
246
|
+
raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
247
|
+
|
|
248
|
+
return conn
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> pl.DataFrame:
|
|
201
252
|
"""
|
|
202
253
|
Runs a SQL query against a collection of dataframes
|
|
203
254
|
|
|
204
255
|
Arguments:
|
|
205
256
|
sql_query: The SQL query to run
|
|
206
|
-
dataframes: A dictionary of table names to their
|
|
257
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
207
258
|
|
|
208
259
|
Returns:
|
|
209
|
-
The result as a
|
|
260
|
+
The result as a polars Dataframe from running the query
|
|
210
261
|
"""
|
|
211
|
-
|
|
212
|
-
import duckdb
|
|
213
|
-
duckdb_conn = duckdb.connect()
|
|
214
|
-
else:
|
|
215
|
-
conn = sqlite3.connect(":memory:")
|
|
262
|
+
duckdb_conn = create_duckdb_connection()
|
|
216
263
|
|
|
217
264
|
try:
|
|
218
265
|
for name, df in dataframes.items():
|
|
219
|
-
|
|
220
|
-
duckdb_conn.execute(f"CREATE TABLE {name} AS FROM df")
|
|
221
|
-
else:
|
|
222
|
-
df.to_sql(name, conn, index=False)
|
|
266
|
+
duckdb_conn.register(name, df)
|
|
223
267
|
|
|
224
|
-
|
|
268
|
+
result_df = duckdb_conn.sql(sql_query).pl()
|
|
225
269
|
finally:
|
|
226
|
-
duckdb_conn.close()
|
|
270
|
+
duckdb_conn.close()
|
|
271
|
+
|
|
272
|
+
return result_df
|
|
227
273
|
|
|
228
274
|
|
|
229
|
-
def
|
|
275
|
+
def load_yaml_config(filepath: FilePath) -> dict:
|
|
230
276
|
"""
|
|
231
|
-
|
|
277
|
+
Loads a YAML config file
|
|
232
278
|
|
|
233
279
|
Arguments:
|
|
234
|
-
|
|
235
|
-
|
|
280
|
+
filepath: The path to the YAML file
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
A dictionary representation of the YAML file
|
|
284
|
+
"""
|
|
285
|
+
try:
|
|
286
|
+
with open(filepath, 'r') as f:
|
|
287
|
+
return yaml.safe_load(f)
|
|
288
|
+
except yaml.YAMLError as e:
|
|
289
|
+
raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def run_duckdb_stmt(
|
|
293
|
+
logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None, redacted_values: list[str] = []
|
|
294
|
+
) -> duckdb.DuckDBPyConnection:
|
|
295
|
+
"""
|
|
296
|
+
Runs a statement on a DuckDB connection
|
|
297
|
+
|
|
298
|
+
Arguments:
|
|
299
|
+
logger: The logger to use
|
|
300
|
+
duckdb_conn: The DuckDB connection
|
|
301
|
+
stmt: The statement to run
|
|
302
|
+
params: The parameters to use
|
|
303
|
+
redacted_values: The values to redact
|
|
304
|
+
"""
|
|
305
|
+
redacted_stmt = stmt
|
|
306
|
+
for value in redacted_values:
|
|
307
|
+
redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
|
|
308
|
+
|
|
309
|
+
logger.info(f"Running statement: {redacted_stmt}", extra={"data": {"params": params}})
|
|
310
|
+
try:
|
|
311
|
+
return duckdb_conn.execute(stmt, params)
|
|
312
|
+
except duckdb.ParserException as e:
|
|
313
|
+
logger.error(f"Failed to run statement: {redacted_stmt}", exc_info=e)
|
|
314
|
+
raise e
|
|
315
|
+
|
|
236
316
|
|
|
317
|
+
def get_current_time() -> str:
|
|
318
|
+
"""
|
|
319
|
+
Returns the current time in the format HH:MM:SS.ms
|
|
320
|
+
"""
|
|
321
|
+
return datetime.now().strftime('%H:%M:%S.%f')[:-3]
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def parse_dependent_tables(sql_query: str, all_table_names: Iterable[str]) -> tuple[set[str], sqlglot.Expression]:
|
|
325
|
+
"""
|
|
326
|
+
Parses the dependent tables from a SQL query
|
|
327
|
+
|
|
328
|
+
Arguments:
|
|
329
|
+
sql_query: The SQL query to parse
|
|
330
|
+
all_table_names: The list of all table names
|
|
331
|
+
|
|
237
332
|
Returns:
|
|
238
|
-
The
|
|
333
|
+
The set of dependent tables
|
|
239
334
|
"""
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
for in_column in in_df_json["schema"]["fields"]:
|
|
244
|
-
col_name: str = in_column["name"]
|
|
245
|
-
out_column = { "name": col_name, "type": in_column["type"] }
|
|
246
|
-
out_fields.append(out_column)
|
|
247
|
-
|
|
248
|
-
if not pd_types.is_numeric_dtype(df[col_name].dtype):
|
|
249
|
-
non_numeric_fields.append(col_name)
|
|
335
|
+
# Parse the SQL query and extract all table references
|
|
336
|
+
parsed = sqlglot.parse_one(sql_query)
|
|
337
|
+
dependencies = set()
|
|
250
338
|
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
return
|
|
339
|
+
# Collect all table references from the parsed SQL
|
|
340
|
+
for table in parsed.find_all(sqlglot.expressions.Table):
|
|
341
|
+
if table.name in set(all_table_names):
|
|
342
|
+
dependencies.add(table.name)
|
|
343
|
+
|
|
344
|
+
return dependencies, parsed
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
async def asyncio_gather(coroutines: list):
|
|
348
|
+
tasks = [asyncio.create_task(coro) for coro in coroutines]
|
|
349
|
+
|
|
350
|
+
try:
|
|
351
|
+
return await asyncio.gather(*tasks)
|
|
352
|
+
except BaseException:
|
|
353
|
+
# Cancel all tasks
|
|
354
|
+
for task in tasks:
|
|
355
|
+
if not task.done():
|
|
356
|
+
task.cancel()
|
|
357
|
+
# Wait for tasks to be cancelled
|
|
358
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
359
|
+
raise
|
squirrels/_version.py
CHANGED
|
@@ -1,40 +1,103 @@
|
|
|
1
|
-
from typing import
|
|
1
|
+
from typing import Any, Iterable, Callable
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
import polars as pl
|
|
3
4
|
|
|
5
|
+
from .. import _utils as u
|
|
4
6
|
|
|
5
7
|
@dataclass
|
|
6
|
-
class
|
|
8
|
+
class ConnectionsArgs:
|
|
9
|
+
project_path: str
|
|
7
10
|
_proj_vars: dict[str, Any]
|
|
8
|
-
_env_vars: dict[str,
|
|
11
|
+
_env_vars: dict[str, str]
|
|
9
12
|
|
|
10
13
|
@property
|
|
11
14
|
def proj_vars(self) -> dict[str, Any]:
|
|
12
15
|
return self._proj_vars.copy()
|
|
13
16
|
|
|
14
17
|
@property
|
|
15
|
-
def env_vars(self) -> dict[str,
|
|
18
|
+
def env_vars(self) -> dict[str, str]:
|
|
16
19
|
return self._env_vars.copy()
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
@dataclass
|
|
20
|
-
class ConnectionsArgs
|
|
21
|
-
|
|
23
|
+
class ParametersArgs(ConnectionsArgs):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class _WithConnectionDictArgs(ConnectionsArgs):
|
|
29
|
+
_connections: dict[str, Any]
|
|
30
|
+
|
|
31
|
+
@property
|
|
32
|
+
def connections(self) -> dict[str, Any]:
|
|
33
|
+
"""
|
|
34
|
+
A dictionary of connection keys to SQLAlchemy Engines for database connections.
|
|
35
|
+
|
|
36
|
+
Can also be used to store other in-memory objects in advance such as ML models.
|
|
37
|
+
"""
|
|
38
|
+
return self._connections.copy()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class BuildModelArgs(_WithConnectionDictArgs):
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self, conn_args: ConnectionsArgs, _connections: dict[str, Any],
|
|
45
|
+
dependencies: Iterable[str],
|
|
46
|
+
ref: Callable[[str], pl.LazyFrame],
|
|
47
|
+
run_external_sql: Callable[[str, str], pl.DataFrame]
|
|
48
|
+
):
|
|
49
|
+
super().__init__(conn_args.project_path, conn_args.proj_vars, conn_args.env_vars, _connections)
|
|
50
|
+
self._dependencies = dependencies
|
|
51
|
+
self._ref = ref
|
|
52
|
+
self._run_external_sql = run_external_sql
|
|
22
53
|
|
|
23
|
-
|
|
54
|
+
@property
|
|
55
|
+
def dependencies(self) -> set[str]:
|
|
56
|
+
"""
|
|
57
|
+
The set of dependent data model names
|
|
58
|
+
"""
|
|
59
|
+
return set(self._dependencies)
|
|
60
|
+
|
|
61
|
+
def ref(self, model: str) -> pl.LazyFrame:
|
|
24
62
|
"""
|
|
25
|
-
|
|
63
|
+
Returns the result (as polars DataFrame) of a dependent model (predefined in "dependencies" function)
|
|
26
64
|
|
|
27
|
-
|
|
65
|
+
Note: This is different behaviour than the "ref" function for SQL models, which figures out the dependent models for you,
|
|
66
|
+
and returns a string for the table/view name instead of a polars DataFrame.
|
|
28
67
|
|
|
29
68
|
Arguments:
|
|
30
|
-
|
|
69
|
+
model: The model name
|
|
31
70
|
|
|
32
71
|
Returns:
|
|
33
|
-
A
|
|
72
|
+
A polars DataFrame
|
|
34
73
|
"""
|
|
35
|
-
return self.
|
|
74
|
+
return self._ref(model)
|
|
36
75
|
|
|
76
|
+
def run_external_sql(self, connection_name: str, sql_query: str, **kwargs) -> pl.DataFrame:
|
|
77
|
+
"""
|
|
78
|
+
Runs a SQL query against an external database, with option to specify the connection name. Placeholder values are provided automatically
|
|
37
79
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
80
|
+
Arguments:
|
|
81
|
+
sql_query: The SQL query. Can be parameterized with placeholders
|
|
82
|
+
connection_name: The connection name for the database. If None, uses the one configured for the model
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
The query result as a polars DataFrame
|
|
86
|
+
"""
|
|
87
|
+
return self._run_external_sql(sql_query, connection_name)
|
|
88
|
+
|
|
89
|
+
def run_sql_on_dataframes(self, sql_query: str, *, dataframes: dict[str, pl.LazyFrame] | None = None, **kwargs) -> pl.DataFrame:
|
|
90
|
+
"""
|
|
91
|
+
Uses a dictionary of dataframes to execute a SQL query in an embedded in-memory database (sqlite or duckdb based on setting)
|
|
92
|
+
|
|
93
|
+
Arguments:
|
|
94
|
+
sql_query: The SQL query to run
|
|
95
|
+
dataframes: A dictionary of table names to their polars LazyFrame. If None, uses results of dependent models
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
The result as a polars LazyFrame from running the query
|
|
99
|
+
"""
|
|
100
|
+
if dataframes is None:
|
|
101
|
+
dataframes = {x: self.ref(x) for x in self._dependencies}
|
|
102
|
+
|
|
103
|
+
return u.run_sql_on_dataframes(sql_query, dataframes)
|