squirrels 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (125) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +58 -111
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +13 -11
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +271 -0
  8. squirrels/_api_routes/base.py +165 -0
  9. squirrels/_api_routes/dashboards.py +150 -0
  10. squirrels/_api_routes/data_management.py +145 -0
  11. squirrels/_api_routes/datasets.py +257 -0
  12. squirrels/_api_routes/oauth2.py +298 -0
  13. squirrels/_api_routes/project.py +252 -0
  14. squirrels/_api_server.py +256 -450
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/_arguments/init_time_args.py +108 -0
  17. squirrels/_arguments/run_time_args.py +147 -0
  18. squirrels/_auth.py +960 -0
  19. squirrels/_command_line.py +126 -45
  20. squirrels/_compile_prompts.py +147 -0
  21. squirrels/_connection_set.py +48 -26
  22. squirrels/_constants.py +68 -38
  23. squirrels/_dashboards.py +160 -0
  24. squirrels/_data_sources.py +570 -0
  25. squirrels/_dataset_types.py +84 -0
  26. squirrels/_exceptions.py +29 -0
  27. squirrels/_initializer.py +177 -80
  28. squirrels/_logging.py +115 -0
  29. squirrels/_manifest.py +208 -79
  30. squirrels/_model_builder.py +69 -0
  31. squirrels/_model_configs.py +74 -0
  32. squirrels/_model_queries.py +52 -0
  33. squirrels/_models.py +926 -367
  34. squirrels/_package_data/base_project/.env +42 -0
  35. squirrels/_package_data/base_project/.env.example +42 -0
  36. squirrels/_package_data/base_project/assets/expenses.db +0 -0
  37. squirrels/_package_data/base_project/connections.yml +16 -0
  38. squirrels/_package_data/base_project/dashboards/dashboard_example.py +34 -0
  39. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  40. squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +5 -2
  41. squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +3 -3
  42. squirrels/{package_data → _package_data}/base_project/docker/compose.yml +1 -1
  43. squirrels/_package_data/base_project/duckdb_init.sql +10 -0
  44. squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +3 -2
  45. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  46. squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
  47. squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
  48. squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
  49. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +12 -0
  50. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +26 -0
  51. squirrels/_package_data/base_project/models/federates/federate_example.py +37 -0
  52. squirrels/_package_data/base_project/models/federates/federate_example.sql +19 -0
  53. squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
  54. squirrels/_package_data/base_project/models/sources.yml +38 -0
  55. squirrels/{package_data → _package_data}/base_project/parameters.yml +56 -40
  56. squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
  57. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +21 -40
  58. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  59. squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
  60. squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
  61. squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
  62. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
  63. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  64. squirrels/_package_data/templates/dataset_results.html +112 -0
  65. squirrels/_package_data/templates/oauth_login.html +271 -0
  66. squirrels/_package_data/templates/squirrels_studio.html +20 -0
  67. squirrels/_package_loader.py +8 -4
  68. squirrels/_parameter_configs.py +104 -103
  69. squirrels/_parameter_options.py +348 -0
  70. squirrels/_parameter_sets.py +57 -47
  71. squirrels/_parameters.py +1664 -0
  72. squirrels/_project.py +721 -0
  73. squirrels/_py_module.py +7 -5
  74. squirrels/_schemas/__init__.py +0 -0
  75. squirrels/_schemas/auth_models.py +167 -0
  76. squirrels/_schemas/query_param_models.py +75 -0
  77. squirrels/{_api_response_models.py → _schemas/response_models.py} +126 -47
  78. squirrels/_seeds.py +35 -16
  79. squirrels/_sources.py +110 -0
  80. squirrels/_utils.py +248 -73
  81. squirrels/_version.py +1 -1
  82. squirrels/arguments.py +7 -0
  83. squirrels/auth.py +4 -0
  84. squirrels/connections.py +3 -0
  85. squirrels/dashboards.py +2 -81
  86. squirrels/data_sources.py +14 -631
  87. squirrels/parameter_options.py +13 -348
  88. squirrels/parameters.py +14 -1266
  89. squirrels/types.py +16 -0
  90. squirrels-0.5.0.dist-info/METADATA +113 -0
  91. squirrels-0.5.0.dist-info/RECORD +97 -0
  92. {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info}/WHEEL +1 -1
  93. squirrels-0.5.0.dist-info/entry_points.txt +3 -0
  94. {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info/licenses}/LICENSE +1 -1
  95. squirrels/_authenticator.py +0 -85
  96. squirrels/_dashboards_io.py +0 -61
  97. squirrels/_environcfg.py +0 -84
  98. squirrels/arguments/init_time_args.py +0 -40
  99. squirrels/arguments/run_time_args.py +0 -208
  100. squirrels/package_data/assets/favicon.ico +0 -0
  101. squirrels/package_data/assets/index.css +0 -1
  102. squirrels/package_data/assets/index.js +0 -58
  103. squirrels/package_data/base_project/assets/expenses.db +0 -0
  104. squirrels/package_data/base_project/connections.yml +0 -7
  105. squirrels/package_data/base_project/dashboards/dashboard_example.py +0 -32
  106. squirrels/package_data/base_project/dashboards.yml +0 -10
  107. squirrels/package_data/base_project/env.yml +0 -29
  108. squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
  109. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -22
  110. squirrels/package_data/base_project/models/federates/federate_example.py +0 -21
  111. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -3
  112. squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
  113. squirrels/package_data/base_project/pyconfigs/connections.py +0 -19
  114. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -95
  115. squirrels/package_data/base_project/seeds/seed_subcategories.csv +0 -15
  116. squirrels/package_data/base_project/squirrels.yml.j2 +0 -94
  117. squirrels/package_data/templates/index.html +0 -18
  118. squirrels/project.py +0 -378
  119. squirrels/user_base.py +0 -55
  120. squirrels-0.4.1.dist-info/METADATA +0 -117
  121. squirrels-0.4.1.dist-info/RECORD +0 -60
  122. squirrels-0.4.1.dist-info/entry_points.txt +0 -4
  123. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  124. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  125. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
