iceaxe 0.8.3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of iceaxe might be problematic. Click here for more details.

Files changed (75) hide show
  1. iceaxe/__init__.py +20 -0
  2. iceaxe/__tests__/__init__.py +0 -0
  3. iceaxe/__tests__/benchmarks/__init__.py +0 -0
  4. iceaxe/__tests__/benchmarks/test_bulk_insert.py +45 -0
  5. iceaxe/__tests__/benchmarks/test_select.py +114 -0
  6. iceaxe/__tests__/conf_models.py +133 -0
  7. iceaxe/__tests__/conftest.py +204 -0
  8. iceaxe/__tests__/docker_helpers.py +208 -0
  9. iceaxe/__tests__/helpers.py +268 -0
  10. iceaxe/__tests__/migrations/__init__.py +0 -0
  11. iceaxe/__tests__/migrations/conftest.py +36 -0
  12. iceaxe/__tests__/migrations/test_action_sorter.py +237 -0
  13. iceaxe/__tests__/migrations/test_generator.py +140 -0
  14. iceaxe/__tests__/migrations/test_generics.py +91 -0
  15. iceaxe/__tests__/mountaineer/__init__.py +0 -0
  16. iceaxe/__tests__/mountaineer/dependencies/__init__.py +0 -0
  17. iceaxe/__tests__/mountaineer/dependencies/test_core.py +76 -0
  18. iceaxe/__tests__/schemas/__init__.py +0 -0
  19. iceaxe/__tests__/schemas/test_actions.py +1265 -0
  20. iceaxe/__tests__/schemas/test_cli.py +25 -0
  21. iceaxe/__tests__/schemas/test_db_memory_serializer.py +1571 -0
  22. iceaxe/__tests__/schemas/test_db_serializer.py +435 -0
  23. iceaxe/__tests__/schemas/test_db_stubs.py +190 -0
  24. iceaxe/__tests__/test_alias.py +83 -0
  25. iceaxe/__tests__/test_base.py +52 -0
  26. iceaxe/__tests__/test_comparison.py +383 -0
  27. iceaxe/__tests__/test_field.py +11 -0
  28. iceaxe/__tests__/test_helpers.py +9 -0
  29. iceaxe/__tests__/test_modifications.py +151 -0
  30. iceaxe/__tests__/test_queries.py +764 -0
  31. iceaxe/__tests__/test_queries_str.py +173 -0
  32. iceaxe/__tests__/test_session.py +1511 -0
  33. iceaxe/__tests__/test_text_search.py +287 -0
  34. iceaxe/alias_values.py +67 -0
  35. iceaxe/base.py +351 -0
  36. iceaxe/comparison.py +560 -0
  37. iceaxe/field.py +263 -0
  38. iceaxe/functions.py +1432 -0
  39. iceaxe/generics.py +140 -0
  40. iceaxe/io.py +107 -0
  41. iceaxe/logging.py +91 -0
  42. iceaxe/migrations/__init__.py +5 -0
  43. iceaxe/migrations/action_sorter.py +98 -0
  44. iceaxe/migrations/cli.py +228 -0
  45. iceaxe/migrations/client_io.py +62 -0
  46. iceaxe/migrations/generator.py +404 -0
  47. iceaxe/migrations/migration.py +86 -0
  48. iceaxe/migrations/migrator.py +101 -0
  49. iceaxe/modifications.py +176 -0
  50. iceaxe/mountaineer/__init__.py +10 -0
  51. iceaxe/mountaineer/cli.py +74 -0
  52. iceaxe/mountaineer/config.py +46 -0
  53. iceaxe/mountaineer/dependencies/__init__.py +6 -0
  54. iceaxe/mountaineer/dependencies/core.py +67 -0
  55. iceaxe/postgres.py +133 -0
  56. iceaxe/py.typed +0 -0
  57. iceaxe/queries.py +1459 -0
  58. iceaxe/queries_str.py +294 -0
  59. iceaxe/schemas/__init__.py +0 -0
  60. iceaxe/schemas/actions.py +864 -0
  61. iceaxe/schemas/cli.py +30 -0
  62. iceaxe/schemas/db_memory_serializer.py +711 -0
  63. iceaxe/schemas/db_serializer.py +347 -0
  64. iceaxe/schemas/db_stubs.py +529 -0
  65. iceaxe/session.py +860 -0
  66. iceaxe/session_optimized.c +12207 -0
  67. iceaxe/session_optimized.cpython-313-darwin.so +0 -0
  68. iceaxe/session_optimized.pyx +212 -0
  69. iceaxe/sql_types.py +149 -0
  70. iceaxe/typing.py +73 -0
  71. iceaxe-0.8.3.dist-info/METADATA +262 -0
  72. iceaxe-0.8.3.dist-info/RECORD +75 -0
  73. iceaxe-0.8.3.dist-info/WHEEL +6 -0
  74. iceaxe-0.8.3.dist-info/licenses/LICENSE +21 -0
  75. iceaxe-0.8.3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,52 @@
