squirrels 0.5.0b4__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +2 -0
- squirrels/_api_routes/auth.py +83 -74
- squirrels/_api_routes/base.py +58 -41
- squirrels/_api_routes/dashboards.py +37 -21
- squirrels/_api_routes/data_management.py +72 -27
- squirrels/_api_routes/datasets.py +107 -84
- squirrels/_api_routes/oauth2.py +11 -13
- squirrels/_api_routes/project.py +71 -33
- squirrels/_api_server.py +130 -63
- squirrels/_arguments/run_time_args.py +9 -9
- squirrels/_auth.py +117 -162
- squirrels/_command_line.py +68 -32
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +11 -2
- squirrels/_constants.py +22 -8
- squirrels/_data_sources.py +38 -32
- squirrels/_dataset_types.py +2 -4
- squirrels/_initializer.py +1 -1
- squirrels/_logging.py +117 -0
- squirrels/_manifest.py +125 -58
- squirrels/_model_builder.py +10 -54
- squirrels/_models.py +224 -108
- squirrels/_package_data/base_project/.env +15 -4
- squirrels/_package_data/base_project/.env.example +14 -3
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +2 -2
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +4 -4
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +22 -15
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +1 -1
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +5 -1
- squirrels/_package_data/base_project/pyconfigs/context.py +23 -12
- squirrels/_package_data/base_project/pyconfigs/parameters.py +68 -33
- squirrels/_package_data/base_project/pyconfigs/user.py +11 -18
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +18 -28
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +8 -10
- squirrels/_project.py +351 -234
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/auth_models.py +32 -9
- squirrels/_schemas/query_param_models.py +9 -1
- squirrels/_schemas/response_models.py +36 -10
- squirrels/_seeds.py +1 -1
- squirrels/_sources.py +23 -19
- squirrels/_utils.py +83 -35
- squirrels/_version.py +1 -1
- squirrels/arguments.py +5 -0
- squirrels/auth.py +4 -1
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +6 -1
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/METADATA +28 -13
- squirrels-0.5.1.dist-info/RECORD +98 -0
- squirrels-0.5.0b4.dist-info/RECORD +0 -94
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
squirrels/_api_routes/project.py
CHANGED
|
@@ -8,16 +8,17 @@ from fastapi.security import HTTPBearer
|
|
|
8
8
|
from mcp.server.fastmcp import FastMCP, Context
|
|
9
9
|
from dataclasses import asdict
|
|
10
10
|
from cachetools import TTLCache
|
|
11
|
+
from textwrap import dedent
|
|
11
12
|
import time
|
|
12
13
|
|
|
13
14
|
from .. import _utils as u, _constants as c
|
|
14
15
|
from .._schemas import response_models as rm
|
|
15
16
|
from .._parameter_sets import ParameterSet
|
|
16
17
|
from .._exceptions import ConfigurationError, InvalidInputError
|
|
17
|
-
from .._manifest import PermissionScope
|
|
18
|
+
from .._manifest import PermissionScope, AuthenticationEnforcement
|
|
18
19
|
from .._version import __version__
|
|
19
20
|
from .._schemas.query_param_models import get_query_models_for_parameters
|
|
20
|
-
from ..
|
|
21
|
+
from .._schemas.auth_models import AbstractUser
|
|
21
22
|
from .base import RouteBase
|
|
22
23
|
|
|
23
24
|
|
|
@@ -34,13 +35,13 @@ class ProjectRoutes(RouteBase):
|
|
|
34
35
|
|
|
35
36
|
async def _get_parameters_helper(
|
|
36
37
|
self, parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
37
|
-
user:
|
|
38
|
+
user: AbstractUser, selections: tuple[tuple[str, Any], ...]
|
|
38
39
|
) -> ParameterSet:
|
|
39
40
|
"""Helper for getting parameters"""
|
|
40
41
|
selections_dict = dict(selections)
|
|
41
42
|
if "x_parent_param" not in selections_dict:
|
|
42
43
|
if len(selections_dict) > 1:
|
|
43
|
-
raise InvalidInputError(400, "
|
|
44
|
+
raise InvalidInputError(400, "invalid_input_for_cascading_parameters", f"The parameters endpoint takes at most 1 widget parameter selection (unless x_parent_param is provided). Got {selections_dict}")
|
|
44
45
|
elif len(selections_dict) == 1:
|
|
45
46
|
parent_param = next(iter(selections_dict))
|
|
46
47
|
selections_dict["x_parent_param"] = parent_param
|
|
@@ -58,7 +59,7 @@ class ProjectRoutes(RouteBase):
|
|
|
58
59
|
|
|
59
60
|
async def _get_parameters_cachable(
|
|
60
61
|
self, parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
61
|
-
user:
|
|
62
|
+
user: AbstractUser, selections: tuple[tuple[str, Any], ...]
|
|
62
63
|
) -> ParameterSet:
|
|
63
64
|
"""Cachable version of parameters helper"""
|
|
64
65
|
return await self.do_cachable_action(
|
|
@@ -66,9 +67,13 @@ class ProjectRoutes(RouteBase):
|
|
|
66
67
|
)
|
|
67
68
|
|
|
68
69
|
def setup_routes(
|
|
69
|
-
self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_version: str, param_fields: dict
|
|
70
|
+
self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_version: str, project_label: str, param_fields: dict
|
|
70
71
|
):
|
|
71
72
|
"""Setup project metadata routes"""
|
|
73
|
+
|
|
74
|
+
elevated_access_level = self.project._elevated_access_level
|
|
75
|
+
if elevated_access_level != "admin":
|
|
76
|
+
self.logger.warning(f"{c.SQRL_PERMISSIONS_ELEVATED_ACCESS_LEVEL} has been set to a non-admin access level. For security reasons, DO NOT expose the APIs for this app publicly!")
|
|
72
77
|
|
|
73
78
|
# Project metadata endpoint
|
|
74
79
|
@app.get(project_metadata_path, tags=["Project Metadata"], response_class=JSONResponse)
|
|
@@ -78,37 +83,54 @@ class ProjectRoutes(RouteBase):
|
|
|
78
83
|
version=project_version,
|
|
79
84
|
label=self.manifest_cfg.project_variables.label,
|
|
80
85
|
description=self.manifest_cfg.project_variables.description,
|
|
86
|
+
elevated_access_level=elevated_access_level,
|
|
87
|
+
redoc_path=project_metadata_path + "/redoc",
|
|
88
|
+
swagger_path=project_metadata_path + "/docs",
|
|
89
|
+
mcp_server_path=project_metadata_path + "/mcp",
|
|
81
90
|
squirrels_version=__version__
|
|
82
91
|
)
|
|
83
92
|
|
|
84
93
|
# Data catalog endpoint
|
|
85
94
|
data_catalog_path = project_metadata_path + '/data-catalog'
|
|
86
95
|
|
|
87
|
-
async def get_data_catalog0(user:
|
|
96
|
+
async def get_data_catalog0(user: AbstractUser) -> rm.CatalogModel:
|
|
88
97
|
parameters = self.param_cfg_set.apply_selections(None, {}, user)
|
|
89
98
|
parameters_model = parameters.to_api_response_model0()
|
|
90
99
|
full_parameters_list = [p.name for p in parameters_model.parameters]
|
|
100
|
+
user_has_elevated_privileges = u.user_has_elevated_privileges(user.access_level, elevated_access_level)
|
|
91
101
|
|
|
92
102
|
dataset_items: list[rm.DatasetItemModel] = []
|
|
93
103
|
for name, config in self.manifest_cfg.datasets.items():
|
|
94
104
|
if self.authenticator.can_user_access_scope(user, config.scope):
|
|
95
|
-
|
|
105
|
+
name_for_api = u.normalize_name_for_api(name)
|
|
96
106
|
metadata = self.project.dataset_metadata(name).to_json()
|
|
97
107
|
parameters = config.parameters if config.parameters is not None else full_parameters_list
|
|
108
|
+
|
|
109
|
+
# Build dataset-specific configurables list
|
|
110
|
+
if user_has_elevated_privileges:
|
|
111
|
+
dataset_configurables_defaults = self.manifest_cfg.get_default_configurables(name)
|
|
112
|
+
dataset_configurables_list = [
|
|
113
|
+
rm.ConfigurableDefaultModel(name=name, default=default)
|
|
114
|
+
for name, default in dataset_configurables_defaults.items()
|
|
115
|
+
]
|
|
116
|
+
else:
|
|
117
|
+
dataset_configurables_list = []
|
|
118
|
+
|
|
98
119
|
dataset_items.append(rm.DatasetItemModel(
|
|
99
|
-
name=
|
|
120
|
+
name=name, label=config.label,
|
|
100
121
|
description=config.description,
|
|
101
122
|
schema=metadata["schema"], # type: ignore
|
|
123
|
+
configurables=dataset_configurables_list,
|
|
102
124
|
parameters=parameters,
|
|
103
|
-
parameters_path=f"{project_metadata_path}/dataset/{
|
|
104
|
-
result_path=f"{project_metadata_path}/dataset/{
|
|
125
|
+
parameters_path=f"{project_metadata_path}/dataset/{name_for_api}/parameters",
|
|
126
|
+
result_path=f"{project_metadata_path}/dataset/{name_for_api}"
|
|
105
127
|
))
|
|
106
128
|
|
|
107
129
|
dashboard_items: list[rm.DashboardItemModel] = []
|
|
108
130
|
for name, dashboard in self.project._dashboards.items():
|
|
109
131
|
config = dashboard.config
|
|
110
132
|
if self.authenticator.can_user_access_scope(user, config.scope):
|
|
111
|
-
|
|
133
|
+
name_for_api = u.normalize_name_for_api(name)
|
|
112
134
|
|
|
113
135
|
try:
|
|
114
136
|
dashboard_format = self.project._dashboards[name].get_dashboard_format()
|
|
@@ -121,19 +143,24 @@ class ProjectRoutes(RouteBase):
|
|
|
121
143
|
description=config.description,
|
|
122
144
|
result_format=dashboard_format,
|
|
123
145
|
parameters=parameters,
|
|
124
|
-
parameters_path=f"{project_metadata_path}/dashboard/{
|
|
125
|
-
result_path=f"{project_metadata_path}/dashboard/{
|
|
146
|
+
parameters_path=f"{project_metadata_path}/dashboard/{name_for_api}/parameters",
|
|
147
|
+
result_path=f"{project_metadata_path}/dashboard/{name_for_api}"
|
|
126
148
|
))
|
|
127
149
|
|
|
128
|
-
if
|
|
129
|
-
compiled_dag = await self.project._get_compiled_dag(user
|
|
150
|
+
if user_has_elevated_privileges:
|
|
151
|
+
compiled_dag = await self.project._get_compiled_dag(user)
|
|
130
152
|
connections_items = self.project._get_all_connections()
|
|
131
153
|
data_models = self.project._get_all_data_models(compiled_dag)
|
|
132
154
|
lineage_items = self.project._get_all_data_lineage(compiled_dag)
|
|
155
|
+
configurables_list = [
|
|
156
|
+
rm.ConfigurableItemModel(name=name, label=cfg.label, default=cfg.default, description=cfg.description)
|
|
157
|
+
for name, cfg in self.manifest_cfg.configurables.items()
|
|
158
|
+
]
|
|
133
159
|
else:
|
|
134
160
|
connections_items = []
|
|
135
161
|
data_models = []
|
|
136
162
|
lineage_items = []
|
|
163
|
+
configurables_list = []
|
|
137
164
|
|
|
138
165
|
return rm.CatalogModel(
|
|
139
166
|
parameters=parameters_model.parameters,
|
|
@@ -142,29 +169,40 @@ class ProjectRoutes(RouteBase):
|
|
|
142
169
|
connections=connections_items,
|
|
143
170
|
models=data_models,
|
|
144
171
|
lineage=lineage_items,
|
|
172
|
+
configurables=configurables_list,
|
|
145
173
|
)
|
|
146
174
|
|
|
147
175
|
@app.get(data_catalog_path, tags=["Project Metadata"], summary="Get catalog of datasets and dashboards available for user")
|
|
148
|
-
async def get_data_catalog(request: Request, user:
|
|
176
|
+
async def get_data_catalog(request: Request, user: AbstractUser = Depends(self.get_current_user)) -> rm.CatalogModel:
|
|
149
177
|
"""
|
|
150
178
|
Get catalog of datasets and dashboards available for the authenticated user.
|
|
151
179
|
|
|
152
180
|
For admin users, this endpoint will also return detailed information about all models and their lineage in the project.
|
|
153
181
|
"""
|
|
154
|
-
|
|
182
|
+
start = time.time()
|
|
183
|
+
|
|
184
|
+
# If authentication is required, require user to be authenticated to access catalog
|
|
185
|
+
if self.manifest_cfg.authentication.enforcement == AuthenticationEnforcement.REQUIRED and user.access_level == "guest":
|
|
186
|
+
raise InvalidInputError(401, "user_required", "Authentication is required to access the data catalog")
|
|
187
|
+
data_catalog = await get_data_catalog0(user)
|
|
188
|
+
|
|
189
|
+
self.logger.log_activity_time("GET REQUEST for DATA CATALOG", start)
|
|
190
|
+
return data_catalog
|
|
155
191
|
|
|
156
192
|
@mcp.tool(
|
|
157
|
-
name=f"
|
|
158
|
-
|
|
193
|
+
name=f"get_data_catalog_from_{project_name}",
|
|
194
|
+
title=f"Get Data Catalog (Project: {project_label})",
|
|
195
|
+
description=dedent(f"""
|
|
196
|
+
Use this tool to get the details of all datasets and parameters you can access in the Squirrels project '{project_name}'.
|
|
197
|
+
|
|
198
|
+
Unless the data catalog for this project has already been provided, use this tool at the start of each conversation.
|
|
199
|
+
""").strip()
|
|
159
200
|
)
|
|
160
|
-
async def get_data_catalog_tool(ctx: Context):
|
|
161
|
-
|
|
201
|
+
async def get_data_catalog_tool(ctx: Context) -> rm.CatalogModelForTool:
|
|
202
|
+
headers = self.get_headers_from_tool_ctx(ctx)
|
|
203
|
+
user = self.get_user_from_tool_headers(headers)
|
|
162
204
|
data_catalog = await get_data_catalog0(user)
|
|
163
|
-
|
|
164
|
-
"parameters": data_catalog.parameters,
|
|
165
|
-
"datasets": data_catalog.datasets,
|
|
166
|
-
}
|
|
167
|
-
return restricted_data_catalog
|
|
205
|
+
return rm.CatalogModelForTool(parameters=data_catalog.parameters, datasets=data_catalog.datasets)
|
|
168
206
|
|
|
169
207
|
# Project-level parameters endpoints
|
|
170
208
|
project_level_parameters_path = project_metadata_path + '/parameters'
|
|
@@ -177,9 +215,9 @@ class ProjectRoutes(RouteBase):
|
|
|
177
215
|
|
|
178
216
|
async def get_parameters_definition(
|
|
179
217
|
parameters_list: list[str] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
180
|
-
user, all_request_params: dict, params: dict
|
|
218
|
+
user: AbstractUser, all_request_params: dict, params: dict, *, headers: dict[str, str]
|
|
181
219
|
) -> rm.ParametersModel:
|
|
182
|
-
self._validate_request_params(all_request_params, params)
|
|
220
|
+
self._validate_request_params(all_request_params, params, headers)
|
|
183
221
|
|
|
184
222
|
get_parameters_function = self._get_parameters_helper if self.no_cache else self._get_parameters_cachable
|
|
185
223
|
selections = self.get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
|
|
@@ -193,9 +231,9 @@ class ProjectRoutes(RouteBase):
|
|
|
193
231
|
) -> rm.ParametersModel:
|
|
194
232
|
start = time.time()
|
|
195
233
|
result = await get_parameters_definition(
|
|
196
|
-
None, "project", "", PermissionScope.PUBLIC, user, dict(request.query_params), asdict(params)
|
|
234
|
+
None, "project", "", PermissionScope.PUBLIC, user, dict(request.query_params), asdict(params), headers=dict(request.headers)
|
|
197
235
|
)
|
|
198
|
-
self.log_activity_time("GET REQUEST for PROJECT PARAMETERS", start
|
|
236
|
+
self.logger.log_activity_time("GET REQUEST for PROJECT PARAMETERS", start)
|
|
199
237
|
return result
|
|
200
238
|
|
|
201
239
|
@app.post(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
|
|
@@ -205,9 +243,9 @@ class ProjectRoutes(RouteBase):
|
|
|
205
243
|
start = time.time()
|
|
206
244
|
payload: dict = await request.json()
|
|
207
245
|
result = await get_parameters_definition(
|
|
208
|
-
None, "project", "", PermissionScope.PUBLIC, user, payload, params.model_dump()
|
|
246
|
+
None, "project", "", PermissionScope.PUBLIC, user, payload, params.model_dump(), headers=dict(request.headers)
|
|
209
247
|
)
|
|
210
|
-
self.log_activity_time("POST REQUEST for PROJECT PARAMETERS", start
|
|
248
|
+
self.logger.log_activity_time("POST REQUEST for PROJECT PARAMETERS", start)
|
|
211
249
|
return result
|
|
212
250
|
|
|
213
251
|
return get_parameters_definition
|
squirrels/_api_server.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from fastapi import FastAPI, Request, status
|
|
2
|
-
from fastapi.responses import JSONResponse, RedirectResponse
|
|
2
|
+
from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse
|
|
3
3
|
from fastapi.security import HTTPBearer
|
|
4
|
+
from fastapi.templating import Jinja2Templates
|
|
5
|
+
from fastapi.staticfiles import StaticFiles
|
|
4
6
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
5
7
|
from starlette.responses import Response as StarletteResponse
|
|
6
8
|
from contextlib import asynccontextmanager
|
|
@@ -8,12 +10,13 @@ from argparse import Namespace
|
|
|
8
10
|
from pathlib import Path
|
|
9
11
|
from starlette.middleware.sessions import SessionMiddleware
|
|
10
12
|
from mcp.server.fastmcp import FastMCP
|
|
11
|
-
import io, time, mimetypes, traceback, uuid, asyncio,
|
|
13
|
+
import io, time, mimetypes, traceback, uuid, asyncio, contextlib
|
|
12
14
|
|
|
13
|
-
from . import _constants as c, _utils as u
|
|
15
|
+
from . import _constants as c, _utils as u, _parameter_sets as ps
|
|
14
16
|
from ._exceptions import InvalidInputError, ConfigurationError, FileExecutionError
|
|
15
17
|
from ._version import __version__, sq_major_version
|
|
16
18
|
from ._project import SquirrelsProject
|
|
19
|
+
from ._request_context import set_request_id
|
|
17
20
|
|
|
18
21
|
# Import route modules
|
|
19
22
|
from ._api_routes.auth import AuthRoutes
|
|
@@ -21,7 +24,9 @@ from ._api_routes.project import ProjectRoutes
|
|
|
21
24
|
from ._api_routes.datasets import DatasetRoutes
|
|
22
25
|
from ._api_routes.dashboards import DashboardRoutes
|
|
23
26
|
from ._api_routes.data_management import DataManagementRoutes
|
|
24
|
-
|
|
27
|
+
|
|
28
|
+
# # Disabled for now, a 'bring your own OAuth2 server' approach will be provided in the future
|
|
29
|
+
# from ._api_routes.oauth2 import OAuth2Routes
|
|
25
30
|
|
|
26
31
|
mimetypes.add_type('application/javascript', '.js')
|
|
27
32
|
|
|
@@ -32,10 +37,13 @@ class SmartCORSMiddleware(BaseHTTPMiddleware):
|
|
|
32
37
|
while still allowing all other origins without credentials.
|
|
33
38
|
"""
|
|
34
39
|
|
|
35
|
-
def __init__(self, app, allowed_credential_origins: list[str]
|
|
40
|
+
def __init__(self, app, allowed_credential_origins: list[str], configurables_as_headers: list[str]):
|
|
36
41
|
super().__init__(app)
|
|
37
|
-
|
|
38
|
-
|
|
42
|
+
|
|
43
|
+
allowed_predefined_headers = ["Authorization", "Content-Type", "x-api-key", "x-orientation", "x-verify-params"]
|
|
44
|
+
|
|
45
|
+
self.allowed_credential_origins = allowed_credential_origins
|
|
46
|
+
self.allowed_request_headers = ",".join(allowed_predefined_headers + configurables_as_headers)
|
|
39
47
|
|
|
40
48
|
async def dispatch(self, request: Request, call_next):
|
|
41
49
|
origin = request.headers.get("origin")
|
|
@@ -44,7 +52,7 @@ class SmartCORSMiddleware(BaseHTTPMiddleware):
|
|
|
44
52
|
if request.method == "OPTIONS":
|
|
45
53
|
response = StarletteResponse(status_code=200)
|
|
46
54
|
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
|
47
|
-
response.headers["Access-Control-Allow-Headers"] =
|
|
55
|
+
response.headers["Access-Control-Allow-Headers"] = self.allowed_request_headers
|
|
48
56
|
|
|
49
57
|
else:
|
|
50
58
|
# Call the next middleware/route
|
|
@@ -54,7 +62,7 @@ class SmartCORSMiddleware(BaseHTTPMiddleware):
|
|
|
54
62
|
response.headers["Access-Control-Expose-Headers"] = "Applied-Username"
|
|
55
63
|
|
|
56
64
|
if origin:
|
|
57
|
-
scheme =
|
|
65
|
+
scheme = u.get_scheme(request.url.hostname)
|
|
58
66
|
request_origin = f"{scheme}://{request.url.netloc}"
|
|
59
67
|
# Check if this origin is in the whitelist or if origin matches the host origin
|
|
60
68
|
if origin == request_origin or origin in self.allowed_credential_origins:
|
|
@@ -93,11 +101,14 @@ class ApiServer:
|
|
|
93
101
|
self.context_func = project._context_func
|
|
94
102
|
self.dashboards = project._dashboards
|
|
95
103
|
|
|
96
|
-
self.mcp = FastMCP(
|
|
104
|
+
self.mcp = FastMCP(
|
|
105
|
+
name="Squirrels",
|
|
106
|
+
stateless_http=True
|
|
107
|
+
)
|
|
97
108
|
|
|
98
109
|
# Initialize route modules
|
|
99
110
|
get_bearer_token = HTTPBearer(auto_error=False)
|
|
100
|
-
self.oauth2_routes = OAuth2Routes(get_bearer_token, project, no_cache)
|
|
111
|
+
# self.oauth2_routes = OAuth2Routes(get_bearer_token, project, no_cache)
|
|
101
112
|
self.auth_routes = AuthRoutes(get_bearer_token, project, no_cache)
|
|
102
113
|
self.project_routes = ProjectRoutes(get_bearer_token, project, no_cache)
|
|
103
114
|
self.dataset_routes = DatasetRoutes(get_bearer_token, project, no_cache)
|
|
@@ -105,37 +116,63 @@ class ApiServer:
|
|
|
105
116
|
self.data_management_routes = DataManagementRoutes(get_bearer_token, project, no_cache)
|
|
106
117
|
|
|
107
118
|
|
|
108
|
-
async def
|
|
109
|
-
"""
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
119
|
+
async def _refresh_datasource_params(self) -> None:
|
|
120
|
+
"""
|
|
121
|
+
Background task to periodically refresh datasource parameter options.
|
|
122
|
+
Runs every N minutes as configured by SQRL_PARAMETERS__DATASOURCE_REFRESH_MINUTES (default: 60).
|
|
123
|
+
"""
|
|
124
|
+
refresh_minutes_str = self.env_vars.get(c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES, "60")
|
|
125
|
+
try:
|
|
126
|
+
refresh_minutes = int(refresh_minutes_str)
|
|
127
|
+
if refresh_minutes <= 0:
|
|
128
|
+
self.logger.info(f"The value of {c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES} is: {refresh_minutes_str} minutes")
|
|
129
|
+
self.logger.info(f"Datasource parameter refresh is disabled since the refresh interval is not positive.")
|
|
130
|
+
return
|
|
131
|
+
except ValueError:
|
|
132
|
+
self.logger.warning(f"Invalid value for {c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES}: {refresh_minutes_str}. Must be an integer. Disabling datasource parameter refresh.")
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
refresh_seconds = refresh_minutes * 60
|
|
136
|
+
self.logger.info(f"Starting datasource parameter refresh background task (every {refresh_minutes} minutes)")
|
|
137
|
+
|
|
114
138
|
while True:
|
|
115
139
|
try:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
140
|
+
await asyncio.sleep(refresh_seconds)
|
|
141
|
+
self.logger.info("Refreshing datasource parameter options...")
|
|
142
|
+
|
|
143
|
+
# Fetch fresh dataframes from datasources in a thread pool to avoid blocking
|
|
144
|
+
loop = asyncio.get_running_loop()
|
|
145
|
+
default_conn_name = self.manifest_cfg.env_vars.get(c.SQRL_CONNECTIONS_DEFAULT_NAME_USED, "default")
|
|
146
|
+
df_dict = await loop.run_in_executor(
|
|
147
|
+
None,
|
|
148
|
+
ps.ParameterConfigsSetIO._get_df_dict_from_data_sources,
|
|
149
|
+
self.param_cfg_set,
|
|
150
|
+
default_conn_name,
|
|
151
|
+
self.seeds,
|
|
152
|
+
self.conn_set,
|
|
153
|
+
self.project._datalake_db_path
|
|
154
|
+
)
|
|
123
155
|
|
|
156
|
+
# Re-convert datasource parameters with fresh data
|
|
157
|
+
self.param_cfg_set._post_process_params(df_dict)
|
|
158
|
+
|
|
159
|
+
self.logger.info("Successfully refreshed datasource parameter options")
|
|
160
|
+
except asyncio.CancelledError:
|
|
161
|
+
self.logger.info("Datasource parameter refresh task cancelled")
|
|
162
|
+
break
|
|
124
163
|
except Exception as e:
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
await asyncio.sleep(1) # Check every second
|
|
164
|
+
self.logger.error(f"Error refreshing datasource parameter options: {e}", exc_info=True)
|
|
165
|
+
# Continue the loop even if there's an error
|
|
129
166
|
|
|
130
167
|
@asynccontextmanager
|
|
131
168
|
async def _run_background_tasks(self, app: FastAPI):
|
|
132
|
-
|
|
169
|
+
refresh_datasource_task = asyncio.create_task(self._refresh_datasource_params())
|
|
133
170
|
|
|
134
171
|
async with contextlib.AsyncExitStack() as stack:
|
|
135
172
|
await stack.enter_async_context(self.mcp.session_manager.run())
|
|
136
173
|
yield
|
|
137
174
|
|
|
138
|
-
|
|
175
|
+
refresh_datasource_task.cancel()
|
|
139
176
|
|
|
140
177
|
|
|
141
178
|
def _get_tags_metadata(self) -> list[dict]:
|
|
@@ -171,10 +208,10 @@ class ApiServer:
|
|
|
171
208
|
"name": "User Management",
|
|
172
209
|
"description": "Manage users and their attributes",
|
|
173
210
|
},
|
|
174
|
-
{
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
},
|
|
211
|
+
# {
|
|
212
|
+
# "name": "OAuth2",
|
|
213
|
+
# "description": "Authorize and get token using the OAuth2 protocol",
|
|
214
|
+
# },
|
|
178
215
|
])
|
|
179
216
|
return tags_metadata
|
|
180
217
|
|
|
@@ -188,17 +225,19 @@ class ApiServer:
|
|
|
188
225
|
"""
|
|
189
226
|
start = time.time()
|
|
190
227
|
|
|
191
|
-
squirrels_version_path = f'/api/squirrels
|
|
192
|
-
project_name =
|
|
228
|
+
squirrels_version_path = f'/api/squirrels/v{sq_major_version}'
|
|
229
|
+
project_name = self.manifest_cfg.project_variables.name
|
|
230
|
+
project_name_for_api = u.normalize_name_for_api(project_name)
|
|
231
|
+
project_label = self.manifest_cfg.project_variables.label
|
|
193
232
|
project_version = f"v{self.manifest_cfg.project_variables.major_version}"
|
|
194
|
-
project_metadata_path = squirrels_version_path + f"/project/{
|
|
233
|
+
project_metadata_path = squirrels_version_path + f"/project/{project_name_for_api}/{project_version}"
|
|
195
234
|
|
|
196
235
|
param_fields = self.param_cfg_set.get_all_api_field_info()
|
|
197
236
|
|
|
198
237
|
tags_metadata = self._get_tags_metadata()
|
|
199
238
|
|
|
200
239
|
app = FastAPI(
|
|
201
|
-
title=f"Squirrels APIs for '{
|
|
240
|
+
title=f"Squirrels APIs for '{project_label}'", openapi_tags=tags_metadata,
|
|
202
241
|
description="For specifying parameter selections to dataset APIs, you can choose between using query parameters with the GET method or using request body with the POST method",
|
|
203
242
|
lifespan=self._run_background_tasks,
|
|
204
243
|
openapi_url=project_metadata_path+"/openapi.json",
|
|
@@ -209,29 +248,30 @@ class ApiServer:
|
|
|
209
248
|
app.add_middleware(SessionMiddleware, secret_key=self.env_vars.get(c.SQRL_SECRET_KEY, ""), max_age=None, same_site="none", https_only=True)
|
|
210
249
|
|
|
211
250
|
async def _log_request_run(request: Request) -> None:
|
|
212
|
-
headers = dict(request.scope["headers"])
|
|
213
|
-
request_id = uuid.uuid4().hex
|
|
214
|
-
headers[b"x-request-id"] = request_id.encode()
|
|
215
|
-
request.scope["headers"] = list(headers.items())
|
|
216
|
-
|
|
217
251
|
try:
|
|
218
252
|
body = await request.json()
|
|
219
253
|
except Exception:
|
|
220
|
-
body = None
|
|
254
|
+
body = None # Non-JSON payloads may contain sensitive information, so we don't log them
|
|
255
|
+
|
|
256
|
+
partial_headers: dict[str, str] = {}
|
|
257
|
+
for header in request.headers.keys():
|
|
258
|
+
if header.startswith("x-") and header not in ["x-api-key"]:
|
|
259
|
+
partial_headers[header] = request.headers[header]
|
|
221
260
|
|
|
222
|
-
headers_dict = dict(request.headers)
|
|
223
261
|
path, params = request.url.path, dict(request.query_params)
|
|
224
262
|
path_with_params = f"{path}?{request.query_params}" if len(params) > 0 else path
|
|
225
|
-
data = {"request_method": request.method, "request_path": path, "request_params": params, "
|
|
226
|
-
info
|
|
227
|
-
self.logger.info(f'Running request: {request.method} {path_with_params}', extra={"data": data, "info": info})
|
|
263
|
+
data = {"request_method": request.method, "request_path": path, "request_params": params, "request_body": body, "partial_headers": partial_headers}
|
|
264
|
+
self.logger.info(f'Running request: {request.method} {path_with_params}', data=data)
|
|
228
265
|
|
|
229
266
|
@app.middleware("http")
|
|
230
267
|
async def catch_exceptions_middleware(request: Request, call_next):
|
|
268
|
+
# Generate and set request ID for this request
|
|
269
|
+
request_id = set_request_id()
|
|
270
|
+
|
|
231
271
|
buffer = io.StringIO()
|
|
232
272
|
try:
|
|
233
273
|
await _log_request_run(request)
|
|
234
|
-
|
|
274
|
+
response = await call_next(request)
|
|
235
275
|
except InvalidInputError as exc:
|
|
236
276
|
message = str(exc)
|
|
237
277
|
self.logger.error(message)
|
|
@@ -258,7 +298,10 @@ class ApiServer:
|
|
|
258
298
|
err_msg = buffer.getvalue()
|
|
259
299
|
if err_msg:
|
|
260
300
|
self.logger.error(err_msg)
|
|
261
|
-
|
|
301
|
+
|
|
302
|
+
# Add request ID to response header
|
|
303
|
+
response.headers["X-Request-ID"] = request_id
|
|
304
|
+
|
|
262
305
|
return response
|
|
263
306
|
|
|
264
307
|
# Configure CORS with smart credential handling
|
|
@@ -266,36 +309,60 @@ class ApiServer:
|
|
|
266
309
|
credential_origins_env = self.env_vars.get(c.SQRL_AUTH_CREDENTIAL_ORIGINS, "https://squirrels-analytics.github.io")
|
|
267
310
|
allowed_credential_origins = [origin.strip() for origin in credential_origins_env.split(",") if origin.strip()]
|
|
268
311
|
|
|
269
|
-
|
|
312
|
+
# Allow both underscore and dash versions of configurable headers
|
|
313
|
+
configurables_as_headers = []
|
|
314
|
+
for name in self.manifest_cfg.configurables.keys():
|
|
315
|
+
configurables_as_headers.append(f"x-config-{name}") # underscore version
|
|
316
|
+
configurables_as_headers.append(f"x-config-{u.normalize_name_for_api(name)}") # dash version
|
|
317
|
+
|
|
318
|
+
app.add_middleware(SmartCORSMiddleware, allowed_credential_origins=allowed_credential_origins, configurables_as_headers=configurables_as_headers)
|
|
270
319
|
|
|
271
320
|
# Setup route modules
|
|
272
|
-
self.oauth2_routes.setup_routes(app)
|
|
273
|
-
self.auth_routes.setup_routes(app)
|
|
274
|
-
get_parameters_definition = self.project_routes.setup_routes(app, self.mcp, project_metadata_path, project_name, project_version, param_fields)
|
|
321
|
+
# self.oauth2_routes.setup_routes(app, squirrels_version_path)
|
|
322
|
+
self.auth_routes.setup_routes(app, squirrels_version_path)
|
|
323
|
+
get_parameters_definition = self.project_routes.setup_routes(app, self.mcp, project_metadata_path, project_name, project_version, project_label, param_fields)
|
|
275
324
|
self.data_management_routes.setup_routes(app, project_metadata_path, param_fields)
|
|
276
|
-
self.dataset_routes.setup_routes(app, self.mcp, project_metadata_path, project_name,
|
|
325
|
+
self.dataset_routes.setup_routes(app, self.mcp, project_metadata_path, project_name, project_label, param_fields, get_parameters_definition)
|
|
277
326
|
self.dashboard_routes.setup_routes(app, project_metadata_path, param_fields, get_parameters_definition)
|
|
278
327
|
app.mount(project_metadata_path, self.mcp.streamable_http_app())
|
|
279
328
|
|
|
329
|
+
# Mount static files from public directory if it exists
|
|
330
|
+
# This allows users to serve static assets (images, CSS, JS, etc.) from {project_path}/public/
|
|
331
|
+
public_dir = Path(self.project._filepath) / c.PUBLIC_FOLDER
|
|
332
|
+
if public_dir.exists() and public_dir.is_dir():
|
|
333
|
+
app.mount("/public", StaticFiles(directory=str(public_dir)), name="public")
|
|
334
|
+
self.logger.info(f"Mounted static files from: {public_dir}")
|
|
335
|
+
|
|
280
336
|
# Add Root Path Redirection to Squirrels Studio
|
|
281
337
|
full_hostname = f"http://{uvicorn_args.host}:{uvicorn_args.port}"
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
338
|
+
squirrels_studio_path = f"/project/{project_name_for_api}/{project_version}/studio"
|
|
339
|
+
templates = Jinja2Templates(directory=str(Path(__file__).parent / "_package_data" / "templates"))
|
|
340
|
+
|
|
341
|
+
@app.get(squirrels_studio_path, include_in_schema=False)
|
|
342
|
+
async def squirrels_studio():
|
|
343
|
+
default_studio_path = "https://squirrels-analytics.github.io/squirrels-studio-v1"
|
|
344
|
+
sqrl_studio_base_url = self.env_vars.get(c.SQRL_STUDIO_BASE_URL, default_studio_path)
|
|
345
|
+
context = {
|
|
346
|
+
"sqrl_studio_base_url": sqrl_studio_base_url,
|
|
347
|
+
"project_name": project_name_for_api,
|
|
348
|
+
"project_version": project_version,
|
|
349
|
+
}
|
|
350
|
+
return HTMLResponse(content=templates.get_template("squirrels_studio.html").render(context))
|
|
285
351
|
|
|
286
352
|
@app.get("/", include_in_schema=False)
|
|
287
353
|
async def redirect_to_studio():
|
|
288
|
-
return RedirectResponse(url=
|
|
354
|
+
return RedirectResponse(url=squirrels_studio_path)
|
|
355
|
+
|
|
356
|
+
self.logger.log_activity_time("creating app server", start)
|
|
289
357
|
|
|
290
358
|
# Run the API Server
|
|
291
359
|
import uvicorn
|
|
292
360
|
|
|
293
361
|
print("\nWelcome to the Squirrels Data Application!\n")
|
|
294
|
-
print(f"- Application UI: {
|
|
362
|
+
print(f"- Application UI (Squirrels Studio): {full_hostname}{squirrels_studio_path}")
|
|
295
363
|
print(f"- API Docs (with ReDoc): {full_hostname}{project_metadata_path}/redoc")
|
|
296
364
|
print(f"- API Docs (with Swagger UI): {full_hostname}{project_metadata_path}/docs")
|
|
365
|
+
print(f"- MCP Server URL: {full_hostname}{project_metadata_path}/mcp")
|
|
297
366
|
print()
|
|
298
367
|
|
|
299
|
-
|
|
300
|
-
uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port)
|
|
301
|
-
|
|
368
|
+
uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port, proxy_headers=True, forwarded_allow_ips="*")
|
|
@@ -2,7 +2,7 @@ from typing import Callable, Any, Coroutine
|
|
|
2
2
|
import polars as pl
|
|
3
3
|
|
|
4
4
|
from .init_time_args import ParametersArgs, BuildModelArgs
|
|
5
|
-
from ..
|
|
5
|
+
from .._schemas.auth_models import AbstractUser
|
|
6
6
|
from .._parameters import Parameter, TextValue
|
|
7
7
|
|
|
8
8
|
|
|
@@ -10,14 +10,14 @@ class ContextArgs(ParametersArgs):
|
|
|
10
10
|
|
|
11
11
|
def __init__(
|
|
12
12
|
self, param_args: ParametersArgs,
|
|
13
|
-
user:
|
|
14
|
-
prms: dict[str, Parameter],
|
|
15
|
-
|
|
13
|
+
user: AbstractUser,
|
|
14
|
+
prms: dict[str, Parameter],
|
|
15
|
+
configurables: dict[str, str]
|
|
16
16
|
):
|
|
17
17
|
super().__init__(param_args.project_path, param_args.proj_vars, param_args.env_vars)
|
|
18
18
|
self.user = user
|
|
19
19
|
self._prms = prms
|
|
20
|
-
self.
|
|
20
|
+
self._configurables = configurables
|
|
21
21
|
self._placeholders = {}
|
|
22
22
|
|
|
23
23
|
@property
|
|
@@ -28,11 +28,11 @@ class ContextArgs(ParametersArgs):
|
|
|
28
28
|
return self._prms.copy()
|
|
29
29
|
|
|
30
30
|
@property
|
|
31
|
-
def
|
|
31
|
+
def configurables(self) -> dict[str, str]:
|
|
32
32
|
"""
|
|
33
|
-
A dictionary of
|
|
33
|
+
A dictionary of configurable name to value (set by application)
|
|
34
34
|
"""
|
|
35
|
-
return self.
|
|
35
|
+
return self._configurables.copy()
|
|
36
36
|
|
|
37
37
|
@property
|
|
38
38
|
def _placeholders_copy(self) -> dict[str, Any]:
|
|
@@ -80,7 +80,7 @@ class ModelArgs(BuildModelArgs, ContextArgs):
|
|
|
80
80
|
self._env_vars = ctx_args.env_vars
|
|
81
81
|
self.user = ctx_args.user
|
|
82
82
|
self._prms = ctx_args.prms
|
|
83
|
-
self.
|
|
83
|
+
self._configurables = ctx_args.configurables
|
|
84
84
|
self._placeholders = ctx_args._placeholders_copy
|
|
85
85
|
self._connections = build_model_args.connections
|
|
86
86
|
self._dependencies = build_model_args.dependencies
|