squirrels/_sources.py ADDED
@@ -0,0 +1,110 @@
1
+ from typing import Any
2
+ from pydantic import BaseModel, Field, model_validator
3
+ import time, sqlglot, yaml
4
+
5
+ from . import _utils as u, _constants as c, _model_configs as mc
6
+
7
+
8
+ class UpdateHints(BaseModel):
9
+ increasing_column: str | None = Field(default=None)
10
+ strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if selective_overwrite_value is set")
11
+ selective_overwrite_value: Any = Field(default=None, description="Delete all values of the increasing column greater than or equal to this value")
12
+
13
+
14
+ class Source(mc.ConnectionInterface, mc.ModelConfig):
15
+ table: str | None = Field(default=None)
16
+ load_to_vdl: bool = Field(default=False, description="Whether to load the data to the 'virtual data lake' (VDL)")
17
+ primary_key: list[str] = Field(default_factory=list)
18
+ update_hints: UpdateHints = Field(default_factory=UpdateHints)
19
+
20
+ def finalize_table(self, source_name: str):
21
+ if self.table is None:
22
+ self.table = source_name
23
+ return self
24
+
25
+ def get_table(self) -> str:
26
+ assert self.table is not None, "Table must be set"
27
+ return self.table
28
+
29
+ def get_cols_for_create_table_stmt(self) -> str:
30
+ cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
31
+ return cols_clause
32
+
33
+ def get_max_incr_col_query(self, source_name: str) -> str:
34
+ return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
35
+
36
+ def get_query_for_upsert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
37
+ select_cols = ", ".join([col.name for col in self.columns])
38
+ if full_refresh or max_value_of_increasing_col is None:
39
+ return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
40
+
41
+ increasing_col = self.update_hints.increasing_column
42
+ increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
43
+ where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
44
+
45
+ # TODO: figure out if using pushdown query is worth it
46
+ # if dialect in ['postgres', 'mysql']:
47
+ # pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
48
+ # transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
49
+ # return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
50
+
51
+ return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
52
+
53
+
54
+ class Sources(BaseModel):
55
+ sources: dict[str, Source] = Field(default_factory=dict)
56
+
57
+ @model_validator(mode="before")
58
+ @classmethod
59
+ def convert_sources_list_to_dict(cls, data: dict[str, Any]) -> dict[str, Any]:
60
+ if "sources" in data and isinstance(data["sources"], list):
61
+ # Convert list of sources to dictionary
62
+ sources_dict = {}
63
+ for source in data["sources"]:
64
+ if isinstance(source, dict) and "name" in source:
65
+ name = source.pop("name") # Remove name from source config
66
+ if name in sources_dict:
67
+ raise u.ConfigurationError(f"Duplicate source name found: {name}")
68
+ sources_dict[name] = source
69
+ else:
70
+ raise u.ConfigurationError(f"All sources must have a name field in sources file")
71
+ data["sources"] = sources_dict
72
+ return data
73
+
74
+ @model_validator(mode="after")
75
+ def validate_column_types(self):
76
+ for source_name, source in self.sources.items():
77
+ for col in source.columns:
78
+ if not col.type:
79
+ raise u.ConfigurationError(f"Column '{col.name}' in source '{source_name}' must have a type specified")
80
+ return self
81
+
82
+ def finalize_null_fields(self, env_vars: dict[str, str]):
83
+ for source_name, source in self.sources.items():
84
+ source.finalize_connection(env_vars)
85
+ source.finalize_table(source_name)
86
+ return self
87
+
88
+
89
+ class SourcesIO:
90
+ @classmethod
91
+ def load_file(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Sources:
92
+ start = time.time()
93
+
94
+ sources_path = u.Path(base_path, c.MODELS_FOLDER, c.SOURCES_FILE)
95
+ if sources_path.exists():
96
+ raw_content = u.read_file(sources_path)
97
+ rendered = u.render_string(raw_content, base_path=base_path, env_vars=env_vars)
98
+ sources_data = yaml.safe_load(rendered) or {}
99
+ else:
100
+ sources_data = {}
101
+
102
+ if not isinstance(sources_data, dict):
103
+ raise u.ConfigurationError(
104
+ f"Parsed content from YAML file must be a dictionary. Got: {sources_data}"
105
+ )
106
+
107
+ sources = Sources(**sources_data).finalize_null_fields(env_vars)
108
+
109
+ logger.log_activity_time("loading sources", start)
110
+ return sources
squirrels/_utils.py CHANGED
@@ -1,35 +1,33 @@
1
- from typing import Sequence, Optional, Union, TypeVar, Callable
2
- from pathlib import Path
3
- from pandas.api import types as pd_types
1
+ from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
4
2
  from datetime import datetime
5
- import os, time, logging, json, sqlite3, pandas as pd
3
+ from pathlib import Path
4
+ import os, time, logging, json, duckdb, polars as pl, yaml
6
5
  import jinja2 as j2, jinja2.nodes as j2_nodes
6
+ import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
7
7
 
8
8
  from . import _constants as c
9
+ from ._exceptions import ConfigurationError
9
10
 
10
11
  FilePath = Union[str, Path]
11
12
 
12
-
13
- ## Custom Exceptions
14
-
15
- class InvalidInputError(Exception):
16
- """
17
- Use this exception when the error is due to providing invalid inputs to the REST API
18
- """
19
- pass
20
-
21
- class ConfigurationError(Exception):
22
- """
23
- Use this exception when the server error is due to errors in the squirrels project instead of the squirrels framework/library
24
- """
25
- pass
26
-
27
- class FileExecutionError(Exception):
28
- def __init__(self, message: str, error: Exception, *args) -> None:
29
- t = " "
30
- new_message = f"\n" + message + f"\n{t}Produced error message:\n{t}{t}{error} (see above for more details on handled exception)"
31
- super().__init__(new_message, *args)
32
- self.error = error
13
+ # Polars
14
+ polars_dtypes_to_sqrl_dtypes: dict[type[pl.DataType], list[str]] = {
15
+ pl.String: ["string", "varchar", "char", "text"],
16
+ pl.Int8: ["tinyint", "int1"],
17
+ pl.Int16: ["smallint", "short", "int2"],
18
+ pl.Int32: ["integer", "int", "int4"],
19
+ pl.Int64: ["bigint", "long", "int8"],
20
+ pl.Float32: ["float", "float4", "real"],
21
+ pl.Float64: ["double", "float8", "decimal"], # Note: Polars Decimal type is considered unstable, so we use Float64 for "decimal"
22
+ pl.Boolean: ["boolean", "bool", "logical"],
23
+ pl.Date: ["date"],
24
+ pl.Time: ["time"],
25
+ pl.Datetime: ["timestamp", "datetime"],
26
+ pl.Duration: ["interval"],
27
+ pl.Binary: ["blob", "binary", "varbinary"]
28
+ }
29
+
30
+ sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {sqrl_type: k for k, v in polars_dtypes_to_sqrl_dtypes.items() for sqrl_type in v}
33
31
 
34
32
 
35
33
  ## Other utility classes
@@ -40,7 +38,7 @@ class Logger(logging.Logger):
40
38
  time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
41
39
  data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
42
40
  info = { "request_id": request_id } if request_id else {}
43
- self.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
41
+ self.info(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
44
42
 
45
43
 
46
44
  class EnvironmentWithMacros(j2.Environment):
@@ -85,14 +83,6 @@ class EnvironmentWithMacros(j2.Environment):
85
83
 
86
84
  ## Utility functions/variables
87
85
 
88
- def log_activity_time(logger: logging.Logger, activity: str, start_timestamp: float, *, request_id: str | None = None) -> None:
89
- end_timestamp = time.time()
90
- time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
91
- data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
92
- info = { "request_id": request_id } if request_id else {}
93
- logger.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
94
-
95
-
96
86
  def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
97
87
  """
98
88
  Given a template string, render it with the given keyword arguments
@@ -115,7 +105,6 @@ def read_file(filepath: FilePath) -> str:
115
105
 
116
106
  Arguments:
117
107
  filepath (str | pathlib.Path): The path to the file to read
118
- is_required: If true, throw error if file doesn't exist
119
108
 
120
109
  Returns:
121
110
  Content of the file, or None if doesn't exist and not required
@@ -129,7 +118,7 @@ def read_file(filepath: FilePath) -> str:
129
118
 
130
119
  def normalize_name(name: str) -> str:
131
120
  """
132
- Normalizes names to the convention of the squirrels manifest file.
121
+ Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
133
122
 
134
123
  Arguments:
135
124
  name: The name to normalize.
@@ -142,7 +131,7 @@ def normalize_name(name: str) -> str:
142
131
 
143
132
  def normalize_name_for_api(name: str) -> str:
144
133
  """
145
- Normalizes names to the REST API convention.
134
+ Normalizes names to the REST API convention (with dashes instead of underscores).
146
135
 
147
136
  Arguments:
148
137
  name: The name to normalize.
@@ -180,7 +169,7 @@ def load_json_or_comma_delimited_str_as_list(input_str: Union[str, Sequence]) ->
180
169
  return [x.strip() for x in input_str.split(",")]
181
170
 
182
171
 
183
- X, Y = TypeVar('X'), TypeVar('Y')
172
+ X = TypeVar('X'); Y = TypeVar('Y')
184
173
  def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) -> Optional[Y]:
185
174
  """
186
175
  Given a input value and a function that processes the value, return the output of the function unless input is None
@@ -197,60 +186,246 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
197
186
  return processor(input_val)
198
187
 
199
188
 
200
- def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pd.DataFrame], do_use_duckdb: bool) -> pd.DataFrame:
189
+ def _read_duckdb_init_sql(
190
+ *,
191
+ datalake_db_path: str | None = None,
192
+ ) -> str:
193
+ """
194
+ Reads and caches the duckdb init file content.
195
+ Returns None if file doesn't exist or is empty.
196
+ """
197
+ try:
198
+ init_contents = []
199
+ global_init_path = Path(os.path.expanduser('~'), c.GLOBAL_ENV_FOLDER, c.DUCKDB_INIT_FILE)
200
+ if global_init_path.exists():
201
+ with open(global_init_path, 'r') as f:
202
+ init_contents.append(f.read())
203
+
204
+ if Path(c.DUCKDB_INIT_FILE).exists():
205
+ with open(c.DUCKDB_INIT_FILE, 'r') as f:
206
+ init_contents.append(f.read())
207
+
208
+ if datalake_db_path:
209
+ attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
210
+ init_contents.append(attach_stmt)
211
+
212
+ init_sql = "\n\n".join(init_contents).strip()
213
+ return init_sql
214
+ except Exception as e:
215
+ raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
216
+
217
+ def create_duckdb_connection(
218
+ db_path: str | Path = ":memory:",
219
+ *,
220
+ datalake_db_path: str | None = None
221
+ ) -> duckdb.DuckDBPyConnection:
222
+ """
223
+ Creates a DuckDB connection and initializes it with statements from duckdb init file
224
+
225
+ Arguments:
226
+ filepath: Path to the DuckDB database file. Defaults to in-memory database.
227
+ datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
228
+
229
+ Returns:
230
+ A DuckDB connection (which must be closed after use)
231
+ """
232
+ conn = duckdb.connect(db_path)
233
+
234
+ try:
235
+ init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
236
+ conn.execute(init_sql)
237
+ except Exception as e:
238
+ conn.close()
239
+ raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
240
+
241
+ return conn
242
+
243
+
244
+ def run_sql_on_dataframes(sql_query: str, dataframes: dict[str, pl.LazyFrame]) -> pl.DataFrame:
201
245
  """
202
246
  Runs a SQL query against a collection of dataframes
203
247
 
204
248
  Arguments:
205
249
  sql_query: The SQL query to run
206
- dataframes: A dictionary of table names to their pandas Dataframe
250
+ dataframes: A dictionary of table names to their polars LazyFrame
207
251
 
208
252
  Returns:
209
- The result as a pandas Dataframe from running the query
253
+ The result as a polars Dataframe from running the query
210
254
  """
