squirrels 0.5.0rc0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (108) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +58 -111
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +10 -12
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +271 -0
  8. squirrels/_api_routes/base.py +171 -0
  9. squirrels/_api_routes/dashboards.py +158 -0
  10. squirrels/_api_routes/data_management.py +148 -0
  11. squirrels/_api_routes/datasets.py +265 -0
  12. squirrels/_api_routes/oauth2.py +298 -0
  13. squirrels/_api_routes/project.py +252 -0
  14. squirrels/_api_server.py +245 -781
  15. squirrels/_arguments/__init__.py +0 -0
  16. squirrels/{arguments → _arguments}/init_time_args.py +7 -2
  17. squirrels/{arguments → _arguments}/run_time_args.py +13 -35
  18. squirrels/_auth.py +720 -212
  19. squirrels/_command_line.py +81 -41
  20. squirrels/_compile_prompts.py +147 -0
  21. squirrels/_connection_set.py +16 -7
  22. squirrels/_constants.py +29 -9
  23. squirrels/{_dashboards_io.py → _dashboards.py} +87 -6
  24. squirrels/_data_sources.py +570 -0
  25. squirrels/{dataset_result.py → _dataset_types.py} +2 -4
  26. squirrels/_exceptions.py +9 -37
  27. squirrels/_initializer.py +83 -59
  28. squirrels/_logging.py +117 -0
  29. squirrels/_manifest.py +129 -62
  30. squirrels/_model_builder.py +10 -52
  31. squirrels/_model_configs.py +3 -3
  32. squirrels/_model_queries.py +1 -1
  33. squirrels/_models.py +249 -118
  34. squirrels/{package_data → _package_data}/base_project/.env +16 -4
  35. squirrels/{package_data → _package_data}/base_project/.env.example +15 -3
  36. squirrels/{package_data → _package_data}/base_project/connections.yml +4 -3
  37. squirrels/{package_data → _package_data}/base_project/dashboards/dashboard_example.py +4 -4
  38. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  39. squirrels/{package_data → _package_data}/base_project/duckdb_init.sql +1 -0
  40. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  41. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.py +2 -2
  42. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.sql +1 -1
  43. squirrels/{package_data → _package_data}/base_project/models/builds/build_example.yml +2 -0
  44. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
  45. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
  46. squirrels/_package_data/base_project/models/federates/federate_example.py +48 -0
  47. squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
  48. squirrels/{package_data → _package_data}/base_project/models/federates/federate_example.yml +7 -7
  49. squirrels/{package_data → _package_data}/base_project/models/sources.yml +5 -6
  50. squirrels/{package_data → _package_data}/base_project/parameters.yml +32 -45
  51. squirrels/_package_data/base_project/pyconfigs/connections.py +18 -0
  52. squirrels/{package_data → _package_data}/base_project/pyconfigs/context.py +31 -22
  53. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  54. squirrels/_package_data/base_project/pyconfigs/user.py +44 -0
  55. squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.yml +1 -1
  56. squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.yml +1 -1
  57. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  58. squirrels/_package_data/templates/dataset_results.html +112 -0
  59. squirrels/_package_data/templates/oauth_login.html +271 -0
  60. squirrels/_package_data/templates/squirrels_studio.html +20 -0
  61. squirrels/_parameter_configs.py +76 -55
  62. squirrels/_parameter_options.py +348 -0
  63. squirrels/_parameter_sets.py +53 -45
  64. squirrels/_parameters.py +1664 -0
  65. squirrels/_project.py +403 -242
  66. squirrels/_py_module.py +3 -2
  67. squirrels/_request_context.py +33 -0
  68. squirrels/_schemas/__init__.py +0 -0
  69. squirrels/_schemas/auth_models.py +167 -0
  70. squirrels/_schemas/query_param_models.py +75 -0
  71. squirrels/{_api_response_models.py → _schemas/response_models.py} +48 -18
  72. squirrels/_seeds.py +1 -1
  73. squirrels/_sources.py +23 -19
  74. squirrels/_utils.py +121 -39
  75. squirrels/_version.py +1 -1
  76. squirrels/arguments.py +7 -0
  77. squirrels/auth.py +4 -0
  78. squirrels/connections.py +3 -0
  79. squirrels/dashboards.py +2 -81
  80. squirrels/data_sources.py +14 -563
  81. squirrels/parameter_options.py +13 -348
  82. squirrels/parameters.py +14 -1266
  83. squirrels/types.py +16 -0
  84. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/METADATA +42 -30
  85. squirrels-0.5.1.dist-info/RECORD +98 -0
  86. squirrels/package_data/base_project/dashboards/dashboard_example.yml +0 -22
  87. squirrels/package_data/base_project/macros/macros_example.sql +0 -15
  88. squirrels/package_data/base_project/models/dbviews/dbview_example.sql +0 -12
  89. squirrels/package_data/base_project/models/dbviews/dbview_example.yml +0 -26
  90. squirrels/package_data/base_project/models/federates/federate_example.py +0 -44
  91. squirrels/package_data/base_project/models/federates/federate_example.sql +0 -17
  92. squirrels/package_data/base_project/pyconfigs/connections.py +0 -14
  93. squirrels/package_data/base_project/pyconfigs/parameters.py +0 -93
  94. squirrels/package_data/base_project/pyconfigs/user.py +0 -23
  95. squirrels/package_data/base_project/squirrels.yml.j2 +0 -71
  96. squirrels-0.5.0rc0.dist-info/RECORD +0 -70
  97. /squirrels/{package_data → _package_data}/base_project/assets/expenses.db +0 -0
  98. /squirrels/{package_data → _package_data}/base_project/assets/weather.db +0 -0
  99. /squirrels/{package_data → _package_data}/base_project/docker/.dockerignore +0 -0
  100. /squirrels/{package_data → _package_data}/base_project/docker/Dockerfile +0 -0
  101. /squirrels/{package_data → _package_data}/base_project/docker/compose.yml +0 -0
  102. /squirrels/{package_data/base_project/.gitignore → _package_data/base_project/gitignore} +0 -0
  103. /squirrels/{package_data → _package_data}/base_project/seeds/seed_categories.csv +0 -0
  104. /squirrels/{package_data → _package_data}/base_project/seeds/seed_subcategories.csv +0 -0
  105. /squirrels/{package_data → _package_data}/base_project/tmp/.gitignore +0 -0
  106. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
  107. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
  108. {squirrels-0.5.0rc0.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
squirrels/_api_server.py CHANGED
@@ -1,30 +1,83 @@
1
- from typing import Coroutine, Mapping, Callable, TypeVar, Annotated, Any
2
- from dataclasses import make_dataclass, asdict
3
- from fastapi import Depends, FastAPI, Request, Response, status
4
- from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
5
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from pydantic import create_model, BaseModel, Field
1
+ from fastapi import FastAPI, Request, status
2
+ from fastapi.responses import JSONResponse, HTMLResponse, RedirectResponse
3
+ from fastapi.security import HTTPBearer
4
+ from fastapi.templating import Jinja2Templates
5
+ from fastapi.staticfiles import StaticFiles
6
+ from starlette.middleware.base import BaseHTTPMiddleware
7
+ from starlette.responses import Response as StarletteResponse
8
8
  from contextlib import asynccontextmanager
9
- from cachetools import TTLCache
10
9
  from argparse import Namespace
11
10
  from pathlib import Path
12
- import io, time, mimetypes, traceback, uuid, asyncio, urllib.parse
11
+ from starlette.middleware.sessions import SessionMiddleware
12
+ from mcp.server.fastmcp import FastMCP
13
+ import io, time, mimetypes, traceback, uuid, asyncio, contextlib
13
14
 
14
- from . import _constants as c, _utils as u, _api_response_models as arm
15
+ from . import _constants as c, _utils as u, _parameter_sets as ps
15
16
  from ._exceptions import InvalidInputError, ConfigurationError, FileExecutionError
16
17
  from ._version import __version__, sq_major_version
17
- from ._manifest import PermissionScope
18
- from ._auth import BaseUser, AccessToken, UserField
19
- from ._parameter_sets import ParameterSet
20
- from .dashboards import Dashboard
21
18
  from ._project import SquirrelsProject
22
- from .dataset_result import DatasetResult
23
- from ._parameter_configs import APIParamFieldInfo
19
+ from ._request_context import set_request_id
20
+
21
+ # Import route modules
22
+ from ._api_routes.auth import AuthRoutes
23
+ from ._api_routes.project import ProjectRoutes
24
+ from ._api_routes.datasets import DatasetRoutes
25
+ from ._api_routes.dashboards import DashboardRoutes
26
+ from ._api_routes.data_management import DataManagementRoutes
27
+
28
+ # # Disabled for now, a 'bring your own OAuth2 server' approach will be provided in the future
29
+ # from ._api_routes.oauth2 import OAuth2Routes
24
30
 
25
31
  mimetypes.add_type('application/javascript', '.js')
26
32
 
27
33
 
34
+ class SmartCORSMiddleware(BaseHTTPMiddleware):
35
+ """
36
+ Custom CORS middleware that allows specific origins to use credentials
37
+ while still allowing all other origins without credentials.
38
+ """
39
+
40
+ def __init__(self, app, allowed_credential_origins: list[str], configurables_as_headers: list[str]):
41
+ super().__init__(app)
42
+
43
+ allowed_predefined_headers = ["Authorization", "Content-Type", "x-api-key", "x-orientation", "x-verify-params"]
44
+
45
+ self.allowed_credential_origins = allowed_credential_origins
46
+ self.allowed_request_headers = ",".join(allowed_predefined_headers + configurables_as_headers)
47
+
48
+ async def dispatch(self, request: Request, call_next):
49
+ origin = request.headers.get("origin")
50
+
51
+ # Handle preflight requests
52
+ if request.method == "OPTIONS":
53
+ response = StarletteResponse(status_code=200)
54
+ response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
55
+ response.headers["Access-Control-Allow-Headers"] = self.allowed_request_headers
56
+
57
+ else:
58
+ # Call the next middleware/route
59
+ response: StarletteResponse = await call_next(request)
60
+
61
+ # Always expose the Applied-Username header
62
+ response.headers["Access-Control-Expose-Headers"] = "Applied-Username"
63
+
64
+ if origin:
65
+ scheme = u.get_scheme(request.url.hostname)
66
+ request_origin = f"{scheme}://{request.url.netloc}"
67
+ # Check if this origin is in the whitelist or if origin matches the host origin
68
+ if origin == request_origin or origin in self.allowed_credential_origins:
69
+ response.headers["Access-Control-Allow-Origin"] = origin
70
+ response.headers["Access-Control-Allow-Credentials"] = "true"
71
+ else:
72
+ # Allow all other origins but without credentials / cookies
73
+ response.headers["Access-Control-Allow-Origin"] = "*"
74
+ else:
75
+ # No origin header (same-origin request or non-browser)
76
+ response.headers["Access-Control-Allow-Origin"] = "*"
77
+
78
+ return response
79
+
80
+
28
81
  class ApiServer:
29
82
  def __init__(self, no_cache: bool, project: SquirrelsProject) -> None:
30
83
  """
@@ -47,68 +100,83 @@ class ApiServer:
47
100
  self.param_cfg_set = project._param_cfg_set
48
101
  self.context_func = project._context_func
49
102
  self.dashboards = project._dashboards
103
+
104
+ self.mcp = FastMCP(
105
+ name="Squirrels",
106
+ stateless_http=True
107
+ )
108
+
109
+ # Initialize route modules
110
+ get_bearer_token = HTTPBearer(auto_error=False)
111
+ # self.oauth2_routes = OAuth2Routes(get_bearer_token, project, no_cache)
112
+ self.auth_routes = AuthRoutes(get_bearer_token, project, no_cache)
113
+ self.project_routes = ProjectRoutes(get_bearer_token, project, no_cache)
114
+ self.dataset_routes = DatasetRoutes(get_bearer_token, project, no_cache)
115
+ self.dashboard_routes = DashboardRoutes(get_bearer_token, project, no_cache)
116
+ self.data_management_routes = DataManagementRoutes(get_bearer_token, project, no_cache)
50
117
 
51
118
 
52
- async def _monitor_for_staging_file(self) -> None:
53
- """Background task that monitors for staging file and renames it when present"""
54
- duckdb_venv_path = self.project._duckdb_venv_path
55
- staging_file = Path(duckdb_venv_path + ".stg")
56
- target_file = Path(duckdb_venv_path)
57
-
119
+ async def _refresh_datasource_params(self) -> None:
120
+ """
121
+ Background task to periodically refresh datasource parameter options.
122
+ Runs every N minutes as configured by SQRL_PARAMETERS__DATASOURCE_REFRESH_MINUTES (default: 60).
123
+ """
124
+ refresh_minutes_str = self.env_vars.get(c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES, "60")
125
+ try:
126
+ refresh_minutes = int(refresh_minutes_str)
127
+ if refresh_minutes <= 0:
128
+ self.logger.info(f"The value of {c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES} is: {refresh_minutes_str} minutes")
129
+ self.logger.info(f"Datasource parameter refresh is disabled since the refresh interval is not positive.")
130
+ return
131
+ except ValueError:
132
+ self.logger.warning(f"Invalid value for {c.SQRL_PARAMETERS_DATASOURCE_REFRESH_MINUTES}: {refresh_minutes_str}. Must be an integer. Disabling datasource parameter refresh.")
133
+ return
134
+
135
+ refresh_seconds = refresh_minutes * 60
136
+ self.logger.info(f"Starting datasource parameter refresh background task (every {refresh_minutes} minutes)")
137
+
58
138
  while True:
59
139
  try:
60
- if staging_file.exists():
61
- try:
62
- staging_file.replace(target_file)
63
- self.logger.info("Successfully renamed staging database to virtual environment database")
64
- except OSError:
65
- # Silently continue if file cannot be renamed (will retry next iteration)
66
- pass
140
+ await asyncio.sleep(refresh_seconds)
141
+ self.logger.info("Refreshing datasource parameter options...")
142
+
143
+ # Fetch fresh dataframes from datasources in a thread pool to avoid blocking
144
+ loop = asyncio.get_running_loop()
145
+ default_conn_name = self.manifest_cfg.env_vars.get(c.SQRL_CONNECTIONS_DEFAULT_NAME_USED, "default")
146
+ df_dict = await loop.run_in_executor(
147
+ None,
148
+ ps.ParameterConfigsSetIO._get_df_dict_from_data_sources,
149
+ self.param_cfg_set,
150
+ default_conn_name,
151
+ self.seeds,
152
+ self.conn_set,
153
+ self.project._datalake_db_path
154
+ )
155
+
156
+ # Re-convert datasource parameters with fresh data
157
+ self.param_cfg_set._post_process_params(df_dict)
67
158
 
159
+ self.logger.info("Successfully refreshed datasource parameter options")
160
+ except asyncio.CancelledError:
161
+ self.logger.info("Datasource parameter refresh task cancelled")
162
+ break
68
163
  except Exception as e:
69
- # Log any unexpected errors but keep running
70
- self.logger.error(f"Error in monitoring {c.DUCKDB_VENV_FILE + '.stg'}: {str(e)}")
71
-
72
- await asyncio.sleep(1) # Check every second
164
+ self.logger.error(f"Error refreshing datasource parameter options: {e}", exc_info=True)
165
+ # Continue the loop even if there's an error
73
166
 
74
167
  @asynccontextmanager
75
168
  async def _run_background_tasks(self, app: FastAPI):
76
- task = asyncio.create_task(self._monitor_for_staging_file())
77
- yield
78
- task.cancel()
79
-
80
-
81
- def _validate_request_params(self, all_request_params: Mapping, params: Mapping) -> None:
82
- invalid_params = [param for param in all_request_params if param not in params]
83
- if params.get("x_verify_params", False) and invalid_params:
84
- raise InvalidInputError(201, f"Invalid query parameters: {', '.join(invalid_params)}")
169
+ refresh_datasource_task = asyncio.create_task(self._refresh_datasource_params())
85
170
 
86
-
87
- def run(self, uvicorn_args: Namespace) -> None:
88
- """
89
- Runs the API server with uvicorn for CLI "squirrels run"
90
-
91
- Arguments:
92
- uvicorn_args: List of arguments to pass to uvicorn.run. Currently only supports "host" and "port"
93
- """
94
- start = time.time()
95
-
96
- squirrels_version_path = f'/api/squirrels-v{sq_major_version}'
97
- project_name = u.normalize_name_for_api(self.manifest_cfg.project_variables.name)
98
- project_version = f"v{self.manifest_cfg.project_variables.major_version}"
99
- project_metadata_path = squirrels_version_path + f"/project/{project_name}/{project_version}"
171
+ async with contextlib.AsyncExitStack() as stack:
172
+ await stack.enter_async_context(self.mcp.session_manager.run())
173
+ yield
100
174
 
101
- param_fields = self.param_cfg_set.get_all_api_field_info()
175
+ refresh_datasource_task.cancel()
176
+
102
177
 
178
+ def _get_tags_metadata(self) -> list[dict]:
103
179
  tags_metadata = [
104
- {
105
- "name": "Authentication",
106
- "description": "Submit authentication credentials, and get token for authentication",
107
- },
108
- {
109
- "name": "User Management",
110
- "description": "Manage users and their attributes",
111
- },
112
180
  {
113
181
  "name": "Project Metadata",
114
182
  "description": "Get information on project such as name, version, and other API endpoints",
@@ -131,8 +199,45 @@ class ApiServer:
131
199
  "description": f"Get parameters or results for dashboard '{dashboard_name}'",
132
200
  })
133
201
 
202
+ tags_metadata.extend([
203
+ {
204
+ "name": "Authentication",
205
+ "description": "Submit authentication credentials and authorize with a session cookie",
206
+ },
207
+ {
208
+ "name": "User Management",
209
+ "description": "Manage users and their attributes",
210
+ },
211
+ # {
212
+ # "name": "OAuth2",
213
+ # "description": "Authorize and get token using the OAuth2 protocol",
214
+ # },
215
+ ])
216
+ return tags_metadata
217
+
218
+
219
+ def run(self, uvicorn_args: Namespace) -> None:
220
+ """
221
+ Runs the API server with uvicorn for CLI "squirrels run"
222
+
223
+ Arguments:
224
+ uvicorn_args: List of arguments to pass to uvicorn.run. Currently only supports "host" and "port"
225
+ """
226
+ start = time.time()
227
+
228
+ squirrels_version_path = f'/api/squirrels/v{sq_major_version}'
229
+ project_name = self.manifest_cfg.project_variables.name
230
+ project_name_for_api = u.normalize_name_for_api(project_name)
231
+ project_label = self.manifest_cfg.project_variables.label
232
+ project_version = f"v{self.manifest_cfg.project_variables.major_version}"
233
+ project_metadata_path = squirrels_version_path + f"/project/{project_name_for_api}/{project_version}"
234
+
235
+ param_fields = self.param_cfg_set.get_all_api_field_info()
236
+
237
+ tags_metadata = self._get_tags_metadata()
238
+
134
239
  app = FastAPI(
135
- title=f"Squirrels APIs for '{self.manifest_cfg.project_variables.label}'", openapi_tags=tags_metadata,
240
+ title=f"Squirrels APIs for '{project_label}'", openapi_tags=tags_metadata,
136
241
  description="For specifying parameter selections to dataset APIs, you can choose between using query parameters with the GET method or using request body with the POST method",
137
242
  lifespan=self._run_background_tasks,
138
243
  openapi_url=project_metadata_path+"/openapi.json",
@@ -140,765 +245,124 @@ class ApiServer:
140
245
  redoc_url=project_metadata_path+"/redoc"
141
246
  )
142
247
 
143
- async def _log_request_run(request: Request) -> None:
144
- headers = dict(request.scope["headers"])
145
- request_id = uuid.uuid4().hex
146
- headers[b"x-request-id"] = request_id.encode()
147
- request.scope["headers"] = list(headers.items())
248
+ app.add_middleware(SessionMiddleware, secret_key=self.env_vars.get(c.SQRL_SECRET_KEY, ""), max_age=None, same_site="none", https_only=True)
148
249
 
250
+ async def _log_request_run(request: Request) -> None:
149
251
  try:
150
252
  body = await request.json()
151
253
  except Exception:
152
- body = None
254
+ body = None # Non-JSON payloads may contain sensitive information, so we don't log them
255
+
256
+ partial_headers: dict[str, str] = {}
257
+ for header in request.headers.keys():
258
+ if header.startswith("x-") and header not in ["x-api-key"]:
259
+ partial_headers[header] = request.headers[header]
153
260
 
154
- headers_dict = dict(request.headers)
155
261
  path, params = request.url.path, dict(request.query_params)
156
262
  path_with_params = f"{path}?{request.query_params}" if len(params) > 0 else path
157
- data = {"request_method": request.method, "request_path": path, "request_params": params, "request_headers": headers_dict, "request_body": body}
158
- info = {"request_id": request_id}
159
- self.logger.info(f'Running request: {request.method} {path_with_params}', extra={"data": data, "info": info})
160
-
161
- def _get_request_id(request: Request) -> str:
162
- return request.headers.get("x-request-id", "")
263
+ data = {"request_method": request.method, "request_path": path, "request_params": params, "request_body": body, "partial_headers": partial_headers}
264
+ self.logger.info(f'Running request: {request.method} {path_with_params}', data=data)
163
265
 
164
266
  @app.middleware("http")
165
267
  async def catch_exceptions_middleware(request: Request, call_next):
268
+ # Generate and set request ID for this request
269
+ request_id = set_request_id()
270
+
166
271
  buffer = io.StringIO()
167
272
  try:
168
273
  await _log_request_run(request)
169
- return await call_next(request)
274
+ response = await call_next(request)
170
275
  except InvalidInputError as exc:
171
- traceback.print_exc(file=buffer)
172
276
  message = str(exc)
173
- if exc.error_code < 20:
174
- status_code = status.HTTP_401_UNAUTHORIZED
175
- elif exc.error_code < 40:
176
- status_code = status.HTTP_403_FORBIDDEN
177
- elif exc.error_code < 60:
178
- status_code = status.HTTP_404_NOT_FOUND
179
- elif exc.error_code < 70:
180
- if exc.error_code == 61:
181
- message = "The dataset depends on static data models that cannot be found. You may need to build the virtual data environment first."
182
- status_code = status.HTTP_409_CONFLICT
183
- else:
184
- status_code = status.HTTP_400_BAD_REQUEST
277
+ self.logger.error(message)
185
278
  response = JSONResponse(
186
- status_code=status_code, content={"message": message, "blame": "API client", "error_code": exc.error_code}
279
+ status_code=exc.status_code, content={"error": exc.error, "error_description": exc.error_description}
187
280
  )
188
281
  except FileExecutionError as exc:
189
282
  traceback.print_exception(exc.error, file=buffer)
190
283
  buffer.write(str(exc))
191
284
  response = JSONResponse(
192
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": f"An unexpected error occurred", "blame": "Squirrels project"}
285
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": f"An unexpected server error occurred", "blame": "Squirrels project"}
193
286
  )
194
287
  except ConfigurationError as exc:
195
288
  traceback.print_exc(file=buffer)
196
289
  response = JSONResponse(
197
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": f"An unexpected error occurred", "blame": "Squirrels project"}
290
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": f"An unexpected server error occurred", "blame": "Squirrels project"}
198
291
  )
199
292
  except Exception as exc:
200
293
  traceback.print_exc(file=buffer)
201
294
  response = JSONResponse(
202
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": f"An unexpected error occurred", "blame": "Squirrels framework"}
295
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": f"An unexpected server error occurred", "blame": "Squirrels framework"}
203
296
  )
204
297
 
205
298
  err_msg = buffer.getvalue()
206
- self.logger.error(err_msg)
207
- print(err_msg)
208
- return response
209
-
210
- app.add_middleware(
211
- CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
212
- expose_headers=["Applied-Username"]
213
- )
214
-
215
- # Helpers
216
- T = TypeVar('T')
217
-
218
- def get_selections_as_immutable(params: Mapping, uncached_keys: set[str]) -> tuple[tuple[str, Any], ...]:
219
- # Changing selections into a cachable "tuple of pairs" that will later be converted to dictionary
220
- selections = list()
221
- for key, val in params.items():
222
- if key in uncached_keys or val is None:
223
- continue
224
- if isinstance(val, (list, tuple)):
225
- if len(val) == 1: # for backward compatibility
226
- val = val[0]
227
- else:
228
- val = tuple(val)
229
- selections.append((u.normalize_name(key), val))
230
- return tuple(selections)
231
-
232
- async def do_cachable_action(cache: TTLCache, action: Callable[..., Coroutine[Any, Any, T]], *args) -> T:
233
- cache_key = tuple(args)
234
- result = cache.get(cache_key)
235
- if result is None:
236
- result = await action(*args)
237
- cache[cache_key] = result
238
- return result
239
-
240
- def _get_query_models_helper(widget_parameters: list[str] | None, predefined_params: list[APIParamFieldInfo]):
241
- if widget_parameters is None:
242
- widget_parameters = list(param_fields.keys())
299
+ if err_msg:
300
+ self.logger.error(err_msg)
243
301
 
244
- QueryModelForGetRaw = make_dataclass("QueryParams", [
245
- param_fields[param].as_query_info() for param in widget_parameters
246
- ] + [param.as_query_info() for param in predefined_params])
247
- QueryModelForGet = Annotated[QueryModelForGetRaw, Depends()]
248
-
249
- field_definitions = {param: param_fields[param].as_body_info() for param in widget_parameters}
250
- for param in predefined_params:
251
- field_definitions[param.name] = param.as_body_info()
252
- QueryModelForPost = create_model("RequestBodyParams", **field_definitions) # type: ignore
253
- return QueryModelForGet, QueryModelForPost
254
-
255
- def get_query_models_for_parameters(widget_parameters: list[str] | None):
256
- predefined_params = [
257
- APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
258
- APIParamFieldInfo("x_parent_param", str, description="The parameter name used for parameter updates. If not provided, then all parameters are retrieved"),
259
- ]
260
- return _get_query_models_helper(widget_parameters, predefined_params)
261
-
262
- def get_query_models_for_dataset(widget_parameters: list[str] | None):
263
- predefined_params = [
264
- APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dataset"),
265
- APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
266
- APIParamFieldInfo("x_select", list[str], examples=[[]], description="The columns to select from the dataset. All are returned if not specified"),
267
- APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
268
- APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
269
- ]
270
- return _get_query_models_helper(widget_parameters, predefined_params)
271
-
272
- def get_query_models_for_dashboard(widget_parameters: list[str] | None):
273
- predefined_params = [
274
- APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid for the dashboard"),
275
- ]
276
- return _get_query_models_helper(widget_parameters, predefined_params)
277
-
278
- def get_query_models_for_querying_models():
279
- predefined_params = [
280
- APIParamFieldInfo("x_verify_params", bool, default=False, description="If true, the query parameters are verified to be valid"),
281
- APIParamFieldInfo("x_orientation", str, default="records", description="The orientation of the data to return, one of: 'records', 'rows', or 'columns'"),
282
- APIParamFieldInfo("x_offset", int, default=0, description="The number of rows to skip before returning data (applied after data caching)"),
283
- APIParamFieldInfo("x_limit", int, default=1000, description="The maximum number of rows to return (applied after data caching and offset)"),
284
- APIParamFieldInfo("x_sql_query", str, description="The SQL query to execute on the data models"),
285
- ]
286
- return _get_query_models_helper(None, predefined_params)
287
-
288
- def _get_section_from_request_path(request: Request, section: int) -> str:
289
- url_path: str = request.scope['route'].path
290
- return url_path.split('/')[section]
291
-
292
- def get_dataset_name(request: Request, section: int) -> str:
293
- dataset_raw = _get_section_from_request_path(request, section)
294
- return u.normalize_name(dataset_raw)
295
-
296
- def get_dashboard_name(request: Request, section: int) -> str:
297
- dashboard_raw = _get_section_from_request_path(request, section)
298
- return u.normalize_name(dashboard_raw)
299
-
300
- expiry_mins = self.env_vars.get(c.SQRL_AUTH_TOKEN_EXPIRE_MINUTES, 30)
301
- try:
302
- expiry_mins = int(expiry_mins)
303
- except ValueError:
304
- raise ConfigurationError(f"Value for environment variable {c.SQRL_AUTH_TOKEN_EXPIRE_MINUTES} is not an integer, got: {expiry_mins}")
305
-
306
- # Project Metadata API
307
-
308
- @app.get(project_metadata_path, tags=["Project Metadata"], response_class=JSONResponse)
309
- async def get_project_metadata(request: Request) -> arm.ProjectModel:
310
- return arm.ProjectModel(
311
- name=project_name,
312
- version=project_version,
313
- label=self.manifest_cfg.project_variables.label,
314
- description=self.manifest_cfg.project_variables.description,
315
- squirrels_version=__version__
316
- )
317
-
318
- # Authentication
319
- login_path = project_metadata_path + '/login'
320
-
321
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl=login_path, auto_error=False)
322
-
323
- async def get_current_user(response: Response, token: str = Depends(oauth2_scheme)) -> BaseUser | None:
324
- user = self.authenticator.get_user_from_token(token)
325
- username = "" if user is None else user.username
326
- response.headers["Applied-Username"] = username
327
- return user
328
-
329
- ## Login API
330
- @app.post(login_path, tags=["Authentication"])
331
- async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()) -> arm.LoginReponse:
332
- user = self.authenticator.get_user(form_data.username, form_data.password)
333
- access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_mins)
334
- return arm.LoginReponse(access_token=access_token, token_type="bearer", username=user.username, is_admin=user.is_admin, expiry_time=expiry)
335
-
336
- ## Change Password API
337
- change_password_path = project_metadata_path + '/change-password'
338
-
339
- class ChangePasswordRequest(BaseModel):
340
- old_password: str
341
- new_password: str
342
-
343
- @app.put(change_password_path, description="Change the password for the current user", tags=["Authentication"])
344
- async def change_password(request: ChangePasswordRequest, user: BaseUser | None = Depends(get_current_user)) -> None:
345
- if user is None:
346
- raise InvalidInputError(1, "Invalid authorization token")
347
- self.authenticator.change_password(user.username, request.old_password, request.new_password)
348
-
349
- ## Token API
350
- tokens_path = project_metadata_path + '/tokens'
351
-
352
- class TokenRequestBody(BaseModel):
353
- title: str | None = Field(default=None, description=f"The title of the token. If not provided, a temporary token is created (expiring in {expiry_mins} minutes) and cannot be revoked")
354
- expiry_minutes: int | None = Field(
355
- default=None,
356
- description=f"The number of minutes the token is valid for (or indefinitely if not provided). Ignored and set to {expiry_mins} minutes if title is not provided."
357
- )
358
-
359
- @app.post(tokens_path, description="Create a new token for the user", tags=["Authentication"])
360
- async def create_token(body: TokenRequestBody, user: BaseUser | None = Depends(get_current_user)) -> arm.LoginReponse:
361
- if user is None:
362
- raise InvalidInputError(1, "Invalid authorization token")
363
-
364
- if body.title is None:
365
- expiry_minutes = expiry_mins
366
- else:
367
- expiry_minutes = body.expiry_minutes
368
-
369
- access_token, expiry = self.authenticator.create_access_token(user, expiry_minutes=expiry_minutes, title=body.title)
370
- return arm.LoginReponse(access_token=access_token, token_type="bearer", username=user.username, is_admin=user.is_admin, expiry_time=expiry)
371
-
372
- ## Get All Tokens API
373
- @app.get(tokens_path, description="Get all tokens with title for the current user", tags=["Authentication"])
374
- async def get_all_tokens(user: BaseUser | None = Depends(get_current_user)) -> list[AccessToken]:
375
- if user is None:
376
- raise InvalidInputError(1, "Invalid authorization token")
377
- return self.authenticator.get_all_tokens(user.username)
378
-
379
- ## Revoke Token API
380
- revoke_token_path = project_metadata_path + '/tokens/{token_id}'
381
-
382
- @app.delete(revoke_token_path, description="Revoke a token", tags=["Authentication"])
383
- async def revoke_token(token_id: str, user: BaseUser | None = Depends(get_current_user)) -> None:
384
- if user is None:
385
- raise InvalidInputError(1, "Invalid authorization token")
386
- self.authenticator.revoke_token(user.username, token_id)
387
-
388
- ## Get Authenticated User Fields From Token API
389
- get_me_path = project_metadata_path + '/me'
390
-
391
- fields_without_username = {
392
- k: (v.annotation, v.default)
393
- for k, v in self.authenticator.User.model_fields.items()
394
- if k != "username"
395
- }
396
- UserModel = create_model("UserModel", __base__=BaseModel, **fields_without_username) # type: ignore
397
-
398
- class UserWithoutUsername(UserModel):
399
- pass
400
-
401
- class UserWithUsername(UserModel):
402
- username: str
403
-
404
- class AddUserRequestBody(UserWithUsername):
405
- password: str
406
-
407
- @app.get(get_me_path, description="Get the authenticated user's fields", tags=["Authentication"])
408
- async def get_me(user: BaseUser | None = Depends(get_current_user)) -> UserWithUsername:
409
- if user is None:
410
- raise InvalidInputError(1, "Invalid authorization token")
411
- return UserWithUsername(**user.model_dump(mode='json'))
412
-
413
- # User Management
414
-
415
- ## User Fields API
416
- user_fields_path = project_metadata_path + '/user-fields'
417
-
418
- @app.get(user_fields_path, description="Get details of the user fields", tags=["User Management"])
419
- async def get_user_fields() -> list[UserField]:
420
- return self.authenticator.user_fields
421
-
422
- ## Add User API
423
- add_user_path = project_metadata_path + '/users'
424
-
425
- @app.post(add_user_path, description="Add a new user by providing details for username, password, and user fields", tags=["User Management"])
426
- async def add_user(
427
- new_user: AddUserRequestBody, user: BaseUser | None = Depends(get_current_user)
428
- ) -> None:
429
- if user is None or not user.is_admin:
430
- raise InvalidInputError(20, "Authorized user is forbidden to add new users")
431
- self.authenticator.add_user(new_user.username, new_user.model_dump(mode='json', exclude={"username"}))
432
-
433
- ## Update User API
434
- update_user_path = project_metadata_path + '/users/{username}'
435
-
436
- @app.put(update_user_path, description="Update the user of the given username given the new user details", tags=["User Management"])
437
- async def update_user(
438
- username: str, updated_user: UserWithoutUsername, user: BaseUser | None = Depends(get_current_user)
439
- ) -> None:
440
- if user is None or not user.is_admin:
441
- raise InvalidInputError(20, "Authorized user is forbidden to update users")
442
- self.authenticator.add_user(username, updated_user.model_dump(mode='json'), update_user=True)
443
-
444
- ## List Users API
445
- list_users_path = project_metadata_path + '/users'
446
-
447
- @app.get(list_users_path, tags=["User Management"])
448
- async def list_all_users() -> list[UserWithUsername]:
449
- return self.authenticator.get_all_users()
450
-
451
- ## Delete User API
452
- delete_user_path = project_metadata_path + '/users/{username}'
453
-
454
- @app.delete(delete_user_path, tags=["User Management"])
455
- async def delete_user(username: str, user: BaseUser | None = Depends(get_current_user)) -> None:
456
- if user is None or not user.is_admin:
457
- raise InvalidInputError(21, "Authorized user is forbidden to delete users")
458
- if username == user.username:
459
- raise InvalidInputError(22, "Cannot delete your own user")
460
- self.authenticator.delete_user(username)
461
-
462
- # Data Catalog API
463
- data_catalog_path = project_metadata_path + '/data-catalog'
464
-
465
- dataset_results_path = project_metadata_path + '/dataset/{dataset}'
466
- dataset_parameters_path = dataset_results_path + '/parameters'
467
-
468
- dashboard_results_path = project_metadata_path + '/dashboard/{dashboard}'
469
- dashboard_parameters_path = dashboard_results_path + '/parameters'
470
-
471
- async def get_data_catalog0(user: BaseUser | None) -> arm.CatalogModel:
472
- parameters = self.param_cfg_set.apply_selections(None, {}, user)
473
- parameters_model = parameters.to_api_response_model0()
474
- full_parameters_list = [p.name for p in parameters_model.parameters]
475
-
476
- dataset_items: list[arm.DatasetItemModel] = []
477
- for name, config in self.manifest_cfg.datasets.items():
478
- if self.authenticator.can_user_access_scope(user, config.scope):
479
- name_normalized = u.normalize_name_for_api(name)
480
- metadata = self.project.dataset_metadata(name).to_json()
481
- parameters = config.parameters if config.parameters is not None else full_parameters_list
482
- dataset_items.append(arm.DatasetItemModel(
483
- name=name_normalized, label=config.label,
484
- description=config.description,
485
- schema=metadata["schema"], # type: ignore
486
- parameters=parameters,
487
- parameters_path=dataset_parameters_path.format(dataset=name_normalized),
488
- result_path=dataset_results_path.format(dataset=name_normalized)
489
- ))
490
-
491
- dashboard_items: list[arm.DashboardItemModel] = []
492
- for name, dashboard in self.dashboards.items():
493
- config = dashboard.config
494
- if self.authenticator.can_user_access_scope(user, config.scope):
495
- name_normalized = u.normalize_name_for_api(name)
496
-
497
- try:
498
- dashboard_format = self.dashboards[name].get_dashboard_format()
499
- except KeyError:
500
- raise ConfigurationError(f"No dashboard file found for: {name}")
501
-
502
- parameters = config.parameters if config.parameters is not None else full_parameters_list
503
- dashboard_items.append(arm.DashboardItemModel(
504
- name=name, label=config.label,
505
- description=config.description,
506
- result_format=dashboard_format,
507
- parameters=parameters,
508
- parameters_path=dashboard_parameters_path.format(dashboard=name_normalized),
509
- result_path=dashboard_results_path.format(dashboard=name_normalized)
510
- ))
511
-
512
- if user and user.is_admin:
513
- compiled_dag = await self.project._get_compiled_dag(user=user)
514
- connections_items = self.project._get_all_connections()
515
- data_models = self.project._get_all_data_models(compiled_dag)
516
- lineage_items = self.project._get_all_data_lineage(compiled_dag)
517
- else:
518
- connections_items = []
519
- data_models = []
520
- lineage_items = []
521
-
522
- return arm.CatalogModel(
523
- parameters=parameters_model.parameters,
524
- datasets=dataset_items,
525
- dashboards=dashboard_items,
526
- connections=connections_items,
527
- models=data_models,
528
- lineage=lineage_items,
529
- )
530
-
531
- @app.get(data_catalog_path, tags=["Project Metadata"], summary="Get catalog of datasets and dashboards available for user")
532
- async def get_data_catalog(request: Request, user: BaseUser | None = Depends(get_current_user)) -> arm.CatalogModel:
533
- """
534
- Get catalog of datasets and dashboards available for the authenticated user.
535
-
536
- For admin users, this endpoint will also return detailed information about all models and their lineage in the project.
537
- """
538
- return await get_data_catalog0(user)
539
-
540
- # Parameters API Helpers
541
- parameters_description = "Selections of one parameter may cascade the available options in another parameter. " \
542
- "For example, if the dataset has parameters for 'country' and 'city', available options for 'city' would " \
543
- "depend on the selected option 'country'. If a parameter has 'trigger_refresh' as true, provide the parameter " \
544
- "selection to this endpoint whenever it changes to refresh the parameter options of children parameters."
545
-
546
- async def get_parameters_helper(
547
- parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
548
- user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
549
- ) -> ParameterSet:
550
- selections_dict = dict(selections)
551
- if "x_parent_param" not in selections_dict:
552
- if len(selections_dict) > 1:
553
- raise InvalidInputError(202, f"The parameters endpoint takes at most 1 widget parameter selection (unless x_parent_param is provided). Got {selections_dict}")
554
- elif len(selections_dict) == 1:
555
- parent_param = next(iter(selections_dict))
556
- selections_dict["x_parent_param"] = parent_param
557
-
558
- parent_param = selections_dict.get("x_parent_param")
559
- if parent_param is not None and parent_param not in selections_dict:
560
- # this condition is possible for multi-select parameters with empty selection
561
- selections_dict[parent_param] = list()
302
+ # Add request ID to response header
303
+ response.headers["X-Request-ID"] = request_id
562
304
 
563
- if not self.authenticator.can_user_access_scope(user, entity_scope):
564
- raise self.project._permission_error(user, entity_type, entity_name, entity_scope.name)
565
-
566
- param_set = self.param_cfg_set.apply_selections(parameters_tuple, selections_dict, user, parent_param=parent_param)
567
- return param_set
568
-
569
- parameters_cache_size = int(self.env_vars.get(c.SQRL_PARAMETERS_CACHE_SIZE, 1024))
570
- parameters_cache_ttl = int(self.env_vars.get(c.SQRL_PARAMETERS_CACHE_TTL_MINUTES, 60))
571
- params_cache = TTLCache(maxsize=parameters_cache_size, ttl=parameters_cache_ttl*60)
572
-
573
- async def get_parameters_cachable(
574
- parameters_tuple: tuple[str, ...] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
575
- user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
576
- ) -> ParameterSet:
577
- return await do_cachable_action(params_cache, get_parameters_helper, parameters_tuple, entity_type, entity_name, entity_scope, user, selections)
578
-
579
- async def get_parameters_definition(
580
- parameters_list: list[str] | None, entity_type: str, entity_name: str, entity_scope: PermissionScope,
581
- user: BaseUser | None, all_request_params: dict, params: Mapping
582
- ) -> arm.ParametersModel:
583
- self._validate_request_params(all_request_params, params)
584
-
585
- get_parameters_function = get_parameters_helper if self.no_cache else get_parameters_cachable
586
- selections = get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
587
- parameters_tuple = tuple(parameters_list) if parameters_list is not None else None
588
- result = await get_parameters_function(parameters_tuple, entity_type, entity_name, entity_scope, user, selections)
589
- return result.to_api_response_model0()
590
-
591
- def validate_parameters_list(parameters: list[str] | None, entity_type: str) -> None:
592
- if parameters is None:
593
- return
594
- for param in parameters:
595
- if param not in param_fields:
596
- all_params = list(param_fields.keys())
597
- raise ConfigurationError(
598
- f"{entity_type} '{dataset_name}' use parameter '{param}' which doesn't exist. Available parameters are:"
599
- f"\n {all_params}"
600
- )
601
-
602
- # Project-Level Parameters API
603
- project_level_parameters_path = project_metadata_path + '/parameters'
604
-
605
- QueryModelForGetProjectParams, QueryModelForPostProjectParams = get_query_models_for_parameters(None)
606
-
607
- @app.get(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
608
- async def get_project_parameters(
609
- request: Request, params: QueryModelForGetProjectParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
610
- ) -> arm.ParametersModel:
611
- start = time.time()
612
- result = await get_parameters_definition(
613
- None, "project", "", PermissionScope.PUBLIC, user, dict(request.query_params), asdict(params)
614
- )
615
- self.logger.log_activity_time("GET REQUEST for PROJECT PARAMETERS", start, request_id=_get_request_id(request))
616
- return result
617
-
618
- @app.post(project_level_parameters_path, tags=["Project Metadata"], description=parameters_description)
619
- async def get_project_parameters_with_post(
620
- request: Request, params: QueryModelForPostProjectParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
621
- ) -> arm.ParametersModel:
622
- start = time.time()
623
- params_model: BaseModel = params
624
- payload: dict = await request.json()
625
- result = await get_parameters_definition(
626
- None, "project", "", PermissionScope.PUBLIC, user, payload, params_model.model_dump()
627
- )
628
- self.logger.log_activity_time("POST REQUEST for PROJECT PARAMETERS", start, request_id=_get_request_id(request))
629
- return result
630
-
631
- # Dataset Results API Helpers
632
- async def get_dataset_results_helper(
633
- dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
634
- ) -> DatasetResult:
635
- return await self.project.dataset(dataset, selections=dict(selections), user=user)
636
-
637
- dataset_results_cache_size = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_SIZE, 128))
638
- dataset_results_cache_ttl = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_TTL_MINUTES, 60))
639
- dataset_results_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
640
-
641
- async def get_dataset_results_cachable(
642
- dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
643
- ) -> DatasetResult:
644
- return await do_cachable_action(dataset_results_cache, get_dataset_results_helper, dataset, user, selections)
645
-
646
- async def get_dataset_results_definition(
647
- dataset_name: str, user: BaseUser | None, all_request_params: dict, params: Mapping
648
- ) -> arm.DatasetResultModel:
649
- self._validate_request_params(all_request_params, params)
650
-
651
- get_dataset_function = get_dataset_results_helper if self.no_cache else get_dataset_results_cachable
652
- uncached_keys = {"x_verify_params", "x_orientation", "x_select", "x_limit", "x_offset"}
653
- selections = get_selections_as_immutable(params, uncached_keys)
654
- result = await get_dataset_function(dataset_name, user, selections)
655
-
656
- orientation = params.get("x_orientation", "records")
657
- raw_select = params.get("x_select")
658
- select = tuple(raw_select) if raw_select is not None else tuple()
659
- limit = params.get("x_limit", 1000)
660
- offset = params.get("x_offset", 0)
661
- return arm.DatasetResultModel(**result.to_json(orientation, select, limit, offset))
662
-
663
- # Dashboard Results API Helpers
664
- async def get_dashboard_results_helper(
665
- dashboard: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
666
- ) -> Dashboard:
667
- return await self.project.dashboard(dashboard, selections=dict(selections), user=user)
668
-
669
- dashboard_results_cache_size = int(self.env_vars.get(c.SQRL_DASHBOARDS_CACHE_SIZE, 128))
670
- dashboard_results_cache_ttl = int(self.env_vars.get(c.SQRL_DASHBOARDS_CACHE_TTL_MINUTES, 60))
671
- dashboard_results_cache = TTLCache(maxsize=dashboard_results_cache_size, ttl=dashboard_results_cache_ttl*60)
672
-
673
- async def get_dashboard_results_cachable(
674
- dashboard: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
675
- ) -> Dashboard:
676
- return await do_cachable_action(dashboard_results_cache, get_dashboard_results_helper, dashboard, user, selections)
677
-
678
- async def get_dashboard_results_definition(
679
- dashboard_name: str, user: BaseUser | None, all_request_params: dict, params: Mapping
680
- ) -> Response:
681
- self._validate_request_params(all_request_params, params)
682
-
683
- get_dashboard_function = get_dashboard_results_helper if self.no_cache else get_dashboard_results_cachable
684
- selections = get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
685
- dashboard_obj = await get_dashboard_function(dashboard_name, user, selections)
686
- if dashboard_obj._format == c.PNG:
687
- assert isinstance(dashboard_obj._content, bytes)
688
- result = Response(dashboard_obj._content, media_type="image/png")
689
- elif dashboard_obj._format == c.HTML:
690
- result = HTMLResponse(dashboard_obj._content)
691
- else:
692
- raise NotImplementedError()
693
- return result
694
-
695
- # Dataset Parameters and Results APIs
696
- for dataset_name, dataset_config in self.manifest_cfg.datasets.items():
697
- dataset_normalized = u.normalize_name_for_api(dataset_name)
698
- curr_parameters_path = dataset_parameters_path.format(dataset=dataset_normalized)
699
- curr_results_path = dataset_results_path.format(dataset=dataset_normalized)
700
-
701
- validate_parameters_list(dataset_config.parameters, "Dataset")
702
-
703
- QueryModelForGetParams, QueryModelForPostParams = get_query_models_for_parameters(dataset_config.parameters)
704
- QueryModelForGetDataset, QueryModelForPostDataset = get_query_models_for_dataset(dataset_config.parameters)
705
-
706
- @app.get(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=parameters_description, response_class=JSONResponse)
707
- async def get_dataset_parameters(
708
- request: Request, params: QueryModelForGetParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
709
- ) -> arm.ParametersModel:
710
- start = time.time()
711
- curr_dataset_name = get_dataset_name(request, -2)
712
- parameters_list = self.manifest_cfg.datasets[curr_dataset_name].parameters
713
- scope = self.manifest_cfg.datasets[curr_dataset_name].scope
714
- result = await get_parameters_definition(
715
- parameters_list, "dataset", curr_dataset_name, scope, user, dict(request.query_params), asdict(params)
716
- )
717
- self.logger.log_activity_time("GET REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
718
- return result
719
-
720
- @app.post(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=parameters_description, response_class=JSONResponse)
721
- async def get_dataset_parameters_with_post(
722
- request: Request, params: QueryModelForPostParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
723
- ) -> arm.ParametersModel:
724
- start = time.time()
725
- curr_dataset_name = get_dataset_name(request, -2)
726
- parameters_list = self.manifest_cfg.datasets[curr_dataset_name].parameters
727
- scope = self.manifest_cfg.datasets[curr_dataset_name].scope
728
- params: BaseModel = params
729
- payload: dict = await request.json()
730
- result = await get_parameters_definition(
731
- parameters_list, "dataset", curr_dataset_name, scope, user, payload, params.model_dump()
732
- )
733
- self.logger.log_activity_time("POST REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
734
- return result
735
-
736
- @app.get(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
737
- async def get_dataset_results(
738
- request: Request, params: QueryModelForGetDataset, user: BaseUser | None = Depends(get_current_user) # type: ignore
739
- ) -> arm.DatasetResultModel:
740
- start = time.time()
741
- curr_dataset_name = get_dataset_name(request, -1)
742
- result = await get_dataset_results_definition(curr_dataset_name, user, dict(request.query_params), asdict(params))
743
- self.logger.log_activity_time("GET REQUEST for DATASET RESULTS", start, request_id=_get_request_id(request))
744
- return result
745
-
746
- @app.post(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
747
- async def get_dataset_results_with_post(
748
- request: Request, params: QueryModelForPostDataset, user: BaseUser | None = Depends(get_current_user) # type: ignore
749
- ) -> arm.DatasetResultModel:
750
- start = time.time()
751
- curr_dataset_name = get_dataset_name(request, -1)
752
- params: BaseModel = params
753
- payload: dict = await request.json()
754
- result = await get_dataset_results_definition(curr_dataset_name, user, payload, params.model_dump())
755
- self.logger.log_activity_time("POST REQUEST for DATASET RESULTS", start, request_id=_get_request_id(request))
756
- return result
757
-
758
- # Dashboard Parameters and Results APIs
759
- for dashboard_name, dashboard in self.dashboards.items():
760
- dashboard_normalized = u.normalize_name_for_api(dashboard_name)
761
- curr_parameters_path = dashboard_parameters_path.format(dashboard=dashboard_normalized)
762
- curr_results_path = dashboard_results_path.format(dashboard=dashboard_normalized)
763
-
764
- validate_parameters_list(dashboard.config.parameters, "Dashboard")
765
-
766
- QueryModelForGetParams, QueryModelForPostParams = get_query_models_for_parameters(dashboard.config.parameters)
767
- QueryModelForGetDash, QueryModelForPostDash = get_query_models_for_dashboard(dashboard.config.parameters)
768
-
769
- @app.get(curr_parameters_path, tags=[f"Dashboard '{dashboard_name}'"], description=parameters_description, response_class=JSONResponse)
770
- async def get_dashboard_parameters(
771
- request: Request, params: QueryModelForGetParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
772
- ) -> arm.ParametersModel:
773
- start = time.time()
774
- curr_dashboard_name = get_dashboard_name(request, -2)
775
- parameters_list = self.dashboards[curr_dashboard_name].config.parameters
776
- scope = self.dashboards[curr_dashboard_name].config.scope
777
- result = await get_parameters_definition(
778
- parameters_list, "dashboard", curr_dashboard_name, scope, user, dict(request.query_params), asdict(params)
779
- )
780
- self.logger.log_activity_time("GET REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
781
- return result
782
-
783
- @app.post(curr_parameters_path, tags=[f"Dashboard '{dashboard_name}'"], description=parameters_description, response_class=JSONResponse)
784
- async def get_dashboard_parameters_with_post(
785
- request: Request, params: QueryModelForPostParams, user: BaseUser | None = Depends(get_current_user) # type: ignore
786
- ) -> arm.ParametersModel:
787
- start = time.time()
788
- curr_dashboard_name = get_dashboard_name(request, -2)
789
- parameters_list = self.dashboards[curr_dashboard_name].config.parameters
790
- scope = self.dashboards[curr_dashboard_name].config.scope
791
- params: BaseModel = params
792
- payload: dict = await request.json()
793
- result = await get_parameters_definition(
794
- parameters_list, "dashboard", curr_dashboard_name, scope, user, payload, params.model_dump()
795
- )
796
- self.logger.log_activity_time("POST REQUEST for PARAMETERS", start, request_id=_get_request_id(request))
797
- return result
798
-
799
- @app.get(curr_results_path, tags=[f"Dashboard '{dashboard_name}'"], description=dashboard.config.description, response_class=Response)
800
- async def get_dashboard_results(
801
- request: Request, params: QueryModelForGetDash, user: BaseUser | None = Depends(get_current_user) # type: ignore
802
- ) -> Response:
803
- start = time.time()
804
- curr_dashboard_name = get_dashboard_name(request, -1)
805
- result = await get_dashboard_results_definition(curr_dashboard_name, user, dict(request.query_params), asdict(params))
806
- self.logger.log_activity_time("GET REQUEST for DASHBOARD RESULTS", start, request_id=_get_request_id(request))
807
- return result
808
-
809
- @app.post(curr_results_path, tags=[f"Dashboard '{dashboard_name}'"], description=dashboard.config.description, response_class=Response)
810
- async def get_dashboard_results_with_post(
811
- request: Request, params: QueryModelForPostDash, user: BaseUser | None = Depends(get_current_user) # type: ignore
812
- ) -> Response:
813
- start = time.time()
814
- curr_dashboard_name = get_dashboard_name(request, -1)
815
- params: BaseModel = params
816
- payload: dict = await request.json()
817
- result = await get_dashboard_results_definition(curr_dashboard_name, user, payload, params.model_dump())
818
- self.logger.log_activity_time("POST REQUEST for DASHBOARD RESULTS", start, request_id=_get_request_id(request))
819
- return result
820
-
821
- # Build Project API
822
- @app.post(project_metadata_path + '/build', tags=["Data Management"], summary="Build or update the virtual data environment for the project")
823
- async def build(user: BaseUser | None = Depends(get_current_user)): # type: ignore
824
- if not self.authenticator.can_user_access_scope(user, PermissionScope.PRIVATE):
825
- raise InvalidInputError(26, f"User '{user}' does not have permission to build the virtual data environment")
826
- await self.project.build(stage_file=True)
827
- return Response(status_code=status.HTTP_200_OK)
828
-
829
- # Query Models API
830
- query_models_path = project_metadata_path + '/query-models'
831
- QueryModelForQueryModels, QueryModelForPostQueryModels = get_query_models_for_querying_models()
832
-
833
- async def query_models_helper(
834
- sql_query: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
835
- ) -> DatasetResult:
836
- return await self.project.query_models(sql_query, selections=dict(selections), user=user)
837
-
838
- async def query_models_cachable(
839
- sql_query: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
840
- ) -> DatasetResult:
841
- # Share the same cache for dataset results
842
- return await do_cachable_action(dataset_results_cache, query_models_helper, sql_query, user, selections)
843
-
844
- async def query_models_definition(
845
- user: BaseUser | None, all_request_params: dict, params: Mapping
846
- ) -> arm.DatasetResultModel:
847
- self._validate_request_params(all_request_params, params)
848
-
849
- if not self.authenticator.can_user_access_scope(user, PermissionScope.PRIVATE):
850
- raise InvalidInputError(27, f"User '{user}' does not have permission to query data models")
851
- sql_query = params.get("x_sql_query")
852
- if sql_query is None:
853
- raise InvalidInputError(203, "SQL query must be provided")
854
-
855
- query_models_function = query_models_helper if self.no_cache else query_models_cachable
856
- uncached_keys = {"x_verify_params", "x_sql_query", "x_orientation", "x_limit", "x_offset"}
857
- selections = get_selections_as_immutable(params, uncached_keys)
858
- result = await query_models_function(sql_query, user, selections)
859
-
860
- orientation = params.get("x_orientation", "records")
861
- limit = params.get("x_limit", 1000)
862
- offset = params.get("x_offset", 0)
863
- return arm.DatasetResultModel(**result.to_json(orientation, tuple(), limit, offset))
305
+ return response
864
306
 
865
- @app.get(query_models_path, tags=["Data Management"], response_class=JSONResponse)
866
- async def query_models(
867
- request: Request, params: QueryModelForQueryModels, user: BaseUser | None = Depends(get_current_user) # type: ignore
868
- ) -> arm.DatasetResultModel:
869
- start = time.time()
870
- result = await query_models_definition(user, dict(request.query_params), asdict(params))
871
- self.logger.log_activity_time("GET REQUEST for QUERY MODELS", start, request_id=_get_request_id(request))
872
- return result
873
-
874
- @app.post(query_models_path, tags=["Data Management"], response_class=JSONResponse)
875
- async def query_models_with_post(
876
- request: Request, params: QueryModelForPostQueryModels, user: BaseUser | None = Depends(get_current_user) # type: ignore
877
- ) -> arm.DatasetResultModel:
878
- start = time.time()
879
- params: BaseModel = params
880
- payload: dict = await request.json()
881
- result = await query_models_definition(user, payload, params.model_dump())
882
- self.logger.log_activity_time("POST REQUEST for QUERY MODELS", start, request_id=_get_request_id(request))
883
- return result
884
-
307
+ # Configure CORS with smart credential handling
308
+ # Get allowed origins for credentials from environment variable
309
+ credential_origins_env = self.env_vars.get(c.SQRL_AUTH_CREDENTIAL_ORIGINS, "https://squirrels-analytics.github.io")
310
+ allowed_credential_origins = [origin.strip() for origin in credential_origins_env.split(",") if origin.strip()]
311
+
312
+ # Allow both underscore and dash versions of configurable headers
313
+ configurables_as_headers = []
314
+ for name in self.manifest_cfg.configurables.keys():
315
+ configurables_as_headers.append(f"x-config-{name}") # underscore version
316
+ configurables_as_headers.append(f"x-config-{u.normalize_name_for_api(name)}") # dash version
317
+
318
+ app.add_middleware(SmartCORSMiddleware, allowed_credential_origins=allowed_credential_origins, configurables_as_headers=configurables_as_headers)
319
+
320
+ # Setup route modules
321
+ # self.oauth2_routes.setup_routes(app, squirrels_version_path)
322
+ self.auth_routes.setup_routes(app, squirrels_version_path)
323
+ get_parameters_definition = self.project_routes.setup_routes(app, self.mcp, project_metadata_path, project_name, project_version, project_label, param_fields)
324
+ self.data_management_routes.setup_routes(app, project_metadata_path, param_fields)
325
+ self.dataset_routes.setup_routes(app, self.mcp, project_metadata_path, project_name, project_label, param_fields, get_parameters_definition)
326
+ self.dashboard_routes.setup_routes(app, project_metadata_path, param_fields, get_parameters_definition)
327
+ app.mount(project_metadata_path, self.mcp.streamable_http_app())
328
+
329
+ # Mount static files from public directory if it exists
330
+ # This allows users to serve static assets (images, CSS, JS, etc.) from {project_path}/public/
331
+ public_dir = Path(self.project._filepath) / c.PUBLIC_FOLDER
332
+ if public_dir.exists() and public_dir.is_dir():
333
+ app.mount("/public", StaticFiles(directory=str(public_dir)), name="public")
334
+ self.logger.info(f"Mounted static files from: {public_dir}")
335
+
885
336
  # Add Root Path Redirection to Squirrels Studio
886
337
  full_hostname = f"http://{uvicorn_args.host}:{uvicorn_args.port}"
887
- encoded_hostname = urllib.parse.quote(full_hostname, safe="")
888
- squirrels_studio_url = f"https://squirrels-analytics.github.io/squirrels-studio/#/login?host={encoded_hostname}&projectName={project_name}&projectVersion={project_version}"
889
-
338
+ squirrels_studio_path = f"/project/{project_name_for_api}/{project_version}/studio"
339
+ templates = Jinja2Templates(directory=str(Path(__file__).parent / "_package_data" / "templates"))
340
+
341
+ @app.get(squirrels_studio_path, include_in_schema=False)
342
+ async def squirrels_studio():
343
+ default_studio_path = "https://squirrels-analytics.github.io/squirrels-studio-v1"
344
+ sqrl_studio_base_url = self.env_vars.get(c.SQRL_STUDIO_BASE_URL, default_studio_path)
345
+ context = {
346
+ "sqrl_studio_base_url": sqrl_studio_base_url,
347
+ "project_name": project_name_for_api,
348
+ "project_version": project_version,
349
+ }
350
+ return HTMLResponse(content=templates.get_template("squirrels_studio.html").render(context))
351
+
890
352
  @app.get("/", include_in_schema=False)
891
353
  async def redirect_to_studio():
892
- return RedirectResponse(url=squirrels_studio_url)
893
-
354
+ return RedirectResponse(url=squirrels_studio_path)
355
+
356
+ self.logger.log_activity_time("creating app server", start)
357
+
894
358
  # Run the API Server
895
359
  import uvicorn
896
360
 
897
361
  print("\nWelcome to the Squirrels Data Application!\n")
898
- print(f"- Application UI: {squirrels_studio_url}")
362
+ print(f"- Application UI (Squirrels Studio): {full_hostname}{squirrels_studio_path}")
899
363
  print(f"- API Docs (with ReDoc): {full_hostname}{project_metadata_path}/redoc")
900
364
  print(f"- API Docs (with Swagger UI): {full_hostname}{project_metadata_path}/docs")
365
+ print(f"- MCP Server URL: {full_hostname}{project_metadata_path}/mcp")
901
366
  print()
902
367
 
903
- self.logger.log_activity_time("creating app server", start)
904
- uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port)
368
+ uvicorn.run(app, host=uvicorn_args.host, port=uvicorn_args.port, proxy_headers=True, forwarded_allow_ips="*")