squirrels 0.5.0b4__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of squirrels might be problematic. Click here for more details.

Files changed (69) hide show
  1. squirrels/__init__.py +2 -0
  2. squirrels/_api_routes/auth.py +83 -74
  3. squirrels/_api_routes/base.py +58 -41
  4. squirrels/_api_routes/dashboards.py +37 -21
  5. squirrels/_api_routes/data_management.py +72 -27
  6. squirrels/_api_routes/datasets.py +107 -84
  7. squirrels/_api_routes/oauth2.py +11 -13
  8. squirrels/_api_routes/project.py +71 -33
  9. squirrels/_api_server.py +130 -63
  10. squirrels/_arguments/run_time_args.py +9 -9
  11. squirrels/_auth.py +117 -162
  12. squirrels/_command_line.py +68 -32
  13. squirrels/_compile_prompts.py +147 -0
  14. squirrels/_connection_set.py +11 -2
  15. squirrels/_constants.py +22 -8
  16. squirrels/_data_sources.py +38 -32
  17. squirrels/_dataset_types.py +2 -4
  18. squirrels/_initializer.py +1 -1
  19. squirrels/_logging.py +117 -0
  20. squirrels/_manifest.py +125 -58
  21. squirrels/_model_builder.py +10 -54
  22. squirrels/_models.py +224 -108
  23. squirrels/_package_data/base_project/.env +15 -4
  24. squirrels/_package_data/base_project/.env.example +14 -3
  25. squirrels/_package_data/base_project/connections.yml +4 -3
  26. squirrels/_package_data/base_project/dashboards/dashboard_example.py +2 -2
  27. squirrels/_package_data/base_project/dashboards/dashboard_example.yml +4 -4
  28. squirrels/_package_data/base_project/duckdb_init.sql +1 -0
  29. squirrels/_package_data/base_project/models/dbviews/dbview_example.sql +7 -2
  30. squirrels/_package_data/base_project/models/dbviews/dbview_example.yml +16 -10
  31. squirrels/_package_data/base_project/models/federates/federate_example.py +22 -15
  32. squirrels/_package_data/base_project/models/federates/federate_example.sql +3 -7
  33. squirrels/_package_data/base_project/models/federates/federate_example.yml +1 -1
  34. squirrels/_package_data/base_project/models/sources.yml +5 -6
  35. squirrels/_package_data/base_project/parameters.yml +24 -38
  36. squirrels/_package_data/base_project/pyconfigs/connections.py +5 -1
  37. squirrels/_package_data/base_project/pyconfigs/context.py +23 -12
  38. squirrels/_package_data/base_project/pyconfigs/parameters.py +68 -33
  39. squirrels/_package_data/base_project/pyconfigs/user.py +11 -18
  40. squirrels/_package_data/base_project/seeds/seed_categories.yml +1 -1
  41. squirrels/_package_data/base_project/seeds/seed_subcategories.yml +1 -1
  42. squirrels/_package_data/base_project/squirrels.yml.j2 +18 -28
  43. squirrels/_package_data/templates/squirrels_studio.html +20 -0
  44. squirrels/_parameter_configs.py +43 -22
  45. squirrels/_parameter_options.py +1 -1
  46. squirrels/_parameter_sets.py +8 -10
  47. squirrels/_project.py +351 -234
  48. squirrels/_request_context.py +33 -0
  49. squirrels/_schemas/auth_models.py +32 -9
  50. squirrels/_schemas/query_param_models.py +9 -1
  51. squirrels/_schemas/response_models.py +36 -10
  52. squirrels/_seeds.py +1 -1
  53. squirrels/_sources.py +23 -19
  54. squirrels/_utils.py +83 -35
  55. squirrels/_version.py +1 -1
  56. squirrels/arguments.py +5 -0
  57. squirrels/auth.py +4 -1
  58. squirrels/connections.py +2 -0
  59. squirrels/dashboards.py +3 -1
  60. squirrels/data_sources.py +6 -0
  61. squirrels/parameter_options.py +5 -0
  62. squirrels/parameters.py +5 -0
  63. squirrels/types.py +6 -1
  64. {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/METADATA +28 -13
  65. squirrels-0.5.1.dist-info/RECORD +98 -0
  66. squirrels-0.5.0b4.dist-info/RECORD +0 -94
  67. {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/WHEEL +0 -0
  68. {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/entry_points.txt +0 -0
  69. {squirrels-0.5.0b4.dist-info → squirrels-0.5.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,7 @@
1
1
  """
2
2
  Data management routes for build and query models
3
3
  """
4
- from typing import Callable, Any
4
+ from typing import Any
5
5
  from fastapi import FastAPI, Depends, Request, Response, status
6
6
  from fastapi.responses import JSONResponse
7
7
  from fastapi.security import HTTPBearer
@@ -9,13 +9,12 @@ from dataclasses import asdict
9
9
  from cachetools import TTLCache
10
10
  import time
11
11
 
12
- from .. import _constants as c
12
+ from .. import _constants as c, _utils as u
13
13
  from .._schemas import response_models as rm
14
14
  from .._exceptions import InvalidInputError
15
- from .._auth import BaseUser
16
- from .._manifest import PermissionScope
15
+ from .._schemas.auth_models import AbstractUser
17
16
  from .._dataset_types import DatasetResult
18
- from .._schemas.query_param_models import get_query_models_for_querying_models
17
+ from .._schemas.query_param_models import get_query_models_for_querying_models, get_query_models_for_compiled_models
19
18
  from .base import RouteBase
20
19
 
21
20
 
@@ -31,55 +30,74 @@ class DataManagementRoutes(RouteBase):
31
30
  self.query_models_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
32
31
 
33
32
  async def _query_models_helper(
34
- self, sql_query: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
33
+ self, sql_query: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
35
34
  ) -> DatasetResult:
36
35
  """Helper to query models"""
37
- return await self.project.query_models(sql_query, selections=dict(selections), user=user)
36
+ cfg_filtered = {k: v for k, v in dict(configurables).items() if k in self.manifest_cfg.configurables}
37
+ return await self.project.query_models(sql_query, user=user, selections=dict(selections), configurables=cfg_filtered)
38
38
 
39
39
  async def _query_models_cachable(
40
- self, sql_query: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
40
+ self, sql_query: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
41
41
  ) -> DatasetResult:
42
42
  """Cachable version of query models helper"""
43
- return await self.do_cachable_action(self.query_models_cache, self._query_models_helper, sql_query, user, selections)
43
+ return await self.do_cachable_action(self.query_models_cache, self._query_models_helper, sql_query, user, selections, configurables)
44
44
 
45
45
  async def _query_models_definition(
46
- self, user: BaseUser | None, all_request_params: dict, params: dict
46
+ self, user: AbstractUser, all_request_params: dict, params: dict, *, headers: dict[str, str]
47
47
  ) -> rm.DatasetResultModel:
48
48
  """Query models definition"""
49
- self._validate_request_params(all_request_params, params)
49
+ self._validate_request_params(all_request_params, params, headers)
50
50
 
51
- if not self.authenticator.can_user_access_scope(user, PermissionScope.PRIVATE):
52
- raise InvalidInputError(403, "Forbidden to query data models", f"User '{user}' does not have permission to query data models")
51
+ if not u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level):
52
+ raise InvalidInputError(403, "unauthorized_access_to_query_models", f"User '{user}' does not have permission to query data models")
53
53
 
54
54
  sql_query = params.get("x_sql_query")
55
55
  if sql_query is None:
56
- raise InvalidInputError(400, "SQL query must be provided", "SQL query must be provided")
56
+ raise InvalidInputError(400, "sql_query_required", "SQL query must be provided")
57
57
 
58
58
  query_models_function = self._query_models_helper if self.no_cache else self._query_models_cachable
59
59
  uncached_keys = {"x_verify_params", "x_sql_query", "x_orientation", "x_limit", "x_offset"}
60
60
  selections = self.get_selections_as_immutable(params, uncached_keys)
61
- result = await query_models_function(sql_query, user, selections)
61
+ configurables = self.get_configurables_from_headers(headers)
62
+ result = await query_models_function(sql_query, user, selections, configurables)
62
63
 
63
- orientation = params.get("x_orientation", "records")
64
+ orientation_header = headers.get("x-orientation")
65
+ orientation = str(orientation_header).lower() if orientation_header is not None else params.get("x_orientation", "records")
64
66
  limit = params.get("x_limit", 1000)
65
67
  offset = params.get("x_offset", 0)
66
- return rm.DatasetResultModel(**result.to_json(orientation, tuple(), limit, offset))
68
+ return rm.DatasetResultModel(**result.to_json(orientation, limit, offset))
67
69
 
70
+ async def _get_compiled_model_definition(
71
+ self, model_name: str, user: AbstractUser, all_request_params: dict, params: dict, *, headers: dict[str, str]
72
+ ) -> rm.CompiledQueryModel:
73
+ """Get compiled model definition"""
74
+ normalized_model_name = u.normalize_name(model_name)
75
+ self._validate_request_params(all_request_params, params, headers)
76
+
77
+ # Internal users only
78
+ if not u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level):
79
+ raise InvalidInputError(403, "unauthorized_access_to_compile_model", f"User '{user}' does not have permission to fetch compiled SQL")
80
+
81
+ selections = self.get_selections_as_immutable(params, uncached_keys={"x_verify_params"})
82
+ configurables = self.get_configurables_from_headers(headers)
83
+ cfg_filtered = {k: v for k, v in dict(configurables).items() if k in self.manifest_cfg.configurables}
84
+ return await self.project.get_compiled_model_query(normalized_model_name, user=user, selections=dict(selections), configurables=cfg_filtered)
85
+
68
86
  def setup_routes(self, app: FastAPI, project_metadata_path: str, param_fields: dict) -> None:
69
87
  """Setup data management routes"""
70
88
 
71
89
  # Build project endpoint
72
90
  build_path = project_metadata_path + '/build'
73
91
 
74
- @app.post(build_path, tags=["Data Management"], summary="Build or update the virtual data environment for the project")
92
+ @app.post(build_path, tags=["Data Management"], summary="Build or update the Virtual Data Lake (VDL) for the project")
75
93
  async def build(user=Depends(self.get_current_user)): # type: ignore
76
- if not self.authenticator.can_user_access_scope(user, PermissionScope.PRIVATE):
77
- raise InvalidInputError(403, "Forbidden to build", f"User '{user}' does not have permission to build the virtual data environment")
78
- await self.project.build(stage_file=True)
94
+ if not u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level):
95
+ raise InvalidInputError(403, "unauthorized_access_to_build_model", f"User '{user}' does not have permission to build the virtual data lake (VDL)")
96
+ await self.project.build()
79
97
  return Response(status_code=status.HTTP_200_OK)
80
98
 
81
- # Query models endpoints
82
- query_models_path = project_metadata_path + '/query-models'
99
+ # Query result endpoints
100
+ query_models_path = project_metadata_path + '/query-result'
83
101
  QueryModelForQueryModels, QueryModelForPostQueryModels = get_query_models_for_querying_models(param_fields)
84
102
 
85
103
  @app.get(query_models_path, tags=["Data Management"], response_class=JSONResponse)
@@ -87,8 +105,8 @@ class DataManagementRoutes(RouteBase):
87
105
  request: Request, params: QueryModelForQueryModels, user=Depends(self.get_current_user) # type: ignore
88
106
  ) -> rm.DatasetResultModel:
89
107
  start = time.time()
90
- result = await self._query_models_definition(user, dict(request.query_params), asdict(params))
91
- self.log_activity_time("GET REQUEST for QUERY MODELS", start, request)
108
+ result = await self._query_models_definition(user, dict(request.query_params), asdict(params), headers=dict(request.headers))
109
+ self.logger.log_activity_time("GET REQUEST for QUERY MODELS", start)
92
110
  return result
93
111
 
94
112
  @app.post(query_models_path, tags=["Data Management"], response_class=JSONResponse)
@@ -97,7 +115,34 @@ class DataManagementRoutes(RouteBase):
97
115
  ) -> rm.DatasetResultModel:
98
116
  start = time.time()
99
117
  payload: dict = await request.json()
100
- result = await self._query_models_definition(user, payload, params.model_dump())
101
- self.log_activity_time("POST REQUEST for QUERY MODELS", start, request)
118
+ result = await self._query_models_definition(user, payload, params.model_dump(), headers=dict(request.headers))
119
+ self.logger.log_activity_time("POST REQUEST for QUERY MODELS", start)
120
+ return result
121
+
122
+ # Compiled models endpoints - TODO: remove duplication
123
+ compiled_models_path = project_metadata_path + '/compiled-models/{model_name}'
124
+ QueryModelForGetCompiled, QueryModelForPostCompiled = get_query_models_for_compiled_models(param_fields)
125
+
126
+ @app.get(compiled_models_path, tags=["Data Management"], response_class=JSONResponse, summary="Get compiled definition for a model")
127
+ async def get_compiled_model(
128
+ request: Request, model_name: str, params: QueryModelForGetCompiled, user=Depends(self.get_current_user)
129
+ ) -> rm.CompiledQueryModel:
130
+ start = time.time()
131
+ result = await self._get_compiled_model_definition(model_name, user, dict(request.query_params), asdict(params), headers=dict(request.headers))
132
+ self.logger.log_activity_time(
133
+ "GET REQUEST for GET COMPILED MODEL", start, additional_data={"model_name": model_name}
134
+ )
135
+ return result
136
+
137
+ @app.post(compiled_models_path, tags=["Data Management"], response_class=JSONResponse, summary="Get compiled definition for a model")
138
+ async def get_compiled_model_with_post(
139
+ request: Request, model_name: str, params: QueryModelForPostCompiled, user=Depends(self.get_current_user)
140
+ ) -> rm.CompiledQueryModel:
141
+ start = time.time()
142
+ payload: dict = await request.json()
143
+ result = await self._get_compiled_model_definition(model_name, user, payload, params.model_dump(), headers=dict(request.headers))
144
+ self.logger.log_activity_time(
145
+ "POST REQUEST for GET COMPILED MODEL", start, additional_data={"model_name": model_name}
146
+ )
102
147
  return result