211
- if do_use_duckdb:
212
- import duckdb
213
- duckdb_conn = duckdb.connect()
214
- else:
215
- conn = sqlite3.connect(":memory:")
255
+ duckdb_conn = create_duckdb_connection()
216
256
 
217
257
  try:
218
258
  for name, df in dataframes.items():
219
- if do_use_duckdb:
220
- duckdb_conn.execute(f"CREATE TABLE {name} AS FROM df")
221
- else:
222
- df.to_sql(name, conn, index=False)
259
+ duckdb_conn.register(name, df)
223
260
 
224
- return duckdb_conn.execute(sql_query).df() if do_use_duckdb else pd.read_sql(sql_query, conn)
261
+ result_df = duckdb_conn.sql(sql_query).pl()
225
262
  finally:
226
- duckdb_conn.close() if do_use_duckdb else conn.close()
263
+ duckdb_conn.close()
264
+
265
+ return result_df
227
266
 
228
267
 
229
- def df_to_json0(df: pd.DataFrame, dimensions: list[str] | None = None) -> dict:
268
+ def load_yaml_config(filepath: FilePath) -> dict:
230
269
  """
231
- Convert a pandas DataFrame to the response format that the dataset result API of Squirrels outputs.
270
+ Loads a YAML config file
232
271
 
233
272
  Arguments:
234
- df: The dataframe to convert into an API response
235
- dimensions: The list of declared dimensions. If None, all non-numeric columns are assumed as dimensions
236
-
273
+ filepath: The path to the YAML file
274
+
237
275
  Returns:
238
- The response of a Squirrels dataset result API
276
+ A dictionary representation of the YAML file
239
277
  """
