squirrels 0.5.0b4__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +2 -0
- squirrels/_api_routes/auth.py +83 -74
- squirrels/_api_routes/base.py +58 -41
- squirrels/_api_routes/dashboards.py +37 -21
- squirrels/_api_routes/data_management.py +72 -27
- squirrels/_api_routes/datasets.py +107 -84
- squirrels/_api_routes/oauth2.py +11 -13
- squirrels/_api_routes/project.py +71 -33
- squirrels/_api_server.py +130 -63
- squirrels/_arguments/run_time_args.py +9 -9
- squirrels/_auth.py +117 -162
- squirrels/_command_line.py +68 -32
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +11 -2
- squirrels/_constants.py +22 -8
- squirrels/_data_sources.py +38 -32
- squirrels/_dataset_types.py +2 -4
- squirrels/_initializer.py +1 -1
- squirrels/_logging.py +117 -0
- squirrels/_manifest.py +125 -58
- squirrels/_model_builder.py +10 -54
- squirrels/_models.py +224 -108
- squirrels/_package_data/base_project/.env +15 -4
- squirrels/_package_data/base_project/.env.example +14 -3
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +2 -2
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +4 -4
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +22 -15
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +1 -1
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +5 -1
- squirrels/_package_data/base_project/pyconfigs/context.py +23 -12
- squirrels/_package_data/base_project/pyconfigs/parameters.py +68 -33
- squirrels/_package_data/base_project/pyconfigs/user.py +11 -18
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +18 -28
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +8 -10
- squirrels/_project.py +351 -234
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/auth_models.py +32 -9
- squirrels/_schemas/query_param_models.py +9 -1
- squirrels/_schemas/response_models.py +36 -10
- squirrels/_seeds.py +1 -1
- squirrels/_sources.py +23 -19
- squirrels/_utils.py +83 -35
- squirrels/_version.py +1 -1
- squirrels/arguments.py +5 -0
- squirrels/auth.py +4 -1
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +6 -1
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/METADATA +28 -13
- squirrels-0.5.1.dist-info/RECORD +98 -0
- squirrels-0.5.0b4.dist-info/RECORD +0 -94
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Request context management using ContextVars for request-scoped data.
|
|
3
|
+
Provides thread-safe and async-safe access to request IDs throughout the request lifecycle.
|
|
4
|
+
"""
|
|
5
|
+
from contextvars import ContextVar
|
|
6
|
+
import uuid
|
|
7
|
+
import base64
|
|
8
|
+
|
|
9
|
+
# ContextVar for storing the current request ID
|
|
10
|
+
_request_id: ContextVar[str | None] = ContextVar("request_id", default=None)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_request_id() -> str | None:
|
|
14
|
+
"""
|
|
15
|
+
Get the current request ID from the context.
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
The request ID string if available, None otherwise (e.g., in background tasks).
|
|
19
|
+
"""
|
|
20
|
+
return _request_id.get()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def set_request_id() -> str:
|
|
24
|
+
"""
|
|
25
|
+
Set a new request ID in the context.
|
|
26
|
+
Uses base64 URL-safe encoding of UUID bytes to create a shorter ID (22 chars vs 36).
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
The request ID that was set.
|
|
30
|
+
"""
|
|
31
|
+
request_id = base64.urlsafe_b64encode(uuid.uuid4().bytes).decode().rstrip('=')
|
|
32
|
+
_request_id.set(request_id)
|
|
33
|
+
return request_id
|
|
@@ -1,19 +1,37 @@
|
|
|
1
|
-
from typing import Callable, Any
|
|
1
|
+
from typing import Callable, Any, Literal
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
|
4
4
|
|
|
5
5
|
|
|
6
|
-
class
|
|
6
|
+
class CustomUserFields(BaseModel):
|
|
7
|
+
"""
|
|
8
|
+
Extend this class to add custom user fields.
|
|
9
|
+
- Only the following types are supported: [str, int, float, bool, typing.Literal]
|
|
10
|
+
- Add "| None" after the type to make it nullable.
|
|
11
|
+
- Always set a default value for the column (use None if default is null).
|
|
12
|
+
"""
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class AbstractUser(BaseModel):
|
|
7
17
|
model_config = ConfigDict(from_attributes=True)
|
|
8
18
|
username: str
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@classmethod
|
|
12
|
-
def dropped_columns(cls):
|
|
13
|
-
return []
|
|
19
|
+
access_level: Literal["admin", "member", "guest"]
|
|
20
|
+
custom_fields: CustomUserFields
|
|
14
21
|
|
|
15
22
|
def __hash__(self):
|
|
16
23
|
return hash(self.username)
|
|
24
|
+
|
|
25
|
+
def __str__(self):
|
|
26
|
+
return self.username
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class GuestUser(AbstractUser):
|
|
30
|
+
access_level: Literal["guest"] = "guest"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RegisteredUser(AbstractUser):
|
|
34
|
+
access_level: Literal["admin", "member"] = "member"
|
|
17
35
|
|
|
18
36
|
|
|
19
37
|
class ApiKey(BaseModel):
|
|
@@ -40,9 +58,14 @@ class UserField(BaseModel):
|
|
|
40
58
|
class ProviderConfigs(BaseModel):
|
|
41
59
|
client_id: str
|
|
42
60
|
client_secret: str
|
|
43
|
-
|
|
61
|
+
server_url: str
|
|
62
|
+
server_metadata_path: str = Field(default="/.well-known/openid-configuration")
|
|
44
63
|
client_kwargs: dict = Field(default_factory=dict)
|
|
45
|
-
get_user: Callable[[dict],
|
|
64
|
+
get_user: Callable[[dict], RegisteredUser]
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def server_metadata_url(self) -> str:
|
|
68
|
+
return f"{self.server_url}{self.server_metadata_path}"
|
|
46
69
|
|
|
47
70
|
|
|
48
71
|
class AuthProvider(BaseModel):
|
|
@@ -40,7 +40,7 @@ def get_query_models_for_dataset(widget_parameters: list[str] | None, param_fiel
|
|
|
40
40
|
predefined_params = [
|
|
41
41
|
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
|
|
42
42
|
APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
|
|
43
|
-
APIParamFieldInfo("
|
|
43
|
+
APIParamFieldInfo("x_sql_query", str, description="Optional DuckDB SQL to transform the final dataset. Use table name 'result' to reference the dataset."),
|
|
44
44
|
APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
|
|
45
45
|
APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
|
|
46
46
|
]
|
|
@@ -65,3 +65,11 @@ def get_query_models_for_querying_models(param_fields: dict):
|
|
|
65
65
|
APIParamFieldInfo("x_sql_query", str, description="The SQL query to execute on the data models"),
|
|
66
66
|
]
|
|
67
67
|
return _get_query_models_helper(None, predefined_params, param_fields)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_query_models_for_compiled_models(param_fields: dict):
|
|
71
|
+
"""Generate query models for fetching compiled model SQL"""
|
|
72
|
+
predefined_params = [
|
|
73
|
+
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the model"),
|
|
74
|
+
]
|
|
75
|
+
return _get_query_models_helper(None, predefined_params, param_fields)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Annotated, Literal
|
|
1
|
+
from typing import Annotated, Literal, Any
|
|
2
2
|
from pydantic import BaseModel, Field
|
|
3
3
|
from datetime import date
|
|
4
4
|
|
|
@@ -99,6 +99,14 @@ parameters_path_description = "The API path to the parameters for the dataset /
|
|
|
99
99
|
metadata_path_description = "The API path to the metadata (i.e., description and schema) for the dataset"
|
|
100
100
|
result_path_description = "The API path to the results for the dataset / dashboard"
|
|
101
101
|
|
|
102
|
+
class ConfigurableDefaultModel(BaseModel):
|
|
103
|
+
name: str
|
|
104
|
+
default: str
|
|
105
|
+
|
|
106
|
+
class ConfigurableItemModel(ConfigurableDefaultModel):
|
|
107
|
+
label: str
|
|
108
|
+
description: str
|
|
109
|
+
|
|
102
110
|
class ColumnModel(BaseModel):
|
|
103
111
|
name: Annotated[str, Field(examples=["mycol"], description="Name of column")]
|
|
104
112
|
type: Annotated[str, Field(examples=["string", "integer", "boolean", "datetime"], description='Column type (such as "string", "integer", "boolean", "datetime", etc.)')]
|
|
@@ -118,18 +126,19 @@ class DatasetItemModel(BaseModel):
|
|
|
118
126
|
name: Annotated[str, Field(examples=["mydataset"], description=name_description)]
|
|
119
127
|
label: Annotated[str, Field(examples=["My Dataset"], description=label_description)]
|
|
120
128
|
description: Annotated[str, Field(examples=[""], description=description_description)]
|
|
121
|
-
|
|
129
|
+
configurables: Annotated[list[ConfigurableDefaultModel], Field(default_factory=list, description="The list of configurables with their default values")]
|
|
130
|
+
parameters: Annotated[list[str], Field(examples=["myparam1", "myparam2"], description="The list of parameter names used by the dataset. If the list is empty, the dataset does not accept any parameters.")]
|
|
122
131
|
data_schema: Annotated[SchemaWithConditionModel, Field(alias="schema", description="JSON object describing the schema of the dataset")]
|
|
123
|
-
parameters_path: Annotated[str, Field(examples=["/squirrels
|
|
124
|
-
result_path: Annotated[str, Field(examples=["/squirrels
|
|
125
|
-
|
|
132
|
+
parameters_path: Annotated[str, Field(examples=["/squirrels/v0/myproject/v1/dataset/mydataset/parameters"], description=parameters_path_description)]
|
|
133
|
+
result_path: Annotated[str, Field(examples=["/squirrels/v0/myproject/v1/dataset/mydataset"], description=result_path_description)]
|
|
134
|
+
|
|
126
135
|
class DashboardItemModel(ParametersModel):
|
|
127
136
|
name: Annotated[str, Field(examples=["mydashboard"], description=name_description)]
|
|
128
137
|
label: Annotated[str, Field(examples=["My Dashboard"], description=label_description)]
|
|
129
138
|
description: Annotated[str, Field(examples=[""], description=description_description)]
|
|
130
139
|
parameters: Annotated[list[str], Field(examples=["myparam1", "myparam2"], description="The list of parameter names used by the dashboard")]
|
|
131
|
-
parameters_path: Annotated[str, Field(examples=["/squirrels
|
|
132
|
-
result_path: Annotated[str, Field(examples=["/squirrels
|
|
140
|
+
parameters_path: Annotated[str, Field(examples=["/squirrels/v0/myproject/v1/dashboard/mydashboard/parameters"], description=parameters_path_description)]
|
|
141
|
+
result_path: Annotated[str, Field(examples=["/squirrels/v0/myproject/v1/dashboard/mydashboard"], description=result_path_description)]
|
|
133
142
|
result_format: Annotated[str, Field(examples=["png", "html"], description="The format of the dashboard's result API response (one of 'png' or 'html')")]
|
|
134
143
|
|
|
135
144
|
ModelConfigType = mc.ModelConfig | s.Source | mc.SeedConfig | mc.BuildModelConfig | mc.DbviewModelConfig | mc.FederateModelConfig
|
|
@@ -155,13 +164,16 @@ class LineageRelation(BaseModel):
|
|
|
155
164
|
source: LineageNode
|
|
156
165
|
target: LineageNode
|
|
157
166
|
|
|
158
|
-
class
|
|
159
|
-
parameters: Annotated[ParametersListType, Field(description="The list of all parameters in the project")]
|
|
167
|
+
class CatalogModelForTool(BaseModel):
|
|
168
|
+
parameters: Annotated[ParametersListType, Field(description="The list of all parameters in the project. It is possible that not all parameters are used by a dataset.")]
|
|
160
169
|
datasets: Annotated[list[DatasetItemModel], Field(description="The list of accessible datasets")]
|
|
170
|
+
|
|
171
|
+
class CatalogModel(CatalogModelForTool):
|
|
161
172
|
dashboards: Annotated[list[DashboardItemModel], Field(description="The list of accessible dashboards")]
|
|
162
173
|
connections: Annotated[list[ConnectionItemModel], Field(description="The list of connections in the project (only provided for admin users)")]
|
|
163
174
|
models: Annotated[list[DataModelItem], Field(description="The list of data models in the project (only provided for admin users)")]
|
|
164
175
|
lineage: Annotated[list[LineageRelation], Field(description="The lineage information between data assets (only provided for admin users)")]
|
|
176
|
+
configurables: Annotated[list[ConfigurableItemModel], Field(description="The list of configurables (only provided for admin users)")]
|
|
165
177
|
|
|
166
178
|
|
|
167
179
|
## Dataset Results Response Models
|
|
@@ -180,15 +192,29 @@ class DatasetResultModel(BaseModel):
|
|
|
180
192
|
)]
|
|
181
193
|
|
|
182
194
|
|
|
195
|
+
## Compiled Query Response Model
|
|
196
|
+
|
|
197
|
+
class CompiledQueryModel(BaseModel):
|
|
198
|
+
language: Annotated[Literal["sql", "python"], Field(examples=["sql"], description="The language of the data model query: 'sql' or 'python'")]
|
|
199
|
+
definition: Annotated[str, Field("", description="The compiled SQL or Python definition of the data model.")]
|
|
200
|
+
placeholders: Annotated[dict[str, Any], Field({}, description="The placeholders for the data model.")]
|
|
201
|
+
|
|
202
|
+
|
|
183
203
|
## Project Metadata Response Models
|
|
184
204
|
|
|
185
205
|
class ProjectVersionModel(BaseModel):
|
|
186
206
|
major_version: Annotated[int, Field(examples=[1])]
|
|
187
|
-
data_catalog_path: Annotated[str, Field(examples=["/squirrels
|
|
207
|
+
data_catalog_path: Annotated[str, Field(examples=["/squirrels/v0/project/myproject/v1/data-catalog"])]
|
|
188
208
|
|
|
189
209
|
class ProjectModel(BaseModel):
|
|
190
210
|
name: Annotated[str, Field(examples=["myproject"])]
|
|
191
211
|
version: Annotated[str, Field(examples=["v1"])]
|
|
192
212
|
label: Annotated[str, Field(examples=["My Project"])]
|
|
193
213
|
description: Annotated[str, Field(examples=["My project description"])]
|
|
214
|
+
elevated_access_level: Annotated[Literal["admin", "member", "guest"], Field(
|
|
215
|
+
examples=["admin"], description="The access level required to access elevated features (such as configurables and data lineage)"
|
|
216
|
+
)]
|
|
217
|
+
redoc_path: Annotated[str, Field(examples=["/squirrels/v0/project/myproject/v1/redoc"])]
|
|
218
|
+
swagger_path: Annotated[str, Field(examples=["/squirrels/v0/project/myproject/v1/docs"])]
|
|
219
|
+
mcp_server_path: Annotated[str, Field(examples=["/squirrels/v0/project/myproject/v1/mcp"])]
|
|
194
220
|
squirrels_version: Annotated[str, Field(examples=["0.1.0"])]
|
squirrels/_seeds.py
CHANGED
|
@@ -37,7 +37,7 @@ class SeedsIO:
|
|
|
37
37
|
@classmethod
|
|
38
38
|
def load_files(cls, logger: u.Logger, base_path: str, env_vars: dict[str, str]) -> Seeds:
|
|
39
39
|
start = time.time()
|
|
40
|
-
infer_schema_setting: bool = (env_vars.get(c.SQRL_SEEDS_INFER_SCHEMA, "true")
|
|
40
|
+
infer_schema_setting: bool = u.to_bool(env_vars.get(c.SQRL_SEEDS_INFER_SCHEMA, "true"))
|
|
41
41
|
na_values_setting: list[str] = json.loads(env_vars.get(c.SQRL_SEEDS_NA_VALUES, "[]"))
|
|
42
42
|
|
|
43
43
|
seeds_dict = {}
|
squirrels/_sources.py
CHANGED
|
@@ -1,19 +1,19 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
from pydantic import BaseModel, Field, model_validator
|
|
3
|
-
import time, sqlglot
|
|
3
|
+
import time, sqlglot, yaml
|
|
4
4
|
|
|
5
5
|
from . import _utils as u, _constants as c, _model_configs as mc
|
|
6
6
|
|
|
7
7
|
|
|
8
8
|
class UpdateHints(BaseModel):
|
|
9
9
|
increasing_column: str | None = Field(default=None)
|
|
10
|
-
strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if
|
|
11
|
-
selective_overwrite_value: Any = Field(default=None)
|
|
10
|
+
strictly_increasing: bool = Field(default=True, description="Delete the max value of the increasing column, ignored if selective_overwrite_value is set")
|
|
11
|
+
selective_overwrite_value: Any = Field(default=None, description="Delete all values of the increasing column greater than or equal to this value")
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
class Source(mc.ConnectionInterface, mc.ModelConfig):
|
|
15
15
|
table: str | None = Field(default=None)
|
|
16
|
-
|
|
16
|
+
load_to_vdl: bool = Field(default=False, description="Whether to load the data to the 'virtual data lake' (VDL)")
|
|
17
17
|
primary_key: list[str] = Field(default_factory=list)
|
|
18
18
|
update_hints: UpdateHints = Field(default_factory=UpdateHints)
|
|
19
19
|
|
|
@@ -28,34 +28,28 @@ class Source(mc.ConnectionInterface, mc.ModelConfig):
|
|
|
28
28
|
|
|
29
29
|
def get_cols_for_create_table_stmt(self) -> str:
|
|
30
30
|
cols_clause = ", ".join([f"{col.name} {col.type}" for col in self.columns])
|
|
31
|
-
|
|
32
|
-
return f"{cols_clause}{primary_key_clause}"
|
|
33
|
-
|
|
34
|
-
def get_cols_for_insert_stmt(self) -> str:
|
|
35
|
-
return ", ".join([col.name for col in self.columns])
|
|
31
|
+
return cols_clause
|
|
36
32
|
|
|
37
33
|
def get_max_incr_col_query(self, source_name: str) -> str:
|
|
38
34
|
return f"SELECT max({self.update_hints.increasing_column}) FROM {source_name}"
|
|
39
35
|
|
|
40
|
-
def
|
|
41
|
-
select_cols = self.
|
|
36
|
+
def get_query_for_upsert(self, dialect: str, conn_name: str, table_name: str, max_value_of_increasing_col: Any | None, *, full_refresh: bool = True) -> str:
|
|
37
|
+
select_cols = ", ".join([col.name for col in self.columns])
|
|
42
38
|
if full_refresh or max_value_of_increasing_col is None:
|
|
43
39
|
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name}"
|
|
44
40
|
|
|
45
41
|
increasing_col = self.update_hints.increasing_column
|
|
46
42
|
increasing_col_type = next(col.type for col in self.columns if col.name == increasing_col)
|
|
47
43
|
where_cond = f"{increasing_col}::{increasing_col_type} > '{max_value_of_increasing_col}'::{increasing_col_type}"
|
|
48
|
-
pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
|
|
49
44
|
|
|
50
|
-
if
|
|
51
|
-
|
|
52
|
-
|
|
45
|
+
# TODO: figure out if using pushdown query is worth it
|
|
46
|
+
# if dialect in ['postgres', 'mysql']:
|
|
47
|
+
# pushdown_query = f"SELECT {select_cols} FROM {table_name} WHERE {where_cond}"
|
|
48
|
+
# transpiled_query = sqlglot.transpile(pushdown_query, read='duckdb', write=dialect)[0].replace("'", "''")
|
|
49
|
+
# return f"FROM {dialect}_query('db_{conn_name}', '{transpiled_query}')"
|
|
53
50
|
|
|
54
51
|
return f"SELECT {select_cols} FROM db_{conn_name}.{table_name} WHERE {where_cond}"
|
|
55
52
|
|
|
56
|
-
def get_insert_replace_clause(self) -> str:
|
|
57
|
-
return "" if len(self.primary_key) == 0 else "OR REPLACE"
|
|
58
|
-
|
|
59
53
|
|
|
60
54
|
class Sources(BaseModel):
|
|
61
55
|
sources: dict[str, Source] = Field(default_factory=dict)
|
|
@@ -98,7 +92,17 @@ class SourcesIO:
|
|
|
98
92
|
start = time.time()
|
|
99
93
|
|
|
100
94
|
sources_path = u.Path(base_path, c.MODELS_FOLDER, c.SOURCES_FILE)
|
|
101
|
-
|
|
95
|
+
if sources_path.exists():
|
|
96
|
+
raw_content = u.read_file(sources_path)
|
|
97
|
+
rendered = u.render_string(raw_content, base_path=base_path, env_vars=env_vars)
|
|
98
|
+
sources_data = yaml.safe_load(rendered) or {}
|
|
99
|
+
else:
|
|
100
|
+
sources_data = {}
|
|
101
|
+
|
|
102
|
+
if not isinstance(sources_data, dict):
|
|
103
|
+
raise u.ConfigurationError(
|
|
104
|
+
f"Parsed content from YAML file must be a dictionary. Got: {sources_data}"
|
|
105
|
+
)
|
|
102
106
|
|
|
103
107
|
sources = Sources(**sources_data).finalize_null_fields(env_vars)
|
|
104
108
|
|
squirrels/_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
from typing import Sequence, Optional, Union, TypeVar, Callable,
|
|
1
|
+
from typing import Sequence, Optional, Union, TypeVar, Callable, Iterable, Literal, Any
|
|
2
2
|
from datetime import datetime
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from functools import lru_cache
|
|
5
4
|
import os, time, logging, json, duckdb, polars as pl, yaml
|
|
6
5
|
import jinja2 as j2, jinja2.nodes as j2_nodes
|
|
7
6
|
import sqlglot, sqlglot.expressions, asyncio, hashlib, inspect, base64
|
|
@@ -34,12 +33,20 @@ sqrl_dtypes_to_polars_dtypes: dict[str, type[pl.DataType]] = {sqrl_type: k for k
|
|
|
34
33
|
## Other utility classes
|
|
35
34
|
|
|
36
35
|
class Logger(logging.Logger):
|
|
37
|
-
def
|
|
36
|
+
def info(self, msg: str, *, data: dict[str, Any] = {}, **kwargs) -> None:
|
|
37
|
+
super().info(msg, extra={"data": data}, **kwargs)
|
|
38
|
+
|
|
39
|
+
def log_activity_time(self, activity: str, start_timestamp: float, *, additional_data: dict[str, Any] = {}) -> None:
|
|
38
40
|
end_timestamp = time.time()
|
|
39
41
|
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
40
|
-
data = {
|
|
41
|
-
|
|
42
|
-
|
|
42
|
+
data = {
|
|
43
|
+
"activity": activity,
|
|
44
|
+
"start_timestamp": start_timestamp,
|
|
45
|
+
"end_timestamp": end_timestamp,
|
|
46
|
+
"time_taken_ms": time_taken,
|
|
47
|
+
**additional_data
|
|
48
|
+
}
|
|
49
|
+
self.info(f'Time taken for "{activity}": {time_taken}ms', data=data)
|
|
43
50
|
|
|
44
51
|
|
|
45
52
|
class EnvironmentWithMacros(j2.Environment):
|
|
@@ -84,14 +91,6 @@ class EnvironmentWithMacros(j2.Environment):
|
|
|
84
91
|
|
|
85
92
|
## Utility functions/variables
|
|
86
93
|
|
|
87
|
-
def log_activity_time(logger: logging.Logger, activity: str, start_timestamp: float, *, request_id: str | None = None) -> None:
|
|
88
|
-
end_timestamp = time.time()
|
|
89
|
-
time_taken = round((end_timestamp-start_timestamp) * 10**3, 3)
|
|
90
|
-
data = { "activity": activity, "start_timestamp": start_timestamp, "end_timestamp": end_timestamp, "time_taken_ms": time_taken }
|
|
91
|
-
info = { "request_id": request_id } if request_id else {}
|
|
92
|
-
logger.debug(f'Time taken for "{activity}": {time_taken}ms', extra={"data": data, "info": info})
|
|
93
|
-
|
|
94
|
-
|
|
95
94
|
def render_string(raw_str: str, *, base_path: str = ".", **kwargs) -> str:
|
|
96
95
|
"""
|
|
97
96
|
Given a template string, render it with the given keyword arguments
|
|
@@ -127,7 +126,7 @@ def read_file(filepath: FilePath) -> str:
|
|
|
127
126
|
|
|
128
127
|
def normalize_name(name: str) -> str:
|
|
129
128
|
"""
|
|
130
|
-
Normalizes names to the convention of the squirrels manifest file.
|
|
129
|
+
Normalizes names to the convention of the squirrels manifest file (with underscores instead of dashes).
|
|
131
130
|
|
|
132
131
|
Arguments:
|
|
133
132
|
name: The name to normalize.
|
|
@@ -140,7 +139,7 @@ def normalize_name(name: str) -> str:
|
|
|
140
139
|
|
|
141
140
|
def normalize_name_for_api(name: str) -> str:
|
|
142
141
|
"""
|
|
143
|
-
Normalizes names to the REST API convention.
|
|
142
|
+
Normalizes names to the REST API convention (with dashes instead of underscores).
|
|
144
143
|
|
|
145
144
|
Arguments:
|
|
146
145
|
name: The name to normalize.
|
|
@@ -195,8 +194,10 @@ def process_if_not_none(input_val: Optional[X], processor: Callable[[X], Y]) ->
|
|
|
195
194
|
return processor(input_val)
|
|
196
195
|
|
|
197
196
|
|
|
198
|
-
|
|
199
|
-
|
|
197
|
+
def _read_duckdb_init_sql(
|
|
198
|
+
*,
|
|
199
|
+
datalake_db_path: str | None = None,
|
|
200
|
+
) -> str:
|
|
200
201
|
"""
|
|
201
202
|
Reads and caches the duckdb init file content.
|
|
202
203
|
Returns None if file doesn't exist or is empty.
|
|
@@ -211,35 +212,38 @@ def _read_duckdb_init_sql() -> tuple[str, Path | None]:
|
|
|
211
212
|
if Path(c.DUCKDB_INIT_FILE).exists():
|
|
212
213
|
with open(c.DUCKDB_INIT_FILE, 'r') as f:
|
|
213
214
|
init_contents.append(f.read())
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
return init_sql
|
|
215
|
+
|
|
216
|
+
if datalake_db_path:
|
|
217
|
+
attach_stmt = f"ATTACH '{datalake_db_path}' AS vdl (READ_ONLY);"
|
|
218
|
+
init_contents.append(attach_stmt)
|
|
219
|
+
use_stmt = f"USE vdl;"
|
|
220
|
+
init_contents.append(use_stmt)
|
|
221
|
+
|
|
222
|
+
init_sql = "\n\n".join(init_contents).strip()
|
|
223
|
+
return init_sql
|
|
223
224
|
except Exception as e:
|
|
224
225
|
raise ConfigurationError(f"Failed to read {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
225
226
|
|
|
226
|
-
def create_duckdb_connection(
|
|
227
|
+
def create_duckdb_connection(
|
|
228
|
+
db_path: str | Path = ":memory:",
|
|
229
|
+
*,
|
|
230
|
+
datalake_db_path: str | None = None
|
|
231
|
+
) -> duckdb.DuckDBPyConnection:
|
|
227
232
|
"""
|
|
228
233
|
Creates a DuckDB connection and initializes it with statements from duckdb init file
|
|
229
234
|
|
|
230
235
|
Arguments:
|
|
231
236
|
filepath: Path to the DuckDB database file. Defaults to in-memory database.
|
|
232
|
-
|
|
237
|
+
datalake_db_path: The path to the VDL catalog database if applicable. If exists, this is attached as 'vdl' (READ_ONLY). Default is None.
|
|
233
238
|
|
|
234
239
|
Returns:
|
|
235
240
|
A DuckDB connection (which must be closed after use)
|
|
236
241
|
"""
|
|
237
|
-
conn = duckdb.connect(
|
|
242
|
+
conn = duckdb.connect(db_path)
|
|
238
243
|
|
|
239
244
|
try:
|
|
240
|
-
init_sql
|
|
241
|
-
|
|
242
|
-
conn.execute(init_sql)
|
|
245
|
+
init_sql = _read_duckdb_init_sql(datalake_db_path=datalake_db_path)
|
|
246
|
+
conn.execute(init_sql)
|
|
243
247
|
except Exception as e:
|
|
244
248
|
conn.close()
|
|
245
249
|
raise ConfigurationError(f"Failed to execute {c.DUCKDB_INIT_FILE}: {str(e)}") from e
|
|
@@ -283,7 +287,13 @@ def load_yaml_config(filepath: FilePath) -> dict:
|
|
|
283
287
|
"""
|
|
284
288
|
try:
|
|
285
289
|
with open(filepath, 'r') as f:
|
|
286
|
-
|
|
290
|
+
content = yaml.safe_load(f)
|
|
291
|
+
content = content if content else {}
|
|
292
|
+
|
|
293
|
+
if not isinstance(content, dict):
|
|
294
|
+
raise yaml.YAMLError(f"Parsed content from YAML file must be a dictionary. Got: {content}")
|
|
295
|
+
|
|
296
|
+
return content
|
|
287
297
|
except yaml.YAMLError as e:
|
|
288
298
|
raise ConfigurationError(f"Failed to parse yaml file: {filepath}") from e
|
|
289
299
|
|
|
@@ -307,7 +317,7 @@ def run_duckdb_stmt(
|
|
|
307
317
|
redacted_stmt = redacted_stmt.replace(value, "[REDACTED]")
|
|
308
318
|
|
|
309
319
|
for_model_name = f" for model '{model_name}'" if model_name is not None else ""
|
|
310
|
-
logger.
|
|
320
|
+
logger.debug(f"Running SQL statement{for_model_name}:\n{redacted_stmt}")
|
|
311
321
|
try:
|
|
312
322
|
return duckdb_conn.execute(stmt, params)
|
|
313
323
|
except duckdb.ParserException as e:
|
|
@@ -391,3 +401,41 @@ def validate_pkce_challenge(code_verifier: str, code_challenge: str) -> bool:
|
|
|
391
401
|
# Generate expected challenge
|
|
392
402
|
expected_challenge = generate_pkce_challenge(code_verifier)
|
|
393
403
|
return expected_challenge == code_challenge
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def get_scheme(hostname: str | None) -> str:
|
|
407
|
+
"""Get the scheme of the request"""
|
|
408
|
+
return "http" if hostname in ("localhost", "127.0.0.1") else "https"
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def to_title_case(input_str: str) -> str:
|
|
412
|
+
"""Convert a string to title case"""
|
|
413
|
+
spaced_str = input_str.replace('_', ' ').replace('-', ' ')
|
|
414
|
+
return spaced_str.title()
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def to_bool(val: object) -> bool:
|
|
418
|
+
"""Convert common truthy/falsey representations to a boolean.
|
|
419
|
+
|
|
420
|
+
Accepted truthy values (case-insensitive): "1", "true", "t", "yes", "y", "on".
|
|
421
|
+
All other values are considered falsey. None is falsey.
|
|
422
|
+
"""
|
|
423
|
+
if isinstance(val, bool):
|
|
424
|
+
return val
|
|
425
|
+
if val is None:
|
|
426
|
+
return False
|
|
427
|
+
s = str(val).strip().lower()
|
|
428
|
+
return s in ("1", "true", "t", "yes", "y", "on")
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
ACCESS_LEVEL = Literal["admin", "member", "guest"]
|
|
432
|
+
|
|
433
|
+
def get_access_level_rank(access_level: ACCESS_LEVEL) -> int:
|
|
434
|
+
"""Get the rank of an access level. Lower ranks have more privileges."""
|
|
435
|
+
return { "admin": 1, "member": 2, "guest": 3 }.get(access_level.lower(), 1)
|
|
436
|
+
|
|
437
|
+
def user_has_elevated_privileges(user_access_level: ACCESS_LEVEL, required_access_level: ACCESS_LEVEL) -> bool:
|
|
438
|
+
"""Check if a user has privilege to access a resource"""
|
|
439
|
+
user_access_level_rank = get_access_level_rank(user_access_level)
|
|
440
|
+
required_access_level_rank = get_access_level_rank(required_access_level)
|
|
441
|
+
return user_access_level_rank <= required_access_level_rank
|
squirrels/_version.py
CHANGED
squirrels/arguments.py
CHANGED
|
@@ -1,2 +1,7 @@
|
|
|
1
1
|
from ._arguments.init_time_args import ConnectionsArgs, AuthProviderArgs, ParametersArgs, BuildModelArgs
|
|
2
2
|
from ._arguments.run_time_args import ContextArgs, ModelArgs, DashboardArgs
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"ConnectionsArgs", "AuthProviderArgs", "ParametersArgs", "BuildModelArgs",
|
|
6
|
+
"ContextArgs", "ModelArgs", "DashboardArgs"
|
|
7
|
+
]
|
squirrels/auth.py
CHANGED
squirrels/connections.py
CHANGED
squirrels/dashboards.py
CHANGED
squirrels/data_sources.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from ._data_sources import (
|
|
2
|
+
SourceEnum,
|
|
2
3
|
SelectDataSource,
|
|
3
4
|
DateDataSource,
|
|
4
5
|
DateRangeDataSource,
|
|
@@ -6,3 +7,8 @@ from ._data_sources import (
|
|
|
6
7
|
NumberRangeDataSource,
|
|
7
8
|
TextDataSource
|
|
8
9
|
)
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"SourceEnum", "SelectDataSource", "DateDataSource", "DateRangeDataSource",
|
|
13
|
+
"NumberDataSource", "NumberRangeDataSource", "TextDataSource"
|
|
14
|
+
]
|
squirrels/parameter_options.py
CHANGED
|
@@ -6,3 +6,8 @@ from ._parameter_options import (
|
|
|
6
6
|
NumberRangeParameterOption,
|
|
7
7
|
TextParameterOption
|
|
8
8
|
)
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"SelectParameterOption", "DateParameterOption", "DateRangeParameterOption",
|
|
12
|
+
"NumberParameterOption", "NumberRangeParameterOption", "TextParameterOption"
|
|
13
|
+
]
|
squirrels/parameters.py
CHANGED
squirrels/types.py
CHANGED
|
@@ -8,4 +8,9 @@ from ._dataset_types import DatasetMetadata, DatasetResult
|
|
|
8
8
|
|
|
9
9
|
from ._dashboards import Dashboard
|
|
10
10
|
|
|
11
|
-
from ._parameter_configs import ParameterConfigBase
|
|
11
|
+
from ._parameter_configs import ParameterConfigBase
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"DataSource", "ParameterOption", "Parameter", "TextValue",
|
|
15
|
+
"DatasetMetadata", "DatasetResult", "Dashboard", "ParameterConfigBase"
|
|
16
|
+
]
|