103
148
 
@@ -1,10 +1,10 @@
1
1
  """
2
2
  Dataset routes for parameters and results
3
3
  """
4
- from typing import Callable, Any
5
- from pydantic import Field, BaseModel
4
+ from typing import Callable, Coroutine, Any
5
+ from pydantic import Field
6
6
  from fastapi import FastAPI, Depends, Request
7
- from fastapi.responses import JSONResponse, HTMLResponse
7
+ from fastapi.responses import JSONResponse
8
8
  from fastapi.security import HTTPBearer
9
9
 
10
10
  from mcp.server.fastmcp import FastMCP, Context
@@ -12,14 +12,14 @@ from dataclasses import asdict
12
12
  from cachetools import TTLCache
13
13
  from textwrap import dedent
14
14
 
15
- import time
15
+ import time, json
16
16
 
17
17
  from .. import _constants as c, _utils as u
18
18
  from .._schemas import response_models as rm
19
19
  from .._exceptions import ConfigurationError, InvalidInputError
20
20
  from .._dataset_types import DatasetResult
21
21
  from .._schemas.query_param_models import get_query_models_for_parameters, get_query_models_for_dataset
22
- from .._auth import BaseUser
22
+ from .._schemas.auth_models import AbstractUser
23
23
  from .base import RouteBase
