pum 1.2.3__py3-none-any.whl → 1.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pum/checker.py CHANGED
@@ -1,7 +1,85 @@
1
- import difflib
1
+ from dataclasses import dataclass, field
2
+ from datetime import datetime
3
+ from enum import Enum
4
+ import re
2
5
 
3
6
  import psycopg
4
7
 
8
+ from .connection import format_connection_string
9
+
10
+
11
+ class DifferenceType(Enum):
12
+ """Type of difference found."""
13
+
14
+ ADDED = "added"
15
+ REMOVED = "removed"
16
+
17
+
18
+ @dataclass
19
+ class DifferenceItem:
20
+ """Represents a single difference between databases."""
21
+
22
+ type: DifferenceType
23
+ content: dict | str # dict for structured data, str for backward compatibility
24
+
25
+ def __str__(self) -> str:
26
+ """String representation with marker."""
27
+ marker = "+" if self.type == DifferenceType.ADDED else "-"
28
+ if isinstance(self.content, dict):
29
+ # For structured content, create a readable string
30
+ return f"{marker} {self.content}"
31
+ return f"{marker} {self.content}"
32
+
33
+
34
+ @dataclass
35
+ class CheckResult:
36
+ """Result of a single check (e.g., tables, columns)."""
37
+
38
+ name: str
39
+ key: str
40
+ passed: bool
41
+ differences: list[DifferenceItem] = field(default_factory=list)
42
+
43
+ @property
44
+ def difference_count(self) -> int:
45
+ """Number of differences found."""
46
+ return len(self.differences)
47
+
48
+
49
+ @dataclass
50
+ class ComparisonReport:
51
+ """Complete database comparison report."""
52
+
53
+ pg_connection1: str
54
+ pg_connection2: str
55
+ timestamp: datetime
56
+ check_results: list[CheckResult] = field(default_factory=list)
57
+
58
+ @property
59
+ def passed(self) -> bool:
60
+ """Whether all checks passed."""
61
+ return all(result.passed for result in self.check_results)
62
+
63
+ @property
64
+ def total_checks(self) -> int:
65
+ """Total number of checks performed."""
66
+ return len(self.check_results)
67
+
68
+ @property
69
+ def passed_checks(self) -> int:
70
+ """Number of checks that passed."""
71
+ return sum(1 for result in self.check_results if result.passed)
72
+
73
+ @property
74
+ def failed_checks(self) -> int:
75
+ """Number of checks that failed."""
76
+ return self.total_checks - self.passed_checks
77
+
78
+ @property
79
+ def total_differences(self) -> int:
80
+ """Total number of differences across all checks."""
81
+ return sum(result.difference_count for result in self.check_results)
82
+
5
83
 
6
84
  class Checker:
7
85
  """This class is used to compare 2 Postgres databases and show the
@@ -10,42 +88,34 @@ class Checker:
10
88
 
11
89
  def __init__(
12
90
  self,
13
- pg_service1,
14
- pg_service2,
91
+ pg_connection1,
92
+ pg_connection2,
15
93
  exclude_schema=None,
16
94
  exclude_field_pattern=None,
17
95
  ignore_list=None,
18
- verbose_level=1,
19
96
  ):
20
- """Constructor
21
-
22
- Parameters
23
- ----------
24
- pg_service1: str
25
- The name of the postgres service (defined in pg_service.conf)
26
- related to the first db to be compared
27
- pg_service2: str
28
- The name of the postgres service (defined in pg_service.conf)
29
- related to the first db to be compared
30
- ignore_list: list(str)
31
- List of elements to be ignored in check (ex. tables, columns,
32
- views, ...)
33
- exclude_schema: list of strings
34
- List of schemas to be ignored in check.
35
- exclude_field_pattern: list of strings
36
- List of field patterns to be ignored in check.
37
- verbose_level: int
38
- verbose level, 0 -> nothing, 1 -> print first 80 char of each
39
- difference, 2 -> print all the difference details
40
-
97
+ """Initialize the Checker.
98
+
99
+ Args:
100
+ pg_connection1: PostgreSQL service name or connection string for the first database.
101
+ Can be a service name (e.g., 'mydb') or a full connection string
102
+ (e.g., 'postgresql://user:pass@host/db' or 'host=localhost dbname=mydb').
103
+ pg_connection2: PostgreSQL service name or connection string for the second database.
104
+ exclude_schema: List of schemas to be ignored in check.
105
+ exclude_field_pattern: List of field patterns to be ignored in check.
106
+ ignore_list: List of elements to be ignored in check (ex. tables, columns,
107
+ views, ...).
41
108
  """
