iceaxe 0.2.3.dev1__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/PKG-INFO +1 -1
  2. iceaxe-0.3.0/iceaxe/__tests__/helpers.py +263 -0
  3. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_comparison.py +84 -0
  4. iceaxe-0.3.0/iceaxe/__tests__/test_helpers.py +9 -0
  5. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_queries.py +15 -1
  6. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_session.py +50 -0
  7. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/base.py +18 -2
  8. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/comparison.py +28 -5
  9. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/field.py +15 -3
  10. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/queries.py +100 -15
  11. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/db_memory_serializer.py +1 -6
  12. iceaxe-0.3.0/iceaxe/session_optimized.pyx +199 -0
  13. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/typing.py +5 -2
  14. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/pyproject.toml +1 -1
  15. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/setup.py +1 -1
  16. iceaxe-0.2.3.dev1/iceaxe/session_optimized.pyx +0 -102
  17. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/LICENSE +0 -0
  18. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/README.md +0 -0
  19. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/build.py +0 -0
  20. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/.DS_Store +0 -0
  21. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__init__.py +0 -0
  22. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/__init__.py +0 -0
  23. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/benchmarks/__init__.py +0 -0
  24. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/benchmarks/test_select.py +0 -0
  25. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/conf_models.py +0 -0
  26. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/conftest.py +0 -0
  27. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/__init__.py +0 -0
  28. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/conftest.py +0 -0
  29. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/test_action_sorter.py +0 -0
  30. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/test_generator.py +0 -0
  31. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/migrations/test_generics.py +0 -0
  32. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/mountaineer/__init__.py +0 -0
  33. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  34. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/mountaineer/dependencies/test_core.py +0 -0
  35. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/__init__.py +0 -0
  36. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_actions.py +0 -0
  37. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_cli.py +0 -0
  38. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_db_memory_serializer.py +0 -0
  39. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_db_serializer.py +0 -0
  40. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/schemas/test_db_stubs.py +0 -0
  41. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_base.py +0 -0
  42. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/__tests__/test_field.py +0 -0
  43. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/functions.py +0 -0
  44. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/generics.py +0 -0
  45. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/io.py +0 -0
  46. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/logging.py +0 -0
  47. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/__init__.py +0 -0
  48. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/action_sorter.py +0 -0
  49. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/cli.py +0 -0
  50. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/client_io.py +0 -0
  51. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/generator.py +0 -0
  52. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/migration.py +0 -0
  53. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/migrations/migrator.py +0 -0
  54. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/__init__.py +0 -0
  55. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/cli.py +0 -0
  56. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/config.py +0 -0
  57. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/dependencies/__init__.py +0 -0
  58. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/mountaineer/dependencies/core.py +0 -0
  59. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/postgres.py +0 -0
  60. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/py.typed +0 -0
  61. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/queries_str.py +0 -0
  62. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/__init__.py +0 -0
  63. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/actions.py +0 -0
  64. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/cli.py +0 -0
  65. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/db_serializer.py +0 -0
  66. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/schemas/db_stubs.py +0 -0
  67. {iceaxe-0.2.3.dev1 → iceaxe-0.3.0}/iceaxe/session.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: iceaxe
3
- Version: 0.2.3.dev1
3
+ Version: 0.3.0
4
4
  Summary: A modern, fast ORM for Python.
5
5
  Author: Pierce Freeman
6
6
  Author-email: pierce@freeman.vc
