squirrels 0.5.0b3__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- squirrels/__init__.py +4 -0
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +440 -792
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/{_init_time_args.py → init_time_args.py} +23 -43
- squirrels/_arguments/{_run_time_args.py → run_time_args.py} +32 -68
- squirrels/_auth.py +590 -264
- squirrels/_command_line.py +130 -58
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +16 -15
- squirrels/_constants.py +36 -11
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +40 -34
- squirrels/_dataset_types.py +16 -11
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +9 -37
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +7 -6
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +155 -77
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +11 -55
- squirrels/_model_configs.py +5 -5
- squirrels/_model_queries.py +1 -1
- squirrels/_models.py +276 -143
- squirrels/_package_data/base_project/.env +1 -24
- squirrels/_package_data/base_project/.env.example +31 -17
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +13 -7
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +6 -6
- squirrels/_package_data/base_project/docker/Dockerfile +2 -2
- squirrels/_package_data/base_project/docker/compose.yml +1 -1
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +2 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +27 -17
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +7 -7
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +8 -3
- squirrels/_package_data/base_project/pyconfigs/context.py +26 -14
- squirrels/_package_data/base_project/pyconfigs/parameters.py +124 -81
- squirrels/_package_data/base_project/pyconfigs/user.py +48 -15
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +21 -31
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +41 -30
- squirrels/_parameters.py +560 -123
- squirrels/_project.py +487 -277
- squirrels/_py_module.py +71 -10
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +52 -13
- squirrels/_sources.py +29 -23
- squirrels/_utils.py +221 -42
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -2
- squirrels/auth.py +4 -0
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +10 -3
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -1
- squirrels/_api_response_models.py +0 -190
- squirrels/_dashboard_types.py +0 -82
- squirrels/_dashboards_io.py +0 -79
- squirrels-0.5.0b3.dist-info/METADATA +0 -110
- squirrels-0.5.0b3.dist-info/RECORD +0 -80
- /squirrels/_package_data/base_project/{assets → resources}/expenses.db +0 -0
- /squirrels/_package_data/base_project/{assets → resources}/weather.db +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b3.dist-info → squirrels-0.6.0.post0.dist-info}/licenses/LICENSE +0 -0
squirrels/_api_server.py
CHANGED
|
@@ -1,114 +1,180 @@
|
|
|
1
|
-
from typing import
|
|
2
|
-
from dataclasses import
|
|
3
|
-
from fastapi import
|
|
4
|
-
from fastapi.responses import HTMLResponse,
|
|
5
|
-
from fastapi.security import
|
|
6
|
-
from fastapi.
|
|
7
|
-
from
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from fastapi import FastAPI, Request, status
|
|
4
|
+
from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse, PlainTextResponse
|
|
5
|
+
from fastapi.security import HTTPBearer
|
|
6
|
+
from fastapi.staticfiles import StaticFiles
|
|
7
|
+
from fastapi.templating import Jinja2Templates
|
|
8
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
9
|
+
from starlette.responses import Response as StarletteResponse
|
|
10
|
+
from starlette.types import ASGIApp
|
|
8
11
|
from contextlib import asynccontextmanager
|
|
9
|
-
from cachetools import TTLCache
|
|
10
12
|
from argparse import Namespace
|
|
11
13
|
from pathlib import Path
|
|
12
|
-
|
|
14
|
+
from starlette.middleware.sessions import SessionMiddleware
|
|
15
|
+
import io, time, mimetypes, traceback, asyncio
|
|
13
16
|
|
|
14
|
-
from . import _constants as c, _utils as u,
|
|
17
|
+
from . import _constants as c, _utils as u, _parameter_sets as ps
|
|
18
|
+
from ._schemas import response_models as rm
|
|
15
19
|
from ._exceptions import InvalidInputError, ConfigurationError, FileExecutionError
|
|
16
|
-
from .
|
|
17
|
-
from ._manifest import
|
|
18
|
-
from .
|
|
19
|
-
from .
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
from .
|
|
23
|
-
from .
|
|
20
|
+
from ._http_error_responses import invalid_input_error_to_json_response
|
|
21
|
+
from ._manifest import AuthStrategy, AuthType
|
|
22
|
+
from ._request_context import set_request_id
|
|
23
|
+
from ._mcp_server import McpServerBuilder
|
|
24
|
+
|
|
25
|
+
# Import route modules
|
|
26
|
+
from ._api_routes.base import RouteBase
|
|
27
|
+
from ._api_routes.auth import AuthRoutes
|
|
28
|
+
from ._api_routes.project import ProjectRoutes
|
|
29
|
+
from ._api_routes.datasets import DatasetRoutes
|
|
30
|
+
from ._api_routes.dashboards import DashboardRoutes
|
|
31
|
+
from ._api_routes.data_management import DataManagementRoutes
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from contextlib import _AsyncGeneratorContextManager
|
|
35
|
+
from ._project import SquirrelsProject
|
|
36
|
+
|
|
24
37
|
|
|
25
38
|
mimetypes.add_type('application/javascript', '.js')
|
|
26
39
|
|
|
27
40
|
|
|
41
|
+
class SmartCORSMiddleware(BaseHTTPMiddleware):
|
|
42
|
+
"""
|
|
43
|
+
Custom CORS middleware that allows specific origins to use credentials
|
|
44
|
+
while still allowing all other origins without credentials.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, app, allowed_credential_origins: list[str], configurables_as_headers: list[str]):
|
|
48
|
+
super().__init__(app)
|
|
49
|
+
|
|
50
|
+
allowed_predefined_headers = ["Authorization", "Content-Type", "x-api-key"]
|
|
51
|
+
|
|
52
|
+
self.allowed_credential_origins = allowed_credential_origins
|
|
53
|
+
self.allowed_request_headers = ",".join(allowed_predefined_headers + configurables_as_headers)
|
|
54
|
+
|
|
55
|
+
async def dispatch(self, request: Request, call_next):
|
|
56
|
+
origin = request.headers.get("origin")
|
|
57
|
+
|
|
58
|
+
# Handle preflight requests
|
|
59
|
+
if request.method == "OPTIONS":
|
|
60
|
+
response = StarletteResponse(status_code=200)
|
|
61
|
+
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
|
62
|
+
response.headers["Access-Control-Allow-Headers"] = self.allowed_request_headers
|
|
63
|
+
|
|
64
|
+
else:
|
|
65
|
+
# Call the next middleware/route
|
|
66
|
+
response: StarletteResponse = await call_next(request)
|
|
67
|
+
|
|
68
|
+
# Always expose the Applied-Username header
|
|
69
|
+
response.headers["Access-Control-Expose-Headers"] = "Applied-Username"
|
|
70
|
+
|
|
71
|
+
if origin:
|
|
72
|
+
request_origin = f"{request.url.scheme}://{request.url.netloc}"
|
|
73
|
+
# Check if this origin is in the whitelist or if origin matches the host origin
|
|
74
|
+
if origin == request_origin or origin in self.allowed_credential_origins:
|
|
75
|
+
response.headers["Access-Control-Allow-Origin"] = origin
|
|
76
|
+
response.headers["Access-Control-Allow-Credentials"] = "true"
|
|
77
|
+
else:
|
|
78
|
+
# Allow all other origins but without credentials / cookies
|
|
79
|
+
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
80
|
+
else:
|
|
81
|
+
# No origin header (probably a non-browser request)
|
|
82
|
+
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
83
|
+
|
|
84
|
+
return response
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@dataclass
|
|
88
|
+
class FastAPIComponents:
|
|
89
|
+
"""
|
|
90
|
+
HTTP server components to mount the Squirrels project into an existing FastAPI application.
|
|
91
|
+
|
|
92
|
+
Properties:
|
|
93
|
+
mount_path: The mount path for the Squirrels project.
|
|
94
|
+
lifespan: The lifespan context manager for the Squirrels project.
|
|
95
|
+
fastapi_app: The FastAPI app for the Squirrels project.
|
|
96
|
+
"""
|
|
97
|
+
mount_path: str
|
|
98
|
+
lifespan: "_AsyncGeneratorContextManager"
|
|
99
|
+
fastapi_app: "FastAPI"
|
|
100
|
+
|
|
101
|
+
|
|
28
102
|
class ApiServer:
|
|
29
|
-
def __init__(self, no_cache: bool, project: SquirrelsProject) -> None:
|
|
103
|
+
def __init__(self, no_cache: bool, project: "SquirrelsProject") -> None:
|
|
30
104
|
"""
|
|
31
105
|
Constructor for ApiServer
|
|
32
106
|
|
|
33
107
|
Arguments:
|
|
34
108
|
no_cache (bool): Whether to disable caching
|
|
35
109
|
"""
|
|
36
|
-
self.no_cache = no_cache
|
|
37
110
|
self.project = project
|
|
38
111
|
self.logger = project._logger
|
|
39
112
|
self.env_vars = project._env_vars
|
|
40
|
-
self.j2_env = project._j2_env
|
|
41
113
|
self.manifest_cfg = project._manifest_cfg
|
|
42
114
|
self.seeds = project._seeds
|
|
43
|
-
self.conn_args = project._conn_args
|
|
44
115
|
self.conn_set = project._conn_set
|
|
45
|
-
self.authenticator = project._auth
|
|
46
|
-
self.param_args = project._param_args
|
|
47
116
|
self.param_cfg_set = project._param_cfg_set
|
|
48
|
-
self.context_func = project._context_func
|
|
49
117
|
self.dashboards = project._dashboards
|
|
118
|
+
|
|
119
|
+
# Initialize route modules
|
|
120
|
+
get_bearer_token = HTTPBearer(auto_error=False)
|
|
121
|
+
# self.oauth2_routes = OAuth2Routes(get_bearer_token, project, no_cache)
|
|
122
|
+
self.auth_routes = AuthRoutes(get_bearer_token, project, no_cache)
|
|
123
|
+
self.project_routes = ProjectRoutes(get_bearer_token, project, no_cache)
|
|
124
|
+
self.dataset_routes = DatasetRoutes(get_bearer_token, project, no_cache)
|
|
125
|
+
self.dashboard_routes = DashboardRoutes(get_bearer_token, project, no_cache)
|
|
126
|
+
self.data_management_routes = DataManagementRoutes(get_bearer_token, project, no_cache)
|
|
127
|
+
|
|
128
|
+
self._mcp_builder: McpServerBuilder | None = None
|
|
129
|
+
self._mcp_app: ASGIApp | None = None
|
|
50
130
|
|
|
51
131
|
|
|
52
|
-
async def
|
|
53
|
-
"""
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
132
|
+
async def _refresh_datasource_params(self) -> None:
|
|
133
|
+
"""
|
|
134
|
+
Background task to periodically refresh datasource parameter options.
|
|
135
|
+
Runs every N minutes as configured by SQRL_PARAMETERS__DATASOURCE_REFRESH_MINUTES (default: 60).
|
|
136
|
+
"""
|
|
137
|
+
refresh_minutes = self.env_vars.parameters_datasource_refresh_minutes
|
|
138
|
+
if refresh_minutes <= 0:
|
|
139
|
+
self.logger.info(f"The value of {c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES} is: {refresh_minutes} minutes")
|
|
140
|
+
self.logger.info(f"Datasource parameter refresh is disabled since the refresh interval is not positive.")
|
|
141
|
+
return
|
|
142
|
+
|
|
143
|
+
refresh_seconds = refresh_minutes * 60
|
|
144
|
+
self.logger.info(f"Starting datasource parameter refresh background task (every {refresh_minutes} minutes)")
|
|
145
|
+
|
|
146
|
+
default_conn_name = self.env_vars.connections_default_name_used
|
|
58
147
|
while True:
|
|
59
148
|
try:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
149
|
+
await asyncio.sleep(refresh_seconds)
|
|
150
|
+
self.logger.info("Refreshing datasource parameter options...")
|
|
151
|
+
|
|
152
|
+
# Fetch fresh dataframes from datasources in a thread pool to avoid blocking
|
|
153
|
+
loop = asyncio.get_running_loop()
|
|
154
|
+
df_dict = await loop.run_in_executor(
|
|
155
|
+
None,
|
|
156
|
+
ps.ParameterConfigsSetIO._get_df_dict_from_data_sources,
|
|
157
|
+
self.param_cfg_set,
|
|
158
|
+
default_conn_name,
|
|
159
|
+
self.seeds,
|
|
160
|
+
self.conn_set,
|
|
161
|
+
self.project._vdl_catalog_db_path
|
|
162
|
+
)
|
|
67
163
|
|
|
164
|
+
# Re-convert datasource parameters with fresh data
|
|
165
|
+
self.param_cfg_set._post_process_params(df_dict)
|
|
166
|
+
|
|
167
|
+
self.logger.info("Successfully refreshed datasource parameter options")
|
|
168
|
+
except asyncio.CancelledError:
|
|
169
|
+
self.logger.info("Datasource parameter refresh task cancelled")
|
|
170
|
+
break
|
|
68
171
|
except Exception as e:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
await asyncio.sleep(1) # Check every second
|
|
73
|
-
|
|
74
|
-
@asynccontextmanager
|
|
75
|
-
async def _run_background_tasks(self, app: FastAPI):
|
|
76
|
-
task = asyncio.create_task(self._monitor_for_staging_file())
|
|
77
|
-
yield
|
|
78
|
-
task.cancel()
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
def _validate_request_params(self, all_request_params: Mapping, params: Mapping) -> None:
|
|
82
|
-
invalid_params = [param for param in all_request_params if param not in params]
|
|
83
|
-
if params.get("x_verify_params", False) and invalid_params:
|
|
84
|
-
raise InvalidInputError(201, f"Invalid query parameters: {', '.join(invalid_params)}")
|
|
85
|
-
|
|
172
|
+
self.logger.error(f"Error refreshing datasource parameter options: {e}", exc_info=True)
|
|
173
|
+
# Continue the loop even if there's an error
|
|
86
174
|
|
|
87
|
-
def run(self, uvicorn_args: Namespace) -> None:
|
|
88
|
-
"""
|
|
89
|
-
Runs the API server with uvicorn for CLI "squirrels run"
|
|
90
|
-
|
|
91
|
-
Arguments:
|
|
92
|
-
uvicorn_args: List of arguments to pass to uvicorn.run. Currently only supports "host" and "port"
|
|
93
|
-
"""
|
|
94
|
-
start = time.time()
|
|
95
|
-
|
|
96
|
-
squirrels_version_path = f'/api/squirrels-v{sq_major_version}'
|
|
97
|
-
project_name = u.normalize_name_for_api(self.manifest_cfg.project_variables.name)
|
|
98
|
-
project_version = f"v{self.manifest_cfg.project_variables.major_version}"
|
|
99
|
-
project_metadata_path = squirrels_version_path + f"/project/{project_name}/{project_version}"
|
|
100
|
-
|
|
101
|
-
param_fields = self.param_cfg_set.get_all_api_field_info()
|
|
102
175
|
|
|
176
|
+
def _get_tags_metadata(self) -> list[dict]:
|
|
103
177
|
tags_metadata = [
|
|
104
|
-
{
|
|
105
|
-
"name": "Authentication",
|
|
106
|
-
"description": "Submit authentication credentials, and get token for authentication",
|
|
107
|
-
},
|
|
108
|
-
{
|
|
109
|
-
"name": "User Management",
|
|
110
|
-
"description": "Manage users and their attributes",
|
|
111
|
-
},
|
|
112
178
|
{
|
|
113
179
|
"name": "Project Metadata",
|
|
114
180
|
"description": "Get information on project such as name, version, and other API endpoints",
|
|
@@ -131,59 +197,154 @@ class ApiServer:
|
|
|
131
197
|
"description": f"Get parameters or results for dashboard '{dashboard_name}'",
|
|
132
198
|
})
|
|
133
199
|
|
|
200
|
+
tags_metadata.extend([
|
|
201
|
+
{
|
|
202
|
+
"name": "Authentication",
|
|
203
|
+
"description": "Submit authentication credentials and authorize with a session cookie",
|
|
204
|
+
},
|
|
205
|
+
{
|
|
206
|
+
"name": "User Management",
|
|
207
|
+
"description": "Manage users and their attributes",
|
|
208
|
+
}
|
|
209
|
+
])
|
|
210
|
+
return tags_metadata
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _print_banner(self, mount_path: str, host: str | None, port: int | None, is_standalone_mode: bool) -> None:
|
|
214
|
+
"""
|
|
215
|
+
Print the welcome banner with information about the running server.
|
|
216
|
+
"""
|
|
217
|
+
full_hostname = f"http://{host}:{port}" if host and port else ""
|
|
218
|
+
mount_path_stripped = mount_path.rstrip("/")
|
|
219
|
+
show_multiple_options = is_standalone_mode and mount_path_stripped != ""
|
|
220
|
+
|
|
221
|
+
banner_width = 80
|
|
222
|
+
|
|
223
|
+
print()
|
|
224
|
+
print("═" * banner_width)
|
|
225
|
+
print("👋 WELCOME TO SQUIRRELS!".center(banner_width))
|
|
226
|
+
print("═" * banner_width)
|
|
227
|
+
print()
|
|
228
|
+
print(" 🖥️ Application UI")
|
|
229
|
+
print(f" └─ Squirrels Studio: {full_hostname}{mount_path_stripped}/studio")
|
|
230
|
+
if show_multiple_options:
|
|
231
|
+
print(f" ├─ The root path also redirects to Squirrels Studio: {full_hostname}/")
|
|
232
|
+
print( " ├─ This requires an internet connection to load the JS and CSS files")
|
|
233
|
+
print( f" └─ Automatically uses mount path: {mount_path_stripped}")
|
|
234
|
+
print()
|
|
235
|
+
print(" 🔌 MCP Server URLs")
|
|
236
|
+
if show_multiple_options:
|
|
237
|
+
print(f" ├─ Option 1: {full_hostname}{mount_path_stripped}/mcp")
|
|
238
|
+
print(f" └─ Option 2: {full_hostname}/mcp")
|
|
239
|
+
else:
|
|
240
|
+
print(f" └─ Project MCP: {full_hostname}{mount_path_stripped}/mcp")
|
|
241
|
+
print()
|
|
242
|
+
print(" 📖 API Documentation (for the latest version of API contract)")
|
|
243
|
+
print(f" ├─ Swagger UI: {full_hostname}{mount_path_stripped}{c.LATEST_API_VERSION_MOUNT_PATH}/docs")
|
|
244
|
+
print(f" ├─ ReDoc UI: {full_hostname}{mount_path_stripped}{c.LATEST_API_VERSION_MOUNT_PATH}/redoc")
|
|
245
|
+
print(f" └─ OpenAPI Spec: {full_hostname}{mount_path_stripped}{c.LATEST_API_VERSION_MOUNT_PATH}/openapi.json")
|
|
246
|
+
print()
|
|
247
|
+
print(f" To explore all HTTP endpoints, see: {full_hostname}{mount_path_stripped}/docs")
|
|
248
|
+
print()
|
|
249
|
+
print("─" * banner_width)
|
|
250
|
+
print("✨ Server is running! Press CTRL+C to stop.".center(banner_width))
|
|
251
|
+
print("─" * banner_width)
|
|
252
|
+
print()
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def get_lifespan(
|
|
256
|
+
self, mount_path: str, host: str | None, port: int | None, is_standalone_mode: bool
|
|
257
|
+
) -> "_AsyncGeneratorContextManager":
|
|
258
|
+
"""
|
|
259
|
+
Get the lifespan context manager for the Squirrels project.
|
|
260
|
+
"""
|
|
261
|
+
@asynccontextmanager
|
|
262
|
+
async def lifespan(app: FastAPI | None = None):
|
|
263
|
+
"""App lifespan that includes MCP server lifecycle and background tasks."""
|
|
264
|
+
self._print_banner(mount_path, host, port, is_standalone_mode)
|
|
265
|
+
|
|
266
|
+
refresh_datasource_task = asyncio.create_task(self._refresh_datasource_params())
|
|
267
|
+
|
|
268
|
+
if self._mcp_builder:
|
|
269
|
+
async with self._mcp_builder.lifespan():
|
|
270
|
+
yield
|
|
271
|
+
else:
|
|
272
|
+
yield
|
|
273
|
+
|
|
274
|
+
refresh_datasource_task.cancel()
|
|
275
|
+
|
|
276
|
+
return lifespan
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def create_app(
|
|
280
|
+
self,
|
|
281
|
+
lifespan: "_AsyncGeneratorContextManager",
|
|
282
|
+
*,
|
|
283
|
+
mount_path: str = ""
|
|
284
|
+
) -> FastAPI:
|
|
285
|
+
"""
|
|
286
|
+
Create the FastAPI app for the Squirrels project.
|
|
287
|
+
"""
|
|
288
|
+
start = time.time()
|
|
289
|
+
|
|
290
|
+
project_name = self.manifest_cfg.project_variables.name
|
|
291
|
+
project_label = self.manifest_cfg.project_variables.label
|
|
292
|
+
|
|
293
|
+
param_fields = self.param_cfg_set.get_all_api_field_info()
|
|
294
|
+
tags_metadata = self._get_tags_metadata()
|
|
295
|
+
|
|
296
|
+
mount_path_stripped = mount_path.rstrip("/")
|
|
297
|
+
api_v0_mount_path = "/api/0"
|
|
298
|
+
|
|
134
299
|
app = FastAPI(
|
|
135
|
-
title=f"Squirrels
|
|
300
|
+
title=f"Squirrels for '{project_label}'",
|
|
301
|
+
lifespan=lifespan
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
api_v0_app = FastAPI(
|
|
305
|
+
title=f"Squirrels APIs for '{project_label}'", openapi_tags=tags_metadata,
|
|
136
306
|
description="For specifying parameter selections to dataset APIs, you can choose between using query parameters with the GET method or using request body with the POST method",
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
redoc_url=project_metadata_path+"/redoc"
|
|
307
|
+
openapi_url="/openapi.json",
|
|
308
|
+
docs_url="/docs",
|
|
309
|
+
redoc_url="/redoc"
|
|
141
310
|
)
|
|
142
311
|
|
|
143
|
-
|
|
144
|
-
headers = dict(request.scope["headers"])
|
|
145
|
-
request_id = uuid.uuid4().hex
|
|
146
|
-
headers[b"x-request-id"] = request_id.encode()
|
|
147
|
-
request.scope["headers"] = list(headers.items())
|
|
312
|
+
api_v0_app.add_middleware(SessionMiddleware, secret_key=self.env_vars.secret_key, max_age=None, same_site="none", https_only=True)
|
|
148
313
|
|
|
314
|
+
async def _log_request_run(request: Request) -> None:
|
|
149
315
|
try:
|
|
150
316
|
body = await request.json()
|
|
151
317
|
except Exception:
|
|
152
|
-
body = None
|
|
318
|
+
body = None # Non-JSON payloads may contain sensitive information, so we don't log them
|
|
319
|
+
|
|
320
|
+
partial_headers: dict[str, str] = {}
|
|
321
|
+
for header in request.headers.keys():
|
|
322
|
+
if header.startswith("x-") and header not in ["x-api-key"]:
|
|
323
|
+
partial_headers[header] = request.headers[header]
|
|
153
324
|
|
|
154
|
-
headers_dict = dict(request.headers)
|
|
155
325
|
path, params = request.url.path, dict(request.query_params)
|
|
156
326
|
path_with_params = f"{path}?{request.query_params}" if len(params) > 0 else path
|
|
157
|
-
data = {"request_method": request.method, "request_path": path, "request_params": params, "
|
|
158
|
-
info
|
|
159
|
-
self.logger.info(f'Running request: {request.method} {path_with_params}', extra={"data": data, "info": info})
|
|
160
|
-
|
|
161
|
-
def _get_request_id(request: Request) -> str:
|
|
162
|
-
return request.headers.get("x-request-id", "")
|
|
327
|
+
data = {"request_method": request.method, "request_path": path, "request_params": params, "request_body": body, "partial_headers": partial_headers}
|
|
328
|
+
self.logger.info(f'Running request: {request.method} {path_with_params}', data=data)
|
|
163
329
|
|
|
164
|
-
@
|
|
330
|
+
@api_v0_app.middleware("http")
|
|
165
331
|
async def catch_exceptions_middleware(request: Request, call_next):
|
|
332
|
+
# Generate and set request ID for this request
|
|
333
|
+
request_id = set_request_id()
|
|
334
|
+
|
|
166
335
|
buffer = io.StringIO()
|
|
167
336
|
try:
|
|
168
337
|
await _log_request_run(request)
|
|
169
|
-
|
|
338
|
+
response = await call_next(request)
|
|
170
339
|
except InvalidInputError as exc:
|
|
171
|
-
traceback.print_exc(file=buffer)
|
|
172
340
|
message = str(exc)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if exc.error_code == 61:
|
|
181
|
-
message = "The dataset depends on static data models that cannot be found. You may need to build the virtual data environment first."
|
|
182
|
-
status_code = status.HTTP_409_CONFLICT
|
|
183
|
-
else:
|
|
184
|
-
status_code = status.HTTP_400_BAD_REQUEST
|
|
185
|
-
response = JSONResponse(
|
|
186
|
-
status_code=status_code, content={"message": message, "blame": "API client", "error_code": exc.error_code}
|
|
341
|
+
self.logger.error(message)
|
|
342
|
+
strip_path_suffix = f"{mount_path_stripped}{api_v0_mount_path}"
|
|
343
|
+
response = invalid_input_error_to_json_response(
|
|
344
|
+
request,
|
|
345
|
+
exc,
|
|
346
|
+
oauth_resource_metadata_path="/.well-known/oauth-protected-resource",
|
|
347
|
+
strip_path_suffix=strip_path_suffix,
|
|
187
348
|
)
|
|
188
349
|
except FileExecutionError as exc:
|
|
189
350
|
traceback.print_exception(exc.error, file=buffer)
|
|
@@ -203,702 +364,189 @@ class ApiServer:
|
|
|
203
364
|
)
|
|
204
365
|
|
|
205
366
|
err_msg = buffer.getvalue()
|
|
206
|
-
|
|
207
|
-
|
|
367
|
+
if err_msg:
|
|
368
|
+
self.logger.error(err_msg)
|
|
369
|
+
|
|
370
|
+
# Add request ID to response header
|
|
371
|
+
response.headers["X-Request-ID"] = request_id
|
|
372
|
+
|
|
208
373
|
return response
|
|
209
374
|
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
375
|
+
# Configure CORS with smart credential handling
|
|
376
|
+
allowed_credential_origins = self.env_vars.auth_credential_origins
|
|
377
|
+
|
|
378
|
+
configurables_as_headers = []
|
|
379
|
+
for name in self.manifest_cfg.configurables.keys():
|
|
380
|
+
configurables_as_headers.append(f"x-config-{name}") # underscore version
|
|
381
|
+
configurables_as_headers.append(f"x-config-{u.normalize_name_for_api(name)}") # dash version
|
|
382
|
+
|
|
383
|
+
api_v0_app.add_middleware(SmartCORSMiddleware, allowed_credential_origins=allowed_credential_origins, configurables_as_headers=configurables_as_headers)
|
|
384
|
+
|
|
385
|
+
# Setup route modules for the v0 API
|
|
386
|
+
get_parameters_definition = self.project_routes.setup_routes(api_v0_app, param_fields)
|
|
387
|
+
self.data_management_routes.setup_routes(api_v0_app, param_fields)
|
|
388
|
+
self.dataset_routes.setup_routes(api_v0_app, param_fields, get_parameters_definition)
|
|
389
|
+
self.dashboard_routes.setup_routes(api_v0_app, param_fields, get_parameters_definition)
|
|
390
|
+
# self.oauth2_routes.setup_routes(api_v0_app)
|
|
391
|
+
self.auth_routes.setup_routes(api_v0_app)
|
|
392
|
+
|
|
393
|
+
app.mount(api_v0_mount_path, api_v0_app)
|
|
394
|
+
|
|
395
|
+
@app.get("/health", summary="Health check endpoint")
|
|
396
|
+
async def health() -> PlainTextResponse:
|
|
397
|
+
return PlainTextResponse(status_code=200, content="OK")
|
|
398
|
+
|
|
399
|
+
# Mount static files from the public directories if they exist
|
|
400
|
+
# This allows users to serve public-facing static assets (images, CSS, JS, etc.) with HTTP requests
|
|
401
|
+
public_dirs = ["public"]
|
|
402
|
+
for public_dir in public_dirs:
|
|
403
|
+
static_dir = Path(self.project._project_path) / "resources" / public_dir
|
|
404
|
+
if static_dir.exists() and static_dir.is_dir():
|
|
405
|
+
app.mount(f"/{public_dir}", StaticFiles(directory=str(static_dir)), name=public_dir)
|
|
406
|
+
self.logger.info(f"Mounted static files from: {str(static_dir)}")
|
|
407
|
+
|
|
408
|
+
# Build the MCP server after routes are set up
|
|
409
|
+
enforce_mcp_oauth = (
|
|
410
|
+
self.manifest_cfg.project_variables.auth_strategy == AuthStrategy.EXTERNAL
|
|
411
|
+
and self.manifest_cfg.project_variables.auth_type == AuthType.REQUIRED
|
|
213
412
|
)
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
def get_query_models_for_parameters(widget_parameters: list[str] | None):
|
|
256
|
-
predefined_params = [
|
|
257
|
-
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
|
|
258
|
-
APIParamFieldInfo("x_parent_param", str, description="The parameter name used for parameter updates. If not provided, then all parameters are retrieved"),
|
|
259
|
-
]
|
|
260
|
-
return _get_query_models_helper(widget_parameters, predefined_params)
|
|
261
|
-
|
|
262
|
-
def get_query_models_for_dataset(widget_parameters: list[str] | None):
|
|
263
|
-
predefined_params = [
|
|
264
|
-
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
|
|
265
|
-
APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
|
|
266
|
-
APIParamFieldInfo("x_select", list[str], examples=[[]], description="The columns to select from the dataset. All are returned if not specified"),
|
|
267
|
-
APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
|
|
268
|
-
APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
|
|
269
|
-
]
|
|
270
|
-
return _get_query_models_helper(widget_parameters, predefined_params)
|
|
271
|
-
|
|
272
|
-
def get_query_models_for_dashboard(widget_parameters: list[str] | None):
|
|
273
|
-
predefined_params = [
|
|
274
|
-
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dashboard"),
|
|
275
|
-
]
|
|
276
|
-
return _get_query_models_helper(widget_parameters, predefined_params)
|
|
277
|
-
|
|
278
|
-
def get_query_models_for_querying_models():
|
|
279
|
-
predefined_params = [
|
|
280
|
-
APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid"),
|
|
281
|
-
APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
|
|
282
|
-
APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
|
|
283
|
-
APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
|
|
284
|
-
APIParamFieldInfo("x_sql_query", str, description="The SQL query to execute on the data models"),
|
|
285
|
-
]
|
|
286
|
-
return _get_query_models_helper(None, predefined_params)
|
|
287
|
-
|
|
288
|
-
def _get_section_from_request_path(request: Request, section: int) -> str:
|
|
289
|
-
url_path: str = request.scope['route'].path
|
|
290
|
-
return url_path.split('/')[section]
|
|
291
|
-
|
|
292
|
-
def get_dataset_name(request: Request, section: int) -> str:
|
|
293
|
-
dataset_raw = _get_section_from_request_path(request, section)
|
|
294
|
-
return u.normalize_name(dataset_raw)
|
|
295
|
-
|
|
296
|
-
def get_dashboard_name(request: Request, section: int) -> str:
|
|
297
|
-
dashboard_raw = _get_section_from_request_path(request, section)
|
|
298
|
-
return u.normalize_name(dashboard_raw)
|
|
299
|
-
|
|
300
|
-
expiry_mins = self.env_vars.get(c.SQRL_AUTH_TOKEN_EXPIRE_MINUTES, 30)
|
|
301
|
-
try:
|
|
302
|
-
expiry_mins = int(expiry_mins)
|
|
303
|
-
except ValueError:
|
|
304
|
-
raise ConfigurationError(f"Value for environment variable {c.SQRL_AUTH_TOKEN_EXPIRE_MINUTES} is not an integer, got: {expiry_mins}")
|
|
305
|
-
|
|
306
|
-
# Project Metadata API
|
|
307
|
-
|
|
308
|
-
@app.get(project_metadata_path, tags=["Project Metadata"], response_class=JSONResponse)
|
|
309
|
-
async def get_project_metadata(request: Request) -> arm.ProjectModel:
|
|
310
|
-
return arm.ProjectModel(
|
|
311
|
-
name=project_name,
|
|
312
|
-
version=project_version,
|
|
313
|
-
label=self.manifest_cfg.project_variables.label,
|
|
314
|
-
description=self.manifest_cfg.project_variables.description,
|
|
315
|
-
squirrels_version=__version__
|
|
413
|
+
self._mcp_builder = McpServerBuilder(
|
|
414
|
+
project_name=project_name,
|
|
415
|
+
project_label=project_label,
|
|
416
|
+
max_rows_for_ai=self.env_vars.datasets_max_rows_for_ai,
|
|
417
|
+
get_user_from_headers=self.project_routes.get_user_from_headers,
|
|
418
|
+
get_data_catalog_for_mcp=self.project_routes._get_data_catalog_for_mcp,
|
|
419
|
+
get_dataset_parameters_for_mcp=self.dataset_routes._get_dataset_parameters_for_mcp,
|
|
420
|
+
get_dataset_results_for_mcp=self.dataset_routes._get_dataset_results_for_mcp,
|
|
421
|
+
enforce_oauth_bearer=enforce_mcp_oauth,
|
|
422
|
+
oauth_resource_metadata_path="/.well-known/oauth-protected-resource",
|
|
423
|
+
www_authenticate_strip_path_suffix=f"{mount_path_stripped}/mcp",
|
|
424
|
+
)
|
|
425
|
+
self._mcp_app = self._mcp_builder.get_asgi_app()
|
|
426
|
+
|
|
427
|
+
# Mount MCP server
|
|
428
|
+
app.add_route("/mcp", self._mcp_app, methods=["GET", "POST"])
|
|
429
|
+
|
|
430
|
+
# Get API versions and other endpoints
|
|
431
|
+
@app.get("/", summary="Explore all HTTP endpoints")
|
|
432
|
+
async def explore_http_endpoints(request: Request) -> rm.ExploreEndpointsModel:
|
|
433
|
+
_, root_path = RouteBase._get_base_url_for_current_app(request)
|
|
434
|
+
return rm.ExploreEndpointsModel(
|
|
435
|
+
health_url=root_path + "/health",
|
|
436
|
+
api_versions={
|
|
437
|
+
"0": rm.APIVersionMetadataModel(
|
|
438
|
+
project_metadata_url=root_path + api_v0_mount_path + "/",
|
|
439
|
+
documentation_routes=rm.DocumentationRoutesModel(
|
|
440
|
+
swagger_url=root_path + api_v0_mount_path + "/docs",
|
|
441
|
+
redoc_url=root_path + api_v0_mount_path + "/redoc",
|
|
442
|
+
openapi_url=root_path + api_v0_mount_path + "/openapi.json"
|
|
443
|
+
)
|
|
444
|
+
)
|
|
445
|
+
},
|
|
446
|
+
documentation_routes=rm.DocumentationRoutesModel(
|
|
447
|
+
swagger_url=root_path + "/docs",
|
|
448
|
+
redoc_url=root_path + "/redoc",
|
|
449
|
+
openapi_url=root_path + "/openapi.json"
|
|
450
|
+
),
|
|
451
|
+
mcp_server_url=root_path + "/mcp",
|
|
452
|
+
studio_url=root_path + "/studio",
|
|
316
453
|
)
|
|
317
454
|
|
|
318
|
-
#
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=login_path, auto_error=False)
|
|
322
|
-
|
|
323
|
-
async def get_current_user(response: Response, token: str = Depends(oauth2_scheme)) -> BaseUser | None:
|
|
324
|
-
user = self.authenticator.get_user_from_token(token)
|
|
325
|
-
username = "" if user is None else user.username
|
|
326
|
-
response.headers["Applied-Username"] = username
|
|
327
|
-
return user
|
|
328
|
-
|
|
329
|
-
## Login API
|
|
330
|
-
@app.post(login_path, tags=["Authentication"])
|
|
331
|
-
async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()) -> arm.LoginReponse:
|
|
332
|
-
user = self.authenticator.get_user(form_data.username, form_data.password)
|
|
333
|
-
access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_mins)
|
|
334
|
-
return arm.LoginReponse(access_token=access_token, token_type="bearer", username=user.username, is_admin=user.is_admin, expiry_time=expiry)
|
|
335
|
-
|
|
336
|
-
## Change Password API
|
|
337
|
-
change_password_path = project_metadata_path + '/change-password'
|
|
338
|
-
|
|
339
|
-
class ChangePasswordRequest(BaseModel):
|
|
340
|
-
old_password: str
|
|
341
|
-
new_password: str
|
|
342
|
-
|
|
343
|
-
@app.put(change_password_path, description="Change the password for the current user", tags=["Authentication"])
|
|
344
|
-
async def change_password(request: ChangePasswordRequest, user: BaseUser | None = Depends(get_current_user)) -> None:
|
|
345
|
-
if user is None:
|
|
346
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
347
|
-
self.authenticator.change_password(user.username, request.old_password, request.new_password)
|
|
348
|
-
|
|
349
|
-
## Token API
|
|
350
|
-
tokens_path = project_metadata_path + '/tokens'
|
|
351
|
-
|
|
352
|
-
class TokenRequestBody(BaseModel):
|
|
353
|
-
title: str | None = Field(default=None, description=f"The title of the token. If not provided, a temporary token is created (expiring in {expiry_mins} minutes) and cannot be revoked")
|
|
354
|
-
expiry_minutes: int | None = Field(
|
|
355
|
-
default=None,
|
|
356
|
-
description=f"The number of minutes the token is valid for (or indefinitely if not provided). Ignored and set to {expiry_mins} minutes if title is not provided."
|
|
357
|
-
)
|
|
455
|
+
# Add Squirrels Studio
|
|
456
|
+
templates = Jinja2Templates(directory=str(Path(__file__).parent / "_package_data" / "templates"))
|
|
358
457
|
|
|
359
|
-
@app.
|
|
360
|
-
async def
|
|
361
|
-
|
|
362
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
458
|
+
@app.get("/studio", include_in_schema=False)
|
|
459
|
+
async def squirrels_studio(request: Request):
|
|
460
|
+
sqrl_studio_base_url = self.env_vars.studio_base_url
|
|
363
461
|
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
462
|
+
# IMPORTANT: avoid `request.url_for("explore_http_endpoints")` here.
|
|
463
|
+
# When multiple Squirrels FastAPI apps are mounted into a root app, that route name
|
|
464
|
+
# can become ambiguous and resolve to the wrong mounted app. `request.base_url`
|
|
465
|
+
# is derived from the current request scope (including `root_path`), so it always
|
|
466
|
+
# points at the correct mounted Squirrels server instance.
|
|
467
|
+
_, mount_path = RouteBase._get_base_url_for_current_app(request)
|
|
368
468
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
if user is None:
|
|
376
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
377
|
-
return self.authenticator.get_all_tokens(user.username)
|
|
378
|
-
|
|
379
|
-
## Revoke Token API
|
|
380
|
-
revoke_token_path = project_metadata_path + '/tokens/{token_id}'
|
|
381
|
-
|
|
382
|
-
@app.delete(revoke_token_path, description="Revoke a token", tags=["Authentication"])
|
|
383
|
-
async def revoke_token(token_id: str, user: BaseUser | None = Depends(get_current_user)) -> None:
|
|
384
|
-
if user is None:
|
|
385
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
386
|
-
self.authenticator.revoke_token(user.username, token_id)
|
|
387
|
-
|
|
388
|
-
## Get Authenticated User Fields From Token API
|
|
389
|
-
get_me_path = project_metadata_path + '/me'
|
|
390
|
-
|
|
391
|
-
fields_without_username = {
|
|
392
|
-
k: (v.annotation, v.default)
|
|
393
|
-
for k, v in self.authenticator.User.model_fields.items()
|
|
394
|
-
if k != "username"
|
|
395
|
-
}
|
|
396
|
-
UserModel = create_model("UserModel", __base__=BaseModel, **fields_without_username) # type: ignore
|
|
397
|
-
|
|
398
|
-
class UserWithoutUsername(UserModel):
|
|
399
|
-
pass
|
|
400
|
-
|
|
401
|
-
class UserWithUsername(UserModel):
|
|
402
|
-
username: str
|
|
469
|
+
context = {
|
|
470
|
+
"sqrl_studio_base_url": sqrl_studio_base_url,
|
|
471
|
+
"mount_path": mount_path,
|
|
472
|
+
}
|
|
473
|
+
template = templates.get_template("squirrels_studio.html")
|
|
474
|
+
return HTMLResponse(content=template.render(context))
|
|
403
475
|
|
|
404
|
-
|
|
405
|
-
|
|
476
|
+
self.logger.log_activity_time("creating app server", start)
|
|
477
|
+
return app
|
|
478
|
+
|
|
479
|
+
def get_fastapi_components(
|
|
480
|
+
self, host: str, port: int, *,
|
|
481
|
+
mount_path_format: str = "/analytics/{project_name}/v{project_version}",
|
|
482
|
+
is_standalone_mode: bool = False
|
|
483
|
+
) -> FastAPIComponents:
|
|
484
|
+
"""
|
|
485
|
+
Get the FastAPI components for the Squirrels project including mount path, lifespan, and FastAPI app.
|
|
486
|
+
"""
|
|
487
|
+
project_name = u.normalize_name_for_api(self.manifest_cfg.project_variables.name)
|
|
488
|
+
project_version = self.manifest_cfg.project_variables.major_version
|
|
489
|
+
mount_path = mount_path_format.format(project_name=project_name, project_version=project_version)
|
|
406
490
|
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
raise InvalidInputError(1, "Invalid authorization token")
|
|
411
|
-
return UserWithUsername(**user.model_dump(mode='json'))
|
|
491
|
+
lifespan = self.get_lifespan(mount_path, host, port, is_standalone_mode)
|
|
492
|
+
fastapi_app = self.create_app(lifespan, mount_path=mount_path)
|
|
493
|
+
return FastAPIComponents(mount_path=mount_path, lifespan=lifespan, fastapi_app=fastapi_app)
|
|
412
494
|
|
|
413
|
-
|
|
495
|
+
def run(self, uvicorn_args: Namespace) -> None:
|
|
496
|
+
"""
|
|
497
|
+
Runs the API server with uvicorn for CLI "squirrels run"
|
|
414
498
|
|
|
415
|
-
|
|
416
|
-
|
|
499
|
+
Arguments:
|
|
500
|
+
uvicorn_args: List of arguments to pass to uvicorn.run. Supports "host", "port", and "forwarded_allow_ips"
|
|
501
|
+
"""
|
|
502
|
+
host = uvicorn_args.host
|
|
503
|
+
port = uvicorn_args.port
|
|
504
|
+
forwarded_allow_ips = uvicorn_args.forwarded_allow_ips
|
|
505
|
+
|
|
506
|
+
server = self.get_fastapi_components(host=host, port=port, is_standalone_mode=True)
|
|
507
|
+
|
|
508
|
+
root_app = FastAPI(lifespan=server.lifespan)
|
|
509
|
+
root_app.mount(server.mount_path, server.fastapi_app)
|
|
510
|
+
|
|
511
|
+
# Enable CORS handling on the root app so preflight requests (OPTIONS)
|
|
512
|
+
# to top-level endpoints like `/.well-known/oauth-protected-resource` do not 405.
|
|
513
|
+
allowed_credential_origins = self.env_vars.auth_credential_origins
|
|
514
|
+
configurables_as_headers: list[str] = []
|
|
515
|
+
for name in self.manifest_cfg.configurables.keys():
|
|
516
|
+
configurables_as_headers.append(f"x-config-{name}") # underscore version
|
|
517
|
+
configurables_as_headers.append(f"x-config-{u.normalize_name_for_api(name)}") # dash version
|
|
518
|
+
|
|
519
|
+
root_app.add_middleware(
|
|
520
|
+
SmartCORSMiddleware,
|
|
521
|
+
allowed_credential_origins=allowed_credential_origins,
|
|
522
|
+
configurables_as_headers=configurables_as_headers,
|
|
523
|
+
)
|
|
417
524
|
|
|
418
|
-
@
|
|
419
|
-
async def
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
## Add User API
|
|
423
|
-
add_user_path = project_metadata_path + '/users'
|
|
424
|
-
|
|
425
|
-
@app.post(add_user_path, description="Add a new user by providing details for username, password, and user fields", tags=["User Management"])
|
|
426
|
-
async def add_user(
|
|
427
|
-
new_user: AddUserRequestBody, user: BaseUser | None = Depends(get_current_user)
|
|
428
|
-
) -> None:
|
|
429
|
-
if user is None or not user.is_admin:
|
|
430
|
-
raise InvalidInputError(20, "Authorized user is forbidden to add new users")
|
|
431
|
-
self.authenticator.add_user(new_user.username, new_user.model_dump(mode='json', exclude={"username"}))
|
|
432
|
-
|
|
433
|
-
## Update User API
|
|
434
|
-
update_user_path = project_metadata_path + '/users/{username}'
|
|
435
|
-
|
|
436
|
-
@app.put(update_user_path, description="Update the user of the given username given the new user details", tags=["User Management"])
|
|
437
|
-
async def update_user(
|
|
438
|
-
username: str, updated_user: UserWithoutUsername, user: BaseUser | None = Depends(get_current_user)
|
|
439
|
-
) -> None:
|
|
440
|
-
if user is None or not user.is_admin:
|
|
441
|
-
raise InvalidInputError(20, "Authorized user is forbidden to update users")
|
|
442
|
-
self.authenticator.add_user(username, updated_user.model_dump(mode='json'), update_user=True)
|
|
443
|
-
|
|
444
|
-
## List Users API
|
|
445
|
-
list_users_path = project_metadata_path + '/users'
|
|
446
|
-
|
|
447
|
-
@app.get(list_users_path, tags=["User Management"])
|
|
448
|
-
async def list_all_users() -> list[UserWithUsername]:
|
|
449
|
-
return self.authenticator.get_all_users()
|
|
450
|
-
|
|
451
|
-
## Delete User API
|
|
452
|
-
delete_user_path = project_metadata_path + '/users/{username}'
|
|
453
|
-
|
|
454
|
-
@app.delete(delete_user_path, tags=["User Management"])
|
|
455
|
-
async def delete_user(username: str, user: BaseUser | None = Depends(get_current_user)) -> None:
|
|
456
|
-
if user is None or not user.is_admin:
|
|
457
|
-
raise InvalidInputError(21, "Authorized user is forbidden to delete users")
|
|
458
|
-
if username == user.username:
|
|
459
|
-
raise InvalidInputError(22, "Cannot delete your own user")
|
|
460
|
-
self.authenticator.delete_user(username)
|
|
461
|
-
|
|
462
|
-
# Data Catalog API
|
|
463
|
-
data_catalog_path = project_metadata_path + '/data-catalog'
|
|
525
|
+
@root_app.get("/.well-known/oauth-protected-resource", tags=["Authentication"])
|
|
526
|
+
async def oauth_protected_resource(request: Request) -> rm.OAuthProtectedResourceMetadata:
|
|
527
|
+
resource = str(request.base_url).rstrip("/")
|
|
464
528
|
|
|
465
|
-
|
|
466
|
-
|
|
529
|
+
auth_servers: list[str] = []
|
|
530
|
+
for provider in self.project._auth.auth_providers:
|
|
531
|
+
auth_servers.append(provider.provider_configs.server_url)
|
|
467
532
|
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
parameters = self.param_cfg_set.apply_selections(None, {}, user)
|
|
473
|
-
parameters_model = parameters.to_api_response_model0()
|
|
474
|
-
full_parameters_list = [p.name for p in parameters_model.parameters]
|
|
475
|
-
|
|
476
|
-
dataset_items: list[arm.DatasetItemModel] = []
|
|
477
|
-
for name, config in self.manifest_cfg.datasets.items():
|
|
478
|
-
if self.authenticator.can_user_access_scope(user, config.scope):
|
|
479
|
-
name_normalized = u.normalize_name_for_api(name)
|
|
480
|
-
metadata = self.project.dataset_metadata(name).to_json()
|
|
481
|
-
parameters = config.parameters if config.parameters is not None else full_parameters_list
|
|
482
|
-
dataset_items.append(arm.DatasetItemModel(
|
|
483
|
-
name=name_normalized, label=config.label,
|
|
484
|
-
description=config.description,
|
|
485
|
-
schema=metadata["schema"], # type: ignore
|
|
486
|
-
parameters=parameters,
|
|
487
|
-
parameters_path=dataset_parameters_path.format(dataset=name_normalized),
|
|
488
|
-
result_path=dataset_results_path.format(dataset=name_normalized)
|
|
489
|
-
))
|
|
490
|
-
|
|
491
|
-
dashboard_items: list[arm.DashboardItemModel] = []
|
|
492
|
-
for name, dashboard in self.dashboards.items():
|
|
493
|
-
config = dashboard.config
|
|
494
|
-
if self.authenticator.can_user_access_scope(user, config.scope):
|
|
495
|
-
name_normalized = u.normalize_name_for_api(name)
|
|
496
|
-
|
|
497
|
-
try:
|
|
498
|
-
dashboard_format = self.dashboards[name].get_dashboard_format()
|
|
499
|
-
except KeyError:
|
|
500
|
-
raise ConfigurationError(f"No dashboard file found for: {name}")
|
|
501
|
-
|
|
502
|
-
parameters = config.parameters if config.parameters is not None else full_parameters_list
|
|
503
|
-
dashboard_items.append(arm.DashboardItemModel(
|
|
504
|
-
name=name, label=config.label,
|
|
505
|
-
description=config.description,
|
|
506
|
-
result_format=dashboard_format,
|
|
507
|
-
parameters=parameters,
|
|
508
|
-
parameters_path=dashboard_parameters_path.format(dashboard=name_normalized),
|
|
509
|
-
result_path=dashboard_results_path.format(dashboard=name_normalized)
|
|
510
|
-
))
|
|
511
|
-
|
|
512
|
-
if user and user.is_admin:
|
|
513
|
-
compiled_dag = await self.project._get_compiled_dag(user=user)
|
|
514
|
-
connections_items = self.project._get_all_connections()
|
|
515
|
-
data_models = self.project._get_all_data_models(compiled_dag)
|
|
516
|
-
lineage_items = self.project._get_all_data_lineage(compiled_dag)
|
|
517
|
-
else:
|
|
518
|
-
connections_items = []
|
|
519
|
-
data_models = []
|
|
520
|
-
lineage_items = []
|
|
521
|
-
|
|
522
|
-
return arm.CatalogModel(
|
|
523
|
-
parameters=parameters_model.parameters,
|
|
524
|
-
datasets=dataset_items,
|
|
525
|
-
dashboards=dashboard_items,
|
|
526
|
-
connections=connections_items,
|
|
527
|
-
models=data_models,
|
|
528
|
-
lineage=lineage_items,
|
|
533
|
+
return rm.OAuthProtectedResourceMetadata(
|
|
534
|
+
resource=resource,
|
|
535
|
+
authorization_servers=list(set(auth_servers)),
|
|
536
|
+
scopes_supported=["email", "profile"],
|
|
529
537
|
)
|
|
530
|
-
|
|
531
|
-
@app.get(data_catalog_path, tags=["Project Metadata"], summary="Get catalog of datasets and dashboards available for user")
|
|
532
|
-
async def get_data_catalog(request: Request, user: BaseUser | None = Depends(get_current_user)) -> arm.CatalogModel:
|
|
533
|
-
"""
|
|
534
|
-
Get catalog of datasets and dashboards available for the authenticated user.
|
|
535
|
-
|
|
536
|
-
For admin users, this endpoint will also return detailed information about all models and their lineage in the project.
|
|
537
|
-
"""
|
|
538
|
-
return await get_data_catalog0(user)
|
|
539
|
-
|
|
540
|
-
# Parameters API Helpers
|
|
541
|
-
parameters_description = "Selections of one parameter may cascade the available options in another parameter. " \
|
|
542
|
-
"For example, if the dataset has parameters for 'country' and 'city', available options for 'city' would " \
|
|
543
|
-
"depend on the selected option 'country'. If a parameter has 'trigger_refresh' as true, provide the parameter " \
|
|
544
|
-
"selection to this endpoint whenever it changes to refresh the parameter options of children parameters."
|
|
545
|
-
|
|
546
|
-
async def get_parameters_helper(
|
|
547
|
-
parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
548
|
-
user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
549
|
-
) -> ParameterSet:
|
|
550
|
-
selections_dict = dict(selections)
|
|
551
|
-
if "x_parent_param" not in selections_dict:
|
|
552
|
-
if len(selections_dict) > 1:
|
|
553
|
-
raise InvalidInputError(202, f"The parameters endpoint takes at most 1 widget parameter selection (unless x_parent_param is provided). Got {selections_dict}")
|
|
554
|
-
elif len(selections_dict) == 1:
|
|
555
|
-
parent_param = next(iter(selections_dict))
|
|
556
|
-
selections_dict["x_parent_param"] = parent_param
|
|
557
|
-
|
|
558
|
-
parent_param = selections_dict.get("x_parent_param")
|
|
559
|
-
if parent_param is not None and parent_param not in selections_dict:
|
|
560
|
-
# this condition is possible for multi-select parameters with empty selection
|
|
561
|
-
selections_dict[parent_param] = list()
|
|
562
|
-
|
|
563
|
-
if not self.authenticator.can_user_access_scope(user, entity_scope):
|
|
564
|
-
raise self.project._permission_error(user, entity_type, entity_name, entity_scope.name)
|
|
565
|
-
|
|
566
|
-
param_set = self.param_cfg_set.apply_selections(parameters_tuple, selections_dict, user, parent_param=parent_param)
|
|
567
|
-
return param_set
|
|
568
|
-
|
|
569
|
-
parameters_cache_size = int(self.env_vars.get(c.SQRL_PARAMETERS_CACHE_SIZE, 1024))
|
|
570
|
-
parameters_cache_ttl = int(self.env_vars.get(c.SQRL_PARAMETERS_CACHE_TTL_MINUTES, 60))
|
|
571
|
-
params_cache = TTLCache(maxsize=parameters_cache_size, ttl=parameters_cache_ttl*60)
|
|
572
|
-
|
|
573
|
-
async def get_parameters_cachable(
|
|
574
|
-
parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
575
|
-
user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
576
|
-
) -> ParameterSet:
|
|
577
|
-
return await do_cachable_action(params_cache, get_parameters_helper, parameters_tuple, entity_type, entity_name, entity_scope, user, selections)
|
|
578
|
-
|
|
579
|
-
async def get_parameters_definition(
|
|
580
|
-
parameters_list: list[str] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
|
|
581
|
-
user: BaseUser | None, all_request_params: dict, params: Mapping
|
|
582
|
-
) -> arm.ParametersModel:
|
|
583
|
-
self._validate_request_params(all_request_params, params)
|
|
584
|
-
|
|
585
|
-
get_parameters_function = get_parameters_helper if self.no_cache else get_parameters_cachable
|
|
586
|
-
selections = get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
|
|
587
|
-
parameters_tuple = tuple(parameters_list) if parameters_list is not None else None
|
|
588
|
-
result = await get_parameters_function(parameters_tuple, entity_type, entity_name, entity_scope, user, selections)
|
|
589
|
-
return result.to_api_response_model0()
|
|
590
|
-
|
|
591
|
-
def validate_parameters_list(parameters: list[str] | None, entity_type: str) -> None:
|
|
592
|
-
if parameters is None:
|
|
593
|
-
return
|
|
594
|
-
for param in parameters:
|
|
595
|
-
if param not in param_fields:
|
|
596
|
-
all_params = list(param_fields.keys())
|
|
597
|
-
raise ConfigurationError(
|
|
598
|
-
f"{entity_type} '{dataset_name}' use parameter '{param}' which doesn't exist. Available parameters are:"
|
|
599
|
-
f"\n {all_params}"
|
|
600
|
-
)
|
|
601
|
-
|
|
602
|
-
# Project-Level Parameters API
|
|
603
|
-
project_level_parameters_path = project_metadata_path + '/parameters'
|
|
604
|
-
|
|
605
|
-
QueryModelForGetProjectParams, QueryModelForPostProjectParams = get_query_models_for_parameters(None)
|
|
606
|
-
|
|
607
|
-
@app.get(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
|
|
608
|
-
async def get_project_parameters(
|
|
609
|
-
request: Request, params: QueryModelForGetProjectParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
610
|
-
) -> arm.ParametersModel:
|
|
611
|
-
start = time.time()
|
|
612
|
-
result = await get_parameters_definition(
|
|
613
|
-
None, "project", "", PermissionScope.PUBLIC, user, dict(request.query_params), asdict(params)
|
|
614
|
-
)
|
|
615
|
-
self.logger.log_activity_time("GET REQUEST for PROJECT PARAMETERS", start, request_id=_get_request_id(request))
|
|
616
|
-
return result
|
|
617
|
-
|
|
618
|
-
@app.post(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
|
|
619
|
-
async def get_project_parameters_with_post(
|
|
620
|
-
request: Request, params: QueryModelForPostProjectParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
621
|
-
) -> arm.ParametersModel:
|
|
622
|
-
start = time.time()
|
|
623
|
-
params_model: BaseModel = params
|
|
624
|
-
payload: dict = await request.json()
|
|
625
|
-
result = await get_parameters_definition(
|
|
626
|
-
None, "project", "", PermissionScope.PUBLIC, user, payload, params_model.model_dump()
|
|
627
|
-
)
|
|
628
|
-
self.logger.log_activity_time("POST REQUEST for PROJECT PARAMETERS", start, request_id=_get_request_id(request))
|
|
629
|
-
return result
|
|
630
|
-
|
|
631
|
-
# Dataset Results API Helpers
|
|
632
|
-
async def get_dataset_results_helper(
|
|
633
|
-
dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
634
|
-
) -> DatasetResult:
|
|
635
|
-
return await self.project.dataset(dataset, selections=dict(selections), user=user)
|
|
636
|
-
|
|
637
|
-
dataset_results_cache_size = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_SIZE, 128))
|
|
638
|
-
dataset_results_cache_ttl = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_TTL_MINUTES, 60))
|
|
639
|
-
dataset_results_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
|
|
640
|
-
|
|
641
|
-
async def get_dataset_results_cachable(
|
|
642
|
-
dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
643
|
-
) -> DatasetResult:
|
|
644
|
-
return await do_cachable_action(dataset_results_cache, get_dataset_results_helper, dataset, user, selections)
|
|
645
|
-
|
|
646
|
-
async def get_dataset_results_definition(
|
|
647
|
-
dataset_name: str, user: BaseUser | None, all_request_params: dict, params: Mapping
|
|
648
|
-
) -> arm.DatasetResultModel:
|
|
649
|
-
self._validate_request_params(all_request_params, params)
|
|
650
|
-
|
|
651
|
-
get_dataset_function = get_dataset_results_helper if self.no_cache else get_dataset_results_cachable
|
|
652
|
-
uncached_keys = {"x_verify_params", "x_orientation", "x_select", "x_limit", "x_offset"}
|
|
653
|
-
selections = get_selections_as_immutable(params, uncached_keys)
|
|
654
|
-
result = await get_dataset_function(dataset_name, user, selections)
|
|
655
|
-
|
|
656
|
-
orientation = params.get("x_orientation", "records")
|
|
657
|
-
raw_select = params.get("x_select")
|
|
658
|
-
select = tuple(raw_select) if raw_select is not None else tuple()
|
|
659
|
-
limit = params.get("x_limit", 1000)
|
|
660
|
-
offset = params.get("x_offset", 0)
|
|
661
|
-
return arm.DatasetResultModel(**result.to_json(orientation, select, limit, offset))
|
|
662
|
-
|
|
663
|
-
# Dashboard Results API Helpers
|
|
664
|
-
async def get_dashboard_results_helper(
|
|
665
|
-
dashboard: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
666
|
-
) -> Dashboard:
|
|
667
|
-
return await self.project.dashboard(dashboard, selections=dict(selections), user=user)
|
|
668
|
-
|
|
669
|
-
dashboard_results_cache_size = int(self.env_vars.get(c.SQRL_DASHBOARDS_CACHE_SIZE, 128))
|
|
670
|
-
dashboard_results_cache_ttl = int(self.env_vars.get(c.SQRL_DASHBOARDS_CACHE_TTL_MINUTES, 60))
|
|
671
|
-
dashboard_results_cache = TTLCache(maxsize=dashboard_results_cache_size, ttl=dashboard_results_cache_ttl*60)
|
|
672
|
-
|
|
673
|
-
async def get_dashboard_results_cachable(
|
|
674
|
-
dashboard: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
675
|
-
) -> Dashboard:
|
|
676
|
-
return await do_cachable_action(dashboard_results_cache, get_dashboard_results_helper, dashboard, user, selections)
|
|
677
|
-
|
|
678
|
-
async def get_dashboard_results_definition(
|
|
679
|
-
dashboard_name: str, user: BaseUser | None, all_request_params: dict, params: Mapping
|
|
680
|
-
) -> Response:
|
|
681
|
-
self._validate_request_params(all_request_params, params)
|
|
682
|
-
|
|
683
|
-
get_dashboard_function = get_dashboard_results_helper if self.no_cache else get_dashboard_results_cachable
|
|
684
|
-
selections = get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
|
|
685
|
-
dashboard_obj = await get_dashboard_function(dashboard_name, user, selections)
|
|
686
|
-
if dashboard_obj._format == c.PNG:
|
|
687
|
-
assert isinstance(dashboard_obj._content, bytes)
|
|
688
|
-
result = Response(dashboard_obj._content, media_type="image/png")
|
|
689
|
-
elif dashboard_obj._format == c.HTML:
|
|
690
|
-
result = HTMLResponse(dashboard_obj._content)
|
|
691
|
-
else:
|
|
692
|
-
raise NotImplementedError()
|
|
693
|
-
return result
|
|
694
|
-
|
|
695
|
-
# Dataset Parameters and Results APIs
|
|
696
|
-
for dataset_name, dataset_config in self.manifest_cfg.datasets.items():
|
|
697
|
-
dataset_normalized = u.normalize_name_for_api(dataset_name)
|
|
698
|
-
curr_parameters_path = dataset_parameters_path.format(dataset=dataset_normalized)
|
|
699
|
-
curr_results_path = dataset_results_path.format(dataset=dataset_normalized)
|
|
700
|
-
|
|
701
|
-
validate_parameters_list(dataset_config.parameters, "Dataset")
|
|
702
|
-
|
|
703
|
-
QueryModelForGetParams, QueryModelForPostParams = get_query_models_for_parameters(dataset_config.parameters)
|
|
704
|
-
QueryModelForGetDataset, QueryModelForPostDataset = get_query_models_for_dataset(dataset_config.parameters)
|
|
705
|
-
|
|
706
|
-
@app.get(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=parameters_description, response_class=JSONResponse)
|
|
707
|
-
async def get_dataset_parameters(
|
|
708
|
-
request: Request, params: QueryModelForGetParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
709
|
-
) -> arm.ParametersModel:
|
|
710
|
-
start = time.time()
|
|
711
|
-
curr_dataset_name = get_dataset_name(request, -2)
|
|
712
|
-
parameters_list = self.manifest_cfg.datasets[curr_dataset_name].parameters
|
|
713
|
-
scope = self.manifest_cfg.datasets[curr_dataset_name].scope
|
|
714
|
-
result = await get_parameters_definition(
|
|
715
|
-
parameters_list, "dataset", curr_dataset_name, scope, user, dict(request.query_params), asdict(params)
|
|
716
|
-
)
|
|
717
|
-
self.logger.log_activity_time("GET REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
|
|
718
|
-
return result
|
|
719
|
-
|
|
720
|
-
@app.post(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=parameters_description, response_class=JSONResponse)
|
|
721
|
-
async def get_dataset_parameters_with_post(
|
|
722
|
-
request: Request, params: QueryModelForPostParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
723
|
-
) -> arm.ParametersModel:
|
|
724
|
-
start = time.time()
|
|
725
|
-
curr_dataset_name = get_dataset_name(request, -2)
|
|
726
|
-
parameters_list = self.manifest_cfg.datasets[curr_dataset_name].parameters
|
|
727
|
-
scope = self.manifest_cfg.datasets[curr_dataset_name].scope
|
|
728
|
-
params: BaseModel = params
|
|
729
|
-
payload: dict = await request.json()
|
|
730
|
-
result = await get_parameters_definition(
|
|
731
|
-
parameters_list, "dataset", curr_dataset_name, scope, user, payload, params.model_dump()
|
|
732
|
-
)
|
|
733
|
-
self.logger.log_activity_time("POST REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
|
|
734
|
-
return result
|
|
735
|
-
|
|
736
|
-
@app.get(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
|
|
737
|
-
async def get_dataset_results(
|
|
738
|
-
request: Request, params: QueryModelForGetDataset, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
739
|
-
) -> arm.DatasetResultModel:
|
|
740
|
-
start = time.time()
|
|
741
|
-
curr_dataset_name = get_dataset_name(request, -1)
|
|
742
|
-
result = await get_dataset_results_definition(curr_dataset_name, user, dict(request.query_params), asdict(params))
|
|
743
|
-
self.logger.log_activity_time("GET REQUEST for DATASET RESULTS", start, request_id=_get_request_id(request))
|
|
744
|
-
return result
|
|
745
|
-
|
|
746
|
-
@app.post(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
|
|
747
|
-
async def get_dataset_results_with_post(
|
|
748
|
-
request: Request, params: QueryModelForPostDataset, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
749
|
-
) -> arm.DatasetResultModel:
|
|
750
|
-
start = time.time()
|
|
751
|
-
curr_dataset_name = get_dataset_name(request, -1)
|
|
752
|
-
params: BaseModel = params
|
|
753
|
-
payload: dict = await request.json()
|
|
754
|
-
result = await get_dataset_results_definition(curr_dataset_name, user, payload, params.model_dump())
|
|
755
|
-
self.logger.log_activity_time("POST REQUEST for DATASET RESULTS", start, request_id=_get_request_id(request))
|
|
756
|
-
return result
|
|
757
|
-
|
|
758
|
-
# Dashboard Parameters and Results APIs
|
|
759
|
-
for dashboard_name, dashboard in self.dashboards.items():
|
|
760
|
-
dashboard_normalized = u.normalize_name_for_api(dashboard_name)
|
|
761
|
-
curr_parameters_path = dashboard_parameters_path.format(dashboard=dashboard_normalized)
|
|
762
|
-
curr_results_path = dashboard_results_path.format(dashboard=dashboard_normalized)
|
|
763
538
|
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
QueryModelForGetDash, QueryModelForPostDash = get_query_models_for_dashboard(dashboard.config.parameters)
|
|
768
|
-
|
|
769
|
-
@app.get(curr_parameters_path, tags=[f"Dashboard '{dashboard_name}'"], description=parameters_description, response_class=JSONResponse)
|
|
770
|
-
async def get_dashboard_parameters(
|
|
771
|
-
request: Request, params: QueryModelForGetParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
772
|
-
) -> arm.ParametersModel:
|
|
773
|
-
start = time.time()
|
|
774
|
-
curr_dashboard_name = get_dashboard_name(request, -2)
|
|
775
|
-
parameters_list = self.dashboards[curr_dashboard_name].config.parameters
|
|
776
|
-
scope = self.dashboards[curr_dashboard_name].config.scope
|
|
777
|
-
result = await get_parameters_definition(
|
|
778
|
-
parameters_list, "dashboard", curr_dashboard_name, scope, user, dict(request.query_params), asdict(params)
|
|
779
|
-
)
|
|
780
|
-
self.logger.log_activity_time("GET REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
|
|
781
|
-
return result
|
|
782
|
-
|
|
783
|
-
@app.post(curr_parameters_path, tags=[f"Dashboard '{dashboard_name}'"], description=parameters_description, response_class=JSONResponse)
|
|
784
|
-
async def get_dashboard_parameters_with_post(
|
|
785
|
-
request: Request, params: QueryModelForPostParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
786
|
-
) -> arm.ParametersModel:
|
|
787
|
-
start = time.time()
|
|
788
|
-
curr_dashboard_name = get_dashboard_name(request, -2)
|
|
789
|
-
parameters_list = self.dashboards[curr_dashboard_name].config.parameters
|
|
790
|
-
scope = self.dashboards[curr_dashboard_name].config.scope
|
|
791
|
-
params: BaseModel = params
|
|
792
|
-
payload: dict = await request.json()
|
|
793
|
-
result = await get_parameters_definition(
|
|
794
|
-
parameters_list, "dashboard", curr_dashboard_name, scope, user, payload, params.model_dump()
|
|
795
|
-
)
|
|
796
|
-
self.logger.log_activity_time("POST REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
|
|
797
|
-
return result
|
|
798
|
-
|
|
799
|
-
@app.get(curr_results_path, tags=[f"Dashboard '{dashboard_name}'"], description=dashboard.config.description, response_class=Response)
|
|
800
|
-
async def get_dashboard_results(
|
|
801
|
-
request: Request, params: QueryModelForGetDash, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
802
|
-
) -> Response:
|
|
803
|
-
start = time.time()
|
|
804
|
-
curr_dashboard_name = get_dashboard_name(request, -1)
|
|
805
|
-
result = await get_dashboard_results_definition(curr_dashboard_name, user, dict(request.query_params), asdict(params))
|
|
806
|
-
self.logger.log_activity_time("GET REQUEST for DASHBOARD RESULTS", start, request_id=_get_request_id(request))
|
|
807
|
-
return result
|
|
808
|
-
|
|
809
|
-
@app.post(curr_results_path, tags=[f"Dashboard '{dashboard_name}'"], description=dashboard.config.description, response_class=Response)
|
|
810
|
-
async def get_dashboard_results_with_post(
|
|
811
|
-
request: Request, params: QueryModelForPostDash, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
812
|
-
) -> Response:
|
|
813
|
-
start = time.time()
|
|
814
|
-
curr_dashboard_name = get_dashboard_name(request, -1)
|
|
815
|
-
params: BaseModel = params
|
|
816
|
-
payload: dict = await request.json()
|
|
817
|
-
result = await get_dashboard_results_definition(curr_dashboard_name, user, payload, params.model_dump())
|
|
818
|
-
self.logger.log_activity_time("POST REQUEST for DASHBOARD RESULTS", start, request_id=_get_request_id(request))
|
|
819
|
-
return result
|
|
820
|
-
|
|
821
|
-
# Build Project API
|
|
822
|
-
@app.post(project_metadata_path + '/build', tags=["Data Management"], summary="Build or update the virtual data environment for the project")
|
|
823
|
-
async def build(user: BaseUser | None = Depends(get_current_user)): # type: ignore
|
|
824
|
-
if not self.authenticator.can_user_access_scope(user, PermissionScope.PRIVATE):
|
|
825
|
-
raise InvalidInputError(26, f"User '{user}' does not have permission to build the virtual data environment")
|
|
826
|
-
await self.project.build(stage_file=True)
|
|
827
|
-
return Response(status_code=status.HTTP_200_OK)
|
|
828
|
-
|
|
829
|
-
# Query Models API
|
|
830
|
-
query_models_path = project_metadata_path + '/query-models'
|
|
831
|
-
QueryModelForQueryModels, QueryModelForPostQueryModels = get_query_models_for_querying_models()
|
|
832
|
-
|
|
833
|
-
async def query_models_helper(
|
|
834
|
-
sql_query: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
835
|
-
) -> DatasetResult:
|
|
836
|
-
return await self.project.query_models(sql_query, selections=dict(selections), user=user)
|
|
837
|
-
|
|
838
|
-
async def query_models_cachable(
|
|
839
|
-
sql_query: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
|
|
840
|
-
) -> DatasetResult:
|
|
841
|
-
# Share the same cache for dataset results
|
|
842
|
-
return await do_cachable_action(dataset_results_cache, query_models_helper, sql_query, user, selections)
|
|
843
|
-
|
|
844
|
-
async def query_models_definition(
|
|
845
|
-
user: BaseUser | None, all_request_params: dict, params: Mapping
|
|
846
|
-
) -> arm.DatasetResultModel:
|
|
847
|
-
self._validate_request_params(all_request_params, params)
|
|
848
|
-
|
|
849
|
-
if not self.authenticator.can_user_access_scope(user, PermissionScope.PRIVATE):
|
|
850
|
-
raise InvalidInputError(27, f"User '{user}' does not have permission to query data models")
|
|
851
|
-
sql_query = params.get("x_sql_query")
|
|
852
|
-
if sql_query is None:
|
|
853
|
-
raise InvalidInputError(203, "SQL query must be provided")
|
|
539
|
+
mount_path_stripped = server.mount_path.rstrip("/")
|
|
540
|
+
if mount_path_stripped != "":
|
|
541
|
+
root_app.add_route("/mcp", self._mcp_app, methods=["GET", "POST"])
|
|
854
542
|
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
orientation = params.get("x_orientation", "records")
|
|
861
|
-
limit = params.get("x_limit", 1000)
|
|
862
|
-
offset = params.get("x_offset", 0)
|
|
863
|
-
return arm.DatasetResultModel(**result.to_json(orientation, tuple(), limit, offset))
|
|
864
|
-
|
|
865
|
-
@app.get(query_models_path, tags=["Data Management"], response_class=JSONResponse)
|
|
866
|
-
async def query_models(
|
|
867
|
-
request: Request, params: QueryModelForQueryModels, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
868
|
-
) -> arm.DatasetResultModel:
|
|
869
|
-
start = time.time()
|
|
870
|
-
result = await query_models_definition(user, dict(request.query_params), asdict(params))
|
|
871
|
-
self.logger.log_activity_time("GET REQUEST for QUERY MODELS", start, request_id=_get_request_id(request))
|
|
872
|
-
return result
|
|
873
|
-
|
|
874
|
-
@app.post(query_models_path, tags=["Data Management"], response_class=JSONResponse)
|
|
875
|
-
async def query_models_with_post(
|
|
876
|
-
request: Request, params: QueryModelForPostQueryModels, user: BaseUser | None = Depends(get_current_user) # type: ignore
|
|
877
|
-
) -> arm.DatasetResultModel:
|
|
878
|
-
start = time.time()
|
|
879
|
-
params: BaseModel = params
|
|
880
|
-
payload: dict = await request.json()
|
|
881
|
-
result = await query_models_definition(user, payload, params.model_dump())
|
|
882
|
-
self.logger.log_activity_time("POST REQUEST for QUERY MODELS", start, request_id=_get_request_id(request))
|
|
883
|
-
return result
|
|
884
|
-
|
|
885
|
-
# Add Root Path Redirection to Squirrels Studio
|
|
886
|
-
full_hostname = f"http://{uvicorn_args.host}:{uvicorn_args.port}"
|
|
887
|
-
encoded_hostname = urllib.parse.quote(full_hostname, safe="")
|
|
888
|
-
squirrels_studio_url = f"https://squirrels-analytics.github.io/squirrels-studio/#/login?host={encoded_hostname}&projectName={project_name}&projectVersion={project_version}"
|
|
889
|
-
|
|
890
|
-
@app.get("/", include_in_schema=False)
|
|
891
|
-
async def redirect_to_studio():
|
|
892
|
-
return RedirectResponse(url=squirrels_studio_url)
|
|
893
|
-
|
|
543
|
+
@root_app.get("/", include_in_schema=False)
|
|
544
|
+
async def redirect_to_studio():
|
|
545
|
+
return RedirectResponse(url=f"{mount_path_stripped}/studio")
|
|
546
|
+
|
|
894
547
|
# Run the API Server
|
|
895
548
|
import uvicorn
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
print(f"- API Docs (with Swagger UI): {full_hostname}{project_metadata_path}/docs")
|
|
901
|
-
print()
|
|
902
|
-
|
|
903
|
-
self.logger.log_activity_time("creating app server", start)
|
|
904
|
-
uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port)
|
|
549
|
+
uvicorn.run(
|
|
550
|
+
root_app, host=host, port=port, proxy_headers=True, forwarded_allow_ips=forwarded_allow_ips
|
|
551
|
+
)
|
|
552
|
+
|