ibor-audit-tool 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ibor_audit_tool-0.1.0/.gitignore +11 -0
- ibor_audit_tool-0.1.0/.python-version +1 -0
- ibor_audit_tool-0.1.0/PKG-INFO +9 -0
- ibor_audit_tool-0.1.0/README.md +0 -0
- ibor_audit_tool-0.1.0/pyproject.toml +33 -0
- ibor_audit_tool-0.1.0/sample/config.yml +32 -0
- ibor_audit_tool-0.1.0/src/audit_tool/__init__.py +19 -0
- ibor_audit_tool-0.1.0/src/audit_tool/__main__.py +5 -0
- ibor_audit_tool-0.1.0/src/audit_tool/cli.py +33 -0
- ibor_audit_tool-0.1.0/src/audit_tool/compare.py +156 -0
- ibor_audit_tool-0.1.0/src/audit_tool/config.py +149 -0
- ibor_audit_tool-0.1.0/src/audit_tool/errors.py +10 -0
- ibor_audit_tool-0.1.0/src/audit_tool/factor.py +55 -0
- ibor_audit_tool-0.1.0/src/audit_tool/report.py +509 -0
- ibor_audit_tool-0.1.0/tests/test_cli.py +122 -0
- ibor_audit_tool-0.1.0/tests/test_compare.py +109 -0
- ibor_audit_tool-0.1.0/tests/test_config.py +141 -0
- ibor_audit_tool-0.1.0/tests/test_report.py +141 -0
- ibor_audit_tool-0.1.0/uv.lock +350 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.12
|
|
File without changes
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "ibor-audit-tool"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.12"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"psycopg2-binary>=2.9",
|
|
9
|
+
"pyyaml>=6.0",
|
|
10
|
+
"sqlalchemy>=2.0",
|
|
11
|
+
"typer>=0.12",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[project.scripts]
|
|
15
|
+
xaudit= "audit_tool.cli:main"
|
|
16
|
+
|
|
17
|
+
[[tool.uv.index]]
|
|
18
|
+
name = "pypi"
|
|
19
|
+
url = "https://pypi.org/simple"
|
|
20
|
+
publish-url = "https://upload.pypi.org/legacy/"
|
|
21
|
+
explicit = true
|
|
22
|
+
|
|
23
|
+
[tool.hatch.build.targets.wheel]
|
|
24
|
+
packages = ["src/audit_tool"]
|
|
25
|
+
|
|
26
|
+
[dependency-groups]
|
|
27
|
+
dev = [
|
|
28
|
+
"pytest>=8.0",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
[build-system]
|
|
32
|
+
requires = ["hatchling"]
|
|
33
|
+
build-backend = "hatchling.build"
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
connections:
|
|
2
|
+
- name: FA
|
|
3
|
+
url: postgresql://ibor_user:ibor_user@localhost:5432/ibor-demo
|
|
4
|
+
- name: IBOR
|
|
5
|
+
url: postgresql://ibor_user:ibor_user@localhost:5432/ibor-demo
|
|
6
|
+
|
|
7
|
+
compares:
|
|
8
|
+
- name: total_pnl
|
|
9
|
+
mode: full
|
|
10
|
+
left:
|
|
11
|
+
sql: |
|
|
12
|
+
select
|
|
13
|
+
cast(replace(enddate,'-','') as int) as trade_date,
|
|
14
|
+
concat(t1.stkid,'.','HK') as secu_code,
|
|
15
|
+
'XIR_'||book as portfolio_id,
|
|
16
|
+
reportpl as ytd_total_pnl
|
|
17
|
+
from ods_faas.bd_ods_faas_glhs_bond_rpt_nxt_his t1
|
|
18
|
+
left join irmp.dim_dict_code_mapping t5
|
|
19
|
+
on t5.src_sys_name = 'PG'
|
|
20
|
+
and t5.src_cls_code = 'mkt_cd'
|
|
21
|
+
and t5.src_cls_cdval=t1.exch_id
|
|
22
|
+
and t5.valid_flag = '1'
|
|
23
|
+
where enddate='2026-06-04'
|
|
24
|
+
right:
|
|
25
|
+
sql: |
|
|
26
|
+
select
|
|
27
|
+
trade_date,
|
|
28
|
+
portfolio_id,
|
|
29
|
+
secu_code,
|
|
30
|
+
ytd_total_pnl
|
|
31
|
+
from irmp.ibor_bond_position
|
|
32
|
+
where secu_code like '%HK' and trade_date =20260604 and calc_basis ='1' and invest_type ='1'
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Audit tool command-line application."""
|
|
2
|
+
|
|
3
|
+
from audit_tool.compare import Compare, CompareResult
|
|
4
|
+
from audit_tool.factor import Connection, Factor, SQLFactor, TableData
|
|
5
|
+
from audit_tool.report import CompareReportItem, ReportResult
|
|
6
|
+
|
|
7
|
+
__version__ = "0.1.0"
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Compare",
|
|
11
|
+
"CompareResult",
|
|
12
|
+
"Connection",
|
|
13
|
+
"Factor",
|
|
14
|
+
"CompareReportItem",
|
|
15
|
+
"ReportResult",
|
|
16
|
+
"SQLFactor",
|
|
17
|
+
"TableData",
|
|
18
|
+
"__version__",
|
|
19
|
+
]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import typer
|
|
4
|
+
|
|
5
|
+
from audit_tool.errors import ConfigError, FactorLoadError
|
|
6
|
+
from audit_tool.report import run_report, write_html_report
|
|
7
|
+
|
|
8
|
+
app = typer.Typer()
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@app.command()
|
|
12
|
+
def run(
|
|
13
|
+
config_path: Path,
|
|
14
|
+
output: Path = typer.Option(Path("audit-report.html"), "--output", "-o"),
|
|
15
|
+
) -> None:
|
|
16
|
+
try:
|
|
17
|
+
report = run_report(config_path)
|
|
18
|
+
write_html_report(report, output)
|
|
19
|
+
except ConfigError as exc:
|
|
20
|
+
typer.echo(str(exc), err=True)
|
|
21
|
+
raise typer.Exit(2) from exc
|
|
22
|
+
except FactorLoadError as exc:
|
|
23
|
+
typer.echo(str(exc), err=True)
|
|
24
|
+
raise typer.Exit(2) from exc
|
|
25
|
+
|
|
26
|
+
typer.echo(
|
|
27
|
+
f"wrote report to {output} "
|
|
28
|
+
f"({report.total_count} total, {report.passed_count} passed, {report.failed_count} failed)"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def main() -> None:
|
|
33
|
+
app()
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import Counter, defaultdict
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from audit_tool.errors import ConfigError
|
|
8
|
+
from audit_tool.factor import Factor, TableData
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class CompareResult:
|
|
13
|
+
passed: bool
|
|
14
|
+
message: str
|
|
15
|
+
columns: tuple[str, ...] = field(default_factory=tuple)
|
|
16
|
+
diff_summary: dict[str, Any] = field(default_factory=dict)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Compare:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
name: str,
|
|
23
|
+
left: Factor,
|
|
24
|
+
right: Factor,
|
|
25
|
+
*,
|
|
26
|
+
mode: str = "full",
|
|
27
|
+
) -> None:
|
|
28
|
+
self.name = name
|
|
29
|
+
self.left = left
|
|
30
|
+
self.right = right
|
|
31
|
+
self.mode = mode
|
|
32
|
+
|
|
33
|
+
def run(self) -> CompareResult:
|
|
34
|
+
if self.mode != "full":
|
|
35
|
+
raise ConfigError(f"unsupported compare mode {self.mode!r}")
|
|
36
|
+
|
|
37
|
+
left_data = self.left.load()
|
|
38
|
+
right_data = self.right.load()
|
|
39
|
+
return compare_full(
|
|
40
|
+
left_data,
|
|
41
|
+
right_data,
|
|
42
|
+
left_name=_factor_result_name(self.left),
|
|
43
|
+
right_name=_factor_result_name(self.right),
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def compare_full(
|
|
48
|
+
left: TableData,
|
|
49
|
+
right: TableData,
|
|
50
|
+
*,
|
|
51
|
+
left_name: str = "left",
|
|
52
|
+
right_name: str = "right",
|
|
53
|
+
) -> CompareResult:
|
|
54
|
+
if Counter(left.columns) != Counter(right.columns):
|
|
55
|
+
return CompareResult(
|
|
56
|
+
passed=False,
|
|
57
|
+
message="columns differ",
|
|
58
|
+
columns=left.columns,
|
|
59
|
+
diff_summary={
|
|
60
|
+
f"{left_name}_columns": left.columns,
|
|
61
|
+
f"{right_name}_columns": right.columns,
|
|
62
|
+
},
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
right_column_indexes = _column_indexes(right.columns, left.columns)
|
|
66
|
+
left_rows = Counter(_normalize_row(row) for row in left.rows)
|
|
67
|
+
right_rows = Counter(
|
|
68
|
+
_normalize_row(_reorder_row(row, right_column_indexes))
|
|
69
|
+
for row in right.rows
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
if left_rows == right_rows:
|
|
73
|
+
return CompareResult(
|
|
74
|
+
passed=True,
|
|
75
|
+
message="tables match",
|
|
76
|
+
columns=left.columns,
|
|
77
|
+
diff_summary={
|
|
78
|
+
"columns": left.columns,
|
|
79
|
+
"row_count": len(left.rows),
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
left_only = left_rows - right_rows
|
|
84
|
+
right_only = right_rows - left_rows
|
|
85
|
+
return CompareResult(
|
|
86
|
+
passed=False,
|
|
87
|
+
message="rows differ",
|
|
88
|
+
columns=left.columns,
|
|
89
|
+
diff_summary={
|
|
90
|
+
f"{left_name}_row_count": len(left.rows),
|
|
91
|
+
f"{right_name}_row_count": len(right.rows),
|
|
92
|
+
f"{left_name}_only_count": left_only.total(),
|
|
93
|
+
f"{right_name}_only_count": right_only.total(),
|
|
94
|
+
f"{left_name}_only_sample": _sample_rows(left_only),
|
|
95
|
+
f"{right_name}_only_sample": _sample_rows(right_only),
|
|
96
|
+
},
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _factor_result_name(factor: Factor) -> str:
|
|
101
|
+
connection = getattr(factor, "connection", None)
|
|
102
|
+
if connection is not None:
|
|
103
|
+
return str(connection.name)
|
|
104
|
+
return factor.name
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _normalize_row(row: tuple[Any, ...]) -> tuple[Any, ...]:
|
|
108
|
+
return tuple(_normalize_value(value) for value in row)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _column_indexes(
|
|
112
|
+
source_columns: tuple[str, ...],
|
|
113
|
+
target_columns: tuple[str, ...],
|
|
114
|
+
) -> tuple[int, ...]:
|
|
115
|
+
source_indexes: dict[tuple[str, int], int] = {}
|
|
116
|
+
occurrences: defaultdict[str, int] = defaultdict(int)
|
|
117
|
+
for index, column in enumerate(source_columns):
|
|
118
|
+
occurrence = occurrences[column]
|
|
119
|
+
source_indexes[(column, occurrence)] = index
|
|
120
|
+
occurrences[column] += 1
|
|
121
|
+
|
|
122
|
+
target_indexes: list[int] = []
|
|
123
|
+
occurrences.clear()
|
|
124
|
+
for column in target_columns:
|
|
125
|
+
occurrence = occurrences[column]
|
|
126
|
+
target_indexes.append(source_indexes[(column, occurrence)])
|
|
127
|
+
occurrences[column] += 1
|
|
128
|
+
return tuple(target_indexes)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _reorder_row(row: tuple[Any, ...], indexes: tuple[int, ...]) -> tuple[Any, ...]:
|
|
132
|
+
return tuple(row[index] for index in indexes)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _normalize_value(value: Any) -> Any:
|
|
136
|
+
if isinstance(value, dict):
|
|
137
|
+
return tuple(
|
|
138
|
+
sorted(
|
|
139
|
+
((key, _normalize_value(item)) for key, item in value.items()),
|
|
140
|
+
key=repr,
|
|
141
|
+
)
|
|
142
|
+
)
|
|
143
|
+
if isinstance(value, list | tuple):
|
|
144
|
+
return tuple(_normalize_value(item) for item in value)
|
|
145
|
+
if isinstance(value, set):
|
|
146
|
+
return tuple(sorted((_normalize_value(item) for item in value), key=repr))
|
|
147
|
+
return value
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _sample_rows(rows: Counter[tuple[Any, ...]], limit: int = 5) -> list[tuple[Any, ...]]:
|
|
151
|
+
sample: list[tuple[Any, ...]] = []
|
|
152
|
+
for row, count in rows.items():
|
|
153
|
+
sample.extend([row] * min(count, limit - len(sample)))
|
|
154
|
+
if len(sample) >= limit:
|
|
155
|
+
break
|
|
156
|
+
return sample
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import yaml
|
|
8
|
+
|
|
9
|
+
from audit_tool.compare import Compare
|
|
10
|
+
from audit_tool.errors import ConfigError
|
|
11
|
+
from audit_tool.factor import Connection, SQLFactor
|
|
12
|
+
|
|
13
|
+
DEFAULT_SIDE_CONNECTIONS = {
|
|
14
|
+
"left": "FA",
|
|
15
|
+
"right": "IBOR",
|
|
16
|
+
}
|
|
17
|
+
ENV_VAR_REFERENCE = re.compile(
|
|
18
|
+
r"(\$\{[A-Za-z_][A-Za-z0-9_]*\}|\$[A-Z_][A-Z0-9_]*|^[A-Z_][A-Z0-9_]*$)"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def load_config(path: Path) -> dict[str, Compare]:
|
|
23
|
+
raw_config = _read_yaml(path)
|
|
24
|
+
connections = _build_connections(raw_config.get("connections"))
|
|
25
|
+
compare_configs = _index_named_items(raw_config.get("compares"), "compares")
|
|
26
|
+
|
|
27
|
+
return {
|
|
28
|
+
name: _build_compare(
|
|
29
|
+
name,
|
|
30
|
+
_require_mapping(config, f"compares.{name}"),
|
|
31
|
+
connections,
|
|
32
|
+
)
|
|
33
|
+
for name, config in compare_configs.items()
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def _read_yaml(path: Path) -> dict[str, Any]:
|
|
38
|
+
try:
|
|
39
|
+
with path.open("r", encoding="utf-8") as file:
|
|
40
|
+
data = yaml.safe_load(file)
|
|
41
|
+
except OSError as exc:
|
|
42
|
+
raise ConfigError(f"failed to read config {str(path)!r}: {exc}") from exc
|
|
43
|
+
|
|
44
|
+
return _require_mapping(data, "config")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _build_compare(
|
|
48
|
+
name: str,
|
|
49
|
+
config: dict[str, Any],
|
|
50
|
+
connections: dict[str, Connection],
|
|
51
|
+
) -> Compare:
|
|
52
|
+
mode = config.get("mode", "full")
|
|
53
|
+
if mode != "full":
|
|
54
|
+
raise ConfigError(f"compares.{name}.mode must be 'full'")
|
|
55
|
+
|
|
56
|
+
left = _build_compare_side(name, "left", config.get("left"), connections)
|
|
57
|
+
right = _build_compare_side(name, "right", config.get("right"), connections)
|
|
58
|
+
|
|
59
|
+
return Compare(name=name, left=left, right=right, mode=mode)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _build_compare_side(
|
|
63
|
+
compare_name: str,
|
|
64
|
+
side_name: str,
|
|
65
|
+
side_value: Any,
|
|
66
|
+
connections: dict[str, Connection],
|
|
67
|
+
) -> SQLFactor:
|
|
68
|
+
side_config = _require_mapping(side_value, f"compares.{compare_name}.{side_name}")
|
|
69
|
+
connection = _resolve_connection(
|
|
70
|
+
side_name,
|
|
71
|
+
side_config,
|
|
72
|
+
connections,
|
|
73
|
+
)
|
|
74
|
+
sql = _require_string(side_config.get("sql"), f"compares.{compare_name}.{side_name}.sql")
|
|
75
|
+
|
|
76
|
+
return SQLFactor(
|
|
77
|
+
name=f"{compare_name}.{side_name}",
|
|
78
|
+
sql=sql,
|
|
79
|
+
connection=connection,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _resolve_connection(
|
|
84
|
+
side_name: str,
|
|
85
|
+
side_config: dict[str, Any],
|
|
86
|
+
connections: dict[str, Connection],
|
|
87
|
+
) -> Connection:
|
|
88
|
+
connection_name = _resolve_connection_name(side_name, side_config)
|
|
89
|
+
try:
|
|
90
|
+
return connections[connection_name]
|
|
91
|
+
except KeyError as exc:
|
|
92
|
+
raise ConfigError(f"connections.{connection_name} must be configured") from exc
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _resolve_connection_name(side_name: str, side_config: dict[str, Any]) -> str:
|
|
96
|
+
connection_name = side_config.get("connection")
|
|
97
|
+
if connection_name is None:
|
|
98
|
+
return DEFAULT_SIDE_CONNECTIONS[side_name]
|
|
99
|
+
return _require_string(connection_name, f"{side_name}.connection")
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _build_connections(value: Any) -> dict[str, Connection]:
|
|
103
|
+
connection_configs = _index_named_items(value, "connections")
|
|
104
|
+
return {
|
|
105
|
+
name: _build_connection(name, config)
|
|
106
|
+
for name, config in connection_configs.items()
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _build_connection(name: str, config: dict[str, Any]) -> Connection:
|
|
111
|
+
url = _require_string(config.get("url"), f"connections.{name}.url")
|
|
112
|
+
_validate_connection_url(name, url)
|
|
113
|
+
return Connection(name=name, url=url)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _validate_connection_url(name: str, url: str) -> None:
|
|
117
|
+
if ENV_VAR_REFERENCE.search(url):
|
|
118
|
+
raise ConfigError(
|
|
119
|
+
f"connections.{name}.url must be a literal database URL, not an environment variable"
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if "://" not in url:
|
|
123
|
+
raise ConfigError(f"connections.{name}.url must be a literal database URL")
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _require_mapping(value: Any, name: str) -> dict[str, Any]:
|
|
127
|
+
if not isinstance(value, dict):
|
|
128
|
+
raise ConfigError(f"{name} must be a mapping")
|
|
129
|
+
return value
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _index_named_items(value: Any, name: str) -> dict[str, dict[str, Any]]:
|
|
133
|
+
if not isinstance(value, list):
|
|
134
|
+
raise ConfigError(f"{name} must be a list")
|
|
135
|
+
|
|
136
|
+
items: dict[str, dict[str, Any]] = {}
|
|
137
|
+
for index, item in enumerate(value):
|
|
138
|
+
item_config = _require_mapping(item, f"{name}[{index}]")
|
|
139
|
+
item_name = _require_string(item_config.get("name"), f"{name}[{index}].name")
|
|
140
|
+
if item_name in items:
|
|
141
|
+
raise ConfigError(f"{name} contains duplicate name {item_name!r}")
|
|
142
|
+
items[item_name] = item_config
|
|
143
|
+
return items
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _require_string(value: Any, name: str) -> str:
|
|
147
|
+
if not isinstance(value, str) or not value:
|
|
148
|
+
raise ConfigError(f"{name} must be a non-empty string")
|
|
149
|
+
return value
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
class AuditToolError(Exception):
|
|
2
|
+
"""Base exception for audit-tool errors."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ConfigError(AuditToolError):
|
|
6
|
+
"""Raised when the audit configuration is invalid."""
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class FactorLoadError(AuditToolError):
|
|
10
|
+
"""Raised when a factor cannot load its data."""
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import create_engine, text
|
|
8
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
9
|
+
|
|
10
|
+
from audit_tool.errors import FactorLoadError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass(frozen=True)
|
|
14
|
+
class TableData:
|
|
15
|
+
columns: tuple[str, ...]
|
|
16
|
+
rows: tuple[tuple[Any, ...], ...]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass(frozen=True)
|
|
20
|
+
class Connection:
|
|
21
|
+
name: str
|
|
22
|
+
url: str
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Factor(ABC):
|
|
26
|
+
def __init__(self, name: str) -> None:
|
|
27
|
+
self.name = name
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def load(self) -> TableData:
|
|
31
|
+
"""Load this factor as table-shaped data."""
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class SQLFactor(Factor):
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
name: str,
|
|
38
|
+
sql: str,
|
|
39
|
+
connection: Connection,
|
|
40
|
+
) -> None:
|
|
41
|
+
super().__init__(name)
|
|
42
|
+
self.sql = sql
|
|
43
|
+
self.connection = connection
|
|
44
|
+
|
|
45
|
+
def load(self) -> TableData:
|
|
46
|
+
try:
|
|
47
|
+
engine = create_engine(self.connection.url)
|
|
48
|
+
with engine.connect() as connection:
|
|
49
|
+
result = connection.execute(text(self.sql))
|
|
50
|
+
columns = tuple(str(column) for column in result.keys())
|
|
51
|
+
rows = tuple(tuple(row) for row in result.fetchall())
|
|
52
|
+
except SQLAlchemyError as exc:
|
|
53
|
+
raise FactorLoadError(f"failed to load SQL factor {self.name!r}: {exc}") from exc
|
|
54
|
+
|
|
55
|
+
return TableData(columns=columns, rows=rows)
|