sanjaya-sqlalchemy 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ node_modules/
2
+ tsp-output/
3
+
4
+ __pycache__/
5
+ *.egg-info/
6
+ .venv/
7
+ dist/
8
+ .pytest_cache/
9
+
10
+ # Internal planning docs (not for public repo)
11
+ permissions.md
12
+ docs/implementation-plan.md
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Tom Brennan
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,62 @@
1
+ Metadata-Version: 2.4
2
+ Name: sanjaya-sqlalchemy
3
+ Version: 0.1.0
4
+ Summary: SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform
5
+ Project-URL: Repository, https://github.com/tjb1982/sanjaya
6
+ Project-URL: Issues, https://github.com/tjb1982/sanjaya/issues
7
+ Author: Tom Brennan
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Keywords: analytics,data-provider,reporting,sqlalchemy
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.12
16
+ Classifier: Programming Language :: Python :: 3.13
17
+ Classifier: Typing :: Typed
18
+ Requires-Python: >=3.12
19
+ Requires-Dist: sanjaya-core~=0.1
20
+ Requires-Dist: sqlalchemy<3,>=2.0
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest-cov; extra == 'dev'
23
+ Requires-Dist: pytest>=8; extra == 'dev'
24
+ Description-Content-Type: text/markdown
25
+
26
+ # sanjaya-sqlalchemy
27
+
28
+ SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform.
29
+
30
+ Translates the `DataProvider` interface from `sanjaya-core` into SQL queries using SQLAlchemy Core expressions. No ORM required.
31
+
32
+ ## Installation
33
+
34
+ ```
35
+ pip install sanjaya-sqlalchemy
36
+ ```
37
+
38
+ ## Usage
39
+
40
+ ```python
41
+ from sqlalchemy import MetaData, create_engine
42
+ from sanjaya_core.types import ColumnMeta, DatasetCapabilities
43
+ from sanjaya_core.enums import ColumnType
44
+ from sanjaya_sqlalchemy import SQLAlchemyProvider
45
+
46
+ engine = create_engine("postgresql://...")
47
+ metadata = MetaData()
48
+ metadata.reflect(bind=engine)
49
+
50
+ provider = SQLAlchemyProvider(
51
+ key="trade_activity",
52
+ label="Trade Activity",
53
+ engine=engine,
54
+ selectable=metadata.tables["trade_activity"],
55
+ columns=[
56
+ ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
57
+ ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
58
+ ColumnMeta(name="amount", label="Amount", type=ColumnType.CURRENCY),
59
+ ],
60
+ capabilities=DatasetCapabilities(pivot=True),
61
+ )
62
+ ```
@@ -0,0 +1,37 @@
1
+ # sanjaya-sqlalchemy
2
+
3
+ SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform.
4
+
5
+ Translates the `DataProvider` interface from `sanjaya-core` into SQL queries using SQLAlchemy Core expressions. No ORM required.
6
+
7
+ ## Installation
8
+
9
+ ```
10
+ pip install sanjaya-sqlalchemy
11
+ ```
12
+
13
+ ## Usage
14
+
15
+ ```python
16
+ from sqlalchemy import MetaData, create_engine
17
+ from sanjaya_core.types import ColumnMeta, DatasetCapabilities
18
+ from sanjaya_core.enums import ColumnType
19
+ from sanjaya_sqlalchemy import SQLAlchemyProvider
20
+
21
+ engine = create_engine("postgresql://...")
22
+ metadata = MetaData()
23
+ metadata.reflect(bind=engine)
24
+
25
+ provider = SQLAlchemyProvider(
26
+ key="trade_activity",
27
+ label="Trade Activity",
28
+ engine=engine,
29
+ selectable=metadata.tables["trade_activity"],
30
+ columns=[
31
+ ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
32
+ ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
33
+ ColumnMeta(name="amount", label="Amount", type=ColumnType.CURRENCY),
34
+ ],
35
+ capabilities=DatasetCapabilities(pivot=True),
36
+ )
37
+ ```
@@ -0,0 +1,46 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "sanjaya-sqlalchemy"
7
+ version = "0.1.0"
8
+ description = "SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform"
9
+ readme = "README.md"
10
+ requires-python = ">=3.12"
11
+ license = "MIT"
12
+ license-files = ["LICENSE"]
13
+ authors = [
14
+ { name = "Tom Brennan" },
15
+ ]
16
+ keywords = ["reporting", "sqlalchemy", "data-provider", "analytics"]
17
+ classifiers = [
18
+ "Development Status :: 3 - Alpha",
19
+ "Intended Audience :: Developers",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Programming Language :: Python :: 3.13",
24
+ "Typing :: Typed",
25
+ ]
26
+ dependencies = [
27
+ "sanjaya-core~=0.1",
28
+ "sqlalchemy>=2.0,<3",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=8",
34
+ "pytest-cov",
35
+ ]
36
+
37
+ [project.urls]
38
+ Repository = "https://github.com/tjb1982/sanjaya"
39
+ Issues = "https://github.com/tjb1982/sanjaya/issues"
40
+
41
+ [tool.hatch.build.targets.wheel]
42
+ packages = ["src/sanjaya_sqlalchemy"]
43
+
44
+ [tool.pytest.ini_options]
45
+ testpaths = ["tests"]
46
+ pythonpath = ["."]
@@ -0,0 +1,7 @@
1
+ """SQLAlchemy Core data-provider for the Sanjaya reporting platform."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from sanjaya_sqlalchemy.provider import SQLAlchemyProvider
6
+
7
+ __all__ = ["SQLAlchemyProvider"]
@@ -0,0 +1,109 @@
1
+ """Compile :class:`~sanjaya_core.filters.FilterGroup` trees into SQLAlchemy
2
+ ``WHERE`` clause expressions.
3
+
4
+ The compiler walks the recursive filter model from *sanjaya-core* and
5
+ produces a :class:`sqlalchemy.sql.expression.ColumnElement[bool]` that can be
6
+ appended to any ``SELECT`` statement via ``.where()``.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any
12
+
13
+ import sqlalchemy as sa
14
+ from sqlalchemy.sql.expression import ColumnElement
15
+
16
+ from sanjaya_core.enums import FilterCombinator, FilterOperator
17
+ from sanjaya_core.filters import FilterCondition, FilterGroup
18
+
19
+
20
+ def compile_filter_group(
21
+ fg: FilterGroup,
22
+ column_lookup: dict[str, ColumnElement[Any]],
23
+ ) -> ColumnElement[bool]:
24
+ """Translate a :class:`FilterGroup` into a SQLAlchemy boolean expression.
25
+
26
+ Parameters
27
+ ----------
28
+ fg:
29
+ The recursive filter tree to compile.
30
+ column_lookup:
31
+ Mapping of column *name* → SQLAlchemy :class:`Column`. The compiler
32
+ looks up each :attr:`FilterCondition.column` here.
33
+
34
+ Returns
35
+ -------
36
+ A composable SQLAlchemy expression suitable for ``stmt.where(expr)``.
37
+
38
+ Raises
39
+ ------
40
+ KeyError
41
+ If a condition references a column not present in *column_lookup*.
42
+ """
43
+ parts: list[ColumnElement[bool]] = []
44
+
45
+ for cond in fg.conditions:
46
+ parts.append(_compile_condition(cond, column_lookup))
47
+
48
+ for sub in fg.groups:
49
+ parts.append(compile_filter_group(sub, column_lookup))
50
+
51
+ if not parts:
52
+ expr: ColumnElement[bool] = sa.literal(True)
53
+ elif fg.combinator == FilterCombinator.AND:
54
+ expr = sa.and_(*parts)
55
+ else:
56
+ expr = sa.or_(*parts)
57
+
58
+ if fg.negate:
59
+ expr = sa.not_(expr)
60
+
61
+ return expr
62
+
63
+
64
+ def _compile_condition(
65
+ cond: FilterCondition,
66
+ column_lookup: dict[str, ColumnElement[Any]],
67
+ ) -> ColumnElement[bool]:
68
+ """Compile a single :class:`FilterCondition` to a SQLAlchemy expression."""
69
+ col = column_lookup[cond.column]
70
+ v = cond.value
71
+
72
+ expr: ColumnElement[bool]
73
+
74
+ match cond.operator:
75
+ case FilterOperator.EQ:
76
+ expr = col == v
77
+ case FilterOperator.NEQ:
78
+ expr = col != v
79
+ case FilterOperator.GT:
80
+ expr = col > v
81
+ case FilterOperator.LT:
82
+ expr = col < v
83
+ case FilterOperator.GTE:
84
+ expr = col >= v
85
+ case FilterOperator.LTE:
86
+ expr = col <= v
87
+ case FilterOperator.CONTAINS:
88
+ expr = col.contains(str(v), autoescape=True)
89
+ case FilterOperator.STARTSWITH:
90
+ expr = col.startswith(str(v), autoescape=True)
91
+ case FilterOperator.ENDSWITH:
92
+ expr = col.endswith(str(v), autoescape=True)
93
+ case FilterOperator.IS_NULL:
94
+ expr = col.is_(None)
95
+ case FilterOperator.IS_NOT_NULL:
96
+ expr = col.isnot(None)
97
+ case FilterOperator.BETWEEN:
98
+ lo, hi = v[0], v[1]
99
+ expr = col.between(lo, hi)
100
+ case FilterOperator.IN:
101
+ expr = col.in_(list(v))
102
+ case _:
103
+ # Unknown operator — match everything (safe fallback).
104
+ expr = sa.literal(True)
105
+
106
+ if cond.negate:
107
+ expr = sa.not_(expr)
108
+
109
+ return expr
@@ -0,0 +1,409 @@
1
+ """SQLAlchemy Core implementation of :class:`~sanjaya_core.provider.DataProvider`.
2
+
3
+ ``SQLAlchemyProvider`` is configured with a SQLAlchemy :class:`~sqlalchemy.Table`
4
+ (or selectable), a list of :class:`~sanjaya_core.types.ColumnMeta` definitions,
5
+ and a :class:`~sqlalchemy.engine.Engine`. It translates every provider
6
+ method into SQL using SQLAlchemy Core expressions — no ORM required.
7
+
8
+ Usage example::
9
+
10
+ from sqlalchemy import MetaData, create_engine
11
+ from sanjaya_core.types import ColumnMeta, DatasetCapabilities
12
+ from sanjaya_core.enums import ColumnType
13
+ from sanjaya_sqlalchemy import SQLAlchemyProvider
14
+
15
+ engine = create_engine("postgresql://...")
16
+ metadata = MetaData()
17
+ metadata.reflect(bind=engine)
18
+ trade_table = metadata.tables["trade_activity"]
19
+
20
+ provider = SQLAlchemyProvider(
21
+ key="trade_activity",
22
+ label="Trade Activity",
23
+ engine=engine,
24
+ selectable=trade_table,
25
+ columns=[
26
+ ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
27
+ ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
28
+ ColumnMeta(name="amount", label="Amount", type=ColumnType.CURRENCY),
29
+ ...
30
+ ],
31
+ capabilities=DatasetCapabilities(pivot=True),
32
+ )
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ from typing import Any
38
+
39
+ import sqlalchemy as sa
40
+ from sqlalchemy import FromClause
41
+ from sqlalchemy.engine import Engine
42
+ from sqlalchemy.sql.expression import ColumnElement
43
+
44
+ from sanjaya_core.context import RequestContext
45
+ from sanjaya_core.enums import AggFunc, SortDirection
46
+ from sanjaya_core.filters import FilterGroup
47
+ from sanjaya_core.provider import DataProvider
48
+ from sanjaya_core.types import (
49
+ AggregateColumn,
50
+ AggregateResult,
51
+ ColumnMeta,
52
+ DatasetCapabilities,
53
+ SortSpec,
54
+ TabularResult,
55
+ ValueSpec,
56
+ )
57
+
58
+ from sanjaya_sqlalchemy.filters import compile_filter_group
59
+
60
+
61
+ class SQLAlchemyProvider(DataProvider):
62
+ """A :class:`DataProvider` backed by a SQLAlchemy selectable + engine.
63
+
64
+ Parameters
65
+ ----------
66
+ key:
67
+ Dataset identifier.
68
+ label:
69
+ Human-readable name.
70
+ engine:
71
+ A SQLAlchemy :class:`~sqlalchemy.engine.Engine` used to execute
72
+ queries.
73
+ selectable:
74
+ The table or subquery to query against (e.g. a
75
+ :class:`~sqlalchemy.Table` or ``select()``).
76
+ columns:
77
+ Column metadata definitions. Each
78
+ :attr:`~sanjaya_core.types.ColumnMeta.name` must match a column
79
+ name in *selectable*.
80
+ description:
81
+ Optional description.
82
+ capabilities:
83
+ Optional capability flags.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ *,
89
+ key: str,
90
+ label: str,
91
+ engine: Engine,
92
+ selectable: FromClause,
93
+ columns: list[ColumnMeta],
94
+ description: str = "",
95
+ capabilities: DatasetCapabilities | None = None,
96
+ ) -> None:
97
+ super().__init__(
98
+ key=key,
99
+ label=label,
100
+ description=description,
101
+ capabilities=capabilities or DatasetCapabilities(pivot=True),
102
+ )
103
+ self._engine = engine
104
+ self._selectable = selectable
105
+ self._columns = columns
106
+ self._column_lookup: dict[str, ColumnElement[Any]] = {
107
+ c.name: selectable.c[c.name] for c in columns
108
+ }
109
+
110
+ # ------------------------------------------------------------------
111
+ # DataProvider interface
112
+ # ------------------------------------------------------------------
113
+
114
+ def get_columns(self) -> list[ColumnMeta]:
115
+ return list(self._columns)
116
+
117
+ def query(
118
+ self,
119
+ selected_columns: list[str],
120
+ *,
121
+ filter_group: FilterGroup | None = None,
122
+ sort: list[SortSpec] | None = None,
123
+ limit: int = 100,
124
+ offset: int = 0,
125
+ ctx: RequestContext | None = None,
126
+ ) -> TabularResult:
127
+ # --- count ---
128
+ count_stmt = sa.select(sa.func.count()).select_from(self._selectable)
129
+ if filter_group:
130
+ count_stmt = count_stmt.where(
131
+ compile_filter_group(filter_group, self._column_lookup)
132
+ )
133
+
134
+ # --- data ---
135
+ cols = [self._column_lookup[c] for c in selected_columns]
136
+ data_stmt = sa.select(*cols).select_from(self._selectable)
137
+ if filter_group:
138
+ data_stmt = data_stmt.where(
139
+ compile_filter_group(filter_group, self._column_lookup)
140
+ )
141
+ data_stmt = self._apply_sort(data_stmt, sort)
142
+ data_stmt = data_stmt.limit(limit).offset(offset)
143
+
144
+ with self._engine.connect() as conn:
145
+ total = conn.execute(count_stmt).scalar_one()
146
+ rows = [dict(r._mapping) for r in conn.execute(data_stmt)]
147
+
148
+ return TabularResult(columns=selected_columns, rows=rows, total=total)
149
+
150
+ def aggregate(
151
+ self,
152
+ group_by_rows: list[str],
153
+ group_by_cols: list[str],
154
+ values: list[ValueSpec],
155
+ *,
156
+ filter_group: FilterGroup | None = None,
157
+ sort: list[SortSpec] | None = None,
158
+ limit: int | None = None,
159
+ offset: int = 0,
160
+ ctx: RequestContext | None = None,
161
+ ) -> AggregateResult:
162
+ if group_by_cols:
163
+ return self._pivot_aggregate(
164
+ group_by_rows, group_by_cols, values,
165
+ filter_group=filter_group, sort=sort,
166
+ limit=limit, offset=offset,
167
+ )
168
+ return self._simple_aggregate(
169
+ group_by_rows, values,
170
+ filter_group=filter_group, sort=sort,
171
+ limit=limit, offset=offset,
172
+ )
173
+
174
+ # ------------------------------------------------------------------
175
+ # Internal: simple (non-pivot) aggregation
176
+ # ------------------------------------------------------------------
177
+
178
+ def _simple_aggregate(
179
+ self,
180
+ group_by_rows: list[str],
181
+ values: list[ValueSpec],
182
+ *,
183
+ filter_group: FilterGroup | None = None,
184
+ sort: list[SortSpec] | None = None,
185
+ limit: int | None = None,
186
+ offset: int = 0,
187
+ ) -> AggregateResult:
188
+ # Build SELECT clause: group-by columns + aggregate expressions.
189
+ group_cols = [self._column_lookup[c] for c in group_by_rows]
190
+ agg_exprs: list[tuple[str, ColumnElement[Any]]] = []
191
+ result_columns: list[AggregateColumn] = []
192
+
193
+ for col_name in group_by_rows:
194
+ result_columns.append(
195
+ AggregateColumn(key=col_name, header=col_name)
196
+ )
197
+
198
+ for vs in values:
199
+ key = f"{vs.agg}_{vs.column}"
200
+ sa_expr = self._agg_expression(vs)
201
+ agg_exprs.append((key, sa_expr))
202
+ result_columns.append(
203
+ AggregateColumn(
204
+ key=key,
205
+ header=vs.label or key,
206
+ measure=vs.column,
207
+ agg=vs.agg,
208
+ )
209
+ )
210
+
211
+ select_cols = [
212
+ *[c.label(c.name) for c in group_cols],
213
+ *[expr.label(key) for key, expr in agg_exprs],
214
+ ]
215
+
216
+ # --- count (total groups) ---
217
+ count_sub = (
218
+ sa.select(*[c.label(c.name) for c in group_cols])
219
+ .select_from(self._selectable)
220
+ .group_by(*group_cols)
221
+ )
222
+ if filter_group:
223
+ count_sub = count_sub.where(
224
+ compile_filter_group(filter_group, self._column_lookup)
225
+ )
226
+ count_stmt = sa.select(sa.func.count()).select_from(count_sub.subquery())
227
+
228
+ # --- data ---
229
+ data_stmt = (
230
+ sa.select(*select_cols)
231
+ .select_from(self._selectable)
232
+ .group_by(*group_cols)
233
+ )
234
+ if filter_group:
235
+ data_stmt = data_stmt.where(
236
+ compile_filter_group(filter_group, self._column_lookup)
237
+ )
238
+ data_stmt = self._apply_sort(data_stmt, sort)
239
+ if limit is not None:
240
+ data_stmt = data_stmt.limit(limit)
241
+ data_stmt = data_stmt.offset(offset)
242
+
243
+ with self._engine.connect() as conn:
244
+ total = conn.execute(count_stmt).scalar_one()
245
+ rows = [dict(r._mapping) for r in conn.execute(data_stmt)]
246
+
247
+ return AggregateResult(columns=result_columns, rows=rows, total=total)
248
+
249
+ # ------------------------------------------------------------------
250
+ # Internal: pivot aggregation
251
+ # ------------------------------------------------------------------
252
+
253
+ def _pivot_aggregate(
254
+ self,
255
+ group_by_rows: list[str],
256
+ group_by_cols: list[str],
257
+ values: list[ValueSpec],
258
+ *,
259
+ filter_group: FilterGroup | None = None,
260
+ sort: list[SortSpec] | None = None,
261
+ limit: int | None = None,
262
+ offset: int = 0,
263
+ ) -> AggregateResult:
264
+ """Two-pass pivot: discover combos, then aggregate per combo.
265
+
266
+ SQLAlchemy Core doesn't have built-in PIVOT support, so we:
267
+ 1. Query distinct pivot-column combinations.
268
+ 2. Build one ``CASE WHEN … END`` per (combo × measure) to simulate
269
+ a pivot in a single grouped query.
270
+ """
271
+ where_clause: ColumnElement[bool] | None = None
272
+ if filter_group:
273
+ where_clause = compile_filter_group(
274
+ filter_group, self._column_lookup
275
+ )
276
+
277
+ # --- pass 1: discover distinct pivot combos ---
278
+ pivot_sa_cols = [self._column_lookup[c] for c in group_by_cols]
279
+ combo_stmt = (
280
+ sa.select(*pivot_sa_cols)
281
+ .select_from(self._selectable)
282
+ .distinct()
283
+ )
284
+ if where_clause is not None:
285
+ combo_stmt = combo_stmt.where(where_clause)
286
+ # Deterministic ordering of combos.
287
+ combo_stmt = combo_stmt.order_by(*pivot_sa_cols)
288
+
289
+ with self._engine.connect() as conn:
290
+ combos = [tuple(r._mapping[c] for c in group_by_cols) for r in conn.execute(combo_stmt)]
291
+
292
+ # --- build result column metadata + CASE expressions ---
293
+ group_row_cols = [self._column_lookup[c] for c in group_by_rows]
294
+
295
+ result_columns: list[AggregateColumn] = []
296
+ for col_name in group_by_rows:
297
+ result_columns.append(
298
+ AggregateColumn(key=col_name, header=col_name)
299
+ )
300
+
301
+ case_exprs: list[tuple[str, ColumnElement[Any]]] = []
302
+
303
+ for combo in combos:
304
+ # Build the CASE WHEN condition for this combo.
305
+ combo_cond = sa.and_(
306
+ *(
307
+ self._column_lookup[group_by_cols[i]] == combo[i]
308
+ for i in range(len(group_by_cols))
309
+ )
310
+ )
311
+ for vs in values:
312
+ pivot_key_parts = [str(v) for v in combo]
313
+ col_key = "_".join(pivot_key_parts + [vs.agg, vs.column])
314
+ result_columns.append(
315
+ AggregateColumn(
316
+ key=col_key,
317
+ header=col_key,
318
+ pivot_keys=pivot_key_parts,
319
+ measure=vs.column,
320
+ agg=vs.agg,
321
+ )
322
+ )
323
+ # CASE WHEN <combo_cond> THEN <measure_col> END → wrapped in agg
324
+ measure_col = self._column_lookup[vs.column]
325
+ case_expr = sa.case((combo_cond, measure_col))
326
+ agg_expr = self._wrap_agg(vs.agg, case_expr)
327
+ case_exprs.append((col_key, agg_expr))
328
+
329
+ # --- pass 2: grouped query with CASE expressions ---
330
+ select_cols = [
331
+ *[c.label(c.name) for c in group_row_cols],
332
+ *[expr.label(key) for key, expr in case_exprs],
333
+ ]
334
+
335
+ # count total groups
336
+ count_sub = (
337
+ sa.select(*[c.label(c.name) for c in group_row_cols])
338
+ .select_from(self._selectable)
339
+ .group_by(*group_row_cols)
340
+ )
341
+ if where_clause is not None:
342
+ count_sub = count_sub.where(where_clause)
343
+ count_stmt = sa.select(sa.func.count()).select_from(count_sub.subquery())
344
+
345
+ data_stmt = (
346
+ sa.select(*select_cols)
347
+ .select_from(self._selectable)
348
+ .group_by(*group_row_cols)
349
+ )
350
+ if where_clause is not None:
351
+ data_stmt = data_stmt.where(where_clause)
352
+ data_stmt = self._apply_sort(data_stmt, sort)
353
+ if limit is not None:
354
+ data_stmt = data_stmt.limit(limit)
355
+ data_stmt = data_stmt.offset(offset)
356
+
357
+ with self._engine.connect() as conn:
358
+ total = conn.execute(count_stmt).scalar_one()
359
+ rows = [dict(r._mapping) for r in conn.execute(data_stmt)]
360
+
361
+ return AggregateResult(columns=result_columns, rows=rows, total=total)
362
+
363
+ # ------------------------------------------------------------------
364
+ # Internal helpers
365
+ # ------------------------------------------------------------------
366
+
367
+ def _apply_sort(
368
+ self,
369
+ stmt: sa.Select[Any],
370
+ sort: list[SortSpec] | None,
371
+ ) -> sa.Select[Any]:
372
+ """Append ORDER BY clauses to *stmt*."""
373
+ if not sort:
374
+ return stmt
375
+ clauses = []
376
+ for spec in sort:
377
+ col = self._column_lookup[spec.column]
378
+ clauses.append(
379
+ col.desc() if spec.direction == SortDirection.DESC else col.asc()
380
+ )
381
+ return stmt.order_by(*clauses)
382
+
383
+ def _agg_expression(self, vs: ValueSpec) -> ColumnElement[Any]:
384
+ """Build a SQLAlchemy aggregate expression for a :class:`ValueSpec`."""
385
+ col = self._column_lookup[vs.column]
386
+ return self._wrap_agg(vs.agg, col)
387
+
388
+ @staticmethod
389
+ def _wrap_agg(agg: AggFunc, expr: ColumnElement[Any]) -> ColumnElement[Any]:
390
+ """Wrap *expr* in the SQL aggregate function for *agg*."""
391
+ match agg:
392
+ case AggFunc.SUM:
393
+ return sa.func.sum(expr)
394
+ case AggFunc.AVG:
395
+ return sa.func.avg(expr)
396
+ case AggFunc.MIN:
397
+ return sa.func.min(expr)
398
+ case AggFunc.MAX:
399
+ return sa.func.max(expr)
400
+ case AggFunc.COUNT:
401
+ return sa.func.count(expr)
402
+ case AggFunc.DISTINCT_COUNT:
403
+ return sa.func.count(sa.distinct(expr))
404
+ case AggFunc.FIRST:
405
+ return sa.func.min(expr) # approximation — no SQL FIRST
406
+ case AggFunc.LAST:
407
+ return sa.func.max(expr) # approximation — no SQL LAST
408
+ case _:
409
+ return sa.func.count(expr)
File without changes
@@ -0,0 +1,97 @@
1
+ """Shared fixtures for sanjaya-sqlalchemy tests.
2
+
3
+ Uses an in-memory SQLite database with a ``trades`` table pre-populated
4
+ with deterministic sample data. The schema and data are designed to
5
+ exercise flat queries, simple aggregation, and pivot aggregation.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import pytest
11
+ import sqlalchemy as sa
12
+
13
+ from sanjaya_core.enums import AggFunc, ColumnType
14
+ from sanjaya_core.types import (
15
+ ColumnMeta,
16
+ ColumnPivotOptions,
17
+ DatasetCapabilities,
18
+ PivotAggOption,
19
+ )
20
+
21
+ from sanjaya_sqlalchemy import SQLAlchemyProvider
22
+
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Fixtures
26
+ # ---------------------------------------------------------------------------
27
+
28
+
29
+ @pytest.fixture(scope="session")
30
+ def engine() -> sa.engine.Engine:
31
+ return sa.create_engine("sqlite:///:memory:")
32
+
33
+
34
+ @pytest.fixture(scope="session")
35
+ def trades_table(engine: sa.engine.Engine) -> sa.Table:
36
+ metadata = sa.MetaData()
37
+ table = sa.Table(
38
+ "trades",
39
+ metadata,
40
+ sa.Column("id", sa.Integer, primary_key=True),
41
+ sa.Column("desk", sa.String(50)),
42
+ sa.Column("region", sa.String(50)),
43
+ sa.Column("instrument", sa.String(50)),
44
+ sa.Column("amount", sa.Float),
45
+ sa.Column("quantity", sa.Integer),
46
+ )
47
+ metadata.create_all(engine)
48
+
49
+ rows = [
50
+ {"id": 1, "desk": "FX", "region": "US", "instrument": "EURUSD", "amount": 1000.0, "quantity": 10},
51
+ {"id": 2, "desk": "FX", "region": "US", "instrument": "GBPUSD", "amount": 2000.0, "quantity": 20},
52
+ {"id": 3, "desk": "FX", "region": "EU", "instrument": "EURUSD", "amount": 1500.0, "quantity": 15},
53
+ {"id": 4, "desk": "Rates", "region": "US", "instrument": "T-Note", "amount": 5000.0, "quantity": 50},
54
+ {"id": 5, "desk": "Rates", "region": "EU", "instrument": "Bund", "amount": 3000.0, "quantity": 30},
55
+ {"id": 6, "desk": "Rates", "region": "APAC", "instrument": "JGB", "amount": 4000.0, "quantity": 40},
56
+ {"id": 7, "desk": "FX", "region": "APAC", "instrument": "USDJPY", "amount": 500.0, "quantity": 5},
57
+ ]
58
+ with engine.begin() as conn:
59
+ conn.execute(table.insert(), rows)
60
+
61
+ return table
62
+
63
+
64
+ COLUMN_DEFS = [
65
+ ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
66
+ ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
67
+ ColumnMeta(name="region", label="Region", type=ColumnType.STRING),
68
+ ColumnMeta(name="instrument", label="Instrument", type=ColumnType.STRING),
69
+ ColumnMeta(
70
+ name="amount",
71
+ label="Amount",
72
+ type=ColumnType.CURRENCY,
73
+ pivot=ColumnPivotOptions(
74
+ role="measure",
75
+ allowed_aggs=[
76
+ PivotAggOption(agg=AggFunc.SUM, label="Sum"),
77
+ PivotAggOption(agg=AggFunc.AVG, label="Avg"),
78
+ ],
79
+ ),
80
+ ),
81
+ ColumnMeta(name="quantity", label="Quantity", type=ColumnType.NUMBER),
82
+ ]
83
+
84
+
85
+ @pytest.fixture(scope="session")
86
+ def provider(
87
+ engine: sa.engine.Engine,
88
+ trades_table: sa.Table,
89
+ ) -> SQLAlchemyProvider:
90
+ return SQLAlchemyProvider(
91
+ key="trades",
92
+ label="Trades",
93
+ engine=engine,
94
+ selectable=trades_table,
95
+ columns=COLUMN_DEFS,
96
+ capabilities=DatasetCapabilities(pivot=True),
97
+ )
@@ -0,0 +1,159 @@
1
+ """Tests for the SQLAlchemy filter compiler."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sqlalchemy as sa
6
+ from sqlalchemy.engine import Engine
7
+
8
+ from sanjaya_core.enums import FilterCombinator, FilterOperator
9
+ from sanjaya_core.filters import FilterCondition, FilterGroup
10
+
11
+ from sanjaya_sqlalchemy.filters import compile_filter_group
12
+
13
+
14
+ class TestFilterCompiler:
15
+ """Compile filter trees and verify by actually running SQL."""
16
+
17
+ def test_eq(self, engine: Engine, trades_table: sa.Table) -> None:
18
+ fg = FilterGroup(conditions=[
19
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
20
+ ])
21
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
22
+ stmt = sa.select(trades_table).where(expr)
23
+ with engine.connect() as conn:
24
+ rows = conn.execute(stmt).fetchall()
25
+ assert len(rows) == 4
26
+ assert all(r._mapping["desk"] == "FX" for r in rows)
27
+
28
+ def test_neq(self, engine: Engine, trades_table: sa.Table) -> None:
29
+ fg = FilterGroup(conditions=[
30
+ FilterCondition(column="desk", operator=FilterOperator.NEQ, value="FX"),
31
+ ])
32
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
33
+ stmt = sa.select(trades_table).where(expr)
34
+ with engine.connect() as conn:
35
+ rows = conn.execute(stmt).fetchall()
36
+ assert len(rows) == 3
37
+
38
+ def test_gt_lt(self, engine: Engine, trades_table: sa.Table) -> None:
39
+ fg = FilterGroup(conditions=[
40
+ FilterCondition(column="amount", operator=FilterOperator.GT, value=1500),
41
+ FilterCondition(column="amount", operator=FilterOperator.LT, value=5000),
42
+ ])
43
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
44
+ stmt = sa.select(trades_table).where(expr)
45
+ with engine.connect() as conn:
46
+ rows = conn.execute(stmt).fetchall()
47
+ # amount > 1500 AND < 5000 → 2000, 3000, 4000 = 3 rows
48
+ assert len(rows) == 3
49
+
50
+ def test_between(self, engine: Engine, trades_table: sa.Table) -> None:
51
+ fg = FilterGroup(conditions=[
52
+ FilterCondition(column="amount", operator=FilterOperator.BETWEEN, value=[1000, 3000]),
53
+ ])
54
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
55
+ stmt = sa.select(trades_table).where(expr)
56
+ with engine.connect() as conn:
57
+ rows = conn.execute(stmt).fetchall()
58
+ # 1000, 2000, 1500, 3000 = 4 rows
59
+ assert len(rows) == 4
60
+
61
+ def test_in(self, engine: Engine, trades_table: sa.Table) -> None:
62
+ fg = FilterGroup(conditions=[
63
+ FilterCondition(column="region", operator=FilterOperator.IN, value=["US", "EU"]),
64
+ ])
65
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
66
+ stmt = sa.select(trades_table).where(expr)
67
+ with engine.connect() as conn:
68
+ rows = conn.execute(stmt).fetchall()
69
+ assert len(rows) == 5
70
+
71
+ def test_contains(self, engine: Engine, trades_table: sa.Table) -> None:
72
+ fg = FilterGroup(conditions=[
73
+ FilterCondition(column="instrument", operator=FilterOperator.CONTAINS, value="USD"),
74
+ ])
75
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
76
+ stmt = sa.select(trades_table).where(expr)
77
+ with engine.connect() as conn:
78
+ rows = conn.execute(stmt).fetchall()
79
+ # EURUSD (×2), GBPUSD, USDJPY = 4
80
+ assert len(rows) == 4
81
+
82
+ def test_startswith(self, engine: Engine, trades_table: sa.Table) -> None:
83
+ fg = FilterGroup(conditions=[
84
+ FilterCondition(column="instrument", operator=FilterOperator.STARTSWITH, value="EUR"),
85
+ ])
86
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
87
+ stmt = sa.select(trades_table).where(expr)
88
+ with engine.connect() as conn:
89
+ rows = conn.execute(stmt).fetchall()
90
+ assert len(rows) == 2
91
+
92
+ def test_or_combinator(self, engine: Engine, trades_table: sa.Table) -> None:
93
+ fg = FilterGroup(
94
+ combinator=FilterCombinator.OR,
95
+ conditions=[
96
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
97
+ FilterCondition(column="region", operator=FilterOperator.EQ, value="APAC"),
98
+ ],
99
+ )
100
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
101
+ stmt = sa.select(trades_table).where(expr)
102
+ with engine.connect() as conn:
103
+ rows = conn.execute(stmt).fetchall()
104
+ # FX: 1,2,3,7; APAC: 6,7 → union = 1,2,3,6,7 = 5
105
+ assert len(rows) == 5
106
+
107
+ def test_negate_condition(self, engine: Engine, trades_table: sa.Table) -> None:
108
+ fg = FilterGroup(conditions=[
109
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX", negate=True),
110
+ ])
111
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
112
+ stmt = sa.select(trades_table).where(expr)
113
+ with engine.connect() as conn:
114
+ rows = conn.execute(stmt).fetchall()
115
+ assert len(rows) == 3 # NOT FX → Rates rows
116
+
117
+ def test_negate_group(self, engine: Engine, trades_table: sa.Table) -> None:
118
+ fg = FilterGroup(
119
+ negate=True,
120
+ conditions=[
121
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
122
+ ],
123
+ )
124
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
125
+ stmt = sa.select(trades_table).where(expr)
126
+ with engine.connect() as conn:
127
+ rows = conn.execute(stmt).fetchall()
128
+ assert len(rows) == 3
129
+
130
+ def test_nested_groups(self, engine: Engine, trades_table: sa.Table) -> None:
131
+ """(desk = 'FX') AND (region = 'US' OR region = 'EU')."""
132
+ fg = FilterGroup(
133
+ conditions=[
134
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
135
+ ],
136
+ groups=[
137
+ FilterGroup(
138
+ combinator=FilterCombinator.OR,
139
+ conditions=[
140
+ FilterCondition(column="region", operator=FilterOperator.EQ, value="US"),
141
+ FilterCondition(column="region", operator=FilterOperator.EQ, value="EU"),
142
+ ],
143
+ ),
144
+ ],
145
+ )
146
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
147
+ stmt = sa.select(trades_table).where(expr)
148
+ with engine.connect() as conn:
149
+ rows = conn.execute(stmt).fetchall()
150
+ # FX + (US or EU) → ids 1, 2, 3
151
+ assert len(rows) == 3
152
+
153
+ def test_empty_group_matches_all(self, engine: Engine, trades_table: sa.Table) -> None:
154
+ fg = FilterGroup()
155
+ expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
156
+ stmt = sa.select(trades_table).where(expr)
157
+ with engine.connect() as conn:
158
+ rows = conn.execute(stmt).fetchall()
159
+ assert len(rows) == 7
@@ -0,0 +1,240 @@
1
+ """Tests for SQLAlchemyProvider — flat queries, aggregation, and pivot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from sanjaya_core.enums import AggFunc, FilterOperator, SortDirection
6
+ from sanjaya_core.filters import FilterCondition, FilterGroup
7
+ from sanjaya_core.types import SortSpec, ValueSpec
8
+
9
+ from sanjaya_sqlalchemy import SQLAlchemyProvider
10
+
11
+
12
+ # ---------------------------------------------------------------------------
13
+ # Flat query
14
+ # ---------------------------------------------------------------------------
15
+
16
+
17
+ class TestQuery:
18
+ def test_select_all(self, provider: SQLAlchemyProvider) -> None:
19
+ result = provider.query(["id", "desk", "amount"])
20
+ assert result.total == 7
21
+ assert len(result.rows) == 7
22
+ assert result.columns == ["id", "desk", "amount"]
23
+ # Every row should have exactly the requested keys.
24
+ for row in result.rows:
25
+ assert set(row.keys()) == {"id", "desk", "amount"}
26
+
27
+ def test_limit_offset(self, provider: SQLAlchemyProvider) -> None:
28
+ r1 = provider.query(["id"], sort=[SortSpec(column="id")], limit=3, offset=0)
29
+ r2 = provider.query(["id"], sort=[SortSpec(column="id")], limit=3, offset=3)
30
+ assert r1.total == 7
31
+ ids_1 = [r["id"] for r in r1.rows]
32
+ ids_2 = [r["id"] for r in r2.rows]
33
+ assert ids_1 == [1, 2, 3]
34
+ assert ids_2 == [4, 5, 6]
35
+
36
+ def test_filter(self, provider: SQLAlchemyProvider) -> None:
37
+ fg = FilterGroup(conditions=[
38
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
39
+ ])
40
+ result = provider.query(["id", "desk"], filter_group=fg)
41
+ assert result.total == 4
42
+ assert all(r["desk"] == "FX" for r in result.rows)
43
+
44
+ def test_sort_asc(self, provider: SQLAlchemyProvider) -> None:
45
+ result = provider.query(
46
+ ["id", "amount"],
47
+ sort=[SortSpec(column="amount", direction=SortDirection.ASC)],
48
+ )
49
+ amounts = [r["amount"] for r in result.rows]
50
+ assert amounts == sorted(amounts)
51
+
52
+ def test_sort_desc(self, provider: SQLAlchemyProvider) -> None:
53
+ result = provider.query(
54
+ ["id", "amount"],
55
+ sort=[SortSpec(column="amount", direction=SortDirection.DESC)],
56
+ )
57
+ amounts = [r["amount"] for r in result.rows]
58
+ assert amounts == sorted(amounts, reverse=True)
59
+
60
+ def test_empty_result(self, provider: SQLAlchemyProvider) -> None:
61
+ fg = FilterGroup(conditions=[
62
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="Nonexistent"),
63
+ ])
64
+ result = provider.query(["id"], filter_group=fg)
65
+ assert result.total == 0
66
+ assert result.rows == []
67
+
68
+
69
+ # ---------------------------------------------------------------------------
70
+ # Simple aggregation (no pivot)
71
+ # ---------------------------------------------------------------------------
72
+
73
+
74
+ class TestSimpleAggregate:
75
+ def test_group_by_single_column(self, provider: SQLAlchemyProvider) -> None:
76
+ result = provider.aggregate(
77
+ group_by_rows=["desk"],
78
+ group_by_cols=[],
79
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
80
+ )
81
+ assert result.total == 2 # FX, Rates
82
+ sums = {r["desk"]: r["sum_amount"] for r in result.rows}
83
+ assert sums["FX"] == 1000 + 2000 + 1500 + 500 # 5000
84
+ assert sums["Rates"] == 5000 + 3000 + 4000 # 12000
85
+
86
+ def test_group_by_multiple_columns(self, provider: SQLAlchemyProvider) -> None:
87
+ result = provider.aggregate(
88
+ group_by_rows=["desk", "region"],
89
+ group_by_cols=[],
90
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
91
+ )
92
+ # FX-US, FX-EU, FX-APAC, Rates-US, Rates-EU, Rates-APAC
93
+ assert result.total == 6
94
+
95
+ def test_multiple_agg_funcs(self, provider: SQLAlchemyProvider) -> None:
96
+ result = provider.aggregate(
97
+ group_by_rows=["desk"],
98
+ group_by_cols=[],
99
+ values=[
100
+ ValueSpec(column="amount", agg=AggFunc.SUM),
101
+ ValueSpec(column="amount", agg=AggFunc.COUNT),
102
+ ],
103
+ )
104
+ fx_row = next(r for r in result.rows if r["desk"] == "FX")
105
+ assert fx_row["sum_amount"] == 5000.0
106
+ assert fx_row["count_amount"] == 4 # 4 FX trades
107
+
108
+ def test_aggregate_with_filter(self, provider: SQLAlchemyProvider) -> None:
109
+ fg = FilterGroup(conditions=[
110
+ FilterCondition(column="region", operator=FilterOperator.EQ, value="US"),
111
+ ])
112
+ result = provider.aggregate(
113
+ group_by_rows=["desk"],
114
+ group_by_cols=[],
115
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
116
+ filter_group=fg,
117
+ )
118
+ assert result.total == 2
119
+ sums = {r["desk"]: r["sum_amount"] for r in result.rows}
120
+ assert sums["FX"] == 3000.0
121
+ assert sums["Rates"] == 5000.0
122
+
123
+ def test_aggregate_with_limit_offset(self, provider: SQLAlchemyProvider) -> None:
124
+ result = provider.aggregate(
125
+ group_by_rows=["desk"],
126
+ group_by_cols=[],
127
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
128
+ sort=[SortSpec(column="desk")],
129
+ limit=1,
130
+ offset=0,
131
+ )
132
+ assert result.total == 2
133
+ assert len(result.rows) == 1
134
+ assert result.rows[0]["desk"] == "FX"
135
+
136
+ def test_distinct_count(self, provider: SQLAlchemyProvider) -> None:
137
+ result = provider.aggregate(
138
+ group_by_rows=["desk"],
139
+ group_by_cols=[],
140
+ values=[ValueSpec(column="region", agg=AggFunc.DISTINCT_COUNT)],
141
+ )
142
+ dc = {r["desk"]: r["distinctCount_region"] for r in result.rows}
143
+ assert dc["FX"] == 3 # US, EU, APAC
144
+ assert dc["Rates"] == 3 # US, EU, APAC
145
+
146
+
147
+ # ---------------------------------------------------------------------------
148
+ # Pivot aggregation
149
+ # ---------------------------------------------------------------------------
150
+
151
+
152
+ class TestPivotAggregate:
153
+ def test_basic_pivot(self, provider: SQLAlchemyProvider) -> None:
154
+ """Pivot by region: desk × region → sum(amount)."""
155
+ result = provider.aggregate(
156
+ group_by_rows=["desk"],
157
+ group_by_cols=["region"],
158
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
159
+ )
160
+ assert result.total == 2 # FX, Rates
161
+ # There should be row-dim columns + one per (region × measure)
162
+ col_keys = [c.key for c in result.columns]
163
+ assert "desk" in col_keys
164
+ # Verify a few pivot columns exist
165
+ assert any("US_sum_amount" in k for k in col_keys)
166
+ assert any("EU_sum_amount" in k for k in col_keys)
167
+
168
+ fx_row = next(r for r in result.rows if r["desk"] == "FX")
169
+ assert fx_row["US_sum_amount"] == 3000.0 # 1000 + 2000
170
+ assert fx_row["EU_sum_amount"] == 1500.0
171
+ assert fx_row["APAC_sum_amount"] == 500.0
172
+
173
+ def test_pivot_column_keys_are_strings(self, provider: SQLAlchemyProvider) -> None:
174
+ result = provider.aggregate(
175
+ group_by_rows=["desk"],
176
+ group_by_cols=["region"],
177
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
178
+ )
179
+ for col in result.columns:
180
+ if col.pivot_keys:
181
+ assert all(isinstance(k, str) for k in col.pivot_keys)
182
+
183
+ def test_multi_dim_pivot(self, provider: SQLAlchemyProvider) -> None:
184
+ """Pivot by (region, instrument)."""
185
+ result = provider.aggregate(
186
+ group_by_rows=["desk"],
187
+ group_by_cols=["region", "instrument"],
188
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
189
+ )
190
+ assert result.total == 2
191
+ # Find the FX-US-EURUSD cell
192
+ fx_row = next(r for r in result.rows if r["desk"] == "FX")
193
+ assert fx_row.get("US_EURUSD_sum_amount") == 1000.0
194
+
195
+ def test_pivot_with_multiple_measures(self, provider: SQLAlchemyProvider) -> None:
196
+ result = provider.aggregate(
197
+ group_by_rows=["desk"],
198
+ group_by_cols=["region"],
199
+ values=[
200
+ ValueSpec(column="amount", agg=AggFunc.SUM),
201
+ ValueSpec(column="quantity", agg=AggFunc.SUM),
202
+ ],
203
+ )
204
+ fx_row = next(r for r in result.rows if r["desk"] == "FX")
205
+ assert fx_row["US_sum_amount"] == 3000.0
206
+ assert fx_row["US_sum_quantity"] == 30 # 10 + 20
207
+
208
+ def test_pivot_with_filter(self, provider: SQLAlchemyProvider) -> None:
209
+ fg = FilterGroup(conditions=[
210
+ FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
211
+ ])
212
+ result = provider.aggregate(
213
+ group_by_rows=["desk"],
214
+ group_by_cols=["region"],
215
+ values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
216
+ filter_group=fg,
217
+ )
218
+ assert result.total == 1
219
+ assert result.rows[0]["desk"] == "FX"
220
+
221
+
222
+ # ---------------------------------------------------------------------------
223
+ # Provider metadata
224
+ # ---------------------------------------------------------------------------
225
+
226
+
227
+ class TestProviderMeta:
228
+ def test_get_columns(self, provider: SQLAlchemyProvider) -> None:
229
+ cols = provider.get_columns()
230
+ assert len(cols) == 6
231
+ names = [c.name for c in cols]
232
+ assert "id" in names
233
+ assert "desk" in names
234
+
235
+ def test_capabilities(self, provider: SQLAlchemyProvider) -> None:
236
+ assert provider.capabilities.pivot is True
237
+
238
+ def test_identity(self, provider: SQLAlchemyProvider) -> None:
239
+ assert provider.key == "trades"
240
+ assert provider.label == "Trades"