240
- in_df_json = json.loads(df.to_json(orient='table', index=False))
241
- out_fields = []
242
- non_numeric_fields = []
243
- for in_column in in_df_json["schema"]["fields"]:
244
- col_name: str = in_column["name"]
245
- out_column = { "name": col_name, "type": in_column["type"] }
246
- out_fields.append(out_column)
278
+ try:
279
+ with open(filepath, 'r') as f:
280
+ content = yaml.safe_load(f)
281
+ content = content if content else {}
247
282
 
248
- if not pd_types.is_numeric_dtype(df[col_name].dtype):
249
- non_numeric_fields.append(col_name)
283
+ if not isinstance(content, dict):
284
+ raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
285
+
286
+ return content
287
+ except yaml.YAMLError as e:
288
+ raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
289
+
290
+
291
+ def run_duckdb_stmt(
292
+ logger: Logger, duckdb_conn: duckdb.DuckDBPyConnection, stmt: str, *, params: dict[str, Any] | None = None,
293
+ model_name: str | None = None, redacted_values: list[str] = []
294
+ ) -> duckdb.DuckDBPyConnection:
295
+ """
296
+ Runs a statement on a DuckDB connection
297
+
298
+ Arguments:
299
+ logger: The logger to use
300
+ duckdb_conn: The DuckDB connection
301
+ stmt: The statement to run
302
+ params: The parameters to use
303
+ redacted_values: The values to redact
304
+ """
305
+ redacted_stmt = stmt
306
+ for value in redacted_values:
307
+ redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
250
308
 
