squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- squirrels/__init__.py +4 -0
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +440 -792
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
- squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
- squirrels/_auth.py +590 -264
- squirrels/_command_line.py +130 -58
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +16 -15
- squirrels/_constants.py +36 -11
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +40 -34
- squirrels/_dataset_types.py +16 -11
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +7 -6
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +155 -77
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +11 -55
- squirrels/_model_configs.py +5 -5
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +276 -143
- squirrels/_package_data/base_project/.env +1 -24
- squirrels/_package_data/base_project/.env.example +31 -17
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
- squirrels/_package_data/base_project/docker/Dockerfile +2 -2
- squirrels/_package_data/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
- squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
- squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
- squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +41 -30
- squirrels/_parameters.py +560 -123
- squirrels/_project.py +487 -277
- squirrels/_py_module.py +71 -10
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +52 -13
- squirrels/_sources.py +29 -23
- squirrels/_utils.py +221 -42
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -2
- squirrels/auth.py +4 -0
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +10 -3
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
- squirrels/_api_response_models.py +0 -190
- squirrels/_dashboard_types.py +0 -82
- squirrels/_dashboards_io.py +0 -79
- squirrels-0.5.0b3.dist-info/METADATA +0 -110
- squirrels-0.5.0b3.dist-info/RECORD +0 -80
- /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
- /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/__init__.py
CHANGED
|
@@ -2,6 +2,8 @@ from ._version import __version__
|
|
|
2
2
|
|
|
3
3
|
from .arguments import *
|
|
4
4
|
|
|
5
|
+
from .auth import *
|
|
6
|
+
|
|
5
7
|
from .connections import *
|
|
6
8
|
|
|
7
9
|
from .parameter_options import *
|
|
@@ -15,3 +17,5 @@ from .dashboards import *
|
|
|
15
17
|
from .types import *
|
|
16
18
|
|
|
17
19
|
from ._project import SquirrelsProject
|
|
20
|
+
|
|
21
|
+
__all__ = ["SquirrelsProject"]
|
|
@@ -0,0 +1,337 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Authentication and user management routes
|
|
3
|
+
"""
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
import secrets
|
|
6
|
+
from typing import Annotated, Literal
|
|
7
|
+
from urllib.parse import quote
|
|
8
|
+
from fastapi import FastAPI, Depends, Request, Response, Form, APIRouter
|
|
9
|
+
from fastapi.responses import RedirectResponse, HTMLResponse
|
|
10
|
+
from fastapi.security import HTTPBearer
|
|
11
|
+
from pydantic import BaseModel, Field
|
|
12
|
+
from authlib.integrations.starlette_client import OAuth
|
|
13
|
+
|
|
14
|
+
from .. import _utils as u
|
|
15
|
+
from .._schemas import response_models as rm
|
|
16
|
+
from .._exceptions import InvalidInputError
|
|
17
|
+
from .._schemas.auth_models import AbstractUser, RegisteredUser, GuestUser, UserFieldsModel, ApiKey
|
|
18
|
+
from .._manifest import AuthStrategy
|
|
19
|
+
from .base import RouteBase
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class AuthRoutes(RouteBase):
|
|
23
|
+
"""Authentication and user management routes"""
|
|
24
|
+
|
|
25
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
26
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
27
|
+
|
|
28
|
+
def setup_routes(self, app: FastAPI) -> None:
|
|
29
|
+
"""Setup all authentication routes"""
|
|
30
|
+
|
|
31
|
+
auth_path = "/auth"
|
|
32
|
+
auth_router = APIRouter(prefix=auth_path)
|
|
33
|
+
user_management_router = APIRouter(prefix=auth_path + "/user-management")
|
|
34
|
+
|
|
35
|
+
auth_strategy = self.manifest_cfg.project_variables.auth_strategy
|
|
36
|
+
is_external = (auth_strategy == AuthStrategy.EXTERNAL)
|
|
37
|
+
|
|
38
|
+
# Get expiry configuration
|
|
39
|
+
expiry_mins = self.env_vars.auth_token_expire_minutes
|
|
40
|
+
|
|
41
|
+
# Create user models
|
|
42
|
+
class CustomFieldsModel(self.authenticator.CustomUserFields):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
class UpdateUserModel(BaseModel):
|
|
46
|
+
access_level: Literal["admin", "member"] = Field(description="The access level of the user. Admins have more permissions such as creating and updating users.")
|
|
47
|
+
custom_fields: CustomFieldsModel = Field(description="User fields that are specific to this Squirrels project")
|
|
48
|
+
|
|
49
|
+
class UserInfoModel(UpdateUserModel):
|
|
50
|
+
username: str
|
|
51
|
+
|
|
52
|
+
class AddUserModel(UserInfoModel):
|
|
53
|
+
password: str
|
|
54
|
+
|
|
55
|
+
class UserSessionModel(BaseModel):
|
|
56
|
+
user: UserInfoModel
|
|
57
|
+
session_expiry_timestamp: float | None
|
|
58
|
+
|
|
59
|
+
# Setup OAuth2 login providers
|
|
60
|
+
oauth = OAuth()
|
|
61
|
+
|
|
62
|
+
for provider in self.authenticator.auth_providers:
|
|
63
|
+
oauth.register(
|
|
64
|
+
name=provider.name,
|
|
65
|
+
server_metadata_url=provider.provider_configs.server_metadata_url,
|
|
66
|
+
client_id=provider.provider_configs.client_id,
|
|
67
|
+
client_secret=provider.provider_configs.client_secret,
|
|
68
|
+
client_kwargs=provider.provider_configs.client_kwargs
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# User info endpoint
|
|
72
|
+
user_session_path = '/user-session'
|
|
73
|
+
|
|
74
|
+
@auth_router.get(user_session_path, description="Get the authenticated user's fields", tags=["Authentication"])
|
|
75
|
+
async def get_user_session(
|
|
76
|
+
request: Request, user: RegisteredUser | GuestUser = Depends(self.get_current_user)
|
|
77
|
+
) -> UserSessionModel:
|
|
78
|
+
if isinstance(user, GuestUser):
|
|
79
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token, no user info found")
|
|
80
|
+
|
|
81
|
+
expiry = request.session.get("access_token_expiry")
|
|
82
|
+
if expiry is None:
|
|
83
|
+
expiry = getattr(request.state, "access_token_expiry", None)
|
|
84
|
+
|
|
85
|
+
user_session = UserSessionModel(
|
|
86
|
+
user=user.model_dump(mode='json'),
|
|
87
|
+
session_expiry_timestamp=float(expiry) if expiry is not None else None
|
|
88
|
+
)
|
|
89
|
+
return user_session
|
|
90
|
+
|
|
91
|
+
# Login endpoint
|
|
92
|
+
if not is_external:
|
|
93
|
+
login_path = '/login'
|
|
94
|
+
|
|
95
|
+
@auth_router.post(login_path, tags=["Authentication"], description="Authenticate with username & password. Returns user information if no redirect_url is provided, otherwise redirects to the specified URL.")
|
|
96
|
+
async def login(
|
|
97
|
+
request: Request, username: Annotated[str, Form()], password: Annotated[str, Form()]
|
|
98
|
+
) -> UserSessionModel:
|
|
99
|
+
user = self.authenticator.get_user(username, password)
|
|
100
|
+
|
|
101
|
+
access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_mins)
|
|
102
|
+
expiry_timestamp = expiry.timestamp()
|
|
103
|
+
request.session["access_token"] = access_token
|
|
104
|
+
request.session["access_token_expiry"] = expiry_timestamp
|
|
105
|
+
|
|
106
|
+
user_session = UserSessionModel(user=user.model_dump(mode='json'), session_expiry_timestamp=expiry_timestamp)
|
|
107
|
+
return user_session
|
|
108
|
+
|
|
109
|
+
# Provider authentication endpoints
|
|
110
|
+
providers_path = '/providers'
|
|
111
|
+
provider_login_path = '/providers/{provider_name}/login'
|
|
112
|
+
provider_callback_path = '/providers/{provider_name}/callback'
|
|
113
|
+
|
|
114
|
+
@auth_router.get(providers_path, tags=["Authentication"])
|
|
115
|
+
async def get_providers(request: Request) -> list[rm.ProviderResponse]:
|
|
116
|
+
"""Get list of available authentication providers"""
|
|
117
|
+
_, root_path = self._get_base_url_for_current_app(request)
|
|
118
|
+
|
|
119
|
+
def get_icon_url(icon: str) -> str:
|
|
120
|
+
if icon.startswith("/public/"):
|
|
121
|
+
core_url = root_path.split("/api/")[0]
|
|
122
|
+
return core_url + icon
|
|
123
|
+
return icon
|
|
124
|
+
|
|
125
|
+
return [
|
|
126
|
+
rm.ProviderResponse(
|
|
127
|
+
name=provider.name,
|
|
128
|
+
label=provider.label,
|
|
129
|
+
icon=get_icon_url(provider.icon),
|
|
130
|
+
login_url=f"{root_path}{auth_path}/providers/{quote(provider.name)}/login",
|
|
131
|
+
)
|
|
132
|
+
for provider in self.authenticator.auth_providers
|
|
133
|
+
]
|
|
134
|
+
|
|
135
|
+
@auth_router.get(provider_login_path, tags=["Authentication"], responses={
|
|
136
|
+
307: {"description": "Redirect to sign in with provider"},
|
|
137
|
+
})
|
|
138
|
+
async def provider_login(request: Request, provider_name: str, redirect_url: str | None = None) -> RedirectResponse:
|
|
139
|
+
"""
|
|
140
|
+
Redirect to the login URL for the OAuth provider.
|
|
141
|
+
|
|
142
|
+
If login is successful, this endpoint redirects to the specified `redirect_url`. If no `redirect_url` is provided, it returns the user information of the Squirrels project's user.
|
|
143
|
+
"""
|
|
144
|
+
client = oauth.create_client(provider_name)
|
|
145
|
+
if client is None:
|
|
146
|
+
raise InvalidInputError(status_code=404, error="provider_not_found", error_description=f"Provider {provider_name} not found or configured.")
|
|
147
|
+
|
|
148
|
+
origin, root_path = self._get_base_url_for_current_app(request)
|
|
149
|
+
callback_uri = f"{origin}{root_path}{auth_path}/providers/{quote(provider_name)}/callback"
|
|
150
|
+
request.session["redirect_url"] = redirect_url
|
|
151
|
+
|
|
152
|
+
# OIDC best practice: include a nonce when requesting an id_token.
|
|
153
|
+
# Not all providers will use it, but major OIDC providers support it.
|
|
154
|
+
nonce = secrets.token_urlsafe(24)
|
|
155
|
+
request.session[f"oidc_nonce:{provider_name}"] = nonce
|
|
156
|
+
|
|
157
|
+
# PKCE: Some providers (e.g. Keycloak) require the authorization request to include
|
|
158
|
+
# `code_challenge_method=S256`. We also store the verifier so we can send it when
|
|
159
|
+
# exchanging the authorization code for tokens.
|
|
160
|
+
code_verifier = secrets.token_urlsafe(64) # ~86 chars; within 43-128 PKCE range
|
|
161
|
+
request.session[f"pkce_verifier:{provider_name}"] = code_verifier
|
|
162
|
+
code_challenge = u.generate_pkce_challenge(code_verifier)
|
|
163
|
+
|
|
164
|
+
return await client.authorize_redirect(
|
|
165
|
+
request,
|
|
166
|
+
callback_uri,
|
|
167
|
+
nonce=nonce,
|
|
168
|
+
code_challenge=code_challenge,
|
|
169
|
+
code_challenge_method="S256",
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
@auth_router.get(provider_callback_path, tags=["Authentication"], responses={
|
|
173
|
+
200: {"description": "HTML page indicating successful login"},
|
|
174
|
+
307: {"description": "Redirect to redirect_url provided from provider login"},
|
|
175
|
+
})
|
|
176
|
+
async def provider_callback(request: Request, provider_name: str):
|
|
177
|
+
"""Handle OAuth callback from provider"""
|
|
178
|
+
client = oauth.create_client(provider_name)
|
|
179
|
+
if client is None:
|
|
180
|
+
raise InvalidInputError(status_code=404, error="provider_not_found", error_description=f"Provider {provider_name} not found or configured.")
|
|
181
|
+
|
|
182
|
+
try:
|
|
183
|
+
code_verifier = request.session.pop(f"pkce_verifier:{provider_name}", None)
|
|
184
|
+
if code_verifier is None:
|
|
185
|
+
token_details: dict = await client.authorize_access_token(request)
|
|
186
|
+
else:
|
|
187
|
+
token_details = await client.authorize_access_token(request, code_verifier=code_verifier)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
raise InvalidInputError(status_code=400, error="provider_authorization_failed", error_description=f"Could not authorize with provider for access token: {str(e)}")
|
|
190
|
+
|
|
191
|
+
if is_external:
|
|
192
|
+
# Prefer id_token (JWT) for session auth if available. Many providers (e.g. Google)
|
|
193
|
+
# issue opaque access tokens that do not contain an issuer, which breaks provider
|
|
194
|
+
# auto-detection for session-based auth.
|
|
195
|
+
access_token = token_details.get("access_token")
|
|
196
|
+
id_token = token_details.get("id_token")
|
|
197
|
+
if isinstance(id_token, str) and id_token and id_token.count(".") == 2:
|
|
198
|
+
access_token = id_token
|
|
199
|
+
|
|
200
|
+
if not isinstance(access_token, str) or not access_token:
|
|
201
|
+
raise InvalidInputError(400, "provider_missing_access_token", f"Provider token not found for {provider_name}")
|
|
202
|
+
|
|
203
|
+
expires_in = token_details.get("expires_in")
|
|
204
|
+
if expires_in is None:
|
|
205
|
+
# Fallback for providers that only return absolute expiry
|
|
206
|
+
expiry_timestamp = token_details.get("expires_at")
|
|
207
|
+
if expiry_timestamp is None:
|
|
208
|
+
raise InvalidInputError(400, "provider_missing_expiry", f"Provider expiry timestamp not found for {provider_name}")
|
|
209
|
+
else:
|
|
210
|
+
expiry_timestamp = datetime.now(timezone.utc).timestamp() + float(expires_in)
|
|
211
|
+
else:
|
|
212
|
+
expected_nonce = request.session.pop(f"oidc_nonce:{provider_name}", None)
|
|
213
|
+
user_info = self.authenticator.get_user_info_from_token_details(
|
|
214
|
+
provider_name, token_details, expected_nonce=expected_nonce
|
|
215
|
+
)
|
|
216
|
+
user = self.authenticator.create_or_get_user_from_provider(provider_name, user_info)
|
|
217
|
+
access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_mins)
|
|
218
|
+
expiry_timestamp = expiry.timestamp()
|
|
219
|
+
|
|
220
|
+
request.session["access_token"] = access_token
|
|
221
|
+
request.session["access_token_expiry"] = expiry_timestamp
|
|
222
|
+
|
|
223
|
+
redirect_url = request.session.pop("redirect_url", None)
|
|
224
|
+
if redirect_url:
|
|
225
|
+
return RedirectResponse(url=redirect_url)
|
|
226
|
+
|
|
227
|
+
template = self.templates.get_template("login_successful.html")
|
|
228
|
+
return HTMLResponse(content=template.render({"request": request}), status_code=200)
|
|
229
|
+
|
|
230
|
+
# Logout endpoint
|
|
231
|
+
logout_path = '/logout'
|
|
232
|
+
|
|
233
|
+
@auth_router.post(logout_path, tags=["Authentication"])
|
|
234
|
+
async def logout(request: Request):
|
|
235
|
+
"""Logout the current user by clearing the access token and expiry from the session"""
|
|
236
|
+
request.session.pop("access_token", None)
|
|
237
|
+
request.session.pop("access_token_expiry", None)
|
|
238
|
+
return Response(status_code=200)
|
|
239
|
+
|
|
240
|
+
if not is_external:
|
|
241
|
+
# Change password endpoint
|
|
242
|
+
change_password_path = '/password'
|
|
243
|
+
|
|
244
|
+
class ChangePasswordRequest(BaseModel):
|
|
245
|
+
old_password: str
|
|
246
|
+
new_password: str
|
|
247
|
+
|
|
248
|
+
@auth_router.put(change_password_path, description="Change the password for the current user", tags=["Authentication"])
|
|
249
|
+
async def change_password(request: ChangePasswordRequest, user: RegisteredUser | GuestUser = Depends(self.get_current_user)) -> None:
|
|
250
|
+
if isinstance(user, GuestUser):
|
|
251
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token")
|
|
252
|
+
self.authenticator.change_password(user.username, request.old_password, request.new_password)
|
|
253
|
+
|
|
254
|
+
# API Key endpoints
|
|
255
|
+
api_key_path = '/api-keys'
|
|
256
|
+
|
|
257
|
+
class ApiKeyRequestBody(BaseModel):
|
|
258
|
+
title: str = Field(description="The title of the API key")
|
|
259
|
+
expiry_minutes: int | None = Field(
|
|
260
|
+
default=None,
|
|
261
|
+
description="The number of minutes the API key is valid for (or valid indefinitely if not provided)."
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
@auth_router.post(api_key_path, description="Create a new API key for the user", tags=["Authentication"])
|
|
265
|
+
async def create_api_key(body: ApiKeyRequestBody, user: RegisteredUser | GuestUser = Depends(self.get_current_user)) -> rm.ApiKeyResponse:
|
|
266
|
+
if isinstance(user, GuestUser):
|
|
267
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token, cannot create API key")
|
|
268
|
+
|
|
269
|
+
api_key, _ = self.authenticator.create_access_token(user, expiry_minutes=body.expiry_minutes, title=body.title)
|
|
270
|
+
return rm.ApiKeyResponse(api_key=api_key)
|
|
271
|
+
|
|
272
|
+
@auth_router.get(api_key_path, description="Get all API keys with title for the current user", tags=["Authentication"])
|
|
273
|
+
async def get_all_api_keys(user: RegisteredUser | GuestUser = Depends(self.get_current_user)) -> list[ApiKey]:
|
|
274
|
+
if isinstance(user, GuestUser):
|
|
275
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token, cannot get API keys")
|
|
276
|
+
return self.authenticator.get_all_api_keys(user.username)
|
|
277
|
+
|
|
278
|
+
revoke_api_key_path = '/api-keys/{key_id}'
|
|
279
|
+
|
|
280
|
+
@auth_router.delete(revoke_api_key_path, description="Revoke an API key", tags=["Authentication"], responses={
|
|
281
|
+
204: { "description": "API key revoked successfully" }
|
|
282
|
+
})
|
|
283
|
+
async def revoke_api_key(key_id: str, user: RegisteredUser | GuestUser = Depends(self.get_current_user)) -> Response:
|
|
284
|
+
if isinstance(user, GuestUser):
|
|
285
|
+
raise InvalidInputError(401, "invalid_authorization_token", "Invalid authorization token, cannot revoke API key")
|
|
286
|
+
self.authenticator.revoke_api_key(user.username, key_id)
|
|
287
|
+
return Response(status_code=204)
|
|
288
|
+
|
|
289
|
+
app.include_router(auth_router)
|
|
290
|
+
|
|
291
|
+
# User management endpoints (disabled if external auth only)
|
|
292
|
+
if is_external:
|
|
293
|
+
return
|
|
294
|
+
|
|
295
|
+
user_fields_path = '/user-fields'
|
|
296
|
+
|
|
297
|
+
@user_management_router.get(user_fields_path, description="Get details of the user fields", tags=["User Management"])
|
|
298
|
+
async def get_user_fields() -> UserFieldsModel:
|
|
299
|
+
return self.authenticator.user_fields
|
|
300
|
+
|
|
301
|
+
list_or_add_users_path = '/users'
|
|
302
|
+
update_or_delete_user_path = '/users/{username}'
|
|
303
|
+
|
|
304
|
+
@user_management_router.get(list_or_add_users_path, tags=["User Management"])
|
|
305
|
+
async def list_all_users(user: AbstractUser = Depends(self.get_current_user)) -> list[UserInfoModel]:
|
|
306
|
+
if user.access_level != "admin":
|
|
307
|
+
raise InvalidInputError(403, "unauthorized_to_list_users", "Current user does not have permission to list users")
|
|
308
|
+
return self.authenticator.get_all_users()
|
|
309
|
+
|
|
310
|
+
@user_management_router.post(list_or_add_users_path, description="Add a new user by providing details for username, password, and user fields", tags=["User Management"])
|
|
311
|
+
async def add_user(
|
|
312
|
+
new_user: AddUserModel, user: AbstractUser = Depends(self.get_current_user)
|
|
313
|
+
) -> UserInfoModel:
|
|
314
|
+
if user.access_level != "admin":
|
|
315
|
+
raise InvalidInputError(403, "unauthorized_to_add_user", "Current user does not have permission to add new users")
|
|
316
|
+
return self.authenticator.add_user(new_user.username, new_user.model_dump(mode='json', exclude={"username"}))
|
|
317
|
+
|
|
318
|
+
@user_management_router.put(update_or_delete_user_path, description="Update the user of the given username given the new user details", tags=["User Management"])
|
|
319
|
+
async def update_user(
|
|
320
|
+
username: str, updated_user: UpdateUserModel, user: AbstractUser = Depends(self.get_current_user)
|
|
321
|
+
) -> UserInfoModel:
|
|
322
|
+
if user.access_level != "admin":
|
|
323
|
+
raise InvalidInputError(403, "unauthorized_to_update_user", "Current user does not have permission to update users")
|
|
324
|
+
return self.authenticator.add_user(username, updated_user.model_dump(mode='json'), update_user=True)
|
|
325
|
+
|
|
326
|
+
@user_management_router.delete(update_or_delete_user_path, tags=["User Management"], responses={
|
|
327
|
+
204: { "description": "User deleted successfully" }
|
|
328
|
+
})
|
|
329
|
+
async def delete_user(username: str, user: AbstractUser = Depends(self.get_current_user)) -> Response:
|
|
330
|
+
if user.access_level != "admin":
|
|
331
|
+
raise InvalidInputError(403, "unauthorized_to_delete_user", "Current user cannot delete users")
|
|
332
|
+
if username == user.username:
|
|
333
|
+
raise InvalidInputError(403, "cannot_delete_own_user", "Cannot delete your own user")
|
|
334
|
+
self.authenticator.delete_user(username)
|
|
335
|
+
return Response(status_code=204)
|
|
336
|
+
|
|
337
|
+
app.include_router(user_management_router)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base utilities and dependencies for API routes
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Mapping, TypeVar, Callable, Coroutine, Literal
|
|
5
|
+
from textwrap import dedent
|
|
6
|
+
from fastapi import Request, Response, Depends, Header
|
|
7
|
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
8
|
+
from fastapi.templating import Jinja2Templates
|
|
9
|
+
from cachetools import TTLCache
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from datetime import datetime, timezone
|
|
12
|
+
|
|
13
|
+
from .. import _utils as u
|
|
14
|
+
from .._exceptions import InvalidInputError
|
|
15
|
+
from .._project import SquirrelsProject
|
|
16
|
+
from .._schemas.auth_models import AbstractUser
|
|
17
|
+
from .._dataset_types import DatasetResultFormat
|
|
18
|
+
from .._manifest import AuthType, AuthStrategy
|
|
19
|
+
|
|
20
|
+
T = TypeVar('T')
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class RouteBase:
|
|
24
|
+
"""Base class for route modules providing common functionality"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, get_bearer_token: HTTPBearer, project: SquirrelsProject, no_cache: bool = False):
|
|
27
|
+
self.project = project
|
|
28
|
+
self.no_cache = no_cache
|
|
29
|
+
self.logger = project._logger
|
|
30
|
+
self.env_vars = project._env_vars
|
|
31
|
+
self.manifest_cfg = project._manifest_cfg
|
|
32
|
+
self.authenticator = project._auth
|
|
33
|
+
self.param_cfg_set = project._param_cfg_set
|
|
34
|
+
|
|
35
|
+
# Setup templates
|
|
36
|
+
template_dir = Path(__file__).parent.parent / "_package_data" / "templates"
|
|
37
|
+
self.templates = Jinja2Templates(directory=str(template_dir))
|
|
38
|
+
|
|
39
|
+
# Authorization dependency for current user
|
|
40
|
+
def get_token_from_session(request: Request) -> str | None:
|
|
41
|
+
access_token = request.session.get("access_token")
|
|
42
|
+
if access_token is None:
|
|
43
|
+
return None
|
|
44
|
+
|
|
45
|
+
expiry = request.session.get("access_token_expiry")
|
|
46
|
+
datetime_now = datetime.now(timezone.utc).timestamp()
|
|
47
|
+
if expiry and expiry > datetime_now:
|
|
48
|
+
return access_token
|
|
49
|
+
|
|
50
|
+
if self.manifest_cfg.project_variables.auth_type == AuthType.REQUIRED:
|
|
51
|
+
raise InvalidInputError(401, "session_expired", "Login session expired. Please login again.")
|
|
52
|
+
else:
|
|
53
|
+
return None
|
|
54
|
+
|
|
55
|
+
def get_user_from_headers(api_key: str | None, bearer_token: str | None) -> tuple[AbstractUser, float | None]:
|
|
56
|
+
auth_strategy = self.manifest_cfg.project_variables.auth_strategy
|
|
57
|
+
if (auth_strategy == AuthStrategy.EXTERNAL):
|
|
58
|
+
if not bearer_token:
|
|
59
|
+
return self.project._guest_user, None
|
|
60
|
+
|
|
61
|
+
user, expiry = self.authenticator.get_user_from_external_token(bearer_token)
|
|
62
|
+
else:
|
|
63
|
+
final_token = api_key if api_key else bearer_token
|
|
64
|
+
user, expiry = self.authenticator.get_user_from_token(final_token)
|
|
65
|
+
|
|
66
|
+
if user is None:
|
|
67
|
+
user = self.project._guest_user
|
|
68
|
+
|
|
69
|
+
return user, expiry
|
|
70
|
+
|
|
71
|
+
async def get_current_user(
|
|
72
|
+
request: Request, response: Response,
|
|
73
|
+
x_api_key: str | None = Header(None, description="API key for authentication (alternative to Authorization header)"),
|
|
74
|
+
auth: HTTPAuthorizationCredentials = Depends(get_bearer_token)
|
|
75
|
+
) -> AbstractUser:
|
|
76
|
+
token = auth.credentials if auth and auth.scheme == "Bearer" else None
|
|
77
|
+
access_token = token if token else get_token_from_session(request)
|
|
78
|
+
user, expiry = get_user_from_headers(x_api_key, access_token)
|
|
79
|
+
response.headers["Applied-Username"] = user.username
|
|
80
|
+
request.state.access_token_expiry = expiry
|
|
81
|
+
return user
|
|
82
|
+
|
|
83
|
+
self.get_user_from_headers = get_user_from_headers
|
|
84
|
+
self.get_current_user = get_current_user
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def _get_base_url_for_current_app(request: Request) -> tuple[str, str]:
|
|
88
|
+
"""
|
|
89
|
+
Build the absolute base URL for the *current* mounted app, including `root_path`.
|
|
90
|
+
|
|
91
|
+
We avoid `request.url_for(...)` because route names can collide when multiple Squirrels
|
|
92
|
+
FastAPI apps are mounted into the same root app.
|
|
93
|
+
"""
|
|
94
|
+
origin = f"{request.url.scheme}://{request.url.netloc}"
|
|
95
|
+
root_path = str(request.scope.get("root_path") or "").rstrip("/")
|
|
96
|
+
return origin, root_path
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def _parameters_description(self) -> str:
|
|
100
|
+
"""Get the standard parameters description"""
|
|
101
|
+
return dedent("""
|
|
102
|
+
Selections of one parameter may cascade the available options in another parameter.
|
|
103
|
+
|
|
104
|
+
For example, if the dataset has parameters for 'country' and 'city', available options for 'city' would depend on the selected option 'country'.
|
|
105
|
+
|
|
106
|
+
If a parameter has `"trigger_refresh": true` and its selection changes, provide the parameter selection to this endpoint to refresh the parameter options of children parameters.
|
|
107
|
+
""").strip()
|
|
108
|
+
|
|
109
|
+
def get_selections_as_immutable(self, params: Mapping, uncached_keys: set[str]) -> tuple[tuple[str, Any], ...]:
|
|
110
|
+
"""Convert selections into a cachable tuple of pairs"""
|
|
111
|
+
selections = list()
|
|
112
|
+
for key, val in params.items():
|
|
113
|
+
if key in uncached_keys or val is None:
|
|
114
|
+
continue
|
|
115
|
+
if isinstance(val, (list, tuple)):
|
|
116
|
+
val = tuple(val)
|
|
117
|
+
selections.append((u.normalize_name(key), val))
|
|
118
|
+
return tuple(selections)
|
|
119
|
+
|
|
120
|
+
async def do_cachable_action(self, cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
|
|
121
|
+
"""Execute a cachable action"""
|
|
122
|
+
cache_key = tuple(args)
|
|
123
|
+
result = cache.get(cache_key)
|
|
124
|
+
if result is None:
|
|
125
|
+
result = await action(*args)
|
|
126
|
+
cache[cache_key] = result
|
|
127
|
+
return result
|
|
128
|
+
|
|
129
|
+
def get_name_from_path_section(self, request: Request, section: int) -> str:
|
|
130
|
+
"""Extract name from request path section"""
|
|
131
|
+
url_path: str = request.url.path
|
|
132
|
+
name_raw = url_path.split('/')[section]
|
|
133
|
+
return u.normalize_name(name_raw)
|
|
134
|
+
|
|
135
|
+
def get_configurables_from_headers(self, headers: Mapping[str, str]) -> tuple[tuple[str, str], ...]:
|
|
136
|
+
"""Extract configurables from request headers with prefix 'x-config-'."""
|
|
137
|
+
prefix = "x-config-"
|
|
138
|
+
cfg_pairs: list[tuple[str, str]] = []
|
|
139
|
+
seen_configurables: dict[str, str] = {} # normalized_name -> header_name
|
|
140
|
+
|
|
141
|
+
for key, value in headers.items():
|
|
142
|
+
key_lower = str(key).lower()
|
|
143
|
+
if key_lower.startswith(prefix):
|
|
144
|
+
cfg_name_raw = key_lower[len(prefix):]
|
|
145
|
+
cfg_name_normalized = u.normalize_name(cfg_name_raw) # Convert to underscore convention
|
|
146
|
+
|
|
147
|
+
# Check if we've already seen this configurable (with different header format)
|
|
148
|
+
if cfg_name_normalized in seen_configurables:
|
|
149
|
+
existing_header = seen_configurables[cfg_name_normalized]
|
|
150
|
+
raise InvalidInputError(
|
|
151
|
+
400, "duplicate_configurable_header",
|
|
152
|
+
f"Only one header format is allowed for configurable '{cfg_name_normalized}'. "
|
|
153
|
+
f"Both '{existing_header}' and '{key_lower}' were provided."
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
seen_configurables[cfg_name_normalized] = key_lower
|
|
157
|
+
cfg_pairs.append((cfg_name_normalized, str(value)))
|
|
158
|
+
|
|
159
|
+
configurables = [k for k, _ in cfg_pairs]
|
|
160
|
+
self.logger.info(f"Configurables specified: {configurables}", data={"configurables_specified": configurables})
|
|
161
|
+
return tuple(cfg_pairs)
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def extract_orientation_offset_and_limit(
|
|
165
|
+
params: Mapping[str, Any], *,
|
|
166
|
+
key_prefix: str = "x_",
|
|
167
|
+
default_orientation: Literal["records", "rows", "columns"] = "records",
|
|
168
|
+
default_offset: int = 0, default_limit: int = 1000
|
|
169
|
+
) -> DatasetResultFormat:
|
|
170
|
+
"""
|
|
171
|
+
Extract orientation, offset, and limit from query parameters.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
params: Query parameters
|
|
175
|
+
|
|
176
|
+
Returns:
|
|
177
|
+
Tuple of (orientation, offset, limit)
|
|
178
|
+
"""
|
|
179
|
+
# Handle orientation
|
|
180
|
+
orientation = str(params.get(f"{key_prefix}orientation", default_orientation)).lower()
|
|
181
|
+
|
|
182
|
+
if orientation not in ["records", "rows", "columns"]:
|
|
183
|
+
raise InvalidInputError(400, "invalid_orientation", f"Orientation must be 'records', 'rows', or 'columns'. Invalid orientation provided: {orientation}")
|
|
184
|
+
|
|
185
|
+
# Handle limit and offset
|
|
186
|
+
offset = int(params.get(f"{key_prefix}offset", default_offset))
|
|
187
|
+
limit = int(params.get(f"{key_prefix}limit", default_limit))
|
|
188
|
+
|
|
189
|
+
if offset < 0:
|
|
190
|
+
raise InvalidInputError(400, "invalid_offset", "Offset must be non-negative")
|
|
191
|
+
|
|
192
|
+
if limit < 0:
|
|
193
|
+
raise InvalidInputError(400, "invalid_limit", "Limit must be non-negative")
|
|
194
|
+
|
|
195
|
+
return DatasetResultFormat(orientation, offset, limit)
|
|
196
|
+
|