42
- self.conn1 = psycopg.connect(f"service={pg_service1}")
109
+ self.pg_connection1 = pg_connection1
110
+ self.pg_connection2 = pg_connection2
111
+
112
+ self.conn1 = psycopg.connect(format_connection_string(pg_connection1))
43
113
  self.cur1 = self.conn1.cursor()
44
114
 
45
- self.conn2 = psycopg.connect(f"service={pg_service2}")
115
+ self.conn2 = psycopg.connect(format_connection_string(pg_connection2))
46
116
  self.cur2 = self.conn2.cursor()
47
117
 
48
- self.ignore_list = ignore_list
118
+ self.ignore_list = ignore_list or []
49
119
  self.exclude_schema = "('information_schema'"
50
120
  if exclude_schema is not None:
51
121
  for schema in exclude_schema:
@@ -53,67 +123,51 @@ class Checker:
53
123
  self.exclude_schema += ")"
54
124
  self.exclude_field_pattern = exclude_field_pattern or []
55
125
 
56
- self.verbose_level = verbose_level
57
-
58
- def run_checks(self):
126
+ def run_checks(self) -> ComparisonReport:
59
127
  """Run all the checks functions.
60
128
 
61
- Returns
62
- -------
63
- bool
64
- True if all the checks are true
65
- False otherwise
66
- dict
67
- Dictionary of lists of differences
68
-
129
+ Returns:
130
+ Complete comparison report with all check results.
69
131
  """
70
- result = True
71
- differences_dict = {}
72
-
73
- if "tables" not in self.ignore_list:
74
- tmp_result, differences_dict["tables"] = self.check_tables()
75
- result = False if not tmp_result else result
76
- if "columns" not in self.ignore_list:
77
- tmp_result, differences_dict["columns"] = self.check_columns(
78
- "views" not in self.ignore_list
79
- )
80
- result = False if not tmp_result else result
81
- if "constraints" not in self.ignore_list:
82
- tmp_result, differences_dict["constraints"] = self.check_constraints()
83
- result = False if not tmp_result else result
84
- if "views" not in self.ignore_list:
85
- tmp_result, differences_dict["views"] = self.check_views()
86
- result = False if not tmp_result else result
87
- if "sequences" not in self.ignore_list:
88
- tmp_result, differences_dict["sequences"] = self.check_sequences()
89
- result = False if not tmp_result else result
90
- if "indexes" not in self.ignore_list:
91
- tmp_result, differences_dict["indexes"] = self.check_indexes()
92
- result = False if not tmp_result else result
93
- if "triggers" not in self.ignore_list:
94
- tmp_result, differences_dict["triggers"] = self.check_triggers()
95
- result = False if not tmp_result else result
96
- if "functions" not in self.ignore_list:
97
- tmp_result, differences_dict["functions"] = self.check_functions()
98
- result = False if not tmp_result else result
99
- if "rules" not in self.ignore_list:
100
- tmp_result, differences_dict["rules"] = self.check_rules()
101
- result = False if not tmp_result else result
102
- if self.verbose_level == 0:
103
- differences_dict = None
104
- return result, differences_dict
132
+ checks = [
133
+ ("tables", "Tables", self.check_tables),
134
+ ("columns", "Columns", lambda: self.check_columns("views" not in self.ignore_list)),
135
+ ("constraints", "Constraints", self.check_constraints),
136
+ ("views", "Views", self.check_views),
137
+ ("sequences", "Sequences", self.check_sequences),
138
+ ("indexes", "Indexes", self.check_indexes),
139
+ ("triggers", "Triggers", self.check_triggers),
140
+ ("functions", "Functions", self.check_functions),
141
+ ("rules", "Rules", self.check_rules),
142
+ ]
143
+
144
+ check_results = []
145
+ for check_key, check_name, check_func in checks:
146
+ if check_key not in self.ignore_list:
147
+ passed, differences = check_func()
148
+ check_results.append(
149
+ CheckResult(
150
+ name=check_name,
151
+ key=check_key,
152
+ passed=passed,
153
+ differences=differences,
154
+ )
155
+ )
156
+
157
+ return ComparisonReport(
158
+ pg_connection1=self.pg_connection1,
159
+ pg_connection2=self.pg_connection2,
160
+ timestamp=datetime.now(),
161
+ check_results=check_results,
162
+ )
105
163
 
