squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. squirrels/__init__.py +4 -0
  2. squirrels/_api_routes/__init__.py +5 -0
  3. squirrels/_api_routes/auth.py +337 -0
  4. squirrels/_api_routes/base.py +196 -0
  5. squirrels/_api_routes/dashboards.py +156 -0
  6. squirrels/_api_routes/data_management.py +148 -0
  7. squirrels/_api_routes/datasets.py +220 -0
  8. squirrels/_api_routes/project.py +289 -0
  9. squirrels/_api_server.py +440 -792
  10. squirrels/_arguments/__init__.py +0 -0
  11. squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
  12. squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
  13. squirrels/_auth.py +590 -264
  14. squirrels/_command_line.py +130 -58
  15. squirrels/_compile_prompts.py +147 -0
  16. squirrels/_connection_set.py +16 -15
  17. squirrels/_constants.py +36 -11
  18. squirrels/_dashboards.py +179 -0
  19. squirrels/_data_sources.py +40 -34
  20. squirrels/_dataset_types.py +16 -11
  21. squirrels/_env_vars.py +209 -0
  22. squirrels/_exceptions.py +9 -37
  23. squirrels/_http_error_responses.py +52 -0
  24. squirrels/_initializer.py +7 -6
  25. squirrels/_logging.py +121 -0
  26. squirrels/_manifest.py +155 -77
  27. squirrels/_mcp_server.py +578 -0
  28. squirrels/_model_builder.py +11 -55
  29. squirrels/_model_configs.py +5 -5
  30. squirrels/_model_queries.py +1 -1
  31. squirrels/_models.py +276 -143
  32. squirrels/_package_data/base_project/.env +1 -24
  33. squirrels/_package_data/base_project/.env.example +31 -17
  34. squirrels/_package_data/base_project/connections.yml +4 -3
  35. squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
  36. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
  37. squirrels/_package_data/base_project/docker/Dockerfile +2 -2
  38. squirrels/_package_data/base_project/docker/compose.yml +1 -1
  39. squirrels/_package_data/base_project/duckdb_init.sql +1 -0
  40. squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
  41. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
  42. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
  43. squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
  44. squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
  45. squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
  46. squirrels/_package_data/base_project/models/sources.yml +5 -6
  47. squirrels/_package_data/base_project/parameters.yml +24 -38
  48. squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
  49. squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
  50. squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
  51. squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
  52. squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
  53. squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
  54. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
  55. squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
  56. squirrels/_package_data/templates/login_successful.html +53 -0
  57. squirrels/_package_data/templates/squirrels_studio.html +22 -0
  58. squirrels/_parameter_configs.py +43 -22
  59. squirrels/_parameter_options.py +1 -1
  60. squirrels/_parameter_sets.py +41 -30
  61. squirrels/_parameters.py +560 -123
  62. squirrels/_project.py +487 -277
  63. squirrels/_py_module.py +71 -10
  64. squirrels/_request_context.py +33 -0
  65. squirrels/_schemas/__init__.py +0 -0
  66. squirrels/_schemas/auth_models.py +83 -0
  67. squirrels/_schemas/query_param_models.py +70 -0
  68. squirrels/_schemas/request_models.py +26 -0
  69. squirrels/_schemas/response_models.py +286 -0
  70. squirrels/_seeds.py +52 -13
  71. squirrels/_sources.py +29 -23
  72. squirrels/_utils.py +221 -42
  73. squirrels/_version.py +1 -3
  74. squirrels/arguments.py +7 -2
  75. squirrels/auth.py +4 -0
  76. squirrels/connections.py +2 -0
  77. squirrels/dashboards.py +3 -1
  78. squirrels/data_sources.py +6 -0
  79. squirrels/parameter_options.py +5 -0
  80. squirrels/parameters.py +5 -0
  81. squirrels/types.py +10 -3
  82. squirrels-0.6.0.post0.dist-info/METADATA +148 -0
  83. squirrels-0.6.0.post0.dist-info/RECORD +101 -0
  84. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
  85. squirrels/_api_response_models.py +0 -190
  86. squirrels/_dashboard_types.py +0 -82
  87. squirrels/_dashboards_io.py +0 -79
  88. squirrels-0.5.0b3.dist-info/METADATA +0 -110
  89. squirrels-0.5.0b3.dist-info/RECORD +0 -80
  90. /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
  91. /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
  92. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
  93. {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_seeds.py CHANGED
@@ -1,7 +1,15 @@
1
1
  from dataclasses import dataclass
2
- import os, time, glob, polars as pl, json
2
+ import os
3
+ import re
4
+ import time
5
+ import glob
6
+ import json
3
7
 
8
+ import polars as pl
9
+
10
+ from ._exceptions import ConfigurationError
4
11
  from . import _utils as u, _constants as c, _model_configs as mc
12
+ from ._env_vars import SquirrelsEnvVars
5
13
 
6
14
 
7
15
  @dataclass
@@ -13,21 +21,47 @@ class Seed:
13
21
  if self.config.cast_column_types:
14
22
  exprs = []
15
23
  for col_config in self.config.columns:
16
- sqrl_dtype = "double" if col_config.type.lower().startswith("decimal") else col_config.type
17
- polars_dtype = u.sqrl_dtypes_to_polars_dtypes.get(sqrl_dtype, pl.String)
24
+ col_type = col_config.type.lower()
25
+ if col_type.startswith("decimal"):
26
+ polars_dtype = self._parse_decimal_type(col_type)
27
+ else:
28
+ try:
29
+ polars_dtype = u.sqrl_dtypes_to_polars_dtypes[col_type]
30
+ except KeyError as e:
31
+ raise ConfigurationError(f"Unknown column type: '{col_type}'") from e
32
+
18
33
  exprs.append(pl.col(col_config.name).cast(polars_dtype))
19
34
 
20
35
  self.df = self.df.with_columns(*exprs)
21
36
 
37
+ @staticmethod
38
+ def _parse_decimal_type(col_type: str) -> pl.Decimal:
39
+ """Parse a decimal type string and return the appropriate polars Decimal type.
40
+
41
+ Supports formats: "decimal" or "decimal(precision, scale)"
42
+ """
43
+
44
+ # Match decimal(precision, scale) pattern
45
+ match = re.match(r"decimal\s*\(\s*(\d+)\s*,\s*(\d+)\s*\)", col_type)
46
+ if match:
47
+ precision = int(match.group(1))
48
+ scale = int(match.group(2))
49
+ return pl.Decimal(precision=precision, scale=scale)
50
+
51
+ if col_type == "decimal":
52
+ return pl.Decimal(precision=18, scale=2)
53
+
54
+ raise ConfigurationError(f"Unknown column type: '{col_type}'")
55
+
22
56
 
23
57
  @dataclass
24
58
  class Seeds:
25
59
  _data: dict[str, Seed]
26
-
60
+
27
61
  def run_query(self, sql_query: str) -> pl.DataFrame:
28
62
  dataframes = {key: seed.df for key, seed in self._data.items()}
29
63
  return u.run_sql_on_dataframes(sql_query, dataframes)
30
-
64
+
31
65
  def get_dataframes(self) -> dict[str, Seed]:
32
66
  return self._data.copy()
33
67
 
@@ -35,13 +69,14 @@ class Seeds:
35
69
  class SeedsIO:
36
70
 
37
71
  @classmethod
38
- def load_files(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Seeds:
72
+ def load_files(cls, logger: u.Logger, env_vars: SquirrelsEnvVars) -> Seeds:
39
73
  start = time.time()
40
- infer_schema_setting: bool = (env_vars.get(c.SQRL_SEEDS_INFER_SCHEMA, "true").lower() == "true")
41
- na_values_setting: list[str] = json.loads(env_vars.get(c.SQRL_SEEDS_NA_VALUES, "[]"))
42
-
74
+ project_path = env_vars.project_path
75
+ infer_schema_setting: bool = env_vars.seeds_infer_schema
76
+ na_values_setting: list[str] = env_vars.seeds_na_values
77
+
43
78
  seeds_dict = {}
44
- csv_files = glob.glob(os.path.join(base_path, c.SEEDS_FOLDER, '**/*.csv'), recursive=True)
79
+ csv_files = glob.glob(os.path.join(project_path, c.SEEDS_FOLDER, '**/*.csv'), recursive=True)
45
80
  for csv_file in csv_files:
46
81
  config_file = os.path.splitext(csv_file)[0] + '.yml'
47
82
  config_dict = u.load_yaml_config(config_file) if os.path.exists(config_file) else {}
@@ -49,10 +84,14 @@ class SeedsIO:
49
84
 
50
85
  file_stem = os.path.splitext(os.path.basename(csv_file))[0]
51
86
  infer_schema = not config.cast_column_types and infer_schema_setting
52
- df = pl.read_csv(csv_file, try_parse_dates=True, infer_schema=infer_schema, null_values=na_values_setting).lazy()
53
-
87
+ df = pl.read_csv(
88
+ csv_file, try_parse_dates=True,
89
+ infer_schema=infer_schema,
90
+ null_values=na_values_setting
91
+ ).lazy()
92
+
54
93
  seeds_dict[file_stem] = Seed(config, df)
55
-
94
+
56
95
  seeds = Seeds(seeds_dict)
57
96
  logger.log_activity_time("loading seed files", start)
58
97
  return seeds
squirrels/_sources.py CHANGED
@@ -1,19 +1,20 @@
1
1
  from typing import Any
2
2
  from pydantic import BaseModel, Field, model_validator
3
- import time, sqlglot
3
+ import time, yaml
4
4
 
5
5
  from . import _utils as u, _constants as c, _model_configs as mc
6
+ from ._env_vars import SquirrelsEnvVars
6
7
 
7
8
 
8
9
  class UpdateHints(BaseModel):
9
10
  increasing_column: str | None = Field(default=None)
10
- strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if value is set")
11
- selective_overwrite_value: Any = Field(default=None)
11
+ strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if selective_overwrite_value is set")
12
+ selective_overwrite_value: Any = Field(default=None, description="Delete all values of the increasing column greater than or equal to this value")
12
13
 
13
14
 
14
15
  class Source(mc.ConnectionInterface, mc.ModelConfig):
15
16
  table: str | None = Field(default=None)
16
- load_to_duckdb: bool = Field(default=False, description="Whether to load the data to DuckDB")
17
+ load_to_vdl: bool = Field(default=False, description="Whether to load the data to the 'virtual data lake' (VDL)")
17
18
  primary_key: list[str] = Field(default_factory=list)
18
19
  update_hints: UpdateHints = Field(default_factory=UpdateHints)
19
20
 
@@ -28,34 +29,28 @@ class Source(mc.ConnectionInterface, mc.ModelConfig):
28
29
 
29
30
  def get_cols_for_create_table_stmt(self) -> str:
30
31
  cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
31
- primary_key_clause = f", PRIMARY KEY ({', '.join(self.primary_key)})" if self.primary_key else ""
32
- return f"{cols_clause}{primary_key_clause}"
33
-
34
- def get_cols_for_insert_stmt(self) -> str:
35
- return ", ".join([col.name for col in self.columns])
32
+ return cols_clause
36
33
 
37
34
  def get_max_incr_col_query(self, source_name: str) -> str:
38
35
  return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
39
36
 
40
- def get_query_for_insert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
41
- select_cols = self.get_cols_for_insert_stmt()
37
+ def get_query_for_upsert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
38
+ select_cols = ", ".join([col.name for col in self.columns])
42
39
  if full_refresh or max_value_of_increasing_col is None:
43
40
  return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
44
41
 
45
42
  increasing_col = self.update_hints.increasing_column
46
43
  increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
47
44
  where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
48
- pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
49
45
 
50
- if dialect in ['postgres', 'mysql']:
51
- transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
52
- return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
46
+ # TODO: figure out if using pushdown query is worth it
47
+ # if dialect in ['postgres', 'mysql']:
48
+ # pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
49
+ # transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
50
+ # return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
53
51
 
54
52
  return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
55
53
 
56
- def get_insert_replace_clause(self) -> str:
57
- return "" if len(self.primary_key) == 0 else "OR REPLACE"
58
-
59
54
 
60
55
  class Sources(BaseModel):
61
56
  sources: dict[str, Source] = Field(default_factory=dict)
@@ -85,20 +80,31 @@ class Sources(BaseModel):
85
80
  raise u.ConfigurationError(f"Column '{col.name}' in source '{source_name}' must have a type specified")
86
81
  return self
87
82
 
88
- def finalize_null_fields(self, env_vars: dict[str, str]):
83
+ def finalize_null_fields(self, env_vars: SquirrelsEnvVars):
84
+ default_conn_name = env_vars.connections_default_name_used
89
85
  for source_name, source in self.sources.items():
90
- source.finalize_connection(env_vars)
86
+ source.finalize_connection(default_conn_name=default_conn_name)
91
87
  source.finalize_table(source_name)
92
88
  return self
93
89
 
94
90
 
95
91
  class SourcesIO:
96
92
  @classmethod
97
- def load_file(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Sources:
93
+ def load_file(cls, logger: u.Logger, env_vars: SquirrelsEnvVars, env_vars_unformatted: dict[str, str]) -> Sources:
98
94
  start = time.time()
99
95
 
100
- sources_path = u.Path(base_path, c.MODELS_FOLDER, c.SOURCES_FILE)
101
- sources_data = u.load_yaml_config(sources_path) if sources_path.exists() else {}
96
+ sources_path = u.Path(env_vars.project_path, c.MODELS_FOLDER, c.SOURCES_FILE)
97
+ if sources_path.exists():
98
+ raw_content = u.read_file(sources_path)
99
+ rendered = u.render_string(raw_content, project_path=env_vars.project_path, env_vars=env_vars_unformatted)
100
+ sources_data = yaml.safe_load(rendered) or {}
101
+ else:
102
+ sources_data = {}
103
+
104
+ if not isinstance(sources_data, dict):
105
+ raise u.ConfigurationError(
106
+ f"Parsed content from YAML file must be a dictionary. Got: {sources_data}"
107
+ )
102
108
 
103
109
  sources = Sources(**sources_data).finalize_null_fields(env_vars)
104
110
 
squirrels/_utils.py CHANGED
@@ -1,18 +1,16 @@
1
- from typing import Sequence, Optional, Union, TypeVar, Callable, Any, Iterable
1
+ from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
2
2
  from datetime import datetime
3
3
  from pathlib import Path
4
- from functools import lru_cache
5
- from pydantic import BaseModel
6
4
  import os, time, logging, json, duckdb, polars as pl, yaml
7
5
  import jinja2 as j2, jinja2.nodes as j2_nodes
8
- import sqlglot, sqlglot.expressions, asyncio
6
+ import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
9
7
 
10
8
  from . import _constants as c
11
9
  from ._exceptions import ConfigurationError
12
10
 
13
11
  FilePath = Union[str, Path]
14
12
 
15
- # Polars
13
+ # Polars <-> Squirrels dtypes mappings (except Decimal)
16
14
  polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
17
15
  pl.String: ["string", "varchar", "char", "text"],
18
16
  pl.Int8: ["tinyint", "int1"],
@@ -20,7 +18,7 @@ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
20
18
  pl.Int32: ["integer", "int", "int4"],
21
19
  pl.Int64: ["bigint", "long", "int8"],
22
20
  pl.Float32: ["float", "float4", "real"],
23
- pl.Float64: ["double", "float8", "decimal"], # Note: Polars Decimal type is considered unstable, so we use Float64 for "decimal"
21
+ pl.Float64: ["double", "float8"],
24
22
  pl.Boolean: ["boolean", "bool", "logical"],
25
23
  pl.Date: ["date"],
26
24
  pl.Time: ["time"],
@@ -29,18 +27,28 @@ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
29
27
  pl.Binary: ["blob", "binary", "varbinary"]
30
28
  }
31
29
 
32
- sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v}
30
+ sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {
31
+ sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v
32
+ }
33
33
 
34
34
 
35
35
  ## Other utility classes
36
36
 
37
37
  class Logger(logging.Logger):
38
- def log_activity_time(self, activity: str, start_timestamp: float, *, request_id: str | None = None) -> None:
38
+ def info(self, msg: str, *, data: dict[str, Any] = {}, **kwargs) -> None:
39
+ super().info(msg, extra={"data": data}, **kwargs)
40
+
41
+ def log_activity_time(self, activity: str, start_timestamp: float, *, additional_data: dict[str, Any] = {}) -> None:
39
42
  end_timestamp = time.time()
40
43
  time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
41
- data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
42
- info = { "request_id": request_id } if request_id else {}
43
- self.info(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
44
+ data = {
45
+ "activity": activity,
46
+ "start_timestamp": start_timestamp,
47
+ "end_timestamp": end_timestamp,
48
+ "time_taken_ms": time_taken,
49
+ **additional_data
50
+ }
51
+ self.info(f'Time taken for "{activity}": {time_taken}ms', data=data)
44
52
 
45
53
 
46
54
  class EnvironmentWithMacros(j2.Environment):
@@ -85,15 +93,7 @@ class EnvironmentWithMacros(j2.Environment):
85
93
 
86
94
  ## Utility functions/variables
87
95
 
88
- def log_activity_time(logger: logging.Logger, activity: str, start_timestamp: float, *, request_id: str | None = None) -> None:
89
- end_timestamp = time.time()
90
- time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
91
- data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
92
- info = { "request_id": request_id } if request_id else {}
93
- logger.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
94
-
95
-
96
- def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
96
+ def render_string(raw_str: str, *, project_path: str = ".", **kwargs) -> str:
97
97
  """
98
98
  Given a template string, render it with the given keyword arguments
99
99
 
@@ -104,7 +104,7 @@ def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
104
104
  Returns:
105
105
  The rendered string
106
106
  """
107
- j2_env = j2.Environment(loader=j2.FileSystemLoader(base_path))
107
+ j2_env = j2.Environment(loader=j2.FileSystemLoader(project_path))
108
108
  template = j2_env.from_string(raw_str)
109
109
  return template.render(kwargs)
110
110
 
@@ -128,7 +128,7 @@ def read_file(filepath: FilePath) -> str:
128
128
 
129
129
  def normalize_name(name: str) -> str:
130
130
  """
131
- Normalizes names to the convention of the squirrels manifest file.
131
+ Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
132
132
 
133
133
  Arguments:
134
134
  name: The name to normalize.
@@ -141,7 +141,7 @@ def normalize_name(name: str) -> str:
141
141
 
142
142
  def normalize_name_for_api(name: str) -> str:
143
143
  """
144
- Normalizes names to the REST API convention.
144
+ Normalizes names to the REST API convention (with dashes instead of underscores).
145
145
 
146
146
  Arguments:
147
147
  name: The name to normalize.
@@ -196,8 +196,10 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
196
196
  return processor(input_val)
197
197
 
198
198
 
199
- @lru_cache(maxsize=1)
200
- def _read_duckdb_init_sql() -> tuple[str, Path | None]:
199
+ def _read_duckdb_init_sql(
200
+ *,
201
+ datalake_db_path: str | None = None,
202
+ ) -> str:
201
203
  """
202
204
  Reads and caches the duckdb init file content.
203
205
  Returns None if file doesn't exist or is empty.
@@ -212,35 +214,38 @@ def _read_duckdb_init_sql() -> tuple[str, Path | None]:
212
214
  if Path(c.DUCKDB_INIT_FILE).exists():
213
215
  with open(c.DUCKDB_INIT_FILE, 'r') as f:
214
216
  init_contents.append(f.read())
215
-
216
- init_sql = "\n".join(init_contents).strip()
217
- target_init_path = None
218
- if init_sql:
219
- target_init_path = Path(c.TARGET_FOLDER, c.DUCKDB_INIT_FILE)
220
- target_init_path.parent.mkdir(parents=True, exist_ok=True)
221
- target_init_path.write_text(init_sql)
222
-
223
- return init_sql, target_init_path
217
+
218
+ if datalake_db_path:
219
+ attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
220
+ init_contents.append(attach_stmt)
221
+ use_stmt = f"USE vdl;"
222
+ init_contents.append(use_stmt)
223
+
224
+ init_sql = "\n\n".join(init_contents).strip()
225
+ return init_sql
224
226
  except Exception as e:
225
227
  raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
226
228
 
227
- def create_duckdb_connection(filepath: str | Path = ":memory:", *, read_only: bool = False) -> duckdb.DuckDBPyConnection:
229
+ def create_duckdb_connection(
230
+ db_path: str | Path = ":memory:",
231
+ *,
232
+ datalake_db_path: str | None = None
233
+ ) -> duckdb.DuckDBPyConnection:
228
234
  """
229
235
  Creates a DuckDB connection and initializes it with statements from duckdb init file
230
236
 
231
237
  Arguments:
232
238
  filepath: Path to the DuckDB database file. Defaults to in-memory database.
233
- read_only: Whether to open the database in read-only mode. Defaults to False.
239
+ datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
234
240
 
235
241
  Returns:
236
242
  A DuckDB connection (which must be closed after use)
237
243
  """
238
- conn = duckdb.connect(filepath, read_only=read_only)
244
+ conn = duckdb.connect(db_path)
239
245
 
240
246
  try:
241
- init_sql, _ = _read_duckdb_init_sql()
242
- if init_sql:
243
- conn.execute(init_sql)
247
+ init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
248
+ conn.execute(init_sql)
244
249
  except Exception as e:
245
250
  conn.close()
246
251
  raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
@@ -272,6 +277,114 @@ def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -
272
277
  return result_df
273
278
 
274
279
 
280
+ async def run_polars_sql_on_dataframes(
281
+ sql_query: str, dataframes: dict[str, pl.LazyFrame], *, timeout_seconds: float = 2.0, max_rows: int | None = None
282
+ ) -> pl.DataFrame:
283
+ """
284
+ Runs a SQL query against a collection of dataframes using Polars SQL (more secure than DuckDB for user input).
285
+
286
+ Arguments:
287
+ sql_query: The SQL query to run (Polars SQL dialect)
288
+ dataframes: A dictionary of table names to their polars LazyFrame
289
+ timeout_seconds: Maximum execution time in seconds (default 2.0)
290
+ max_rows: Maximum number of rows to collect. Collects at most max_rows + 1 rows
291
+ to allow overflow detection without loading unbounded results into memory.
292
+
293
+ Returns:
294
+ The result as a polars DataFrame from running the query (limited to max_rows + 1)
295
+
296
+ Raises:
297
+ ConfigurationError: If the query is invalid or insecure
298
+ """
299
+ # Validate the SQL query
300
+ _validate_sql_query_security(sql_query, dataframes)
301
+
302
+ # Execute with timeout
303
+ try:
304
+ loop = asyncio.get_event_loop()
305
+ result = await asyncio.wait_for(
306
+ loop.run_in_executor(None, _run_polars_sql_sync, sql_query, dataframes, max_rows),
307
+ timeout=timeout_seconds
308
+ )
309
+ return result
310
+ except asyncio.TimeoutError as e:
311
+ raise ConfigurationError(f"SQL query execution exceeded timeout of {timeout_seconds} seconds") from e
312
+
313
+
314
+ def _run_polars_sql_sync(sql_query: str, dataframes: dict[str, pl.LazyFrame], max_rows: int | None) -> pl.DataFrame:
315
+ """
316
+ Synchronous execution of Polars SQL.
317
+
318
+ Arguments:
319
+ sql_query: The SQL query to run
320
+ dataframes: A dictionary of table names to their polars LazyFrame
321
+ max_rows: Maximum number of rows to collect.
322
+ """
323
+ ctx = pl.SQLContext(**dataframes)
324
+ result = ctx.execute(sql_query, eager=False)
325
+ if max_rows is not None:
326
+ result = result.limit(max_rows)
327
+ return result.collect()
328
+
329
+
330
+ def _validate_sql_query_security(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> None:
331
+ """
332
+ Validates that a SQL query is safe to execute.
333
+
334
+ Enforces:
335
+ - Single statement only
336
+ - Read-only operations (SELECT/WITH/UNION)
337
+ - Table references limited to registered frames (excluding CTE names)
338
+
339
+ Arguments:
340
+ sql_query: The SQL query to validate
341
+ dataframes: Dictionary of allowed table names
342
+
343
+ Raises:
344
+ ConfigurationError: If validation fails
345
+ """
346
+ try:
347
+ parsed = sqlglot.parse(sql_query)
348
+ except Exception as e:
349
+ raise ConfigurationError(f"Failed to parse SQL query: {str(e)}") from e
350
+
351
+ # Enforce single statement
352
+ if len(parsed) != 1:
353
+ raise ConfigurationError(f"Only single SQL statements are allowed. Found {len(parsed)} statements.")
354
+
355
+ statement = parsed[0]
356
+
357
+ # Enforce read-only: allow SELECT, WITH (CTE), UNION, INTERSECT, EXCEPT
358
+ allowed_types = (
359
+ sqlglot.expressions.Select,
360
+ sqlglot.expressions.Union,
361
+ sqlglot.expressions.Intersect,
362
+ sqlglot.expressions.Except,
363
+ )
364
+
365
+ if not isinstance(statement, allowed_types):
366
+ raise ConfigurationError(
367
+ f"Only read-only SQL statements (SELECT, WITH, UNION, INTERSECT, EXCEPT) are allowed. "
368
+ f"Found: {type(statement).__name__}"
369
+ )
370
+
371
+ # Collect CTE names (these are temporary tables created by WITH clauses)
372
+ cte_names: set[str] = set()
373
+ for cte in statement.find_all(sqlglot.expressions.CTE):
374
+ if cte.alias:
375
+ cte_names.add(cte.alias)
376
+
377
+ # Validate table references (excluding CTE names)
378
+ allowed_tables = set(dataframes.keys()) | cte_names
379
+ for table in statement.find_all(sqlglot.expressions.Table):
380
+ table_name = table.name
381
+ if table_name not in allowed_tables:
382
+ raise ConfigurationError(
383
+ f"Table reference '{table_name}' is not allowed. "
384
+ f"Only the following tables are available: {sorted(dataframes.keys())}"
385
+ )
386
+
387
+
275
388
  def load_yaml_config(filepath: FilePath) -> dict:
276
389
  """
277
390
  Loads a YAML config file
@@ -284,7 +397,13 @@ def load_yaml_config(filepath: FilePath) -> dict:
284
397
  """
285
398
  try:
286
399
  with open(filepath, 'r') as f:
287
- return yaml.safe_load(f)
400
+ content = yaml.safe_load(f)
401
+ content = content if content else {}
402
+
403
+ if not isinstance(content, dict):
404
+ raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
405
+
406
+ return content
288
407
  except yaml.YAMLError as e:
289
408
  raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
290
409
 
@@ -308,7 +427,7 @@ def run_duckdb_stmt(
308
427
  redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
309
428
 
310
429
  for_model_name = f" for model '{model_name}'" if model_name is not None else ""
311
- logger.info(f"Running SQL statement{for_model_name}:\n{redacted_stmt}", extra={"data": {"params": params}})
430
+ logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}")
312
431
  try:
313
432
  return duckdb_conn.execute(stmt, params)
314
433
  except duckdb.ParserException as e:
@@ -359,3 +478,63 @@ async def asyncio_gather(coroutines: list):
359
478
  # Wait for tasks to be cancelled
360
479
  await asyncio.gather(*tasks, return_exceptions=True)
361
480
  raise
481
+
482
+
483
+ def hash_string(input_str: str, salt: str) -> str:
484
+ """
485
+ Hashes a string using SHA-256
486
+ """
487
+ return hashlib.sha256((input_str + salt).encode()).hexdigest()
488
+
489
+
490
+ T = TypeVar('T')
491
+ def call_func(func: Callable[..., T], **kwargs) -> T:
492
+ """
493
+ Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
494
+ """
495
+ sig = inspect.signature(func)
496
+ # Filter kwargs to only include parameters that the function accepts
497
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
498
+ return func(**filtered_kwargs)
499
+
500
+
501
+ def generate_pkce_challenge(code_verifier: str) -> str:
502
+ """Generate PKCE code challenge from code verifier"""
503
+ # Generate SHA256 hash of code_verifier
504
+ verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
505
+ # Base64 URL encode (without padding)
506
+ expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
507
+ return expected_challenge
508
+
509
+
510
+ def to_title_case(input_str: str) -> str:
511
+ """Convert a string to title case"""
512
+ spaced_str = input_str.replace('_', ' ').replace('-', ' ')
513
+ return spaced_str.title()
514
+
515
+
516
+ def to_bool(val: object) -> bool:
517
+ """Convert common truthy/falsey representations to a boolean.
518
+
519
+ Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
520
+ All other values are considered falsey. None is falsey.
521
+ """
522
+ if isinstance(val, bool):
523
+ return val
524
+ if val is None:
525
+ return False
526
+ s = str(val).strip().lower()
527
+ return s in ("1", "true", "t", "yes", "y", "on")
528
+
529
+
530
+ ACCESS_LEVEL = Literal["admin", "member", "guest"]
531
+
532
+ def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
533
+ """Get the rank of an access level. Lower ranks have more privileges."""
534
+ return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
535
+
536
+ def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
537
+ """Check if a user has privilege to access a resource"""
538
+ user_access_level_rank = get_access_level_rank(user_access_level)
539
+ required_access_level_rank = get_access_level_rank(required_access_level)
540
+ return user_access_level_rank <= required_access_level_rank
squirrels/_version.py CHANGED
@@ -1,3 +1 @@
1
- __version__ = '0.5.0'
2
-
3
- sq_major_version, sq_minor_version, sq_patch_version = __version__.split('.')[:3]
1
+ __version__ = '0.6.0'
squirrels/arguments.py CHANGED
@@ -1,2 +1,7 @@
1
- from ._arguments._init_time_args import ConnectionsArgs, ParametersArgs, BuildModelArgs
2
- from ._arguments._run_time_args import ContextArgs, ModelArgs, DashboardArgs
1
+ from ._arguments.init_time_args import ConnectionsArgs, AuthProviderArgs, ParametersArgs, BuildModelArgs
2
+ from ._arguments.run_time_args import ContextArgs, ModelArgs, DashboardArgs
3
+
4
+ __all__ = [
5
+ "ConnectionsArgs", "AuthProviderArgs", "ParametersArgs", "BuildModelArgs",
6
+ "ContextArgs", "ModelArgs", "DashboardArgs"
7
+ ]
squirrels/auth.py ADDED
@@ -0,0 +1,4 @@
1
+ from ._schemas.auth_models import CustomUserFields, RegisteredUser
2
+ from ._auth import ProviderConfigs, provider
3
+
4
+ __all__ = ["CustomUserFields", "RegisteredUser", "ProviderConfigs", "provider"]
squirrels/connections.py CHANGED
@@ -1 +1,3 @@
1
1
  from ._manifest import ConnectionProperties, ConnectionTypeEnum
2
+
3
+ __all__ = ["ConnectionProperties", "ConnectionTypeEnum"]
squirrels/dashboards.py CHANGED
@@ -1 +1,3 @@
1
- from ._dashboard_types import PngDashboard, HtmlDashboard
1
+ from ._dashboards import PngDashboard, HtmlDashboard
2
+
3
+ __all__ = ["PngDashboard", "HtmlDashboard"]
squirrels/data_sources.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from ._data_sources import (
2
+ SourceEnum,
2
3
  SelectDataSource,
3
4
  DateDataSource,
4
5
  DateRangeDataSource,
@@ -6,3 +7,8 @@ from ._data_sources import (
6
7
  NumberRangeDataSource,
7
8
  TextDataSource
8
9
  )
10
+
11
+ __all__ = [
12
+ "SourceEnum", "SelectDataSource", "DateDataSource", "DateRangeDataSource",
13
+ "NumberDataSource", "NumberRangeDataSource", "TextDataSource"
14
+ ]
@@ -6,3 +6,8 @@ from ._parameter_options import (
6
6
  NumberRangeParameterOption,
7
7
  TextParameterOption
8
8
  )
9
+
10
+ __all__ = [
11
+ "SelectParameterOption", "DateParameterOption", "DateRangeParameterOption",
12
+ "NumberParameterOption", "NumberRangeParameterOption", "TextParameterOption"
13
+ ]
squirrels/parameters.py CHANGED
@@ -7,3 +7,8 @@ from ._parameters import (
7
7
  NumberRangeParameter,
8
8
  TextParameter
9
9
  )
10
+
11
+ __all__ = [
12
+ "SingleSelectParameter", "MultiSelectParameter", "DateParameter", "DateRangeParameter",
13
+ "NumberParameter", "NumberRangeParameter", "TextParameter"
14
+ ]