@@ -0,0 +1,263 @@
1
+ import ast
2
+ import inspect
3
+ import os
4
+ from contextlib import contextmanager
5
+ from dataclasses import dataclass
6
+ from json import JSONDecodeError, dump as json_dump, loads as json_loads
7
+ from re import Pattern
8
+ from tempfile import NamedTemporaryFile, TemporaryDirectory
9
+ from textwrap import dedent
10
+
11
+ from pyright import run
12
+
13
+
14
+ @dataclass
15
+ class PyrightDiagnostic:
16
+ file: str
17
+ severity: str
18
+ message: str
19
+ rule: str | None
20
+ line: int
21
+ column: int
22
+
23
+
24
+ class ExpectedPyrightError(Exception):
25
+ """
26
+ Exception raised when Pyright doesn't produce the expected error
27
+
28
+ """
29
+
30
+ pass
31
+
32
+
33
+ def get_imports_from_module(module_source: str) -> set[str]:
34
+ """
35
+ Extract all import statements from module source
36
+
37
+ """
38
+ tree = ast.parse(module_source)
39
+ imports: set[str] = set()
40
+
41
+ for node in ast.walk(tree):
42
+ if isinstance(node, ast.Import):
43
+ for name in node.names:
44
+ imports.add(f"import {name.name}")
45
+ elif isinstance(node, ast.ImportFrom):
46
+ names = ", ".join(name.name for name in node.names)
47
+ if node.module is None:
48
+ # Handle "from . import x" case
49
+ imports.add(f"from . import {names}")
50
+ else:
51
+ imports.add(f"from {node.module} import {names}")
52
+
53
+ return imports
54
+
55
+
56
+ def strip_type_ignore(line: str) -> str:
57
+ """
58
+ Strip type: ignore comments from a line while preserving the line content
59
+
60
+ """
61
+ if "#" not in line:
62
+ return line
63
+
64
+ # Split only on the first #
65
+ code_part, *comment_parts = line.split("#", 1)
66
+ if not comment_parts:
67
+ return line
68
+
69
+ comment = comment_parts[0]
70
+ # If this is a type: ignore comment, return just the code
71
+ if "type:" in comment and "ignore" in comment:
72
+ return code_part.rstrip()
73
+
74
+ # Otherwise return the full line
75
+ return line
76
+
77
+
78
+ def extract_current_function_code():
79
+ """
80
+ Extracts the source code of the function calling this utility,
81
+ along with any necessary imports at the module level. This only works for
82
+ functions in a pytest testing context that are prefixed with `test_`.
83
+
84
+ """
85
+ # Get the frame of the calling function
86
+ frame = inspect.currentframe()
87
+
88
+ try:
89
+ # Go up until we find the test function; workaround to not
90
+ # knowing the entrypoint of our contextmanager at runtime
91
+ while frame is not None:
92
+ func_name = frame.f_code.co_name
93
+ if func_name.startswith("test_"):
94
+ test_frame = frame
95
+ break
96
+ frame = frame.f_back
97
+ else:
98
+ raise RuntimeError("Could not find test function frame")
99
+
100
+ # Source code of the function
101
+ func_source = inspect.getsource(test_frame.f_code)
102
+
103
+ # Source code of the larger test file, which contains the test function
104
+ # All the imports used by the test function should be within this file
105
+ module = inspect.getmodule(test_frame)
106
+ if not module:
107
+ raise RuntimeError("Could not find module for test function")
108
+
109
+ module_source = inspect.getsource(module)
110
+
111
+ # Postprocess the source code to build into a valid new module
112
+ imports = get_imports_from_module(module_source)
113
+ filtered_lines = [strip_type_ignore(line) for line in func_source.split("\n")]
114
+ return "\n".join(sorted(imports)) + "\n\n" + dedent("\n".join(filtered_lines))
115
+
116
+ finally:
117
+ del frame # Avoid reference cycles
118
+
119
+
120
+ def create_pyright_config():
121
+ """
122
+ Creates a new pyright configuration that ignores unused imports or other
123
+ issues that are not related to context-manager wrapped type checking.
124
+
125
+ """
126
+ return {
127
+ "include": ["."],
128
+ "exclude": [],
129
+ "ignore": [],
130
+ "strict": [],
131
+ "typeCheckingMode": "strict",
132
+ "reportUnusedImport": False,
133
+ "reportUnusedVariable": False,
134
+ # Focus only on type checking
135
+ "reportOptionalMemberAccess": True,
136
+ "reportGeneralTypeIssues": True,
137
+ "reportPropertyTypeMismatch": True,
138
+ "reportFunctionMemberAccess": True,
139
+ "reportTypeCommentUsage": True,
140
+ "reportMissingTypeStubs": False,
141
+ # Only typehint intentional typehints, not inferred values
142
+ "reportUnknownParameterType": False,
143
+ "reportUnknownVariableType": False,
144
+ "reportUnknownMemberType": False,
145
+ "reportUnknownArgumentType": False,
146
+ "reportMissingParameterType": False,
147
+ }
148
+
149
+
150
+ def run_pyright(file_path: str) -> list[PyrightDiagnostic]:
151
+ """
152
+ Run pyright on a file and return the diagnostics
153
+
154
+ """
155
+ try:
156
+ with TemporaryDirectory() as temp_dir:
157
+ # Create pyright config
158
+ config_path = os.path.join(temp_dir, "pyrightconfig.json")
159
+ with open(config_path, "w") as f:
160
+ json_dump(create_pyright_config(), f)
161
+
162
+ # Copy the file to analyze into the project directory
163
+ test_file = os.path.join(temp_dir, "test.py")
164
+ with open(file_path, "r") as src, open(test_file, "w") as dst:
165
+ dst.write(src.read())
166
+
167
+ # Run pyright with the config
168
+ result = run(
169
+ "--project",
170
+ temp_dir,
171
+ "--outputjson",
172
+ test_file,
173
+ capture_output=True,
174
+ text=True,
175
+ )
176
+
177
+ try:
178
+ output = json_loads(result.stdout)
179
+ except JSONDecodeError:
180
+ print(f"Failed to parse pyright output: {result.stdout}") # noqa: T201
181
+ print(f"Stderr: {result.stderr}") # noqa: T201
182
+ raise
183
+
184
+ if "generalDiagnostics" not in output:
185
+ raise RuntimeError(
186
+ f"Unknown pyright output, missing generalDiagnostics: {output}"
187
+ )
188
+
189
+ diagnostics: list[PyrightDiagnostic] = []
190
+ for diag in output["generalDiagnostics"]:
191
+ diagnostics.append(
192
+ PyrightDiagnostic(
193
+ file=diag["file"],
194
+ severity=diag["severity"],
195
+ message=diag["message"],
196
+ rule=diag.get("rule"),
197
+ line=diag["range"]["start"]["line"] + 1, # Convert to 1-based
198
+ column=(
199
+ diag["range"]["start"]["character"]
200
+ + 1 # Convert to 1-based
201
+ ),
202
+ )
203
+ )
204
+
205
+ return diagnostics
206
+
207
+ except Exception as e:
208
+ raise RuntimeError(f"Failed to run pyright: {str(e)}")
209
+
210
+
211
+ @contextmanager
212
+ def pyright_raises(
213
+ expected_rule: str,
214
+ expected_line: int | None = None,
215
+ matches: Pattern | None = None,
216
+ ):
217
+ """
218
+ Context manager that verifies code produces a specific Pyright error.
219
+
220
+ :params expected_rule: The Pyright rule that should be violated
221
+ :params expected_line: Optional line number where the error should occur
222
+
223
+ :raises ExpectedPyrightError: If Pyright doesn't produce the expected error
224
+
225
+ """
226
+ # Create a temporary file to store the code
227
+ with NamedTemporaryFile(mode="w", suffix=".py") as temp_file:
228
+ temp_path = temp_file.name
229
+
230
+ # Extract the source code of the calling function
231
+ source_code = extract_current_function_code()
232
+ print(f"Running Pyright on:\n{source_code}") # noqa: T201
233
+
234
+ # Write the source code to the temporary file
235
+ temp_file.write(source_code)
236
+ temp_file.flush()
237
+
238
+ # At runtime, our actual code is probably a no-op but we still let it run
239
+ # inside the scope of the contextmanager
240
+ yield
241
+
242
+ # Run Pyright on the temporary file
243
+ diagnostics = run_pyright(temp_path)
244
+
245
+ # Check if any of the diagnostics match our expected error
246
+ for diagnostic in diagnostics:
247
+ if diagnostic.rule == expected_rule:
248
+ if expected_line is not None and diagnostic.line != expected_line:
249
+ continue
250
+ if matches and not matches.search(diagnostic.message):
251
+ continue
252
+ # Found matching error
253
+ return
254
+
255
+ # If we get here, we didn't find the expected error
256
+ actual_errors = [
257
+ f"{d.rule or 'unknown'} on line {d.line}: {d.message}" for d in diagnostics
258
+ ]
259
+ raise ExpectedPyrightError(
260
+ f"Expected Pyright error {expected_rule}"
261
+ f"{f' on line {expected_line}' if expected_line else ''}"
262
+ f" but got: {', '.join(actual_errors) if actual_errors else 'no errors'}"
263
+ )
@@ -1,10 +1,14 @@
1
+ from re import compile as re_compile
1
2
  from typing import Any