251
- out_dimensions = non_numeric_fields if dimensions is None else dimensions
252
- dataset_json = {
253
- "schema": { "fields": out_fields, "dimensions": out_dimensions },
254
- "data": in_df_json["data"]
255
- }
256
- return dataset_json
309
+ for_model_name = f" for model '{model_name}'" if model_name is not None else ""
310
+ logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}", extra={"data": {"params": params}})
311
+ try:
312
+ return duckdb_conn.execute(stmt, params)
313
+ except duckdb.ParserException as e:
314
+ logger.error(f"Failed to run statement: {redacted_stmt}", exc_info=e)
315
+ raise e
316
+
317
+
318
+ def get_current_time() -> str:
319
+ """
320
+ Returns the current time in the format HH:MM:SS.ms
321
+ """
322
+ return datetime.now().strftime('%H:%M:%S.%f')[:-3]
323
+
324
+
325
+ def parse_dependent_tables(sql_query: str, all_table_names: Iterable[str]) -> tuple[set[str], sqlglot.Expression]:
326
+ """
327
+ Parses the dependent tables from a SQL query
328
+
329
+ Arguments:
330
+ sql_query: The SQL query to parse
331
+ all_table_names: The list of all table names
332
+
333
+ Returns:
334
+ The set of dependent tables
335
+ """
336
+ # Parse the SQL query and extract all table references
337
+ parsed = sqlglot.parse_one(sql_query)
338
+ dependencies = set()
339
+
340
+ # Collect all table references from the parsed SQL
341
+ for table in parsed.find_all(sqlglot.expressions.Table):
342
+ if table.name in set(all_table_names):
343
+ dependencies.add(table.name)
344
+
345
+ return dependencies, parsed
346
+
347
+
348
+ async def asyncio_gather(coroutines: list):
349
+ tasks = [asyncio.create_task(coro) for coro in coroutines]
350
+
351
+ try:
352
+ return await asyncio.gather(*tasks)
353
+ except BaseException:
354
+ # Cancel all tasks
355
+ for task in tasks:
356
+ if not task.done():
357
+ task.cancel()
358
+ # Wait for tasks to be cancelled
359
+ await asyncio.gather(*tasks, return_exceptions=True)
360
+ raise
361
+
362
+
363
+ def hash_string(input_str: str, salt: str) -> str:
364
+ """
365
+ Hashes a string using SHA-256
366
+ """
367
+ return hashlib.sha256((input_str + salt).encode()).hexdigest()
368
+
369
+
370
+ T = TypeVar('T')
371
+ def call_func(func: Callable[..., T], **kwargs) -> T:
372
+ """
373
+ Calls a function with the given arguments if func expects arguments, otherwise calls func without arguments
374
+ """
375
+ sig = inspect.signature(func)
376
+ # Filter kwargs to only include parameters that the function accepts
377
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
378
+ return func(**filtered_kwargs)
379
+
380
+
381
+ def generate_pkce_challenge(code_verifier: str) -> str:
382
+ """Generate PKCE code challenge from code verifier"""
383
+ # Generate SHA256 hash of code_verifier
384
+ verifier_hash = hashlib.sha256(code_verifier.encode('utf-8')).digest()
385
+ # Base64 URL encode (without padding)
386
+ expected_challenge = base64.urlsafe_b64encode(verifier_hash).decode('utf-8').rstrip('=')
387
+ return expected_challenge
388
+
389
+ def validate_pkce_challenge(code_verifier: str, code_challenge: str) -> bool:
390
+ """Validate PKCE code verifier against code challenge"""
391
+ # Generate expected challenge
392
+ expected_challenge = generate_pkce_challenge(code_verifier)
393
+ return expected_challenge == code_challenge
394
+
395
+
396
+ def get_scheme(hostname: str | None) -> str:
397
+ """Get the scheme of the request"""
398
+ return "http" if hostname in ("localhost", "127.0.0.1") else "https"
399
+
400
+
401
+ def to_title_case(input_str: str) -> str:
402
+ """Convert a string to title case"""
403
+ spaced_str = input_str.replace('_', ' ').replace('-', ' ')
404
+ return spaced_str.title()
405
+
406
+
407
+ def to_bool(val: object) -> bool:
408
+ """Convert common truthy/falsey representations to a boolean.
409
+
410
+ Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
411
+ All other values are considered falsey. None is falsey.
412
+ """
413
+ if isinstance(val, bool):
414
+ return val
415
+ if val is None:
416
+ return False
417
+ s = str(val).strip().lower()
418
+ return s in ("1", "true", "t", "yes", "y", "on")
419
+
420
+
421
+ ACCESS_LEVEL = Literal["admin", "member", "guest"]
422
+
423
+ def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
424
+ """Get the rank of an access level. Lower ranks have more privileges."""
425
+ return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
426
+
427
+ def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
428
+ """Check if a user has privilege to access a resource"""
429
+ user_access_level_rank = get_access_level_rank(user_access_level)
430
+ required_access_level_rank = get_access_level_rank(required_access_level)
431
+ return user_access_level_rank <= required_access_level_rank
squirrels/_version.py CHANGED
@@ -1,3 +1,3 @@
1
- __version__ = '0.4.1'
1
+ __version__ = '0.5.0'
2
2
 
