sanjaya-sqlalchemy 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sanjaya_sqlalchemy-0.1.0/.gitignore +12 -0
- sanjaya_sqlalchemy-0.1.0/LICENSE +21 -0
- sanjaya_sqlalchemy-0.1.0/PKG-INFO +62 -0
- sanjaya_sqlalchemy-0.1.0/README.md +37 -0
- sanjaya_sqlalchemy-0.1.0/pyproject.toml +46 -0
- sanjaya_sqlalchemy-0.1.0/src/sanjaya_sqlalchemy/__init__.py +7 -0
- sanjaya_sqlalchemy-0.1.0/src/sanjaya_sqlalchemy/filters.py +109 -0
- sanjaya_sqlalchemy-0.1.0/src/sanjaya_sqlalchemy/provider.py +409 -0
- sanjaya_sqlalchemy-0.1.0/tests/__init__.py +0 -0
- sanjaya_sqlalchemy-0.1.0/tests/conftest.py +97 -0
- sanjaya_sqlalchemy-0.1.0/tests/test_filters.py +159 -0
- sanjaya_sqlalchemy-0.1.0/tests/test_provider.py +240 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Tom Brennan
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sanjaya-sqlalchemy
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform
|
|
5
|
+
Project-URL: Repository, https://github.com/tjb1982/sanjaya
|
|
6
|
+
Project-URL: Issues, https://github.com/tjb1982/sanjaya/issues
|
|
7
|
+
Author: Tom Brennan
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Keywords: analytics,data-provider,reporting,sqlalchemy
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
17
|
+
Classifier: Typing :: Typed
|
|
18
|
+
Requires-Python: >=3.12
|
|
19
|
+
Requires-Dist: sanjaya-core~=0.1
|
|
20
|
+
Requires-Dist: sqlalchemy<3,>=2.0
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
|
23
|
+
Requires-Dist: pytest>=8; extra == 'dev'
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
|
|
26
|
+
# sanjaya-sqlalchemy
|
|
27
|
+
|
|
28
|
+
SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform.
|
|
29
|
+
|
|
30
|
+
Translates the `DataProvider` interface from `sanjaya-core` into SQL queries using SQLAlchemy Core expressions. No ORM required.
|
|
31
|
+
|
|
32
|
+
## Installation
|
|
33
|
+
|
|
34
|
+
```
|
|
35
|
+
pip install sanjaya-sqlalchemy
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
## Usage
|
|
39
|
+
|
|
40
|
+
```python
|
|
41
|
+
from sqlalchemy import MetaData, create_engine
|
|
42
|
+
from sanjaya_core.types import ColumnMeta, DatasetCapabilities
|
|
43
|
+
from sanjaya_core.enums import ColumnType
|
|
44
|
+
from sanjaya_sqlalchemy import SQLAlchemyProvider
|
|
45
|
+
|
|
46
|
+
engine = create_engine("postgresql://...")
|
|
47
|
+
metadata = MetaData()
|
|
48
|
+
metadata.reflect(bind=engine)
|
|
49
|
+
|
|
50
|
+
provider = SQLAlchemyProvider(
|
|
51
|
+
key="trade_activity",
|
|
52
|
+
label="Trade Activity",
|
|
53
|
+
engine=engine,
|
|
54
|
+
selectable=metadata.tables["trade_activity"],
|
|
55
|
+
columns=[
|
|
56
|
+
ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
|
|
57
|
+
ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
|
|
58
|
+
ColumnMeta(name="amount", label="Amount", type=ColumnType.CURRENCY),
|
|
59
|
+
],
|
|
60
|
+
capabilities=DatasetCapabilities(pivot=True),
|
|
61
|
+
)
|
|
62
|
+
```
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# sanjaya-sqlalchemy
|
|
2
|
+
|
|
3
|
+
SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform.
|
|
4
|
+
|
|
5
|
+
Translates the `DataProvider` interface from `sanjaya-core` into SQL queries using SQLAlchemy Core expressions. No ORM required.
|
|
6
|
+
|
|
7
|
+
## Installation
|
|
8
|
+
|
|
9
|
+
```
|
|
10
|
+
pip install sanjaya-sqlalchemy
|
|
11
|
+
```
|
|
12
|
+
|
|
13
|
+
## Usage
|
|
14
|
+
|
|
15
|
+
```python
|
|
16
|
+
from sqlalchemy import MetaData, create_engine
|
|
17
|
+
from sanjaya_core.types import ColumnMeta, DatasetCapabilities
|
|
18
|
+
from sanjaya_core.enums import ColumnType
|
|
19
|
+
from sanjaya_sqlalchemy import SQLAlchemyProvider
|
|
20
|
+
|
|
21
|
+
engine = create_engine("postgresql://...")
|
|
22
|
+
metadata = MetaData()
|
|
23
|
+
metadata.reflect(bind=engine)
|
|
24
|
+
|
|
25
|
+
provider = SQLAlchemyProvider(
|
|
26
|
+
key="trade_activity",
|
|
27
|
+
label="Trade Activity",
|
|
28
|
+
engine=engine,
|
|
29
|
+
selectable=metadata.tables["trade_activity"],
|
|
30
|
+
columns=[
|
|
31
|
+
ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
|
|
32
|
+
ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
|
|
33
|
+
ColumnMeta(name="amount", label="Amount", type=ColumnType.CURRENCY),
|
|
34
|
+
],
|
|
35
|
+
capabilities=DatasetCapabilities(pivot=True),
|
|
36
|
+
)
|
|
37
|
+
```
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "sanjaya-sqlalchemy"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "SQLAlchemy Core data-provider implementation for the Sanjaya reporting platform"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
license-files = ["LICENSE"]
|
|
13
|
+
authors = [
|
|
14
|
+
{ name = "Tom Brennan" },
|
|
15
|
+
]
|
|
16
|
+
keywords = ["reporting", "sqlalchemy", "data-provider", "analytics"]
|
|
17
|
+
classifiers = [
|
|
18
|
+
"Development Status :: 3 - Alpha",
|
|
19
|
+
"Intended Audience :: Developers",
|
|
20
|
+
"License :: OSI Approved :: MIT License",
|
|
21
|
+
"Programming Language :: Python :: 3",
|
|
22
|
+
"Programming Language :: Python :: 3.12",
|
|
23
|
+
"Programming Language :: Python :: 3.13",
|
|
24
|
+
"Typing :: Typed",
|
|
25
|
+
]
|
|
26
|
+
dependencies = [
|
|
27
|
+
"sanjaya-core~=0.1",
|
|
28
|
+
"sqlalchemy>=2.0,<3",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[project.optional-dependencies]
|
|
32
|
+
dev = [
|
|
33
|
+
"pytest>=8",
|
|
34
|
+
"pytest-cov",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.urls]
|
|
38
|
+
Repository = "https://github.com/tjb1982/sanjaya"
|
|
39
|
+
Issues = "https://github.com/tjb1982/sanjaya/issues"
|
|
40
|
+
|
|
41
|
+
[tool.hatch.build.targets.wheel]
|
|
42
|
+
packages = ["src/sanjaya_sqlalchemy"]
|
|
43
|
+
|
|
44
|
+
[tool.pytest.ini_options]
|
|
45
|
+
testpaths = ["tests"]
|
|
46
|
+
pythonpath = ["."]
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""Compile :class:`~sanjaya_core.filters.FilterGroup` trees into SQLAlchemy
|
|
2
|
+
``WHERE`` clause expressions.
|
|
3
|
+
|
|
4
|
+
The compiler walks the recursive filter model from *sanjaya-core* and
|
|
5
|
+
produces a :class:`sqlalchemy.sql.expression.ColumnElement[bool]` that can be
|
|
6
|
+
appended to any ``SELECT`` statement via ``.where()``.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import sqlalchemy as sa
|
|
14
|
+
from sqlalchemy.sql.expression import ColumnElement
|
|
15
|
+
|
|
16
|
+
from sanjaya_core.enums import FilterCombinator, FilterOperator
|
|
17
|
+
from sanjaya_core.filters import FilterCondition, FilterGroup
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def compile_filter_group(
|
|
21
|
+
fg: FilterGroup,
|
|
22
|
+
column_lookup: dict[str, ColumnElement[Any]],
|
|
23
|
+
) -> ColumnElement[bool]:
|
|
24
|
+
"""Translate a :class:`FilterGroup` into a SQLAlchemy boolean expression.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
fg:
|
|
29
|
+
The recursive filter tree to compile.
|
|
30
|
+
column_lookup:
|
|
31
|
+
Mapping of column *name* → SQLAlchemy :class:`Column`. The compiler
|
|
32
|
+
looks up each :attr:`FilterCondition.column` here.
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
A composable SQLAlchemy expression suitable for ``stmt.where(expr)``.
|
|
37
|
+
|
|
38
|
+
Raises
|
|
39
|
+
------
|
|
40
|
+
KeyError
|
|
41
|
+
If a condition references a column not present in *column_lookup*.
|
|
42
|
+
"""
|
|
43
|
+
parts: list[ColumnElement[bool]] = []
|
|
44
|
+
|
|
45
|
+
for cond in fg.conditions:
|
|
46
|
+
parts.append(_compile_condition(cond, column_lookup))
|
|
47
|
+
|
|
48
|
+
for sub in fg.groups:
|
|
49
|
+
parts.append(compile_filter_group(sub, column_lookup))
|
|
50
|
+
|
|
51
|
+
if not parts:
|
|
52
|
+
expr: ColumnElement[bool] = sa.literal(True)
|
|
53
|
+
elif fg.combinator == FilterCombinator.AND:
|
|
54
|
+
expr = sa.and_(*parts)
|
|
55
|
+
else:
|
|
56
|
+
expr = sa.or_(*parts)
|
|
57
|
+
|
|
58
|
+
if fg.negate:
|
|
59
|
+
expr = sa.not_(expr)
|
|
60
|
+
|
|
61
|
+
return expr
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def _compile_condition(
|
|
65
|
+
cond: FilterCondition,
|
|
66
|
+
column_lookup: dict[str, ColumnElement[Any]],
|
|
67
|
+
) -> ColumnElement[bool]:
|
|
68
|
+
"""Compile a single :class:`FilterCondition` to a SQLAlchemy expression."""
|
|
69
|
+
col = column_lookup[cond.column]
|
|
70
|
+
v = cond.value
|
|
71
|
+
|
|
72
|
+
expr: ColumnElement[bool]
|
|
73
|
+
|
|
74
|
+
match cond.operator:
|
|
75
|
+
case FilterOperator.EQ:
|
|
76
|
+
expr = col == v
|
|
77
|
+
case FilterOperator.NEQ:
|
|
78
|
+
expr = col != v
|
|
79
|
+
case FilterOperator.GT:
|
|
80
|
+
expr = col > v
|
|
81
|
+
case FilterOperator.LT:
|
|
82
|
+
expr = col < v
|
|
83
|
+
case FilterOperator.GTE:
|
|
84
|
+
expr = col >= v
|
|
85
|
+
case FilterOperator.LTE:
|
|
86
|
+
expr = col <= v
|
|
87
|
+
case FilterOperator.CONTAINS:
|
|
88
|
+
expr = col.contains(str(v), autoescape=True)
|
|
89
|
+
case FilterOperator.STARTSWITH:
|
|
90
|
+
expr = col.startswith(str(v), autoescape=True)
|
|
91
|
+
case FilterOperator.ENDSWITH:
|
|
92
|
+
expr = col.endswith(str(v), autoescape=True)
|
|
93
|
+
case FilterOperator.IS_NULL:
|
|
94
|
+
expr = col.is_(None)
|
|
95
|
+
case FilterOperator.IS_NOT_NULL:
|
|
96
|
+
expr = col.isnot(None)
|
|
97
|
+
case FilterOperator.BETWEEN:
|
|
98
|
+
lo, hi = v[0], v[1]
|
|
99
|
+
expr = col.between(lo, hi)
|
|
100
|
+
case FilterOperator.IN:
|
|
101
|
+
expr = col.in_(list(v))
|
|
102
|
+
case _:
|
|
103
|
+
# Unknown operator — match everything (safe fallback).
|
|
104
|
+
expr = sa.literal(True)
|
|
105
|
+
|
|
106
|
+
if cond.negate:
|
|
107
|
+
expr = sa.not_(expr)
|
|
108
|
+
|
|
109
|
+
return expr
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
"""SQLAlchemy Core implementation of :class:`~sanjaya_core.provider.DataProvider`.
|
|
2
|
+
|
|
3
|
+
``SQLAlchemyProvider`` is configured with a SQLAlchemy :class:`~sqlalchemy.Table`
|
|
4
|
+
(or selectable), a list of :class:`~sanjaya_core.types.ColumnMeta` definitions,
|
|
5
|
+
and a :class:`~sqlalchemy.engine.Engine`. It translates every provider
|
|
6
|
+
method into SQL using SQLAlchemy Core expressions — no ORM required.
|
|
7
|
+
|
|
8
|
+
Usage example::
|
|
9
|
+
|
|
10
|
+
from sqlalchemy import MetaData, create_engine
|
|
11
|
+
from sanjaya_core.types import ColumnMeta, DatasetCapabilities
|
|
12
|
+
from sanjaya_core.enums import ColumnType
|
|
13
|
+
from sanjaya_sqlalchemy import SQLAlchemyProvider
|
|
14
|
+
|
|
15
|
+
engine = create_engine("postgresql://...")
|
|
16
|
+
metadata = MetaData()
|
|
17
|
+
metadata.reflect(bind=engine)
|
|
18
|
+
trade_table = metadata.tables["trade_activity"]
|
|
19
|
+
|
|
20
|
+
provider = SQLAlchemyProvider(
|
|
21
|
+
key="trade_activity",
|
|
22
|
+
label="Trade Activity",
|
|
23
|
+
engine=engine,
|
|
24
|
+
selectable=trade_table,
|
|
25
|
+
columns=[
|
|
26
|
+
ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
|
|
27
|
+
ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
|
|
28
|
+
ColumnMeta(name="amount", label="Amount", type=ColumnType.CURRENCY),
|
|
29
|
+
...
|
|
30
|
+
],
|
|
31
|
+
capabilities=DatasetCapabilities(pivot=True),
|
|
32
|
+
)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
from typing import Any
|
|
38
|
+
|
|
39
|
+
import sqlalchemy as sa
|
|
40
|
+
from sqlalchemy import FromClause
|
|
41
|
+
from sqlalchemy.engine import Engine
|
|
42
|
+
from sqlalchemy.sql.expression import ColumnElement
|
|
43
|
+
|
|
44
|
+
from sanjaya_core.context import RequestContext
|
|
45
|
+
from sanjaya_core.enums import AggFunc, SortDirection
|
|
46
|
+
from sanjaya_core.filters import FilterGroup
|
|
47
|
+
from sanjaya_core.provider import DataProvider
|
|
48
|
+
from sanjaya_core.types import (
|
|
49
|
+
AggregateColumn,
|
|
50
|
+
AggregateResult,
|
|
51
|
+
ColumnMeta,
|
|
52
|
+
DatasetCapabilities,
|
|
53
|
+
SortSpec,
|
|
54
|
+
TabularResult,
|
|
55
|
+
ValueSpec,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
from sanjaya_sqlalchemy.filters import compile_filter_group
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class SQLAlchemyProvider(DataProvider):
|
|
62
|
+
"""A :class:`DataProvider` backed by a SQLAlchemy selectable + engine.
|
|
63
|
+
|
|
64
|
+
Parameters
|
|
65
|
+
----------
|
|
66
|
+
key:
|
|
67
|
+
Dataset identifier.
|
|
68
|
+
label:
|
|
69
|
+
Human-readable name.
|
|
70
|
+
engine:
|
|
71
|
+
A SQLAlchemy :class:`~sqlalchemy.engine.Engine` used to execute
|
|
72
|
+
queries.
|
|
73
|
+
selectable:
|
|
74
|
+
The table or subquery to query against (e.g. a
|
|
75
|
+
:class:`~sqlalchemy.Table` or ``select()``).
|
|
76
|
+
columns:
|
|
77
|
+
Column metadata definitions. Each
|
|
78
|
+
:attr:`~sanjaya_core.types.ColumnMeta.name` must match a column
|
|
79
|
+
name in *selectable*.
|
|
80
|
+
description:
|
|
81
|
+
Optional description.
|
|
82
|
+
capabilities:
|
|
83
|
+
Optional capability flags.
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
def __init__(
|
|
87
|
+
self,
|
|
88
|
+
*,
|
|
89
|
+
key: str,
|
|
90
|
+
label: str,
|
|
91
|
+
engine: Engine,
|
|
92
|
+
selectable: FromClause,
|
|
93
|
+
columns: list[ColumnMeta],
|
|
94
|
+
description: str = "",
|
|
95
|
+
capabilities: DatasetCapabilities | None = None,
|
|
96
|
+
) -> None:
|
|
97
|
+
super().__init__(
|
|
98
|
+
key=key,
|
|
99
|
+
label=label,
|
|
100
|
+
description=description,
|
|
101
|
+
capabilities=capabilities or DatasetCapabilities(pivot=True),
|
|
102
|
+
)
|
|
103
|
+
self._engine = engine
|
|
104
|
+
self._selectable = selectable
|
|
105
|
+
self._columns = columns
|
|
106
|
+
self._column_lookup: dict[str, ColumnElement[Any]] = {
|
|
107
|
+
c.name: selectable.c[c.name] for c in columns
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# ------------------------------------------------------------------
|
|
111
|
+
# DataProvider interface
|
|
112
|
+
# ------------------------------------------------------------------
|
|
113
|
+
|
|
114
|
+
def get_columns(self) -> list[ColumnMeta]:
|
|
115
|
+
return list(self._columns)
|
|
116
|
+
|
|
117
|
+
def query(
|
|
118
|
+
self,
|
|
119
|
+
selected_columns: list[str],
|
|
120
|
+
*,
|
|
121
|
+
filter_group: FilterGroup | None = None,
|
|
122
|
+
sort: list[SortSpec] | None = None,
|
|
123
|
+
limit: int = 100,
|
|
124
|
+
offset: int = 0,
|
|
125
|
+
ctx: RequestContext | None = None,
|
|
126
|
+
) -> TabularResult:
|
|
127
|
+
# --- count ---
|
|
128
|
+
count_stmt = sa.select(sa.func.count()).select_from(self._selectable)
|
|
129
|
+
if filter_group:
|
|
130
|
+
count_stmt = count_stmt.where(
|
|
131
|
+
compile_filter_group(filter_group, self._column_lookup)
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# --- data ---
|
|
135
|
+
cols = [self._column_lookup[c] for c in selected_columns]
|
|
136
|
+
data_stmt = sa.select(*cols).select_from(self._selectable)
|
|
137
|
+
if filter_group:
|
|
138
|
+
data_stmt = data_stmt.where(
|
|
139
|
+
compile_filter_group(filter_group, self._column_lookup)
|
|
140
|
+
)
|
|
141
|
+
data_stmt = self._apply_sort(data_stmt, sort)
|
|
142
|
+
data_stmt = data_stmt.limit(limit).offset(offset)
|
|
143
|
+
|
|
144
|
+
with self._engine.connect() as conn:
|
|
145
|
+
total = conn.execute(count_stmt).scalar_one()
|
|
146
|
+
rows = [dict(r._mapping) for r in conn.execute(data_stmt)]
|
|
147
|
+
|
|
148
|
+
return TabularResult(columns=selected_columns, rows=rows, total=total)
|
|
149
|
+
|
|
150
|
+
def aggregate(
|
|
151
|
+
self,
|
|
152
|
+
group_by_rows: list[str],
|
|
153
|
+
group_by_cols: list[str],
|
|
154
|
+
values: list[ValueSpec],
|
|
155
|
+
*,
|
|
156
|
+
filter_group: FilterGroup | None = None,
|
|
157
|
+
sort: list[SortSpec] | None = None,
|
|
158
|
+
limit: int | None = None,
|
|
159
|
+
offset: int = 0,
|
|
160
|
+
ctx: RequestContext | None = None,
|
|
161
|
+
) -> AggregateResult:
|
|
162
|
+
if group_by_cols:
|
|
163
|
+
return self._pivot_aggregate(
|
|
164
|
+
group_by_rows, group_by_cols, values,
|
|
165
|
+
filter_group=filter_group, sort=sort,
|
|
166
|
+
limit=limit, offset=offset,
|
|
167
|
+
)
|
|
168
|
+
return self._simple_aggregate(
|
|
169
|
+
group_by_rows, values,
|
|
170
|
+
filter_group=filter_group, sort=sort,
|
|
171
|
+
limit=limit, offset=offset,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# ------------------------------------------------------------------
|
|
175
|
+
# Internal: simple (non-pivot) aggregation
|
|
176
|
+
# ------------------------------------------------------------------
|
|
177
|
+
|
|
178
|
+
def _simple_aggregate(
|
|
179
|
+
self,
|
|
180
|
+
group_by_rows: list[str],
|
|
181
|
+
values: list[ValueSpec],
|
|
182
|
+
*,
|
|
183
|
+
filter_group: FilterGroup | None = None,
|
|
184
|
+
sort: list[SortSpec] | None = None,
|
|
185
|
+
limit: int | None = None,
|
|
186
|
+
offset: int = 0,
|
|
187
|
+
) -> AggregateResult:
|
|
188
|
+
# Build SELECT clause: group-by columns + aggregate expressions.
|
|
189
|
+
group_cols = [self._column_lookup[c] for c in group_by_rows]
|
|
190
|
+
agg_exprs: list[tuple[str, ColumnElement[Any]]] = []
|
|
191
|
+
result_columns: list[AggregateColumn] = []
|
|
192
|
+
|
|
193
|
+
for col_name in group_by_rows:
|
|
194
|
+
result_columns.append(
|
|
195
|
+
AggregateColumn(key=col_name, header=col_name)
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
for vs in values:
|
|
199
|
+
key = f"{vs.agg}_{vs.column}"
|
|
200
|
+
sa_expr = self._agg_expression(vs)
|
|
201
|
+
agg_exprs.append((key, sa_expr))
|
|
202
|
+
result_columns.append(
|
|
203
|
+
AggregateColumn(
|
|
204
|
+
key=key,
|
|
205
|
+
header=vs.label or key,
|
|
206
|
+
measure=vs.column,
|
|
207
|
+
agg=vs.agg,
|
|
208
|
+
)
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
select_cols = [
|
|
212
|
+
*[c.label(c.name) for c in group_cols],
|
|
213
|
+
*[expr.label(key) for key, expr in agg_exprs],
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
# --- count (total groups) ---
|
|
217
|
+
count_sub = (
|
|
218
|
+
sa.select(*[c.label(c.name) for c in group_cols])
|
|
219
|
+
.select_from(self._selectable)
|
|
220
|
+
.group_by(*group_cols)
|
|
221
|
+
)
|
|
222
|
+
if filter_group:
|
|
223
|
+
count_sub = count_sub.where(
|
|
224
|
+
compile_filter_group(filter_group, self._column_lookup)
|
|
225
|
+
)
|
|
226
|
+
count_stmt = sa.select(sa.func.count()).select_from(count_sub.subquery())
|
|
227
|
+
|
|
228
|
+
# --- data ---
|
|
229
|
+
data_stmt = (
|
|
230
|
+
sa.select(*select_cols)
|
|
231
|
+
.select_from(self._selectable)
|
|
232
|
+
.group_by(*group_cols)
|
|
233
|
+
)
|
|
234
|
+
if filter_group:
|
|
235
|
+
data_stmt = data_stmt.where(
|
|
236
|
+
compile_filter_group(filter_group, self._column_lookup)
|
|
237
|
+
)
|
|
238
|
+
data_stmt = self._apply_sort(data_stmt, sort)
|
|
239
|
+
if limit is not None:
|
|
240
|
+
data_stmt = data_stmt.limit(limit)
|
|
241
|
+
data_stmt = data_stmt.offset(offset)
|
|
242
|
+
|
|
243
|
+
with self._engine.connect() as conn:
|
|
244
|
+
total = conn.execute(count_stmt).scalar_one()
|
|
245
|
+
rows = [dict(r._mapping) for r in conn.execute(data_stmt)]
|
|
246
|
+
|
|
247
|
+
return AggregateResult(columns=result_columns, rows=rows, total=total)
|
|
248
|
+
|
|
249
|
+
# ------------------------------------------------------------------
|
|
250
|
+
# Internal: pivot aggregation
|
|
251
|
+
# ------------------------------------------------------------------
|
|
252
|
+
|
|
253
|
+
def _pivot_aggregate(
|
|
254
|
+
self,
|
|
255
|
+
group_by_rows: list[str],
|
|
256
|
+
group_by_cols: list[str],
|
|
257
|
+
values: list[ValueSpec],
|
|
258
|
+
*,
|
|
259
|
+
filter_group: FilterGroup | None = None,
|
|
260
|
+
sort: list[SortSpec] | None = None,
|
|
261
|
+
limit: int | None = None,
|
|
262
|
+
offset: int = 0,
|
|
263
|
+
) -> AggregateResult:
|
|
264
|
+
"""Two-pass pivot: discover combos, then aggregate per combo.
|
|
265
|
+
|
|
266
|
+
SQLAlchemy Core doesn't have built-in PIVOT support, so we:
|
|
267
|
+
1. Query distinct pivot-column combinations.
|
|
268
|
+
2. Build one ``CASE WHEN … END`` per (combo × measure) to simulate
|
|
269
|
+
a pivot in a single grouped query.
|
|
270
|
+
"""
|
|
271
|
+
where_clause: ColumnElement[bool] | None = None
|
|
272
|
+
if filter_group:
|
|
273
|
+
where_clause = compile_filter_group(
|
|
274
|
+
filter_group, self._column_lookup
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# --- pass 1: discover distinct pivot combos ---
|
|
278
|
+
pivot_sa_cols = [self._column_lookup[c] for c in group_by_cols]
|
|
279
|
+
combo_stmt = (
|
|
280
|
+
sa.select(*pivot_sa_cols)
|
|
281
|
+
.select_from(self._selectable)
|
|
282
|
+
.distinct()
|
|
283
|
+
)
|
|
284
|
+
if where_clause is not None:
|
|
285
|
+
combo_stmt = combo_stmt.where(where_clause)
|
|
286
|
+
# Deterministic ordering of combos.
|
|
287
|
+
combo_stmt = combo_stmt.order_by(*pivot_sa_cols)
|
|
288
|
+
|
|
289
|
+
with self._engine.connect() as conn:
|
|
290
|
+
combos = [tuple(r._mapping[c] for c in group_by_cols) for r in conn.execute(combo_stmt)]
|
|
291
|
+
|
|
292
|
+
# --- build result column metadata + CASE expressions ---
|
|
293
|
+
group_row_cols = [self._column_lookup[c] for c in group_by_rows]
|
|
294
|
+
|
|
295
|
+
result_columns: list[AggregateColumn] = []
|
|
296
|
+
for col_name in group_by_rows:
|
|
297
|
+
result_columns.append(
|
|
298
|
+
AggregateColumn(key=col_name, header=col_name)
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
case_exprs: list[tuple[str, ColumnElement[Any]]] = []
|
|
302
|
+
|
|
303
|
+
for combo in combos:
|
|
304
|
+
# Build the CASE WHEN condition for this combo.
|
|
305
|
+
combo_cond = sa.and_(
|
|
306
|
+
*(
|
|
307
|
+
self._column_lookup[group_by_cols[i]] == combo[i]
|
|
308
|
+
for i in range(len(group_by_cols))
|
|
309
|
+
)
|
|
310
|
+
)
|
|
311
|
+
for vs in values:
|
|
312
|
+
pivot_key_parts = [str(v) for v in combo]
|
|
313
|
+
col_key = "_".join(pivot_key_parts + [vs.agg, vs.column])
|
|
314
|
+
result_columns.append(
|
|
315
|
+
AggregateColumn(
|
|
316
|
+
key=col_key,
|
|
317
|
+
header=col_key,
|
|
318
|
+
pivot_keys=pivot_key_parts,
|
|
319
|
+
measure=vs.column,
|
|
320
|
+
agg=vs.agg,
|
|
321
|
+
)
|
|
322
|
+
)
|
|
323
|
+
# CASE WHEN <combo_cond> THEN <measure_col> END → wrapped in agg
|
|
324
|
+
measure_col = self._column_lookup[vs.column]
|
|
325
|
+
case_expr = sa.case((combo_cond, measure_col))
|
|
326
|
+
agg_expr = self._wrap_agg(vs.agg, case_expr)
|
|
327
|
+
case_exprs.append((col_key, agg_expr))
|
|
328
|
+
|
|
329
|
+
# --- pass 2: grouped query with CASE expressions ---
|
|
330
|
+
select_cols = [
|
|
331
|
+
*[c.label(c.name) for c in group_row_cols],
|
|
332
|
+
*[expr.label(key) for key, expr in case_exprs],
|
|
333
|
+
]
|
|
334
|
+
|
|
335
|
+
# count total groups
|
|
336
|
+
count_sub = (
|
|
337
|
+
sa.select(*[c.label(c.name) for c in group_row_cols])
|
|
338
|
+
.select_from(self._selectable)
|
|
339
|
+
.group_by(*group_row_cols)
|
|
340
|
+
)
|
|
341
|
+
if where_clause is not None:
|
|
342
|
+
count_sub = count_sub.where(where_clause)
|
|
343
|
+
count_stmt = sa.select(sa.func.count()).select_from(count_sub.subquery())
|
|
344
|
+
|
|
345
|
+
data_stmt = (
|
|
346
|
+
sa.select(*select_cols)
|
|
347
|
+
.select_from(self._selectable)
|
|
348
|
+
.group_by(*group_row_cols)
|
|
349
|
+
)
|
|
350
|
+
if where_clause is not None:
|
|
351
|
+
data_stmt = data_stmt.where(where_clause)
|
|
352
|
+
data_stmt = self._apply_sort(data_stmt, sort)
|
|
353
|
+
if limit is not None:
|
|
354
|
+
data_stmt = data_stmt.limit(limit)
|
|
355
|
+
data_stmt = data_stmt.offset(offset)
|
|
356
|
+
|
|
357
|
+
with self._engine.connect() as conn:
|
|
358
|
+
total = conn.execute(count_stmt).scalar_one()
|
|
359
|
+
rows = [dict(r._mapping) for r in conn.execute(data_stmt)]
|
|
360
|
+
|
|
361
|
+
return AggregateResult(columns=result_columns, rows=rows, total=total)
|
|
362
|
+
|
|
363
|
+
# ------------------------------------------------------------------
|
|
364
|
+
# Internal helpers
|
|
365
|
+
# ------------------------------------------------------------------
|
|
366
|
+
|
|
367
|
+
def _apply_sort(
|
|
368
|
+
self,
|
|
369
|
+
stmt: sa.Select[Any],
|
|
370
|
+
sort: list[SortSpec] | None,
|
|
371
|
+
) -> sa.Select[Any]:
|
|
372
|
+
"""Append ORDER BY clauses to *stmt*."""
|
|
373
|
+
if not sort:
|
|
374
|
+
return stmt
|
|
375
|
+
clauses = []
|
|
376
|
+
for spec in sort:
|
|
377
|
+
col = self._column_lookup[spec.column]
|
|
378
|
+
clauses.append(
|
|
379
|
+
col.desc() if spec.direction == SortDirection.DESC else col.asc()
|
|
380
|
+
)
|
|
381
|
+
return stmt.order_by(*clauses)
|
|
382
|
+
|
|
383
|
+
def _agg_expression(self, vs: ValueSpec) -> ColumnElement[Any]:
|
|
384
|
+
"""Build a SQLAlchemy aggregate expression for a :class:`ValueSpec`."""
|
|
385
|
+
col = self._column_lookup[vs.column]
|
|
386
|
+
return self._wrap_agg(vs.agg, col)
|
|
387
|
+
|
|
388
|
+
@staticmethod
|
|
389
|
+
def _wrap_agg(agg: AggFunc, expr: ColumnElement[Any]) -> ColumnElement[Any]:
|
|
390
|
+
"""Wrap *expr* in the SQL aggregate function for *agg*."""
|
|
391
|
+
match agg:
|
|
392
|
+
case AggFunc.SUM:
|
|
393
|
+
return sa.func.sum(expr)
|
|
394
|
+
case AggFunc.AVG:
|
|
395
|
+
return sa.func.avg(expr)
|
|
396
|
+
case AggFunc.MIN:
|
|
397
|
+
return sa.func.min(expr)
|
|
398
|
+
case AggFunc.MAX:
|
|
399
|
+
return sa.func.max(expr)
|
|
400
|
+
case AggFunc.COUNT:
|
|
401
|
+
return sa.func.count(expr)
|
|
402
|
+
case AggFunc.DISTINCT_COUNT:
|
|
403
|
+
return sa.func.count(sa.distinct(expr))
|
|
404
|
+
case AggFunc.FIRST:
|
|
405
|
+
return sa.func.min(expr) # approximation — no SQL FIRST
|
|
406
|
+
case AggFunc.LAST:
|
|
407
|
+
return sa.func.max(expr) # approximation — no SQL LAST
|
|
408
|
+
case _:
|
|
409
|
+
return sa.func.count(expr)
|
|
File without changes
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Shared fixtures for sanjaya-sqlalchemy tests.
|
|
2
|
+
|
|
3
|
+
Uses an in-memory SQLite database with a ``trades`` table pre-populated
|
|
4
|
+
with deterministic sample data. The schema and data are designed to
|
|
5
|
+
exercise flat queries, simple aggregation, and pivot aggregation.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import pytest
|
|
11
|
+
import sqlalchemy as sa
|
|
12
|
+
|
|
13
|
+
from sanjaya_core.enums import AggFunc, ColumnType
|
|
14
|
+
from sanjaya_core.types import (
|
|
15
|
+
ColumnMeta,
|
|
16
|
+
ColumnPivotOptions,
|
|
17
|
+
DatasetCapabilities,
|
|
18
|
+
PivotAggOption,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
from sanjaya_sqlalchemy import SQLAlchemyProvider
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
# ---------------------------------------------------------------------------
|
|
25
|
+
# Fixtures
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture(scope="session")
|
|
30
|
+
def engine() -> sa.engine.Engine:
|
|
31
|
+
return sa.create_engine("sqlite:///:memory:")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.fixture(scope="session")
|
|
35
|
+
def trades_table(engine: sa.engine.Engine) -> sa.Table:
|
|
36
|
+
metadata = sa.MetaData()
|
|
37
|
+
table = sa.Table(
|
|
38
|
+
"trades",
|
|
39
|
+
metadata,
|
|
40
|
+
sa.Column("id", sa.Integer, primary_key=True),
|
|
41
|
+
sa.Column("desk", sa.String(50)),
|
|
42
|
+
sa.Column("region", sa.String(50)),
|
|
43
|
+
sa.Column("instrument", sa.String(50)),
|
|
44
|
+
sa.Column("amount", sa.Float),
|
|
45
|
+
sa.Column("quantity", sa.Integer),
|
|
46
|
+
)
|
|
47
|
+
metadata.create_all(engine)
|
|
48
|
+
|
|
49
|
+
rows = [
|
|
50
|
+
{"id": 1, "desk": "FX", "region": "US", "instrument": "EURUSD", "amount": 1000.0, "quantity": 10},
|
|
51
|
+
{"id": 2, "desk": "FX", "region": "US", "instrument": "GBPUSD", "amount": 2000.0, "quantity": 20},
|
|
52
|
+
{"id": 3, "desk": "FX", "region": "EU", "instrument": "EURUSD", "amount": 1500.0, "quantity": 15},
|
|
53
|
+
{"id": 4, "desk": "Rates", "region": "US", "instrument": "T-Note", "amount": 5000.0, "quantity": 50},
|
|
54
|
+
{"id": 5, "desk": "Rates", "region": "EU", "instrument": "Bund", "amount": 3000.0, "quantity": 30},
|
|
55
|
+
{"id": 6, "desk": "Rates", "region": "APAC", "instrument": "JGB", "amount": 4000.0, "quantity": 40},
|
|
56
|
+
{"id": 7, "desk": "FX", "region": "APAC", "instrument": "USDJPY", "amount": 500.0, "quantity": 5},
|
|
57
|
+
]
|
|
58
|
+
with engine.begin() as conn:
|
|
59
|
+
conn.execute(table.insert(), rows)
|
|
60
|
+
|
|
61
|
+
return table
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
COLUMN_DEFS = [
|
|
65
|
+
ColumnMeta(name="id", label="ID", type=ColumnType.NUMBER),
|
|
66
|
+
ColumnMeta(name="desk", label="Desk", type=ColumnType.STRING),
|
|
67
|
+
ColumnMeta(name="region", label="Region", type=ColumnType.STRING),
|
|
68
|
+
ColumnMeta(name="instrument", label="Instrument", type=ColumnType.STRING),
|
|
69
|
+
ColumnMeta(
|
|
70
|
+
name="amount",
|
|
71
|
+
label="Amount",
|
|
72
|
+
type=ColumnType.CURRENCY,
|
|
73
|
+
pivot=ColumnPivotOptions(
|
|
74
|
+
role="measure",
|
|
75
|
+
allowed_aggs=[
|
|
76
|
+
PivotAggOption(agg=AggFunc.SUM, label="Sum"),
|
|
77
|
+
PivotAggOption(agg=AggFunc.AVG, label="Avg"),
|
|
78
|
+
],
|
|
79
|
+
),
|
|
80
|
+
),
|
|
81
|
+
ColumnMeta(name="quantity", label="Quantity", type=ColumnType.NUMBER),
|
|
82
|
+
]
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@pytest.fixture(scope="session")
|
|
86
|
+
def provider(
|
|
87
|
+
engine: sa.engine.Engine,
|
|
88
|
+
trades_table: sa.Table,
|
|
89
|
+
) -> SQLAlchemyProvider:
|
|
90
|
+
return SQLAlchemyProvider(
|
|
91
|
+
key="trades",
|
|
92
|
+
label="Trades",
|
|
93
|
+
engine=engine,
|
|
94
|
+
selectable=trades_table,
|
|
95
|
+
columns=COLUMN_DEFS,
|
|
96
|
+
capabilities=DatasetCapabilities(pivot=True),
|
|
97
|
+
)
|
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
"""Tests for the SQLAlchemy filter compiler."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sqlalchemy as sa
|
|
6
|
+
from sqlalchemy.engine import Engine
|
|
7
|
+
|
|
8
|
+
from sanjaya_core.enums import FilterCombinator, FilterOperator
|
|
9
|
+
from sanjaya_core.filters import FilterCondition, FilterGroup
|
|
10
|
+
|
|
11
|
+
from sanjaya_sqlalchemy.filters import compile_filter_group
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestFilterCompiler:
|
|
15
|
+
"""Compile filter trees and verify by actually running SQL."""
|
|
16
|
+
|
|
17
|
+
def test_eq(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
18
|
+
fg = FilterGroup(conditions=[
|
|
19
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
|
|
20
|
+
])
|
|
21
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
22
|
+
stmt = sa.select(trades_table).where(expr)
|
|
23
|
+
with engine.connect() as conn:
|
|
24
|
+
rows = conn.execute(stmt).fetchall()
|
|
25
|
+
assert len(rows) == 4
|
|
26
|
+
assert all(r._mapping["desk"] == "FX" for r in rows)
|
|
27
|
+
|
|
28
|
+
def test_neq(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
29
|
+
fg = FilterGroup(conditions=[
|
|
30
|
+
FilterCondition(column="desk", operator=FilterOperator.NEQ, value="FX"),
|
|
31
|
+
])
|
|
32
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
33
|
+
stmt = sa.select(trades_table).where(expr)
|
|
34
|
+
with engine.connect() as conn:
|
|
35
|
+
rows = conn.execute(stmt).fetchall()
|
|
36
|
+
assert len(rows) == 3
|
|
37
|
+
|
|
38
|
+
def test_gt_lt(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
39
|
+
fg = FilterGroup(conditions=[
|
|
40
|
+
FilterCondition(column="amount", operator=FilterOperator.GT, value=1500),
|
|
41
|
+
FilterCondition(column="amount", operator=FilterOperator.LT, value=5000),
|
|
42
|
+
])
|
|
43
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
44
|
+
stmt = sa.select(trades_table).where(expr)
|
|
45
|
+
with engine.connect() as conn:
|
|
46
|
+
rows = conn.execute(stmt).fetchall()
|
|
47
|
+
# amount > 1500 AND < 5000 → 2000, 3000, 4000 = 3 rows
|
|
48
|
+
assert len(rows) == 3
|
|
49
|
+
|
|
50
|
+
def test_between(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
51
|
+
fg = FilterGroup(conditions=[
|
|
52
|
+
FilterCondition(column="amount", operator=FilterOperator.BETWEEN, value=[1000, 3000]),
|
|
53
|
+
])
|
|
54
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
55
|
+
stmt = sa.select(trades_table).where(expr)
|
|
56
|
+
with engine.connect() as conn:
|
|
57
|
+
rows = conn.execute(stmt).fetchall()
|
|
58
|
+
# 1000, 2000, 1500, 3000 = 4 rows
|
|
59
|
+
assert len(rows) == 4
|
|
60
|
+
|
|
61
|
+
def test_in(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
62
|
+
fg = FilterGroup(conditions=[
|
|
63
|
+
FilterCondition(column="region", operator=FilterOperator.IN, value=["US", "EU"]),
|
|
64
|
+
])
|
|
65
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
66
|
+
stmt = sa.select(trades_table).where(expr)
|
|
67
|
+
with engine.connect() as conn:
|
|
68
|
+
rows = conn.execute(stmt).fetchall()
|
|
69
|
+
assert len(rows) == 5
|
|
70
|
+
|
|
71
|
+
def test_contains(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
72
|
+
fg = FilterGroup(conditions=[
|
|
73
|
+
FilterCondition(column="instrument", operator=FilterOperator.CONTAINS, value="USD"),
|
|
74
|
+
])
|
|
75
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
76
|
+
stmt = sa.select(trades_table).where(expr)
|
|
77
|
+
with engine.connect() as conn:
|
|
78
|
+
rows = conn.execute(stmt).fetchall()
|
|
79
|
+
# EURUSD (×2), GBPUSD, USDJPY = 4
|
|
80
|
+
assert len(rows) == 4
|
|
81
|
+
|
|
82
|
+
def test_startswith(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
83
|
+
fg = FilterGroup(conditions=[
|
|
84
|
+
FilterCondition(column="instrument", operator=FilterOperator.STARTSWITH, value="EUR"),
|
|
85
|
+
])
|
|
86
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
87
|
+
stmt = sa.select(trades_table).where(expr)
|
|
88
|
+
with engine.connect() as conn:
|
|
89
|
+
rows = conn.execute(stmt).fetchall()
|
|
90
|
+
assert len(rows) == 2
|
|
91
|
+
|
|
92
|
+
def test_or_combinator(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
93
|
+
fg = FilterGroup(
|
|
94
|
+
combinator=FilterCombinator.OR,
|
|
95
|
+
conditions=[
|
|
96
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
|
|
97
|
+
FilterCondition(column="region", operator=FilterOperator.EQ, value="APAC"),
|
|
98
|
+
],
|
|
99
|
+
)
|
|
100
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
101
|
+
stmt = sa.select(trades_table).where(expr)
|
|
102
|
+
with engine.connect() as conn:
|
|
103
|
+
rows = conn.execute(stmt).fetchall()
|
|
104
|
+
# FX: 1,2,3,7; APAC: 6,7 → union = 1,2,3,6,7 = 5
|
|
105
|
+
assert len(rows) == 5
|
|
106
|
+
|
|
107
|
+
def test_negate_condition(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
108
|
+
fg = FilterGroup(conditions=[
|
|
109
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX", negate=True),
|
|
110
|
+
])
|
|
111
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
112
|
+
stmt = sa.select(trades_table).where(expr)
|
|
113
|
+
with engine.connect() as conn:
|
|
114
|
+
rows = conn.execute(stmt).fetchall()
|
|
115
|
+
assert len(rows) == 3 # NOT FX → Rates rows
|
|
116
|
+
|
|
117
|
+
def test_negate_group(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
118
|
+
fg = FilterGroup(
|
|
119
|
+
negate=True,
|
|
120
|
+
conditions=[
|
|
121
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
|
|
122
|
+
],
|
|
123
|
+
)
|
|
124
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
125
|
+
stmt = sa.select(trades_table).where(expr)
|
|
126
|
+
with engine.connect() as conn:
|
|
127
|
+
rows = conn.execute(stmt).fetchall()
|
|
128
|
+
assert len(rows) == 3
|
|
129
|
+
|
|
130
|
+
def test_nested_groups(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
131
|
+
"""(desk = 'FX') AND (region = 'US' OR region = 'EU')."""
|
|
132
|
+
fg = FilterGroup(
|
|
133
|
+
conditions=[
|
|
134
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
|
|
135
|
+
],
|
|
136
|
+
groups=[
|
|
137
|
+
FilterGroup(
|
|
138
|
+
combinator=FilterCombinator.OR,
|
|
139
|
+
conditions=[
|
|
140
|
+
FilterCondition(column="region", operator=FilterOperator.EQ, value="US"),
|
|
141
|
+
FilterCondition(column="region", operator=FilterOperator.EQ, value="EU"),
|
|
142
|
+
],
|
|
143
|
+
),
|
|
144
|
+
],
|
|
145
|
+
)
|
|
146
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
147
|
+
stmt = sa.select(trades_table).where(expr)
|
|
148
|
+
with engine.connect() as conn:
|
|
149
|
+
rows = conn.execute(stmt).fetchall()
|
|
150
|
+
# FX + (US or EU) → ids 1, 2, 3
|
|
151
|
+
assert len(rows) == 3
|
|
152
|
+
|
|
153
|
+
def test_empty_group_matches_all(self, engine: Engine, trades_table: sa.Table) -> None:
|
|
154
|
+
fg = FilterGroup()
|
|
155
|
+
expr = compile_filter_group(fg, {c.name: c for c in trades_table.columns})
|
|
156
|
+
stmt = sa.select(trades_table).where(expr)
|
|
157
|
+
with engine.connect() as conn:
|
|
158
|
+
rows = conn.execute(stmt).fetchall()
|
|
159
|
+
assert len(rows) == 7
|
|
@@ -0,0 +1,240 @@
|
|
|
1
|
+
"""Tests for SQLAlchemyProvider — flat queries, aggregation, and pivot."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from sanjaya_core.enums import AggFunc, FilterOperator, SortDirection
|
|
6
|
+
from sanjaya_core.filters import FilterCondition, FilterGroup
|
|
7
|
+
from sanjaya_core.types import SortSpec, ValueSpec
|
|
8
|
+
|
|
9
|
+
from sanjaya_sqlalchemy import SQLAlchemyProvider
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
# Flat query
|
|
14
|
+
# ---------------------------------------------------------------------------
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestQuery:
|
|
18
|
+
def test_select_all(self, provider: SQLAlchemyProvider) -> None:
|
|
19
|
+
result = provider.query(["id", "desk", "amount"])
|
|
20
|
+
assert result.total == 7
|
|
21
|
+
assert len(result.rows) == 7
|
|
22
|
+
assert result.columns == ["id", "desk", "amount"]
|
|
23
|
+
# Every row should have exactly the requested keys.
|
|
24
|
+
for row in result.rows:
|
|
25
|
+
assert set(row.keys()) == {"id", "desk", "amount"}
|
|
26
|
+
|
|
27
|
+
def test_limit_offset(self, provider: SQLAlchemyProvider) -> None:
|
|
28
|
+
r1 = provider.query(["id"], sort=[SortSpec(column="id")], limit=3, offset=0)
|
|
29
|
+
r2 = provider.query(["id"], sort=[SortSpec(column="id")], limit=3, offset=3)
|
|
30
|
+
assert r1.total == 7
|
|
31
|
+
ids_1 = [r["id"] for r in r1.rows]
|
|
32
|
+
ids_2 = [r["id"] for r in r2.rows]
|
|
33
|
+
assert ids_1 == [1, 2, 3]
|
|
34
|
+
assert ids_2 == [4, 5, 6]
|
|
35
|
+
|
|
36
|
+
def test_filter(self, provider: SQLAlchemyProvider) -> None:
|
|
37
|
+
fg = FilterGroup(conditions=[
|
|
38
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
|
|
39
|
+
])
|
|
40
|
+
result = provider.query(["id", "desk"], filter_group=fg)
|
|
41
|
+
assert result.total == 4
|
|
42
|
+
assert all(r["desk"] == "FX" for r in result.rows)
|
|
43
|
+
|
|
44
|
+
def test_sort_asc(self, provider: SQLAlchemyProvider) -> None:
|
|
45
|
+
result = provider.query(
|
|
46
|
+
["id", "amount"],
|
|
47
|
+
sort=[SortSpec(column="amount", direction=SortDirection.ASC)],
|
|
48
|
+
)
|
|
49
|
+
amounts = [r["amount"] for r in result.rows]
|
|
50
|
+
assert amounts == sorted(amounts)
|
|
51
|
+
|
|
52
|
+
def test_sort_desc(self, provider: SQLAlchemyProvider) -> None:
|
|
53
|
+
result = provider.query(
|
|
54
|
+
["id", "amount"],
|
|
55
|
+
sort=[SortSpec(column="amount", direction=SortDirection.DESC)],
|
|
56
|
+
)
|
|
57
|
+
amounts = [r["amount"] for r in result.rows]
|
|
58
|
+
assert amounts == sorted(amounts, reverse=True)
|
|
59
|
+
|
|
60
|
+
def test_empty_result(self, provider: SQLAlchemyProvider) -> None:
|
|
61
|
+
fg = FilterGroup(conditions=[
|
|
62
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="Nonexistent"),
|
|
63
|
+
])
|
|
64
|
+
result = provider.query(["id"], filter_group=fg)
|
|
65
|
+
assert result.total == 0
|
|
66
|
+
assert result.rows == []
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# ---------------------------------------------------------------------------
|
|
70
|
+
# Simple aggregation (no pivot)
|
|
71
|
+
# ---------------------------------------------------------------------------
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class TestSimpleAggregate:
|
|
75
|
+
def test_group_by_single_column(self, provider: SQLAlchemyProvider) -> None:
|
|
76
|
+
result = provider.aggregate(
|
|
77
|
+
group_by_rows=["desk"],
|
|
78
|
+
group_by_cols=[],
|
|
79
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
80
|
+
)
|
|
81
|
+
assert result.total == 2 # FX, Rates
|
|
82
|
+
sums = {r["desk"]: r["sum_amount"] for r in result.rows}
|
|
83
|
+
assert sums["FX"] == 1000 + 2000 + 1500 + 500 # 5000
|
|
84
|
+
assert sums["Rates"] == 5000 + 3000 + 4000 # 12000
|
|
85
|
+
|
|
86
|
+
def test_group_by_multiple_columns(self, provider: SQLAlchemyProvider) -> None:
|
|
87
|
+
result = provider.aggregate(
|
|
88
|
+
group_by_rows=["desk", "region"],
|
|
89
|
+
group_by_cols=[],
|
|
90
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
91
|
+
)
|
|
92
|
+
# FX-US, FX-EU, FX-APAC, Rates-US, Rates-EU, Rates-APAC
|
|
93
|
+
assert result.total == 6
|
|
94
|
+
|
|
95
|
+
def test_multiple_agg_funcs(self, provider: SQLAlchemyProvider) -> None:
|
|
96
|
+
result = provider.aggregate(
|
|
97
|
+
group_by_rows=["desk"],
|
|
98
|
+
group_by_cols=[],
|
|
99
|
+
values=[
|
|
100
|
+
ValueSpec(column="amount", agg=AggFunc.SUM),
|
|
101
|
+
ValueSpec(column="amount", agg=AggFunc.COUNT),
|
|
102
|
+
],
|
|
103
|
+
)
|
|
104
|
+
fx_row = next(r for r in result.rows if r["desk"] == "FX")
|
|
105
|
+
assert fx_row["sum_amount"] == 5000.0
|
|
106
|
+
assert fx_row["count_amount"] == 4 # 4 FX trades
|
|
107
|
+
|
|
108
|
+
def test_aggregate_with_filter(self, provider: SQLAlchemyProvider) -> None:
|
|
109
|
+
fg = FilterGroup(conditions=[
|
|
110
|
+
FilterCondition(column="region", operator=FilterOperator.EQ, value="US"),
|
|
111
|
+
])
|
|
112
|
+
result = provider.aggregate(
|
|
113
|
+
group_by_rows=["desk"],
|
|
114
|
+
group_by_cols=[],
|
|
115
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
116
|
+
filter_group=fg,
|
|
117
|
+
)
|
|
118
|
+
assert result.total == 2
|
|
119
|
+
sums = {r["desk"]: r["sum_amount"] for r in result.rows}
|
|
120
|
+
assert sums["FX"] == 3000.0
|
|
121
|
+
assert sums["Rates"] == 5000.0
|
|
122
|
+
|
|
123
|
+
def test_aggregate_with_limit_offset(self, provider: SQLAlchemyProvider) -> None:
|
|
124
|
+
result = provider.aggregate(
|
|
125
|
+
group_by_rows=["desk"],
|
|
126
|
+
group_by_cols=[],
|
|
127
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
128
|
+
sort=[SortSpec(column="desk")],
|
|
129
|
+
limit=1,
|
|
130
|
+
offset=0,
|
|
131
|
+
)
|
|
132
|
+
assert result.total == 2
|
|
133
|
+
assert len(result.rows) == 1
|
|
134
|
+
assert result.rows[0]["desk"] == "FX"
|
|
135
|
+
|
|
136
|
+
def test_distinct_count(self, provider: SQLAlchemyProvider) -> None:
|
|
137
|
+
result = provider.aggregate(
|
|
138
|
+
group_by_rows=["desk"],
|
|
139
|
+
group_by_cols=[],
|
|
140
|
+
values=[ValueSpec(column="region", agg=AggFunc.DISTINCT_COUNT)],
|
|
141
|
+
)
|
|
142
|
+
dc = {r["desk"]: r["distinctCount_region"] for r in result.rows}
|
|
143
|
+
assert dc["FX"] == 3 # US, EU, APAC
|
|
144
|
+
assert dc["Rates"] == 3 # US, EU, APAC
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# ---------------------------------------------------------------------------
|
|
148
|
+
# Pivot aggregation
|
|
149
|
+
# ---------------------------------------------------------------------------
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class TestPivotAggregate:
|
|
153
|
+
def test_basic_pivot(self, provider: SQLAlchemyProvider) -> None:
|
|
154
|
+
"""Pivot by region: desk × region → sum(amount)."""
|
|
155
|
+
result = provider.aggregate(
|
|
156
|
+
group_by_rows=["desk"],
|
|
157
|
+
group_by_cols=["region"],
|
|
158
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
159
|
+
)
|
|
160
|
+
assert result.total == 2 # FX, Rates
|
|
161
|
+
# There should be row-dim columns + one per (region × measure)
|
|
162
|
+
col_keys = [c.key for c in result.columns]
|
|
163
|
+
assert "desk" in col_keys
|
|
164
|
+
# Verify a few pivot columns exist
|
|
165
|
+
assert any("US_sum_amount" in k for k in col_keys)
|
|
166
|
+
assert any("EU_sum_amount" in k for k in col_keys)
|
|
167
|
+
|
|
168
|
+
fx_row = next(r for r in result.rows if r["desk"] == "FX")
|
|
169
|
+
assert fx_row["US_sum_amount"] == 3000.0 # 1000 + 2000
|
|
170
|
+
assert fx_row["EU_sum_amount"] == 1500.0
|
|
171
|
+
assert fx_row["APAC_sum_amount"] == 500.0
|
|
172
|
+
|
|
173
|
+
def test_pivot_column_keys_are_strings(self, provider: SQLAlchemyProvider) -> None:
|
|
174
|
+
result = provider.aggregate(
|
|
175
|
+
group_by_rows=["desk"],
|
|
176
|
+
group_by_cols=["region"],
|
|
177
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
178
|
+
)
|
|
179
|
+
for col in result.columns:
|
|
180
|
+
if col.pivot_keys:
|
|
181
|
+
assert all(isinstance(k, str) for k in col.pivot_keys)
|
|
182
|
+
|
|
183
|
+
def test_multi_dim_pivot(self, provider: SQLAlchemyProvider) -> None:
|
|
184
|
+
"""Pivot by (region, instrument)."""
|
|
185
|
+
result = provider.aggregate(
|
|
186
|
+
group_by_rows=["desk"],
|
|
187
|
+
group_by_cols=["region", "instrument"],
|
|
188
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
189
|
+
)
|
|
190
|
+
assert result.total == 2
|
|
191
|
+
# Find the FX-US-EURUSD cell
|
|
192
|
+
fx_row = next(r for r in result.rows if r["desk"] == "FX")
|
|
193
|
+
assert fx_row.get("US_EURUSD_sum_amount") == 1000.0
|
|
194
|
+
|
|
195
|
+
def test_pivot_with_multiple_measures(self, provider: SQLAlchemyProvider) -> None:
|
|
196
|
+
result = provider.aggregate(
|
|
197
|
+
group_by_rows=["desk"],
|
|
198
|
+
group_by_cols=["region"],
|
|
199
|
+
values=[
|
|
200
|
+
ValueSpec(column="amount", agg=AggFunc.SUM),
|
|
201
|
+
ValueSpec(column="quantity", agg=AggFunc.SUM),
|
|
202
|
+
],
|
|
203
|
+
)
|
|
204
|
+
fx_row = next(r for r in result.rows if r["desk"] == "FX")
|
|
205
|
+
assert fx_row["US_sum_amount"] == 3000.0
|
|
206
|
+
assert fx_row["US_sum_quantity"] == 30 # 10 + 20
|
|
207
|
+
|
|
208
|
+
def test_pivot_with_filter(self, provider: SQLAlchemyProvider) -> None:
|
|
209
|
+
fg = FilterGroup(conditions=[
|
|
210
|
+
FilterCondition(column="desk", operator=FilterOperator.EQ, value="FX"),
|
|
211
|
+
])
|
|
212
|
+
result = provider.aggregate(
|
|
213
|
+
group_by_rows=["desk"],
|
|
214
|
+
group_by_cols=["region"],
|
|
215
|
+
values=[ValueSpec(column="amount", agg=AggFunc.SUM)],
|
|
216
|
+
filter_group=fg,
|
|
217
|
+
)
|
|
218
|
+
assert result.total == 1
|
|
219
|
+
assert result.rows[0]["desk"] == "FX"
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
# ---------------------------------------------------------------------------
|
|
223
|
+
# Provider metadata
|
|
224
|
+
# ---------------------------------------------------------------------------
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class TestProviderMeta:
|
|
228
|
+
def test_get_columns(self, provider: SQLAlchemyProvider) -> None:
|
|
229
|
+
cols = provider.get_columns()
|
|
230
|
+
assert len(cols) == 6
|
|
231
|
+
names = [c.name for c in cols]
|
|
232
|
+
assert "id" in names
|
|
233
|
+
assert "desk" in names
|
|
234
|
+
|
|
235
|
+
def test_capabilities(self, provider: SQLAlchemyProvider) -> None:
|
|
236
|
+
assert provider.capabilities.pivot is True
|
|
237
|
+
|
|
238
|
+
def test_identity(self, provider: SQLAlchemyProvider) -> None:
|
|
239
|
+
assert provider.key == "trades"
|
|
240
|
+
assert provider.label == "Trades"
|