squirrels 0.1.0__py3-none-any.whl → 0.6.0.post0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dateutils/__init__.py +6 -0
- dateutils/_enums.py +25 -0
- squirrels/dateutils.py → dateutils/_implementation.py +409 -380
- dateutils/types.py +6 -0
- squirrels/__init__.py +21 -18
- squirrels/_api_routes/__init__.py +5 -0
- squirrels/_api_routes/auth.py +337 -0
- squirrels/_api_routes/base.py +196 -0
- squirrels/_api_routes/dashboards.py +156 -0
- squirrels/_api_routes/data_management.py +148 -0
- squirrels/_api_routes/datasets.py +220 -0
- squirrels/_api_routes/project.py +289 -0
- squirrels/_api_server.py +552 -134
- squirrels/_arguments/__init__.py +0 -0
- squirrels/_arguments/init_time_args.py +83 -0
- squirrels/_arguments/run_time_args.py +111 -0
- squirrels/_auth.py +777 -0
- squirrels/_command_line.py +239 -107
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +94 -0
- squirrels/_constants.py +141 -64
- squirrels/_dashboards.py +179 -0
- squirrels/_data_sources.py +570 -0
- squirrels/_dataset_types.py +91 -0
- squirrels/_env_vars.py +209 -0
- squirrels/_exceptions.py +29 -0
- squirrels/_http_error_responses.py +52 -0
- squirrels/_initializer.py +319 -110
- squirrels/_logging.py +121 -0
- squirrels/_manifest.py +357 -187
- squirrels/_mcp_server.py +578 -0
- squirrels/_model_builder.py +69 -0
- squirrels/_model_configs.py +74 -0
- squirrels/_model_queries.py +52 -0
- squirrels/_models.py +1201 -0
- squirrels/_package_data/base_project/.env +7 -0
- squirrels/_package_data/base_project/.env.example +44 -0
- squirrels/_package_data/base_project/connections.yml +16 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +40 -0
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +22 -0
- squirrels/_package_data/base_project/docker/.dockerignore +16 -0
- squirrels/_package_data/base_project/docker/Dockerfile +16 -0
- squirrels/_package_data/base_project/docker/compose.yml +7 -0
- squirrels/_package_data/base_project/duckdb_init.sql +10 -0
- squirrels/_package_data/base_project/gitignore +13 -0
- squirrels/_package_data/base_project/macros/macros_example.sql +17 -0
- squirrels/_package_data/base_project/models/builds/build_example.py +26 -0
- squirrels/_package_data/base_project/models/builds/build_example.sql +16 -0
- squirrels/_package_data/base_project/models/builds/build_example.yml +57 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +17 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +32 -0
- squirrels/_package_data/base_project/models/federates/federate_example.py +51 -0
- squirrels/_package_data/base_project/models/federates/federate_example.sql +21 -0
- squirrels/_package_data/base_project/models/federates/federate_example.yml +65 -0
- squirrels/_package_data/base_project/models/sources.yml +38 -0
- squirrels/_package_data/base_project/parameters.yml +142 -0
- squirrels/_package_data/base_project/pyconfigs/connections.py +19 -0
- squirrels/_package_data/base_project/pyconfigs/context.py +96 -0
- squirrels/_package_data/base_project/pyconfigs/parameters.py +141 -0
- squirrels/_package_data/base_project/pyconfigs/user.py +56 -0
- squirrels/_package_data/base_project/resources/expenses.db +0 -0
- squirrels/_package_data/base_project/resources/public/.gitkeep +0 -0
- squirrels/_package_data/base_project/resources/weather.db +0 -0
- squirrels/_package_data/base_project/seeds/seed_categories.csv +6 -0
- squirrels/_package_data/base_project/seeds/seed_categories.yml +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.csv +15 -0
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +21 -0
- squirrels/_package_data/base_project/squirrels.yml.j2 +61 -0
- squirrels/_package_data/base_project/tmp/.gitignore +2 -0
- squirrels/_package_data/templates/login_successful.html +53 -0
- squirrels/_package_data/templates/squirrels_studio.html +22 -0
- squirrels/_package_loader.py +29 -0
- squirrels/_parameter_configs.py +592 -0
- squirrels/_parameter_options.py +348 -0
- squirrels/_parameter_sets.py +207 -0
- squirrels/_parameters.py +1703 -0
- squirrels/_project.py +796 -0
- squirrels/_py_module.py +122 -0
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/__init__.py +0 -0
- squirrels/_schemas/auth_models.py +83 -0
- squirrels/_schemas/query_param_models.py +70 -0
- squirrels/_schemas/request_models.py +26 -0
- squirrels/_schemas/response_models.py +286 -0
- squirrels/_seeds.py +97 -0
- squirrels/_sources.py +112 -0
- squirrels/_utils.py +540 -149
- squirrels/_version.py +1 -3
- squirrels/arguments.py +7 -0
- squirrels/auth.py +4 -0
- squirrels/connections.py +3 -0
- squirrels/dashboards.py +3 -0
- squirrels/data_sources.py +14 -282
- squirrels/parameter_options.py +13 -189
- squirrels/parameters.py +14 -801
- squirrels/types.py +18 -0
- squirrels-0.6.0.post0.dist-info/METADATA +148 -0
- squirrels-0.6.0.post0.dist-info/RECORD +101 -0
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/WHEEL +1 -2
- {squirrels-0.1.0.dist-info → squirrels-0.6.0.post0.dist-info}/entry_points.txt +1 -0
- squirrels-0.6.0.post0.dist-info/licenses/LICENSE +201 -0
- squirrels/_credentials_manager.py +0 -87
- squirrels/_module_loader.py +0 -37
- squirrels/_parameter_set.py +0 -151
- squirrels/_renderer.py +0 -286
- squirrels/_timed_imports.py +0 -37
- squirrels/connection_set.py +0 -126
- squirrels/package_data/base_project/.gitignore +0 -4
- squirrels/package_data/base_project/connections.py +0 -21
- squirrels/package_data/base_project/database/sample_database.db +0 -0
- squirrels/package_data/base_project/database/seattle_weather.db +0 -0
- squirrels/package_data/base_project/datasets/sample_dataset/context.py +0 -8
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.py +0 -23
- squirrels/package_data/base_project/datasets/sample_dataset/database_view1.sql.j2 +0 -7
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.py +0 -10
- squirrels/package_data/base_project/datasets/sample_dataset/final_view.sql.j2 +0 -2
- squirrels/package_data/base_project/datasets/sample_dataset/parameters.py +0 -30
- squirrels/package_data/base_project/datasets/sample_dataset/selections.cfg +0 -6
- squirrels/package_data/base_project/squirrels.yaml +0 -26
- squirrels/package_data/static/favicon.ico +0 -0
- squirrels/package_data/static/script.js +0 -234
- squirrels/package_data/static/style.css +0 -110
- squirrels/package_data/templates/index.html +0 -32
- squirrels-0.1.0.dist-info/LICENSE +0 -22
- squirrels-0.1.0.dist-info/METADATA +0 -67
- squirrels-0.1.0.dist-info/RECORD +0 -40
- squirrels-0.1.0.dist-info/top_level.txt +0 -1
squirrels/_mcp_server.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MCP Server implementation using the official MCP Python SDK low-level APIs.
|
|
3
|
+
|
|
4
|
+
This module provides the MCP server for Squirrels projects, exposing:
|
|
5
|
+
- Tools: get_data_catalog, get_dataset_parameters, get_dataset_results
|
|
6
|
+
- Resources: sqrl://data-catalog
|
|
7
|
+
"""
|
|
8
|
+
from typing import Any, Protocol
|
|
9
|
+
from collections.abc import AsyncIterator
|
|
10
|
+
from contextlib import asynccontextmanager
|
|
11
|
+
from textwrap import dedent
|
|
12
|
+
from pydantic import AnyUrl
|
|
13
|
+
from starlette.applications import Starlette
|
|
14
|
+
from starlette.requests import Request
|
|
15
|
+
from starlette.routing import Mount
|
|
16
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
17
|
+
from starlette.types import ASGIApp
|
|
18
|
+
from mcp.server.lowlevel import Server
|
|
19
|
+
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|
20
|
+
import mcp.types as types
|
|
21
|
+
import json
|
|
22
|
+
|
|
23
|
+
from . import _utils as u
|
|
24
|
+
from ._schemas.auth_models import AbstractUser
|
|
25
|
+
from ._schemas.request_models import McpRequestHeaders
|
|
26
|
+
from ._exceptions import InvalidInputError
|
|
27
|
+
from ._http_error_responses import invalid_input_error_to_json_response
|
|
28
|
+
from ._schemas import response_models as rm
|
|
29
|
+
from ._dataset_types import DatasetResult, DatasetResultFormat
|
|
30
|
+
from ._api_routes.base import RouteBase
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class GetUserFromHeaders(Protocol):
|
|
34
|
+
def __call__(self, api_key: str | None, bearer_token: str | None) -> tuple[AbstractUser, float | None]:
|
|
35
|
+
...
|
|
36
|
+
|
|
37
|
+
class GetDataCatalogForMcp(Protocol):
|
|
38
|
+
async def __call__(self, user: AbstractUser) -> rm.CatalogModelForMcp:
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
class GetDatasetParametersForMcp(Protocol):
|
|
42
|
+
async def __call__(
|
|
43
|
+
self, dataset: str, parameter_name: str, selected_ids: str | list[str], user: AbstractUser
|
|
44
|
+
) -> rm.ParametersModel:
|
|
45
|
+
...
|
|
46
|
+
|
|
47
|
+
class GetDatasetResultsForMcp(Protocol):
|
|
48
|
+
async def __call__(
|
|
49
|
+
self, dataset: str, parameters: dict[str, Any], sql_query: str | None, user: AbstractUser, configurables: tuple[tuple[str, str], ...]
|
|
50
|
+
) -> DatasetResult:
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class McpServerBuilder:
|
|
55
|
+
"""
|
|
56
|
+
Builder for the MCP server that exposes Squirrels tools and resources.
|
|
57
|
+
|
|
58
|
+
This class is responsible for:
|
|
59
|
+
- Creating the low-level MCP Server
|
|
60
|
+
- Registering list_tools, call_tool, list_resources, read_resource handlers
|
|
61
|
+
- Creating the StreamableHTTPSessionManager for HTTP transport
|
|
62
|
+
- Providing the ASGI app and lifespan manager
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(
|
|
66
|
+
self,
|
|
67
|
+
project_name: str,
|
|
68
|
+
project_label: str,
|
|
69
|
+
max_rows_for_ai: int,
|
|
70
|
+
get_user_from_headers: GetUserFromHeaders,
|
|
71
|
+
get_data_catalog_for_mcp: GetDataCatalogForMcp,
|
|
72
|
+
get_dataset_parameters_for_mcp: GetDatasetParametersForMcp,
|
|
73
|
+
get_dataset_results_for_mcp: GetDatasetResultsForMcp,
|
|
74
|
+
*,
|
|
75
|
+
enforce_oauth_bearer: bool = False,
|
|
76
|
+
oauth_resource_metadata_path: str = "/.well-known/oauth-protected-resource",
|
|
77
|
+
www_authenticate_strip_path_suffix: str = "/mcp",
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Initialize the MCP server builder.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
project_name: The name of the Squirrels project
|
|
84
|
+
project_label: The human-readable label of the project
|
|
85
|
+
max_rows_for_ai: Maximum number of rows to return for AI tools
|
|
86
|
+
get_data_catalog_for_mcp: Async function to get the data catalog
|
|
87
|
+
get_dataset_parameters_for_mcp: Async function to get dataset parameters
|
|
88
|
+
get_dataset_results_for_mcp: Async function to get dataset results
|
|
89
|
+
"""
|
|
90
|
+
self.project_name = project_name
|
|
91
|
+
self.project_label = project_label
|
|
92
|
+
self.max_rows_for_ai = max_rows_for_ai
|
|
93
|
+
self.default_for_limit = min(self.max_rows_for_ai, 10)
|
|
94
|
+
|
|
95
|
+
self.enforce_oauth_bearer = enforce_oauth_bearer
|
|
96
|
+
self.oauth_resource_metadata_path = oauth_resource_metadata_path
|
|
97
|
+
self.www_authenticate_strip_path_suffix = www_authenticate_strip_path_suffix
|
|
98
|
+
|
|
99
|
+
self._get_user_from_headers = get_user_from_headers
|
|
100
|
+
self._get_data_catalog_for_mcp = get_data_catalog_for_mcp
|
|
101
|
+
self._get_dataset_parameters_for_mcp = get_dataset_parameters_for_mcp
|
|
102
|
+
self._get_dataset_results_for_mcp = get_dataset_results_for_mcp
|
|
103
|
+
|
|
104
|
+
# Tool names
|
|
105
|
+
self.catalog_tool_name = f"get_data_catalog_from_{project_name}"
|
|
106
|
+
self.parameters_tool_name = f"get_dataset_parameters_from_{project_name}"
|
|
107
|
+
self.results_tool_name = f"get_dataset_results_from_{project_name}"
|
|
108
|
+
|
|
109
|
+
# Resource URI
|
|
110
|
+
self.catalog_resource_uri = "sqrl://data-catalog"
|
|
111
|
+
self.catalog_resource_name = f"data_catalog_from_{project_name}"
|
|
112
|
+
|
|
113
|
+
# Build the server
|
|
114
|
+
self._server = self._build_server()
|
|
115
|
+
self._session_manager = StreamableHTTPSessionManager(
|
|
116
|
+
app=self._server,
|
|
117
|
+
stateless=True,
|
|
118
|
+
json_response=True,
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
def _get_tool_annotations(
|
|
122
|
+
self, title: str, *, read_only: bool = True, destructive: bool = False,
|
|
123
|
+
idempotent: bool = True, open_world: bool = False
|
|
124
|
+
) -> types.ToolAnnotations:
|
|
125
|
+
"""Get the tool annotations for the given title."""
|
|
126
|
+
return types.ToolAnnotations(
|
|
127
|
+
title=title,
|
|
128
|
+
readOnlyHint=read_only,
|
|
129
|
+
destructiveHint=destructive,
|
|
130
|
+
idempotentHint=idempotent,
|
|
131
|
+
openWorldHint=open_world,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
def _build_server(self) -> Server:
|
|
135
|
+
"""Build and configure the low-level MCP Server."""
|
|
136
|
+
server = Server("Squirrels")
|
|
137
|
+
|
|
138
|
+
# Register handlers
|
|
139
|
+
server.list_tools()(self._list_tools)
|
|
140
|
+
server.call_tool()(self._call_tool)
|
|
141
|
+
server.list_resources()(self._list_resources)
|
|
142
|
+
server.read_resource()(self._read_resource)
|
|
143
|
+
|
|
144
|
+
return server
|
|
145
|
+
|
|
146
|
+
def _get_request_headers(self) -> McpRequestHeaders:
|
|
147
|
+
"""
|
|
148
|
+
Get HTTP headers from the current MCP request context.
|
|
149
|
+
|
|
150
|
+
Uses server.request_context.request.headers to access headers
|
|
151
|
+
from the underlying HTTP request.
|
|
152
|
+
"""
|
|
153
|
+
try:
|
|
154
|
+
request = self._server.request_context.request
|
|
155
|
+
if request is not None and hasattr(request, 'headers'):
|
|
156
|
+
return McpRequestHeaders(raw_headers=request.headers)
|
|
157
|
+
except (AttributeError, LookupError):
|
|
158
|
+
pass
|
|
159
|
+
|
|
160
|
+
return McpRequestHeaders()
|
|
161
|
+
|
|
162
|
+
def _get_request_metadata(self) -> dict[str, Any]:
|
|
163
|
+
"""
|
|
164
|
+
Metadata of the current MCP request as a dictionary.
|
|
165
|
+
|
|
166
|
+
Returns:
|
|
167
|
+
A dictionary of the request metadata
|
|
168
|
+
"""
|
|
169
|
+
request_metadata = self._server.request_context.meta
|
|
170
|
+
if request_metadata is None:
|
|
171
|
+
return {}
|
|
172
|
+
return request_metadata.model_dump(mode="json")
|
|
173
|
+
|
|
174
|
+
def _get_configurables(self, mcp_headers: McpRequestHeaders) -> tuple[tuple[str, str], ...]:
|
|
175
|
+
"""
|
|
176
|
+
Extract configurables from headers and metadata.
|
|
177
|
+
"""
|
|
178
|
+
prefix = "x-config-"
|
|
179
|
+
cfg_dict: dict[str, str] = {}
|
|
180
|
+
|
|
181
|
+
# 1. Extract from headers
|
|
182
|
+
for key, value in mcp_headers.raw_headers.items():
|
|
183
|
+
key_lower = str(key).lower()
|
|
184
|
+
if key_lower.startswith(prefix):
|
|
185
|
+
cfg_name_raw = key_lower[len(prefix):]
|
|
186
|
+
cfg_name_normalized = u.normalize_name(cfg_name_raw)
|
|
187
|
+
|
|
188
|
+
if cfg_name_normalized in cfg_dict:
|
|
189
|
+
raise InvalidInputError(
|
|
190
|
+
400, "duplicate_configurable",
|
|
191
|
+
f"Configurable '{cfg_name_normalized}' specified multiple times in headers."
|
|
192
|
+
)
|
|
193
|
+
cfg_dict[cfg_name_normalized] = str(value)
|
|
194
|
+
|
|
195
|
+
# 2. Extract from metadata
|
|
196
|
+
metadata = self._get_request_metadata()
|
|
197
|
+
for key, value in metadata.items():
|
|
198
|
+
if key == "progressToken":
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
cfg_name_normalized = u.normalize_name(key)
|
|
202
|
+
if cfg_name_normalized in cfg_dict:
|
|
203
|
+
raise InvalidInputError(
|
|
204
|
+
400, "duplicate_configurable",
|
|
205
|
+
f"Configurable '{cfg_name_normalized}' specified multiple times (header and metadata)."
|
|
206
|
+
)
|
|
207
|
+
cfg_dict[cfg_name_normalized] = str(value)
|
|
208
|
+
|
|
209
|
+
return tuple(cfg_dict.items())
|
|
210
|
+
|
|
211
|
+
def _get_validated_user_for_request(self, mcp_headers: McpRequestHeaders) -> tuple[AbstractUser, float | None]:
|
|
212
|
+
"""
|
|
213
|
+
Return the validated user for the current HTTP request.
|
|
214
|
+
|
|
215
|
+
If the MCP app runs with `enforce_oauth_bearer=True`, missing Bearer tokens
|
|
216
|
+
must produce an HTTP 401 (not an MCP tool error), so we raise InvalidInputError.
|
|
217
|
+
"""
|
|
218
|
+
# Prefer values set by the HTTP middleware to avoid double validation.
|
|
219
|
+
try:
|
|
220
|
+
request = self._server.request_context.request
|
|
221
|
+
if request is not None and hasattr(request, "state"):
|
|
222
|
+
state = request.state
|
|
223
|
+
user = getattr(state, "sqrl_user", None)
|
|
224
|
+
expiry = getattr(state, "access_token_expiry", None)
|
|
225
|
+
if user is not None:
|
|
226
|
+
return user, expiry
|
|
227
|
+
except (AttributeError, LookupError):
|
|
228
|
+
pass
|
|
229
|
+
|
|
230
|
+
if self.enforce_oauth_bearer and not mcp_headers.bearer_token:
|
|
231
|
+
raise InvalidInputError(401, "user_required", "Authentication is required")
|
|
232
|
+
|
|
233
|
+
return self._get_user_from_headers(api_key=mcp_headers.api_key, bearer_token=mcp_headers.bearer_token)
|
|
234
|
+
|
|
235
|
+
async def _list_tools(self) -> list[types.Tool]:
|
|
236
|
+
"""Return the list of available MCP tools."""
|
|
237
|
+
headers = self._get_request_headers()
|
|
238
|
+
feature_flags = headers.feature_flags
|
|
239
|
+
full_result_flag = "mcp-full-dataset-v1" in feature_flags
|
|
240
|
+
|
|
241
|
+
dataset_results_extended_description = dedent("""
|
|
242
|
+
The "offset" and "limit" arguments affect the "content" field, but not the "structuredContent" field, of this tool's result. Assume that you (the AI model) can only see the "content" field, but accessing this tool's result through code execution (if applicable) uses the "structuredContent" field. Note that the "sql_query" and "orientation" arguments still apply to both the "content" and "structuredContent" fields.
|
|
243
|
+
""").strip() if full_result_flag else ""
|
|
244
|
+
|
|
245
|
+
return [
|
|
246
|
+
types.Tool(
|
|
247
|
+
name=self.catalog_tool_name,
|
|
248
|
+
title=f"Data Catalog For {self.project_label}",
|
|
249
|
+
description=dedent(f"""
|
|
250
|
+
Use this tool to get the details of all datasets and parameters you can access in the Squirrels project '{self.project_name}'.
|
|
251
|
+
|
|
252
|
+
Unless the data catalog for this project has already been provided, use this tool at the start of each conversation.
|
|
253
|
+
""").strip(),
|
|
254
|
+
annotations=self._get_tool_annotations(title=f"Data Catalog For {self.project_label}"),
|
|
255
|
+
inputSchema={
|
|
256
|
+
"type": "object",
|
|
257
|
+
"properties": {},
|
|
258
|
+
"required": [],
|
|
259
|
+
},
|
|
260
|
+
# outputSchema=rm.CatalogModelForMcp.model_json_schema(),
|
|
261
|
+
),
|
|
262
|
+
types.Tool(
|
|
263
|
+
name=self.parameters_tool_name,
|
|
264
|
+
title=f"Parameters Updates For {self.project_label}",
|
|
265
|
+
description=dedent(f"""
|
|
266
|
+
Use this tool to get updates for dataset parameters in the Squirrels project "{self.project_name}" when a selection is to be made on a parameter with `"trigger_refresh": true`.
|
|
267
|
+
|
|
268
|
+
For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool should be called to get the updates for other parameters such as "city".
|
|
269
|
+
|
|
270
|
+
Do not use this tool on parameters that do not have `"trigger_refresh": true`.
|
|
271
|
+
""").strip(),
|
|
272
|
+
annotations=self._get_tool_annotations(title=f"Parameters Updates For {self.project_label}"),
|
|
273
|
+
inputSchema={
|
|
274
|
+
"type": "object",
|
|
275
|
+
"properties": {
|
|
276
|
+
"dataset": {
|
|
277
|
+
"type": "string",
|
|
278
|
+
"description": "The name of the dataset whose parameters the trigger parameter will update",
|
|
279
|
+
},
|
|
280
|
+
"selected_ids": {
|
|
281
|
+
"type": "string",
|
|
282
|
+
"description": dedent("""
|
|
283
|
+
A JSON object (as string) with one key-value pair. The key is the name of the parameter triggering the refresh, and the value is the ID(s) of the selected option(s) for the parameter.
|
|
284
|
+
- If the parameter's widget_type is single_select, use a string for the ID of the selected option
|
|
285
|
+
- If the parameter's widget_type is multi_select, use an array of strings for the IDs of the selected options
|
|
286
|
+
|
|
287
|
+
An error is raised if this JSON object does not have exactly one key-value pair.
|
|
288
|
+
""").strip(),
|
|
289
|
+
},
|
|
290
|
+
},
|
|
291
|
+
"required": ["dataset", "selected_ids"],
|
|
292
|
+
},
|
|
293
|
+
# outputSchema=rm.ParametersModel.model_json_schema(),
|
|
294
|
+
),
|
|
295
|
+
types.Tool(
|
|
296
|
+
name=self.results_tool_name,
|
|
297
|
+
title=f"Dataset Results For {self.project_label}",
|
|
298
|
+
description=dedent(f"""
|
|
299
|
+
Use this tool to get the dataset results as a JSON object for a dataset in the Squirrels project "{self.project_name}".
|
|
300
|
+
|
|
301
|
+
{dataset_results_extended_description}
|
|
302
|
+
""").strip(),
|
|
303
|
+
annotations=self._get_tool_annotations(title=f"Dataset Results For {self.project_label}"),
|
|
304
|
+
inputSchema={
|
|
305
|
+
"type": "object",
|
|
306
|
+
"properties": {
|
|
307
|
+
"dataset": {
|
|
308
|
+
"type": "string",
|
|
309
|
+
"description": "The name of the dataset to get results for",
|
|
310
|
+
},
|
|
311
|
+
"parameters": {
|
|
312
|
+
"type": "string",
|
|
313
|
+
"description": dedent("""
|
|
314
|
+
A JSON object (as string) containing key-value pairs for parameter name and selected value. The selected value to provide depends on the parameter widget type:
|
|
315
|
+
- If the parameter's widget_type is single_select, use a string for the ID of the selected option
|
|
316
|
+
- If the parameter's widget_type is multi_select, use an array of strings for the IDs of the selected options
|
|
317
|
+
- If the parameter's widget_type is date, use a string like "YYYY-MM-DD"
|
|
318
|
+
- If the parameter's widget_type is date_range, use array of strings like ["YYYY-MM-DD","YYYY-MM-DD"]
|
|
319
|
+
- If the parameter's widget_type is number, use a number like 1
|
|
320
|
+
- If the parameter's widget_type is number_range, use array of numbers like [1,100]
|
|
321
|
+
- If the parameter's widget_type is text, use a string for the text value
|
|
322
|
+
- Complex objects are NOT supported
|
|
323
|
+
""").strip(),
|
|
324
|
+
},
|
|
325
|
+
"sql_query": {
|
|
326
|
+
"type": ["string", "null"],
|
|
327
|
+
"description": dedent("""
|
|
328
|
+
A custom Polars SQL query to execute on the final dataset result.
|
|
329
|
+
- Use table name 'result' to reference the dataset result.
|
|
330
|
+
- Use this to apply transformations to the dataset result if needed (such as filtering, sorting, or selecting columns).
|
|
331
|
+
- If not provided, the dataset result is returned as is.
|
|
332
|
+
""").strip(),
|
|
333
|
+
"default": None,
|
|
334
|
+
},
|
|
335
|
+
"orientation": {
|
|
336
|
+
"type": "string",
|
|
337
|
+
"enum": ["rows", "columns", "records"],
|
|
338
|
+
"description": "The orientation of the dataset result. Options are 'rows', 'columns', and 'records'. Default is 'rows'.",
|
|
339
|
+
"default": "rows",
|
|
340
|
+
},
|
|
341
|
+
"offset": {
|
|
342
|
+
"type": "integer",
|
|
343
|
+
"description": "The number of rows to skip from first row. Applied after the sql_query. Default is 0.",
|
|
344
|
+
"default": 0,
|
|
345
|
+
},
|
|
346
|
+
"limit": {
|
|
347
|
+
"type": "integer",
|
|
348
|
+
"description": dedent(f"""
|
|
349
|
+
The maximum number of rows to return. Applied after the sql_query.
|
|
350
|
+
Default is {self.default_for_limit}. Maximum allowed value is {self.max_rows_for_ai}.
|
|
351
|
+
""").strip(),
|
|
352
|
+
"default": self.default_for_limit,
|
|
353
|
+
},
|
|
354
|
+
},
|
|
355
|
+
"required": ["dataset", "parameters"],
|
|
356
|
+
},
|
|
357
|
+
outputSchema=rm.DatasetResultModel.model_json_schema(),
|
|
358
|
+
),
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
def _get_dataset_and_parameters(self, arguments: dict[str, Any], *, params_key: str = "parameters") -> tuple[str, dict[str, Any]]:
|
|
362
|
+
"""Get dataset and parameters from arguments.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
arguments: The arguments from the tool call
|
|
366
|
+
params_key: The key of the parameters in the arguments
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
A tuple of the dataset and parameters
|
|
370
|
+
|
|
371
|
+
Raises:
|
|
372
|
+
InvalidInputError: If the dataset or parameters are invalid
|
|
373
|
+
"""
|
|
374
|
+
try:
|
|
375
|
+
dataset = str(arguments["dataset"])
|
|
376
|
+
except KeyError:
|
|
377
|
+
raise InvalidInputError(400, "invalid_dataset", "The 'dataset' argument is required.")
|
|
378
|
+
|
|
379
|
+
parameters_arg = str(arguments.get(params_key, "{}"))
|
|
380
|
+
|
|
381
|
+
# validate parameters argument
|
|
382
|
+
try:
|
|
383
|
+
parameters = json.loads(parameters_arg)
|
|
384
|
+
except json.JSONDecodeError:
|
|
385
|
+
parameters = None # error handled below
|
|
386
|
+
|
|
387
|
+
if not isinstance(parameters, dict):
|
|
388
|
+
raise InvalidInputError(400, "invalid_parameters", f"The '{params_key}' argument must be a JSON object.")
|
|
389
|
+
|
|
390
|
+
return dataset, parameters
|
|
391
|
+
|
|
392
|
+
async def _call_tool(self, name: str, arguments: dict[str, Any] | None) -> types.CallToolResult:
|
|
393
|
+
"""Handle tool calls by dispatching to the appropriate function.
|
|
394
|
+
|
|
395
|
+
Returns structured data (dict) directly for successful calls, which the MCP
|
|
396
|
+
framework will serialize to JSON. For errors, returns CallToolResult with isError=True.
|
|
397
|
+
"""
|
|
398
|
+
arguments = arguments or {}
|
|
399
|
+
|
|
400
|
+
try:
|
|
401
|
+
mcp_headers = self._get_request_headers()
|
|
402
|
+
user, _ = self._get_validated_user_for_request(mcp_headers)
|
|
403
|
+
|
|
404
|
+
feature_flags = mcp_headers.feature_flags
|
|
405
|
+
full_result_flag = "mcp-full-dataset-v1" in feature_flags
|
|
406
|
+
|
|
407
|
+
if name == self.catalog_tool_name:
|
|
408
|
+
result = await self._get_data_catalog_for_mcp(user)
|
|
409
|
+
return types.CallToolResult(
|
|
410
|
+
content=[types.TextContent(type="text", text=result.model_dump_json(by_alias=True))],
|
|
411
|
+
structuredContent=result.model_dump(mode="json", by_alias=True),
|
|
412
|
+
)
|
|
413
|
+
|
|
414
|
+
elif name == self.parameters_tool_name:
|
|
415
|
+
dataset, parameters = self._get_dataset_and_parameters(arguments, params_key="selected_ids")
|
|
416
|
+
|
|
417
|
+
# validate parameters is a single key-value pair
|
|
418
|
+
if len(parameters) != 1:
|
|
419
|
+
raise InvalidInputError(
|
|
420
|
+
400, "invalid_selected_ids",
|
|
421
|
+
"The 'selected_ids' argument must have exactly one key-value pair."
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
# validate selected ids is a string or list of strings
|
|
425
|
+
parameter_name, selected_ids = next(iter(parameters.items()))
|
|
426
|
+
if not isinstance(selected_ids, (str, list)):
|
|
427
|
+
raise InvalidInputError(
|
|
428
|
+
400, "invalid_selected_ids",
|
|
429
|
+
f"The selected ids of the parameter '{parameter_name}' must be a string or list of strings."
|
|
430
|
+
)
|
|
431
|
+
|
|
432
|
+
# get dataset parameters
|
|
433
|
+
result = await self._get_dataset_parameters_for_mcp(dataset, parameter_name, selected_ids, user)
|
|
434
|
+
return types.CallToolResult(
|
|
435
|
+
content=[types.TextContent(type="text", text=result.model_dump_json(by_alias=True))],
|
|
436
|
+
structuredContent=result.model_dump(mode="json", by_alias=True),
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
elif name == self.results_tool_name:
|
|
440
|
+
dataset, parameters = self._get_dataset_and_parameters(arguments, params_key="parameters")
|
|
441
|
+
|
|
442
|
+
# validate sql_query argument
|
|
443
|
+
sql_query_arg = arguments.get("sql_query")
|
|
444
|
+
sql_query = str(sql_query_arg) if sql_query_arg else None
|
|
445
|
+
|
|
446
|
+
# validate orientation argument
|
|
447
|
+
result_format = RouteBase.extract_orientation_offset_and_limit(arguments, key_prefix="", default_orientation="rows", default_limit=self.default_for_limit)
|
|
448
|
+
orientation, limit = result_format.orientation, result_format.limit
|
|
449
|
+
if limit > self.max_rows_for_ai:
|
|
450
|
+
raise InvalidInputError(400, "invalid_limit", f"The 'limit' argument must be less than or equal to {self.max_rows_for_ai}.")
|
|
451
|
+
|
|
452
|
+
# get dataset result object
|
|
453
|
+
configurables = self._get_configurables(mcp_headers)
|
|
454
|
+
result_obj = await self._get_dataset_results_for_mcp(
|
|
455
|
+
dataset, parameters, sql_query, user, configurables
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# format dataset result object
|
|
459
|
+
structured_result = result_obj.to_json(result_format)
|
|
460
|
+
result_model = rm.DatasetResultModel(**structured_result)
|
|
461
|
+
|
|
462
|
+
if full_result_flag:
|
|
463
|
+
full_result_format = DatasetResultFormat(orientation, 0, None)
|
|
464
|
+
structured_result = result_obj.to_json(full_result_format)
|
|
465
|
+
|
|
466
|
+
return types.CallToolResult(
|
|
467
|
+
content=[types.TextContent(type="text", text=result_model.model_dump_json(by_alias=True))],
|
|
468
|
+
structuredContent=structured_result,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
else:
|
|
472
|
+
return types.CallToolResult(
|
|
473
|
+
content=[types.TextContent(type="text", text=f"Unknown tool: {name}")],
|
|
474
|
+
isError=True
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
except InvalidInputError as e:
|
|
478
|
+
# If auth is required, surface HTTP 401s as real HTTP responses.
|
|
479
|
+
if e.status_code == 401:
|
|
480
|
+
raise
|
|
481
|
+
return types.CallToolResult(
|
|
482
|
+
content=[types.TextContent(type="text", text=f"Error: {e.error_description}")],
|
|
483
|
+
isError=True,
|
|
484
|
+
)
|
|
485
|
+
except Exception as e:
|
|
486
|
+
return types.CallToolResult(
|
|
487
|
+
content=[types.TextContent(type="text", text=f"Error: {str(e)}")],
|
|
488
|
+
isError=True
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
async def _list_resources(self) -> list[types.Resource]:
|
|
492
|
+
"""Return the list of available MCP resources."""
|
|
493
|
+
return [
|
|
494
|
+
types.Resource(
|
|
495
|
+
uri=AnyUrl(self.catalog_resource_uri),
|
|
496
|
+
name=self.catalog_resource_name,
|
|
497
|
+
description=f"Details of all datasets and parameters you can access in the Squirrels project '{self.project_name}'.",
|
|
498
|
+
),
|
|
499
|
+
]
|
|
500
|
+
|
|
501
|
+
async def _read_resource(self, uri: AnyUrl) -> str | bytes:
|
|
502
|
+
"""Read the content of a resource."""
|
|
503
|
+
mcp_headers = self._get_request_headers()
|
|
504
|
+
|
|
505
|
+
if str(uri) == self.catalog_resource_uri:
|
|
506
|
+
user, _ = self._get_validated_user_for_request(mcp_headers)
|
|
507
|
+
result = await self._get_data_catalog_for_mcp(user)
|
|
508
|
+
return result.model_dump_json(by_alias=True)
|
|
509
|
+
else:
|
|
510
|
+
raise ValueError(f"Unknown resource URI: {uri}")
|
|
511
|
+
|
|
512
|
+
@asynccontextmanager
|
|
513
|
+
async def lifespan(self, app: object | None = None) -> AsyncIterator[None]:
|
|
514
|
+
"""
|
|
515
|
+
Async context manager for the MCP session manager lifecycle.
|
|
516
|
+
|
|
517
|
+
Use this in the FastAPI app lifespan to ensure proper startup/shutdown.
|
|
518
|
+
"""
|
|
519
|
+
async with self._session_manager.run():
|
|
520
|
+
yield
|
|
521
|
+
|
|
522
|
+
def get_asgi_app(self) -> ASGIApp:
|
|
523
|
+
"""
|
|
524
|
+
Get the ASGI app for the MCP server.
|
|
525
|
+
"""
|
|
526
|
+
async def _invalid_input_handler(request: Request, exc: InvalidInputError):
|
|
527
|
+
# When mounted under `/mcp` (or a larger mount path ending in `/mcp`),
|
|
528
|
+
# strip only that mount suffix so the resource_metadata URL points to
|
|
529
|
+
# the top-level endpoint.
|
|
530
|
+
return invalid_input_error_to_json_response(
|
|
531
|
+
request,
|
|
532
|
+
exc,
|
|
533
|
+
oauth_resource_metadata_path=self.oauth_resource_metadata_path,
|
|
534
|
+
strip_path_suffix=self.www_authenticate_strip_path_suffix,
|
|
535
|
+
is_mcp=True,
|
|
536
|
+
)
|
|
537
|
+
|
|
538
|
+
app = Starlette(
|
|
539
|
+
routes=[
|
|
540
|
+
Mount("/", app=self._session_manager.handle_request),
|
|
541
|
+
],
|
|
542
|
+
lifespan=self.lifespan,
|
|
543
|
+
exception_handlers={InvalidInputError: _invalid_input_handler},
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
builder = self
|
|
547
|
+
|
|
548
|
+
class _McpOAuthGateMiddleware(BaseHTTPMiddleware):
|
|
549
|
+
async def dispatch(self, request: Request, call_next):
|
|
550
|
+
try:
|
|
551
|
+
if builder.enforce_oauth_bearer:
|
|
552
|
+
auth_header = request.headers.get("authorization", "")
|
|
553
|
+
token = None
|
|
554
|
+
if auth_header.lower().startswith("bearer "):
|
|
555
|
+
token = auth_header[7:].strip()
|
|
556
|
+
|
|
557
|
+
if not token:
|
|
558
|
+
raise InvalidInputError(401, "user_required", "Authentication is required")
|
|
559
|
+
|
|
560
|
+
user, expiry = builder._get_user_from_headers(api_key=None, bearer_token=token)
|
|
561
|
+
request.state.sqrl_user = user
|
|
562
|
+
request.state.access_token_expiry = expiry
|
|
563
|
+
|
|
564
|
+
return await call_next(request)
|
|
565
|
+
except InvalidInputError as exc:
|
|
566
|
+
# Starlette's BaseHTTPMiddleware may bypass exception handlers for
|
|
567
|
+
# exceptions raised within dispatch; handle explicitly here.
|
|
568
|
+
return invalid_input_error_to_json_response(
|
|
569
|
+
request,
|
|
570
|
+
exc,
|
|
571
|
+
oauth_resource_metadata_path=builder.oauth_resource_metadata_path,
|
|
572
|
+
strip_path_suffix=builder.www_authenticate_strip_path_suffix,
|
|
573
|
+
is_mcp=True,
|
|
574
|
+
)
|
|
575
|
+
|
|
576
|
+
app.add_middleware(_McpOAuthGateMiddleware)
|
|
577
|
+
return app
|
|
578
|
+
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
import duckdb, time
|
|
3
|
+
|
|
4
|
+
from . import _utils as u, _connection_set as cs, _models as m
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class ModelBuilder:
|
|
9
|
+
_datalake_db_path: str
|
|
10
|
+
_conn_set: cs.ConnectionSet
|
|
11
|
+
_static_models: dict[str, m.StaticModel]
|
|
12
|
+
_conn_args: cs.ConnectionsArgs
|
|
13
|
+
_logger: u.Logger = field(default_factory=lambda: u.Logger(""))
|
|
14
|
+
|
|
15
|
+
def _attach_connections(self, duckdb_conn: duckdb.DuckDBPyConnection) -> None:
|
|
16
|
+
for conn_name, conn_props in self._conn_set.get_connections_as_dict().items():
|
|
17
|
+
if not isinstance(conn_props, m.ConnectionProperties):
|
|
18
|
+
continue
|
|
19
|
+
attach_uri = conn_props.attach_uri_for_duckdb
|
|
20
|
+
if attach_uri is None:
|
|
21
|
+
continue # skip unsupported dialects
|
|
22
|
+
attach_stmt = f"ATTACH IF NOT EXISTS '{attach_uri}' AS db_{conn_name} (READ_ONLY)"
|
|
23
|
+
u.run_duckdb_stmt(self._logger, duckdb_conn, attach_stmt, redacted_values=[attach_uri])
|
|
24
|
+
|
|
25
|
+
async def _build_models(self, duckdb_conn: duckdb.DuckDBPyConnection, select: str | None, full_refresh: bool) -> None:
|
|
26
|
+
"""
|
|
27
|
+
Compile and construct the build models as DuckDB tables.
|
|
28
|
+
"""
|
|
29
|
+
# Compile the build models
|
|
30
|
+
models_list = self._static_models.values() if select is None else [self._static_models[select]]
|
|
31
|
+
for model in models_list:
|
|
32
|
+
model.compile_for_build(self._conn_args, self._static_models)
|
|
33
|
+
|
|
34
|
+
# Find all terminal nodes
|
|
35
|
+
terminal_nodes = set()
|
|
36
|
+
if select is None:
|
|
37
|
+
for model in models_list:
|
|
38
|
+
terminal_nodes.update(model.get_terminal_nodes_for_build(set()))
|
|
39
|
+
for model in models_list:
|
|
40
|
+
model.confirmed_no_cycles = False
|
|
41
|
+
else:
|
|
42
|
+
terminal_nodes.add(select)
|
|
43
|
+
|
|
44
|
+
# Run the build models
|
|
45
|
+
coroutines = []
|
|
46
|
+
for model_name in terminal_nodes:
|
|
47
|
+
model = self._static_models[model_name]
|
|
48
|
+
# await model.build_model(duckdb_conn, full_refresh)
|
|
49
|
+
coro = model.build_model(duckdb_conn, full_refresh)
|
|
50
|
+
coroutines.append(coro)
|
|
51
|
+
await u.asyncio_gather(coroutines)
|
|
52
|
+
|
|
53
|
+
async def build(self, full_refresh: bool, select: str | None) -> None:
|
|
54
|
+
start = time.time()
|
|
55
|
+
|
|
56
|
+
# Connect directly to DuckLake instead of attaching (supports concurrent connections)
|
|
57
|
+
duckdb_conn = u.create_duckdb_connection(self._datalake_db_path)
|
|
58
|
+
|
|
59
|
+
try:
|
|
60
|
+
# Attach connections
|
|
61
|
+
self._attach_connections(duckdb_conn)
|
|
62
|
+
|
|
63
|
+
# Construct build models
|
|
64
|
+
await self._build_models(duckdb_conn, select, full_refresh)
|
|
65
|
+
|
|
66
|
+
finally:
|
|
67
|
+
duckdb_conn.close()
|
|
68
|
+
|
|
69
|
+
self._logger.log_activity_time("TOTAL TIME to build the Virtual Data Lake (VDL)", start)
|