3
3
  sq_major_version, sq_minor_version, sq_patch_version = __version__.split('.')[:3]
squirrels/arguments.py ADDED
@@ -0,0 +1,7 @@
1
+ from ._arguments.init_time_args import ConnectionsArgs, AuthProviderArgs, ParametersArgs, BuildModelArgs
2
+ from ._arguments.run_time_args import ContextArgs, ModelArgs, DashboardArgs
3
+
4
+ __all__ = [
5
+ "ConnectionsArgs", "AuthProviderArgs", "ParametersArgs", "BuildModelArgs",
6
+ "ContextArgs", "ModelArgs", "DashboardArgs"
7
+ ]
squirrels/auth.py ADDED
@@ -0,0 +1,4 @@
1
+ from ._schemas.auth_models import CustomUserFields, RegisteredUser
2
+ from ._auth import ProviderConfigs, provider
3
+
4
+ __all__ = ["CustomUserFields", "RegisteredUser", "ProviderConfigs", "provider"]
@@ -0,0 +1,3 @@
1
+ from ._manifest import ConnectionProperties, ConnectionTypeEnum
2
+
3
+ __all__ = ["ConnectionProperties", "ConnectionTypeEnum"]
squirrels/dashboards.py CHANGED
@@ -1,82 +1,3 @@
1
- import matplotlib.figure as _figure, io as _io, abc as _abc, typing as _t
1
+ from ._dashboards import PngDashboard, HtmlDashboard
2
2
 
