squirrels 0.4.1__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +58 -111
- dateutils/types.py +6 -0
- squirrels/__init__.py +13 -11
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +271 -0
- squirrels/_api_routes/base.py +165 -0
- squirrels/_api_routes/dashboards.py +150 -0
- squirrels/_api_routes/data_management.py +145 -0
- squirrels/_api_routes/datasets.py +257 -0
- squirrels/_api_routes/oauth2.py +298 -0
- squirrels/_api_routes/project.py +252 -0
- squirrels/_api_server.py +256 -450
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +108 -0
- squirrels/_arguments/run_time_args.py +147 -0
- squirrels/_auth.py +960 -0
- squirrels/_command_line.py +126 -45
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +48 -26
- squirrels/_constants.py +68 -38
- squirrels/_dashboards.py +160 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +84 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_initializer.py +177 -80
- squirrels/_logging.py +115 -0
- squirrels/_manifest.py +208 -79
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +926 -367
- squirrels/_package_data/base_project/.env +42 -0
- squirrels/_package_data/base_project/.env.example +42 -0
- squirrels/_package_data/base_project/assets/expenses.db +0 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +34 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +5 -2
- squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +3 -3
- squirrels/{package_data → _package_data}/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +3 -2
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +12 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +26 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +37 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +19 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/{package_data → _package_data}/base_project/parameters.yml +56 -40
- squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +21 -40
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_package_loader.py +8 -4
- squirrels/_parameter_configs.py +104 -103
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +57 -47
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +721 -0
- squirrels/_py_module.py +7 -5
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +167 -0
- squirrels/_schemas/query_param_models.py +75 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +126 -47
- squirrels/_seeds.py +35 -16
- squirrels/_sources.py +110 -0
- squirrels/_utils.py +248 -73
- squirrels/_version.py +1 -1
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +2 -81
- squirrels/data_sources.py +14 -631
- squirrels/parameter_options.py +13 -348
- squirrels/parameters.py +14 -1266
- squirrels/types.py +16 -0
- squirrels-0.5.0.dist-info/METADATA +113 -0
- squirrels-0.5.0.dist-info/RECORD +97 -0
- {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info}/WHEEL +1 -1
- squirrels-0.5.0.dist-info/entry_points.txt +3 -0
- {squirrels-0.4.1.dist-info → squirrels-0.5.0.dist-info/licenses}/LICENSE +1 -1
- squirrels/_authenticator.py +0 -85
- squirrels/_dashboards_io.py +0 -61
- squirrels/_environcfg.py +0 -84
- squirrels/arguments/init_time_args.py +0 -40
- squirrels/arguments/run_time_args.py +0 -208
- squirrels/package_data/assets/favicon.ico +0 -0
- squirrels/package_data/assets/index.css +0 -1
- squirrels/package_data/assets/index.js +0 -58
- squirrels/package_data/base_project/assets/expenses.db +0 -0
- squirrels/package_data/base_project/connections.yml +0 -7
- squirrels/package_data/base_project/dashboards/dashboard_example.py +0 -32
- squirrels/package_data/base_project/dashboards.yml +0 -10
- squirrels/package_data/base_project/env.yml +0 -29
- squirrels/package_data/base_project/models/dbviews/dbview_example.py +0 -47
- squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -22
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -21
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -3
- squirrels/package_data/base_project/pyconfigs/auth.py +0 -45
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -19
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -95
- squirrels/package_data/base_project/seeds/seed_subcategories.csv +0 -15
- squirrels/package_data/base_project/squirrels.yml.j2 +0 -94
- squirrels/package_data/templates/index.html +0 -18
- squirrels/project.py +0 -378
- squirrels/user_base.py +0 -55
- squirrels-0.4.1.dist-info/METADATA +0 -117
- squirrels-0.4.1.dist-info/RECORD +0 -60
- squirrels-0.4.1.dist-info/entry_points.txt +0 -4
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from fastapi import FastAPI, Depends, Request, Query, Response, APIRouter, Form
|
|
2
|
+
from fastapi.responses import RedirectResponse, HTMLResponse
|
|
3
|
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
4
|
+
from typing import Annotated, cast
|
|
5
|
+
|
|
6
|
+
from .base import RouteBase
|
|
7
|
+
from .._schemas.auth_models import (
|
|
8
|
+
ClientRegistrationRequest, ClientUpdateRequest, ClientRegistrationResponse, ClientDetailsResponse, ClientUpdateResponse,
|
|
9
|
+
TokenResponse, OAuthServerMetadata, AbstractUser
|
|
10
|
+
)
|
|
11
|
+
from .._exceptions import InvalidInputError
|
|
12
|
+
from .. import _utils as u
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OAuth2Routes(RouteBase):
|
|
16
|
+
"""OAuth2 routes"""
|
|
17
|
+
|
|
18
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
19
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
20
|
+
|
|
21
|
+
def serve_login_page(self, auth_path: str, request: Request, client_id: str) -> HTMLResponse:
|
|
22
|
+
"""Helper function to serve the login page with optional error message"""
|
|
23
|
+
# Get client information for display
|
|
24
|
+
client_details = self.authenticator.get_oauth_client_details(client_id)
|
|
25
|
+
client_name = client_details.client_name if client_details else None
|
|
26
|
+
project_name = self.manifest_cfg.project_variables.label
|
|
27
|
+
|
|
28
|
+
# Get available login providers
|
|
29
|
+
providers = []
|
|
30
|
+
for provider in self.authenticator.auth_providers:
|
|
31
|
+
provider_login_url = f"{auth_path}/providers/{provider.name}/login"
|
|
32
|
+
providers.append({
|
|
33
|
+
"name": provider.name,
|
|
34
|
+
"label": provider.label,
|
|
35
|
+
"icon": provider.icon,
|
|
36
|
+
"login_url": provider_login_url
|
|
37
|
+
})
|
|
38
|
+
|
|
39
|
+
# Template context
|
|
40
|
+
context = {
|
|
41
|
+
"request": request,
|
|
42
|
+
"project_name": project_name,
|
|
43
|
+
"client_name": client_name,
|
|
44
|
+
"providers": providers,
|
|
45
|
+
"login_url": f"{auth_path}/login",
|
|
46
|
+
"return_url": str(request.url),
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
return HTMLResponse(
|
|
50
|
+
content=self.templates.get_template("oauth_login.html").render(context),
|
|
51
|
+
status_code=200
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
def setup_routes(self, app: FastAPI, squirrels_version_path: str) -> None:
|
|
55
|
+
"""Setup all OAuth2 routes"""
|
|
56
|
+
|
|
57
|
+
auth_path = squirrels_version_path + "/auth"
|
|
58
|
+
router_path = "/oauth2"
|
|
59
|
+
router = APIRouter(prefix=router_path)
|
|
60
|
+
|
|
61
|
+
# Authorization dependency for client management
|
|
62
|
+
get_client_token = HTTPBearer(auto_error=False)
|
|
63
|
+
|
|
64
|
+
async def validate_client_registration_token(
|
|
65
|
+
client_id: str, auth: HTTPAuthorizationCredentials = Depends(get_client_token),
|
|
66
|
+
) -> None:
|
|
67
|
+
"""Validate Bearer token for client management operations"""
|
|
68
|
+
|
|
69
|
+
if not auth or not auth.scheme == "Bearer":
|
|
70
|
+
raise InvalidInputError(401, "invalid_client",
|
|
71
|
+
"Missing or invalid authorization header. Use 'Authorization: Bearer <registration_access_token>'"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
token = auth.credentials
|
|
75
|
+
is_valid = self.authenticator.validate_registration_access_token(client_id, token)
|
|
76
|
+
if not is_valid:
|
|
77
|
+
raise InvalidInputError(401, "invalid_token", "Invalid registration access token for this client")
|
|
78
|
+
|
|
79
|
+
def validate_oauth_client_credentials(client_id: str | None, client_secret: str | None) -> str:
|
|
80
|
+
"""
|
|
81
|
+
Validate OAuth client credentials from form data or Authorization header.
|
|
82
|
+
Returns the validated client_id.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
# Validate client credentials
|
|
86
|
+
if not client_id or not client_secret or not self.authenticator.validate_client_credentials(client_id, client_secret):
|
|
87
|
+
raise InvalidInputError(400, "invalid_client", "Invalid client credentials")
|
|
88
|
+
|
|
89
|
+
return cast(str, client_id)
|
|
90
|
+
|
|
91
|
+
# Client Registration Endpoint
|
|
92
|
+
client_management_path = '/client/{client_id}'
|
|
93
|
+
|
|
94
|
+
@router.post("/client", description="Register a new OAuth client", tags=["OAuth2"])
|
|
95
|
+
async def register_oauth_client(request: ClientRegistrationRequest) -> ClientRegistrationResponse:
|
|
96
|
+
"""Register a new OAuth client and return client credentials"""
|
|
97
|
+
|
|
98
|
+
# Register the client using the authenticator
|
|
99
|
+
client_registration_response = self.authenticator.register_oauth_client(
|
|
100
|
+
request, client_management_path_format=router_path+client_management_path
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return client_registration_response
|
|
104
|
+
|
|
105
|
+
# Client Management Endpoints
|
|
106
|
+
@router.get(client_management_path, description="Get OAuth client registration details", tags=["OAuth2"])
|
|
107
|
+
async def get_oauth_client(
|
|
108
|
+
client_id: str, _: Annotated[None, Depends(validate_client_registration_token)]
|
|
109
|
+
) -> ClientDetailsResponse:
|
|
110
|
+
"""Get OAuth client registration details"""
|
|
111
|
+
|
|
112
|
+
client_details = self.authenticator.get_oauth_client_details(client_id)
|
|
113
|
+
|
|
114
|
+
return client_details
|
|
115
|
+
|
|
116
|
+
@router.put(client_management_path, description="Update OAuth client registration", tags=["OAuth2"])
|
|
117
|
+
async def update_oauth_client(
|
|
118
|
+
client_id: str, request: ClientUpdateRequest, _: Annotated[None, Depends(validate_client_registration_token)]
|
|
119
|
+
) -> ClientUpdateResponse:
|
|
120
|
+
"""Update OAuth client registration and rotate access token"""
|
|
121
|
+
|
|
122
|
+
# Update the client and get new registration access token
|
|
123
|
+
client_details = self.authenticator.update_oauth_client_with_token_rotation(client_id, request)
|
|
124
|
+
|
|
125
|
+
return client_details
|
|
126
|
+
|
|
127
|
+
@router.delete(client_management_path, description="Revoke OAuth client registration", tags=["OAuth2"], responses={
|
|
128
|
+
204: { "description": "OAuth client registration revoked successfully" }
|
|
129
|
+
})
|
|
130
|
+
async def revoke_oauth_client(
|
|
131
|
+
client_id: str, _: Annotated[None, Depends(validate_client_registration_token)]
|
|
132
|
+
) -> Response:
|
|
133
|
+
"""Revoke (deactivate) OAuth client registration"""
|
|
134
|
+
|
|
135
|
+
self.authenticator.revoke_oauth_client(client_id)
|
|
136
|
+
return Response(status_code=204)
|
|
137
|
+
|
|
138
|
+
# Authorization Endpoint
|
|
139
|
+
@router.get("/authorize", description="OAuth 2.1 Authorization Endpoint", tags=["OAuth2"], response_model=None)
|
|
140
|
+
async def authorize_endpoint(
|
|
141
|
+
request: Request,
|
|
142
|
+
response_type: str = Query(default="code", description="OAuth response type"),
|
|
143
|
+
client_id: str = Query(..., description="OAuth client identifier"),
|
|
144
|
+
redirect_uri: str = Query(..., description="URI to redirect after authorization"),
|
|
145
|
+
scope: str = Query(default="read", description="Requested scope"),
|
|
146
|
+
state: str | None = Query(default=None, description="State parameter for CSRF protection"),
|
|
147
|
+
code_challenge: str = Query(..., description="PKCE code challenge (required)"),
|
|
148
|
+
code_challenge_method: str = Query(default="S256", description="PKCE code challenge method"),
|
|
149
|
+
user: AbstractUser = Depends(self.get_current_user)
|
|
150
|
+
):
|
|
151
|
+
"""OAuth 2.1 authorization endpoint for initiating authorization code flow"""
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
# Validate response_type
|
|
155
|
+
if response_type != "code":
|
|
156
|
+
raise InvalidInputError(400, "unsupported_response_type", "Only 'code' response type is supported")
|
|
157
|
+
|
|
158
|
+
# Check if user is authenticated
|
|
159
|
+
if user.access_level == "guest":
|
|
160
|
+
# User is not authenticated - serve login page
|
|
161
|
+
return self.serve_login_page(auth_path, request, client_id)
|
|
162
|
+
|
|
163
|
+
# TODO: Serve a page with an "authorize" button even if user is already authenticated
|
|
164
|
+
# Ex. if not request.session.get("authorization_approved"), redirect to a page with button that submits to "/approve-authorization"
|
|
165
|
+
|
|
166
|
+
# User is authenticated - generate authorization code
|
|
167
|
+
authorization_code = self.authenticator.create_authorization_code(
|
|
168
|
+
client_id=client_id,
|
|
169
|
+
username=user.username,
|
|
170
|
+
redirect_uri=redirect_uri,
|
|
171
|
+
scope=scope,
|
|
172
|
+
code_challenge=code_challenge,
|
|
173
|
+
code_challenge_method=code_challenge_method
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
# Redirect back to client with authorization code
|
|
177
|
+
success_params = f"?code={authorization_code}"
|
|
178
|
+
if state:
|
|
179
|
+
success_params += f"&state={state}"
|
|
180
|
+
|
|
181
|
+
return RedirectResponse(url=f"{redirect_uri}{success_params}")
|
|
182
|
+
|
|
183
|
+
except InvalidInputError as e:
|
|
184
|
+
if e.error == "invalid_request":
|
|
185
|
+
error_params = f"?error={e.error}&error_description={e.error_description.replace(' ', '+')}"
|
|
186
|
+
if state:
|
|
187
|
+
error_params += f"&state={state}"
|
|
188
|
+
return RedirectResponse(url=f"{redirect_uri}{error_params}")
|
|
189
|
+
else:
|
|
190
|
+
raise e
|
|
191
|
+
|
|
192
|
+
# Token Endpoint
|
|
193
|
+
@router.post("/token", description="OAuth 2.1 Token Endpoint", tags=["OAuth2"])
|
|
194
|
+
async def token_endpoint(
|
|
195
|
+
grant_type: str = Form(...),
|
|
196
|
+
code: str | None = Form(default=None),
|
|
197
|
+
redirect_uri: str | None = Form(default=None),
|
|
198
|
+
code_verifier: str | None = Form(default=None),
|
|
199
|
+
refresh_token: str | None = Form(default=None),
|
|
200
|
+
client_id: str | None = Form(default=None),
|
|
201
|
+
client_secret: str | None = Form(default=None)
|
|
202
|
+
) -> TokenResponse:
|
|
203
|
+
"""OAuth 2.1 token endpoint for exchanging authorization code or refresh token for access token"""
|
|
204
|
+
|
|
205
|
+
# Validate client credentials
|
|
206
|
+
auth_client_id = validate_oauth_client_credentials(client_id, client_secret)
|
|
207
|
+
|
|
208
|
+
# Get token expiry configuration
|
|
209
|
+
expiry_mins = self._get_access_token_expiry_minutes()
|
|
210
|
+
|
|
211
|
+
if grant_type == "authorization_code":
|
|
212
|
+
# Validate required parameters for authorization code flow
|
|
213
|
+
if not all([code, redirect_uri, code_verifier, auth_client_id]):
|
|
214
|
+
raise InvalidInputError(400, "invalid_request", "Missing required parameters for authorization_code grant")
|
|
215
|
+
|
|
216
|
+
# Type casts since we validated above
|
|
217
|
+
code = cast(str, code)
|
|
218
|
+
redirect_uri = cast(str, redirect_uri)
|
|
219
|
+
code_verifier = cast(str, code_verifier)
|
|
220
|
+
auth_client_id = cast(str, auth_client_id)
|
|
221
|
+
|
|
222
|
+
# Exchange authorization code for tokens
|
|
223
|
+
token_response = self.authenticator.exchange_authorization_code(
|
|
224
|
+
code=code,
|
|
225
|
+
client_id=auth_client_id,
|
|
226
|
+
redirect_uri=redirect_uri,
|
|
227
|
+
code_verifier=code_verifier,
|
|
228
|
+
access_token_expiry_minutes=expiry_mins
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
return token_response
|
|
232
|
+
|
|
233
|
+
elif grant_type == "refresh_token":
|
|
234
|
+
# Validate required parameters for refresh token flow
|
|
235
|
+
if not all([refresh_token, auth_client_id]):
|
|
236
|
+
raise InvalidInputError(400, "invalid_request", "Missing required parameters for refresh_token grant")
|
|
237
|
+
|
|
238
|
+
# Type casts since we validated above
|
|
239
|
+
refresh_token = cast(str, refresh_token)
|
|
240
|
+
auth_client_id = cast(str, auth_client_id)
|
|
241
|
+
|
|
242
|
+
# Refresh access token
|
|
243
|
+
token_response = self.authenticator.refresh_oauth_access_token(
|
|
244
|
+
refresh_token=refresh_token,
|
|
245
|
+
client_id=auth_client_id,
|
|
246
|
+
access_token_expiry_minutes=expiry_mins
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return token_response
|
|
250
|
+
|
|
251
|
+
else:
|
|
252
|
+
raise InvalidInputError(400, "unsupported_grant_type", f"Grant type '{grant_type}' is not supported")
|
|
253
|
+
|
|
254
|
+
# Token Revocation Endpoint
|
|
255
|
+
@router.post("/token/revoke", description="OAuth 2.1 Token Revocation Endpoint", tags=["OAuth2"])
|
|
256
|
+
async def revoke_endpoint(
|
|
257
|
+
token: str = Form(..., description="The token to be revoked"),
|
|
258
|
+
token_type_hint: str | None = Form(default=None, description="Hint about the type of token being revoked"),
|
|
259
|
+
client_id: str | None = Form(default=None),
|
|
260
|
+
client_secret: str | None = Form(default=None)
|
|
261
|
+
) -> Response:
|
|
262
|
+
"""OAuth 2.1 token revocation endpoint for revoking refresh tokens"""
|
|
263
|
+
|
|
264
|
+
# Validate client credentials
|
|
265
|
+
auth_client_id = validate_oauth_client_credentials(client_id, client_secret)
|
|
266
|
+
|
|
267
|
+
# Revoke the token (per RFC 7009, always return 200 regardless of token validity)
|
|
268
|
+
try:
|
|
269
|
+
self.authenticator.revoke_oauth_token(auth_client_id, token, token_type_hint)
|
|
270
|
+
except InvalidInputError:
|
|
271
|
+
# Per OAuth spec, revocation endpoint should return 200 even for invalid tokens
|
|
272
|
+
pass
|
|
273
|
+
|
|
274
|
+
return Response(status_code=200)
|
|
275
|
+
|
|
276
|
+
# Authorization Server Metadata Endpoint (well-known endpoint)
|
|
277
|
+
@app.get("/.well-known/oauth-authorization-server", tags=["OAuth2"], description="OAuth 2.1 Authorization Server Metadata")
|
|
278
|
+
async def authorization_server_metadata(request: Request) -> OAuthServerMetadata:
|
|
279
|
+
"""OAuth 2.1 Authorization Server Metadata endpoint (RFC 8414)"""
|
|
280
|
+
|
|
281
|
+
# Get the base URL from the request
|
|
282
|
+
scheme = u.get_scheme(request.url.hostname)
|
|
283
|
+
base_url = scheme + "://" + request.url.netloc
|
|
284
|
+
|
|
285
|
+
return OAuthServerMetadata(
|
|
286
|
+
issuer=base_url,
|
|
287
|
+
authorization_endpoint=f"{base_url}{router_path}/authorize",
|
|
288
|
+
token_endpoint=f"{base_url}{router_path}/token",
|
|
289
|
+
revocation_endpoint=f"{base_url}{router_path}/token/revoke",
|
|
290
|
+
registration_endpoint=f"{base_url}{router_path}/client",
|
|
291
|
+
scopes_supported=["read"],
|
|
292
|
+
response_types_supported=["code"],
|
|
293
|
+
grant_types_supported=["authorization_code", "refresh_token"],
|
|
294
|
+
token_endpoint_auth_methods_supported=["client_secret_post"],
|
|
295
|
+
code_challenge_methods_supported=["S256"]
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
app.include_router(router)
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Project metadata routes
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any
|
|
5
|
+
from fastapi import FastAPI, Depends, Request
|
|
6
|
+
from fastapi.responses import JSONResponse
|
|
7
|
+
from fastapi.security import HTTPBearer
|
|
8
|
+
from mcp.server.fastmcp import FastMCP, Context
|
|
9
|
+
from dataclasses import asdict
|
|
10
|
+
from cachetools import TTLCache
|
|
11
|
+
from textwrap import dedent
|
|
12
|
+
import time
|
|
13
|
+
|
|
14
|
+
from .. import _utils as u, _constants as c
|
|
15
|
+
from .._schemas import response_models as rm
|
|
16
|
+
from .._parameter_sets import ParameterSet
|
|
17
|
+
from .._exceptions import ConfigurationError, InvalidInputError
|
|
18
|
+
from .._manifest import PermissionScope, AuthenticationEnforcement
|
|
19
|
+
from .._version import __version__
|
|
20
|
+
from .._schemas.query_param_models import get_query_models_for_parameters
|
|
21
|
+
from .._schemas.auth_models import AbstractUser
|
|
22
|
+
from .base import RouteBase
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ProjectRoutes(RouteBase):
|
|
26
|
+
"""Project metadata and data catalog routes"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
29
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
30
|
+
|
|
31
|
+
# Setup caches
|
|
32
|
+
parameters_cache_size = int(self.env_vars.get(c.SQRL_PARAMETERS_CACHE_SIZE, 1024))
|
|
33
|
+
parameters_cache_ttl = int(self.env_vars.get(c.SQRL_PARAMETERS_CACHE_TTL_MINUTES, 60))
|
|
34
|
+
self.parameters_cache = TTLCache(maxsize=parameters_cache_size, ttl=parameters_cache_ttl*60)
|
|
35
|
+
|
|
36
|
+
async def _get_parameters_helper(
|
|
37
|
+
self, parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
38
|
+
user: AbstractUser, selections: tuple[tuple[str, Any], ...]
|
|
39
|
+
) -> ParameterSet:
|
|
40
|
+
"""Helper for getting parameters"""
|
|
41
|
+
selections_dict = dict(selections)
|
|
42
|
+
if "x_parent_param" not in selections_dict:
|
|
43
|
+
if len(selections_dict) > 1:
|
|
44
|
+
raise InvalidInputError(400, "invalid_input_for_cascading_parameters", f"The parameters endpoint takes at most 1 widget parameter selection (unless x_parent_param is provided). Got {selections_dict}")
|
|
45
|
+
elif len(selections_dict) == 1:
|
|
46
|
+
parent_param = next(iter(selections_dict))
|
|
47
|
+
selections_dict["x_parent_param"] = parent_param
|
|
48
|
+
|
|
49
|
+
parent_param = selections_dict.get("x_parent_param")
|
|
50
|
+
if parent_param is not None and parent_param not in selections_dict:
|
|
51
|
+
# this condition is possible for multi-select parameters with empty selection
|
|
52
|
+
selections_dict[parent_param] = list()
|
|
53
|
+
|
|
54
|
+
if not self.authenticator.can_user_access_scope(user, entity_scope):
|
|
55
|
+
raise self.project._permission_error(user, entity_type, entity_name, entity_scope.name)
|
|
56
|
+
|
|
57
|
+
param_set = self.param_cfg_set.apply_selections(parameters_tuple, selections_dict, user, parent_param=parent_param)
|
|
58
|
+
return param_set
|
|
59
|
+
|
|
60
|
+
async def _get_parameters_cachable(
|
|
61
|
+
self, parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
62
|
+
user: AbstractUser, selections: tuple[tuple[str, Any], ...]
|
|
63
|
+
) -> ParameterSet:
|
|
64
|
+
"""Cachable version of parameters helper"""
|
|
65
|
+
return await self.do_cachable_action(
|
|
66
|
+
self.parameters_cache, self._get_parameters_helper, parameters_tuple, entity_type, entity_name, entity_scope, user, selections
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def setup_routes(
|
|
70
|
+
self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_version: str, project_label: str, param_fields: dict
|
|
71
|
+
):
|
|
72
|
+
"""Setup project metadata routes"""
|
|
73
|
+
|
|
74
|
+
elevated_access_level = self.project._elevated_access_level
|
|
75
|
+
if elevated_access_level != "admin":
|
|
76
|
+
self.logger.warning(f"{c.SQRL_PERMISSIONS_ELEVATED_ACCESS_LEVEL} has been set to a non-admin access level. For security reasons, DO NOT expose the APIs for this app publicly!")
|
|
77
|
+
|
|
78
|
+
# Project metadata endpoint
|
|
79
|
+
@app.get(project_metadata_path, tags=["Project Metadata"], response_class=JSONResponse)
|
|
80
|
+
async def get_project_metadata(request: Request) -> rm.ProjectModel:
|
|
81
|
+
return rm.ProjectModel(
|
|
82
|
+
name=project_name,
|
|
83
|
+
version=project_version,
|
|
84
|
+
label=self.manifest_cfg.project_variables.label,
|
|
85
|
+
description=self.manifest_cfg.project_variables.description,
|
|
86
|
+
elevated_access_level=elevated_access_level,
|
|
87
|
+
redoc_path=project_metadata_path + "/redoc",
|
|
88
|
+
swagger_path=project_metadata_path + "/docs",
|
|
89
|
+
mcp_server_path=project_metadata_path + "/mcp",
|
|
90
|
+
squirrels_version=__version__
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Data catalog endpoint
|
|
94
|
+
data_catalog_path = project_metadata_path + '/data-catalog'
|
|
95
|
+
|
|
96
|
+
async def get_data_catalog0(user: AbstractUser) -> rm.CatalogModel:
|
|
97
|
+
parameters = self.param_cfg_set.apply_selections(None, {}, user)
|
|
98
|
+
parameters_model = parameters.to_api_response_model0()
|
|
99
|
+
full_parameters_list = [p.name for p in parameters_model.parameters]
|
|
100
|
+
user_has_elevated_privileges = u.user_has_elevated_privileges(user.access_level, elevated_access_level)
|
|
101
|
+
|
|
102
|
+
dataset_items: list[rm.DatasetItemModel] = []
|
|
103
|
+
for name, config in self.manifest_cfg.datasets.items():
|
|
104
|
+
if self.authenticator.can_user_access_scope(user, config.scope):
|
|
105
|
+
name_for_api = u.normalize_name_for_api(name)
|
|
106
|
+
metadata = self.project.dataset_metadata(name).to_json()
|
|
107
|
+
parameters = config.parameters if config.parameters is not None else full_parameters_list
|
|
108
|
+
|
|
109
|
+
# Build dataset-specific configurables list
|
|
110
|
+
if user_has_elevated_privileges:
|
|
111
|
+
dataset_configurables_defaults = self.manifest_cfg.get_default_configurables(name)
|
|
112
|
+
dataset_configurables_list = [
|
|
113
|
+
rm.ConfigurableDefaultModel(name=name, default=default)
|
|
114
|
+
for name, default in dataset_configurables_defaults.items()
|
|
115
|
+
]
|
|
116
|
+
else:
|
|
117
|
+
dataset_configurables_list = []
|
|
118
|
+
|
|
119
|
+
dataset_items.append(rm.DatasetItemModel(
|
|
120
|
+
name=name, label=config.label,
|
|
121
|
+
description=config.description,
|
|
122
|
+
schema=metadata["schema"], # type: ignore
|
|
123
|
+
configurables=dataset_configurables_list,
|
|
124
|
+
parameters=parameters,
|
|
125
|
+
parameters_path=f"{project_metadata_path}/dataset/{name_for_api}/parameters",
|
|
126
|
+
result_path=f"{project_metadata_path}/dataset/{name_for_api}"
|
|
127
|
+
))
|
|
128
|
+
|
|
129
|
+
dashboard_items: list[rm.DashboardItemModel] = []
|
|
130
|
+
for name, dashboard in self.project._dashboards.items():
|
|
131
|
+
config = dashboard.config
|
|
132
|
+
if self.authenticator.can_user_access_scope(user, config.scope):
|
|
133
|
+
name_for_api = u.normalize_name_for_api(name)
|
|
134
|
+
|
|
135
|
+
try:
|
|
136
|
+
dashboard_format = self.project._dashboards[name].get_dashboard_format()
|
|
137
|
+
except KeyError:
|
|
138
|
+
raise ConfigurationError(f"No dashboard file found for: {name}")
|
|
139
|
+
|
|
140
|
+
parameters = config.parameters if config.parameters is not None else full_parameters_list
|
|
141
|
+
dashboard_items.append(rm.DashboardItemModel(
|
|
142
|
+
name=name, label=config.label,
|
|
143
|
+
description=config.description,
|
|
144
|
+
result_format=dashboard_format,
|
|
145
|
+
parameters=parameters,
|
|
146
|
+
parameters_path=f"{project_metadata_path}/dashboard/{name_for_api}/parameters",
|
|
147
|
+
result_path=f"{project_metadata_path}/dashboard/{name_for_api}"
|
|
148
|
+
))
|
|
149
|
+
|
|
150
|
+
if user_has_elevated_privileges:
|
|
151
|
+
compiled_dag = await self.project._get_compiled_dag(user)
|
|
152
|
+
connections_items = self.project._get_all_connections()
|
|
153
|
+
data_models = self.project._get_all_data_models(compiled_dag)
|
|
154
|
+
lineage_items = self.project._get_all_data_lineage(compiled_dag)
|
|
155
|
+
configurables_list = [
|
|
156
|
+
rm.ConfigurableItemModel(name=name, label=cfg.label, default=cfg.default, description=cfg.description)
|
|
157
|
+
for name, cfg in self.manifest_cfg.configurables.items()
|
|
158
|
+
]
|
|
159
|
+
else:
|
|
160
|
+
connections_items = []
|
|
161
|
+
data_models = []
|
|
162
|
+
lineage_items = []
|
|
163
|
+
configurables_list = []
|
|
164
|
+
|
|
165
|
+
return rm.CatalogModel(
|
|
166
|
+
parameters=parameters_model.parameters,
|
|
167
|
+
datasets=dataset_items,
|
|
168
|
+
dashboards=dashboard_items,
|
|
169
|
+
connections=connections_items,
|
|
170
|
+
models=data_models,
|
|
171
|
+
lineage=lineage_items,
|
|
172
|
+
configurables=configurables_list,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@app.get(data_catalog_path, tags=["Project Metadata"], summary="Get catalog of datasets and dashboards available for user")
|
|
176
|
+
async def get_data_catalog(request: Request, user: AbstractUser = Depends(self.get_current_user)) -> rm.CatalogModel:
|
|
177
|
+
"""
|
|
178
|
+
Get catalog of datasets and dashboards available for the authenticated user.
|
|
179
|
+
|
|
180
|
+
For admin users, this endpoint will also return detailed information about all models and their lineage in the project.
|
|
181
|
+
"""
|
|
182
|
+
start = time.time()
|
|
183
|
+
|
|
184
|
+
# If authentication is required, require user to be authenticated to access catalog
|
|
185
|
+
if self.manifest_cfg.authentication.enforcement == AuthenticationEnforcement.REQUIRED and user.access_level == "guest":
|
|
186
|
+
raise InvalidInputError(401, "user_required", "Authentication is required to access the data catalog")
|
|
187
|
+
data_catalog = await get_data_catalog0(user)
|
|
188
|
+
|
|
189
|
+
self.log_activity_time("GET REQUEST for DATA CATALOG", start, request)
|
|
190
|
+
return data_catalog
|
|
191
|
+
|
|
192
|
+
@mcp.tool(
|
|
193
|
+
name=f"get_data_catalog_from_{project_name}",
|
|
194
|
+
title=f"Get Data Catalog (Project: {project_label})",
|
|
195
|
+
description=dedent(f"""
|
|
196
|
+
Use this tool to get the details of all datasets and parameters you can access in the Squirrels project '{project_name}'.
|
|
197
|
+
|
|
198
|
+
Unless the data catalog for this project has already been provided, use this tool at the start of each conversation.
|
|
199
|
+
""").strip()
|
|
200
|
+
)
|
|
201
|
+
async def get_data_catalog_tool(ctx: Context) -> rm.CatalogModelForTool:
|
|
202
|
+
headers = self.get_headers_from_tool_ctx(ctx)
|
|
203
|
+
user = self.get_user_from_tool_headers(headers)
|
|
204
|
+
data_catalog = await get_data_catalog0(user)
|
|
205
|
+
return rm.CatalogModelForTool(parameters=data_catalog.parameters, datasets=data_catalog.datasets)
|
|
206
|
+
|
|
207
|
+
# Project-level parameters endpoints
|
|
208
|
+
project_level_parameters_path = project_metadata_path + '/parameters'
|
|
209
|
+
parameters_description = "Selections of one parameter may cascade the available options in another parameter. " \
|
|
210
|
+
"For example, if the dataset has parameters for 'country' and 'city', available options for 'city' would " \
|
|
211
|
+
"depend on the selected option 'country'. If a parameter has 'trigger_refresh' as true, provide the parameter " \
|
|
212
|
+
"selection to this endpoint whenever it changes to refresh the parameter options of children parameters."
|
|
213
|
+
|
|
214
|
+
QueryModelForGetProjectParams, QueryModelForPostProjectParams = get_query_models_for_parameters(None, param_fields)
|
|
215
|
+
|
|
216
|
+
async def get_parameters_definition(
|
|
217
|
+
parameters_list: list[str] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
218
|
+
user: AbstractUser, all_request_params: dict, params: dict, *, headers: dict[str, str]
|
|
219
|
+
) -> rm.ParametersModel:
|
|
220
|
+
self._validate_request_params(all_request_params, params, headers)
|
|
221
|
+
|
|
222
|
+
get_parameters_function = self._get_parameters_helper if self.no_cache else self._get_parameters_cachable
|
|
223
|
+
selections = self.get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
|
|
224
|
+
parameters_tuple = tuple(parameters_list) if parameters_list is not None else None
|
|
225
|
+
result = await get_parameters_function(parameters_tuple, entity_type, entity_name, entity_scope, user, selections)
|
|
226
|
+
return result.to_api_response_model0()
|
|
227
|
+
|
|
228
|
+
@app.get(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
|
|
229
|
+
async def get_project_parameters(
|
|
230
|
+
request: Request, params: QueryModelForGetProjectParams, user=Depends(self.get_current_user) # type: ignore
|
|
231
|
+
) -> rm.ParametersModel:
|
|
232
|
+
start = time.time()
|
|
233
|
+
result = await get_parameters_definition(
|
|
234
|
+
None, "project", "", PermissionScope.PUBLIC, user, dict(request.query_params), asdict(params), headers=dict(request.headers)
|
|
235
|
+
)
|
|
236
|
+
self.log_activity_time("GET REQUEST for PROJECT PARAMETERS", start, request)
|
|
237
|
+
return result
|
|
238
|
+
|
|
239
|
+
@app.post(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
|
|
240
|
+
async def get_project_parameters_with_post(
|
|
241
|
+
request: Request, params: QueryModelForPostProjectParams, user=Depends(self.get_current_user) # type: ignore
|
|
242
|
+
) -> rm.ParametersModel:
|
|
243
|
+
start = time.time()
|
|
244
|
+
payload: dict = await request.json()
|
|
245
|
+
result = await get_parameters_definition(
|
|
246
|
+
None, "project", "", PermissionScope.PUBLIC, user, payload, params.model_dump(), headers=dict(request.headers)
|
|
247
|
+
)
|
|
248
|
+
self.log_activity_time("POST REQUEST for PROJECT PARAMETERS", start, request)
|
|
249
|
+
return result
|
|
250
|
+
|
|
251
|
+
return get_parameters_definition
|
|
252
|
+
|