24
24
 
25
25
 
@@ -34,39 +34,57 @@ class DatasetRoutes(RouteBase):
34
34
  dataset_results_cache_ttl = int(self.env_vars.get(c.SQRL_DATASETS_CACHE_TTL_MINUTES, 60))
35
35
  self.dataset_results_cache = TTLCache(maxsize=dataset_results_cache_size, ttl=dataset_results_cache_ttl*60)
36
36
 
37
+ # Setup max rows for AI
38
+ self.max_rows_for_ai = int(self.env_vars.get(c.SQRL_DATASETS_MAX_ROWS_FOR_AI, 100))
39
+
37
40
  async def _get_dataset_results_helper(
38
- self, dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
41
+ self, dataset: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
39
42
  ) -> DatasetResult:
40
43
  """Helper to get dataset results"""
41
- return await self.project.dataset(dataset, selections=dict(selections), user=user)
44
+ # Only pass configurables that are defined in manifest
45
+ cfg_filtered = {k: v for k, v in dict(configurables).items() if k in self.manifest_cfg.configurables}
46
+ return await self.project.dataset(dataset, user=user, selections=dict(selections), configurables=cfg_filtered)
42
47
 
43
48
  async def _get_dataset_results_cachable(
44
- self, dataset: str, user: BaseUser | None, selections: tuple[tuple[str, Any], ...]
49
+ self, dataset: str, user: AbstractUser, selections: tuple[tuple[str, Any], ...], configurables: tuple[tuple[str, str], ...]
45
50
  ) -> DatasetResult:
46
51
  """Cachable version of dataset results helper"""
47
- return await self.do_cachable_action(self.dataset_results_cache, self._get_dataset_results_helper, dataset, user, selections)
52
+ return await self.do_cachable_action(self.dataset_results_cache, self._get_dataset_results_helper, dataset, user, selections, configurables)
48
53
 
49
54
  async def _get_dataset_results_definition(
50
- self, dataset_name: str, user: BaseUser | None, all_request_params: dict, params: dict
55
+ self, dataset_name: str, user: AbstractUser, all_request_params: dict, params: dict, headers: dict[str, str]
51
56
  ) -> rm.DatasetResultModel:
52
57
  """Get dataset results definition"""
53
- self._validate_request_params(all_request_params, params)
58
+ self._validate_request_params(all_request_params, params, headers)
54
59
 
55
60
  get_dataset_function = self._get_dataset_results_helper if self.no_cache else self._get_dataset_results_cachable
56
- uncached_keys = {"x_verify_params", "x_orientation", "x_select", "x_limit", "x_offset"}
61
+ uncached_keys = {"x_verify_params", "x_orientation", "x_sql_query", "x_limit", "x_offset"}
57
62
  selections = self.get_selections_as_immutable(params, uncached_keys)
58
- result = await get_dataset_function(dataset_name, user, selections)
59
63
 