3
- from . import _constants as _c
4
-
5
-
6
- class Dashboard(metaclass=_abc.ABCMeta):
7
- """
8
- Abstract parent class for all Dashboard classes.
9
- """
10
-
11
- @property
12
- @_abc.abstractmethod
13
- def _content(self) -> bytes | str:
14
- pass
15
-
16
- @property
17
- @_abc.abstractmethod
18
- def _format(self) -> str:
19
- pass
20
-
21
-
22
- class PngDashboard(Dashboard):
23
- """
24
- Instantiate a Dashboard in PNG format from a matplotlib figure or bytes
25
- """
26
-
27
- def __init__(self, content: _figure.Figure | _io.BytesIO | bytes) -> None:
28
- """
29
- Constructor for PngDashboard
30
-
31
- Arguments:
32
- content: The content of the dashboard as a matplotlib.figure.Figure or bytes
33
- """
34
- if isinstance(content, _figure.Figure):
35
- buffer = _io.BytesIO()
36
- content.savefig(buffer, format=_c.PNG)
37
- content = buffer.getvalue()
38
-
39
- if isinstance(content, _io.BytesIO):
40
- content = content.getvalue()
41
-
42
- self.__content = content
43
-
44
- @property
45
- def _content(self) -> bytes:
46
- return self.__content
47
-
48
- @property
49
- def _format(self) -> _t.Literal['png']:
50
- return _c.PNG
51
-
52
- def _repr_png_(self):
53
- return self._content
54
-
55
-
56
- class HtmlDashboard(Dashboard):
57
- """
58
- Instantiate a Dashboard from an HTML string
59
- """
60
-
61
- def __init__(self, content: _io.StringIO | str) -> None:
62
- """
63
- Constructor for HtmlDashboard
64
-
65
- Arguments:
66
- content: The content of the dashboard as HTML string
67
- """
68
- if isinstance(content, _io.StringIO):
69
- content = content.getvalue()
70
-
71
- self.__content = content
72
-
73
- @property
74
- def _content(self) -> str:
75
- return self.__content
76
-
77
- @property
78
- def _format(self) -> _t.Literal['html']:
79
- return _c.HTML
80
-
81
- def _repr_html_(self):
82
- return self._content
3
+ __all__ = ["PngDashboard", "HtmlDashboard"]