squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. dateutils/__init__.py +6 -0
  2. dateutils/_enums.py +25 -0
  3. squirrels/dateutils.py → dateutils/_implementation.py +409 -380
  4. dateutils/types.py +6 -0
  5. squirrels/__init__.py +21 -18
  6. squirrels/_api_routes/__init__.py +5 -0
  7. squirrels/_api_routes/auth.py +337 -0
  8. squirrels/_api_routes/base.py +196 -0
  9. squirrels/_api_routes/dashboards.py +156 -0
  10. squirrels/_api_routes/data_management.py +148 -0
  11. squirrels/_api_routes/datasets.py +220 -0
  12. squirrels/_api_routes/project.py +289 -0
  13. squirrels/_api_server.py +552 -134
  14. squirrels/_arguments/__init__.py +0 -0
  15. squirrels/_arguments/init_time_args.py +83 -0
  16. squirrels/_arguments/run_time_args.py +111 -0
  17. squirrels/_auth.py +777 -0
  18. squirrels/_command_line.py +239 -107
  19. squirrels/_compile_prompts.py +147 -0
  20. squirrels/_connection_set.py +94 -0
  21. squirrels/_constants.py +141 -64
  22. squirrels/_dashboards.py +179 -0
  23. squirrels/_data_sources.py +570 -0
  24. squirrels/_dataset_types.py +91 -0
  25. squirrels/_env_vars.py +209 -0
  26. squirrels/_exceptions.py +29 -0
  27. squirrels/_http_error_responses.py +52 -0
  28. squirrels/_initializer.py +319 -110
  29. squirrels/_logging.py +121 -0
  30. squirrels/_manifest.py +357 -187
  31. squirrels/_mcp_server.py +578 -0
  32. squirrels/_model_builder.py +69 -0
  33. squirrels/_model_configs.py +74 -0
  34. squirrels/_model_queries.py +52 -0
  35. squirrels/_models.py +1201 -0
  36. squirrels/_package_data/base_project/.env +7 -0
  37. squirrels/_package_data/base_project/.env.example +44 -0
  38. squirrels/_package_data/base_project/connections.yml +16 -0
  39. squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
  40. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
  41. squirrels/_package_data/base_project/docker/.dockerignore +16 -0
  42. squirrels/_package_data/base_project/docker/Dockerfile +16 -0
  43. squirrels/_package_data/base_project/docker/compose.yml +7 -0
  44. squirrels/_package_data/base_project/duckdb_init.sql +10 -0
  45. squirrels/_package_data/base_project/gitignore +13 -0
  46. squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
  47. squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
  48. squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
  49. squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
  50. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
  51. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
  52. squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
  53. squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
  54. squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
  55. squirrels/_package_data/base_project/models/sources.yml +38 -0
  56. squirrels/_package_data/base_project/parameters.yml +142 -0
  57. squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
  58. squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
  59. squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
  60. squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
  61. squirrels/_package_data/base_project/resources/expenses.db +0 -0
  62. squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
  63. squirrels/_package_data/base_project/resources/weather.db +0 -0
  64. squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
  65. squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
  66. squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
  67. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
  68. squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
  69. squirrels/_package_data/base_project/tmp/.gitignore +2 -0
  70. squirrels/_package_data/templates/login_successful.html +53 -0
  71. squirrels/_package_data/templates/squirrels_studio.html +22 -0
  72. squirrels/_package_loader.py +29 -0
  73. squirrels/_parameter_configs.py +592 -0
  74. squirrels/_parameter_options.py +348 -0
  75. squirrels/_parameter_sets.py +207 -0
  76. squirrels/_parameters.py +1703 -0
  77. squirrels/_project.py +796 -0
  78. squirrels/_py_module.py +122 -0
  79. squirrels/_request_context.py +33 -0
  80. squirrels/_schemas/__init__.py +0 -0
  81. squirrels/_schemas/auth_models.py +83 -0
  82. squirrels/_schemas/query_param_models.py +70 -0
  83. squirrels/_schemas/request_models.py +26 -0
  84. squirrels/_schemas/response_models.py +286 -0
  85. squirrels/_seeds.py +97 -0
  86. squirrels/_sources.py +112 -0
  87. squirrels/_utils.py +540 -149
  88. squirrels/_version.py +1 -3
  89. squirrels/arguments.py +7 -0
  90. squirrels/auth.py +4 -0
  91. squirrels/connections.py +3 -0
  92. squirrels/dashboards.py +3 -0
  93. squirrels/data_sources.py +14 -282
  94. squirrels/parameter_options.py +13 -189
  95. squirrels/parameters.py +14 -801
  96. squirrels/types.py +18 -0
  97. squirrels-0.6.0.post0.dist-info/METADATA +148 -0
  98. squirrels-0.6.0.post0.dist-info/RECORD +101 -0
  99. {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
  100. {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
  101. squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
  102. squirrels/_credentials_manager.py +0 -87
  103. squirrels/_module_loader.py +0 -37
  104. squirrels/_parameter_set.py +0 -151
  105. squirrels/_renderer.py +0 -286
  106. squirrels/_timed_imports.py +0 -37
  107. squirrels/connection_set.py +0 -126
  108. squirrels/package_data/base_project/.gitignore +0 -4
  109. squirrels/package_data/base_project/connections.py +0 -21
  110. squirrels/package_data/base_project/database/sample_database.db +0 -0
  111. squirrels/package_data/base_project/database/seattle_weather.db +0 -0
  112. squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
  113. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
  114. squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
  115. squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
  116. squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
  117. squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
  118. squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
  119. squirrels/package_data/base_project/squirrels.yaml +0 -26
  120. squirrels/package_data/static/favicon.ico +0 -0
  121. squirrels/package_data/static/script.js +0 -234
  122. squirrels/package_data/static/style.css +0 -110
  123. squirrels/package_data/templates/index.html +0 -32
  124. squirrels-0.1.0.dist-info/LICENSE +0 -22
  125. squirrels-0.1.0.dist-info/METADATA +0 -67
  126. squirrels-0.1.0.dist-info/RECORD +0 -40
  127. squirrels-0.1.0.dist-info/top_level.txt +0 -1
@@ -0,0 +1,578 @@
1
+ """
2
+ MCP Server implementation using the official MCP Python SDK low-level APIs.
3
+
4
+ This module provides the MCP server for Squirrels projects, exposing:
5
+ - Tools: get_data_catalog, get_dataset_parameters, get_dataset_results
6
+ - Resources: sqrl://data-catalog
7
+ """
8
+ from typing import Any, Protocol
9
+ from collections.abc import AsyncIterator
10
+ from contextlib import asynccontextmanager
11
+ from textwrap import dedent
12
+ from pydantic import AnyUrl
13
+ from starlette.applications import Starlette
14
+ from starlette.requests import Request
15
+ from starlette.routing import Mount
16
+ from starlette.middleware.base import BaseHTTPMiddleware
17
+ from starlette.types import ASGIApp
18
+ from mcp.server.lowlevel import Server
19
+ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
20
+ import mcp.types as types
21
+ import json
22
+
23
+ from . import _utils as u
24
+ from ._schemas.auth_models import AbstractUser
25
+ from ._schemas.request_models import McpRequestHeaders
26
+ from ._exceptions import InvalidInputError
27
+ from ._http_error_responses import invalid_input_error_to_json_response
28
+ from ._schemas import response_models as rm
29
+ from ._dataset_types import DatasetResult, DatasetResultFormat
30
+ from ._api_routes.base import RouteBase
31
+
32
+
33
+ class GetUserFromHeaders(Protocol):
34
+ def __call__(self, api_key: str | None, bearer_token: str | None) -> tuple[AbstractUser, float | None]:
35
+ ...
36
+
37
+ class GetDataCatalogForMcp(Protocol):
38
+ async def __call__(self, user: AbstractUser) -> rm.CatalogModelForMcp:
39
+ ...
40
+
41
+ class GetDatasetParametersForMcp(Protocol):
42
+ async def __call__(
43
+ self, dataset: str, parameter_name: str, selected_ids: str | list[str], user: AbstractUser
44
+ ) -> rm.ParametersModel:
45
+ ...
46
+
47
+ class GetDatasetResultsForMcp(Protocol):
48
+ async def __call__(
49
+ self, dataset: str, parameters: dict[str, Any], sql_query: str | None, user: AbstractUser, configurables: tuple[tuple[str, str], ...]
50
+ ) -> DatasetResult:
51
+ ...
52
+
53
+
54
+ class McpServerBuilder:
55
+ """
56
+ Builder for the MCP server that exposes Squirrels tools and resources.
57
+
58
+ This class is responsible for:
59
+ - Creating the low-level MCP Server
60
+ - Registering list_tools, call_tool, list_resources, read_resource handlers
61
+ - Creating the StreamableHTTPSessionManager for HTTP transport
62
+ - Providing the ASGI app and lifespan manager
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ project_name: str,
68
+ project_label: str,
69
+ max_rows_for_ai: int,
70
+ get_user_from_headers: GetUserFromHeaders,
71
+ get_data_catalog_for_mcp: GetDataCatalogForMcp,
72
+ get_dataset_parameters_for_mcp: GetDatasetParametersForMcp,
73
+ get_dataset_results_for_mcp: GetDatasetResultsForMcp,
74
+ *,
75
+ enforce_oauth_bearer: bool = False,
76
+ oauth_resource_metadata_path: str = "/.well-known/oauth-protected-resource",
77
+ www_authenticate_strip_path_suffix: str = "/mcp",
78
+ ):
79
+ """
80
+ Initialize the MCP server builder.
81
+
82
+ Args:
83
+ project_name: The name of the Squirrels project
84
+ project_label: The human-readable label of the project
85
+ max_rows_for_ai: Maximum number of rows to return for AI tools
86
+ get_data_catalog_for_mcp: Async function to get the data catalog
87
+ get_dataset_parameters_for_mcp: Async function to get dataset parameters
88
+ get_dataset_results_for_mcp: Async function to get dataset results
89
+ """
90
+ self.project_name = project_name
91
+ self.project_label = project_label
92
+ self.max_rows_for_ai = max_rows_for_ai
93
+ self.default_for_limit = min(self.max_rows_for_ai, 10)
94
+
95
+ self.enforce_oauth_bearer = enforce_oauth_bearer
96
+ self.oauth_resource_metadata_path = oauth_resource_metadata_path
97
+ self.www_authenticate_strip_path_suffix = www_authenticate_strip_path_suffix
98
+
99
+ self._get_user_from_headers = get_user_from_headers
100
+ self._get_data_catalog_for_mcp = get_data_catalog_for_mcp
101
+ self._get_dataset_parameters_for_mcp = get_dataset_parameters_for_mcp
102
+ self._get_dataset_results_for_mcp = get_dataset_results_for_mcp
103
+
104
+ # Tool names
105
+ self.catalog_tool_name = f"get_data_catalog_from_{project_name}"
106
+ self.parameters_tool_name = f"get_dataset_parameters_from_{project_name}"
107
+ self.results_tool_name = f"get_dataset_results_from_{project_name}"
108
+
109
+ # Resource URI
110
+ self.catalog_resource_uri = "sqrl://data-catalog"
111
+ self.catalog_resource_name = f"data_catalog_from_{project_name}"
112
+
113
+ # Build the server
114
+ self._server = self._build_server()
115
+ self._session_manager = StreamableHTTPSessionManager(
116
+ app=self._server,
117
+ stateless=True,
118
+ json_response=True,
119
+ )
120
+
121
+ def _get_tool_annotations(
122
+ self, title: str, *, read_only: bool = True, destructive: bool = False,
123
+ idempotent: bool = True, open_world: bool = False
124
+ ) -> types.ToolAnnotations:
125
+ """Get the tool annotations for the given title."""
126
+ return types.ToolAnnotations(
127
+ title=title,
128
+ readOnlyHint=read_only,
129
+ destructiveHint=destructive,
130
+ idempotentHint=idempotent,
131
+ openWorldHint=open_world,
132
+ )
133
+
134
+ def _build_server(self) -> Server:
135
+ """Build and configure the low-level MCP Server."""
136
+ server = Server("Squirrels")
137
+
138
+ # Register handlers
139
+ server.list_tools()(self._list_tools)
140
+ server.call_tool()(self._call_tool)
141
+ server.list_resources()(self._list_resources)
142
+ server.read_resource()(self._read_resource)
143
+
144
+ return server
145
+
146
+ def _get_request_headers(self) -> McpRequestHeaders:
147
+ """
148
+ Get HTTP headers from the current MCP request context.
149
+
150
+ Uses server.request_context.request.headers to access headers
151
+ from the underlying HTTP request.
152
+ """
153
+ try:
154
+ request = self._server.request_context.request
155
+ if request is not None and hasattr(request, 'headers'):
156
+ return McpRequestHeaders(raw_headers=request.headers)
157
+ except (AttributeError, LookupError):
158
+ pass
159
+
160
+ return McpRequestHeaders()
161
+
162
+ def _get_request_metadata(self) -> dict[str, Any]:
163
+ """
164
+ Metadata of the current MCP request as a dictionary.
165
+
166
+ Returns:
167
+ A dictionary of the request metadata
168
+ """
169
+ request_metadata = self._server.request_context.meta
170
+ if request_metadata is None:
171
+ return {}
172
+ return request_metadata.model_dump(mode="json")
173
+
174
+ def _get_configurables(self, mcp_headers: McpRequestHeaders) -> tuple[tuple[str, str], ...]:
175
+ """
176
+ Extract configurables from headers and metadata.
177
+ """
178
+ prefix = "x-config-"
179
+ cfg_dict: dict[str, str] = {}
180
+
181
+ # 1. Extract from headers
182
+ for key, value in mcp_headers.raw_headers.items():
183
+ key_lower = str(key).lower()
184
+ if key_lower.startswith(prefix):
185
+ cfg_name_raw = key_lower[len(prefix):]
186
+ cfg_name_normalized = u.normalize_name(cfg_name_raw)
187
+
188
+ if cfg_name_normalized in cfg_dict:
189
+ raise InvalidInputError(
190
+ 400, "duplicate_configurable",
191
+ f"Configurable '{cfg_name_normalized}' specified multiple times in headers."
192
+ )
193
+ cfg_dict[cfg_name_normalized] = str(value)
194
+
195
+ # 2. Extract from metadata
196
+ metadata = self._get_request_metadata()
197
+ for key, value in metadata.items():
198
+ if key == "progressToken":
199
+ continue
200
+
201
+ cfg_name_normalized = u.normalize_name(key)
202
+ if cfg_name_normalized in cfg_dict:
203
+ raise InvalidInputError(
204
+ 400, "duplicate_configurable",
205
+ f"Configurable '{cfg_name_normalized}' specified multiple times (header and metadata)."
206
+ )
207
+ cfg_dict[cfg_name_normalized] = str(value)
208
+
209
+ return tuple(cfg_dict.items())
210
+
211
+ def _get_validated_user_for_request(self, mcp_headers: McpRequestHeaders) -> tuple[AbstractUser, float | None]:
212
+ """
213
+ Return the validated user for the current HTTP request.
214
+
215
+ If the MCP app runs with `enforce_oauth_bearer=True`, missing Bearer tokens
216
+ must produce an HTTP 401 (not an MCP tool error), so we raise InvalidInputError.
217
+ """
218
+ # Prefer values set by the HTTP middleware to avoid double validation.
219
+ try:
220
+ request = self._server.request_context.request
221
+ if request is not None and hasattr(request, "state"):
222
+ state = request.state
223
+ user = getattr(state, "sqrl_user", None)
224
+ expiry = getattr(state, "access_token_expiry", None)
225
+ if user is not None:
226
+ return user, expiry
227
+ except (AttributeError, LookupError):
228
+ pass
229
+
230
+ if self.enforce_oauth_bearer and not mcp_headers.bearer_token:
231
+ raise InvalidInputError(401, "user_required", "Authentication is required")
232
+
233
+ return self._get_user_from_headers(api_key=mcp_headers.api_key, bearer_token=mcp_headers.bearer_token)
234
+
235
+ async def _list_tools(self) -> list[types.Tool]:
236
+ """Return the list of available MCP tools."""
237
+ headers = self._get_request_headers()
238
+ feature_flags = headers.feature_flags
239
+ full_result_flag = "mcp-full-dataset-v1" in feature_flags
240
+
241
+ dataset_results_extended_description = dedent("""
242
+ The "offset" and "limit" arguments affect the "content" field, but not the "structuredContent" field, of this tool's result. Assume that you (the AI model) can only see the "content" field, but accessing this tool's result through code execution (if applicable) uses the "structuredContent" field. Note that the "sql_query" and "orientation" arguments still apply to both the "content" and "structuredContent" fields.
243
+ """).strip() if full_result_flag else ""
244
+
245
+ return [
246
+ types.Tool(
247
+ name=self.catalog_tool_name,
248
+ title=f"Data Catalog For {self.project_label}",
249
+ description=dedent(f"""
250
+ Use this tool to get the details of all datasets and parameters you can access in the Squirrels project '{self.project_name}'.
251
+
252
+ Unless the data catalog for this project has already been provided, use this tool at the start of each conversation.
253
+ """).strip(),
254
+ annotations=self._get_tool_annotations(title=f"Data Catalog For {self.project_label}"),
255
+ inputSchema={
256
+ "type": "object",
257
+ "properties": {},
258
+ "required": [],
259
+ },
260
+ # outputSchema=rm.CatalogModelForMcp.model_json_schema(),
261
+ ),
262
+ types.Tool(
263
+ name=self.parameters_tool_name,
264
+ title=f"Parameters Updates For {self.project_label}",
265
+ description=dedent(f"""
266
+ Use this tool to get updates for dataset parameters in the Squirrels project "{self.project_name}" when a selection is to be made on a parameter with `"trigger_refresh": true`.
267
+
268
+ For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool should be called to get the updates for other parameters such as "city".
269
+
270
+ Do not use this tool on parameters that do not have `"trigger_refresh": true`.
271
+ """).strip(),
272
+ annotations=self._get_tool_annotations(title=f"Parameters Updates For {self.project_label}"),
273
+ inputSchema={
274
+ "type": "object",
275
+ "properties": {
276
+ "dataset": {
277
+ "type": "string",
278
+ "description": "The name of the dataset whose parameters the trigger parameter will update",
279
+ },
280
+ "selected_ids": {
281
+ "type": "string",
282
+ "description": dedent("""
283
+ A JSON object (as string) with one key-value pair. The key is the name of the parameter triggering the refresh, and the value is the ID(s) of the selected option(s) for the parameter.
284
+ - If the parameter's widget_type is single_select, use a string for the ID of the selected option
285
+ - If the parameter's widget_type is multi_select, use an array of strings for the IDs of the selected options
286
+
287
+ An error is raised if this JSON object does not have exactly one key-value pair.
288
+ """).strip(),
289
+ },
290
+ },
291
+ "required": ["dataset", "selected_ids"],
292
+ },
293
+ # outputSchema=rm.ParametersModel.model_json_schema(),
294
+ ),
295
+ types.Tool(
296
+ name=self.results_tool_name,
297
+ title=f"Dataset Results For {self.project_label}",
298
+ description=dedent(f"""
299
+ Use this tool to get the dataset results as a JSON object for a dataset in the Squirrels project "{self.project_name}".
300
+
301
+ {dataset_results_extended_description}
302
+ """).strip(),
303
+ annotations=self._get_tool_annotations(title=f"Dataset Results For {self.project_label}"),
304
+ inputSchema={
305
+ "type": "object",
306
+ "properties": {
307
+ "dataset": {
308
+ "type": "string",
309
+ "description": "The name of the dataset to get results for",
310
+ },
311
+ "parameters": {
312
+ "type": "string",
313
+ "description": dedent("""
314
+ A JSON object (as string) containing key-value pairs for parameter name and selected value. The selected value to provide depends on the parameter widget type:
315
+ - If the parameter's widget_type is single_select, use a string for the ID of the selected option
316
+ - If the parameter's widget_type is multi_select, use an array of strings for the IDs of the selected options
317
+ - If the parameter's widget_type is date, use a string like "YYYY-MM-DD"
318
+ - If the parameter's widget_type is date_range, use array of strings like ["YYYY-MM-DD","YYYY-MM-DD"]
319
+ - If the parameter's widget_type is number, use a number like 1
320
+ - If the parameter's widget_type is number_range, use array of numbers like [1,100]
321
+ - If the parameter's widget_type is text, use a string for the text value
322
+ - Complex objects are NOT supported
323
+ """).strip(),
324
+ },
325
+ "sql_query": {
326
+ "type": ["string", "null"],
327
+ "description": dedent("""
328
+ A custom Polars SQL query to execute on the final dataset result.
329
+ - Use table name 'result' to reference the dataset result.
330
+ - Use this to apply transformations to the dataset result if needed (such as filtering, sorting, or selecting columns).
331
+ - If not provided, the dataset result is returned as is.
332
+ """).strip(),
333
+ "default": None,
334
+ },
335
+ "orientation": {
336
+ "type": "string",
337
+ "enum": ["rows", "columns", "records"],
338
+ "description": "The orientation of the dataset result. Options are 'rows', 'columns', and 'records'. Default is 'rows'.",
339
+ "default": "rows",
340
+ },
341
+ "offset": {
342
+ "type": "integer",
343
+ "description": "The number of rows to skip from first row. Applied after the sql_query. Default is 0.",
344
+ "default": 0,
345
+ },
346
+ "limit": {
347
+ "type": "integer",
348
+ "description": dedent(f"""
349
+ The maximum number of rows to return. Applied after the sql_query.
350
+ Default is {self.default_for_limit}. Maximum allowed value is {self.max_rows_for_ai}.
351
+ """).strip(),
352
+ "default": self.default_for_limit,
353
+ },
354
+ },
355
+ "required": ["dataset", "parameters"],
356
+ },
357
+ outputSchema=rm.DatasetResultModel.model_json_schema(),
358
+ ),
359
+ ]
360
+
361
+ def _get_dataset_and_parameters(self, arguments: dict[str, Any], *, params_key: str = "parameters") -> tuple[str, dict[str, Any]]:
362
+ """Get dataset and parameters from arguments.
363
+
364
+ Args:
365
+ arguments: The arguments from the tool call
366
+ params_key: The key of the parameters in the arguments
367
+
368
+ Returns:
369
+ A tuple of the dataset and parameters
370
+
371
+ Raises:
372
+ InvalidInputError: If the dataset or parameters are invalid
373
+ """
374
+ try:
375
+ dataset = str(arguments["dataset"])
376
+ except KeyError:
377
+ raise InvalidInputError(400, "invalid_dataset", "The 'dataset' argument is required.")
378
+
379
+ parameters_arg = str(arguments.get(params_key, "{}"))
380
+
381
+ # validate parameters argument
382
+ try:
383
+ parameters = json.loads(parameters_arg)
384
+ except json.JSONDecodeError:
385
+ parameters = None # error handled below
386
+
387
+ if not isinstance(parameters, dict):
388
+ raise InvalidInputError(400, "invalid_parameters", f"The '{params_key}' argument must be a JSON object.")
389
+
390
+ return dataset, parameters
391
+
392
+ async def _call_tool(self, name: str, arguments: dict[str, Any] | None) -> types.CallToolResult:
393
+ """Handle tool calls by dispatching to the appropriate function.
394
+
395
+ Returns structured data (dict) directly for successful calls, which the MCP
396
+ framework will serialize to JSON. For errors, returns CallToolResult with isError=True.
397
+ """
398
+ arguments = arguments or {}
399
+
400
+ try:
401
+ mcp_headers = self._get_request_headers()
402
+ user, _ = self._get_validated_user_for_request(mcp_headers)
403
+
404
+ feature_flags = mcp_headers.feature_flags
405
+ full_result_flag = "mcp-full-dataset-v1" in feature_flags
406
+
407
+ if name == self.catalog_tool_name:
408
+ result = await self._get_data_catalog_for_mcp(user)
409
+ return types.CallToolResult(
410
+ content=[types.TextContent(type="text", text=result.model_dump_json(by_alias=True))],
411
+ structuredContent=result.model_dump(mode="json", by_alias=True),
412
+ )
413
+
414
+ elif name == self.parameters_tool_name:
415
+ dataset, parameters = self._get_dataset_and_parameters(arguments, params_key="selected_ids")
416
+
417
+ # validate parameters is a single key-value pair
418
+ if len(parameters) != 1:
419
+ raise InvalidInputError(
420
+ 400, "invalid_selected_ids",
421
+ "The 'selected_ids' argument must have exactly one key-value pair."
422
+ )
423
+
424
+ # validate selected ids is a string or list of strings
425
+ parameter_name, selected_ids = next(iter(parameters.items()))
426
+ if not isinstance(selected_ids, (str, list)):
427
+ raise InvalidInputError(
428
+ 400, "invalid_selected_ids",
429
+ f"The selected ids of the parameter '{parameter_name}' must be a string or list of strings."
430
+ )
431
+
432
+ # get dataset parameters
433
+ result = await self._get_dataset_parameters_for_mcp(dataset, parameter_name, selected_ids, user)
434
+ return types.CallToolResult(
435
+ content=[types.TextContent(type="text", text=result.model_dump_json(by_alias=True))],
436
+ structuredContent=result.model_dump(mode="json", by_alias=True),
437
+ )
438
+
439
+ elif name == self.results_tool_name:
440
+ dataset, parameters = self._get_dataset_and_parameters(arguments, params_key="parameters")
441
+
442
+ # validate sql_query argument
443
+ sql_query_arg = arguments.get("sql_query")
444
+ sql_query = str(sql_query_arg) if sql_query_arg else None
445
+
446
+ # validate orientation argument
447
+ result_format = RouteBase.extract_orientation_offset_and_limit(arguments, key_prefix="", default_orientation="rows", default_limit=self.default_for_limit)
448
+ orientation, limit = result_format.orientation, result_format.limit
449
+ if limit > self.max_rows_for_ai:
450
+ raise InvalidInputError(400, "invalid_limit", f"The 'limit' argument must be less than or equal to {self.max_rows_for_ai}.")
451
+
452
+ # get dataset result object
453
+ configurables = self._get_configurables(mcp_headers)
454
+ result_obj = await self._get_dataset_results_for_mcp(
455
+ dataset, parameters, sql_query, user, configurables
456
+ )
457
+
458
+ # format dataset result object
459
+ structured_result = result_obj.to_json(result_format)
460
+ result_model = rm.DatasetResultModel(**structured_result)
461
+
462
+ if full_result_flag:
463
+ full_result_format = DatasetResultFormat(orientation, 0, None)
464
+ structured_result = result_obj.to_json(full_result_format)
465
+
466
+ return types.CallToolResult(
467
+ content=[types.TextContent(type="text", text=result_model.model_dump_json(by_alias=True))],
468
+ structuredContent=structured_result,
469
+ )
470
+
471
+ else:
472
+ return types.CallToolResult(
473
+ content=[types.TextContent(type="text", text=f"Unknown tool: {name}")],
474
+ isError=True
475
+ )
476
+
477
+ except InvalidInputError as e:
478
+ # If auth is required, surface HTTP 401s as real HTTP responses.
479
+ if e.status_code == 401:
480
+ raise
481
+ return types.CallToolResult(
482
+ content=[types.TextContent(type="text", text=f"Error: {e.error_description}")],
483
+ isError=True,
484
+ )
485
+ except Exception as e:
486
+ return types.CallToolResult(
487
+ content=[types.TextContent(type="text", text=f"Error: {str(e)}")],
488
+ isError=True
489
+ )
490
+
491
+ async def _list_resources(self) -> list[types.Resource]:
492
+ """Return the list of available MCP resources."""
493
+ return [
494
+ types.Resource(
495
+ uri=AnyUrl(self.catalog_resource_uri),
496
+ name=self.catalog_resource_name,
497
+ description=f"Details of all datasets and parameters you can access in the Squirrels project '{self.project_name}'.",
498
+ ),
499
+ ]
500
+
501
+ async def _read_resource(self, uri: AnyUrl) -> str | bytes:
502
+ """Read the content of a resource."""
503
+ mcp_headers = self._get_request_headers()
504
+
505
+ if str(uri) == self.catalog_resource_uri:
506
+ user, _ = self._get_validated_user_for_request(mcp_headers)
507
+ result = await self._get_data_catalog_for_mcp(user)
508
+ return result.model_dump_json(by_alias=True)
509
+ else:
510
+ raise ValueError(f"Unknown resource URI: {uri}")
511
+
512
+ @asynccontextmanager
513
+ async def lifespan(self, app: object | None = None) -> AsyncIterator[None]:
514
+ """
515
+ Async context manager for the MCP session manager lifecycle.
516
+
517
+ Use this in the FastAPI app lifespan to ensure proper startup/shutdown.
518
+ """
519
+ async with self._session_manager.run():
520
+ yield
521
+
522
+ def get_asgi_app(self) -> ASGIApp:
523
+ """
524
+ Get the ASGI app for the MCP server.
525
+ """
526
+ async def _invalid_input_handler(request: Request, exc: InvalidInputError):
527
+ # When mounted under `/mcp` (or a larger mount path ending in `/mcp`),
528
+ # strip only that mount suffix so the resource_metadata URL points to
529
+ # the top-level endpoint.
530
+ return invalid_input_error_to_json_response(
531
+ request,
532
+ exc,
533
+ oauth_resource_metadata_path=self.oauth_resource_metadata_path,
534
+ strip_path_suffix=self.www_authenticate_strip_path_suffix,
535
+ is_mcp=True,
536
+ )
537
+
538
+ app = Starlette(
539
+ routes=[
540
+ Mount("/", app=self._session_manager.handle_request),
541
+ ],
542
+ lifespan=self.lifespan,
543
+ exception_handlers={InvalidInputError: _invalid_input_handler},
544
+ )
545
+
546
+ builder = self
547
+
548
+ class _McpOAuthGateMiddleware(BaseHTTPMiddleware):
549
+ async def dispatch(self, request: Request, call_next):
550
+ try:
551
+ if builder.enforce_oauth_bearer:
552
+ auth_header = request.headers.get("authorization", "")
553
+ token = None
554
+ if auth_header.lower().startswith("bearer "):
555
+ token = auth_header[7:].strip()
556
+
557
+ if not token:
558
+ raise InvalidInputError(401, "user_required", "Authentication is required")
559
+
560
+ user, expiry = builder._get_user_from_headers(api_key=None, bearer_token=token)
561
+ request.state.sqrl_user = user
562
+ request.state.access_token_expiry = expiry
563
+
564
+ return await call_next(request)
565
+ except InvalidInputError as exc:
566
+ # Starlette's BaseHTTPMiddleware may bypass exception handlers for
567
+ # exceptions raised within dispatch; handle explicitly here.
568
+ return invalid_input_error_to_json_response(
569
+ request,
570
+ exc,
571
+ oauth_resource_metadata_path=builder.oauth_resource_metadata_path,
572
+ strip_path_suffix=builder.www_authenticate_strip_path_suffix,
573
+ is_mcp=True,
574
+ )
575
+
576
+ app.add_middleware(_McpOAuthGateMiddleware)
577
+ return app
578
+
@@ -0,0 +1,69 @@
1
+ from dataclasses import dataclass, field
2
+ import duckdb, time
3
+
4
+ from . import _utils as u, _connection_set as cs, _models as m
5
+
6
+
7
+ @dataclass
8
+ class ModelBuilder:
9
+ _datalake_db_path: str
10
+ _conn_set: cs.ConnectionSet
11
+ _static_models: dict[str, m.StaticModel]
12
+ _conn_args: cs.ConnectionsArgs
13
+ _logger: u.Logger = field(default_factory=lambda: u.Logger(""))
14
+
15
+ def _attach_connections(self, duckdb_conn: duckdb.DuckDBPyConnection) -> None:
16
+ for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
17
+ if not isinstance(conn_props, m.ConnectionProperties):
18
+ continue
19
+ attach_uri = conn_props.attach_uri_for_duckdb
20
+ if attach_uri is None:
21
+ continue # skip unsupported dialects
22
+ attach_stmt = f"ATTACH IF NOT EXISTS '{attach_uri}' AS db_{conn_name} (READ_ONLY)"
23
+ u.run_duckdb_stmt(self._logger, duckdb_conn, attach_stmt, redacted_values=[attach_uri])
24
+
25
+ async def _build_models(self, duckdb_conn: duckdb.DuckDBPyConnection, select: str | None, full_refresh: bool) -> None:
26
+ """
27
+ Compile and construct the build models as DuckDB tables.
28
+ """
29
+ # Compile the build models
30
+ models_list = self._static_models.values() if select is None else [self._static_models[select]]
31
+ for model in models_list:
32
+ model.compile_for_build(self._conn_args, self._static_models)
33
+
34
+ # Find all terminal nodes
35
+ terminal_nodes = set()
36
+ if select is None:
37
+ for model in models_list:
38
+ terminal_nodes.update(model.get_terminal_nodes_for_build(set()))
39
+ for model in models_list:
40
+ model.confirmed_no_cycles = False
41
+ else:
42
+ terminal_nodes.add(select)
43
+
44
+ # Run the build models
45
+ coroutines = []
46
+ for model_name in terminal_nodes:
47
+ model = self._static_models[model_name]
48
+ # await model.build_model(duckdb_conn, full_refresh)
49
+ coro = model.build_model(duckdb_conn, full_refresh)
50
+ coroutines.append(coro)
51
+ await u.asyncio_gather(coroutines)
52
+
53
+ async def build(self, full_refresh: bool, select: str | None) -> None:
54
+ start = time.time()
55
+
56
+ # Connect directly to DuckLake instead of attaching (supports concurrent connections)
57
+ duckdb_conn = u.create_duckdb_connection(self._datalake_db_path)
58
+
59
+ try:
60
+ # Attach connections
61
+ self._attach_connections(duckdb_conn)
62
+
63
+ # Construct build models
64
+ await self._build_models(duckdb_conn, select, full_refresh)
65
+
66
+ finally:
67
+ duckdb_conn.close()
68
+
69
+ self._logger.log_activity_time("TOTAL TIME to build the Virtual Data Lake (VDL)", start)