60
- orientation = params.get("x_orientation", "records")
61
- raw_select: list[str] | None = params.get("x_select")
62
- select = tuple(raw_select) if raw_select is not None else tuple()
64
+ user_has_elevated_privileges = u.user_has_elevated_privileges(user.access_level, self.project._elevated_access_level)
65
+ configurables = self.get_configurables_from_headers(headers) if user_has_elevated_privileges else tuple()
66
+ result = await get_dataset_function(dataset_name, user, selections, configurables)
67
+
68
+ # Apply optional final SQL transformation before select/limit/offset
69
+ sql_query = params.get("x_sql_query")
70
+ if sql_query:
71
+ try:
72
+ transformed = u.run_sql_on_dataframes(sql_query, {"result": result.df.lazy()})
73
+ except Exception as e:
74
+ raise InvalidInputError(400, "invalid_sql_query", "Failed to run provided SQL on the dataset result") from e
75
+
76
+ transformed = transformed.drop("_row_num", strict=False).with_row_index("_row_num", offset=1)
77
+ result = DatasetResult(target_model_config=result.target_model_config, df=transformed)
78
+
79
+ orientation_header = headers.get("x-orientation")
80
+ orientation = str(orientation_header).lower() if orientation_header is not None else params.get("x_orientation", "records")
63
81
  limit = params.get("x_limit", 1000)
64
82
  offset = params.get("x_offset", 0)
65
- return rm.DatasetResultModel(**result.to_json(orientation, select, limit, offset))
83
+ return rm.DatasetResultModel(**result.to_json(orientation, limit, offset))
66
84
 
67
85
  def setup_routes(
68
- self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_version: str,
69
- param_fields: dict, get_parameters_definition: Callable
86
+ self, app: FastAPI, mcp: FastMCP, project_metadata_path: str, project_name: str, project_label: str,
87
+ param_fields: dict, get_parameters_definition: Callable[..., Coroutine[Any, Any, rm.ParametersModel]]
70
88
  ) -> None:
71
89
  """Setup dataset routes"""
72
90
 
@@ -84,19 +102,19 @@ class DatasetRoutes(RouteBase):
84
102
  f"\n {all_params}"
85
103
  )
86
104
 
87
- async def get_dataset_parameters_updates(dataset_name: str, user: BaseUser | None, all_request_params: dict, params: dict):
105
+ async def get_dataset_parameters_updates(dataset_name: str, user: AbstractUser, all_request_params: dict, params: dict, headers: dict[str, str]):
88
106
  parameters_list = self.manifest_cfg.datasets[dataset_name].parameters
89
107
  scope = self.manifest_cfg.datasets[dataset_name].scope
90
108
  result = await get_parameters_definition(
91
- parameters_list, "dataset", dataset_name, scope, user, all_request_params, params
109
+ parameters_list, "dataset", dataset_name, scope, user, all_request_params, params, headers=headers
92
110
  )
93
111
  return result
94
112
 
95
113
  # Dataset parameters and results APIs
96
114
  for dataset_name, dataset_config in self.manifest_cfg.datasets.items():
97
- dataset_normalized = u.normalize_name_for_api(dataset_name)
98
- curr_parameters_path = dataset_parameters_path.format(dataset=dataset_normalized)
99
- curr_results_path = dataset_results_path.format(dataset=dataset_normalized)
115
+ dataset_name_for_api = u.normalize_name_for_api(dataset_name)
116
+ curr_parameters_path = dataset_parameters_path.format(dataset=dataset_name_for_api)
117
+ curr_results_path = dataset_results_path.format(dataset=dataset_name_for_api)
100
118
 
101
119
  validate_parameters_list(dataset_config.parameters, "Dataset", dataset_name)
102
120
 
@@ -109,8 +127,10 @@ class DatasetRoutes(RouteBase):
109
127
  ) -> rm.ParametersModel:
110
128
  start = time.time()
111
129
  curr_dataset_name = self.get_name_from_path_section(request, -2)
112
- result = await get_dataset_parameters_updates(curr_dataset_name, user, dict(request.query_params), asdict(params))
113
- self.log_activity_time("GET REQUEST for PARAMETERS", start, request)
130
+ result = await get_dataset_parameters_updates(curr_dataset_name, user, dict(request.query_params), asdict(params), dict(request.headers))
131
+ self.logger.log_activity_time(
132
+ "GET REQUEST for PARAMETERS", start, additional_data={"dataset_name": curr_dataset_name}
133
+ )
114
134
  return result
115
135
 
116
136
  @app.post(curr_parameters_path, tags=[f"Dataset '{dataset_name}'"], description=self._parameters_description, response_class=JSONResponse)
@@ -120,8 +140,10 @@ class DatasetRoutes(RouteBase):
120
140
  start = time.time()
121
141
  curr_dataset_name = self.get_name_from_path_section(request, -2)
122
142
  payload: dict = await request.json()
123
- result = await get_dataset_parameters_updates(curr_dataset_name, user, payload, params.model_dump())
124
- self.log_activity_time("POST REQUEST for PARAMETERS", start, request)
143
+ result = await get_dataset_parameters_updates(curr_dataset_name, user, payload, params.model_dump(), dict(request.headers))
144
+ self.logger.log_activity_time(
145
+ "POST REQUEST for PARAMETERS", start, additional_data={"dataset_name": curr_dataset_name}
146
+ )
125
147
  return result
