squirrels 0.5.0b4__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of squirrels might be problematic. Click here for more details.
- squirrels/__init__.py +2 -0
- squirrels/_api_routes/auth.py +83 -74
- squirrels/_api_routes/base.py +58 -41
- squirrels/_api_routes/dashboards.py +37 -21
- squirrels/_api_routes/data_management.py +72 -27
- squirrels/_api_routes/datasets.py +107 -84
- squirrels/_api_routes/oauth2.py +11 -13
- squirrels/_api_routes/project.py +71 -33
- squirrels/_api_server.py +130 -63
- squirrels/_arguments/run_time_args.py +9 -9
- squirrels/_auth.py +117 -162
- squirrels/_command_line.py +68 -32
- squirrels/_compile_prompts.py +147 -0
- squirrels/_connection_set.py +11 -2
- squirrels/_constants.py +22 -8
- squirrels/_data_sources.py +38 -32
- squirrels/_dataset_types.py +2 -4
- squirrels/_initializer.py +1 -1
- squirrels/_logging.py +117 -0
- squirrels/_manifest.py +125 -58
- squirrels/_model_builder.py +10 -54
- squirrels/_models.py +224 -108
- squirrels/_package_data/base_project/.env +15 -4
- squirrels/_package_data/base_project/.env.example +14 -3
- squirrels/_package_data/base_project/connections.yml +4 -3
- squirrels/_package_data/base_project/dashboards/dashboard_example.py +2 -2
- squirrels/_package_data/base_project/dashboards/dashboard_example.yml +4 -4
- squirrels/_package_data/base_project/duckdb_init.sql +1 -0
- squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
- squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
- squirrels/_package_data/base_project/models/federates/federate_example.py +22 -15
- squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
- squirrels/_package_data/base_project/models/federates/federate_example.yml +1 -1
- squirrels/_package_data/base_project/models/sources.yml +5 -6
- squirrels/_package_data/base_project/parameters.yml +24 -38
- squirrels/_package_data/base_project/pyconfigs/connections.py +5 -1
- squirrels/_package_data/base_project/pyconfigs/context.py +23 -12
- squirrels/_package_data/base_project/pyconfigs/parameters.py +68 -33
- squirrels/_package_data/base_project/pyconfigs/user.py +11 -18
- squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
- squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
- squirrels/_package_data/base_project/squirrels.yml.j2 +18 -28
- squirrels/_package_data/templates/squirrels_studio.html +20 -0
- squirrels/_parameter_configs.py +43 -22
- squirrels/_parameter_options.py +1 -1
- squirrels/_parameter_sets.py +8 -10
- squirrels/_project.py +351 -234
- squirrels/_request_context.py +33 -0
- squirrels/_schemas/auth_models.py +32 -9
- squirrels/_schemas/query_param_models.py +9 -1
- squirrels/_schemas/response_models.py +36 -10
- squirrels/_seeds.py +1 -1
- squirrels/_sources.py +23 -19
- squirrels/_utils.py +83 -35
- squirrels/_version.py +1 -1
- squirrels/arguments.py +5 -0
- squirrels/auth.py +4 -1
- squirrels/connections.py +2 -0
- squirrels/dashboards.py +3 -1
- squirrels/data_sources.py +6 -0
- squirrels/parameter_options.py +5 -0
- squirrels/parameters.py +5 -0
- squirrels/types.py +6 -1
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/METADATA +28 -13
- squirrels-0.5.1.dist-info/RECORD +98 -0
- squirrels-0.5.0b4.dist-info/RECORD +0 -94
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
- {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Data management routes for build and query models
|
|
3
3
|
"""
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import Any
|
|
5
5
|
from fastapi import FastAPI, Depends, Request, Response, status
|
|
6
6
|
from fastapi.responses import JSONResponse
|
|
7
7
|
from fastapi.security import HTTPBearer
|
|
@@ -9,13 +9,12 @@ from dataclasses import asdict
|
|
|
9
9
|
from cachetools import TTLCache
|
|
10
10
|
import time
|
|
11
11
|
|
|
12
|
-
from .. import _constants as c
|
|
12
|
+
from .. import _constants as c, _utils as u
|
|
13
13
|
from .._schemas import response_models as rm
|
|
14
14
|
from .._exceptions import InvalidInputError
|
|
15
|
-
from ..
|
|
16
|
-
from .._manifest import PermissionScope
|
|
15
|
+
from .._schemas.auth_models import AbstractUser
|
|
17
16
|
from .._dataset_types import DatasetResult
|
|
18
|
-
from .._schemas.query_param_models import get_query_models_for_querying_models
|
|
17
|
+
from .._schemas.query_param_models import get_query_models_for_querying_models, get_query_models_for_compiled_models
|
|
19
18
|
from .base import RouteBase
|
|
20
19
|
|
|
21
20
|
|
|
@@ -31,55 +30,74 @@ class DataManagementRoutes(RouteBase):
|
|
|
31
30
|
self.query_models_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
|
|
32
31
|
|
|
33
32
|
async def _query_models_helper(
|
|
34
|
-
self, sql_query: str, user:
|
|
33
|
+
self, sql_query: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
|
|
35
34
|
) -> DatasetResult:
|
|
36
35
|
"""Helper to query models"""
|
|
37
|
-
|
|
36
|
+
cfg_filtered = {k: v for k, v in dict(configurables).items() if k in self.manifest_cfg.configurables}
|
|
37
|
+
return await self.project.query_models(sql_query, user=user, selections=dict(selections), configurables=cfg_filtered)
|
|
38
38
|
|
|
39
39
|
async def _query_models_cachable(
|
|
40
|
-
self, sql_query: str, user:
|
|
40
|
+
self, sql_query: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
|
|
41
41
|
) -> DatasetResult:
|
|
42
42
|
"""Cachable version of query models helper"""
|
|
43
|
-
return await self.do_cachable_action(self.query_models_cache, self._query_models_helper, sql_query, user, selections)
|
|
43
|
+
return await self.do_cachable_action(self.query_models_cache, self._query_models_helper, sql_query, user, selections, configurables)
|
|
44
44
|
|
|
45
45
|
async def _query_models_definition(
|
|
46
|
-
self, user:
|
|
46
|
+
self, user: AbstractUser, all_request_params: dict, params: dict, *, headers: dict[str, str]
|
|
47
47
|
) -> rm.DatasetResultModel:
|
|
48
48
|
"""Query models definition"""
|
|
49
|
-
self._validate_request_params(all_request_params, params)
|
|
49
|
+
self._validate_request_params(all_request_params, params, headers)
|
|
50
50
|
|
|
51
|
-
if not
|
|
52
|
-
raise InvalidInputError(403, "
|
|
51
|
+
if not u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level):
|
|
52
|
+
raise InvalidInputError(403, "unauthorized_access_to_query_models", f"User '{user}' does not have permission to query data models")
|
|
53
53
|
|
|
54
54
|
sql_query = params.get("x_sql_query")
|
|
55
55
|
if sql_query is None:
|
|
56
|
-
raise InvalidInputError(400, "
|
|
56
|
+
raise InvalidInputError(400, "sql_query_required", "SQL query must be provided")
|
|
57
57
|
|
|
58
58
|
query_models_function = self._query_models_helper if self.no_cache else self._query_models_cachable
|
|
59
59
|
uncached_keys = {"x_verify_params", "x_sql_query", "x_orientation", "x_limit", "x_offset"}
|
|
60
60
|
selections = self.get_selections_as_immutable(params, uncached_keys)
|
|
61
|
-
|
|
61
|
+
configurables = self.get_configurables_from_headers(headers)
|
|
62
|
+
result = await query_models_function(sql_query, user, selections, configurables)
|
|
62
63
|
|
|
63
|
-
|
|
64
|
+
orientation_header = headers.get("x-orientation")
|
|
65
|
+
orientation = str(orientation_header).lower() if orientation_header is not None else params.get("x_orientation", "records")
|
|
64
66
|
limit = params.get("x_limit", 1000)
|
|
65
67
|
offset = params.get("x_offset", 0)
|
|
66
|
-
return rm.DatasetResultModel(**result.to_json(orientation,
|
|
68
|
+
return rm.DatasetResultModel(**result.to_json(orientation, limit, offset))
|
|
67
69
|
|
|
70
|
+
async def _get_compiled_model_definition(
|
|
71
|
+
self, model_name: str, user: AbstractUser, all_request_params: dict, params: dict, *, headers: dict[str, str]
|
|
72
|
+
) -> rm.CompiledQueryModel:
|
|
73
|
+
"""Get compiled model definition"""
|
|
74
|
+
normalized_model_name = u.normalize_name(model_name)
|
|
75
|
+
self._validate_request_params(all_request_params, params, headers)
|
|
76
|
+
|
|
77
|
+
# Internal users only
|
|
78
|
+
if not u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level):
|
|
79
|
+
raise InvalidInputError(403, "unauthorized_access_to_compile_model", f"User '{user}' does not have permission to fetch compiled SQL")
|
|
80
|
+
|
|
81
|
+
selections = self.get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
|
|
82
|
+
configurables = self.get_configurables_from_headers(headers)
|
|
83
|
+
cfg_filtered = {k: v for k, v in dict(configurables).items() if k in self.manifest_cfg.configurables}
|
|
84
|
+
return await self.project.get_compiled_model_query(normalized_model_name, user=user, selections=dict(selections), configurables=cfg_filtered)
|
|
85
|
+
|
|
68
86
|
def setup_routes(self, app: FastAPI, project_metadata_path: str, param_fields: dict) -> None:
|
|
69
87
|
"""Setup data management routes"""
|
|
70
88
|
|
|
71
89
|
# Build project endpoint
|
|
72
90
|
build_path = project_metadata_path + '/build'
|
|
73
91
|
|
|
74
|
-
@app.post(build_path, tags=["Data Management"], summary="Build or update the
|
|
92
|
+
@app.post(build_path, tags=["Data Management"], summary="Build or update the Virtual Data Lake (VDL) for the project")
|
|
75
93
|
async def build(user=Depends(self.get_current_user)): # type: ignore
|
|
76
|
-
if not
|
|
77
|
-
raise InvalidInputError(403, "
|
|
78
|
-
await self.project.build(
|
|
94
|
+
if not u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level):
|
|
95
|
+
raise InvalidInputError(403, "unauthorized_access_to_build_model", f"User '{user}' does not have permission to build the virtual data lake (VDL)")
|
|
96
|
+
await self.project.build()
|
|
79
97
|
return Response(status_code=status.HTTP_200_OK)
|
|
80
98
|
|
|
81
|
-
# Query
|
|
82
|
-
query_models_path = project_metadata_path + '/query-
|
|
99
|
+
# Query result endpoints
|
|
100
|
+
query_models_path = project_metadata_path + '/query-result'
|
|
83
101
|
QueryModelForQueryModels, QueryModelForPostQueryModels = get_query_models_for_querying_models(param_fields)
|
|
84
102
|
|
|
85
103
|
@app.get(query_models_path, tags=["Data Management"], response_class=JSONResponse)
|
|
@@ -87,8 +105,8 @@ class DataManagementRoutes(RouteBase):
|
|
|
87
105
|
request: Request, params: QueryModelForQueryModels, user=Depends(self.get_current_user) # type: ignore
|
|
88
106
|
) -> rm.DatasetResultModel:
|
|
89
107
|
start = time.time()
|
|
90
|
-
result = await self._query_models_definition(user, dict(request.query_params), asdict(params))
|
|
91
|
-
self.log_activity_time("GET REQUEST for QUERY MODELS", start
|
|
108
|
+
result = await self._query_models_definition(user, dict(request.query_params), asdict(params), headers=dict(request.headers))
|
|
109
|
+
self.logger.log_activity_time("GET REQUEST for QUERY MODELS", start)
|
|
92
110
|
return result
|
|
93
111
|
|
|
94
112
|
@app.post(query_models_path, tags=["Data Management"], response_class=JSONResponse)
|
|
@@ -97,7 +115,34 @@ class DataManagementRoutes(RouteBase):
|
|
|
97
115
|
) -> rm.DatasetResultModel:
|
|
98
116
|
start = time.time()
|
|
99
117
|
payload: dict = await request.json()
|
|
100
|
-
result = await self._query_models_definition(user, payload, params.model_dump())
|
|
101
|
-
self.log_activity_time("POST REQUEST for QUERY MODELS", start
|
|
118
|
+
result = await self._query_models_definition(user, payload, params.model_dump(), headers=dict(request.headers))
|
|
119
|
+
self.logger.log_activity_time("POST REQUEST for QUERY MODELS", start)
|
|
120
|
+
return result
|
|
121
|
+
|
|
122
|
+
# Compiled models endpoints - TODO: remove duplication
|
|
123
|
+
compiled_models_path = project_metadata_path + '/compiled-models/{model_name}'
|
|
124
|
+
QueryModelForGetCompiled, QueryModelForPostCompiled = get_query_models_for_compiled_models(param_fields)
|
|
125
|
+
|
|
126
|
+
@app.get(compiled_models_path, tags=["Data Management"], response_class=JSONResponse, summary="Get compiled definition for a model")
|
|
127
|
+
async def get_compiled_model(
|
|
128
|
+
request: Request, model_name: str, params: QueryModelForGetCompiled, user=Depends(self.get_current_user)
|
|
129
|
+
) -> rm.CompiledQueryModel:
|
|
130
|
+
start = time.time()
|
|
131
|
+
result = await self._get_compiled_model_definition(model_name, user, dict(request.query_params), asdict(params), headers=dict(request.headers))
|
|
132
|
+
self.logger.log_activity_time(
|
|
133
|
+
"GET REQUEST for GET COMPILED MODEL", start, additional_data={"model_name": model_name}
|
|
134
|
+
)
|
|
135
|
+
return result
|
|
136
|
+
|
|
137
|
+
@app.post(compiled_models_path, tags=["Data Management"], response_class=JSONResponse, summary="Get compiled definition for a model")
|
|
138
|
+
async def get_compiled_model_with_post(
|
|
139
|
+
request: Request, model_name: str, params: QueryModelForPostCompiled, user=Depends(self.get_current_user)
|
|
140
|
+
) -> rm.CompiledQueryModel:
|
|
141
|
+
start = time.time()
|
|
142
|
+
payload: dict = await request.json()
|
|
143
|
+
result = await self._get_compiled_model_definition(model_name, user, payload, params.model_dump(), headers=dict(request.headers))
|
|
144
|
+
self.logger.log_activity_time(
|
|
145
|
+
"POST REQUEST for GET COMPILED MODEL", start, additional_data={"model_name": model_name}
|
|
146
|
+
)
|
|
102
147
|
return result
|
|
103
148
|
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Dataset routes for parameters and results
|
|
3
3
|
"""
|
|
4
|
-
from typing import Callable, Any
|
|
5
|
-
from pydantic import Field
|
|
4
|
+
from typing import Callable, Coroutine, Any
|
|
5
|
+
from pydantic import Field
|
|
6
6
|
from fastapi import FastAPI, Depends, Request
|
|
7
|
-
from fastapi.responses import JSONResponse
|
|
7
|
+
from fastapi.responses import JSONResponse
|
|
8
8
|
from fastapi.security import HTTPBearer
|
|
9
9
|
|
|
10
10
|
from mcp.server.fastmcp import FastMCP, Context
|
|
@@ -12,14 +12,14 @@ from dataclasses import asdict
|
|
|
12
12
|
from cachetools import TTLCache
|
|
13
13
|
from textwrap import dedent
|
|
14
14
|
|
|
15
|
-
import time
|
|
15
|
+
import time, json
|
|
16
16
|
|
|
17
17
|
from .. import _constants as c, _utils as u
|
|
18
18
|
from .._schemas import response_models as rm
|
|
19
19
|
from .._exceptions import ConfigurationError, InvalidInputError
|
|
20
20
|
from .._dataset_types import DatasetResult
|
|
21
21
|
from .._schemas.query_param_models import get_query_models_for_parameters, get_query_models_for_dataset
|
|
22
|
-
from ..
|
|
22
|
+
from .._schemas.auth_models import AbstractUser
|
|
23
23
|
from .base import RouteBase
|
|
24
24
|
|
|
25
25
|
|
|
@@ -34,39 +34,57 @@ class DatasetRoutes(RouteBase):
|
|
|
34
34
|
dataset_results_cache_ttl = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_TTL_MINUTES, 60))
|
|
35
35
|
self.dataset_results_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
|
|
36
36
|
|
|
37
|
+
# Setup max rows for AI
|
|
38
|
+
self.max_rows_for_ai = int(self.env_vars.get(c.SQRL_DATASETS_MAX_ROWS_FOR_AI, 100))
|
|
39
|
+
|
|
37
40
|
async def _get_dataset_results_helper(
|
|
38
|
-
self, dataset: str, user:
|
|
41
|
+
self, dataset: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
|
|
39
42
|
) -> DatasetResult:
|
|
40
43
|
"""Helper to get dataset results"""
|
|
41
|
-
|
|
44
|
+
# Only pass configurables that are defined in manifest
|
|
45
|
+
cfg_filtered = {k: v for k, v in dict(configurables).items() if k in self.manifest_cfg.configurables}
|
|
46
|
+
return await self.project.dataset(dataset, user=user, selections=dict(selections), configurables=cfg_filtered)
|
|
42
47
|
|
|
43
48
|
async def _get_dataset_results_cachable(
|
|
44
|
-
self, dataset: str, user:
|
|
49
|
+
self, dataset: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
|
|
45
50
|
) -> DatasetResult:
|
|
46
51
|
"""Cachable version of dataset results helper"""
|
|
47
|
-
return await self.do_cachable_action(self.dataset_results_cache, self._get_dataset_results_helper, dataset, user, selections)
|
|
52
|
+
return await self.do_cachable_action(self.dataset_results_cache, self._get_dataset_results_helper, dataset, user, selections, configurables)
|
|
48
53
|
|
|
49
54
|
async def _get_dataset_results_definition(
|
|
50
|
-
self, dataset_name: str, user:
|
|
55
|
+
self, dataset_name: str, user: AbstractUser, all_request_params: dict, params: dict, headers: dict[str, str]
|
|
51
56
|
) -> rm.DatasetResultModel:
|
|
52
57
|
"""Get dataset results definition"""
|
|
53
|
-
self._validate_request_params(all_request_params, params)
|
|
58
|
+
self._validate_request_params(all_request_params, params, headers)
|
|
54
59
|
|
|
55
60
|
get_dataset_function = self._get_dataset_results_helper if self.no_cache else self._get_dataset_results_cachable
|
|
56
|
-
uncached_keys = {"x_verify_params", "x_orientation", "
|
|
61
|
+
uncached_keys = {"x_verify_params", "x_orientation", "x_sql_query", "x_limit", "x_offset"}
|
|
57
62
|
selections = self.get_selections_as_immutable(params, uncached_keys)
|
|
58
|
-
result = await get_dataset_function(dataset_name, user, selections)
|
|
59
63
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
64
|
+
user_has_elevated_privileges = u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level)
|
|
65
|
+
configurables = self.get_configurables_from_headers(headers) if user_has_elevated_privileges else tuple()
|
|
66
|
+
result = await get_dataset_function(dataset_name, user, selections, configurables)
|
|
67
|
+
|
|
68
|
+
# Apply optional final SQL transformation before select/limit/offset
|
|
69
|
+
sql_query = params.get("x_sql_query")
|
|
70
|
+
if sql_query:
|
|
71
|
+
try:
|
|
72
|
+
transformed = u.run_sql_on_dataframes(sql_query, {"result": result.df.lazy()})
|
|
73
|
+
except Exception as e:
|
|
74
|
+
raise InvalidInputError(400, "invalid_sql_query", "Failed to run provided SQL on the dataset result") from e
|
|
75
|
+
|
|
76
|
+
transformed = transformed.drop("_row_num", strict=False).with_row_index("_row_num", offset=1)
|
|
77
|
+
result = DatasetResult(target_model_config=result.target_model_config, df=transformed)
|
|
78
|
+
|
|
79
|
+
orientation_header = headers.get("x-orientation")
|
|
80
|
+
orientation = str(orientation_header).lower() if orientation_header is not None else params.get("x_orientation", "records")
|
|
63
81
|
limit = params.get("x_limit", 1000)
|
|
64
82
|
offset = params.get("x_offset", 0)
|
|
65
|
-
return rm.DatasetResultModel(**result.to_json(orientation,
|
|
83
|
+
return rm.DatasetResultModel(**result.to_json(orientation, limit, offset))
|
|
66
84
|
|
|
67
85
|
def setup_routes(
|
|
68
|
-
self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str,
|
|
69
|
-
param_fields: dict, get_parameters_definition: Callable
|
|
86
|
+
self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_label: str,
|
|
87
|
+
param_fields: dict, get_parameters_definition: Callable[..., Coroutine[Any, Any, rm.ParametersModel]]
|
|
70
88
|
) -> None:
|
|
71
89
|
"""Setup dataset routes"""
|
|
72
90
|
|
|
@@ -84,19 +102,19 @@ class DatasetRoutes(RouteBase):
|
|
|
84
102
|
f"\n {all_params}"
|
|
85
103
|
)
|
|
86
104
|
|
|
87
|
-
async def get_dataset_parameters_updates(dataset_name: str, user:
|
|
105
|
+
async def get_dataset_parameters_updates(dataset_name: str, user: AbstractUser, all_request_params: dict, params: dict, headers: dict[str, str]):
|
|
88
106
|
parameters_list = self.manifest_cfg.datasets[dataset_name].parameters
|
|
89
107
|
scope = self.manifest_cfg.datasets[dataset_name].scope
|
|
90
108
|
result = await get_parameters_definition(
|
|
91
|
-
parameters_list, "dataset", dataset_name, scope, user, all_request_params, params
|
|
109
|
+
parameters_list, "dataset", dataset_name, scope, user, all_request_params, params, headers=headers
|
|
92
110
|
)
|
|
93
111
|
return result
|
|
94
112
|
|
|
95
113
|
# Dataset parameters and results APIs
|
|
96
114
|
for dataset_name, dataset_config in self.manifest_cfg.datasets.items():
|
|
97
|
-
|
|
98
|
-
curr_parameters_path = dataset_parameters_path.format(dataset=
|
|
99
|
-
curr_results_path = dataset_results_path.format(dataset=
|
|
115
|
+
dataset_name_for_api = u.normalize_name_for_api(dataset_name)
|
|
116
|
+
curr_parameters_path = dataset_parameters_path.format(dataset=dataset_name_for_api)
|
|
117
|
+
curr_results_path = dataset_results_path.format(dataset=dataset_name_for_api)
|
|
100
118
|
|
|
101
119
|
validate_parameters_list(dataset_config.parameters, "Dataset", dataset_name)
|
|
102
120
|
|
|
@@ -109,8 +127,10 @@ class DatasetRoutes(RouteBase):
|
|
|
109
127
|
) -> rm.ParametersModel:
|
|
110
128
|
start = time.time()
|
|
111
129
|
curr_dataset_name = self.get_name_from_path_section(request, -2)
|
|
112
|
-
result = await get_dataset_parameters_updates(curr_dataset_name, user, dict(request.query_params), asdict(params))
|
|
113
|
-
self.log_activity_time(
|
|
130
|
+
result = await get_dataset_parameters_updates(curr_dataset_name, user, dict(request.query_params), asdict(params), dict(request.headers))
|
|
131
|
+
self.logger.log_activity_time(
|
|
132
|
+
"GET REQUEST for PARAMETERS", start, additional_data={"dataset_name": curr_dataset_name}
|
|
133
|
+
)
|
|
114
134
|
return result
|
|
115
135
|
|
|
116
136
|
@app.post(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=self._parameters_description, response_class=JSONResponse)
|
|
@@ -120,8 +140,10 @@ class DatasetRoutes(RouteBase):
|
|
|
120
140
|
start = time.time()
|
|
121
141
|
curr_dataset_name = self.get_name_from_path_section(request, -2)
|
|
122
142
|
payload: dict = await request.json()
|
|
123
|
-
result = await get_dataset_parameters_updates(curr_dataset_name, user, payload, params.model_dump())
|
|
124
|
-
self.log_activity_time(
|
|
143
|
+
result = await get_dataset_parameters_updates(curr_dataset_name, user, payload, params.model_dump(), dict(request.headers))
|
|
144
|
+
self.logger.log_activity_time(
|
|
145
|
+
"POST REQUEST for PARAMETERS", start, additional_data={"dataset_name": curr_dataset_name}
|
|
146
|
+
)
|
|
125
147
|
return result
|
|
126
148
|
|
|
127
149
|
@app.get(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
|
|
@@ -130,8 +152,12 @@ class DatasetRoutes(RouteBase):
|
|
|
130
152
|
) -> rm.DatasetResultModel:
|
|
131
153
|
start = time.time()
|
|
132
154
|
curr_dataset_name = self.get_name_from_path_section(request, -1)
|
|
133
|
-
result = await self._get_dataset_results_definition(
|
|
134
|
-
|
|
155
|
+
result = await self._get_dataset_results_definition(
|
|
156
|
+
curr_dataset_name, user, dict(request.query_params), asdict(params), headers=dict(request.headers)
|
|
157
|
+
)
|
|
158
|
+
self.logger.log_activity_time(
|
|
159
|
+
"GET REQUEST for DATASET RESULTS", start, additional_data={"dataset_name": curr_dataset_name}
|
|
160
|
+
)
|
|
135
161
|
return result
|
|
136
162
|
|
|
137
163
|
@app.post(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
|
|
@@ -141,18 +167,23 @@ class DatasetRoutes(RouteBase):
|
|
|
141
167
|
start = time.time()
|
|
142
168
|
curr_dataset_name = self.get_name_from_path_section(request, -1)
|
|
143
169
|
payload: dict = await request.json()
|
|
144
|
-
result = await self._get_dataset_results_definition(
|
|
145
|
-
|
|
170
|
+
result = await self._get_dataset_results_definition(
|
|
171
|
+
curr_dataset_name, user, payload, params.model_dump(), headers=dict(request.headers)
|
|
172
|
+
)
|
|
173
|
+
self.logger.log_activity_time(
|
|
174
|
+
"POST REQUEST for DATASET RESULTS", start, additional_data={"dataset_name": curr_dataset_name}
|
|
175
|
+
)
|
|
146
176
|
return result
|
|
147
177
|
|
|
148
178
|
# Setup MCP tools
|
|
149
179
|
|
|
150
180
|
@mcp.tool(
|
|
151
|
-
name=f"
|
|
181
|
+
name=f"get_dataset_parameters_from_{project_name}",
|
|
182
|
+
title=f"Get Dataset Parameters Updates (Project: {project_label})",
|
|
152
183
|
description=dedent(f"""
|
|
153
184
|
Use this tool to get updates for dataset parameters in the Squirrels project "{project_name}" when a selection is to be made on a parameter with "trigger_refresh" as true.
|
|
154
185
|
|
|
155
|
-
For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool
|
|
186
|
+
For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool should be called to get the updates for other parameters such as "city".
|
|
156
187
|
|
|
157
188
|
Do not use this tool on parameters whose "trigger_refresh" field is false!
|
|
158
189
|
""").strip()
|
|
@@ -162,28 +193,30 @@ class DatasetRoutes(RouteBase):
|
|
|
162
193
|
dataset: str = Field(description="The name of the dataset whose parameters the trigger parameter will update"),
|
|
163
194
|
parameter_name: str = Field(description="The name of the parameter triggering the refresh"),
|
|
164
195
|
selected_ids: list[str] = Field(description="The ID(s) of the selected option(s) for the parameter"),
|
|
165
|
-
):
|
|
166
|
-
|
|
196
|
+
) -> rm.ParametersModel:
|
|
197
|
+
headers = self.get_headers_from_tool_ctx(ctx)
|
|
198
|
+
user = self.get_user_from_tool_headers(headers)
|
|
167
199
|
dataset_name = u.normalize_name(dataset)
|
|
168
200
|
payload = {
|
|
169
201
|
"x_parent_param": parameter_name,
|
|
170
202
|
parameter_name: selected_ids
|
|
171
203
|
}
|
|
172
|
-
return await get_dataset_parameters_updates(dataset_name, user, payload, payload)
|
|
204
|
+
return await get_dataset_parameters_updates(dataset_name, user, payload, payload, headers)
|
|
173
205
|
|
|
174
206
|
@mcp.tool(
|
|
175
|
-
name=f"
|
|
207
|
+
name=f"get_dataset_results_from_{project_name}",
|
|
208
|
+
title=f"Get Dataset Results (Project: {project_label})",
|
|
176
209
|
description=dedent(f"""
|
|
177
210
|
Use this tool to get the dataset results as a JSON object for a dataset in the Squirrels project "{project_name}".
|
|
178
211
|
- Use the "offset" and "limit" arguments to limit the number of rows you require
|
|
179
|
-
- The "limit" argument controls the number of rows returned. The maximum allowed value is
|
|
212
|
+
- The "limit" argument controls the number of rows returned. The maximum allowed value is {self.max_rows_for_ai}. If the 'total_num_rows' field in the response is greater than {self.max_rows_for_ai}, let the user know that only {self.max_rows_for_ai} rows are shown and clarify if they would like to see more.
|
|
180
213
|
""").strip()
|
|
181
214
|
)
|
|
182
215
|
async def get_dataset_results_tool(
|
|
183
216
|
ctx: Context,
|
|
184
217
|
dataset: str = Field(description="The name of the dataset to get results for"),
|
|
185
|
-
parameters:
|
|
186
|
-
|
|
218
|
+
parameters: str = Field(description=dedent("""
|
|
219
|
+
A JSON object (as string) containing key-value pairs for parameter name and selected value. The selected value to provide depends on the parameter widget type:
|
|
187
220
|
- For single select, use a string for the ID of the selected value
|
|
188
221
|
- For multi select, use an array of strings for the IDs of the selected values
|
|
189
222
|
- For date, use a string like "YYYY-MM-DD"
|
|
@@ -191,52 +224,42 @@ class DatasetRoutes(RouteBase):
|
|
|
191
224
|
- For number, use a number like 1
|
|
192
225
|
- For number ranges, use array of numbers like [1,100]
|
|
193
226
|
- For text, use a string for the text value
|
|
194
|
-
- Complex objects are NOT supported
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
227
|
+
- Complex objects are NOT supported
|
|
228
|
+
""").strip()),
|
|
229
|
+
sql_query: str | None = Field(None, description=dedent("""
|
|
230
|
+
A custom DuckDB SQL query to execute on the final dataset result.
|
|
231
|
+
- Use table name 'result' to reference the dataset result.
|
|
232
|
+
- Use this to apply transformations to the dataset result if needed (such as filtering, sorting, or selecting columns).
|
|
233
|
+
- If not provided, the dataset result is returned as is.
|
|
234
|
+
""").strip()),
|
|
235
|
+
offset: int = Field(0, description="The number of rows to skip from first row. Applied after final SQL. Default is 0."),
|
|
236
|
+
limit: int = Field(self.max_rows_for_ai, description=f"The maximum number of rows to return. Applied after final SQL. Default is {self.max_rows_for_ai}. Maximum allowed value is {self.max_rows_for_ai}."),
|
|
237
|
+
) -> rm.DatasetResultModel:
|
|
238
|
+
if limit > self.max_rows_for_ai:
|
|
239
|
+
raise ValueError(f"The maximum number of rows to return is {self.max_rows_for_ai}.")
|
|
240
|
+
|
|
241
|
+
headers = self.get_headers_from_tool_ctx(ctx)
|
|
242
|
+
user = self.get_user_from_tool_headers(headers)
|
|
202
243
|
dataset_name = u.normalize_name(dataset)
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
params = json.loads(parameters)
|
|
247
|
+
except json.JSONDecodeError:
|
|
248
|
+
params = None # error handled below
|
|
249
|
+
|
|
250
|
+
if not isinstance(params, dict):
|
|
251
|
+
raise InvalidInputError(400, "invalid_parameters", f"The 'parameters' argument must be a JSON object.")
|
|
252
|
+
|
|
253
|
+
params.update({
|
|
254
|
+
"x_sql_query": sql_query,
|
|
206
255
|
"x_offset": offset,
|
|
207
256
|
"x_limit": limit
|
|
208
|
-
}
|
|
209
|
-
result = await self._get_dataset_results_definition(dataset_name, user, params, params)
|
|
210
|
-
return result
|
|
211
|
-
|
|
212
|
-
# Setup UI for tool results
|
|
213
|
-
mcp_tool_results_ui_path = project_metadata_path + "/mcp/tool-results-ui"
|
|
214
|
-
|
|
215
|
-
@app.get(mcp_tool_results_ui_path + "/list-tools", tags=["MCP Supplements"])
|
|
216
|
-
async def list_tools():
|
|
217
|
-
return ["get_dataset_results"]
|
|
257
|
+
})
|
|
218
258
|
|
|
219
|
-
|
|
220
|
-
""
|
|
259
|
+
# Set default orientation as rows if not provided
|
|
260
|
+
if "x-orientation" not in headers:
|
|
261
|
+
headers["x-orientation"] = "rows"
|
|
221
262
|
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
@app.post(mcp_tool_results_ui_path + "/tool/{tool_name}", tags=["MCP Supplements"])
|
|
226
|
-
async def tool_results_ui(tool_name: str, tool_result: ToolResultBody):
|
|
227
|
-
if tool_name == "get_dataset_results":
|
|
228
|
-
# Convert Pydantic model to dict to access any extra fields
|
|
229
|
-
tool_result_dict = tool_result.model_dump()
|
|
230
|
-
|
|
231
|
-
# Prepare template context
|
|
232
|
-
context = {
|
|
233
|
-
"schema": tool_result_dict.get("schema", {}),
|
|
234
|
-
"data": tool_result_dict.get("data", []),
|
|
235
|
-
}
|
|
236
|
-
|
|
237
|
-
# Render HTML template
|
|
238
|
-
html_content = self.templates.get_template("dataset_results.html").render(context)
|
|
239
|
-
return HTMLResponse(content=html_content, status_code=200)
|
|
240
|
-
else:
|
|
241
|
-
raise InvalidInputError(400, "Invalid tool name", f"Tool name '{tool_name}' not supported for UI")
|
|
263
|
+
result = await self._get_dataset_results_definition(dataset_name, user, params, params, headers)
|
|
264
|
+
return result
|
|
242
265
|
|
squirrels/_api_routes/oauth2.py
CHANGED
|
@@ -6,9 +6,10 @@ from typing import Annotated, cast
|
|
|
6
6
|
from .base import RouteBase
|
|
7
7
|
from .._schemas.auth_models import (
|
|
8
8
|
ClientRegistrationRequest, ClientUpdateRequest, ClientRegistrationResponse, ClientDetailsResponse, ClientUpdateResponse,
|
|
9
|
-
TokenResponse, OAuthServerMetadata
|
|
9
|
+
TokenResponse, OAuthServerMetadata, AbstractUser
|
|
10
10
|
)
|
|
11
11
|
from .._exceptions import InvalidInputError
|
|
12
|
+
from .. import _utils as u
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class OAuth2Routes(RouteBase):
|
|
@@ -50,16 +51,13 @@ class OAuth2Routes(RouteBase):
|
|
|
50
51
|
status_code=200
|
|
51
52
|
)
|
|
52
53
|
|
|
53
|
-
def setup_routes(self, app: FastAPI) -> None:
|
|
54
|
+
def setup_routes(self, app: FastAPI, squirrels_version_path: str) -> None:
|
|
54
55
|
"""Setup all OAuth2 routes"""
|
|
55
56
|
|
|
56
|
-
|
|
57
|
+
auth_path = squirrels_version_path + "/auth"
|
|
58
|
+
router_path = "/oauth2"
|
|
57
59
|
router = APIRouter(prefix=router_path)
|
|
58
60
|
|
|
59
|
-
# Create user models
|
|
60
|
-
class UserInfoModel(self.UserInfoModel):
|
|
61
|
-
username: str
|
|
62
|
-
|
|
63
61
|
# Authorization dependency for client management
|
|
64
62
|
get_client_token = HTTPBearer(auto_error=False)
|
|
65
63
|
|
|
@@ -93,7 +91,7 @@ class OAuth2Routes(RouteBase):
|
|
|
93
91
|
# Client Registration Endpoint
|
|
94
92
|
client_management_path = '/client/{client_id}'
|
|
95
93
|
|
|
96
|
-
@router.post("/
|
|
94
|
+
@router.post("/client", description="Register a new OAuth client", tags=["OAuth2"])
|
|
97
95
|
async def register_oauth_client(request: ClientRegistrationRequest) -> ClientRegistrationResponse:
|
|
98
96
|
"""Register a new OAuth client and return client credentials"""
|
|
99
97
|
|
|
@@ -148,7 +146,7 @@ class OAuth2Routes(RouteBase):
|
|
|
148
146
|
state: str | None = Query(default=None, description="State parameter for CSRF protection"),
|
|
149
147
|
code_challenge: str = Query(..., description="PKCE code challenge (required)"),
|
|
150
148
|
code_challenge_method: str = Query(default="S256", description="PKCE code challenge method"),
|
|
151
|
-
user:
|
|
149
|
+
user: AbstractUser = Depends(self.get_current_user)
|
|
152
150
|
):
|
|
153
151
|
"""OAuth 2.1 authorization endpoint for initiating authorization code flow"""
|
|
154
152
|
|
|
@@ -158,9 +156,9 @@ class OAuth2Routes(RouteBase):
|
|
|
158
156
|
raise InvalidInputError(400, "unsupported_response_type", "Only 'code' response type is supported")
|
|
159
157
|
|
|
160
158
|
# Check if user is authenticated
|
|
161
|
-
if user
|
|
159
|
+
if user.access_level == "guest":
|
|
162
160
|
# User is not authenticated - serve login page
|
|
163
|
-
return self.serve_login_page(
|
|
161
|
+
return self.serve_login_page(auth_path, request, client_id)
|
|
164
162
|
|
|
165
163
|
# TODO: Serve a page with an "authorize" button even if user is already authenticated
|
|
166
164
|
# Ex. if not request.session.get("authorization_approved"), redirect to a page with button that submits to "/approve-authorization"
|
|
@@ -281,7 +279,7 @@ class OAuth2Routes(RouteBase):
|
|
|
281
279
|
"""OAuth 2.1 Authorization Server Metadata endpoint (RFC 8414)"""
|
|
282
280
|
|
|
283
281
|
# Get the base URL from the request
|
|
284
|
-
scheme =
|
|
282
|
+
scheme = u.get_scheme(request.url.hostname)
|
|
285
283
|
base_url = scheme + "://" + request.url.netloc
|
|
286
284
|
|
|
287
285
|
return OAuthServerMetadata(
|
|
@@ -289,7 +287,7 @@ class OAuth2Routes(RouteBase):
|
|
|
289
287
|
authorization_endpoint=f"{base_url}{router_path}/authorize",
|
|
290
288
|
token_endpoint=f"{base_url}{router_path}/token",
|
|
291
289
|
revocation_endpoint=f"{base_url}{router_path}/token/revoke",
|
|
292
|
-
registration_endpoint=f"{base_url}{router_path}/
|
|
290
|
+
registration_endpoint=f"{base_url}{router_path}/client",
|
|
293
291
|
scopes_supported=["read"],
|
|
294
292
|
response_types_supported=["code"],
|
|
295
293
|
grant_types_supported=["authorization_code", "refresh_token"],
|