iceaxe 0.2.3.dev1__tar.gz → 0.3.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/PKG-INFO +1 -1
- iceaxe-0.3.0/iceaxe/__tests__/helpers.py +263 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_comparison.py +84 -0
- iceaxe-0.3.0/iceaxe/__tests__/test_helpers.py +9 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_queries.py +15 -1
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_session.py +50 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/base.py +18 -2
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/comparison.py +28 -5
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/field.py +15 -3
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/queries.py +100 -15
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/db_memory_serializer.py +1 -6
- iceaxe-0.3.0/iceaxe/session_optimized.pyx +199 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/typing.py +5 -2
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/pyproject.toml +1 -1
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/setup.py +1 -1
- iceaxe-0.2.3.dev1/iceaxe/session_optimized.pyx +0 -102
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/LICENSE +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/README.md +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/build.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/.DS_Store +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/benchmarks/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/benchmarks/test_select.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/conf_models.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/conftest.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/conftest.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/test_action_sorter.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/test_generator.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/test_generics.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/mountaineer/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/mountaineer/dependencies/test_core.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_actions.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_cli.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_db_memory_serializer.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_db_serializer.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_db_stubs.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_base.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_field.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/functions.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/generics.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/io.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/logging.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/action_sorter.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/cli.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/client_io.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/generator.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/migration.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/migrator.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/cli.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/config.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/dependencies/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/dependencies/core.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/postgres.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/py.typed +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/queries_str.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/__init__.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/actions.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/cli.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/db_serializer.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/db_stubs.py +0 -0
- {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/session.py +0 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import inspect
|
|
3
|
+
import os
|
|
4
|
+
from contextlib import contextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from json import JSONDecodeError, dump as json_dump, loads as json_loads
|
|
7
|
+
from re import Pattern
|
|
8
|
+
from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|
9
|
+
from textwrap import dedent
|
|
10
|
+
|
|
11
|
+
from pyright import run
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class PyrightDiagnostic:
|
|
16
|
+
file: str
|
|
17
|
+
severity: str
|
|
18
|
+
message: str
|
|
19
|
+
rule: str | None
|
|
20
|
+
line: int
|
|
21
|
+
column: int
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ExpectedPyrightError(Exception):
|
|
25
|
+
"""
|
|
26
|
+
Exception raised when Pyright doesn't produce the expected error
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def get_imports_from_module(module_source: str) -> set[str]:
|
|
34
|
+
"""
|
|
35
|
+
Extract all import statements from module source
|
|
36
|
+
|
|
37
|
+
"""
|
|
38
|
+
tree = ast.parse(module_source)
|
|
39
|
+
imports: set[str] = set()
|
|
40
|
+
|
|
41
|
+
for node in ast.walk(tree):
|
|
42
|
+
if isinstance(node, ast.Import):
|
|
43
|
+
for name in node.names:
|
|
44
|
+
imports.add(f"import {name.name}")
|
|
45
|
+
elif isinstance(node, ast.ImportFrom):
|
|
46
|
+
names = ", ".join(name.name for name in node.names)
|
|
47
|
+
if node.module is None:
|
|
48
|
+
# Handle "from . import x" case
|
|
49
|
+
imports.add(f"from . import {names}")
|
|
50
|
+
else:
|
|
51
|
+
imports.add(f"from {node.module} import {names}")
|
|
52
|
+
|
|
53
|
+
return imports
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def strip_type_ignore(line: str) -> str:
|
|
57
|
+
"""
|
|
58
|
+
Strip type: ignore comments from a line while preserving the line content
|
|
59
|
+
|
|
60
|
+
"""
|
|
61
|
+
if "#" not in line:
|
|
62
|
+
return line
|
|
63
|
+
|
|
64
|
+
# Split only on the first #
|
|
65
|
+
code_part, *comment_parts = line.split("#", 1)
|
|
66
|
+
if not comment_parts:
|
|
67
|
+
return line
|
|
68
|
+
|
|
69
|
+
comment = comment_parts[0]
|
|
70
|
+
# If this is a type: ignore comment, return just the code
|
|
71
|
+
if "type:" in comment and "ignore" in comment:
|
|
72
|
+
return code_part.rstrip()
|
|
73
|
+
|
|
74
|
+
# Otherwise return the full line
|
|
75
|
+
return line
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def extract_current_function_code():
|
|
79
|
+
"""
|
|
80
|
+
Extracts the source code of the function calling this utility,
|
|
81
|
+
along with any necessary imports at the module level. This only works for
|
|
82
|
+
functions in a pytest testing context that are prefixed with `test_`.
|
|
83
|
+
|
|
84
|
+
"""
|
|
85
|
+
# Get the frame of the calling function
|
|
86
|
+
frame = inspect.currentframe()
|
|
87
|
+
|
|
88
|
+
try:
|
|
89
|
+
# Go up until we find the test function; workaround to not
|
|
90
|
+
# knowing the entrypoint of our contextmanager at runtime
|
|
91
|
+
while frame is not None:
|
|
92
|
+
func_name = frame.f_code.co_name
|
|
93
|
+
if func_name.startswith("test_"):
|
|
94
|
+
test_frame = frame
|
|
95
|
+
break
|
|
96
|
+
frame = frame.f_back
|
|
97
|
+
else:
|
|
98
|
+
raise RuntimeError("Could not find test function frame")
|
|
99
|
+
|
|
100
|
+
# Source code of the function
|
|
101
|
+
func_source = inspect.getsource(test_frame.f_code)
|
|
102
|
+
|
|
103
|
+
# Source code of the larger test file, which contains the test function
|
|
104
|
+
# All the imports used by the test function should be within this file
|
|
105
|
+
module = inspect.getmodule(test_frame)
|
|
106
|
+
if not module:
|
|
107
|
+
raise RuntimeError("Could not find module for test function")
|
|
108
|
+
|
|
109
|
+
module_source = inspect.getsource(module)
|
|
110
|
+
|
|
111
|
+
# Postprocess the source code to build into a valid new module
|
|
112
|
+
imports = get_imports_from_module(module_source)
|
|
113
|
+
filtered_lines = [strip_type_ignore(line) for line in func_source.split("\n")]
|
|
114
|
+
return "\n".join(sorted(imports)) + "\n\n" + dedent("\n".join(filtered_lines))
|
|
115
|
+
|
|
116
|
+
finally:
|
|
117
|
+
del frame # Avoid reference cycles
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def create_pyright_config():
|
|
121
|
+
"""
|
|
122
|
+
Creates a new pyright configuration that ignores unused imports or other
|
|
123
|
+
issues that are not related to context-manager wrapped type checking.
|
|
124
|
+
|
|
125
|
+
"""
|
|
126
|
+
return {
|
|
127
|
+
"include": ["."],
|
|
128
|
+
"exclude": [],
|
|
129
|
+
"ignore": [],
|
|
130
|
+
"strict": [],
|
|
131
|
+
"typeCheckingMode": "strict",
|
|
132
|
+
"reportUnusedImport": False,
|
|
133
|
+
"reportUnusedVariable": False,
|
|
134
|
+
# Focus only on type checking
|
|
135
|
+
"reportOptionalMemberAccess": True,
|
|
136
|
+
"reportGeneralTypeIssues": True,
|
|
137
|
+
"reportPropertyTypeMismatch": True,
|
|
138
|
+
"reportFunctionMemberAccess": True,
|
|
139
|
+
"reportTypeCommentUsage": True,
|
|
140
|
+
"reportMissingTypeStubs": False,
|
|
141
|
+
# Only typehint intentional typehints, not inferred values
|
|
142
|
+
"reportUnknownParameterType": False,
|
|
143
|
+
"reportUnknownVariableType": False,
|
|
144
|
+
"reportUnknownMemberType": False,
|
|
145
|
+
"reportUnknownArgumentType": False,
|
|
146
|
+
"reportMissingParameterType": False,
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def run_pyright(file_path: str) -> list[PyrightDiagnostic]:
|
|
151
|
+
"""
|
|
152
|
+
Run pyright on a file and return the diagnostics
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
try:
|
|
156
|
+
with TemporaryDirectory() as temp_dir:
|
|
157
|
+
# Create pyright config
|
|
158
|
+
config_path = os.path.join(temp_dir, "pyrightconfig.json")
|
|
159
|
+
with open(config_path, "w") as f:
|
|
160
|
+
json_dump(create_pyright_config(), f)
|
|
161
|
+
|
|
162
|
+
# Copy the file to analyze into the project directory
|
|
163
|
+
test_file = os.path.join(temp_dir, "test.py")
|
|
164
|
+
with open(file_path, "r") as src, open(test_file, "w") as dst:
|
|
165
|
+
dst.write(src.read())
|
|
166
|
+
|
|
167
|
+
# Run pyright with the config
|
|
168
|
+
result = run(
|
|
169
|
+
"--project",
|
|
170
|
+
temp_dir,
|
|
171
|
+
"--outputjson",
|
|
172
|
+
test_file,
|
|
173
|
+
capture_output=True,
|
|
174
|
+
text=True,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
output = json_loads(result.stdout)
|
|
179
|
+
except JSONDecodeError:
|
|
180
|
+
print(f"Failed to parse pyright output: {result.stdout}") # noqa: T201
|
|
181
|
+
print(f"Stderr: {result.stderr}") # noqa: T201
|
|
182
|
+
raise
|
|
183
|
+
|
|
184
|
+
if "generalDiagnostics" not in output:
|
|
185
|
+
raise RuntimeError(
|
|
186
|
+
f"Unknown pyright output, missing generalDiagnostics: {output}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
diagnostics: list[PyrightDiagnostic] = []
|
|
190
|
+
for diag in output["generalDiagnostics"]:
|
|
191
|
+
diagnostics.append(
|
|
192
|
+
PyrightDiagnostic(
|
|
193
|
+
file=diag["file"],
|
|
194
|
+
severity=diag["severity"],
|
|
195
|
+
message=diag["message"],
|
|
196
|
+
rule=diag.get("rule"),
|
|
197
|
+
line=diag["range"]["start"]["line"] + 1, # Convert to 1-based
|
|
198
|
+
column=(
|
|
199
|
+
diag["range"]["start"]["character"]
|
|
200
|
+
+ 1 # Convert to 1-based
|
|
201
|
+
),
|
|
202
|
+
)
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
return diagnostics
|
|
206
|
+
|
|
207
|
+
except Exception as e:
|
|
208
|
+
raise RuntimeError(f"Failed to run pyright: {str(e)}")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
@contextmanager
|
|
212
|
+
def pyright_raises(
|
|
213
|
+
expected_rule: str,
|
|
214
|
+
expected_line: int | None = None,
|
|
215
|
+
matches: Pattern | None = None,
|
|
216
|
+
):
|
|
217
|
+
"""
|
|
218
|
+
Context manager that verifies code produces a specific Pyright error.
|
|
219
|
+
|
|
220
|
+
:params expected_rule: The Pyright rule that should be violated
|
|
221
|
+
:params expected_line: Optional line number where the error should occur
|
|
222
|
+
|
|
223
|
+
:raises ExpectedPyrightError: If Pyright doesn't produce the expected error
|
|
224
|
+
|
|
225
|
+
"""
|
|
226
|
+
# Create a temporary file to store the code
|
|
227
|
+
with NamedTemporaryFile(mode="w", suffix=".py") as temp_file:
|
|
228
|
+
temp_path = temp_file.name
|
|
229
|
+
|
|
230
|
+
# Extract the source code of the calling function
|
|
231
|
+
source_code = extract_current_function_code()
|
|
232
|
+
print(f"Running Pyright on:\n{source_code}") # noqa: T201
|
|
233
|
+
|
|
234
|
+
# Write the source code to the temporary file
|
|
235
|
+
temp_file.write(source_code)
|
|
236
|
+
temp_file.flush()
|
|
237
|
+
|
|
238
|
+
# At runtime, our actual code is probably a no-op but we still let it run
|
|
239
|
+
# inside the scope of the contextmanager
|
|
240
|
+
yield
|
|
241
|
+
|
|
242
|
+
# Run Pyright on the temporary file
|
|
243
|
+
diagnostics = run_pyright(temp_path)
|
|
244
|
+
|
|
245
|
+
# Check if any of the diagnostics match our expected error
|
|
246
|
+
for diagnostic in diagnostics:
|
|
247
|
+
if diagnostic.rule == expected_rule:
|
|
248
|
+
if expected_line is not None and diagnostic.line != expected_line:
|
|
249
|
+
continue
|
|
250
|
+
if matches and not matches.search(diagnostic.message):
|
|
251
|
+
continue
|
|
252
|
+
# Found matching error
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
# If we get here, we didn't find the expected error
|
|
256
|
+
actual_errors = [
|
|
257
|
+
f"{d.rule or 'unknown'} on line {d.line}: {d.message}" for d in diagnostics
|
|
258
|
+
]
|
|
259
|
+
raise ExpectedPyrightError(
|
|
260
|
+
f"Expected Pyright error {expected_rule}"
|
|
261
|
+
f"{f' on line {expected_line}' if expected_line else ''}"
|
|
262
|
+
f" but got: {', '.join(actual_errors) if actual_errors else 'no errors'}"
|
|
263
|
+
)
|
|
@@ -1,10 +1,14 @@
|
|
|
1
|
+
from re import compile as re_compile
|
|
1
2
|
from typing import Any
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
5
|
+
from typing_extensions import assert_type
|
|
4
6
|
|
|
7
|
+
from iceaxe.__tests__.helpers import pyright_raises
|
|
5
8
|
from iceaxe.base import TableBase
|
|
6
9
|
from iceaxe.comparison import ComparisonType, FieldComparison
|
|
7
10
|
from iceaxe.field import DBFieldClassDefinition, DBFieldInfo
|
|
11
|
+
from iceaxe.typing import column
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
def test_comparison_type_enum():
|
|
@@ -17,6 +21,9 @@ def test_comparison_type_enum():
|
|
|
17
21
|
assert ComparisonType.IN == "IN"
|
|
18
22
|
assert ComparisonType.NOT_IN == "NOT IN"
|
|
19
23
|
assert ComparisonType.LIKE == "LIKE"
|
|
24
|
+
assert ComparisonType.NOT_LIKE == "NOT LIKE"
|
|
25
|
+
assert ComparisonType.ILIKE == "ILIKE"
|
|
26
|
+
assert ComparisonType.NOT_ILIKE == "NOT ILIKE"
|
|
20
27
|
assert ComparisonType.IS == "IS"
|
|
21
28
|
assert ComparisonType.IS_NOT == "IS NOT"
|
|
22
29
|
|
|
@@ -158,3 +165,80 @@ def test_comparison_with_different_types(db_field: DBFieldClassDefinition, value
|
|
|
158
165
|
assert result.left == db_field
|
|
159
166
|
assert isinstance(result.comparison, ComparisonType)
|
|
160
167
|
assert result.right == value
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
#
|
|
171
|
+
# Typehinting
|
|
172
|
+
# These checks are run as part of the static typechecking we do
|
|
173
|
+
# for our codebase, not as part of the pytest runtime.
|
|
174
|
+
#
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def test_typehint_like():
|
|
178
|
+
class UserDemo(TableBase):
|
|
179
|
+
id: int
|
|
180
|
+
value_str: str
|
|
181
|
+
value_int: int
|
|
182
|
+
|
|
183
|
+
str_col = column(UserDemo.value_str)
|
|
184
|
+
int_col = column(UserDemo.value_int)
|
|
185
|
+
|
|
186
|
+
assert_type(str_col, DBFieldClassDefinition[str])
|
|
187
|
+
assert_type(int_col, DBFieldClassDefinition[int])
|
|
188
|
+
|
|
189
|
+
assert_type(str_col.ilike("test"), bool)
|
|
190
|
+
assert_type(str_col.not_ilike("test"), bool)
|
|
191
|
+
assert_type(str_col.like("test"), bool)
|
|
192
|
+
assert_type(str_col.not_like("test"), bool)
|
|
193
|
+
|
|
194
|
+
with pyright_raises(
|
|
195
|
+
"reportAttributeAccessIssue",
|
|
196
|
+
matches=re_compile('Cannot access attribute "ilike"'),
|
|
197
|
+
):
|
|
198
|
+
int_col.ilike(5) # type: ignore
|
|
199
|
+
|
|
200
|
+
with pyright_raises(
|
|
201
|
+
"reportAttributeAccessIssue",
|
|
202
|
+
matches=re_compile('Cannot access attribute "ilike"'),
|
|
203
|
+
):
|
|
204
|
+
int_col.not_ilike(5) # type: ignore
|
|
205
|
+
|
|
206
|
+
with pyright_raises(
|
|
207
|
+
"reportAttributeAccessIssue",
|
|
208
|
+
matches=re_compile('Cannot access attribute "ilike"'),
|
|
209
|
+
):
|
|
210
|
+
int_col.like(5) # type: ignore
|
|
211
|
+
|
|
212
|
+
with pyright_raises(
|
|
213
|
+
"reportAttributeAccessIssue",
|
|
214
|
+
matches=re_compile('Cannot access attribute "ilike"'),
|
|
215
|
+
):
|
|
216
|
+
int_col.not_like(5) # type: ignore
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def test_typehint_in():
|
|
220
|
+
class UserDemo(TableBase):
|
|
221
|
+
id: int
|
|
222
|
+
value_str: str
|
|
223
|
+
value_int: int
|
|
224
|
+
|
|
225
|
+
str_col = column(UserDemo.value_str)
|
|
226
|
+
int_col = column(UserDemo.value_int)
|
|
227
|
+
|
|
228
|
+
assert_type(str_col.in_(["test"]), bool)
|
|
229
|
+
assert_type(int_col.in_([5]), bool)
|
|
230
|
+
|
|
231
|
+
assert_type(str_col.not_in(["test"]), bool)
|
|
232
|
+
assert_type(int_col.not_in([5]), bool)
|
|
233
|
+
|
|
234
|
+
with pyright_raises(
|
|
235
|
+
"reportArgumentType",
|
|
236
|
+
matches=re_compile('cannot be assigned to parameter "other"'),
|
|
237
|
+
):
|
|
238
|
+
str_col.in_(["test", 5]) # type: ignore
|
|
239
|
+
|
|
240
|
+
with pyright_raises(
|
|
241
|
+
"reportArgumentType",
|
|
242
|
+
matches=re_compile('cannot be assigned to parameter "other"'),
|
|
243
|
+
):
|
|
244
|
+
str_col.not_in(["test", 5]) # type: ignore
|
|
@@ -9,7 +9,11 @@ from iceaxe.queries import QueryBuilder, and_, or_, select
|
|
|
9
9
|
|
|
10
10
|
def test_select():
|
|
11
11
|
new_query = QueryBuilder().select(UserDemo)
|
|
12
|
-
assert new_query.build() == (
|
|
12
|
+
assert new_query.build() == (
|
|
13
|
+
'SELECT "userdemo"."id" as "userdemo_id", "userdemo"."name" as '
|
|
14
|
+
'"userdemo_name", "userdemo"."email" as "userdemo_email" FROM "userdemo"',
|
|
15
|
+
[],
|
|
16
|
+
)
|
|
13
17
|
|
|
14
18
|
|
|
15
19
|
def test_select_single_field():
|
|
@@ -263,3 +267,13 @@ def test_select_multiple_typehints():
|
|
|
263
267
|
query = select((UserDemo, UserDemo.id, UserDemo.name))
|
|
264
268
|
if TYPE_CHECKING:
|
|
265
269
|
_: QueryBuilder[tuple[UserDemo, int, str], Literal["SELECT"]] = query
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def test_allow_branching():
|
|
273
|
+
base_query = select(UserDemo)
|
|
274
|
+
|
|
275
|
+
query_1 = base_query.limit(1)
|
|
276
|
+
query_2 = base_query.limit(2)
|
|
277
|
+
|
|
278
|
+
assert query_1.limit_value == 1
|
|
279
|
+
assert query_2.limit_value == 2
|
|
@@ -281,6 +281,30 @@ async def test_select_join(db_connection: DBConnection):
|
|
|
281
281
|
]
|
|
282
282
|
|
|
283
283
|
|
|
284
|
+
@pytest.mark.asyncio
|
|
285
|
+
async def test_select_join_multiple_tables(db_connection: DBConnection):
|
|
286
|
+
user = UserDemo(name="John Doe", email="john@example.com")
|
|
287
|
+
await db_connection.insert([user])
|
|
288
|
+
assert user.id is not None
|
|
289
|
+
|
|
290
|
+
artifact = ArtifactDemo(title="Artifact 1", user_id=user.id)
|
|
291
|
+
await db_connection.insert([artifact])
|
|
292
|
+
|
|
293
|
+
new_query = (
|
|
294
|
+
QueryBuilder()
|
|
295
|
+
.select((ArtifactDemo, UserDemo))
|
|
296
|
+
.join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
|
|
297
|
+
.where(UserDemo.name == "John Doe")
|
|
298
|
+
)
|
|
299
|
+
result = await db_connection.exec(new_query)
|
|
300
|
+
assert result == [
|
|
301
|
+
(
|
|
302
|
+
ArtifactDemo(id=artifact.id, title="Artifact 1", user_id=user.id),
|
|
303
|
+
UserDemo(id=user.id, name="John Doe", email="john@example.com"),
|
|
304
|
+
)
|
|
305
|
+
]
|
|
306
|
+
|
|
307
|
+
|
|
284
308
|
@pytest.mark.asyncio
|
|
285
309
|
async def test_select_with_limit_and_offset(db_connection: DBConnection):
|
|
286
310
|
users = [
|
|
@@ -418,6 +442,32 @@ async def test_select_with_left_join(db_connection: DBConnection):
|
|
|
418
442
|
assert result[1] == ("John", 2)
|
|
419
443
|
|
|
420
444
|
|
|
445
|
+
@pytest.mark.asyncio
|
|
446
|
+
async def test_select_with_left_join_object(db_connection: DBConnection):
|
|
447
|
+
users = [
|
|
448
|
+
UserDemo(name="John", email="john@example.com"),
|
|
449
|
+
UserDemo(name="Jane", email="jane@example.com"),
|
|
450
|
+
]
|
|
451
|
+
await db_connection.insert(users)
|
|
452
|
+
|
|
453
|
+
posts = [
|
|
454
|
+
ArtifactDemo(title="John's Post", user_id=users[0].id),
|
|
455
|
+
ArtifactDemo(title="Another Post", user_id=users[0].id),
|
|
456
|
+
]
|
|
457
|
+
await db_connection.insert(posts)
|
|
458
|
+
|
|
459
|
+
query = (
|
|
460
|
+
QueryBuilder()
|
|
461
|
+
.select((UserDemo, ArtifactDemo))
|
|
462
|
+
.join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
|
|
463
|
+
)
|
|
464
|
+
result = await db_connection.exec(query)
|
|
465
|
+
assert len(result) == 3
|
|
466
|
+
assert result[0] == (users[0], posts[0])
|
|
467
|
+
assert result[1] == (users[0], posts[1])
|
|
468
|
+
assert result[2] == (users[1], None)
|
|
469
|
+
|
|
470
|
+
|
|
421
471
|
# @pytest.mark.asyncio
|
|
422
472
|
# async def test_select_with_subquery(db_connection: DBConnection):
|
|
423
473
|
# users = [
|
|
@@ -34,8 +34,8 @@ class DBModelMetaclass(_model_construction.ModelMetaclass):
|
|
|
34
34
|
mcs._cached_args[cls] = raw_kwargs
|
|
35
35
|
|
|
36
36
|
# If we have already set the class's fields, we should wrap them
|
|
37
|
-
if hasattr(cls, "
|
|
38
|
-
cls.
|
|
37
|
+
if hasattr(cls, "__pydantic_fields__"):
|
|
38
|
+
cls.__pydantic_fields__ = {
|
|
39
39
|
field: info
|
|
40
40
|
if isinstance(info, DBFieldInfo)
|
|
41
41
|
else DBFieldInfo.extend_field(
|
|
@@ -98,6 +98,14 @@ class DBModelMetaclass(_model_construction.ModelMetaclass):
|
|
|
98
98
|
|
|
99
99
|
return default
|
|
100
100
|
|
|
101
|
+
@property
|
|
102
|
+
def model_fields(self) -> dict[str, DBFieldInfo]: # type: ignore
|
|
103
|
+
# model_fields must be reimplemented in our custom metaclass, otherwise
|
|
104
|
+
# clients will get the super typehinting signature when they try
|
|
105
|
+
# to access Model.model_fields. This overrides the ClassVar typehint
|
|
106
|
+
# that's placed in the TableBase itself.
|
|
107
|
+
return super().model_fields # type: ignore
|
|
108
|
+
|
|
101
109
|
|
|
102
110
|
class UniqueConstraint(BaseModel):
|
|
103
111
|
columns: list[str]
|
|
@@ -136,3 +144,11 @@ class TableBase(BaseModel, metaclass=DBModelMetaclass):
|
|
|
136
144
|
if cls.table_name == PydanticUndefined:
|
|
137
145
|
return cls.__name__.lower()
|
|
138
146
|
return cls.table_name
|
|
147
|
+
|
|
148
|
+
@classmethod
|
|
149
|
+
def get_client_fields(cls):
|
|
150
|
+
return {
|
|
151
|
+
field: info
|
|
152
|
+
for field, info in cls.model_fields.items()
|
|
153
|
+
if field not in INTERNAL_TABLE_FIELDS
|
|
154
|
+
}
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
from abc import ABC, abstractmethod
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from enum import StrEnum
|
|
4
|
-
from typing import Any, Generic, Self, TypeVar
|
|
4
|
+
from typing import Any, Generic, Self, Sequence, TypeVar
|
|
5
5
|
|
|
6
6
|
from iceaxe.queries_str import QueryElementBase, QueryLiteral
|
|
7
7
|
from iceaxe.typing import is_column, is_comparison, is_comparison_group
|
|
8
8
|
|
|
9
9
|
T = TypeVar("T", bound="ComparisonBase")
|
|
10
|
+
J = TypeVar("J")
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
class ComparisonType(StrEnum):
|
|
@@ -18,7 +19,12 @@ class ComparisonType(StrEnum):
|
|
|
18
19
|
GE = ">="
|
|
19
20
|
IN = "IN"
|
|
20
21
|
NOT_IN = "NOT IN"
|
|
22
|
+
|
|
21
23
|
LIKE = "LIKE"
|
|
24
|
+
NOT_LIKE = "NOT LIKE"
|
|
25
|
+
ILIKE = "ILIKE"
|
|
26
|
+
NOT_ILIKE = "NOT ILIKE"
|
|
27
|
+
|
|
22
28
|
IS = "IS"
|
|
23
29
|
IS_NOT = "IS NOT"
|
|
24
30
|
|
|
@@ -95,7 +101,7 @@ class FieldComparisonGroup:
|
|
|
95
101
|
return QueryLiteral(queries), all_variables
|
|
96
102
|
|
|
97
103
|
|
|
98
|
-
class ComparisonBase(ABC):
|
|
104
|
+
class ComparisonBase(ABC, Generic[J]):
|
|
99
105
|
def __eq__(self, other): # type: ignore
|
|
100
106
|
if other is None:
|
|
101
107
|
return self._compare(ComparisonType.IS, None)
|
|
@@ -118,15 +124,32 @@ class ComparisonBase(ABC):
|
|
|
118
124
|
def __ge__(self, other):
|
|
119
125
|
return self._compare(ComparisonType.GE, other)
|
|
120
126
|
|
|
121
|
-
def in_(self, other) -> bool:
|
|
127
|
+
def in_(self, other: Sequence[J]) -> bool:
|
|
122
128
|
return self._compare(ComparisonType.IN, other) # type: ignore
|
|
123
129
|
|
|
124
|
-
def not_in(self, other) -> bool:
|
|
130
|
+
def not_in(self, other: Sequence[J]) -> bool:
|
|
125
131
|
return self._compare(ComparisonType.NOT_IN, other) # type: ignore
|
|
126
132
|
|
|
127
|
-
def like(
|
|
133
|
+
def like(
|
|
134
|
+
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
|
|
135
|
+
) -> bool:
|
|
128
136
|
return self._compare(ComparisonType.LIKE, other) # type: ignore
|
|
129
137
|
|
|
138
|
+
def not_like(
|
|
139
|
+
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
|
|
140
|
+
) -> bool:
|
|
141
|
+
return self._compare(ComparisonType.NOT_LIKE, other) # type: ignore
|
|
142
|
+
|
|
143
|
+
def ilike(
|
|
144
|
+
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
|
|
145
|
+
) -> bool:
|
|
146
|
+
return self._compare(ComparisonType.ILIKE, other) # type: ignore
|
|
147
|
+
|
|
148
|
+
def not_ilike(
|
|
149
|
+
self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
|
|
150
|
+
) -> bool:
|
|
151
|
+
return self._compare(ComparisonType.NOT_ILIKE, other) # type: ignore
|
|
152
|
+
|
|
130
153
|
def _compare(self, comparison: ComparisonType, other: Any) -> FieldComparison[Self]:
|
|
131
154
|
return FieldComparison(left=self, comparison=comparison, right=other)
|
|
132
155
|
|
|
@@ -4,8 +4,10 @@ from typing import (
|
|
|
4
4
|
Any,
|
|
5
5
|
Callable,
|
|
6
6
|
Concatenate,
|
|
7
|
+
Generic,
|
|
7
8
|
ParamSpec,
|
|
8
9
|
Type,
|
|
10
|
+
TypeVar,
|
|
9
11
|
Unpack,
|
|
10
12
|
cast,
|
|
11
13
|
)
|
|
@@ -23,6 +25,8 @@ if TYPE_CHECKING:
|
|
|
23
25
|
|
|
24
26
|
P = ParamSpec("P")
|
|
25
27
|
|
|
28
|
+
_Unset: Any = PydanticUndefined
|
|
29
|
+
|
|
26
30
|
|
|
27
31
|
class DBFieldInputs(_FieldInfoInputs, total=False):
|
|
28
32
|
primary_key: bool
|
|
@@ -114,11 +118,16 @@ def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # ty
|
|
|
114
118
|
index: bool = False,
|
|
115
119
|
check_expression: str | None = None,
|
|
116
120
|
is_json: bool = False,
|
|
117
|
-
default: Any =
|
|
121
|
+
default: Any = _Unset,
|
|
122
|
+
default_factory: (
|
|
123
|
+
Callable[[], Any] | Callable[[dict[str, Any]], Any] | None
|
|
124
|
+
) = _Unset,
|
|
118
125
|
*args: P.args,
|
|
119
126
|
**kwargs: P.kwargs,
|
|
120
127
|
):
|
|
121
|
-
raw_field = PydanticField(
|
|
128
|
+
raw_field = PydanticField(
|
|
129
|
+
default=default, default_factory=default_factory, **kwargs
|
|
130
|
+
) # type: ignore
|
|
122
131
|
|
|
123
132
|
# The Any request is required for us to be able to assign fields to any
|
|
124
133
|
# arbitrary type, like `value: str = Field()`
|
|
@@ -139,7 +148,10 @@ def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # ty
|
|
|
139
148
|
return func
|
|
140
149
|
|
|
141
150
|
|
|
142
|
-
|
|
151
|
+
T = TypeVar("T")
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class DBFieldClassDefinition(Generic[T], ComparisonBase[T]):
|
|
143
155
|
root_model: Type["TableBase"]
|
|
144
156
|
key: str
|
|
145
157
|
field_definition: DBFieldInfo
|