squirrels 0.5.0b3__py3-none-any.whl → 0.5.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +2 -0
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +262 -0
- squirrels/_api_routes/base.py +154 -0
- squirrels/_api_routes/dashboards.py +142 -0
- squirrels/_api_routes/data_management.py +103 -0
- squirrels/_api_routes/datasets.py +242 -0
- squirrels/_api_routes/oauth2.py +300 -0
- squirrels/_api_routes/project.py +214 -0
- squirrels/_api_server.py +142 -745
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/{_init_time_args.py → init_time_args.py} +5 -0
- squirrels/_arguments/{_run_time_args.py → run_time_args.py} +1 -1
- squirrels/_auth.py +645 -92
- squirrels/_connection_set.py +1 -1
- squirrels/_constants.py +6 -0
- squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
- squirrels/_exceptions.py +9 -37
- squirrels/_model_builder.py +1 -1
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +13 -12
- squirrels/_package_data/base_project/.env +1 -0
- squirrels/_package_data/base_project/.env.example +1 -0
- squirrels/_package_data/base_project/pyconfigs/parameters.py +84 -76
- squirrels/_package_data/base_project/pyconfigs/user.py +30 -2
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_parameter_configs.py +1 -1
- squirrels/_parameter_sets.py +31 -21
- squirrels/_parameters.py +521 -123
- squirrels/_project.py +43 -24
- squirrels/_py_module.py +3 -2
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +144 -0
- squirrels/_schemas/query_param_models.py +67 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +12 -8
- squirrels/_utils.py +34 -2
- squirrels/arguments.py +2 -2
- squirrels/auth.py +1 -0
- squirrels/dashboards.py +1 -1
- squirrels/types.py +3 -3
- {squirrels-0.5.0b3.dist-info → squirrels-0.5.0b4.dist-info}/METADATA +4 -1
- {squirrels-0.5.0b3.dist-info → squirrels-0.5.0b4.dist-info}/RECORD +46 -32
- squirrels/_dashboard_types.py +0 -82
- {squirrels-0.5.0b3.dist-info → squirrels-0.5.0b4.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.5.0b4.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
squirrels/__init__.py
CHANGED
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Authentication and user management routes
|
|
3
|
+
"""
|
|
4
|
+
from typing import Annotated, Callable
|
|
5
|
+
from fastapi import FastAPI, Depends, Request, Response, status, Form, APIRouter
|
|
6
|
+
from fastapi.responses import RedirectResponse
|
|
7
|
+
from fastapi.security import HTTPBearer
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
from authlib.integrations.starlette_client import OAuth
|
|
10
|
+
|
|
11
|
+
from .. import _constants as c
|
|
12
|
+
from .._schemas import response_models as rm
|
|
13
|
+
from .._exceptions import InvalidInputError
|
|
14
|
+
from .._auth import BaseUser
|
|
15
|
+
from .base import RouteBase
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AuthRoutes(RouteBase):
|
|
19
|
+
"""Authentication and user management routes"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
22
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
23
|
+
|
|
24
|
+
def setup_routes(self, app: FastAPI) -> None:
|
|
25
|
+
"""Setup all authentication routes"""
|
|
26
|
+
|
|
27
|
+
auth_router = APIRouter(prefix="/api/auth")
|
|
28
|
+
user_management_router = APIRouter(prefix="/api/auth/user-management")
|
|
29
|
+
|
|
30
|
+
# Get expiry configuration
|
|
31
|
+
expiry_mins = self._get_access_token_expiry_minutes()
|
|
32
|
+
|
|
33
|
+
# Create user models
|
|
34
|
+
class UpdateUserModel(self.UserModel):
|
|
35
|
+
is_admin: bool
|
|
36
|
+
|
|
37
|
+
class UserInfoModel(UpdateUserModel):
|
|
38
|
+
username: str
|
|
39
|
+
|
|
40
|
+
class AddUserModel(UserInfoModel):
|
|
41
|
+
password: str
|
|
42
|
+
|
|
43
|
+
# Setup OAuth2 login providers
|
|
44
|
+
oauth = OAuth()
|
|
45
|
+
|
|
46
|
+
for provider in self.authenticator.auth_providers:
|
|
47
|
+
oauth.register(
|
|
48
|
+
name=provider.name,
|
|
49
|
+
server_metadata_url=provider.provider_configs.server_metadata_url,
|
|
50
|
+
client_id=provider.provider_configs.client_id,
|
|
51
|
+
client_secret=provider.provider_configs.client_secret,
|
|
52
|
+
client_kwargs=provider.provider_configs.client_kwargs
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# User info endpoint
|
|
56
|
+
@auth_router.get("/userinfo", description="Get the authenticated user's fields", tags=["Authentication"])
|
|
57
|
+
async def get_userinfo(user: UserInfoModel | None = Depends(self.get_current_user)) -> UserInfoModel:
|
|
58
|
+
if user is None:
|
|
59
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
60
|
+
return user
|
|
61
|
+
|
|
62
|
+
# Login helper
|
|
63
|
+
def login_helper(
|
|
64
|
+
request: Request, user: BaseUser, redirect_url: str | None, *,
|
|
65
|
+
redirect_status_code: int = status.HTTP_307_TEMPORARY_REDIRECT
|
|
66
|
+
):
|
|
67
|
+
access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_mins)
|
|
68
|
+
request.session["access_token"] = access_token
|
|
69
|
+
request.session["access_token_expiry"] = expiry.timestamp()
|
|
70
|
+
return RedirectResponse(url=redirect_url, status_code=redirect_status_code) if redirect_url else user
|
|
71
|
+
|
|
72
|
+
# Login endpoints
|
|
73
|
+
@auth_router.post("/login", tags=["Authentication"], description="Authenticate with username and password. Returns user information if no redirect_url is provided, otherwise redirects to the specified URL.", responses={
|
|
74
|
+
200: {"model": UserInfoModel, "description": "Login successful, returns user information"},
|
|
75
|
+
302: {"description": "Redirect if redirect URL parameter is specified"},
|
|
76
|
+
})
|
|
77
|
+
async def login(request: Request, username: Annotated[str, Form()], password: Annotated[str, Form()], redirect_url: str | None = None):
|
|
78
|
+
user = self.authenticator.get_user(username, password)
|
|
79
|
+
return login_helper(request, user, redirect_url, redirect_status_code=status.HTTP_302_FOUND)
|
|
80
|
+
|
|
81
|
+
@auth_router.get("/login", tags=["Authentication"], description="Authenticate with an existing API key or session token. Returns user information if no redirect_url is provided, otherwise redirects to the specified URL.", responses={
|
|
82
|
+
200: {"model": UserInfoModel, "description": "Login successful, returns user information"},
|
|
83
|
+
307: {"description": "Redirect if redirect URL parameter is specified"},
|
|
84
|
+
})
|
|
85
|
+
async def login_with_api_key(
|
|
86
|
+
request: Request, redirect_url: str | None = None, user: UserInfoModel | None = Depends(self.get_current_user)
|
|
87
|
+
):
|
|
88
|
+
if user is None:
|
|
89
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
90
|
+
return login_helper(request, user, redirect_url)
|
|
91
|
+
|
|
92
|
+
# Provider authentication endpoints
|
|
93
|
+
providers_path = '/providers'
|
|
94
|
+
provider_login_path = '/providers/{provider_name}/login'
|
|
95
|
+
provider_callback_path = '/providers/{provider_name}/callback'
|
|
96
|
+
|
|
97
|
+
@auth_router.get(providers_path, tags=["Authentication"])
|
|
98
|
+
async def get_providers(request: Request) -> list[rm.ProviderResponse]:
|
|
99
|
+
"""Get list of available authentication providers"""
|
|
100
|
+
return [
|
|
101
|
+
rm.ProviderResponse(
|
|
102
|
+
name=provider.name,
|
|
103
|
+
label=provider.label,
|
|
104
|
+
icon=provider.icon,
|
|
105
|
+
login_url=str(request.url_for('provider_login', provider_name=provider.name))
|
|
106
|
+
)
|
|
107
|
+
for provider in self.authenticator.auth_providers
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
@auth_router.get(provider_login_path, tags=["Authentication"])
|
|
111
|
+
async def provider_login(request: Request, provider_name: str, redirect_url: str | None = None) -> RedirectResponse:
|
|
112
|
+
"""Get OAuth login URL for the provider"""
|
|
113
|
+
client = oauth.create_client(provider_name)
|
|
114
|
+
if client is None:
|
|
115
|
+
raise InvalidInputError(status_code=404, error="provider_not_found", error_description=f"Provider {provider_name} not found or configured.")
|
|
116
|
+
|
|
117
|
+
callback_uri = str(request.url_for('provider_callback', provider_name=provider_name))
|
|
118
|
+
request.session["redirect_url"] = redirect_url
|
|
119
|
+
|
|
120
|
+
return await client.authorize_redirect(request, callback_uri)
|
|
121
|
+
|
|
122
|
+
@auth_router.get(provider_callback_path, tags=["Authentication"], responses={
|
|
123
|
+
200: {"model": UserInfoModel, "description": "Login successful, returns user information"},
|
|
124
|
+
302: {"description": "Redirect if redirect_url is in session"},
|
|
125
|
+
})
|
|
126
|
+
async def provider_callback(request: Request, provider_name: str):
|
|
127
|
+
"""Handle OAuth callback from provider"""
|
|
128
|
+
client = oauth.create_client(provider_name)
|
|
129
|
+
if client is None:
|
|
130
|
+
raise InvalidInputError(status_code=404, error="provider_not_found", error_description=f"Provider {provider_name} not found or configured.")
|
|
131
|
+
|
|
132
|
+
try:
|
|
133
|
+
token = await client.authorize_access_token(request)
|
|
134
|
+
except Exception as e:
|
|
135
|
+
raise InvalidInputError(status_code=400, error="provider_authorization_failed", error_description=f"Could not authorize with provider for access token: {str(e)}")
|
|
136
|
+
|
|
137
|
+
user_info: dict = {}
|
|
138
|
+
if token:
|
|
139
|
+
if 'userinfo' in token:
|
|
140
|
+
user_info = token['userinfo']
|
|
141
|
+
elif 'id_token' in token and isinstance(token['id_token'], dict) and 'sub' in token['id_token']:
|
|
142
|
+
user_info = token['id_token']
|
|
143
|
+
else:
|
|
144
|
+
raise InvalidInputError(status_code=400, error="invalid_provider_user_info", error_description=f"User information not found in token for {provider_name}")
|
|
145
|
+
|
|
146
|
+
user = self.authenticator.create_or_get_user_from_provider(provider_name, user_info)
|
|
147
|
+
access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_mins)
|
|
148
|
+
request.session["access_token"] = access_token
|
|
149
|
+
request.session["access_token_expiry"] = expiry.timestamp()
|
|
150
|
+
|
|
151
|
+
redirect_url = request.session.pop("redirect_url", None)
|
|
152
|
+
return RedirectResponse(url=redirect_url) if redirect_url else user
|
|
153
|
+
|
|
154
|
+
# Logout endpoint
|
|
155
|
+
logout_path = '/logout'
|
|
156
|
+
|
|
157
|
+
@auth_router.get(logout_path, tags=["Authentication"], responses={
|
|
158
|
+
200: {"description": "Logout successful"},
|
|
159
|
+
302: {"description": "Redirect if redirect URL parameter is specified"},
|
|
160
|
+
})
|
|
161
|
+
async def logout(request: Request, redirect_url: str | None = None):
|
|
162
|
+
request.session.pop("access_token", None)
|
|
163
|
+
request.session.pop("access_token_expiry", None)
|
|
164
|
+
if redirect_url:
|
|
165
|
+
return RedirectResponse(url=redirect_url)
|
|
166
|
+
|
|
167
|
+
# Change password endpoint
|
|
168
|
+
change_password_path = '/change-password'
|
|
169
|
+
|
|
170
|
+
class ChangePasswordRequest(BaseModel):
|
|
171
|
+
old_password: str
|
|
172
|
+
new_password: str
|
|
173
|
+
|
|
174
|
+
@auth_router.put(change_password_path, description="Change the password for the current user", tags=["Authentication"])
|
|
175
|
+
async def change_password(request: ChangePasswordRequest, user: UserInfoModel | None = Depends(self.get_current_user)) -> None:
|
|
176
|
+
if user is None:
|
|
177
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
178
|
+
self.authenticator.change_password(user.username, request.old_password, request.new_password)
|
|
179
|
+
|
|
180
|
+
# API Key endpoints
|
|
181
|
+
api_key_path = '/api-key'
|
|
182
|
+
|
|
183
|
+
class ApiKeyRequestBody(BaseModel):
|
|
184
|
+
title: str = Field(description=f"The title of the API key")
|
|
185
|
+
expiry_minutes: int | None = Field(
|
|
186
|
+
default=None,
|
|
187
|
+
description=f"The number of minutes the API key is valid for (or valid indefinitely if not provided)."
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
@auth_router.post(api_key_path, description="Create a new API key for the user", tags=["Authentication"])
|
|
191
|
+
async def create_api_key(body: ApiKeyRequestBody, user: UserInfoModel | None = Depends(self.get_current_user)) -> rm.ApiKeyResponse:
|
|
192
|
+
if user is None:
|
|
193
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
194
|
+
|
|
195
|
+
api_key, _ = self.authenticator.create_access_token(user, expiry_minutes=body.expiry_minutes, title=body.title)
|
|
196
|
+
return rm.ApiKeyResponse(api_key=api_key)
|
|
197
|
+
|
|
198
|
+
@auth_router.get(api_key_path, description="Get all API keys with title for the current user", tags=["Authentication"])
|
|
199
|
+
async def get_all_api_keys(user: UserInfoModel | None = Depends(self.get_current_user)):
|
|
200
|
+
if user is None:
|
|
201
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
202
|
+
return self.authenticator.get_all_api_keys(user.username)
|
|
203
|
+
|
|
204
|
+
revoke_api_key_path = '/api-key/{api_key_id}'
|
|
205
|
+
|
|
206
|
+
@auth_router.delete(revoke_api_key_path, description="Revoke an API key", tags=["Authentication"], responses={
|
|
207
|
+
204: { "description": "API key revoked successfully" }
|
|
208
|
+
})
|
|
209
|
+
async def revoke_api_key(api_key_id: str, user: UserInfoModel | None = Depends(self.get_current_user)) -> Response:
|
|
210
|
+
if user is None:
|
|
211
|
+
raise InvalidInputError(401, "Invalid authorization token", "Invalid authorization token")
|
|
212
|
+
self.authenticator.revoke_api_key(user.username, api_key_id)
|
|
213
|
+
return Response(status_code=204)
|
|
214
|
+
|
|
215
|
+
# User management endpoints
|
|
216
|
+
user_fields_path = '/user-fields'
|
|
217
|
+
|
|
218
|
+
@user_management_router.get(user_fields_path, description="Get details of the user fields", tags=["User Management"])
|
|
219
|
+
async def get_user_fields():
|
|
220
|
+
return self.authenticator.user_fields
|
|
221
|
+
|
|
222
|
+
add_user_path = '/users'
|
|
223
|
+
|
|
224
|
+
@user_management_router.post(add_user_path, description="Add a new user by providing details for username, password, and user fields", tags=["User Management"])
|
|
225
|
+
async def add_user(
|
|
226
|
+
new_user: AddUserModel, user: UserInfoModel | None = Depends(self.get_current_user)
|
|
227
|
+
) -> None:
|
|
228
|
+
if user is None or not user.is_admin:
|
|
229
|
+
raise InvalidInputError(403, "Forbidden to add user", "Authorized user is forbidden to add new users")
|
|
230
|
+
self.authenticator.add_user(new_user.username, new_user.model_dump(mode='json', exclude={"username"}))
|
|
231
|
+
|
|
232
|
+
update_user_path = '/users/{username}'
|
|
233
|
+
|
|
234
|
+
@user_management_router.put(update_user_path, description="Update the user of the given username given the new user details", tags=["User Management"])
|
|
235
|
+
async def update_user(
|
|
236
|
+
username: str, updated_user: UpdateUserModel, user: UserInfoModel | None = Depends(self.get_current_user)
|
|
237
|
+
) -> None:
|
|
238
|
+
if user is None or not user.is_admin:
|
|
239
|
+
raise InvalidInputError(403, "Forbidden to update user", "Authorized user is forbidden to update users")
|
|
240
|
+
self.authenticator.add_user(username, updated_user.model_dump(mode='json'), update_user=True)
|
|
241
|
+
|
|
242
|
+
list_users_path = '/users'
|
|
243
|
+
|
|
244
|
+
@user_management_router.get(list_users_path, tags=["User Management"])
|
|
245
|
+
async def list_all_users():
|
|
246
|
+
return self.authenticator.get_all_users()
|
|
247
|
+
|
|
248
|
+
delete_user_path = '/users/{username}'
|
|
249
|
+
|
|
250
|
+
@user_management_router.delete(delete_user_path, tags=["User Management"], responses={
|
|
251
|
+
204: { "description": "User deleted successfully" }
|
|
252
|
+
})
|
|
253
|
+
async def delete_user(username: str, user: UserInfoModel | None = Depends(self.get_current_user)) -> Response:
|
|
254
|
+
if user is None or not user.is_admin:
|
|
255
|
+
raise InvalidInputError(403, "Forbidden to delete user", "Authorized user is forbidden to delete users")
|
|
256
|
+
if username == user.username:
|
|
257
|
+
raise InvalidInputError(403, "Cannot delete your own user", "Cannot delete your own user")
|
|
258
|
+
self.authenticator.delete_user(username)
|
|
259
|
+
return Response(status_code=204)
|
|
260
|
+
|
|
261
|
+
app.include_router(auth_router)
|
|
262
|
+
app.include_router(user_management_router)
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base utilities and dependencies for API routes
|
|
3
|
+
"""
|
|
4
|
+
from typing import Any, Mapping, TypeVar, Callable, Coroutine
|
|
5
|
+
from fastapi import Request, Response, Depends
|
|
6
|
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
7
|
+
from fastapi.templating import Jinja2Templates
|
|
8
|
+
from cachetools import TTLCache
|
|
9
|
+
from pydantic import BaseModel, create_model
|
|
10
|
+
from mcp.server.fastmcp import Context
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from datetime import datetime, timezone
|
|
13
|
+
|
|
14
|
+
from .. import _utils as u, _constants as c
|
|
15
|
+
from .._exceptions import InvalidInputError, ConfigurationError
|
|
16
|
+
from .._project import SquirrelsProject
|
|
17
|
+
|
|
18
|
+
T = TypeVar('T')
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class RouteBase:
|
|
22
|
+
"""Base class for route modules providing common functionality"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, get_bearer_token: HTTPBearer, project: SquirrelsProject, no_cache: bool = False):
|
|
25
|
+
self.project = project
|
|
26
|
+
self.no_cache = no_cache
|
|
27
|
+
self.logger = project._logger
|
|
28
|
+
self.env_vars = project._env_vars
|
|
29
|
+
self.manifest_cfg = project._manifest_cfg
|
|
30
|
+
self.authenticator = project._auth
|
|
31
|
+
self.param_cfg_set = project._param_cfg_set
|
|
32
|
+
|
|
33
|
+
# Setup templates
|
|
34
|
+
template_dir = Path(__file__).parent.parent / "_package_data" / "templates"
|
|
35
|
+
self.templates = Jinja2Templates(directory=str(template_dir))
|
|
36
|
+
|
|
37
|
+
# Create user models
|
|
38
|
+
fields_without_username = {
|
|
39
|
+
k: (v.annotation, v.default)
|
|
40
|
+
for k, v in self.authenticator.User.model_fields.items()
|
|
41
|
+
if k != "username"
|
|
42
|
+
}
|
|
43
|
+
self.UserModel = create_model("UserModel", __base__=BaseModel, **fields_without_username) # type: ignore
|
|
44
|
+
self.UserInfoModel = create_model("UserInfoModel", __base__=self.UserModel, username=str)
|
|
45
|
+
|
|
46
|
+
class UserInfoModel(self.UserInfoModel):
|
|
47
|
+
username: str
|
|
48
|
+
|
|
49
|
+
def __hash__(self):
|
|
50
|
+
return hash(self.username)
|
|
51
|
+
|
|
52
|
+
# Authorization dependency for current user
|
|
53
|
+
def get_token_from_session(request: Request) -> str | None:
|
|
54
|
+
expiry = request.session.get("access_token_expiry")
|
|
55
|
+
datetime_now = datetime.now(timezone.utc).timestamp()
|
|
56
|
+
if expiry and expiry > datetime_now:
|
|
57
|
+
return request.session.get("access_token")
|
|
58
|
+
else:
|
|
59
|
+
request.session.pop("access_token", None)
|
|
60
|
+
request.session.pop("access_token_expiry", None)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
async def get_current_user(
|
|
64
|
+
request: Request, response: Response, auth: HTTPAuthorizationCredentials = Depends(get_bearer_token)
|
|
65
|
+
) -> UserInfoModel | None:
|
|
66
|
+
token = auth.credentials if auth and auth.scheme == "Bearer" else None
|
|
67
|
+
final_token = token if token else get_token_from_session(request)
|
|
68
|
+
user = self.authenticator.get_user_from_token(final_token)
|
|
69
|
+
username = "" if user is None else user.username
|
|
70
|
+
response.headers["Applied-Username"] = username
|
|
71
|
+
return UserInfoModel(**user.model_dump(mode='json')) if user else None
|
|
72
|
+
|
|
73
|
+
self.get_current_user = get_current_user
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def _parameters_description(self) -> str:
|
|
77
|
+
"""Get the standard parameters description"""
|
|
78
|
+
return "Selections of one parameter may cascade the available options in another parameter. " \
|
|
79
|
+
"For example, if the dataset has parameters for 'country' and 'city', available options for 'city' would " \
|
|
80
|
+
"depend on the selected option 'country'. If a parameter has 'trigger_refresh' as true, provide the parameter " \
|
|
81
|
+
"selection to this endpoint whenever it changes to refresh the parameter options of children parameters."
|
|
82
|
+
|
|
83
|
+
def _validate_request_params(self, all_request_params: Mapping, params: Mapping) -> None:
|
|
84
|
+
"""Validate request parameters"""
|
|
85
|
+
if params.get("x_verify_params", False):
|
|
86
|
+
invalid_params = [param for param in all_request_params if param not in params]
|
|
87
|
+
if invalid_params:
|
|
88
|
+
raise InvalidInputError(400, "Invalid query parameters", f"Invalid query parameters: {', '.join(invalid_params)}")
|
|
89
|
+
|
|
90
|
+
def get_selections_as_immutable(self, params: Mapping, uncached_keys: set[str]) -> tuple[tuple[str, Any], ...]:
|
|
91
|
+
"""Convert selections into a cachable tuple of pairs"""
|
|
92
|
+
selections = list()
|
|
93
|
+
for key, val in params.items():
|
|
94
|
+
if key in uncached_keys or val is None:
|
|
95
|
+
continue
|
|
96
|
+
if isinstance(val, (list, tuple)):
|
|
97
|
+
if len(val) == 1: # for backward compatibility
|
|
98
|
+
val = val[0]
|
|
99
|
+
else:
|
|
100
|
+
val = tuple(val)
|
|
101
|
+
selections.append((u.normalize_name(key), val))
|
|
102
|
+
return tuple(selections)
|
|
103
|
+
|
|
104
|
+
async def do_cachable_action(self, cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
|
|
105
|
+
"""Execute a cachable action"""
|
|
106
|
+
cache_key = tuple(args)
|
|
107
|
+
result = cache.get(cache_key)
|
|
108
|
+
if result is None:
|
|
109
|
+
result = await action(*args)
|
|
110
|
+
cache[cache_key] = result
|
|
111
|
+
return result
|
|
112
|
+
|
|
113
|
+
def _get_request_id(self, request: Request) -> str:
|
|
114
|
+
"""Get request ID from headers"""
|
|
115
|
+
return request.headers.get("x-request-id", "")
|
|
116
|
+
|
|
117
|
+
def log_activity_time(self, activity: str, start_time: float, request: Request) -> None:
|
|
118
|
+
"""Log activity time"""
|
|
119
|
+
self.logger.log_activity_time(activity, start_time, request_id=self._get_request_id(request))
|
|
120
|
+
|
|
121
|
+
def get_name_from_path_section(self, request: Request, section: int) -> str:
|
|
122
|
+
"""Extract name from request path section"""
|
|
123
|
+
url_path: str = request.scope['route'].path
|
|
124
|
+
name_raw = url_path.split('/')[section]
|
|
125
|
+
return u.normalize_name(name_raw)
|
|
126
|
+
|
|
127
|
+
def _get_access_token_expiry_minutes(self) -> int:
|
|
128
|
+
"""Get access token expiry minutes"""
|
|
129
|
+
expiry_mins = self.env_vars.get(c.SQRL_AUTH_TOKEN_EXPIRE_MINUTES, 30)
|
|
130
|
+
try:
|
|
131
|
+
expiry_mins = int(expiry_mins)
|
|
132
|
+
except ValueError:
|
|
133
|
+
raise ConfigurationError(f"Value for environment variable {c.SQRL_AUTH_TOKEN_EXPIRE_MINUTES} is not an integer, got: {expiry_mins}")
|
|
134
|
+
return expiry_mins
|
|
135
|
+
|
|
136
|
+
def get_user_from_tool_ctx(self, tool_ctx: Context):
|
|
137
|
+
request = tool_ctx.request_context.request
|
|
138
|
+
assert request is not None and hasattr(request, "headers")
|
|
139
|
+
headers: dict[str, str] = request.headers
|
|
140
|
+
# Check if 'Authorization' header is present
|
|
141
|
+
authorization_header = headers.get('Authorization')
|
|
142
|
+
|
|
143
|
+
if authorization_header:
|
|
144
|
+
# Split the header into 'Bearer <token>'
|
|
145
|
+
parts = authorization_header.split()
|
|
146
|
+
|
|
147
|
+
if len(parts) == 2 and parts[0] == 'Bearer':
|
|
148
|
+
access_token = parts[1]
|
|
149
|
+
user = self.authenticator.get_user_from_token(access_token)
|
|
150
|
+
return user
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError("Invalid Authorization header format")
|
|
153
|
+
else:
|
|
154
|
+
return None
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dashboard routes for parameters and results
|
|
3
|
+
"""
|
|
4
|
+
from typing import Callable, Any
|
|
5
|
+
from fastapi import FastAPI, Depends, Request, Response
|
|
6
|
+
from fastapi.responses import JSONResponse, HTMLResponse
|
|
7
|
+
from fastapi.security import HTTPBearer
|
|
8
|
+
from dataclasses import asdict
|
|
9
|
+
from cachetools import TTLCache
|
|
10
|
+
import time
|
|
11
|
+
|
|
12
|
+
from .. import _constants as c, _utils as u
|
|
13
|
+
from .._schemas import response_models as rm
|
|
14
|
+
from .._exceptions import ConfigurationError
|
|
15
|
+
from .._dashboards import Dashboard
|
|
16
|
+
from .._schemas.query_param_models import get_query_models_for_parameters, get_query_models_for_dashboard
|
|
17
|
+
from .._auth import BaseUser
|
|
18
|
+
from .base import RouteBase
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DashboardRoutes(RouteBase):
|
|
22
|
+
"""Dashboard parameter and result routes"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
25
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
26
|
+
|
|
27
|
+
# Setup caches
|
|
28
|
+
dashboard_results_cache_size = int(self.env_vars.get(c.SQRL_DASHBOARDS_CACHE_SIZE, 128))
|
|
29
|
+
dashboard_results_cache_ttl = int(self.env_vars.get(c.SQRL_DASHBOARDS_CACHE_TTL_MINUTES, 60))
|
|
30
|
+
self.dashboard_results_cache = TTLCache(maxsize=dashboard_results_cache_size, ttl=dashboard_results_cache_ttl*60)
|
|
31
|
+
|
|
32
|
+
async def _get_dashboard_results_helper(
|
|
33
|
+
self, dashboard: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
34
|
+
) -> Dashboard:
|
|
35
|
+
"""Helper to get dashboard results"""
|
|
36
|
+
return await self.project.dashboard(dashboard, selections=dict(selections), user=user)
|
|
37
|
+
|
|
38
|
+
async def _get_dashboard_results_cachable(
|
|
39
|
+
self, dashboard: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
40
|
+
) -> Dashboard:
|
|
41
|
+
"""Cachable version of dashboard results helper"""
|
|
42
|
+
return await self.do_cachable_action(self.dashboard_results_cache, self._get_dashboard_results_helper, dashboard, user, selections)
|
|
43
|
+
|
|
44
|
+
async def _get_dashboard_results_definition(
|
|
45
|
+
self, dashboard_name: str, user: BaseUser | None, all_request_params: dict, params: dict
|
|
46
|
+
) -> Response:
|
|
47
|
+
"""Get dashboard results definition"""
|
|
48
|
+
self._validate_request_params(all_request_params, params)
|
|
49
|
+
|
|
50
|
+
get_dashboard_function = self._get_dashboard_results_helper if self.no_cache else self._get_dashboard_results_cachable
|
|
51
|
+
selections = self.get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
|
|
52
|
+
dashboard_obj = await get_dashboard_function(dashboard_name, user, selections)
|
|
53
|
+
|
|
54
|
+
if dashboard_obj._format == c.PNG:
|
|
55
|
+
assert isinstance(dashboard_obj._content, bytes)
|
|
56
|
+
result = Response(dashboard_obj._content, media_type="image/png")
|
|
57
|
+
elif dashboard_obj._format == c.HTML:
|
|
58
|
+
result = HTMLResponse(dashboard_obj._content)
|
|
59
|
+
else:
|
|
60
|
+
raise NotImplementedError()
|
|
61
|
+
return result
|
|
62
|
+
|
|
63
|
+
def setup_routes(
|
|
64
|
+
self, app: FastAPI, project_metadata_path: str, param_fields: dict, get_parameters_definition: Callable
|
|
65
|
+
) -> None:
|
|
66
|
+
"""Setup dashboard routes"""
|
|
67
|
+
|
|
68
|
+
dashboard_results_path = project_metadata_path + '/dashboard/{dashboard}'
|
|
69
|
+
dashboard_parameters_path = dashboard_results_path + '/parameters'
|
|
70
|
+
|
|
71
|
+
def validate_parameters_list(parameters: list[str] | None, entity_type: str, dashboard_name: str) -> None:
|
|
72
|
+
if parameters is None:
|
|
73
|
+
return
|
|
74
|
+
for param in parameters:
|
|
75
|
+
if param not in param_fields:
|
|
76
|
+
all_params = list(param_fields.keys())
|
|
77
|
+
raise ConfigurationError(
|
|
78
|
+
f"{entity_type} '{dashboard_name}' use parameter '{param}' which doesn't exist. Available parameters are:"
|
|
79
|
+
f"\n {all_params}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Dashboard parameters and results APIs
|
|
83
|
+
for dashboard_name, dashboard in self.project._dashboards.items():
|
|
84
|
+
dashboard_normalized = u.normalize_name_for_api(dashboard_name)
|
|
85
|
+
curr_parameters_path = dashboard_parameters_path.format(dashboard=dashboard_normalized)
|
|
86
|
+
curr_results_path = dashboard_results_path.format(dashboard=dashboard_normalized)
|
|
87
|
+
|
|
88
|
+
validate_parameters_list(dashboard.config.parameters, "Dashboard", dashboard_name)
|
|
89
|
+
|
|
90
|
+
QueryModelForGetParams, QueryModelForPostParams = get_query_models_for_parameters(dashboard.config.parameters, param_fields)
|
|
91
|
+
QueryModelForGetDash, QueryModelForPostDash = get_query_models_for_dashboard(dashboard.config.parameters, param_fields)
|
|
92
|
+
|
|
93
|
+
@app.get(curr_parameters_path, tags=[f"Dashboard '{dashboard_name}'"], description=self._parameters_description, response_class=JSONResponse)
|
|
94
|
+
async def get_dashboard_parameters(
|
|
95
|
+
request: Request, params: QueryModelForGetParams, user=Depends(self.get_current_user) # type: ignore
|
|
96
|
+
) -> rm.ParametersModel:
|
|
97
|
+
start = time.time()
|
|
98
|
+
curr_dashboard_name = self.get_name_from_path_section(request, -2)
|
|
99
|
+
parameters_list = self.project._dashboards[curr_dashboard_name].config.parameters
|
|
100
|
+
scope = self.project._dashboards[curr_dashboard_name].config.scope
|
|
101
|
+
result = await get_parameters_definition(
|
|
102
|
+
parameters_list, "dashboard", curr_dashboard_name, scope, user, dict(request.query_params), asdict(params)
|
|
103
|
+
)
|
|
104
|
+
self.log_activity_time("GET REQUEST for PARAMETERS", start, request)
|
|
105
|
+
return result
|
|
106
|
+
|
|
107
|
+
@app.post(curr_parameters_path, tags=[f"Dashboard '{dashboard_name}'"], description=self._parameters_description, response_class=JSONResponse)
|
|
108
|
+
async def get_dashboard_parameters_with_post(
|
|
109
|
+
request: Request, params: QueryModelForPostParams, user=Depends(self.get_current_user) # type: ignore
|
|
110
|
+
) -> rm.ParametersModel:
|
|
111
|
+
start = time.time()
|
|
112
|
+
curr_dashboard_name = self.get_name_from_path_section(request, -2)
|
|
113
|
+
parameters_list = self.project._dashboards[curr_dashboard_name].config.parameters
|
|
114
|
+
scope = self.project._dashboards[curr_dashboard_name].config.scope
|
|
115
|
+
payload: dict = await request.json()
|
|
116
|
+
result = await get_parameters_definition(
|
|
117
|
+
parameters_list, "dashboard", curr_dashboard_name, scope, user, payload, params.model_dump()
|
|
118
|
+
)
|
|
119
|
+
self.log_activity_time("POST REQUEST for PARAMETERS", start, request)
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
@app.get(curr_results_path, tags=[f"Dashboard '{dashboard_name}'"], description=dashboard.config.description, response_class=Response)
|
|
123
|
+
async def get_dashboard_results(
|
|
124
|
+
request: Request, params: QueryModelForGetDash, user=Depends(self.get_current_user) # type: ignore
|
|
125
|
+
) -> Response:
|
|
126
|
+
start = time.time()
|
|
127
|
+
curr_dashboard_name = self.get_name_from_path_section(request, -1)
|
|
128
|
+
result = await self._get_dashboard_results_definition(curr_dashboard_name, user, dict(request.query_params), asdict(params))
|
|
129
|
+
self.log_activity_time("GET REQUEST for DASHBOARD RESULTS", start, request)
|
|
130
|
+
return result
|
|
131
|
+
|
|
132
|
+
@app.post(curr_results_path, tags=[f"Dashboard '{dashboard_name}'"], description=dashboard.config.description, response_class=Response)
|
|
133
|
+
async def get_dashboard_results_with_post(
|
|
134
|
+
request: Request, params: QueryModelForPostDash, user=Depends(self.get_current_user) # type: ignore
|
|
135
|
+
) -> Response:
|
|
136
|
+
start = time.time()
|
|
137
|
+
curr_dashboard_name = self.get_name_from_path_section(request, -1)
|
|
138
|
+
payload: dict = await request.json()
|
|
139
|
+
result = await self._get_dashboard_results_definition(curr_dashboard_name, user, payload, params.model_dump())
|
|
140
|
+
self.log_activity_time("POST REQUEST for DASHBOARD RESULTS", start, request)
|
|
141
|
+
return result
|
|
142
|
+
|