126
148
 
127
149
  @app.get(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
@@ -130,8 +152,12 @@ class DatasetRoutes(RouteBase):
130
152
  ) -> rm.DatasetResultModel:
131
153
  start = time.time()
132
154
  curr_dataset_name = self.get_name_from_path_section(request, -1)
133
- result = await self._get_dataset_results_definition(curr_dataset_name, user, dict(request.query_params), asdict(params))
134
- self.log_activity_time("GET REQUEST for DATASET RESULTS", start, request)
155
+ result = await self._get_dataset_results_definition(
156
+ curr_dataset_name, user, dict(request.query_params), asdict(params), headers=dict(request.headers)
157
+ )
158
+ self.logger.log_activity_time(
159
+ "GET REQUEST for DATASET RESULTS", start, additional_data={"dataset_name": curr_dataset_name}
160
+ )
135
161
  return result
136
162
 
137
163
  @app.post(curr_results_path, tags=[f"Dataset '{dataset_name}'"], description=dataset_config.description, response_class=JSONResponse)
@@ -141,18 +167,23 @@ class DatasetRoutes(RouteBase):
141
167
  start = time.time()
142
168
  curr_dataset_name = self.get_name_from_path_section(request, -1)
143
169
  payload: dict = await request.json()
144
- result = await self._get_dataset_results_definition(curr_dataset_name, user, payload, params.model_dump())
145
- self.log_activity_time("POST REQUEST for DATASET RESULTS", start, request)
170
+ result = await self._get_dataset_results_definition(
171
+ curr_dataset_name, user, payload, params.model_dump(), headers=dict(request.headers)
172
+ )
173
+ self.logger.log_activity_time(
174
+ "POST REQUEST for DATASET RESULTS", start, additional_data={"dataset_name": curr_dataset_name}
175
+ )
146
176
  return result
147
177
 
148
178
  # Setup MCP tools
149
179
 
150
180
  @mcp.tool(
151
- name=f"get_dataset_parameters_for_{project_name}_{project_version}",
181
+ name=f"get_dataset_parameters_from_{project_name}",
182
+ title=f"Get Dataset Parameters Updates (Project: {project_label})",
152
183
  description=dedent(f"""
153
184
  Use this tool to get updates for dataset parameters in the Squirrels project "{project_name}" when a selection is to be made on a parameter with "trigger_refresh" as true.
154
185
 
155
- For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool will be called to get the updates for other parameters such as "city".
186
+ For example, suppose there are two parameters, "country" and "city", and the user selects "United States" for "country". If "country" has the "trigger_refresh" field as true, then this tool should be called to get the updates for other parameters such as "city".
156
187
 
157
188
  Do not use this tool on parameters whose "trigger_refresh" field is false!
158
189
  """).strip()
@@ -162,28 +193,30 @@ class DatasetRoutes(RouteBase):
162
193
  dataset: str = Field(description="The name of the dataset whose parameters the trigger parameter will update"),
163
194
  parameter_name: str = Field(description="The name of the parameter triggering the refresh"),
164
195
  selected_ids: list[str] = Field(description="The ID(s) of the selected option(s) for the parameter"),
165
- ):
166
- user = self.get_user_from_tool_ctx(ctx)
196
+ ) -> rm.ParametersModel:
197
+ headers = self.get_headers_from_tool_ctx(ctx)
198
+ user = self.get_user_from_tool_headers(headers)
167
199
  dataset_name = u.normalize_name(dataset)
168
200
  payload = {
169
201
  "x_parent_param": parameter_name,
170
202
  parameter_name: selected_ids
171
203
  }
172
- return await get_dataset_parameters_updates(dataset_name, user, payload, payload)
204
+ return await get_dataset_parameters_updates(dataset_name, user, payload, payload, headers)
173
205
 
174
206
  @mcp.tool(
175
- name=f"get_dataset_results_for_{project_name}_{project_version}",
207
+ name=f"get_dataset_results_from_{project_name}",
208
+ title=f"Get Dataset Results (Project: {project_label})",
176
209
  description=dedent(f"""
177
210
  Use this tool to get the dataset results as a JSON object for a dataset in the Squirrels project "{project_name}".
178
211
  - Use the "offset" and "limit" arguments to limit the number of rows you require
179
- - The "limit" argument controls the number of rows returned. The maximum allowed value is 100. If the 'total_num_rows' field in the response is greater than 100, let the user know that only 100 rows are shown and clarify if they would like to see more.
212
+ - The "limit" argument controls the number of rows returned. The maximum allowed value is {self.max_rows_for_ai}. If the 'total_num_rows' field in the response is greater than {self.max_rows_for_ai}, let the user know that only {self.max_rows_for_ai} rows are shown and clarify if they would like to see more.
180
213
  """).strip()
181
214
  )