2
3
 
3
4
  import pytest
5
+ from typing_extensions import assert_type
4
6
 
7
+ from iceaxe.__tests__.helpers import pyright_raises
5
8
  from iceaxe.base import TableBase
6
9
  from iceaxe.comparison import ComparisonType, FieldComparison
7
10
  from iceaxe.field import DBFieldClassDefinition, DBFieldInfo
11
+ from iceaxe.typing import column
8
12
 
9
13
 
10
14
  def test_comparison_type_enum():
@@ -17,6 +21,9 @@ def test_comparison_type_enum():
17
21
  assert ComparisonType.IN == "IN"
18
22
  assert ComparisonType.NOT_IN == "NOT IN"
19
23
  assert ComparisonType.LIKE == "LIKE"
24
+ assert ComparisonType.NOT_LIKE == "NOT LIKE"
25
+ assert ComparisonType.ILIKE == "ILIKE"
26
+ assert ComparisonType.NOT_ILIKE == "NOT ILIKE"
20
27
  assert ComparisonType.IS == "IS"
21
28
  assert ComparisonType.IS_NOT == "IS NOT"
22
29
 
@@ -158,3 +165,80 @@ def test_comparison_with_different_types(db_field: DBFieldClassDefinition, value
158
165
  assert result.left == db_field
159
166
  assert isinstance(result.comparison, ComparisonType)
160
167
  assert result.right == value
168
+
169
+
170
+ #
171
+ # Typehinting
172
+ # These checks are run as part of the static typechecking we do
173
+ # for our codebase, not as part of the pytest runtime.
174
+ #
175
+
176
+
177
+ def test_typehint_like():
178
+ class UserDemo(TableBase):
179
+ id: int
180
+ value_str: str
181
+ value_int: int
182
+
183
+ str_col = column(UserDemo.value_str)
184
+ int_col = column(UserDemo.value_int)
185
+
186
+ assert_type(str_col, DBFieldClassDefinition[str])
187
+ assert_type(int_col, DBFieldClassDefinition[int])
188
+
189
+ assert_type(str_col.ilike("test"), bool)
190
+ assert_type(str_col.not_ilike("test"), bool)
191
+ assert_type(str_col.like("test"), bool)
192
+ assert_type(str_col.not_like("test"), bool)
193
+
194
+ with pyright_raises(
195
+ "reportAttributeAccessIssue",
196
+ matches=re_compile('Cannot access attribute "ilike"'),
197
+ ):
198
+ int_col.ilike(5) # type: ignore
199
+
200
+ with pyright_raises(
201
+ "reportAttributeAccessIssue",
202
+ matches=re_compile('Cannot access attribute "ilike"'),
203
+ ):
204
+ int_col.not_ilike(5) # type: ignore
205
+
206
+ with pyright_raises(
207
+ "reportAttributeAccessIssue",
208
+ matches=re_compile('Cannot access attribute "ilike"'),
209
+ ):
210
+ int_col.like(5) # type: ignore
211
+
212
+ with pyright_raises(
213
+ "reportAttributeAccessIssue",
214
+ matches=re_compile('Cannot access attribute "ilike"'),
215
+ ):
216
+ int_col.not_like(5) # type: ignore
217
+
218
+
219
+ def test_typehint_in():
220
+ class UserDemo(TableBase):
221
+ id: int
222
+ value_str: str
223
+ value_int: int
224
+
225
+ str_col = column(UserDemo.value_str)
226
+ int_col = column(UserDemo.value_int)
227
+
228
+ assert_type(str_col.in_(["test"]), bool)
229
+ assert_type(int_col.in_([5]), bool)
230
+
231
+ assert_type(str_col.not_in(["test"]), bool)
232
+ assert_type(int_col.not_in([5]), bool)
233
+
234
+ with pyright_raises(
235
+ "reportArgumentType",
236
+ matches=re_compile('cannot be assigned to parameter "other"'),
237
+ ):
238
+ str_col.in_(["test", 5]) # type: ignore
239
+
240
+ with pyright_raises(
241
+ "reportArgumentType",
242
+ matches=re_compile('cannot be assigned to parameter "other"'),
243
+ ):
244
+ str_col.not_in(["test", 5]) # type: ignore
@@ -0,0 +1,9 @@
1
+ from iceaxe.__tests__.helpers import pyright_raises
2
+
3
+
4
+ def test_basic_type_error():
5
+ def type_error_func(x: int) -> int:
6
+ return 10
7
+
8
+ with pyright_raises("reportArgumentType"):
9
+ type_error_func("20") # type: ignore
@@ -9,7 +9,11 @@ from iceaxe.queries import QueryBuilder, and_, or_, select
9
9
 
10
10
  def test_select():
11
11
  new_query = QueryBuilder().select(UserDemo)
12
- assert new_query.build() == ('SELECT "userdemo".* FROM "userdemo"', [])
12
+ assert new_query.build() == (
13
+ 'SELECT "userdemo"."id" as "userdemo_id", "userdemo"."name" as '
14
+ '"userdemo_name", "userdemo"."email" as "userdemo_email" FROM "userdemo"',
15
+ [],
16
+ )
13
17
 
14
18
 
15
19
  def test_select_single_field():
@@ -263,3 +267,13 @@ def test_select_multiple_typehints():
263
267
  query = select((UserDemo, UserDemo.id, UserDemo.name))
264
268
  if TYPE_CHECKING:
265
269
  _: QueryBuilder[tuple[UserDemo, int, str], Literal["SELECT"]] = query
270
+
271
+
272
+ def test_allow_branching():
273
+ base_query = select(UserDemo)
274
+
275
+ query_1 = base_query.limit(1)
276
+ query_2 = base_query.limit(2)
277
+
278
+ assert query_1.limit_value == 1
279
+ assert query_2.limit_value == 2
@@ -281,6 +281,30 @@ async def test_select_join(db_connection: DBConnection):
281
281
  ]