1
+ from typing import Generic, TypeVar
2
+
3
+ from iceaxe.base import (
4
+ DBModelMetaclass,
5
+ TableBase,
6
+ )
7
+ from iceaxe.field import DBFieldInfo
8
+
9
+
10
+ def test_autodetect():
11
+ class WillAutodetect(TableBase):
12
+ pass
13
+
14
+ assert WillAutodetect in DBModelMetaclass.get_registry()
15
+
16
+
17
+ def test_not_autodetect():
18
+ class WillNotAutodetect(TableBase, autodetect=False):
19
+ pass
20
+
21
+ assert WillNotAutodetect not in DBModelMetaclass.get_registry()
22
+
23
+
24
+ def test_not_autodetect_generic(clear_registry):
25
+ T = TypeVar("T")
26
+
27
+ class GenericSuperclass(TableBase, Generic[T], autodetect=False):
28
+ value: T
29
+
30
+ class WillAutodetect(GenericSuperclass[int]):
31
+ pass
32
+
33
+ assert DBModelMetaclass.get_registry() == [WillAutodetect]
34
+
35
+
36
+ def test_model_fields():
37
+ class User(TableBase):
38
+ id: int
39
+ name: str
40
+
41
+ # Check the main fields
42
+ assert isinstance(User.model_fields["id"], DBFieldInfo)
43
+ assert User.model_fields["id"].annotation == int # noqa: E721
44
+ assert User.model_fields["id"].is_required() is True
45
+
46
+ assert isinstance(User.model_fields["name"], DBFieldInfo)
47
+ assert User.model_fields["name"].annotation == str # noqa: E721
48
+ assert User.model_fields["name"].is_required() is True
49
+
50
+ # Check that the special fields exist with the right types
51
+ assert isinstance(User.model_fields["modified_attrs"], DBFieldInfo)
52
+ assert isinstance(User.model_fields["modified_attrs_callbacks"], DBFieldInfo)
@@ -0,0 +1,383 @@
1
+ from re import compile as re_compile
2
+ from typing import Any
3
+
4
+ import pytest
5
+ from typing_extensions import assert_type
6
+
7
+ from iceaxe.__tests__.conf_models import UserDemo
8
+ from iceaxe.__tests__.helpers import pyright_raises
9
+ from iceaxe.base import TableBase
10
+ from iceaxe.comparison import ComparisonType, FieldComparison
11
+ from iceaxe.field import DBFieldClassDefinition, DBFieldInfo
12
+ from iceaxe.queries_str import QueryLiteral
13
+ from iceaxe.sql_types import ColumnType
14
+ from iceaxe.typing import column
15
+
16
+
17
+ def test_comparison_type_enum():
18
+ assert ComparisonType.EQ == "="
19
+ assert ComparisonType.NE == "!="
20
+ assert ComparisonType.LT == "<"
21
+ assert ComparisonType.LE == "<="
22
+ assert ComparisonType.GT == ">"
23
+ assert ComparisonType.GE == ">="
24
+ assert ComparisonType.IN == "IN"
25
+ assert ComparisonType.NOT_IN == "NOT IN"
26
+ assert ComparisonType.LIKE == "LIKE"
27
+ assert ComparisonType.NOT_LIKE == "NOT LIKE"
28
+ assert ComparisonType.ILIKE == "ILIKE"
29
+ assert ComparisonType.NOT_ILIKE == "NOT ILIKE"
30
+ assert ComparisonType.IS == "IS"
31
+ assert ComparisonType.IS_NOT == "IS NOT"
32
+ assert ComparisonType.IS_DISTINCT_FROM == "IS DISTINCT FROM"
33
+ assert ComparisonType.IS_NOT_DISTINCT_FROM == "IS NOT DISTINCT FROM"
34
+
35
+
36
+ @pytest.fixture
37
+ def db_field():
38
+ return DBFieldClassDefinition(
39
+ root_model=TableBase, key="test_key", field_definition=DBFieldInfo()
40
+ )
41
+
42
+
43
+ def test_eq(db_field: DBFieldClassDefinition):
44
+ result = db_field == 5
45
+ assert isinstance(result, FieldComparison)
46
+ assert result.left == db_field
47
+ assert result.comparison == ComparisonType.EQ
48
+ assert result.right == 5
49
+
50
+
51
+ def test_eq_none(db_field: DBFieldClassDefinition):
52
+ result = db_field == None # noqa: E711
53
+ assert isinstance(result, FieldComparison)
54
+ assert result.left == db_field
55
+ assert result.comparison == ComparisonType.IS
56
+ assert result.right is None
57
+
58
+
59
+ def test_ne(db_field: DBFieldClassDefinition):
60
+ result = db_field != 5
61
+ assert isinstance(result, FieldComparison)
62
+ assert result.left == db_field
63
+ assert result.comparison == ComparisonType.NE
64
+ assert result.right == 5
65
+
66
+
67
+ def test_ne_none(db_field: DBFieldClassDefinition):
68
+ result = db_field != None # noqa: E711
69
+ assert isinstance(result, FieldComparison)
70
+ assert result.left == db_field
71
+ assert result.comparison == ComparisonType.IS_NOT
72
+ assert result.right is None
73
+
74
+
75
+ def test_lt(db_field: DBFieldClassDefinition):
76
+ result = db_field < 5
77
+ assert isinstance(result, FieldComparison)
78
+ assert result.left == db_field
79
+ assert result.comparison == ComparisonType.LT
80
+ assert result.right == 5
81
+
82
+
83
+ def test_le(db_field):
84
+ result = db_field <= 5
85
+ assert isinstance(result, FieldComparison)
86
+ assert result.left == db_field
87
+ assert result.comparison == ComparisonType.LE
88
+ assert result.right == 5
89
+
90
+
91
+ def test_gt(db_field: DBFieldClassDefinition):
92
+ result = db_field > 5
93
+ assert isinstance(result, FieldComparison)
94
+ assert result.left == db_field
95
+ assert result.comparison == ComparisonType.GT
96
+ assert result.right == 5
97
+
98
+
99
+ def test_ge(db_field: DBFieldClassDefinition):
100
+ result = db_field >= 5
101
+ assert isinstance(result, FieldComparison)
102
+ assert result.left == db_field
103
+ assert result.comparison == ComparisonType.GE
104
+ assert result.right == 5
105
+
106
+
107
+ def test_in(db_field: DBFieldClassDefinition):
108
+ result = db_field.in_([1, 2, 3])
109
+ assert isinstance(result, FieldComparison)
110
+ assert result.left == db_field
111
+ assert result.comparison == ComparisonType.IN
112
+ assert result.right == [1, 2, 3]
113
+
114
+
115
+ def test_not_in(db_field: DBFieldClassDefinition):
116
+ result = db_field.not_in([1, 2, 3])
117
+ assert isinstance(result, FieldComparison)
118
+ assert result.left == db_field
119
+ assert result.comparison == ComparisonType.NOT_IN
120
+ assert result.right == [1, 2, 3]
121
+
122
+
123
+ def test_contains(db_field: DBFieldClassDefinition):
124
+ result = db_field.like("test")
125
+ assert isinstance(result, FieldComparison)
126
+ assert result.left == db_field
127
+ assert result.comparison == ComparisonType.LIKE
128
+ assert result.right == "test"
129
+
130
+
131
+ def test_compare(db_field: DBFieldClassDefinition):
132
+ result = db_field._compare(ComparisonType.EQ, 10)
133
+ assert isinstance(result, FieldComparison)
134
+ assert result.left == db_field
135
+ assert result.comparison == ComparisonType.EQ
136
+ assert result.right == 10
137
+
138
+
139
+ @pytest.mark.parametrize(
140
+ "value",
141
+ [
142
+ None,
143
+ "",
144
+ 0,
145
+ [],
146
+ {},
147
+ True,
148
+ False,
149
+ 3.14,
150
+ complex(1, 2),
151
+ DBFieldClassDefinition(
152
+ root_model=TableBase, key="other_key", field_definition=DBFieldInfo()
153
+ ),
154
+ ],
155
+ )
156
+ def test_comparison_with_different_types(db_field: DBFieldClassDefinition, value: Any):
157
+ for method in [
158
+ db_field.__eq__,
159
+ db_field.__ne__,
160
+ db_field.__lt__,
161
+ db_field.__le__,
162
+ db_field.__gt__,
163
+ db_field.__ge__,
164
+ db_field.in_,
165
+ db_field.not_in,
166
+ db_field.like,
167
+ ]:
168
+ result = method(value)
169
+ assert isinstance(result, FieldComparison)
170
+ assert result.left == db_field
171
+ assert isinstance(result.comparison, ComparisonType)
172
+ assert result.right == value
173
+
174
+
175
+ #
176
+ # Typehinting
177
+ # These checks are run as part of the static typechecking we do
178
+ # for our codebase, not as part of the pytest runtime.
179
+ #
180
+
181
+
182
+ def test_typehint_like():
183
+ class UserDemo(TableBase):
184
+ id: int
185
+ value_str: str
186
+ value_int: int
187
+
188
+ str_col = column(UserDemo.value_str)
189
+ int_col = column(UserDemo.value_int)
190
+
191
+ assert_type(str_col, DBFieldClassDefinition[str])
192
+ assert_type(int_col, DBFieldClassDefinition[int])
193
+
194
+ assert_type(str_col.ilike("test"), bool)
195
+ assert_type(str_col.not_ilike("test"), bool)
196
+ assert_type(str_col.like("test"), bool)
197
+ assert_type(str_col.not_like("test"), bool)
198
+
199
+ with pyright_raises(
200
+ "reportAttributeAccessIssue",
201
+ matches=re_compile('Cannot access attribute "ilike"'),
202
+ ):
203
+ int_col.ilike(5) # type: ignore
204
+
205
+ with pyright_raises(
206
+ "reportAttributeAccessIssue",
207
+ matches=re_compile('Cannot access attribute "ilike"'),
208
+ ):
209
+ int_col.not_ilike(5) # type: ignore
210
+
211
+ with pyright_raises(
212
+ "reportAttributeAccessIssue",
213
+ matches=re_compile('Cannot access attribute "ilike"'),
214
+ ):
215
+ int_col.like(5) # type: ignore
216
+
217
+ with pyright_raises(
218
+ "reportAttributeAccessIssue",
219
+ matches=re_compile('Cannot access attribute "ilike"'),
220
+ ):
221
+ int_col.not_like(5) # type: ignore
222
+
223
+
224
+ def test_typehint_in():
225
+ class UserDemo(TableBase):
226
+ id: int
227
+ value_str: str
228
+ value_int: int
229
+
230
+ str_col = column(UserDemo.value_str)
231
+ int_col = column(UserDemo.value_int)
232
+
233
+ assert_type(str_col.in_(["test"]), bool)
234
+ assert_type(int_col.in_([5]), bool)
235
+
236
+ assert_type(str_col.not_in(["test"]), bool)
237
+ assert_type(int_col.not_in([5]), bool)
238
+
239
+ with pyright_raises(
240
+ "reportArgumentType",
241
+ matches=re_compile('cannot be assigned to parameter "other"'),
242
+ ):
243
+ str_col.in_(["test", 5]) # type: ignore
244
+
245
+ with pyright_raises(
246
+ "reportArgumentType",
247
+ matches=re_compile('cannot be assigned to parameter "other"'),
248
+ ):
249
+ str_col.not_in(["test", 5]) # type: ignore
250
+
251
+
252
+ @pytest.mark.parametrize(
253
+ "comparison_type,expected_sql",
254
+ [
255
+ (ComparisonType.IN, '"userdemo"."name" = ANY($1)'),
256
+ (ComparisonType.NOT_IN, '"userdemo"."name" != ALL($1)'),
257
+ ],
258
+ )
259
+ def test_in_not_in_formatting(comparison_type: ComparisonType, expected_sql: str):
260
+ """
261
+ Test that in_ and not_in operators correctly format to ANY and ALL in SQL.
262
+ """
263
+ comparison = FieldComparison(
264
+ left=column(UserDemo.name), comparison=comparison_type, right=["John", "Jane"]
265
+ )
266
+ query, variables = comparison.to_query()
267
+ assert isinstance(query, QueryLiteral)
268
+ assert str(query) == expected_sql
269
+ assert variables == [["John", "Jane"]]
270
+
271
+
272
+ def test_default_eq_ne_are_null_safe(db_field: DBFieldClassDefinition):
273
+ """
274
+ Test that the default == and != operators use null-safe comparisons
275
+ """
276
+ # Test == None uses IS NULL
277
+ eq_none = db_field == None # noqa: E711
278
+ assert isinstance(eq_none, FieldComparison)
279
+ assert eq_none.comparison == ComparisonType.IS
280
+
281
+ # Test != None uses IS NOT NULL
282
+ ne_none = db_field != None # noqa: E711
283
+ assert isinstance(ne_none, FieldComparison)
284
+ assert ne_none.comparison == ComparisonType.IS_NOT
285
+
286
+ # Test == column uses IS NOT DISTINCT FROM
287
+ other_field = DBFieldClassDefinition(
288
+ root_model=TableBase, key="other_key", field_definition=DBFieldInfo()
289
+ )
290
+ eq_col = db_field == other_field
291
+ assert isinstance(eq_col, FieldComparison)
292
+ assert eq_col.comparison == ComparisonType.IS_NOT_DISTINCT_FROM
293
+
294
+ # Test != column uses IS DISTINCT FROM
295
+ ne_col = db_field != other_field
296
+ assert isinstance(ne_col, FieldComparison)
297
+ assert ne_col.comparison == ComparisonType.IS_DISTINCT_FROM
298
+
299
+
300
+ @pytest.mark.parametrize(
301
+ "magic_method,value",
302
+ [
303
+ ("__eq__", 5),
304
+ ("__ne__", 5),
305
+ ("__lt__", 5),
306
+ ("__le__", 5),
307
+ ("__gt__", 5),
308
+ ("__ge__", 5),
309
+ ],
310
+ )
311
+ def test_python_magic_methods_set_expression_flag(
312
+ db_field: DBFieldClassDefinition, magic_method: str, value: Any
313
+ ):
314
+ """
315
+ Test that all Python magic methods set python_expression to True
316
+ """
317
+ comparison = getattr(db_field, magic_method)(value)
318
+ assert isinstance(comparison, FieldComparison)
319
+ assert comparison.python_expression is True
320
+
321
+
322
+ @pytest.mark.parametrize(
323
+ "initial_comparison, python_expression, expected_comparison",
324
+ [
325
+ (ComparisonType.IS_NOT_DISTINCT_FROM, True, ComparisonType.EQ),
326
+ (ComparisonType.IS_DISTINCT_FROM, True, ComparisonType.NE),
327
+ (
328
+ ComparisonType.IS_NOT_DISTINCT_FROM,
329
+ False,
330
+ ComparisonType.IS_NOT_DISTINCT_FROM,
331
+ ),
332
+ (ComparisonType.IS_DISTINCT_FROM, False, ComparisonType.IS_DISTINCT_FROM),
333
+ ],
334
+ )
335
+ def test_force_join_constraints(
336
+ initial_comparison: ComparisonType,
337
+ python_expression: bool,
338
+ expected_comparison: ComparisonType,
339
+ ):
340
+ """
341
+ Test that force_join_constraints correctly transforms comparison types
342
+ """
343
+ db_field = DBFieldClassDefinition(
344
+ root_model=TableBase, key="test_key", field_definition=DBFieldInfo()
345
+ )
346
+ other_field = DBFieldClassDefinition(
347
+ root_model=TableBase, key="other_key", field_definition=DBFieldInfo()
348
+ )
349
+
350
+ comparison = FieldComparison(
351
+ left=db_field,
352
+ comparison=initial_comparison,
353
+ right=other_field,
354
+ python_expression=python_expression,
355
+ )
356
+ forced = comparison.force_join_constraints()
357
+ assert forced.comparison == expected_comparison
358
+
359
+
360
+ @pytest.mark.parametrize(
361
+ "sql_type_string, expected_column_type",
362
+ [
363
+ ("timestamp", ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE), # Tests aliasing
364
+ ("timestamp without time zone", ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE),
365
+ ("timestamp with time zone", ColumnType.TIMESTAMP_WITH_TIME_ZONE),
366
+ ("time", ColumnType.TIME_WITHOUT_TIME_ZONE), # Tests aliasing
367
+ ("time without time zone", ColumnType.TIME_WITHOUT_TIME_ZONE),
368
+ ("time with time zone", ColumnType.TIME_WITH_TIME_ZONE),
369
+ ],
370
+ )
371
+ def test_postgres_datetime_timezone_casting(
372
+ sql_type_string: str, expected_column_type: ColumnType
373
+ ):
374
+ """
375
+ Test that PostgresDateTime fields with different timezone configurations
376
+ are properly handled by the ColumnType enum, specifically testing that
377
+ PostgreSQL's storage format ('timestamp without time zone') can be parsed.
378
+ This also tests that SQL standard aliases like "timestamp" correctly map
379
+ to "timestamp without time zone".
380
+ """
381
+
382
+ # Test that ColumnType enum can handle PostgreSQL's storage formats and aliases
383
+ assert ColumnType(sql_type_string) == expected_column_type
@@ -0,0 +1,11 @@
1
+ from iceaxe.base import TableBase
2
+ from iceaxe.field import DBFieldClassDefinition, DBFieldInfo
3
+
4
+
5
+ def test_db_field_class_definition_instantiation():
6
+ field_def = DBFieldClassDefinition(
7
+ root_model=TableBase, key="test_key", field_definition=DBFieldInfo()
8
+ )
9
+ assert field_def.root_model == TableBase
10
+ assert field_def.key == "test_key"
11
+ assert isinstance(field_def.field_definition, DBFieldInfo)
@@ -0,0 +1,9 @@
1
+ from iceaxe.__tests__.helpers import pyright_raises
2
+
3
+
4
+ def test_basic_type_error():
5
+ def type_error_func(x: int) -> int:
6
+ return 10
7
+
8
+ with pyright_raises("reportArgumentType"):
9
+ type_error_func("20") # type: ignore
@@ -0,0 +1,151 @@
1
+ import logging
2
+
3
+ import pytest
4
+
5
+ from iceaxe.__tests__.conf_models import ArtifactDemo, UserDemo
6
+ from iceaxe.modifications import (
7
+ MODIFICATION_TRACKER_VERBOSITY,
8
+ Modification,
9
+ ModificationTracker,
10
+ )
11
+
12
+
13
+ @pytest.fixture
14
+ def tracker():
15
+ """Create a fresh ModificationTracker for each test."""
16
+ return ModificationTracker(known_first_party=["test_modifications"])
17
+
18
+
19
+ @pytest.fixture
20
+ def demo_instance():
21
+ """Create a demo model instance for testing."""
22
+ return UserDemo(id=1, name="test", email="test@example.com")
23
+
24
+
25
+ def test_get_current_stack_trace():
26
+ """Test that get_current_stack_trace returns both traces."""
27
+ full_trace, user_trace = Modification.get_current_stack_trace()
28
+
29
+ assert isinstance(full_trace, str)
30
+ assert isinstance(user_trace, str)
31
+ assert "test_modifications.py" in user_trace
32
+ assert len(full_trace) >= len(user_trace)
33
+
34
+
35
+ def test_track_modification_new_instance(
36
+ tracker: ModificationTracker, demo_instance: UserDemo
37
+ ):
38
+ """Test tracking a new modification."""
39
+ tracker.track_modification(demo_instance)
40
+
41
+ instance_id = id(demo_instance)
42
+ assert instance_id in tracker.modified_models
43
+
44
+ modification = tracker.modified_models[instance_id]
45
+ assert modification.instance == demo_instance
46
+ assert "test_modifications.py" in modification.user_stack_trace
47
+
48
+
49
+ def test_track_modification_duplicate(
50
+ tracker: ModificationTracker, demo_instance: UserDemo
51
+ ):
52
+ """Test that tracking the same instance twice only records it once."""
53
+ tracker.track_modification(demo_instance)
54
+ tracker.track_modification(demo_instance)
55
+
56
+ assert len(tracker.modified_models) == 1
57
+
58
+
59
+ def test_clear_status_single(tracker: ModificationTracker, demo_instance: UserDemo):
60
+ """Test committing a single model."""
61
+ tracker.track_modification(demo_instance)
62
+ tracker.clear_status([demo_instance])
63
+
64
+ assert id(demo_instance) not in tracker.modified_models
65
+
66
+
67
+ def test_clear_status_partial(tracker: ModificationTracker):
68
+ """Test committing some but not all models."""
69
+ instance1 = UserDemo(id=1, name="test1", email="test1@example.com")
70
+ instance2 = UserDemo(id=2, name="test2", email="test2@example.com")
71
+
72
+ tracker.track_modification(instance1)
73
+ tracker.track_modification(instance2)
74
+ tracker.clear_status([instance1])
75
+
76
+ assert id(instance1) not in tracker.modified_models
77
+ assert id(instance2) in tracker.modified_models
78
+ assert tracker.modified_models[id(instance2)].instance == instance2
79
+
80
+
81
+ @pytest.mark.parametrize("verbosity", ["ERROR", "WARNING", "INFO", None])
82
+ def test_log_with_different_verbosity(
83
+ tracker: ModificationTracker,
84
+ demo_instance: UserDemo,
85
+ verbosity: MODIFICATION_TRACKER_VERBOSITY,
86
+ caplog,
87
+ ):
88
+ """Test logging with different verbosity levels."""
89
+ tracker.verbosity = verbosity
90
+ tracker.track_modification(demo_instance)
91
+
92
+ with caplog.at_level(logging.INFO):
93
+ tracker.log()
94
+
95
+ if verbosity:
96
+ assert len(caplog.records) > 0
97
+ assert "Object modified locally but not committed" in caplog.records[0].message
98
+ if verbosity == "INFO":
99
+ assert any(
100
+ "Full stack trace" in record.message for record in caplog.records
101
+ )
102
+ else:
103
+ assert len(caplog.records) == 0
104
+
105
+
106
+ def test_multiple_model_types(tracker: ModificationTracker):
107
+ """Test tracking modifications for different model types."""
108
+ instance1 = UserDemo(id=1, name="test", email="test@example.com")
109
+ instance2 = ArtifactDemo(id=2, title="test", user_id=1)
110
+
111
+ tracker.track_modification(instance1)
112
+ tracker.track_modification(instance2)
113
+
114
+ assert len(tracker.modified_models) == 2
115
+ assert id(instance1) in tracker.modified_models
116
+ assert id(instance2) in tracker.modified_models
117
+
118
+
119
+ def test_clear_status_cleanup(tracker: ModificationTracker):
120
+ """Test that clear_status properly cleans up empty model lists."""
121
+ instance = UserDemo(id=1, name="test", email="test@example.com")
122
+ tracker.track_modification(instance)
123
+
124
+ assert id(instance) in tracker.modified_models
125
+ tracker.clear_status([instance])
126
+ assert id(instance) not in tracker.modified_models
127
+
128
+
129
+ def test_callback_registration(tracker: ModificationTracker):
130
+ """
131
+ Test that registering the tracker as a callback on a model instance
132
+ properly tracks modifications when the model is changed.
133
+ """
134
+ instance = UserDemo(id=1, name="test", email="test@example.com")
135
+ instance.register_modified_callback(tracker.track_modification)
136
+
137
+ # Initially no modifications
138
+ assert id(instance) not in tracker.modified_models
139
+
140
+ # Modify the instance
141
+ instance.name = "new name"
142
+
143
+ # Should have tracked the modification
144
+ assert id(instance) in tracker.modified_models
145
+ modification = tracker.modified_models[id(instance)]
146
+ assert modification.instance == instance
147
+ assert "test_modifications.py" in modification.user_stack_trace
148
+
149
+ # Another modification shouldn't create a new entry
150
+ instance.email = "new@example.com"
151
+ assert len(tracker.modified_models) == 1