squirrels 0.5.0b2__py3-none-any.whl → 0.5.0b4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- dateutils/__init__.py +6 -460
- dateutils/_enums.py +25 -0
- dateutils/_implementation.py +409 -0
- dateutils/types.py +6 -0
- squirrels/__init__.py +9 -13
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +262 -0
- squirrels/_api_routes/base.py +154 -0
- squirrels/_api_routes/dashboards.py +142 -0
- squirrels/_api_routes/data_management.py +103 -0
- squirrels/_api_routes/datasets.py +242 -0
- squirrels/_api_routes/oauth2.py +300 -0
- squirrels/_api_routes/project.py +214 -0
- squirrels/_api_server.py +145 -748
- squirrels/_arguments/__init__.py +0 -0
- squirrels/{arguments → _arguments}/init_time_args.py +7 -2
- squirrels/{arguments → _arguments}/run_time_args.py +4 -26
- squirrels/_auth.py +646 -93
- squirrels/_connection_set.py +5 -5
- squirrels/_constants.py +7 -1
- squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
- squirrels/_data_sources.py +564 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_initializer.py +31 -26
- squirrels/_manifest.py +5 -5
- squirrels/_model_builder.py +1 -1
- squirrels/_model_configs.py +2 -2
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +40 -27
- squirrels/{package_data → _package_data}/base_project/.env +1 -0
- squirrels/{package_data → _package_data}/base_project/.env.example +1 -0
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
- squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.yml +2 -2
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
- squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
- squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.sql +1 -1
- squirrels/_package_data/base_project/models/federates/federate_example.py +41 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +25 -0
- squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +6 -6
- squirrels/{package_data → _package_data}/base_project/parameters.yml +9 -8
- squirrels/_package_data/base_project/pyconfigs/connections.py +14 -0
- squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +14 -16
- squirrels/_package_data/base_project/pyconfigs/parameters.py +106 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +51 -0
- squirrels/_package_data/templates/dataset_results.html +112 -0
- squirrels/_package_data/templates/oauth_login.html +271 -0
- squirrels/_parameter_configs.py +35 -35
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +47 -37
- squirrels/_parameters.py +1664 -0
- squirrels/_project.py +76 -32
- squirrels/_py_module.py +3 -2
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +144 -0
- squirrels/_schemas/query_param_models.py +67 -0
- squirrels/{_api_response_models.py → _schemas/response_models.py} +12 -8
- squirrels/_utils.py +38 -4
- squirrels/arguments.py +2 -0
- squirrels/auth.py +1 -0
- squirrels/connections.py +1 -0
- squirrels/dashboards.py +1 -82
- squirrels/data_sources.py +8 -563
- squirrels/parameter_options.py +8 -348
- squirrels/parameters.py +9 -1266
- squirrels/types.py +11 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/METADATA +4 -1
- squirrels-0.5.0b4.dist-info/RECORD +94 -0
- squirrels/package_data/base_project/macros/macros_example.sql +0 -15
- squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
- squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
- squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
- squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
- squirrels/package_data/base_project/pyconfigs/user.py +0 -23
- squirrels-0.5.0b2.dist-info/RECORD +0 -70
- /squirrels/{dataset_result.py → _dataset_types.py} +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
- /squirrels/{package_data → _package_data}/base_project/connections.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
- /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +0 -0
- /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/dbviews/dbview_example.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/models/sources.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
- /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +0 -0
- /squirrels/{package_data → _package_data}/base_project/squirrels.yml.j2 +0 -0
- /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b2.dist-info → squirrels-0.5.0b4.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dataset routes for parameters and results
|
|
3
|
+
"""
|
|
4
|
+
from typing import Callable, Any
|
|
5
|
+
from pydantic import Field, BaseModel
|
|
6
|
+
from fastapi import FastAPI, Depends, Request
|
|
7
|
+
from fastapi.responses import JSONResponse, HTMLResponse
|
|
8
|
+
from fastapi.security import HTTPBearer
|
|
9
|
+
|
|
10
|
+
from mcp.server.fastmcp import FastMCP, Context
|
|
11
|
+
from dataclasses import asdict
|
|
12
|
+
from cachetools import TTLCache
|
|
13
|
+
from textwrap import dedent
|
|
14
|
+
|
|
15
|
+
import time
|
|
16
|
+
|
|
17
|
+
from .. import _constants as c, _utils as u
|
|
18
|
+
from .._schemas import response_models as rm
|
|
19
|
+
from .._exceptions import ConfigurationError, InvalidInputError
|
|
20
|
+
from .._dataset_types import DatasetResult
|
|
21
|
+
from .._schemas.query_param_models import get_query_models_for_parameters, get_query_models_for_dataset
|
|
22
|
+
from .._auth import BaseUser
|
|
23
|
+
from .base import RouteBase
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class DatasetRoutes(RouteBase):
|
|
27
|
+
"""Dataset parameter and result routes"""
|
|
28
|
+
|
|
29
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
30
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
31
|
+
|
|
32
|
+
# Setup caches
|
|
33
|
+
dataset_results_cache_size = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_SIZE, 128))
|
|
34
|
+
dataset_results_cache_ttl = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_TTL_MINUTES, 60))
|
|
35
|
+
self.dataset_results_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
|
|
36
|
+
|
|
37
|
+
async def _get_dataset_results_helper(
|
|
38
|
+
self, dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
39
|
+
) -> DatasetResult:
|
|
40
|
+
"""Helper to get dataset results"""
|
|
41
|
+
return await self.project.dataset(dataset, selections=dict(selections), user=user)
|
|
42
|
+
|
|
43
|
+
async def _get_dataset_results_cachable(
|
|
44
|
+
self, dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
45
|
+
) -> DatasetResult:
|
|
46
|
+
"""Cachable version of dataset results helper"""
|
|
47
|
+
return await self.do_cachable_action(self.dataset_results_cache, self._get_dataset_results_helper, dataset, user, selections)
|
|
48
|
+
|
|
49
|
+
async def _get_dataset_results_definition(
|
|
50
|
+
self, dataset_name: str, user: BaseUser | None, all_request_params: dict, params: dict
|
|
51
|
+
) -> rm.DatasetResultModel:
|
|
52
|
+
"""Get dataset results definition"""
|
|
53
|
+
self._validate_request_params(all_request_params, params)
|
|
54
|
+
|
|
55
|
+
get_dataset_function = self._get_dataset_results_helper if self.no_cache else self._get_dataset_results_cachable
|
|
56
|
+
uncached_keys = {"x_verify_params", "x_orientation", "x_select", "x_limit", "x_offset"}
|
|
57
|
+
selections = self.get_selections_as_immutable(params, uncached_keys)
|
|
58
|
+
result = await get_dataset_function(dataset_name, user, selections)
|
|
59
|
+
|
|
60
|
+
orientation = params.get("x_orientation", "records")
|
|
61
|
+
raw_select: list[str] | None = params.get("x_select")
|
|
62
|
+
select = tuple(raw_select) if raw_select is not None else tuple()
|
|
63
|
+
limit = params.get("x_limit", 1000)
|
|
64
|
+
offset = params.get("x_offset", 0)
|
|
65
|
+
return rm.DatasetResultModel(**result.to_json(orientation, select, limit, offset))
|
|
66
|
+
|
|
67
|
+
def setup_routes(
|
|
68
|
+
self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_version: str,
|
|
69
|
+
param_fields: dict, get_parameters_definition: Callable
|
|
70
|
+
) -> None:
|
|
71
|
+
"""Setup dataset routes"""
|
|
72
|
+
|
|
73
|
+
dataset_results_path = project_metadata_path + '/dataset/{dataset}'
|
|
74
|
+
dataset_parameters_path = dataset_results_path + '/parameters'
|
|
75
|
+
|
|
76
|
+
def validate_parameters_list(parameters: list[str] | None, entity_type: str, dataset_name: str) -> None:
|
|
77
|
+
if parameters is None:
|
|
78
|
+
return
|
|
79
|
+
for param in parameters:
|
|
80
|
+
if param not in param_fields:
|
|
81
|
+
all_params = list(param_fields.keys())
|
|
82
|
+
raise ConfigurationError(
|
|
83
|
+
f"{entity_type} '{dataset_name}' use parameter '{param}' which doesn't exist. Available parameters are:"
|
|
84
|
+
f"\n {all_params}"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
async def get_dataset_parameters_updates(dataset_name: str, user: BaseUser | None, all_request_params: dict, params: dict):
|
|
88
|
+
parameters_list = self.manifest_cfg.datasets[dataset_name].parameters
|
|
89
|
+
scope = self.manifest_cfg.datasets[dataset_name].scope
|
|
90
|
+
result = await get_parameters_definition(
|
|
91
|
+
parameters_list, "dataset", dataset_name, scope, user, all_request_params, params
|
|
92
|
+
)
|
|
93
|
+
return result
|
|
94
|
+
|
|
95
|
+
# Dataset parameters and results APIs
|
|
96
|
+
for dataset_name, dataset_config in self.manifest_cfg.datasets.items():
|
|
97
|
+
dataset_normalized = u.normalize_name_for_api(dataset_name)
|
|
98
|
+
curr_parameters_path = dataset_parameters_path.format(dataset=dataset_normalized)
|
|
99
|
+
curr_results_path = dataset_results_path.format(dataset=dataset_normalized)
|
|
100
|
+
|
|
101
|
+
validate_parameters_list(dataset_config.parameters, "Dataset", dataset_name)
|
|
102
|
+
|
|
103
|
+
QueryModelForGetParams, QueryModelForPostParams = get_query_models_for_parameters(dataset_config.parameters, param_fields)
|
|
104
|
+
QueryModelForGetDataset, QueryModelForPostDataset = get_query_models_for_dataset(dataset_config.parameters, param_fields)
|
|
105
|
+
|
|
106
|
+
@app.get(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=self._parameters_description, response_class=JSONResponse)
|
|
107
|
+
async def get_dataset_parameters(
|
|
108
|
+
request: Request, params: QueryModelForGetParams, user=Depends(self.get_current_user) # type: ignore
|
|
109
|
+
) -> rm.ParametersModel:
|
|
110
|
+
start = time.time()
|
|
111
|
+
curr_dataset_name = self.get_name_from_path_section(request, -2)
|
|
112
|
+
result = await get_dataset_parameters_updates(curr_dataset_name, user, dict(request.query_params), asdict(params))
|
|
113
|
+
self.log_activity_time("GET REQUEST for PARAMETERS", start, request)
|
|
114
|
+
return result
|
|
115
|
+
|
|
116
|
+
@app.post(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=self._parameters_description, response_class=JSONResponse)
|
|
117
|
+
async def get_dataset_parameters_with_post(
|
|
118
|
+
request: Request, params: QueryModelForPostParams, user=Depends(self.get_current_user) # type: ignore
|
|
119
|
+
) -> rm.ParametersModel:
|
|
120
|
+
start = time.time()
|
|
121
|
+
curr_dataset_name = self.get_name_from_path_section(request, -2)
|
|
122
|
+
payload: dict = await request.json()
|
|
123
|
+
result = await get_dataset_parameters_updates(curr_dataset_name, user, payload, params.model_dump())
|
|
124
|
+
self.log_activity_time("POST REQUEST for PARAMETERS", start, request)
|
|
125
|
+
return result
|
|
126
|
+
|
|
127
|
+
@app.get(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
|
|
128
|
+
async def get_dataset_results(
|
|
129
|
+
request: Request, params: QueryModelForGetDataset, user=Depends(self.get_current_user) # type: ignore
|
|
130
|
+
) -> rm.DatasetResultModel:
|
|
131
|
+
start = time.time()
|
|
132
|
+
curr_dataset_name = self.get_name_from_path_section(request, -1)
|
|
133
|
+
result = await self._get_dataset_results_definition(curr_dataset_name, user, dict(request.query_params), asdict(params))
|
|
134
|
+
self.log_activity_time("GET REQUEST for DATASET RESULTS", start, request)
|
|
135
|
+
return result
|
|
136
|
+
|
|
137
|
+
@app.post(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
|
|
138
|
+
async def get_dataset_results_with_post(
|
|
139
|
+
request: Request, params: QueryModelForPostDataset, user=Depends(self.get_current_user) # type: ignore
|
|
140
|
+
) -> rm.DatasetResultModel:
|
|
141
|
+
start = time.time()
|
|
142
|
+
curr_dataset_name = self.get_name_from_path_section(request, -1)
|
|
143
|
+
payload: dict = await request.json()
|
|
144
|
+
result = await self._get_dataset_results_definition(curr_dataset_name, user, payload, params.model_dump())
|
|
145
|
+
self.log_activity_time("POST REQUEST for DATASET RESULTS", start, request)
|
|
146
|
+
return result
|
|
147
|
+
|
|
148
|
+
# Setup MCP tools
|
|
149
|
+
|
|
150
|
+
@mcp.tool(
|
|
151
|
+
name=f"get_dataset_parameters_for_{project_name}_{project_version}",
|
|
152
|
+
description=dedent(f"""
|
|
153
|
+
Use this tool to get updates for dataset parameters in the Squirrels project "{project_name}" when a selection is to be made on a parameter with "trigger_refresh" as true.
|
|
154
|
+
|
|
155
|
+
For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool will be called to get the updates for other parameters such as "city".
|
|
156
|
+
|
|
157
|
+
Do not use this tool on parameters whose "trigger_refresh" field is false!
|
|
158
|
+
""").strip()
|
|
159
|
+
)
|
|
160
|
+
async def get_dataset_parameters_tool(
|
|
161
|
+
ctx: Context,
|
|
162
|
+
dataset: str = Field(description="The name of the dataset whose parameters the trigger parameter will update"),
|
|
163
|
+
parameter_name: str = Field(description="The name of the parameter triggering the refresh"),
|
|
164
|
+
selected_ids: list[str] = Field(description="The ID(s) of the selected option(s) for the parameter"),
|
|
165
|
+
):
|
|
166
|
+
user = self.get_user_from_tool_ctx(ctx)
|
|
167
|
+
dataset_name = u.normalize_name(dataset)
|
|
168
|
+
payload = {
|
|
169
|
+
"x_parent_param": parameter_name,
|
|
170
|
+
parameter_name: selected_ids
|
|
171
|
+
}
|
|
172
|
+
return await get_dataset_parameters_updates(dataset_name, user, payload, payload)
|
|
173
|
+
|
|
174
|
+
@mcp.tool(
|
|
175
|
+
name=f"get_dataset_results_for_{project_name}_{project_version}",
|
|
176
|
+
description=dedent(f"""
|
|
177
|
+
Use this tool to get the dataset results as a JSON object for a dataset in the Squirrels project "{project_name}".
|
|
178
|
+
- Use the "offset" and "limit" arguments to limit the number of rows you require
|
|
179
|
+
- The "limit" argument controls the number of rows returned. The maximum allowed value is 100. If the 'total_num_rows' field in the response is greater than 100, let the user know that only 100 rows are shown and clarify if they would like to see more.
|
|
180
|
+
""").strip()
|
|
181
|
+
)
|
|
182
|
+
async def get_dataset_results_tool(
|
|
183
|
+
ctx: Context,
|
|
184
|
+
dataset: str = Field(description="The name of the dataset to get results for"),
|
|
185
|
+
parameters: dict[str, Any] = Field(description=dedent("""
|
|
186
|
+
Key-value pairs for parameter name and selected value. The selected value to provide depends on the parameter widget type:
|
|
187
|
+
- For single select, use a string for the ID of the selected value
|
|
188
|
+
- For multi select, use an array of strings for the IDs of the selected values
|
|
189
|
+
- For date, use a string like "YYYY-MM-DD"
|
|
190
|
+
- For date ranges, use array of strings like ["YYYY-MM-DD","YYYY-MM-DD"]
|
|
191
|
+
- For number, use a number like 1
|
|
192
|
+
- For number ranges, use array of numbers like [1,100]
|
|
193
|
+
- For text, use a string for the text value
|
|
194
|
+
- Complex objects are NOT supported""").strip()),
|
|
195
|
+
offset: int = Field(0, description="The number of rows to skip from first row. Default is 0."),
|
|
196
|
+
limit: int = Field(100, description="The maximum number of rows to return. Default is 100. Maximum allowed value is 100."),
|
|
197
|
+
):
|
|
198
|
+
if limit > 100:
|
|
199
|
+
raise ValueError("The maximum number of rows to return is 100.")
|
|
200
|
+
|
|
201
|
+
user = self.get_user_from_tool_ctx(ctx)
|
|
202
|
+
dataset_name = u.normalize_name(dataset)
|
|
203
|
+
params = {
|
|
204
|
+
**parameters,
|
|
205
|
+
"x_orientation": "rows",
|
|
206
|
+
"x_offset": offset,
|
|
207
|
+
"x_limit": limit
|
|
208
|
+
}
|
|
209
|
+
result = await self._get_dataset_results_definition(dataset_name, user, params, params)
|
|
210
|
+
return result
|
|
211
|
+
|
|
212
|
+
# Setup UI for tool results
|
|
213
|
+
mcp_tool_results_ui_path = project_metadata_path + "/mcp/tool-results-ui"
|
|
214
|
+
|
|
215
|
+
@app.get(mcp_tool_results_ui_path + "/list-tools", tags=["MCP Supplements"])
|
|
216
|
+
async def list_tools():
|
|
217
|
+
return ["get_dataset_results"]
|
|
218
|
+
|
|
219
|
+
class ToolResultBody(BaseModel):
|
|
220
|
+
"""Flexible model for tool results - accepts any additional fields"""
|
|
221
|
+
|
|
222
|
+
class Config:
|
|
223
|
+
extra = "allow" # Allow additional fields not defined in the model
|
|
224
|
+
|
|
225
|
+
@app.post(mcp_tool_results_ui_path + "/tool/{tool_name}", tags=["MCP Supplements"])
|
|
226
|
+
async def tool_results_ui(tool_name: str, tool_result: ToolResultBody):
|
|
227
|
+
if tool_name == "get_dataset_results":
|
|
228
|
+
# Convert Pydantic model to dict to access any extra fields
|
|
229
|
+
tool_result_dict = tool_result.model_dump()
|
|
230
|
+
|
|
231
|
+
# Prepare template context
|
|
232
|
+
context = {
|
|
233
|
+
"schema": tool_result_dict.get("schema", {}),
|
|
234
|
+
"data": tool_result_dict.get("data", []),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
# Render HTML template
|
|
238
|
+
html_content = self.templates.get_template("dataset_results.html").render(context)
|
|
239
|
+
return HTMLResponse(content=html_content, status_code=200)
|
|
240
|
+
else:
|
|
241
|
+
raise InvalidInputError(400, "Invalid tool name", f"Tool name '{tool_name}' not supported for UI")
|
|
242
|
+
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
from fastapi import FastAPI, Depends, Request, Query, Response, APIRouter, Form
|
|
2
|
+
from fastapi.responses import RedirectResponse, HTMLResponse
|
|
3
|
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
4
|
+
from typing import Annotated, cast
|
|
5
|
+
|
|
6
|
+
from .base import RouteBase
|
|
7
|
+
from .._schemas.auth_models import (
|
|
8
|
+
ClientRegistrationRequest, ClientUpdateRequest, ClientRegistrationResponse, ClientDetailsResponse, ClientUpdateResponse,
|
|
9
|
+
TokenResponse, OAuthServerMetadata
|
|
10
|
+
)
|
|
11
|
+
from .._exceptions import InvalidInputError
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OAuth2Routes(RouteBase):
|
|
15
|
+
"""OAuth2 routes"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, get_bearer_token: HTTPBearer, project, no_cache: bool = False):
|
|
18
|
+
super().__init__(get_bearer_token, project, no_cache)
|
|
19
|
+
|
|
20
|
+
def serve_login_page(self, auth_path: str, request: Request, client_id: str) -> HTMLResponse:
|
|
21
|
+
"""Helper function to serve the login page with optional error message"""
|
|
22
|
+
# Get client information for display
|
|
23
|
+
client_details = self.authenticator.get_oauth_client_details(client_id)
|
|
24
|
+
client_name = client_details.client_name if client_details else None
|
|
25
|
+
project_name = self.manifest_cfg.project_variables.label
|
|
26
|
+
|
|
27
|
+
# Get available login providers
|
|
28
|
+
providers = []
|
|
29
|
+
for provider in self.authenticator.auth_providers:
|
|
30
|
+
provider_login_url = f"{auth_path}/providers/{provider.name}/login"
|
|
31
|
+
providers.append({
|
|
32
|
+
"name": provider.name,
|
|
33
|
+
"label": provider.label,
|
|
34
|
+
"icon": provider.icon,
|
|
35
|
+
"login_url": provider_login_url
|
|
36
|
+
})
|
|
37
|
+
|
|
38
|
+
# Template context
|
|
39
|
+
context = {
|
|
40
|
+
"request": request,
|
|
41
|
+
"project_name": project_name,
|
|
42
|
+
"client_name": client_name,
|
|
43
|
+
"providers": providers,
|
|
44
|
+
"login_url": f"{auth_path}/login",
|
|
45
|
+
"return_url": str(request.url),
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
return HTMLResponse(
|
|
49
|
+
content=self.templates.get_template("oauth_login.html").render(context),
|
|
50
|
+
status_code=200
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def setup_routes(self, app: FastAPI) -> None:
|
|
54
|
+
"""Setup all OAuth2 routes"""
|
|
55
|
+
|
|
56
|
+
router_path = "/api/auth/oauth2"
|
|
57
|
+
router = APIRouter(prefix=router_path)
|
|
58
|
+
|
|
59
|
+
# Create user models
|
|
60
|
+
class UserInfoModel(self.UserInfoModel):
|
|
61
|
+
username: str
|
|
62
|
+
|
|
63
|
+
# Authorization dependency for client management
|
|
64
|
+
get_client_token = HTTPBearer(auto_error=False)
|
|
65
|
+
|
|
66
|
+
async def validate_client_registration_token(
|
|
67
|
+
client_id: str, auth: HTTPAuthorizationCredentials = Depends(get_client_token),
|
|
68
|
+
) -> None:
|
|
69
|
+
"""Validate Bearer token for client management operations"""
|
|
70
|
+
|
|
71
|
+
if not auth or not auth.scheme == "Bearer":
|
|
72
|
+
raise InvalidInputError(401, "invalid_client",
|
|
73
|
+
"Missing or invalid authorization header. Use 'Authorization: Bearer <registration_access_token>'"
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
token = auth.credentials
|
|
77
|
+
is_valid = self.authenticator.validate_registration_access_token(client_id, token)
|
|
78
|
+
if not is_valid:
|
|
79
|
+
raise InvalidInputError(401, "invalid_token", "Invalid registration access token for this client")
|
|
80
|
+
|
|
81
|
+
def validate_oauth_client_credentials(client_id: str | None, client_secret: str | None) -> str:
|
|
82
|
+
"""
|
|
83
|
+
Validate OAuth client credentials from form data or Authorization header.
|
|
84
|
+
Returns the validated client_id.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
# Validate client credentials
|
|
88
|
+
if not client_id or not client_secret or not self.authenticator.validate_client_credentials(client_id, client_secret):
|
|
89
|
+
raise InvalidInputError(400, "invalid_client", "Invalid client credentials")
|
|
90
|
+
|
|
91
|
+
return cast(str, client_id)
|
|
92
|
+
|
|
93
|
+
# Client Registration Endpoint
|
|
94
|
+
client_management_path = '/client/{client_id}'
|
|
95
|
+
|
|
96
|
+
@router.post("/register", description="Register a new OAuth client", tags=["OAuth2"])
|
|
97
|
+
async def register_oauth_client(request: ClientRegistrationRequest) -> ClientRegistrationResponse:
|
|
98
|
+
"""Register a new OAuth client and return client credentials"""
|
|
99
|
+
|
|
100
|
+
# Register the client using the authenticator
|
|
101
|
+
client_registration_response = self.authenticator.register_oauth_client(
|
|
102
|
+
request, client_management_path_format=router_path+client_management_path
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
return client_registration_response
|
|
106
|
+
|
|
107
|
+
# Client Management Endpoints
|
|
108
|
+
@router.get(client_management_path, description="Get OAuth client registration details", tags=["OAuth2"])
|
|
109
|
+
async def get_oauth_client(
|
|
110
|
+
client_id: str, _: Annotated[None, Depends(validate_client_registration_token)]
|
|
111
|
+
) -> ClientDetailsResponse:
|
|
112
|
+
"""Get OAuth client registration details"""
|
|
113
|
+
|
|
114
|
+
client_details = self.authenticator.get_oauth_client_details(client_id)
|
|
115
|
+
|
|
116
|
+
return client_details
|
|
117
|
+
|
|
118
|
+
@router.put(client_management_path, description="Update OAuth client registration", tags=["OAuth2"])
|
|
119
|
+
async def update_oauth_client(
|
|
120
|
+
client_id: str, request: ClientUpdateRequest, _: Annotated[None, Depends(validate_client_registration_token)]
|
|
121
|
+
) -> ClientUpdateResponse:
|
|
122
|
+
"""Update OAuth client registration and rotate access token"""
|
|
123
|
+
|
|
124
|
+
# Update the client and get new registration access token
|
|
125
|
+
client_details = self.authenticator.update_oauth_client_with_token_rotation(client_id, request)
|
|
126
|
+
|
|
127
|
+
return client_details
|
|
128
|
+
|
|
129
|
+
@router.delete(client_management_path, description="Revoke OAuth client registration", tags=["OAuth2"], responses={
|
|
130
|
+
204: { "description": "OAuth client registration revoked successfully" }
|
|
131
|
+
})
|
|
132
|
+
async def revoke_oauth_client(
|
|
133
|
+
client_id: str, _: Annotated[None, Depends(validate_client_registration_token)]
|
|
134
|
+
) -> Response:
|
|
135
|
+
"""Revoke (deactivate) OAuth client registration"""
|
|
136
|
+
|
|
137
|
+
self.authenticator.revoke_oauth_client(client_id)
|
|
138
|
+
return Response(status_code=204)
|
|
139
|
+
|
|
140
|
+
# Authorization Endpoint
|
|
141
|
+
@router.get("/authorize", description="OAuth 2.1 Authorization Endpoint", tags=["OAuth2"], response_model=None)
|
|
142
|
+
async def authorize_endpoint(
|
|
143
|
+
request: Request,
|
|
144
|
+
response_type: str = Query(default="code", description="OAuth response type"),
|
|
145
|
+
client_id: str = Query(..., description="OAuth client identifier"),
|
|
146
|
+
redirect_uri: str = Query(..., description="URI to redirect after authorization"),
|
|
147
|
+
scope: str = Query(default="read", description="Requested scope"),
|
|
148
|
+
state: str | None = Query(default=None, description="State parameter for CSRF protection"),
|
|
149
|
+
code_challenge: str = Query(..., description="PKCE code challenge (required)"),
|
|
150
|
+
code_challenge_method: str = Query(default="S256", description="PKCE code challenge method"),
|
|
151
|
+
user: UserInfoModel | None = Depends(self.get_current_user)
|
|
152
|
+
):
|
|
153
|
+
"""OAuth 2.1 authorization endpoint for initiating authorization code flow"""
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
# Validate response_type
|
|
157
|
+
if response_type != "code":
|
|
158
|
+
raise InvalidInputError(400, "unsupported_response_type", "Only 'code' response type is supported")
|
|
159
|
+
|
|
160
|
+
# Check if user is authenticated
|
|
161
|
+
if user is None:
|
|
162
|
+
# User is not authenticated - serve login page
|
|
163
|
+
return self.serve_login_page("/api/auth", request, client_id)
|
|
164
|
+
|
|
165
|
+
# TODO: Serve a page with an "authorize" button even if user is already authenticated
|
|
166
|
+
# Ex. if not request.session.get("authorization_approved"), redirect to a page with button that submits to "/approve-authorization"
|
|
167
|
+
|
|
168
|
+
# User is authenticated - generate authorization code
|
|
169
|
+
authorization_code = self.authenticator.create_authorization_code(
|
|
170
|
+
client_id=client_id,
|
|
171
|
+
username=user.username,
|
|
172
|
+
redirect_uri=redirect_uri,
|
|
173
|
+
scope=scope,
|
|
174
|
+
code_challenge=code_challenge,
|
|
175
|
+
code_challenge_method=code_challenge_method
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
# Redirect back to client with authorization code
|
|
179
|
+
success_params = f"?code={authorization_code}"
|
|
180
|
+
if state:
|
|
181
|
+
success_params += f"&state={state}"
|
|
182
|
+
|
|
183
|
+
return RedirectResponse(url=f"{redirect_uri}{success_params}")
|
|
184
|
+
|
|
185
|
+
except InvalidInputError as e:
|
|
186
|
+
if e.error == "invalid_request":
|
|
187
|
+
error_params = f"?error={e.error}&error_description={e.error_description.replace(' ', '+')}"
|
|
188
|
+
if state:
|
|
189
|
+
error_params += f"&state={state}"
|
|
190
|
+
return RedirectResponse(url=f"{redirect_uri}{error_params}")
|
|
191
|
+
else:
|
|
192
|
+
raise e
|
|
193
|
+
|
|
194
|
+
# Token Endpoint
|
|
195
|
+
@router.post("/token", description="OAuth 2.1 Token Endpoint", tags=["OAuth2"])
|
|
196
|
+
async def token_endpoint(
|
|
197
|
+
grant_type: str = Form(...),
|
|
198
|
+
code: str | None = Form(default=None),
|
|
199
|
+
redirect_uri: str | None = Form(default=None),
|
|
200
|
+
code_verifier: str | None = Form(default=None),
|
|
201
|
+
refresh_token: str | None = Form(default=None),
|
|
202
|
+
client_id: str | None = Form(default=None),
|
|
203
|
+
client_secret: str | None = Form(default=None)
|
|
204
|
+
) -> TokenResponse:
|
|
205
|
+
"""OAuth 2.1 token endpoint for exchanging authorization code or refresh token for access token"""
|
|
206
|
+
|
|
207
|
+
# Validate client credentials
|
|
208
|
+
auth_client_id = validate_oauth_client_credentials(client_id, client_secret)
|
|
209
|
+
|
|
210
|
+
# Get token expiry configuration
|
|
211
|
+
expiry_mins = self._get_access_token_expiry_minutes()
|
|
212
|
+
|
|
213
|
+
if grant_type == "authorization_code":
|
|
214
|
+
# Validate required parameters for authorization code flow
|
|
215
|
+
if not all([code, redirect_uri, code_verifier, auth_client_id]):
|
|
216
|
+
raise InvalidInputError(400, "invalid_request", "Missing required parameters for authorization_code grant")
|
|
217
|
+
|
|
218
|
+
# Type casts since we validated above
|
|
219
|
+
code = cast(str, code)
|
|
220
|
+
redirect_uri = cast(str, redirect_uri)
|
|
221
|
+
code_verifier = cast(str, code_verifier)
|
|
222
|
+
auth_client_id = cast(str, auth_client_id)
|
|
223
|
+
|
|
224
|
+
# Exchange authorization code for tokens
|
|
225
|
+
token_response = self.authenticator.exchange_authorization_code(
|
|
226
|
+
code=code,
|
|
227
|
+
client_id=auth_client_id,
|
|
228
|
+
redirect_uri=redirect_uri,
|
|
229
|
+
code_verifier=code_verifier,
|
|
230
|
+
access_token_expiry_minutes=expiry_mins
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
return token_response
|
|
234
|
+
|
|
235
|
+
elif grant_type == "refresh_token":
|
|
236
|
+
# Validate required parameters for refresh token flow
|
|
237
|
+
if not all([refresh_token, auth_client_id]):
|
|
238
|
+
raise InvalidInputError(400, "invalid_request", "Missing required parameters for refresh_token grant")
|
|
239
|
+
|
|
240
|
+
# Type casts since we validated above
|
|
241
|
+
refresh_token = cast(str, refresh_token)
|
|
242
|
+
auth_client_id = cast(str, auth_client_id)
|
|
243
|
+
|
|
244
|
+
# Refresh access token
|
|
245
|
+
token_response = self.authenticator.refresh_oauth_access_token(
|
|
246
|
+
refresh_token=refresh_token,
|
|
247
|
+
client_id=auth_client_id,
|
|
248
|
+
access_token_expiry_minutes=expiry_mins
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
return token_response
|
|
252
|
+
|
|
253
|
+
else:
|
|
254
|
+
raise InvalidInputError(400, "unsupported_grant_type", f"Grant type '{grant_type}' is not supported")
|
|
255
|
+
|
|
256
|
+
# Token Revocation Endpoint
|
|
257
|
+
@router.post("/token/revoke", description="OAuth 2.1 Token Revocation Endpoint", tags=["OAuth2"])
|
|
258
|
+
async def revoke_endpoint(
|
|
259
|
+
token: str = Form(..., description="The token to be revoked"),
|
|
260
|
+
token_type_hint: str | None = Form(default=None, description="Hint about the type of token being revoked"),
|
|
261
|
+
client_id: str | None = Form(default=None),
|
|
262
|
+
client_secret: str | None = Form(default=None)
|
|
263
|
+
) -> Response:
|
|
264
|
+
"""OAuth 2.1 token revocation endpoint for revoking refresh tokens"""
|
|
265
|
+
|
|
266
|
+
# Validate client credentials
|
|
267
|
+
auth_client_id = validate_oauth_client_credentials(client_id, client_secret)
|
|
268
|
+
|
|
269
|
+
# Revoke the token (per RFC 7009, always return 200 regardless of token validity)
|
|
270
|
+
try:
|
|
271
|
+
self.authenticator.revoke_oauth_token(auth_client_id, token, token_type_hint)
|
|
272
|
+
except InvalidInputError:
|
|
273
|
+
# Per OAuth spec, revocation endpoint should return 200 even for invalid tokens
|
|
274
|
+
pass
|
|
275
|
+
|
|
276
|
+
return Response(status_code=200)
|
|
277
|
+
|
|
278
|
+
# Authorization Server Metadata Endpoint (well-known endpoint)
|
|
279
|
+
@app.get("/.well-known/oauth-authorization-server", tags=["OAuth2"], description="OAuth 2.1 Authorization Server Metadata")
|
|
280
|
+
async def authorization_server_metadata(request: Request) -> OAuthServerMetadata:
|
|
281
|
+
"""OAuth 2.1 Authorization Server Metadata endpoint (RFC 8414)"""
|
|
282
|
+
|
|
283
|
+
# Get the base URL from the request
|
|
284
|
+
scheme = "http" if request.url.hostname in ("localhost", "127.0.0.1") else "https"
|
|
285
|
+
base_url = scheme + "://" + request.url.netloc
|
|
286
|
+
|
|
287
|
+
return OAuthServerMetadata(
|
|
288
|
+
issuer=base_url,
|
|
289
|
+
authorization_endpoint=f"{base_url}{router_path}/authorize",
|
|
290
|
+
token_endpoint=f"{base_url}{router_path}/token",
|
|
291
|
+
revocation_endpoint=f"{base_url}{router_path}/token/revoke",
|
|
292
|
+
registration_endpoint=f"{base_url}{router_path}/register",
|
|
293
|
+
scopes_supported=["read"],
|
|
294
|
+
response_types_supported=["code"],
|
|
295
|
+
grant_types_supported=["authorization_code", "refresh_token"],
|
|
296
|
+
token_endpoint_auth_methods_supported=["client_secret_post"],
|
|
297
|
+
code_challenge_methods_supported=["S256"]
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
app.include_router(router)
|