182
215
  async def get_dataset_results_tool(
183
216
  ctx: Context,
184
217
  dataset: str = Field(description="The name of the dataset to get results for"),
185
- parameters: dict[str, Any] = Field(description=dedent("""
186
- Key-value pairs for parameter name and selected value. The selected value to provide depends on the parameter widget type:
218
+ parameters: str = Field(description=dedent("""
219
+ A JSON object (as string) containing key-value pairs for parameter name and selected value. The selected value to provide depends on the parameter widget type:
187
220
  - For single select, use a string for the ID of the selected value
188
221
  - For multi select, use an array of strings for the IDs of the selected values
189
222
  - For date, use a string like "YYYY-MM-DD"
@@ -191,52 +224,42 @@ class DatasetRoutes(RouteBase):
191
224
  - For number, use a number like 1
192
225
  - For number ranges, use array of numbers like [1,100]
193
226
  - For text, use a string for the text value
194
- - Complex objects are NOT supported""").strip()),
195
- offset: int = Field(0, description="The number of rows to skip from first row. Default is 0."),
196
- limit: int = Field(100, description="The maximum number of rows to return. Default is 100. Maximum allowed value is 100."),
197
- ):
198
- if limit > 100:
199
- raise ValueError("The maximum number of rows to return is 100.")
200
-
201
- user = self.get_user_from_tool_ctx(ctx)
227
+ - Complex objects are NOT supported
228
+ """).strip()),
229
+ sql_query: str | None = Field(None, description=dedent("""
230
+ A custom DuckDB SQL query to execute on the final dataset result.
231
+ - Use table name 'result' to reference the dataset result.
232
+ - Use this to apply transformations to the dataset result if needed (such as filtering, sorting, or selecting columns).
233
+ - If not provided, the dataset result is returned as is.
234
+ """).strip()),
235
+ offset: int = Field(0, description="The number of rows to skip from first row. Applied after final SQL. Default is 0."),
236
+ limit: int = Field(self.max_rows_for_ai, description=f"The maximum number of rows to return. Applied after final SQL. Default is {self.max_rows_for_ai}. Maximum allowed value is {self.max_rows_for_ai}."),
237
+ ) -> rm.DatasetResultModel:
238
+ if limit > self.max_rows_for_ai:
239
+ raise ValueError(f"The maximum number of rows to return is {self.max_rows_for_ai}.")
240
+
241
+ headers = self.get_headers_from_tool_ctx(ctx)
242
+ user = self.get_user_from_tool_headers(headers)
202
243
  dataset_name = u.normalize_name(dataset)
203
- params = {
204
- **parameters,
205
- "x_orientation": "rows",
244
+
245
+ try:
246
+ params = json.loads(parameters)
247
+ except json.JSONDecodeError:
248
+ params = None # error handled below
249
+
250
+ if not isinstance(params, dict):
251
+ raise InvalidInputError(400, "invalid_parameters", f"The 'parameters' argument must be a JSON object.")
252
+
253
+ params.update({
254
+ "x_sql_query": sql_query,
206
255
  "x_offset": offset,
207
256
  "x_limit": limit
208
- }
209
- result = await self._get_dataset_results_definition(dataset_name, user, params, params)
210
- return result
211
-
212
- # Setup UI for tool results
213
- mcp_tool_results_ui_path = project_metadata_path + "/mcp/tool-results-ui"
214
-
215
- @app.get(mcp_tool_results_ui_path + "/list-tools", tags=["MCP Supplements"])
216
- async def list_tools():
217
- return ["get_dataset_results"]
257
+ })
218
258
 
219
- class ToolResultBody(BaseModel):
220
- """Flexible model for tool results - accepts any additional fields"""
259
+ # Set default orientation as rows if not provided
260
+ if "x-orientation" not in headers:
261
+ headers["x-orientation"] = "rows"
221
262
 
222
- class Config:
223
- extra = "allow" # Allow additional fields not defined in the model
224
-
225
- @app.post(mcp_tool_results_ui_path + "/tool/{tool_name}", tags=["MCP Supplements"])
226
- async def tool_results_ui(tool_name: str, tool_result: ToolResultBody):
227
- if tool_name == "get_dataset_results":
228
- # Convert Pydantic model to dict to access any extra fields
229
- tool_result_dict = tool_result.model_dump()
230
-
231
- # Prepare template context
232
- context = {
233
- "schema": tool_result_dict.get("schema", {}),
234
- "data": tool_result_dict.get("data", []),
235
- }
236
-
237
- # Render HTML template
238
- html_content = self.templates.get_template("dataset_results.html").render(context)
239
- return HTMLResponse(content=html_content, status_code=200)
240
- else:
241
- raise InvalidInputError(400, "Invalid tool name", f"Tool name '{tool_name}' not supported for UI")
263
+ result = await self._get_dataset_results_definition(dataset_name, user, params, params, headers)
264
+ return result
242
265
 