106
164
  def check_tables(self):
107
165
  """Check if the tables are equals.
108
166
 
109
- Returns
110
- -------
111
- bool
112
- True if the tables are the same
113
- False otherwise
114
- list
115
- A list with the differences
116
-
167
+ Returns:
168
+ tuple: A tuple containing:
169
+ - bool: True if the tables are the same, False otherwise.
170
+ - list: A list with the differences.
117
171
  """
118
172
  query = rf"""SELECT table_schema, table_name
119
173
  FROM information_schema.tables
@@ -128,74 +182,116 @@ class Checker:
128
182
  def check_columns(self, check_views=True):
129
183
  """Check if the columns in all tables are equals.
130
184
 
131
- Parameters
132
- ----------
133
- check_views: bool
134
- if True, check the columns of all the tables and views, if
135
- False check only the columns of the tables
136
-
137
- Returns
138
- -------
139
- bool
140
- True if the columns are the same
141
- False otherwise
142
- list
143
- A list with the differences
185
+ Args:
186
+ check_views: If True, check the columns of all the tables and views,
187
+ if False check only the columns of the tables.
144
188
 
189
+ Returns:
190
+ tuple: A tuple containing:
191
+ - bool: True if the columns are the same, False otherwise.
192
+ - list: A list with the differences.
145
193
  """
146
- with_query = None
194
+ # First, get the list of tables that exist in BOTH databases
195
+ # to avoid reporting columns from tables that don't exist in one DB
147
196
  if check_views:
148
- with_query = rf"""WITH table_list AS (
149
- SELECT table_schema, table_name
197
+ table_query = rf"""SELECT table_schema, table_name
150
198
  FROM information_schema.tables
151
199
  WHERE table_schema NOT IN {self.exclude_schema}
152
200
  AND table_schema NOT LIKE 'pg\_%'
153
201
  ORDER BY table_schema,table_name
154
- )"""
155
-
202
+ """
156
203
  else:
157
- with_query = rf"""WITH table_list AS (
158
- SELECT table_schema, table_name
204
+ table_query = rf"""SELECT table_schema, table_name
159
205
  FROM information_schema.tables
160
206
  WHERE table_schema NOT IN {self.exclude_schema}
161
207
  AND table_schema NOT LIKE 'pg\_%'
162
208
  AND table_type NOT LIKE 'VIEW'
163
209
  ORDER BY table_schema,table_name
164
- )"""
210
+ """
211
+
212
+ # Get tables from both databases
213
+ self.cur1.execute(table_query)
214
+ tables1 = set(self.cur1.fetchall())
215
+
216
+ self.cur2.execute(table_query)
217
+ tables2 = set(self.cur2.fetchall())
218
+
219
+ # Only check columns for tables that exist in both databases
220
+ common_tables = tables1.intersection(tables2)
221
+
222
+ if not common_tables:
223
+ # No common tables, so no columns to compare
224
+ return True, []
165
225
 
166
- query = """{wq}
226
+ # Build the WHERE clause to only include common tables
227
+ table_conditions = " OR ".join(
228
+ [
229
+ f"(isc.table_schema = '{schema}' AND isc.table_name = '{table}')"
230
+ for schema, table in common_tables
231
+ ]
232
+ )
233
+
234
+ query = f"""
167
235
  SELECT isc.table_schema, isc.table_name, column_name,
168
236
  column_default, is_nullable, data_type,
169
237
  character_maximum_length::text, numeric_precision::text,
170
238
  numeric_precision_radix::text, datetime_precision::text
171
- FROM information_schema.columns isc,
172
- table_list tl
173
- WHERE isc.table_schema = tl.table_schema
174
- AND isc.table_name = tl.table_name
175
- {efp}
239
+ FROM information_schema.columns isc
240
+ WHERE ({table_conditions})
241
+ {("".join([f" AND column_name NOT LIKE '{pattern}'" for pattern in self.exclude_field_pattern]))}
176
242
  ORDER BY isc.table_schema, isc.table_name, column_name
177
- """.format(
178
- wq=with_query,
179
- efp="".join(
180
- [f" AND column_name NOT LIKE '{pattern}'" for pattern in self.exclude_field_pattern]
181
- ),
182
- )
243
+ """
183
244
 
184
245
  return self.__check_equals(query)
185
246
 
186
247
  def check_constraints(self):
187
248
  """Check if the constraints are equals.
188
249
 
189
- Returns
190
- -------
191
- bool
192
- True if the constraints are the same
193
- False otherwise
194
- list
195
- A list with the differences
196
-
250
+ Returns:
251
+ tuple: A tuple containing:
252
+ - bool: True if the constraints are the same, False otherwise.
253
+ - list: A list with the differences.
197
254
  """
198
- query = f""" select
255
+ # Get tables from both databases to filter constraints
256
+ table_query = f"""SELECT table_schema, table_name
257
+ FROM information_schema.tables
258
+ WHERE table_schema NOT IN {self.exclude_schema}
259
+ AND table_schema NOT LIKE 'pg\\_%'
260
+ AND table_type NOT LIKE 'VIEW'
261
+ ORDER BY table_schema,table_name
262
+ """
263
+
264
+ self.cur1.execute(table_query)
265
+ tables1 = set(self.cur1.fetchall())
266
+
267
+ self.cur2.execute(table_query)
268
+ tables2 = set(self.cur2.fetchall())
269
+
270
+ # Only check constraints for tables that exist in both databases
271
+ common_tables = tables1.intersection(tables2)
272
+
273
+ if not common_tables:
274
+ return True, []
275
+
276
+ # Build the WHERE clause to only include common tables
277
+ table_conditions = " OR ".join(
278
+ [
279
+ f"(tc.constraint_schema = '{schema}' AND tc.table_name = '{table}')"
280
+ for schema, table in common_tables
281
+ ]
282
+ )
283
+
284
+ # Build WHERE clause for CHECK constraints
285
+ check_table_conditions = " OR ".join(
286
+ [
287
+ f"(n.nspname = '{schema}' AND cl.relname = '{table}')"
288
+ for schema, table in common_tables
289
+ ]
290
+ )
291
+
292
+ # Query for KEY constraints (PRIMARY KEY, FOREIGN KEY, UNIQUE)
293
+ key_query = f"""
294
+ SELECT
199
295
  tc.constraint_name,
200
296
  tc.constraint_schema || '.' || tc.table_name || '.' ||
201
297
  kcu.column_name as physical_full_name,
@@ -204,34 +300,76 @@ class Checker:
204
300
  kcu.column_name,
205
301
  ccu.table_name as foreign_table_name,
206
302
  ccu.column_name as foreign_column_name,
207
- tc.constraint_type
208
- from information_schema.table_constraints as tc
209
- join information_schema.key_column_usage as kcu on
210
- (tc.constraint_name = kcu.constraint_name and
303
+ tc.constraint_type,
304
+ pg_get_constraintdef((
305
+ SELECT con.oid FROM pg_constraint con
306
+ JOIN pg_namespace nsp ON con.connamespace = nsp.oid
307
+ WHERE con.conname = tc.constraint_name
308
+ AND nsp.nspname = tc.constraint_schema
309
+ LIMIT 1
310
+ )) as constraint_definition
311
+ FROM information_schema.table_constraints as tc
312
+ JOIN information_schema.key_column_usage as kcu ON
313
+ (tc.constraint_name = kcu.constraint_name AND
211
314
  tc.table_name = kcu.table_name)
212
- join information_schema.constraint_column_usage as ccu on
315
+ JOIN information_schema.constraint_column_usage as ccu ON
213
316
  ccu.constraint_name = tc.constraint_name
214
- WHERE tc.constraint_schema NOT IN {self.exclude_schema}
317
+ WHERE ({table_conditions})
215
318
  ORDER BY tc.constraint_schema, physical_full_name,
216
319
  tc.constraint_name, foreign_table_name,
217
- foreign_column_name """
320
+ foreign_column_name
321
+ """
322
+
323
+ # Query for CHECK constraints (they don't appear in key_column_usage)
324
+ check_query = f"""
325
+ SELECT
326
+ c.conname as constraint_name,
327
+ n.nspname || '.' || cl.relname as physical_full_name,
328
+ n.nspname as constraint_schema,
329
+ cl.relname as table_name,
330
+ '' as column_name,
331
+ '' as foreign_table_name,
332
+ '' as foreign_column_name,
333
+ 'CHECK' as constraint_type,
334
+ pg_get_constraintdef(c.oid) as constraint_definition
335
+ FROM pg_constraint c
336
+ JOIN pg_class cl ON c.conrelid = cl.oid
337
+ JOIN pg_namespace n ON cl.relnamespace = n.oid
338
+ WHERE c.contype = 'c'
339
+ AND ({check_table_conditions})
340
+ ORDER BY n.nspname, cl.relname, c.conname
341
+ """
342
+
343
+ # Normalization function for constraint records
344
+ def normalize_constraint_record(record_dict, col_names):
345
+ """Normalize constraint definitions in a record."""
346
+ normalized = record_dict.copy()
347
+ if "constraint_definition" in normalized and normalized["constraint_definition"]:
348
+ normalized["constraint_definition"] = self.__normalize_constraint_definition(
349
+ normalized["constraint_definition"]
350
+ )
351
+ return normalized
352
+
353
+ # Execute both queries and combine results
354
+ passed_keys, diffs_keys = self.__check_equals(
355
+ key_query, normalize_func=normalize_constraint_record
356
+ )
357
+ passed_checks, diffs_checks = self.__check_equals(
358
+ check_query, normalize_func=normalize_constraint_record
359
+ )
218
360
 
219
- return self.__check_equals(query)
361
+ return (passed_keys and passed_checks, diffs_keys + diffs_checks)
220
362
 
221
363
  def check_views(self):
222
364
  """Check if the views are equals.
223
365
 
224
- Returns
225
- -------
226
- bool
227
- True if the views are the same
228
- False otherwise
229
- list
230
- A list with the differences
231
-
366
+ Returns:
367
+ tuple: A tuple containing:
368
+ - bool: True if the views are the same, False otherwise.
369
+ - list: A list with the differences.
232
370
  """
233
371
  query = rf"""
234
- SELECT table_name, REPLACE(view_definition,'"','')
372
+ SELECT table_schema, table_name, REPLACE(view_definition,'"','')
235
373
  FROM INFORMATION_SCHEMA.views
236
374
  WHERE table_schema NOT IN {self.exclude_schema}
237
375
  AND table_schema NOT LIKE 'pg\_%'
@@ -244,14 +382,10 @@ class Checker:
244
382
  def check_sequences(self):
245
383
  """Check if the sequences are equals.
246
384
 
247
- Returns
248
- -------
249
- bool
250
- True if the sequences are the same
251
- False otherwise
252
- list
253
- A list with the differences
254
-
385
+ Returns:
386
+ tuple: A tuple containing:
387
+ - bool: True if the sequences are the same, False otherwise.
388
+ - list: A list with the differences.
255
389
  """
256
390
  query = f"""
257
391
  SELECT c.relname,
@@ -267,55 +401,76 @@ class Checker:
267
401
  def check_indexes(self):
268
402
  """Check if the indexes are equals.
269
403
 
270
- Returns
271
- -------
272
- bool
273
- True if the indexes are the same
274
- False otherwise
275
- list
276
- A list with the differences
277
-
404
+ Returns:
405
+ tuple: A tuple containing:
406
+ - bool: True if the indexes are the same, False otherwise.
407
+ - list: A list with the differences.
278
408
  """
409
+ # Get tables from both databases to filter indexes
410
+ table_query = f"""SELECT table_schema, table_name
411
+ FROM information_schema.tables
412
+ WHERE table_schema NOT IN {self.exclude_schema}
413
+ AND table_schema NOT LIKE 'pg\\_%'
414
+ AND table_type NOT LIKE 'VIEW'
415
+ ORDER BY table_schema,table_name
416
+ """
417
+
418
+ self.cur1.execute(table_query)
419
+ tables1 = set(self.cur1.fetchall())
420
+
421
+ self.cur2.execute(table_query)
422
+ tables2 = set(self.cur2.fetchall())
423
+
424
+ # Only check indexes for tables that exist in both databases
425
+ common_tables = tables1.intersection(tables2)
426
+
427
+ if not common_tables:
428
+ return True, []
429
+
430
+ # Build the WHERE clause to only include common tables
431
+ table_conditions = " OR ".join(
432
+ [
433
+ f"(ns.nspname = '{schema}' AND t.relname = '{table}')"
434
+ for schema, table in common_tables
435
+ ]
436
+ )
437
+
279
438
  query = rf"""
280
- select
439
+ SELECT
440
+ ns.nspname as schema_name,
281
441
  t.relname as table_name,
282
442
  i.relname as index_name,
283
443
  a.attname as column_name,
284
- ns.nspname as schema_name
285
- from
444
+ pg_get_indexdef(i.oid) as index_definition
445
+ FROM
286
446
  pg_class t,
287
447
  pg_class i,
288
448
  pg_index ix,
289
449
  pg_attribute a,
290
450
  pg_namespace ns
291
- where
451
+ WHERE
292
452
  t.oid = ix.indrelid
293
- and i.oid = ix.indexrelid
294
- and a.attrelid = t.oid
295
- and t.relnamespace = ns.oid
296
- and a.attnum = ANY(ix.indkey)
297
- and t.relkind = 'r'
298
- AND t.relname NOT IN ('information_schema')
299
- AND t.relname NOT LIKE 'pg\_%'
300
- AND ns.nspname NOT IN {self.exclude_schema}
301
- order by
453
+ AND i.oid = ix.indexrelid
454
+ AND a.attrelid = t.oid
455
+ AND t.relnamespace = ns.oid
456
+ AND a.attnum = ANY(ix.indkey)
457
+ AND t.relkind = 'r'
458
+ AND ({table_conditions})
459
+ ORDER BY
460
+ ns.nspname,
302
461
  t.relname,
303
462
  i.relname,
304
463
  a.attname
305
464
  """
306
465
  return self.__check_equals(query)
307
466
 
308
- def check_triggers(self) -> dict:
467
+ def check_triggers(self):
309
468
  """Check if the triggers are equals.
310
469
 
311
- Returns
312
- -------
313
- bool
314
- True if the triggers are the same
315
- False otherwise
316
- list
317
- A list with the differences
318
-
470
+ Returns:
471
+ tuple: A tuple containing:
472
+ - bool: True if the triggers are the same, False otherwise.
473
+ - list: A list with the differences.
319
474
  """
320
475
  query = f"""
321
476
  WITH trigger_list AS (
@@ -337,14 +492,10 @@ class Checker:
337
492
  def check_functions(self):
338
493
  """Check if the functions are equals.
339
494
 
340
- Returns
341
- -------
342
- bool
343
- True if the functions are the same
344
- False otherwise
345
- list
346
- A list with the differences
347
-
495
+ Returns:
496
+ tuple: A tuple containing:
497
+ - bool: True if the functions are the same, False otherwise.
498
+ - list: A list with the differences.
348
499
  """
349
500
  query = rf"""
350
501
  SELECT routines.routine_schema, routines.routine_name, parameters.data_type,
@@ -364,14 +515,10 @@ class Checker:
364
515
  def check_rules(self):
365
516
  """Check if the rules are equals.
366
517
 
367
- Returns
368
- -------
369
- bool
370
- True if the rules are the same
371
- False otherwise
372
- list
373
- A list with the differences
374
-
518
+ Returns:
519
+ tuple: A tuple containing:
520
+ - bool: True if the rules are the same, False otherwise.
521
+ - list: A list with the differences.
375
522
  """
376
523
  query = rf"""
377
524
  select n.nspname as rule_schema,
@@ -395,17 +542,62 @@ class Checker:
395
542
 
396
543
  return self.__check_equals(query)
397
544
 
398
- def __check_equals(self, query):
545
+ @staticmethod
546
+ def __normalize_constraint_definition(definition: str) -> str:
547
+ """Normalize a constraint definition for comparison.
548
+
549
+ PostgreSQL may represent functionally equivalent constraints differently,
550
+ especially after dump/restore operations. This function normalizes common
551
+ variations to enable accurate comparison.
552
+
553
+ Args:
554
+ definition: The constraint definition string from pg_get_constraintdef()
555
+
556
+ Returns:
557
+ Normalized constraint definition
558
+ """
559
+ if not definition:
560
+ return definition
561
+
562
+ # Normalize different ARRAY representations:
563
+ # Before: (ARRAY['a'::type, 'b'::type])::type[] OR ARRAY[('a'::type)::text, ...]
564
+ # After: Canonical form based on sorted elements
565
+
566
+ # Strategy: Extract the constraint type and key values, ignoring formatting details
567
+ # For ANY/ALL with arrays, extract just the operator and the array values
568
+
569
+ # Remove extra parentheses around ARRAY expressions
570
+ # (ARRAY[...])::type[] -> ARRAY[...]::type[]
571
+ definition = re.sub(r"\(\(ARRAY\[(.*?)\]\)::(.*?)\[\]\)", r"ARRAY[\1]::\2[]", definition)
572
+
573
+ # Also remove parentheses without cast: (ARRAY[...]) -> ARRAY[...]
574
+ definition = re.sub(r"\(ARRAY\[([^\]]+)\]\)", r"ARRAY[\1]", definition)
575
+
576
+ # Normalize array element casts: ('value'::type1)::type2 -> 'value'::type1
577
+ # This handles the case where elements are double-cast
578
+ definition = re.sub(r"\('([^']+)'::([^)]+)\)::(\w+)", r"'\1'::\2", definition)
579
+
580
+ # Remove trailing array cast that may be present or absent: ::text[] or ::character varying[]
581
+ # This is safe because the type information is already in each array element
582
+ definition = re.sub(r"::(?:text|character varying)\[\]", "", definition)
583
+
584
+ # Remove extra whitespace and normalize spacing
585
+ definition = re.sub(r"\s+", " ", definition).strip()
586
+
587
+ return definition
588
+
589
+ def __check_equals(self, query, normalize_func=None) -> tuple[bool, list[DifferenceItem]]:
399
590
  """Check if the query results on the two databases are equals.
400
591
 
401
- Returns
402
- -------
403
- bool
404
- True if the results are the same
405
- False otherwise
406
- list
407
- A list with the differences
592
+ Args:
593
+ query: The SQL query to execute on both databases.
594
+ normalize_func: Optional function to normalize specific fields in records.
595
+ Should accept (dict, col_names) and return normalized dict.
408
596
 
597
+ Returns:
598
+ tuple: A tuple containing:
599
+ - bool: True if the results are the same, False otherwise.
600
+ - list[DifferenceItem]: A list of DifferenceItem objects with structured data.
409
601
  """
410
602
  self.cur1.execute(query)
411
603
  records1 = self.cur1.fetchall()
@@ -416,16 +608,54 @@ class Checker:
416
608
  result = True
417
609
  differences = []
418
610
 
419
- d = difflib.Differ()
420
- records1 = [str(x) for x in records1]
421
- records2 = [str(x) for x in records2]
422
-
423
- for line in d.compare(records1, records2):
424
- if line[0] in ("-", "+"):
425
- result = False
426
- if self.verbose_level == 1:
427
- differences.append(line[0:79])
428
- elif self.verbose_level == 2:
429
- differences.append(line)
611
+ # Convert records to dictionaries based on column names
612
+ col_names = [desc[0] for desc in self.cur1.description]
613
+
614
+ # Create structured records
615
+ structured1 = [dict(zip(col_names, record)) for record in records1]
616
+ structured2 = [dict(zip(col_names, record)) for record in records2]
617
+
618
+ # Apply normalization if provided
619
+ if normalize_func:
620
+ structured1 = [normalize_func(r, col_names) for r in structured1]
621
+ structured2 = [normalize_func(r, col_names) for r in structured2]
622
+ # Recreate records from normalized structured data
623
+ records1 = [tuple(r[col] for col in col_names) for r in structured1]
624
+ records2 = [tuple(r[col] for col in col_names) for r in structured2]
625
+
626
+ # Create sets for comparison
627
+ set1 = {str(tuple(r)) for r in records1}
628
+ set2 = {str(tuple(r)) for r in records2}
629
+
630
+ # Find differences
631
+ removed = set1 - set2
632
+ added = set2 - set1
633
+
634
+ if removed or added:
635
+ result = False
636
+
637
+ # Map string representations back to structured data
638
+ str_to_struct1 = {str(tuple(r)): s for r, s in zip(records1, structured1)}
639
+ str_to_struct2 = {str(tuple(r)): s for r, s in zip(records2, structured2)}
640
+
641
+ # Add removed items
642
+ for item_str in removed:
643
+ if item_str in str_to_struct1:
644
+ differences.append(
645
+ DifferenceItem(
646
+ type=DifferenceType.REMOVED,
647
+ content=str_to_struct1[item_str],
648
+ )
649
+ )
650
+
651
+ # Add added items
652
+ for item_str in added:
653
+ if item_str in str_to_struct2:
654
+ differences.append(
655
+ DifferenceItem(
656
+ type=DifferenceType.ADDED,
657
+ content=str_to_struct2[item_str],
658
+ )
659
+ )
430
660
 
431
661
  return result, differences