282
282
 
283
283
 
284
+ @pytest.mark.asyncio
285
+ async def test_select_join_multiple_tables(db_connection: DBConnection):
286
+ user = UserDemo(name="John Doe", email="john@example.com")
287
+ await db_connection.insert([user])
288
+ assert user.id is not None
289
+
290
+ artifact = ArtifactDemo(title="Artifact 1", user_id=user.id)
291
+ await db_connection.insert([artifact])
292
+
293
+ new_query = (
294
+ QueryBuilder()
295
+ .select((ArtifactDemo, UserDemo))
296
+ .join(UserDemo, UserDemo.id == ArtifactDemo.user_id)
297
+ .where(UserDemo.name == "John Doe")
298
+ )
299
+ result = await db_connection.exec(new_query)
300
+ assert result == [
301
+ (
302
+ ArtifactDemo(id=artifact.id, title="Artifact 1", user_id=user.id),
303
+ UserDemo(id=user.id, name="John Doe", email="john@example.com"),
304
+ )
305
+ ]
306
+
307
+
284
308
  @pytest.mark.asyncio
285
309
  async def test_select_with_limit_and_offset(db_connection: DBConnection):
286
310
  users = [
@@ -418,6 +442,32 @@ async def test_select_with_left_join(db_connection: DBConnection):
418
442
  assert result[1] == ("John", 2)
419
443
 
420
444
 
445
+ @pytest.mark.asyncio
446
+ async def test_select_with_left_join_object(db_connection: DBConnection):
447
+ users = [
448
+ UserDemo(name="John", email="john@example.com"),
449
+ UserDemo(name="Jane", email="jane@example.com"),
450
+ ]
451
+ await db_connection.insert(users)
452
+
453
+ posts = [
454
+ ArtifactDemo(title="John's Post", user_id=users[0].id),
455
+ ArtifactDemo(title="Another Post", user_id=users[0].id),
456
+ ]
457
+ await db_connection.insert(posts)
458
+
459
+ query = (
460
+ QueryBuilder()
461
+ .select((UserDemo, ArtifactDemo))
462
+ .join(ArtifactDemo, UserDemo.id == ArtifactDemo.user_id, "LEFT")
463
+ )
464
+ result = await db_connection.exec(query)
465
+ assert len(result) == 3
466
+ assert result[0] == (users[0], posts[0])
467
+ assert result[1] == (users[0], posts[1])
468
+ assert result[2] == (users[1], None)
469
+
470
+
421
471
  # @pytest.mark.asyncio
422
472
  # async def test_select_with_subquery(db_connection: DBConnection):
423
473
  # users = [
@@ -34,8 +34,8 @@ class DBModelMetaclass(_model_construction.ModelMetaclass):
34
34
  mcs._cached_args[cls] = raw_kwargs
35
35
 
36
36
  # If we have already set the class's fields, we should wrap them
37
- if hasattr(cls, "model_fields"):
38
- cls.model_fields = {
37
+ if hasattr(cls, "__pydantic_fields__"):
38
+ cls.__pydantic_fields__ = {
39
39
  field: info
40
40
  if isinstance(info, DBFieldInfo)
41
41
  else DBFieldInfo.extend_field(
@@ -98,6 +98,14 @@ class DBModelMetaclass(_model_construction.ModelMetaclass):
98
98
 
99
99
  return default
100
100
 
101
+ @property
102
+ def model_fields(self) -> dict[str, DBFieldInfo]: # type: ignore
103
+ # model_fields must be reimplemented in our custom metaclass, otherwise
104
+ # clients will get the super typehinting signature when they try
105
+ # to access Model.model_fields. This overrides the ClassVar typehint
106
+ # that's placed in the TableBase itself.
107
+ return super().model_fields # type: ignore
108
+
101
109
 
102
110
  class UniqueConstraint(BaseModel):
103
111
  columns: list[str]
@@ -136,3 +144,11 @@ class TableBase(BaseModel, metaclass=DBModelMetaclass):
136
144
  if cls.table_name == PydanticUndefined:
137
145
  return cls.__name__.lower()
138
146
  return cls.table_name
147
+
148
+ @classmethod
149
+ def get_client_fields(cls):
150
+ return {
151
+ field: info
152
+ for field, info in cls.model_fields.items()
153
+ if field not in INTERNAL_TABLE_FIELDS
154
+ }
@@ -1,12 +1,13 @@
1
1
  from abc import ABC, abstractmethod
2
2
  from dataclasses import dataclass
3
3
  from enum import StrEnum
4
- from typing import Any, Generic, Self, TypeVar
4
+ from typing import Any, Generic, Self, Sequence, TypeVar
5
5
 
6
6
  from iceaxe.queries_str import QueryElementBase, QueryLiteral
7
7
  from iceaxe.typing import is_column, is_comparison, is_comparison_group
8
8
 
9
9
  T = TypeVar("T", bound="ComparisonBase")
10
+ J = TypeVar("J")
10
11
 
11
12
 
12
13
  class ComparisonType(StrEnum):
@@ -18,7 +19,12 @@ class ComparisonType(StrEnum):
18
19
  GE = ">="
19
20
  IN = "IN"
20
21
  NOT_IN = "NOT IN"
22
+
21
23
  LIKE = "LIKE"
24
+ NOT_LIKE = "NOT LIKE"
25
+ ILIKE = "ILIKE"
26
+ NOT_ILIKE = "NOT ILIKE"
27
+
22
28
  IS = "IS"
23
29
  IS_NOT = "IS NOT"
24
30
 
@@ -95,7 +101,7 @@ class FieldComparisonGroup:
95
101
  return QueryLiteral(queries), all_variables
96
102
 
97
103
 
98
- class ComparisonBase(ABC):
104
+ class ComparisonBase(ABC, Generic[J]):
99
105
  def __eq__(self, other): # type: ignore
100
106
  if other is None:
101
107
  return self._compare(ComparisonType.IS, None)
@@ -118,15 +124,32 @@ class ComparisonBase(ABC):
118
124
  def __ge__(self, other):
119
125
  return self._compare(ComparisonType.GE, other)
120
126
 
121
- def in_(self, other) -> bool:
127
+ def in_(self, other: Sequence[J]) -> bool:
122
128
  return self._compare(ComparisonType.IN, other) # type: ignore
123
129
 
124
- def not_in(self, other) -> bool:
130
+ def not_in(self, other: Sequence[J]) -> bool:
125
131
  return self._compare(ComparisonType.NOT_IN, other) # type: ignore
126
132
 
127
- def like(self, other) -> bool:
133
+ def like(
134
+ self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
135
+ ) -> bool:
128
136
  return self._compare(ComparisonType.LIKE, other) # type: ignore
129
137
 
138
+ def not_like(
139
+ self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
140
+ ) -> bool:
141
+ return self._compare(ComparisonType.NOT_LIKE, other) # type: ignore
142
+
143
+ def ilike(
144
+ self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
145
+ ) -> bool:
146
+ return self._compare(ComparisonType.ILIKE, other) # type: ignore
147
+
148
+ def not_ilike(
149
+ self: "ComparisonBase[str] | ComparisonBase[str | None]", other: str
150
+ ) -> bool:
151
+ return self._compare(ComparisonType.NOT_ILIKE, other) # type: ignore
152
+
130
153
  def _compare(self, comparison: ComparisonType, other: Any) -> FieldComparison[Self]:
131
154
  return FieldComparison(left=self, comparison=comparison, right=other)
132
155
 
@@ -4,8 +4,10 @@ from typing import (
4
4
  Any,
5
5
  Callable,
6
6
  Concatenate,
7
+ Generic,
7
8
  ParamSpec,
8
9
  Type,
10
+ TypeVar,
9
11
  Unpack,
10
12
  cast,
11
13
  )
@@ -23,6 +25,8 @@ if TYPE_CHECKING:
23
25
 
24
26
  P = ParamSpec("P")
25
27
 
28
+ _Unset: Any = PydanticUndefined
29
+
26
30
 
27
31
  class DBFieldInputs(_FieldInfoInputs, total=False):
28
32
  primary_key: bool
@@ -114,11 +118,16 @@ def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # ty
114
118
  index: bool = False,
115
119
  check_expression: str | None = None,
116
120
  is_json: bool = False,
117
- default: Any = PydanticUndefined,
121
+ default: Any = _Unset,
122
+ default_factory: (
123
+ Callable[[], Any] | Callable[[dict[str, Any]], Any] | None
124
+ ) = _Unset,
118
125
  *args: P.args,
119
126
  **kwargs: P.kwargs,
120
127
  ):
121
- raw_field = PydanticField(default=default, **kwargs) # type: ignore
128
+ raw_field = PydanticField(
129
+ default=default, default_factory=default_factory, **kwargs
130
+ ) # type: ignore
122
131
 
123
132
  # The Any request is required for us to be able to assign fields to any
124
133
  # arbitrary type, like `value: str = Field()`
@@ -139,7 +148,10 @@ def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # ty
139
148
  return func
140
149
 
141
150
 
142
- class DBFieldClassDefinition(ComparisonBase):
151
+ T = TypeVar("T")
152
+
153
+
154
+ class DBFieldClassDefinition(Generic[T], ComparisonBase[T]):
143
155
  root_model: Type["TableBase"]
144
156
  key: str
145
157
  field_definition: DBFieldInfo