@@ -6,9 +6,10 @@ from typing import Annotated, cast
6
6
  from .base import RouteBase
7
7
  from .._schemas.auth_models import (
8
8
  ClientRegistrationRequest, ClientUpdateRequest, ClientRegistrationResponse, ClientDetailsResponse, ClientUpdateResponse,
9
- TokenResponse, OAuthServerMetadata
9
+ TokenResponse, OAuthServerMetadata, AbstractUser
10
10
  )
11
11
  from .._exceptions import InvalidInputError
12
+ from .. import _utils as u
12
13
 
13
14
 
14
15
  class OAuth2Routes(RouteBase):
@@ -50,16 +51,13 @@ class OAuth2Routes(RouteBase):
50
51
  status_code=200
51
52
  )
52
53
 
53
- def setup_routes(self, app: FastAPI) -> None:
54
+ def setup_routes(self, app: FastAPI, squirrels_version_path: str) -> None:
54
55
  """Setup all OAuth2 routes"""
55
56
 
56
- router_path = "/api/auth/oauth2"
57
+ auth_path = squirrels_version_path + "/auth"
58
+ router_path = "/oauth2"
57
59
  router = APIRouter(prefix=router_path)
58
60
 
59
- # Create user models
60
- class UserInfoModel(self.UserInfoModel):
61
- username: str
62
-
63
61
  # Authorization dependency for client management
64
62
  get_client_token = HTTPBearer(auto_error=False)
65
63
 
@@ -93,7 +91,7 @@ class OAuth2Routes(RouteBase):
93
91
  # Client Registration Endpoint
94
92
  client_management_path = '/client/{client_id}'
95
93
 
96
- @router.post("/register", description="Register a new OAuth client", tags=["OAuth2"])
94
+ @router.post("/client", description="Register a new OAuth client", tags=["OAuth2"])
97
95
  async def register_oauth_client(request: ClientRegistrationRequest) -> ClientRegistrationResponse:
98
96
  """Register a new OAuth client and return client credentials"""
99
97
 
@@ -148,7 +146,7 @@ class OAuth2Routes(RouteBase):
148
146
  state: str | None = Query(default=None, description="State parameter for CSRF protection"),
149
147
  code_challenge: str = Query(..., description="PKCE code challenge (required)"),
150
148
  code_challenge_method: str = Query(default="S256", description="PKCE code challenge method"),
151
- user: UserInfoModel | None = Depends(self.get_current_user)
149
+ user: AbstractUser = Depends(self.get_current_user)
152
150
  ):
153
151
  """OAuth 2.1 authorization endpoint for initiating authorization code flow"""
154
152
 
@@ -158,9 +156,9 @@ class OAuth2Routes(RouteBase):
158
156
  raise InvalidInputError(400, "unsupported_response_type", "Only 'code' response type is supported")
159
157
 
160
158
  # Check if user is authenticated
161
- if user is None:
159
+ if user.access_level == "guest":
162
160
  # User is not authenticated - serve login page
163
- return self.serve_login_page("/api/auth", request, client_id)
161
+ return self.serve_login_page(auth_path, request, client_id)
164
162
 
165
163
  # TODO: Serve a page with an "authorize" button even if user is already authenticated
166
164
  # Ex. if not request.session.get("authorization_approved"), redirect to a page with button that submits to "/approve-authorization"
@@ -281,7 +279,7 @@ class OAuth2Routes(RouteBase):
281
279
  """OAuth 2.1 Authorization Server Metadata endpoint (RFC 8414)"""
282
280
 
283
281
  # Get the base URL from the request
284
- scheme = "http" if request.url.hostname in ("localhost", "127.0.0.1") else "https"
282
+ scheme = u.get_scheme(request.url.hostname)
285
283
  base_url = scheme + "://" + request.url.netloc
286
284
 
287
285
  return OAuthServerMetadata(
@@ -289,7 +287,7 @@ class OAuth2Routes(RouteBase):
289
287
  authorization_endpoint=f"{base_url}{router_path}/authorize",
290
288
  token_endpoint=f"{base_url}{router_path}/token",
291
289
  revocation_endpoint=f"{base_url}{router_path}/token/revoke",
292
- registration_endpoint=f"{base_url}{router_path}/register",
290
+ registration_endpoint=f"{base_url}{router_path}/client",
293
291
  scopes_supported=["read"],
294
292
  response_types_supported=["code"],
295
293
  grant_types_supported=["authorization_code", "refresh_token"],