squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- squirrels/__init__.py +4 -0
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +440 -792
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
- squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
- squirrels/_auth.py +590 -264
- squirrels/_command_line.py +130 -58
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +16 -15
- squirrels/_constants.py +36 -11
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +40 -34
- squirrels/_dataset_types.py +16 -11
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +7 -6
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +155 -77
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +11 -55
- squirrels/_model_configs.py +5 -5
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +276 -143
- squirrels/_package_data/base_project/.env +1 -24
- squirrels/_package_data/base_project/.env.example +31 -17
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
- squirrels/_package_data/base_project/docker/Dockerfile +2 -2
- squirrels/_package_data/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
- squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
- squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
- squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +41 -30
- squirrels/_parameters.py +560 -123
- squirrels/_project.py +487 -277
- squirrels/_py_module.py +71 -10
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +52 -13
- squirrels/_sources.py +29 -23
- squirrels/_utils.py +221 -42
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -2
- squirrels/auth.py +4 -0
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +10 -3
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
- squirrels/_api_response_models.py +0 -190
- squirrels/_dashboard_types.py +0 -82
- squirrels/_dashboards_io.py +0 -79
- squirrels-0.5.0b3.dist-info/METADATA +0 -110
- squirrels-0.5.0b3.dist-info/RECORD +0 -80
- /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
- /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_seeds.py
CHANGED
|
@@ -1,7 +1,15 @@
|
|
|
1
1
|
from dataclasses import dataclass
|
|
2
|
-
import os
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
import time
|
|
5
|
+
import glob
|
|
6
|
+
import json
|
|
3
7
|
|
|
8
|
+
import polars as pl
|
|
9
|
+
|
|
10
|
+
from ._exceptions import ConfigurationError
|
|
4
11
|
from . import _utils as u, _constants as c, _model_configs as mc
|
|
12
|
+
from ._env_vars import SquirrelsEnvVars
|
|
5
13
|
|
|
6
14
|
|
|
7
15
|
@dataclass
|
|
@@ -13,21 +21,47 @@ class Seed:
|
|
|
13
21
|
if self.config.cast_column_types:
|
|
14
22
|
exprs = []
|
|
15
23
|
for col_config in self.config.columns:
|
|
16
|
-
|
|
17
|
-
|
|
24
|
+
col_type = col_config.type.lower()
|
|
25
|
+
if col_type.startswith("decimal"):
|
|
26
|
+
polars_dtype = self._parse_decimal_type(col_type)
|
|
27
|
+
else:
|
|
28
|
+
try:
|
|
29
|
+
polars_dtype = u.sqrl_dtypes_to_polars_dtypes[col_type]
|
|
30
|
+
except KeyError as e:
|
|
31
|
+
raise ConfigurationError(f"Unknown column type: '{col_type}'") from e
|
|
32
|
+
|
|
18
33
|
exprs.append(pl.col(col_config.name).cast(polars_dtype))
|
|
19
34
|
|
|
20
35
|
self.df = self.df.with_columns(*exprs)
|
|
21
36
|
|
|
37
|
+
@staticmethod
|
|
38
|
+
def _parse_decimal_type(col_type: str) -> pl.Decimal:
|
|
39
|
+
"""Parse a decimal type string and return the appropriate polars Decimal type.
|
|
40
|
+
|
|
41
|
+
Supports formats: "decimal" or "decimal(precision, scale)"
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
# Match decimal(precision, scale) pattern
|
|
45
|
+
match = re.match(r"decimal\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", col_type)
|
|
46
|
+
if match:
|
|
47
|
+
precision = int(match.group(1))
|
|
48
|
+
scale = int(match.group(2))
|
|
49
|
+
return pl.Decimal(precision=precision, scale=scale)
|
|
50
|
+
|
|
51
|
+
if col_type == "decimal":
|
|
52
|
+
return pl.Decimal(precision=18, scale=2)
|
|
53
|
+
|
|
54
|
+
raise ConfigurationError(f"Unknown column type: '{col_type}'")
|
|
55
|
+
|
|
22
56
|
|
|
23
57
|
@dataclass
|
|
24
58
|
class Seeds:
|
|
25
59
|
_data: dict[str, Seed]
|
|
26
|
-
|
|
60
|
+
|
|
27
61
|
def run_query(self, sql_query: str) -> pl.DataFrame:
|
|
28
62
|
dataframes = {key: seed.df for key, seed in self._data.items()}
|
|
29
63
|
return u.run_sql_on_dataframes(sql_query, dataframes)
|
|
30
|
-
|
|
64
|
+
|
|
31
65
|
def get_dataframes(self) -> dict[str, Seed]:
|
|
32
66
|
return self._data.copy()
|
|
33
67
|
|
|
@@ -35,13 +69,14 @@ class Seeds:
|
|
|
35
69
|
class SeedsIO:
|
|
36
70
|
|
|
37
71
|
@classmethod
|
|
38
|
-
def load_files(cls, logger: u.Logger,
|
|
72
|
+
def load_files(cls, logger: u.Logger, env_vars: SquirrelsEnvVars) -> Seeds:
|
|
39
73
|
start = time.time()
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
74
|
+
project_path = env_vars.project_path
|
|
75
|
+
infer_schema_setting: bool = env_vars.seeds_infer_schema
|
|
76
|
+
na_values_setting: list[str] = env_vars.seeds_na_values
|
|
77
|
+
|
|
43
78
|
seeds_dict = {}
|
|
44
|
-
csv_files = glob.glob(os.path.join(
|
|
79
|
+
csv_files = glob.glob(os.path.join(project_path, c.SEEDS_FOLDER, '**/*.csv'), recursive=True)
|
|
45
80
|
for csv_file in csv_files:
|
|
46
81
|
config_file = os.path.splitext(csv_file)[0] + '.yml'
|
|
47
82
|
config_dict = u.load_yaml_config(config_file) if os.path.exists(config_file) else {}
|
|
@@ -49,10 +84,14 @@ class SeedsIO:
|
|
|
49
84
|
|
|
50
85
|
file_stem = os.path.splitext(os.path.basename(csv_file))[0]
|
|
51
86
|
infer_schema = not config.cast_column_types and infer_schema_setting
|
|
52
|
-
df = pl.read_csv(
|
|
53
|
-
|
|
87
|
+
df = pl.read_csv(
|
|
88
|
+
csv_file, try_parse_dates=True,
|
|
89
|
+
infer_schema=infer_schema,
|
|
90
|
+
null_values=na_values_setting
|
|
91
|
+
).lazy()
|
|
92
|
+
|
|
54
93
|
seeds_dict[file_stem] = Seed(config, df)
|
|
55
|
-
|
|
94
|
+
|
|
56
95
|
seeds = Seeds(seeds_dict)
|
|
57
96
|
logger.log_activity_time("loading seed files", start)
|
|
58
97
|
return seeds
|
squirrels/_sources.py
CHANGED
|
@@ -1,19 +1,20 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
from pydantic import BaseModel, Field, model_validator
|
|
3
|
-
import time,
|
|
3
|
+
import time, yaml
|
|
4
4
|
|
|
5
5
|
from . import _utils as u, _constants as c, _model_configs as mc
|
|
6
|
+
from ._env_vars import SquirrelsEnvVars
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class UpdateHints(BaseModel):
|
|
9
10
|
increasing_column: str | None = Field(default=None)
|
|
10
|
-
strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if
|
|
11
|
-
selective_overwrite_value: Any = Field(default=None)
|
|
11
|
+
strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if selective_overwrite_value is set")
|
|
12
|
+
selective_overwrite_value: Any = Field(default=None, description="Delete all values of the increasing column greater than or equal to this value")
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class Source(mc.ConnectionInterface, mc.ModelConfig):
|
|
15
16
|
table: str | None = Field(default=None)
|
|
16
|
-
|
|
17
|
+
load_to_vdl: bool = Field(default=False, description="Whether to load the data to the 'virtual data lake' (VDL)")
|
|
17
18
|
primary_key: list[str] = Field(default_factory=list)
|
|
18
19
|
update_hints: UpdateHints = Field(default_factory=UpdateHints)
|
|
19
20
|
|
|
@@ -28,34 +29,28 @@ class Source(mc.ConnectionInterface, mc.ModelConfig):
|
|
|
28
29
|
|
|
29
30
|
def get_cols_for_create_table_stmt(self) -> str:
|
|
30
31
|
cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
|
|
31
|
-
|
|
32
|
-
return f"{cols_clause}{primary_key_clause}"
|
|
33
|
-
|
|
34
|
-
def get_cols_for_insert_stmt(self) -> str:
|
|
35
|
-
return ", ".join([col.name for col in self.columns])
|
|
32
|
+
return cols_clause
|
|
36
33
|
|
|
37
34
|
def get_max_incr_col_query(self, source_name: str) -> str:
|
|
38
35
|
return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
|
|
39
36
|
|
|
40
|
-
def
|
|
41
|
-
select_cols = self.
|
|
37
|
+
def get_query_for_upsert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
|
|
38
|
+
select_cols = ", ".join([col.name for col in self.columns])
|
|
42
39
|
if full_refresh or max_value_of_increasing_col is None:
|
|
43
40
|
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
|
|
44
41
|
|
|
45
42
|
increasing_col = self.update_hints.increasing_column
|
|
46
43
|
increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
|
|
47
44
|
where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
|
|
48
|
-
pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
|
|
49
45
|
|
|
50
|
-
if
|
|
51
|
-
|
|
52
|
-
|
|
46
|
+
# TODO: figure out if using pushdown query is worth it
|
|
47
|
+
# if dialect in ['postgres', 'mysql']:
|
|
48
|
+
# pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
|
|
49
|
+
# transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
|
|
50
|
+
# return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
|
|
53
51
|
|
|
54
52
|
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
|
|
55
53
|
|
|
56
|
-
def get_insert_replace_clause(self) -> str:
|
|
57
|
-
return "" if len(self.primary_key) == 0 else "OR REPLACE"
|
|
58
|
-
|
|
59
54
|
|
|
60
55
|
class Sources(BaseModel):
|
|
61
56
|
sources: dict[str, Source] = Field(default_factory=dict)
|
|
@@ -85,20 +80,31 @@ class Sources(BaseModel):
|
|
|
85
80
|
raise u.ConfigurationError(f"Column '{col.name}' in source '{source_name}' must have a type specified")
|
|
86
81
|
return self
|
|
87
82
|
|
|
88
|
-
def finalize_null_fields(self, env_vars:
|
|
83
|
+
def finalize_null_fields(self, env_vars: SquirrelsEnvVars):
|
|
84
|
+
default_conn_name = env_vars.connections_default_name_used
|
|
89
85
|
for source_name, source in self.sources.items():
|
|
90
|
-
source.finalize_connection(
|
|
86
|
+
source.finalize_connection(default_conn_name=default_conn_name)
|
|
91
87
|
source.finalize_table(source_name)
|
|
92
88
|
return self
|
|
93
89
|
|
|
94
90
|
|
|
95
91
|
class SourcesIO:
|
|
96
92
|
@classmethod
|
|
97
|
-
def load_file(cls, logger: u.Logger,
|
|
93
|
+
def load_file(cls, logger: u.Logger, env_vars: SquirrelsEnvVars, env_vars_unformatted: dict[str, str]) -> Sources:
|
|
98
94
|
start = time.time()
|
|
99
95
|
|
|
100
|
-
sources_path = u.Path(
|
|
101
|
-
|
|
96
|
+
sources_path = u.Path(env_vars.project_path, c.MODELS_FOLDER, c.SOURCES_FILE)
|
|
97
|
+
if sources_path.exists():
|
|
98
|
+
raw_content = u.read_file(sources_path)
|
|
99
|
+
rendered = u.render_string(raw_content, project_path=env_vars.project_path, env_vars=env_vars_unformatted)
|
|
100
|
+
sources_data = yaml.safe_load(rendered) or {}
|
|
101
|
+
else:
|
|
102
|
+
sources_data = {}
|
|
103
|
+
|
|
104
|
+
if not isinstance(sources_data, dict):
|
|
105
|
+
raise u.ConfigurationError(
|
|
106
|
+
f"Parsed content from YAML file must be a dictionary. Got: {sources_data}"
|
|
107
|
+
)
|
|
102
108
|
|
|
103
109
|
sources = Sources(**sources_data).finalize_null_fields(env_vars)
|
|
104
110
|
|
squirrels/_utils.py
CHANGED
|
@@ -1,18 +1,16 @@
|
|
|
1
|
-
from typing import Sequence, Optional, Union, TypeVar, Callable,
|
|
1
|
+
from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from functools import lru_cache
|
|
5
|
-
from pydantic import BaseModel
|
|
6
4
|
import os, time, logging, json, duckdb, polars as pl, yaml
|
|
7
5
|
import jinja2 as j2, jinja2.nodes as j2_nodes
|
|
8
|
-
import sqlglot, sqlglot.expressions, asyncio
|
|
6
|
+
import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
|
|
9
7
|
|
|
10
8
|
from . import _constants as c
|
|
11
9
|
from ._exceptions import ConfigurationError
|
|
12
10
|
|
|
13
11
|
FilePath = Union[str, Path]
|
|
14
12
|
|
|
15
|
-
# Polars
|
|
13
|
+
# Polars <-> Squirrels dtypes mappings (except Decimal)
|
|
16
14
|
polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
|
|
17
15
|
pl.String: ["string", "varchar", "char", "text"],
|
|
18
16
|
pl.Int8: ["tinyint", "int1"],
|
|
@@ -20,7 +18,7 @@ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
|
|
|
20
18
|
pl.Int32: ["integer", "int", "int4"],
|
|
21
19
|
pl.Int64: ["bigint", "long", "int8"],
|
|
22
20
|
pl.Float32: ["float", "float4", "real"],
|
|
23
|
-
pl.Float64: ["double", "float8"
|
|
21
|
+
pl.Float64: ["double", "float8"],
|
|
24
22
|
pl.Boolean: ["boolean", "bool", "logical"],
|
|
25
23
|
pl.Date: ["date"],
|
|
26
24
|
pl.Time: ["time"],
|
|
@@ -29,18 +27,28 @@ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
|
|
|
29
27
|
pl.Binary: ["blob", "binary", "varbinary"]
|
|
30
28
|
}
|
|
31
29
|
|
|
32
|
-
sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {
|
|
30
|
+
sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {
|
|
31
|
+
sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v
|
|
32
|
+
}
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
## Other utility classes
|
|
36
36
|
|
|
37
37
|
class Logger(logging.Logger):
|
|
38
|
-
def
|
|
38
|
+
def info(self, msg: str, *, data: dict[str, Any] = {}, **kwargs) -> None:
|
|
39
|
+
super().info(msg, extra={"data": data}, **kwargs)
|
|
40
|
+
|
|
41
|
+
def log_activity_time(self, activity: str, start_timestamp: float, *, additional_data: dict[str, Any] = {}) -> None:
|
|
39
42
|
end_timestamp = time.time()
|
|
40
43
|
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
41
|
-
data = {
|
|
42
|
-
|
|
43
|
-
|
|
44
|
+
data = {
|
|
45
|
+
"activity": activity,
|
|
46
|
+
"start_timestamp": start_timestamp,
|
|
47
|
+
"end_timestamp": end_timestamp,
|
|
48
|
+
"time_taken_ms": time_taken,
|
|
49
|
+
**additional_data
|
|
50
|
+
}
|
|
51
|
+
self.info(f'Time taken for "{activity}": {time_taken}ms', data=data)
|
|
44
52
|
|
|
45
53
|
|
|
46
54
|
class EnvironmentWithMacros(j2.Environment):
|
|
@@ -85,15 +93,7 @@ class EnvironmentWithMacros(j2.Environment):
|
|
|
85
93
|
|
|
86
94
|
## Utility functions/variables
|
|
87
95
|
|
|
88
|
-
def
|
|
89
|
-
end_timestamp = time.time()
|
|
90
|
-
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
91
|
-
data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
|
|
92
|
-
info = { "request_id": request_id } if request_id else {}
|
|
93
|
-
logger.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
|
|
96
|
+
def render_string(raw_str: str, *, project_path: str = ".", **kwargs) -> str:
|
|
97
97
|
"""
|
|
98
98
|
Given a template string, render it with the given keyword arguments
|
|
99
99
|
|
|
@@ -104,7 +104,7 @@ def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
|
|
|
104
104
|
Returns:
|
|
105
105
|
The rendered string
|
|
106
106
|
"""
|
|
107
|
-
j2_env = j2.Environment(loader=j2.FileSystemLoader(
|
|
107
|
+
j2_env = j2.Environment(loader=j2.FileSystemLoader(project_path))
|
|
108
108
|
template = j2_env.from_string(raw_str)
|
|
109
109
|
return template.render(kwargs)
|
|
110
110
|
|
|
@@ -128,7 +128,7 @@ def read_file(filepath: FilePath) -> str:
|
|
|
128
128
|
|
|
129
129
|
def normalize_name(name: str) -> str:
|
|
130
130
|
"""
|
|
131
|
-
Normalizes names to the convention of the squirrels manifest file.
|
|
131
|
+
Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
|
|
132
132
|
|
|
133
133
|
Arguments:
|
|
134
134
|
name: The name to normalize.
|
|
@@ -141,7 +141,7 @@ def normalize_name(name: str) -> str:
|
|
|
141
141
|
|
|
142
142
|
def normalize_name_for_api(name: str) -> str:
|
|
143
143
|
"""
|
|
144
|
-
Normalizes names to the REST API convention.
|
|
144
|
+
Normalizes names to the REST API convention (with dashes instead of underscores).
|
|
145
145
|
|
|
146
146
|
Arguments:
|
|
147
147
|
name: The name to normalize.
|
|
@@ -196,8 +196,10 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
|
|
|
196
196
|
return processor(input_val)
|
|
197
197
|
|
|
198
198
|
|
|
199
|
-
|
|
200
|
-
|
|
199
|
+
def _read_duckdb_init_sql(
|
|
200
|
+
*,
|
|
201
|
+
datalake_db_path: str | None = None,
|
|
202
|
+
) -> str:
|
|
201
203
|
"""
|
|
202
204
|
Reads and caches the duckdb init file content.
|
|
203
205
|
Returns None if file doesn't exist or is empty.
|
|
@@ -212,35 +214,38 @@ def _read_duckdb_init_sql() -> tuple[str, Path | None]:
|
|
|
212
214
|
if Path(c.DUCKDB_INIT_FILE).exists():
|
|
213
215
|
with open(c.DUCKDB_INIT_FILE, 'r') as f:
|
|
214
216
|
init_contents.append(f.read())
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
return init_sql
|
|
217
|
+
|
|
218
|
+
if datalake_db_path:
|
|
219
|
+
attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
|
|
220
|
+
init_contents.append(attach_stmt)
|
|
221
|
+
use_stmt = f"USE vdl;"
|
|
222
|
+
init_contents.append(use_stmt)
|
|
223
|
+
|
|
224
|
+
init_sql = "\n\n".join(init_contents).strip()
|
|
225
|
+
return init_sql
|
|
224
226
|
except Exception as e:
|
|
225
227
|
raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
226
228
|
|
|
227
|
-
def create_duckdb_connection(
|
|
229
|
+
def create_duckdb_connection(
|
|
230
|
+
db_path: str | Path = ":memory:",
|
|
231
|
+
*,
|
|
232
|
+
datalake_db_path: str | None = None
|
|
233
|
+
) -> duckdb.DuckDBPyConnection:
|
|
228
234
|
"""
|
|
229
235
|
Creates a DuckDB connection and initializes it with statements from duckdb init file
|
|
230
236
|
|
|
231
237
|
Arguments:
|
|
232
238
|
filepath: Path to the DuckDB database file. Defaults to in-memory database.
|
|
233
|
-
|
|
239
|
+
datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
|
|
234
240
|
|
|
235
241
|
Returns:
|
|
236
242
|
A DuckDB connection (which must be closed after use)
|
|
237
243
|
"""
|
|
238
|
-
conn = duckdb.connect(
|
|
244
|
+
conn = duckdb.connect(db_path)
|
|
239
245
|
|
|
240
246
|
try:
|
|
241
|
-
init_sql
|
|
242
|
-
|
|
243
|
-
conn.execute(init_sql)
|
|
247
|
+
init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
|
|
248
|
+
conn.execute(init_sql)
|
|
244
249
|
except Exception as e:
|
|
245
250
|
conn.close()
|
|
246
251
|
raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
@@ -272,6 +277,114 @@ def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -
|
|
|
272
277
|
return result_df
|
|
273
278
|
|
|
274
279
|
|
|
280
|
+
async def run_polars_sql_on_dataframes(
|
|
281
|
+
sql_query: str, dataframes: dict[str, pl.LazyFrame], *, timeout_seconds: float = 2.0, max_rows: int | None = None
|
|
282
|
+
) -> pl.DataFrame:
|
|
283
|
+
"""
|
|
284
|
+
Runs a SQL query against a collection of dataframes using Polars SQL (more secure than DuckDB for user input).
|
|
285
|
+
|
|
286
|
+
Arguments:
|
|
287
|
+
sql_query: The SQL query to run (Polars SQL dialect)
|
|
288
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
289
|
+
timeout_seconds: Maximum execution time in seconds (default 2.0)
|
|
290
|
+
max_rows: Maximum number of rows to collect. Collects at most max_rows + 1 rows
|
|
291
|
+
to allow overflow detection without loading unbounded results into memory.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
The result as a polars DataFrame from running the query (limited to max_rows + 1)
|
|
295
|
+
|
|
296
|
+
Raises:
|
|
297
|
+
ConfigurationError: If the query is invalid or insecure
|
|
298
|
+
"""
|
|
299
|
+
# Validate the SQL query
|
|
300
|
+
_validate_sql_query_security(sql_query, dataframes)
|
|
301
|
+
|
|
302
|
+
# Execute with timeout
|
|
303
|
+
try:
|
|
304
|
+
loop = asyncio.get_event_loop()
|
|
305
|
+
result = await asyncio.wait_for(
|
|
306
|
+
loop.run_in_executor(None, _run_polars_sql_sync, sql_query, dataframes, max_rows),
|
|
307
|
+
timeout=timeout_seconds
|
|
308
|
+
)
|
|
309
|
+
return result
|
|
310
|
+
except asyncio.TimeoutError as e:
|
|
311
|
+
raise ConfigurationError(f"SQL query execution exceeded timeout of {timeout_seconds} seconds") from e
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _run_polars_sql_sync(sql_query: str, dataframes: dict[str, pl.LazyFrame], max_rows: int | None) -> pl.DataFrame:
|
|
315
|
+
"""
|
|
316
|
+
Synchronous execution of Polars SQL.
|
|
317
|
+
|
|
318
|
+
Arguments:
|
|
319
|
+
sql_query: The SQL query to run
|
|
320
|
+
dataframes: A dictionary of table names to their polars LazyFrame
|
|
321
|
+
max_rows: Maximum number of rows to collect.
|
|
322
|
+
"""
|
|
323
|
+
ctx = pl.SQLContext(**dataframes)
|
|
324
|
+
result = ctx.execute(sql_query, eager=False)
|
|
325
|
+
if max_rows is not None:
|
|
326
|
+
result = result.limit(max_rows)
|
|
327
|
+
return result.collect()
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def _validate_sql_query_security(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> None:
|
|
331
|
+
"""
|
|
332
|
+
Validates that a SQL query is safe to execute.
|
|
333
|
+
|
|
334
|
+
Enforces:
|
|
335
|
+
- Single statement only
|
|
336
|
+
- Read-only operations (SELECT/WITH/UNION)
|
|
337
|
+
- Table references limited to registered frames (excluding CTE names)
|
|
338
|
+
|
|
339
|
+
Arguments:
|
|
340
|
+
sql_query: The SQL query to validate
|
|
341
|
+
dataframes: Dictionary of allowed table names
|
|
342
|
+
|
|
343
|
+
Raises:
|
|
344
|
+
ConfigurationError: If validation fails
|
|
345
|
+
"""
|
|
346
|
+
try:
|
|
347
|
+
parsed = sqlglot.parse(sql_query)
|
|
348
|
+
except Exception as e:
|
|
349
|
+
raise ConfigurationError(f"Failed to parse SQL query: {str(e)}") from e
|
|
350
|
+
|
|
351
|
+
# Enforce single statement
|
|
352
|
+
if len(parsed) != 1:
|
|
353
|
+
raise ConfigurationError(f"Only single SQL statements are allowed. Found {len(parsed)} statements.")
|
|
354
|
+
|
|
355
|
+
statement = parsed[0]
|
|
356
|
+
|
|
357
|
+
# Enforce read-only: allow SELECT, WITH (CTE), UNION, INTERSECT, EXCEPT
|
|
358
|
+
allowed_types = (
|
|
359
|
+
sqlglot.expressions.Select,
|
|
360
|
+
sqlglot.expressions.Union,
|
|
361
|
+
sqlglot.expressions.Intersect,
|
|
362
|
+
sqlglot.expressions.Except,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
if not isinstance(statement, allowed_types):
|
|
366
|
+
raise ConfigurationError(
|
|
367
|
+
f"Only read-only SQL statements (SELECT, WITH, UNION, INTERSECT, EXCEPT) are allowed. "
|
|
368
|
+
f"Found: {type(statement).__name__}"
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# Collect CTE names (these are temporary tables created by WITH clauses)
|
|
372
|
+
cte_names: set[str] = set()
|
|
373
|
+
for cte in statement.find_all(sqlglot.expressions.CTE):
|
|
374
|
+
if cte.alias:
|
|
375
|
+
cte_names.add(cte.alias)
|
|
376
|
+
|
|
377
|
+
# Validate table references (excluding CTE names)
|
|
378
|
+
allowed_tables = set(dataframes.keys()) | cte_names
|
|
379
|
+
for table in statement.find_all(sqlglot.expressions.Table):
|
|
380
|
+
table_name = table.name
|
|
381
|
+
if table_name not in allowed_tables:
|
|
382
|
+
raise ConfigurationError(
|
|
383
|
+
f"Table reference '{table_name}' is not allowed. "
|
|
384
|
+
f"Only the following tables are available: {sorted(dataframes.keys())}"
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
|
|
275
388
|
def load_yaml_config(filepath: FilePath) -> dict:
|
|
276
389
|
"""
|
|
277
390
|
Loads a YAML config file
|
|
@@ -284,7 +397,13 @@ def load_yaml_config(filepath: FilePath) -> dict:
|
|
|
284
397
|
"""
|
|
285
398
|
try:
|
|
286
399
|
with open(filepath, 'r') as f:
|
|
287
|
-
|
|
400
|
+
content = yaml.safe_load(f)
|
|
401
|
+
content = content if content else {}
|
|
402
|
+
|
|
403
|
+
if not isinstance(content, dict):
|
|
404
|
+
raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
|
|
405
|
+
|
|
406
|
+
return content
|
|
288
407
|
except yaml.YAMLError as e:
|
|
289
408
|
raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
|
|
290
409
|
|
|
@@ -308,7 +427,7 @@ def run_duckdb_stmt(
|
|
|
308
427
|
redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
|
|
309
428
|
|
|
310
429
|
for_model_name = f" for model '{model_name}'" if model_name is not None else ""
|
|
311
|
-
logger.
|
|
430
|
+
logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}")
|
|
312
431
|
try:
|
|
313
432
|
return duckdb_conn.execute(stmt, params)
|
|
314
433
|
except duckdb.ParserException as e:
|
|
@@ -359,3 +478,63 @@ async def asyncio_gather(coroutines: list):
|
|
|
359
478
|
# Wait for tasks to be cancelled
|
|
360
479
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
361
480
|
raise
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def hash_string(input_str: str, salt: str) -> str:
|
|
484
|
+
"""
|
|
485
|
+
Hashes a string using SHA-256
|
|
486
|
+
"""
|
|
487
|
+
return hashlib.sha256((input_str + salt).encode()).hexdigest()
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
T = TypeVar('T')
|
|
491
|
+
def call_func(func: Callable[..., T], **kwargs) -> T:
|
|
492
|
+
"""
|
|
493
|
+
Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
|
|
494
|
+
"""
|
|
495
|
+
sig = inspect.signature(func)
|
|
496
|
+
# Filter kwargs to only include parameters that the function accepts
|
|
497
|
+
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
498
|
+
return func(**filtered_kwargs)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
def generate_pkce_challenge(code_verifier: str) -> str:
|
|
502
|
+
"""Generate PKCE code challenge from code verifier"""
|
|
503
|
+
# Generate SHA256 hash of code_verifier
|
|
504
|
+
verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
|
|
505
|
+
# Base64 URL encode (without padding)
|
|
506
|
+
expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
|
|
507
|
+
return expected_challenge
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def to_title_case(input_str: str) -> str:
|
|
511
|
+
"""Convert a string to title case"""
|
|
512
|
+
spaced_str = input_str.replace('_', ' ').replace('-', ' ')
|
|
513
|
+
return spaced_str.title()
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def to_bool(val: object) -> bool:
|
|
517
|
+
"""Convert common truthy/falsey representations to a boolean.
|
|
518
|
+
|
|
519
|
+
Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
|
|
520
|
+
All other values are considered falsey. None is falsey.
|
|
521
|
+
"""
|
|
522
|
+
if isinstance(val, bool):
|
|
523
|
+
return val
|
|
524
|
+
if val is None:
|
|
525
|
+
return False
|
|
526
|
+
s = str(val).strip().lower()
|
|
527
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
ACCESS_LEVEL = Literal["admin", "member", "guest"]
|
|
531
|
+
|
|
532
|
+
def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
|
|
533
|
+
"""Get the rank of an access level. Lower ranks have more privileges."""
|
|
534
|
+
return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
|
|
535
|
+
|
|
536
|
+
def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
|
|
537
|
+
"""Check if a user has privilege to access a resource"""
|
|
538
|
+
user_access_level_rank = get_access_level_rank(user_access_level)
|
|
539
|
+
required_access_level_rank = get_access_level_rank(required_access_level)
|
|
540
|
+
return user_access_level_rank <= required_access_level_rank
|
squirrels/_version.py
CHANGED
squirrels/arguments.py
CHANGED
|
@@ -1,2 +1,7 @@
|
|
|
1
|
-
from ._arguments.
|
|
2
|
-
from ._arguments.
|
|
1
|
+
from ._arguments.init_time_args import ConnectionsArgs, AuthProviderArgs, ParametersArgs, BuildModelArgs
|
|
2
|
+
from ._arguments.run_time_args import ContextArgs, ModelArgs, DashboardArgs
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"ConnectionsArgs", "AuthProviderArgs", "ParametersArgs", "BuildModelArgs",
|
|
6
|
+
"ContextArgs", "ModelArgs", "DashboardArgs"
|
|
7
|
+
]
|
squirrels/auth.py
ADDED
squirrels/connections.py
CHANGED
squirrels/dashboards.py
CHANGED
squirrels/data_sources.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from ._data_sources import (
|
|
2
|
+
SourceEnum,
|
|
2
3
|
SelectDataSource,
|
|
3
4
|
DateDataSource,
|
|
4
5
|
DateRangeDataSource,
|
|
@@ -6,3 +7,8 @@ from ._data_sources import (
|
|
|
6
7
|
NumberRangeDataSource,
|
|
7
8
|
TextDataSource
|
|
8
9
|
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"SourceEnum", "SelectDataSource", "DateDataSource", "DateRangeDataSource",
|
|
13
|
+
"NumberDataSource", "NumberRangeDataSource", "TextDataSource"
|
|
14
|
+
]
|
squirrels/parameter_options.py
CHANGED
|
@@ -6,3 +6,8 @@ from ._parameter_options import (
|
|
|
6
6
|
NumberRangeParameterOption,
|
|
7
7
|
TextParameterOption
|
|
8
8
|
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"SelectParameterOption", "DateParameterOption", "DateRangeParameterOption",
|
|
12
|
+
"NumberParameterOption", "NumberRangeParameterOption", "TextParameterOption"
|
|
13
|
+
]
|
squirrels/parameters.py
CHANGED