pytest-neon 2.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pytest_neon/plugin.py ADDED
@@ -0,0 +1,2087 @@
1
+ """Pytest plugin providing Neon database branch fixtures.
2
+
3
+ This plugin provides fixtures for database testing using Neon's instant
4
+ branching feature. Multiple isolation levels are available:
5
+
6
+ Main fixtures:
7
+ neon_branch_readonly: True read-only access via read_only endpoint (enforced)
8
+ neon_branch_dirty: Session-scoped read-write, shared state across all tests
9
+ neon_branch_isolated: Per-worker branch with reset after each test (recommended)
10
+ neon_branch_readwrite: Deprecated, use neon_branch_isolated instead
11
+ neon_branch: Deprecated alias for neon_branch_isolated
12
+ neon_branch_shared: Shared branch without reset (module-scoped)
13
+
14
+ Connection fixtures (require extras):
15
+ neon_connection: psycopg2 connection (requires psycopg2 extra)
16
+ neon_connection_psycopg: psycopg v3 connection (requires psycopg extra)
17
+ neon_engine: SQLAlchemy engine (requires sqlalchemy extra)
18
+
19
+ Architecture:
20
+ Parent Branch (configured or project default)
21
+ └── Migration Branch (session-scoped, read_write endpoint)
22
+ │ ↑ migrations run here ONCE
23
+
24
+ ├── Read-only Endpoint (read_only endpoint ON migration branch)
25
+ │ ↑ neon_branch_readonly uses this
26
+
27
+ ├── Dirty Branch (session-scoped child, shared across ALL workers)
28
+ │ ↑ neon_branch_dirty uses this
29
+
30
+ └── Isolated Branch (one per xdist worker, lazily created)
31
+ ↑ neon_branch_isolated uses this, reset after each test
32
+
33
+ SQLAlchemy Users:
34
+ If you create your own SQLAlchemy engine (not using neon_engine fixture),
35
+ you MUST use pool_pre_ping=True when using neon_branch_isolated:
36
+
37
+ engine = create_engine(DATABASE_URL, pool_pre_ping=True)
38
+
39
+ This is required because branch resets terminate server-side connections.
40
+ Without pool_pre_ping, SQLAlchemy may try to reuse dead pooled connections,
41
+ causing "SSL connection has been closed unexpectedly" errors.
42
+
43
+ Configuration:
44
+ Set NEON_API_KEY and NEON_PROJECT_ID environment variables, or use
45
+ --neon-api-key and --neon-project-id CLI options.
46
+
47
+ For full documentation, see: https://github.com/ZainRizvi/pytest-neon
48
+ """
49
+
50
+ from __future__ import annotations
51
+
52
+ import contextlib
53
+ import json
54
+ import os
55
+ import random
56
+ import time
57
+ import warnings
58
+ from collections.abc import Callable, Generator
59
+ from dataclasses import asdict, dataclass
60
+ from datetime import datetime, timedelta, timezone
61
+ from typing import Any, TypeVar
62
+
63
+ import pytest
64
+ import requests
65
+ from filelock import FileLock
66
+ from neon_api import NeonAPI
67
+ from neon_api.exceptions import NeonAPIError
68
+ from neon_api.schema import EndpointState
69
+
70
+ T = TypeVar("T")
71
+
72
+ # Default branch expiry in seconds (10 minutes)
73
+ DEFAULT_BRANCH_EXPIRY_SECONDS = 600
74
+
75
+ # Rate limit retry configuration
76
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
77
+ # Neon limits: 700 requests/minute (~11/sec), burst up to 40/sec per route
78
+ _RATE_LIMIT_BASE_DELAY = 4.0 # seconds
79
+ _RATE_LIMIT_MAX_TOTAL_DELAY = 90.0 # 1.5 minutes total cap
80
+ _RATE_LIMIT_JITTER_FACTOR = 0.25 # +/- 25% jitter
81
+ _RATE_LIMIT_MAX_ATTEMPTS = 10 # Maximum number of retry attempts
82
+
83
+ # Sentinel value to detect when neon_apply_migrations was not overridden
84
+ _MIGRATIONS_NOT_DEFINED = object()
85
+
86
+
87
+ class NeonRateLimitError(Exception):
88
+ """Raised when Neon API rate limit is exceeded and retries are exhausted."""
89
+
90
+ pass
91
+
92
+
93
+ def _calculate_retry_delay(
94
+ attempt: int,
95
+ base_delay: float = _RATE_LIMIT_BASE_DELAY,
96
+ jitter_factor: float = _RATE_LIMIT_JITTER_FACTOR,
97
+ ) -> float:
98
+ """
99
+ Calculate delay for a retry attempt with exponential backoff and jitter.
100
+
101
+ Args:
102
+ attempt: The retry attempt number (0-indexed)
103
+ base_delay: Base delay in seconds
104
+ jitter_factor: Jitter factor (0.25 means +/- 25%)
105
+
106
+ Returns:
107
+ Delay in seconds with jitter applied
108
+ """
109
+ # Exponential backoff: base_delay * 2^attempt
110
+ delay = base_delay * (2**attempt)
111
+
112
+ # Apply jitter: delay * (1 +/- jitter_factor)
113
+ jitter = delay * jitter_factor * (2 * random.random() - 1)
114
+ return delay + jitter
115
+
116
+
117
+ def _is_rate_limit_error(exc: Exception) -> bool:
118
+ """
119
+ Check if an exception indicates a rate limit (429) error.
120
+
121
+ Handles both requests.HTTPError (with response object) and NeonAPIError
122
+ (which only has the error text, not the response object).
123
+
124
+ Args:
125
+ exc: The exception to check
126
+
127
+ Returns:
128
+ True if this is a rate limit error, False otherwise
129
+ """
130
+ # Check NeonAPIError first - it inherits from HTTPError but doesn't have
131
+ # a response object, so we need to check the error text
132
+ if isinstance(exc, NeonAPIError):
133
+ # NeonAPIError doesn't preserve the response object, only the text
134
+ # Check for rate limit indicators in the error message
135
+ # Note: We use "too many requests" specifically to avoid false positives
136
+ # from errors like "too many connections" or "too many rows"
137
+ error_text = str(exc).lower()
138
+ return (
139
+ "429" in error_text
140
+ or "rate limit" in error_text
141
+ or "too many requests" in error_text
142
+ )
143
+ if isinstance(exc, requests.HTTPError):
144
+ return exc.response is not None and exc.response.status_code == 429
145
+ return False
146
+
147
+
148
+ def _get_retry_after_from_error(exc: Exception) -> float | None:
149
+ """
150
+ Extract Retry-After header value from an exception if available.
151
+
152
+ Args:
153
+ exc: The exception to check
154
+
155
+ Returns:
156
+ The Retry-After value in seconds, or None if not available
157
+ """
158
+ if isinstance(exc, requests.HTTPError) and exc.response is not None:
159
+ retry_after = exc.response.headers.get("Retry-After")
160
+ if retry_after:
161
+ try:
162
+ return float(retry_after)
163
+ except ValueError:
164
+ pass
165
+ return None
166
+
167
+
168
+ def _retry_on_rate_limit(
169
+ operation: Callable[[], T],
170
+ operation_name: str,
171
+ base_delay: float = _RATE_LIMIT_BASE_DELAY,
172
+ max_total_delay: float = _RATE_LIMIT_MAX_TOTAL_DELAY,
173
+ jitter_factor: float = _RATE_LIMIT_JITTER_FACTOR,
174
+ max_attempts: int = _RATE_LIMIT_MAX_ATTEMPTS,
175
+ ) -> T:
176
+ """
177
+ Execute an operation with retry logic for rate limit (429) errors.
178
+
179
+ Uses exponential backoff with jitter. Retries until the operation succeeds,
180
+ the total delay exceeds max_total_delay, or max_attempts is reached.
181
+
182
+ See: https://api-docs.neon.tech/reference/api-rate-limiting
183
+
184
+ Args:
185
+ operation: Callable that may raise requests.HTTPError or NeonAPIError
186
+ operation_name: Human-readable name for error messages
187
+ base_delay: Base delay in seconds for first retry
188
+ max_total_delay: Maximum total delay across all retries
189
+ jitter_factor: Jitter factor for randomization
190
+ max_attempts: Maximum number of retry attempts
191
+
192
+ Returns:
193
+ The result of the operation
194
+
195
+ Raises:
196
+ NeonRateLimitError: If rate limit retries are exhausted
197
+ requests.HTTPError: For non-429 HTTP errors
198
+ NeonAPIError: For non-429 API errors
199
+ Exception: For other errors from the operation
200
+ """
201
+ total_delay = 0.0
202
+ attempt = 0
203
+
204
+ while True:
205
+ try:
206
+ return operation()
207
+ except (requests.HTTPError, NeonAPIError) as e:
208
+ if _is_rate_limit_error(e):
209
+ # Check for Retry-After header (may be added by Neon in future)
210
+ retry_after = _get_retry_after_from_error(e)
211
+ if retry_after is not None:
212
+ # Ensure minimum delay to prevent infinite loops if Retry-After is 0
213
+ delay = max(retry_after, 0.1)
214
+ else:
215
+ delay = _calculate_retry_delay(attempt, base_delay, jitter_factor)
216
+
217
+ # Check if we've exceeded max total delay
218
+ if total_delay + delay > max_total_delay:
219
+ raise NeonRateLimitError(
220
+ f"Rate limit exceeded for {operation_name}. "
221
+ f"Max total delay ({max_total_delay:.1f}s) reached after "
222
+ f"{attempt + 1} attempts. "
223
+ f"See: https://api-docs.neon.tech/reference/api-rate-limiting"
224
+ ) from e
225
+
226
+ # Check if we've exceeded max attempts
227
+ attempt += 1
228
+ if attempt >= max_attempts:
229
+ raise NeonRateLimitError(
230
+ f"Rate limit exceeded for {operation_name}. "
231
+ f"Max attempts ({max_attempts}) reached after "
232
+ f"{total_delay:.1f}s total delay. "
233
+ f"See: https://api-docs.neon.tech/reference/api-rate-limiting"
234
+ ) from e
235
+
236
+ time.sleep(delay)
237
+ total_delay += delay
238
+ else:
239
+ # Non-429 error, re-raise immediately
240
+ raise
241
+
242
+
243
+ def _get_xdist_worker_id() -> str:
244
+ """
245
+ Get the pytest-xdist worker ID, or "main" if not running under xdist.
246
+
247
+ When running tests in parallel with pytest-xdist, each worker process
248
+ gets a unique ID (gw0, gw1, gw2, etc.). This is used to create separate
249
+ branches per worker to avoid database state pollution between parallel tests.
250
+ """
251
+ return os.environ.get("PYTEST_XDIST_WORKER", "main")
252
+
253
+
254
+ def _sanitize_branch_name(name: str) -> str:
255
+ """
256
+ Sanitize a string for use in Neon branch names.
257
+
258
+ Only allows alphanumeric characters, hyphens, and underscores.
259
+ All other characters (including non-ASCII) are replaced with hyphens.
260
+ """
261
+ import re
262
+
263
+ # Replace anything that's not alphanumeric, hyphen, or underscore with hyphen
264
+ sanitized = re.sub(r"[^a-zA-Z0-9_-]", "-", name)
265
+ # Collapse multiple hyphens into one
266
+ sanitized = re.sub(r"-+", "-", sanitized)
267
+ # Remove leading/trailing hyphens
268
+ sanitized = sanitized.strip("-")
269
+ return sanitized
270
+
271
+
272
+ def _get_git_branch_name() -> str | None:
273
+ """
274
+ Get the current git branch name (sanitized), or None if not in a git repo.
275
+
276
+ Used to include the git branch in Neon branch names, making it easier
277
+ to identify which git branch/PR created orphaned test branches.
278
+
279
+ The branch name is sanitized to replace special characters with hyphens.
280
+ """
281
+ import subprocess
282
+
283
+ try:
284
+ result = subprocess.run(
285
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"],
286
+ capture_output=True,
287
+ text=True,
288
+ timeout=5,
289
+ )
290
+ if result.returncode == 0:
291
+ branch = result.stdout.strip()
292
+ return _sanitize_branch_name(branch) if branch else None
293
+ except (subprocess.TimeoutExpired, FileNotFoundError, OSError):
294
+ pass
295
+ return None
296
+
297
+
298
+ def _extract_password_from_connection_string(connection_string: str) -> str:
299
+ """Extract password from a PostgreSQL connection string."""
300
+ # Format: postgresql://user:password@host/db?params
301
+ from urllib.parse import urlparse
302
+
303
+ parsed = urlparse(connection_string)
304
+ if parsed.password:
305
+ return parsed.password
306
+ raise ValueError(f"No password found in connection string: {connection_string}")
307
+
308
+
309
+ def _get_schema_fingerprint(connection_string: str) -> tuple[tuple[Any, ...], ...]:
310
+ """
311
+ Get a fingerprint of the database schema for change detection.
312
+
313
+ Queries information_schema for all tables, columns, and their properties
314
+ in the public schema. Returns a hashable tuple that can be compared
315
+ before/after migrations to detect if the schema actually changed.
316
+
317
+ This is used to avoid creating unnecessary migration branches when
318
+ no actual schema changes occurred.
319
+ """
320
+ try:
321
+ import psycopg
322
+ except ImportError:
323
+ try:
324
+ import psycopg2 as psycopg # type: ignore[import-not-found]
325
+ except ImportError:
326
+ # No driver available - can't fingerprint, assume migrations changed things
327
+ return ()
328
+
329
+ with psycopg.connect(connection_string) as conn, conn.cursor() as cur:
330
+ cur.execute("""
331
+ SELECT table_name, column_name, data_type, is_nullable,
332
+ column_default, ordinal_position
333
+ FROM information_schema.columns
334
+ WHERE table_schema = 'public'
335
+ ORDER BY table_name, ordinal_position
336
+ """)
337
+ rows = cur.fetchall()
338
+ return tuple(tuple(row) for row in rows)
339
+
340
+
341
+ @dataclass
342
+ class NeonBranch:
343
+ """Information about a Neon test branch."""
344
+
345
+ branch_id: str
346
+ project_id: str
347
+ connection_string: str
348
+ host: str
349
+ parent_id: str | None = None
350
+ endpoint_id: str | None = None
351
+
352
+
353
+ @dataclass
354
+ class NeonConfig:
355
+ """Configuration for Neon operations. Extracted from pytest config."""
356
+
357
+ api_key: str
358
+ project_id: str
359
+ parent_branch_id: str | None
360
+ database_name: str
361
+ role_name: str
362
+ keep_branches: bool
363
+ branch_expiry: int
364
+ env_var_name: str
365
+
366
+ @classmethod
367
+ def from_pytest_config(cls, config: pytest.Config) -> NeonConfig | None:
368
+ """
369
+ Extract NeonConfig from pytest configuration.
370
+
371
+ Returns None if required values (api_key, project_id) are missing,
372
+ allowing callers to skip tests gracefully.
373
+ """
374
+ api_key = _get_config_value(
375
+ config, "neon_api_key", "NEON_API_KEY", "neon_api_key"
376
+ )
377
+ project_id = _get_config_value(
378
+ config, "neon_project_id", "NEON_PROJECT_ID", "neon_project_id"
379
+ )
380
+
381
+ if not api_key or not project_id:
382
+ return None
383
+
384
+ parent_branch_id = _get_config_value(
385
+ config, "neon_parent_branch", "NEON_PARENT_BRANCH_ID", "neon_parent_branch"
386
+ )
387
+ database_name = _get_config_value(
388
+ config, "neon_database", "NEON_DATABASE", "neon_database", "neondb"
389
+ )
390
+ role_name = _get_config_value(
391
+ config, "neon_role", "NEON_ROLE", "neon_role", "neondb_owner"
392
+ )
393
+
394
+ keep_branches = config.getoption("neon_keep_branches", default=None)
395
+ if keep_branches is None:
396
+ keep_branches = config.getini("neon_keep_branches")
397
+
398
+ branch_expiry = config.getoption("neon_branch_expiry", default=None)
399
+ if branch_expiry is None:
400
+ branch_expiry = int(config.getini("neon_branch_expiry"))
401
+
402
+ env_var_name = _get_config_value(
403
+ config, "neon_env_var", "", "neon_env_var", "DATABASE_URL"
404
+ )
405
+
406
+ return cls(
407
+ api_key=api_key,
408
+ project_id=project_id,
409
+ parent_branch_id=parent_branch_id,
410
+ database_name=database_name or "neondb",
411
+ role_name=role_name or "neondb_owner",
412
+ keep_branches=bool(keep_branches),
413
+ branch_expiry=branch_expiry or DEFAULT_BRANCH_EXPIRY_SECONDS,
414
+ env_var_name=env_var_name or "DATABASE_URL",
415
+ )
416
+
417
+
418
+ class NeonBranchManager:
419
+ """
420
+ Manages Neon branch lifecycle operations.
421
+
422
+ This class encapsulates all Neon API interactions for branch management,
423
+ making it easier to test and reason about branch operations.
424
+ """
425
+
426
+ def __init__(self, config: NeonConfig):
427
+ self.config = config
428
+ self._neon = NeonAPI(api_key=config.api_key)
429
+ self._default_branch_id: str | None = None
430
+ self._default_branch_id_fetched = False
431
+
432
+ def get_default_branch_id(self) -> str | None:
433
+ """Get the default/primary branch ID (cached)."""
434
+ if not self._default_branch_id_fetched:
435
+ self._default_branch_id = _get_default_branch_id(
436
+ self._neon, self.config.project_id
437
+ )
438
+ self._default_branch_id_fetched = True
439
+ return self._default_branch_id
440
+
441
+ def create_branch(
442
+ self,
443
+ name_suffix: str = "",
444
+ parent_branch_id: str | None = None,
445
+ expiry_seconds: int | None = None,
446
+ ) -> NeonBranch:
447
+ """
448
+ Create a new Neon branch with a read_write endpoint.
449
+
450
+ Args:
451
+ name_suffix: Suffix to add to branch name (e.g., "-migration", "-dirty")
452
+ parent_branch_id: Parent branch ID (defaults to config's parent)
453
+ expiry_seconds: Branch expiry in seconds (0 or None for no expiry)
454
+
455
+ Returns:
456
+ NeonBranch with connection details
457
+ """
458
+ parent_id = parent_branch_id or self.config.parent_branch_id
459
+
460
+ # Generate unique branch name
461
+ random_suffix = os.urandom(2).hex()
462
+ git_branch = _get_git_branch_name()
463
+ if git_branch:
464
+ git_prefix = git_branch[:15]
465
+ branch_name = f"pytest-{git_prefix}-{random_suffix}{name_suffix}"
466
+ else:
467
+ branch_name = f"pytest-{random_suffix}{name_suffix}"
468
+
469
+ # Build branch config
470
+ branch_config: dict[str, Any] = {"name": branch_name}
471
+ if parent_id:
472
+ branch_config["parent_id"] = parent_id
473
+
474
+ # Set expiry if specified
475
+ if expiry_seconds and expiry_seconds > 0:
476
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=expiry_seconds)
477
+ branch_config["expires_at"] = expires_at.strftime("%Y-%m-%dT%H:%M:%SZ")
478
+
479
+ # Create branch with read_write endpoint
480
+ result = _retry_on_rate_limit(
481
+ lambda: self._neon.branch_create(
482
+ project_id=self.config.project_id,
483
+ branch=branch_config,
484
+ endpoints=[{"type": "read_write"}],
485
+ ),
486
+ operation_name="branch_create",
487
+ )
488
+
489
+ branch = result.branch
490
+ endpoint_id = None
491
+ for op in result.operations:
492
+ if op.endpoint_id:
493
+ endpoint_id = op.endpoint_id
494
+ break
495
+
496
+ if not endpoint_id:
497
+ raise RuntimeError(f"No endpoint created for branch {branch.id}")
498
+
499
+ # Wait for endpoint to be active
500
+ host = self._wait_for_endpoint(endpoint_id)
501
+
502
+ # Safety check: never operate on default branch
503
+ default_branch_id = self.get_default_branch_id()
504
+ if default_branch_id and branch.id == default_branch_id:
505
+ raise RuntimeError(
506
+ f"SAFETY CHECK FAILED: Attempted to operate on default branch "
507
+ f"{branch.id}. Please report this bug."
508
+ )
509
+
510
+ # Get password
511
+ connection_string = self._reset_password_and_build_connection_string(
512
+ branch.id, host
513
+ )
514
+
515
+ return NeonBranch(
516
+ branch_id=branch.id,
517
+ project_id=self.config.project_id,
518
+ connection_string=connection_string,
519
+ host=host,
520
+ parent_id=branch.parent_id,
521
+ endpoint_id=endpoint_id,
522
+ )
523
+
524
+ def create_readonly_endpoint(self, branch: NeonBranch) -> NeonBranch:
525
+ """
526
+ Create a read_only endpoint on an existing branch.
527
+
528
+ This creates a true read-only endpoint that enforces no writes at the
529
+ database level.
530
+
531
+ Args:
532
+ branch: The branch to create the endpoint on
533
+
534
+ Returns:
535
+ NeonBranch with the read_only endpoint's connection details
536
+ """
537
+ result = _retry_on_rate_limit(
538
+ lambda: self._neon.endpoint_create(
539
+ project_id=self.config.project_id,
540
+ endpoint={
541
+ "branch_id": branch.branch_id,
542
+ "type": "read_only",
543
+ },
544
+ ),
545
+ operation_name="endpoint_create_readonly",
546
+ )
547
+
548
+ endpoint_id = result.endpoint.id
549
+ host = self._wait_for_endpoint(endpoint_id)
550
+
551
+ # Reuse the password from the parent branch's connection string.
552
+ # DO NOT call role_password_reset here - it would invalidate the
553
+ # password used by the parent branch's read_write endpoint, breaking
554
+ # any existing connections (especially in xdist where other workers
555
+ # may be using the cached connection string).
556
+ password = _extract_password_from_connection_string(branch.connection_string)
557
+ connection_string = (
558
+ f"postgresql://{self.config.role_name}:{password}@{host}/"
559
+ f"{self.config.database_name}?sslmode=require"
560
+ )
561
+
562
+ return NeonBranch(
563
+ branch_id=branch.branch_id,
564
+ project_id=self.config.project_id,
565
+ connection_string=connection_string,
566
+ host=host,
567
+ parent_id=branch.parent_id,
568
+ endpoint_id=endpoint_id,
569
+ )
570
+
571
+ def delete_branch(self, branch_id: str) -> None:
572
+ """Delete a branch (silently ignores errors)."""
573
+ if self.config.keep_branches:
574
+ return
575
+ try:
576
+ _retry_on_rate_limit(
577
+ lambda: self._neon.branch_delete(
578
+ project_id=self.config.project_id, branch_id=branch_id
579
+ ),
580
+ operation_name="branch_delete",
581
+ )
582
+ except Exception as e:
583
+ msg = f"Failed to delete Neon branch {branch_id}: {e}"
584
+ warnings.warn(msg, stacklevel=2)
585
+
586
+ def delete_endpoint(self, endpoint_id: str) -> None:
587
+ """Delete an endpoint (silently ignores errors)."""
588
+ try:
589
+ _retry_on_rate_limit(
590
+ lambda: self._neon.endpoint_delete(
591
+ project_id=self.config.project_id, endpoint_id=endpoint_id
592
+ ),
593
+ operation_name="endpoint_delete",
594
+ )
595
+ except Exception as e:
596
+ warnings.warn(
597
+ f"Failed to delete Neon endpoint {endpoint_id}: {e}", stacklevel=2
598
+ )
599
+
600
+ def reset_branch(self, branch: NeonBranch) -> None:
601
+ """Reset a branch to its parent's state."""
602
+ if not branch.parent_id:
603
+ msg = f"Branch {branch.branch_id} has no parent - cannot reset"
604
+ raise RuntimeError(msg)
605
+
606
+ _reset_branch_to_parent(branch, self.config.api_key)
607
+
608
+ def _wait_for_endpoint(self, endpoint_id: str, max_wait_seconds: float = 60) -> str:
609
+ """Wait for endpoint to become active and return its host."""
610
+ poll_interval = 0.5
611
+ waited = 0.0
612
+
613
+ while True:
614
+ endpoint_response = _retry_on_rate_limit(
615
+ lambda: self._neon.endpoint(
616
+ project_id=self.config.project_id, endpoint_id=endpoint_id
617
+ ),
618
+ operation_name="endpoint_status",
619
+ )
620
+ endpoint = endpoint_response.endpoint
621
+ state = endpoint.current_state
622
+
623
+ if state == EndpointState.active:
624
+ return endpoint.host
625
+
626
+ if waited >= max_wait_seconds:
627
+ raise RuntimeError(
628
+ f"Timeout waiting for endpoint {endpoint_id} to become active "
629
+ f"(current state: {state})"
630
+ )
631
+
632
+ time.sleep(poll_interval)
633
+ waited += poll_interval
634
+
635
+ def _reset_password_and_build_connection_string(
636
+ self, branch_id: str, host: str
637
+ ) -> str:
638
+ """Reset role password and build connection string."""
639
+ password_response = _retry_on_rate_limit(
640
+ lambda: self._neon.role_password_reset(
641
+ project_id=self.config.project_id,
642
+ branch_id=branch_id,
643
+ role_name=self.config.role_name,
644
+ ),
645
+ operation_name="role_password_reset",
646
+ )
647
+ password = password_response.role.password
648
+
649
+ return (
650
+ f"postgresql://{self.config.role_name}:{password}@{host}/"
651
+ f"{self.config.database_name}?sslmode=require"
652
+ )
653
+
654
+
655
+ class XdistCoordinator:
656
+ """
657
+ Coordinates branch sharing across pytest-xdist workers.
658
+
659
+ Uses file locks and JSON cache files to ensure only one worker creates
660
+ shared resources (like the migration branch), while others reuse them.
661
+ """
662
+
663
+ def __init__(self, tmp_path_factory: pytest.TempPathFactory):
664
+ self.worker_id = _get_xdist_worker_id()
665
+ self.is_xdist = self.worker_id != "main"
666
+
667
+ if self.is_xdist:
668
+ root_tmp_dir = tmp_path_factory.getbasetemp().parent
669
+ self._lock_dir = root_tmp_dir
670
+ else:
671
+ self._lock_dir = None
672
+
673
+ def coordinate_resource(
674
+ self,
675
+ resource_name: str,
676
+ create_fn: Callable[[], dict[str, Any]],
677
+ ) -> tuple[dict[str, Any], bool]:
678
+ """
679
+ Coordinate creation of a shared resource across workers.
680
+
681
+ Args:
682
+ resource_name: Name of the resource (used for cache/lock files)
683
+ create_fn: Function to create the resource, returns dict to cache
684
+
685
+ Returns:
686
+ Tuple of (cached_data, is_creator)
687
+ """
688
+ if not self.is_xdist:
689
+ return create_fn(), True
690
+
691
+ assert self._lock_dir is not None
692
+ cache_file = self._lock_dir / f"neon_{resource_name}.json"
693
+ lock_file = self._lock_dir / f"neon_{resource_name}.lock"
694
+
695
+ with FileLock(str(lock_file)):
696
+ if cache_file.exists():
697
+ data = json.loads(cache_file.read_text())
698
+ return data, False
699
+ else:
700
+ data = create_fn()
701
+ cache_file.write_text(json.dumps(data))
702
+ return data, True
703
+
704
+ def wait_for_signal(self, signal_name: str, timeout: float = 60) -> None:
705
+ """Wait for a signal file to be created by another worker."""
706
+ if not self.is_xdist or self._lock_dir is None:
707
+ return
708
+
709
+ signal_file = self._lock_dir / f"neon_{signal_name}"
710
+ waited = 0.0
711
+ poll_interval = 0.5
712
+
713
+ while not signal_file.exists():
714
+ if waited >= timeout:
715
+ raise RuntimeError(
716
+ f"Worker {self.worker_id} timed out waiting for signal "
717
+ f"'{signal_name}' after {timeout}s. This usually means the "
718
+ f"creator worker failed or is still processing."
719
+ )
720
+ time.sleep(poll_interval)
721
+ waited += poll_interval
722
+
723
+ def send_signal(self, signal_name: str) -> None:
724
+ """Create a signal file for other workers."""
725
+ if not self.is_xdist or self._lock_dir is None:
726
+ return
727
+
728
+ signal_file = self._lock_dir / f"neon_{signal_name}"
729
+ signal_file.write_text("done")
730
+
731
+
732
+ class EnvironmentManager:
733
+ """Manages DATABASE_URL environment variable lifecycle."""
734
+
735
+ def __init__(self, env_var_name: str = "DATABASE_URL"):
736
+ self.env_var_name = env_var_name
737
+ self._original_value: str | None = None
738
+ self._is_set = False
739
+
740
+ def set(self, connection_string: str) -> None:
741
+ """Set the environment variable, saving original value."""
742
+ if not self._is_set:
743
+ self._original_value = os.environ.get(self.env_var_name)
744
+ self._is_set = True
745
+ os.environ[self.env_var_name] = connection_string
746
+
747
+ def restore(self) -> None:
748
+ """Restore the original environment variable value."""
749
+ if not self._is_set:
750
+ return
751
+
752
+ if self._original_value is None:
753
+ os.environ.pop(self.env_var_name, None)
754
+ else:
755
+ os.environ[self.env_var_name] = self._original_value
756
+
757
+ self._is_set = False
758
+
759
+ @contextlib.contextmanager
760
+ def temporary(self, connection_string: str) -> Generator[None, None, None]:
761
+ """Context manager for temporary environment variable."""
762
+ self.set(connection_string)
763
+ try:
764
+ yield
765
+ finally:
766
+ self.restore()
767
+
768
+
769
+ def _get_default_branch_id(neon: NeonAPI, project_id: str) -> str | None:
770
+ """
771
+ Get the default/primary branch ID for a project.
772
+
773
+ This is used as a safety check to ensure we never accidentally
774
+ perform destructive operations (like password reset) on the
775
+ production branch.
776
+
777
+ Returns:
778
+ The branch ID of the default branch, or None if not found.
779
+ """
780
+ try:
781
+ # Wrap in retry logic to handle rate limits
782
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
783
+ response = _retry_on_rate_limit(
784
+ lambda: neon.branches(project_id=project_id),
785
+ operation_name="list_branches",
786
+ )
787
+ for branch in response.branches:
788
+ # Check both 'default' and 'primary' flags for compatibility
789
+ if getattr(branch, "default", False) or getattr(branch, "primary", False):
790
+ return branch.id
791
+ except Exception:
792
+ # If we can't fetch branches, don't block - the safety check
793
+ # will be skipped but tests can still run
794
+ pass
795
+ return None
796
+
797
+
798
+ def pytest_addoption(parser: pytest.Parser) -> None:
799
+ """Add Neon-specific command line options and ini settings."""
800
+ group = parser.getgroup("neon", "Neon database branching")
801
+
802
+ # CLI options
803
+ group.addoption(
804
+ "--neon-api-key",
805
+ dest="neon_api_key",
806
+ help="Neon API key (default: NEON_API_KEY env var)",
807
+ )
808
+ group.addoption(
809
+ "--neon-project-id",
810
+ dest="neon_project_id",
811
+ help="Neon project ID (default: NEON_PROJECT_ID env var)",
812
+ )
813
+ group.addoption(
814
+ "--neon-parent-branch",
815
+ dest="neon_parent_branch",
816
+ help="Parent branch ID to create test branches from (default: project default)",
817
+ )
818
+ group.addoption(
819
+ "--neon-database",
820
+ dest="neon_database",
821
+ help="Database name (default: neondb)",
822
+ )
823
+ group.addoption(
824
+ "--neon-role",
825
+ dest="neon_role",
826
+ help="Database role (default: neondb_owner)",
827
+ )
828
+ group.addoption(
829
+ "--neon-keep-branches",
830
+ action="store_true",
831
+ dest="neon_keep_branches",
832
+ help="Don't delete branches after tests (useful for debugging)",
833
+ )
834
+ group.addoption(
835
+ "--neon-branch-expiry",
836
+ dest="neon_branch_expiry",
837
+ type=int,
838
+ help=(
839
+ f"Branch auto-expiry in seconds "
840
+ f"(default: {DEFAULT_BRANCH_EXPIRY_SECONDS}). Set to 0 to disable."
841
+ ),
842
+ )
843
+ group.addoption(
844
+ "--neon-env-var",
845
+ dest="neon_env_var",
846
+ help="Environment variable to set with connection string (default: DATABASE_URL)", # noqa: E501
847
+ )
848
+
849
+ # INI file settings (pytest.ini, pyproject.toml, etc.)
850
+ parser.addini("neon_api_key", "Neon API key", default=None)
851
+ parser.addini("neon_project_id", "Neon project ID", default=None)
852
+ parser.addini("neon_parent_branch", "Parent branch ID", default=None)
853
+ parser.addini("neon_database", "Database name", default="neondb")
854
+ parser.addini("neon_role", "Database role", default="neondb_owner")
855
+ parser.addini(
856
+ "neon_keep_branches",
857
+ "Don't delete branches after tests",
858
+ type="bool",
859
+ default=False,
860
+ )
861
+ parser.addini(
862
+ "neon_branch_expiry",
863
+ "Branch auto-expiry in seconds",
864
+ default=str(DEFAULT_BRANCH_EXPIRY_SECONDS),
865
+ )
866
+ parser.addini(
867
+ "neon_env_var",
868
+ "Environment variable for connection string",
869
+ default="DATABASE_URL",
870
+ )
871
+
872
+
873
+ def _get_config_value(
874
+ config: pytest.Config,
875
+ option: str,
876
+ env_var: str,
877
+ ini_name: str | None = None,
878
+ default: str | None = None,
879
+ ) -> str | None:
880
+ """Get config value from CLI option, env var, ini setting, or default.
881
+
882
+ Priority order: CLI option > environment variable > ini setting > default
883
+ """
884
+ # 1. CLI option (highest priority)
885
+ value = config.getoption(option, default=None)
886
+ if value is not None:
887
+ return value
888
+
889
+ # 2. Environment variable
890
+ env_value = os.environ.get(env_var)
891
+ if env_value is not None:
892
+ return env_value
893
+
894
+ # 3. INI setting (pytest.ini, pyproject.toml, etc.)
895
+ if ini_name is not None:
896
+ ini_value = config.getini(ini_name)
897
+ if ini_value:
898
+ return ini_value
899
+
900
+ # 4. Default
901
+ return default
902
+
903
+
904
+ def _create_neon_branch(
905
+ request: pytest.FixtureRequest,
906
+ parent_branch_id_override: str | None = None,
907
+ branch_expiry_override: int | None = None,
908
+ branch_name_suffix: str = "",
909
+ ) -> Generator[NeonBranch, None, None]:
910
+ """
911
+ Internal helper that creates and manages a Neon branch lifecycle.
912
+
913
+ This is the core implementation used by branch fixtures.
914
+
915
+ Args:
916
+ request: Pytest fixture request
917
+ parent_branch_id_override: If provided, use this as parent instead of config
918
+ branch_expiry_override: If provided, use this expiry instead of config
919
+ branch_name_suffix: Optional suffix for branch name (e.g., "-migrated", "-test")
920
+ """
921
+ config = request.config
922
+
923
+ api_key = _get_config_value(config, "neon_api_key", "NEON_API_KEY", "neon_api_key")
924
+ project_id = _get_config_value(
925
+ config, "neon_project_id", "NEON_PROJECT_ID", "neon_project_id"
926
+ )
927
+ # Use override if provided, otherwise read from config
928
+ parent_branch_id = parent_branch_id_override or _get_config_value(
929
+ config, "neon_parent_branch", "NEON_PARENT_BRANCH_ID", "neon_parent_branch"
930
+ )
931
+ database_name = _get_config_value(
932
+ config, "neon_database", "NEON_DATABASE", "neon_database", "neondb"
933
+ )
934
+ role_name = _get_config_value(
935
+ config, "neon_role", "NEON_ROLE", "neon_role", "neondb_owner"
936
+ )
937
+
938
+ # For boolean/int options, check CLI first, then ini
939
+ keep_branches = config.getoption("neon_keep_branches", default=None)
940
+ if keep_branches is None:
941
+ keep_branches = config.getini("neon_keep_branches")
942
+
943
+ # Use override if provided, otherwise read from config
944
+ if branch_expiry_override is not None:
945
+ branch_expiry = branch_expiry_override
946
+ else:
947
+ branch_expiry = config.getoption("neon_branch_expiry", default=None)
948
+ if branch_expiry is None:
949
+ branch_expiry = int(config.getini("neon_branch_expiry"))
950
+
951
+ env_var_name = _get_config_value(
952
+ config, "neon_env_var", "", "neon_env_var", "DATABASE_URL"
953
+ )
954
+
955
+ if not api_key:
956
+ pytest.skip(
957
+ "Neon API key not configured (set NEON_API_KEY or use --neon-api-key)"
958
+ )
959
+ if not project_id:
960
+ pytest.skip(
961
+ "Neon project ID not configured "
962
+ "(set NEON_PROJECT_ID or use --neon-project-id)"
963
+ )
964
+
965
+ neon = NeonAPI(api_key=api_key)
966
+
967
+ # Cache the default branch ID for safety checks (only fetch once per session)
968
+ if not hasattr(config, "_neon_default_branch_id"):
969
+ config._neon_default_branch_id = _get_default_branch_id(neon, project_id) # type: ignore[attr-defined]
970
+
971
+ # Generate unique branch name
972
+ # Format: pytest-[git branch (first 15 chars)]-[random]-[suffix]
973
+ # This helps identify orphaned branches by showing which git branch created them
974
+ random_suffix = os.urandom(2).hex() # 2 bytes = 4 hex chars
975
+ git_branch = _get_git_branch_name()
976
+ if git_branch:
977
+ # Truncate git branch to 15 chars to keep branch names reasonable
978
+ git_prefix = git_branch[:15]
979
+ branch_name = f"pytest-{git_prefix}-{random_suffix}{branch_name_suffix}"
980
+ else:
981
+ branch_name = f"pytest-{random_suffix}{branch_name_suffix}"
982
+
983
+ # Build branch creation payload
984
+ branch_config: dict[str, Any] = {"name": branch_name}
985
+ if parent_branch_id:
986
+ branch_config["parent_id"] = parent_branch_id
987
+
988
+ # Set branch expiration (auto-delete) as a safety net for interrupted test runs
989
+ # This uses the branch expires_at field, not endpoint suspend_timeout
990
+ if branch_expiry and branch_expiry > 0:
991
+ expires_at = datetime.now(timezone.utc) + timedelta(seconds=branch_expiry)
992
+ branch_config["expires_at"] = expires_at.strftime("%Y-%m-%dT%H:%M:%SZ")
993
+
994
+ # Create branch with compute endpoint
995
+ # Wrap in retry logic to handle rate limits
996
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
997
+ result = _retry_on_rate_limit(
998
+ lambda: neon.branch_create(
999
+ project_id=project_id,
1000
+ branch=branch_config,
1001
+ endpoints=[{"type": "read_write"}],
1002
+ ),
1003
+ operation_name="branch_create",
1004
+ )
1005
+
1006
+ branch = result.branch
1007
+
1008
+ # Get endpoint_id from operations
1009
+ # (branch_create returns operations, not endpoints directly)
1010
+ endpoint_id = None
1011
+ for op in result.operations:
1012
+ if op.endpoint_id:
1013
+ endpoint_id = op.endpoint_id
1014
+ break
1015
+
1016
+ if not endpoint_id:
1017
+ raise RuntimeError(f"No endpoint created for branch {branch.id}")
1018
+
1019
+ # Wait for endpoint to be ready (it starts in "init" state)
1020
+ # Endpoints typically become active in 1-2 seconds, but we allow up to 60s
1021
+ # to handle occasional Neon API slowness or high load scenarios
1022
+ max_wait_seconds = 60
1023
+ poll_interval = 0.5 # Poll every 500ms for responsive feedback
1024
+ waited = 0.0
1025
+
1026
+ while True:
1027
+ # Wrap in retry logic to handle rate limits during polling
1028
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
1029
+ endpoint_response = _retry_on_rate_limit(
1030
+ lambda: neon.endpoint(project_id=project_id, endpoint_id=endpoint_id),
1031
+ operation_name="endpoint_status",
1032
+ )
1033
+ endpoint = endpoint_response.endpoint
1034
+ state = endpoint.current_state
1035
+
1036
+ if state == EndpointState.active:
1037
+ break
1038
+
1039
+ if waited >= max_wait_seconds:
1040
+ raise RuntimeError(
1041
+ f"Timeout waiting for endpoint {endpoint_id} to become active "
1042
+ f"(current state: {state})"
1043
+ )
1044
+
1045
+ time.sleep(poll_interval)
1046
+ waited += poll_interval
1047
+
1048
+ host = endpoint.host
1049
+
1050
+ # SAFETY CHECK: Ensure we never reset password on the default/production branch
1051
+ # This should be impossible since we just created this branch, but we check
1052
+ # defensively to prevent catastrophic mistakes if there's ever a bug
1053
+ default_branch_id = getattr(config, "_neon_default_branch_id", None)
1054
+ if default_branch_id and branch.id == default_branch_id:
1055
+ raise RuntimeError(
1056
+ f"SAFETY CHECK FAILED: Attempted to reset password on default branch "
1057
+ f"{branch.id}. This should never happen - the plugin creates new "
1058
+ f"branches and should never operate on the default branch. "
1059
+ f"Please report this bug at https://github.com/ZainRizvi/pytest-neon/issues"
1060
+ )
1061
+
1062
+ # Reset password to get the password value
1063
+ # (newly created branches don't expose password)
1064
+ # Wrap in retry logic to handle rate limits
1065
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
1066
+ password_response = _retry_on_rate_limit(
1067
+ lambda: neon.role_password_reset(
1068
+ project_id=project_id,
1069
+ branch_id=branch.id,
1070
+ role_name=role_name,
1071
+ ),
1072
+ operation_name="role_password_reset",
1073
+ )
1074
+ password = password_response.role.password
1075
+
1076
+ # Build connection string
1077
+ connection_string = (
1078
+ f"postgresql://{role_name}:{password}@{host}/{database_name}?sslmode=require"
1079
+ )
1080
+
1081
+ neon_branch_info = NeonBranch(
1082
+ branch_id=branch.id,
1083
+ project_id=project_id,
1084
+ connection_string=connection_string,
1085
+ host=host,
1086
+ parent_id=branch.parent_id,
1087
+ endpoint_id=endpoint_id,
1088
+ )
1089
+
1090
+ # Set DATABASE_URL (or configured env var) for the duration of the fixture scope
1091
+ original_env_value = os.environ.get(env_var_name)
1092
+ os.environ[env_var_name] = connection_string
1093
+
1094
+ try:
1095
+ yield neon_branch_info
1096
+ finally:
1097
+ # Restore original env var
1098
+ if original_env_value is None:
1099
+ os.environ.pop(env_var_name, None)
1100
+ else:
1101
+ os.environ[env_var_name] = original_env_value
1102
+
1103
+ # Cleanup: delete branch unless --neon-keep-branches was specified
1104
+ if not keep_branches:
1105
+ try:
1106
+ # Wrap in retry logic to handle rate limits
1107
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
1108
+ _retry_on_rate_limit(
1109
+ lambda: neon.branch_delete(
1110
+ project_id=project_id, branch_id=branch.id
1111
+ ),
1112
+ operation_name="branch_delete",
1113
+ )
1114
+ except Exception as e:
1115
+ # Log but don't fail tests due to cleanup issues
1116
+ warnings.warn(
1117
+ f"Failed to delete Neon branch {branch.id}: {e}",
1118
+ stacklevel=2,
1119
+ )
1120
+
1121
+
1122
+ def _create_readonly_endpoint(
1123
+ branch: NeonBranch,
1124
+ api_key: str,
1125
+ database_name: str,
1126
+ role_name: str,
1127
+ ) -> NeonBranch:
1128
+ """
1129
+ Create a read_only endpoint on an existing branch.
1130
+
1131
+ Returns a new NeonBranch object with the read_only endpoint's connection string.
1132
+ The read_only endpoint enforces that no writes can be made through this connection.
1133
+
1134
+ Args:
1135
+ branch: The branch to create a read_only endpoint on
1136
+ api_key: Neon API key
1137
+ database_name: Database name for connection string
1138
+ role_name: Role name for connection string
1139
+
1140
+ Returns:
1141
+ NeonBranch with read_only endpoint connection details
1142
+ """
1143
+ neon = NeonAPI(api_key=api_key)
1144
+
1145
+ # Create read_only endpoint on the branch
1146
+ # See: https://api-docs.neon.tech/reference/createprojectendpoint
1147
+ result = _retry_on_rate_limit(
1148
+ lambda: neon.endpoint_create(
1149
+ project_id=branch.project_id,
1150
+ endpoint={
1151
+ "branch_id": branch.branch_id,
1152
+ "type": "read_only",
1153
+ },
1154
+ ),
1155
+ operation_name="endpoint_create_readonly",
1156
+ )
1157
+
1158
+ endpoint = result.endpoint
1159
+ endpoint_id = endpoint.id
1160
+
1161
+ # Wait for endpoint to be ready
1162
+ max_wait_seconds = 60
1163
+ poll_interval = 0.5
1164
+ waited = 0.0
1165
+
1166
+ while True:
1167
+ endpoint_response = _retry_on_rate_limit(
1168
+ lambda: neon.endpoint(
1169
+ project_id=branch.project_id, endpoint_id=endpoint_id
1170
+ ),
1171
+ operation_name="endpoint_status_readonly",
1172
+ )
1173
+ endpoint = endpoint_response.endpoint
1174
+ state = endpoint.current_state
1175
+
1176
+ if state == EndpointState.active:
1177
+ break
1178
+
1179
+ if waited >= max_wait_seconds:
1180
+ raise RuntimeError(
1181
+ f"Timeout waiting for read_only endpoint {endpoint_id} "
1182
+ f"to become active (current state: {state})"
1183
+ )
1184
+
1185
+ time.sleep(poll_interval)
1186
+ waited += poll_interval
1187
+
1188
+ host = endpoint.host
1189
+
1190
+ # Reuse the password from the parent branch's connection string.
1191
+ # DO NOT call role_password_reset here - it would invalidate the
1192
+ # password used by the parent branch's read_write endpoint.
1193
+ password = _extract_password_from_connection_string(branch.connection_string)
1194
+
1195
+ # Build connection string for the read_only endpoint
1196
+ connection_string = (
1197
+ f"postgresql://{role_name}:{password}@{host}/{database_name}?sslmode=require"
1198
+ )
1199
+
1200
+ return NeonBranch(
1201
+ branch_id=branch.branch_id,
1202
+ project_id=branch.project_id,
1203
+ connection_string=connection_string,
1204
+ host=host,
1205
+ parent_id=branch.parent_id,
1206
+ endpoint_id=endpoint_id,
1207
+ )
1208
+
1209
+
1210
+ def _delete_endpoint(project_id: str, endpoint_id: str, api_key: str) -> None:
1211
+ """Delete a Neon endpoint."""
1212
+ neon = NeonAPI(api_key=api_key)
1213
+ try:
1214
+ _retry_on_rate_limit(
1215
+ lambda: neon.endpoint_delete(
1216
+ project_id=project_id, endpoint_id=endpoint_id
1217
+ ),
1218
+ operation_name="endpoint_delete",
1219
+ )
1220
+ except Exception as e:
1221
+ warnings.warn(
1222
+ f"Failed to delete Neon endpoint {endpoint_id}: {e}",
1223
+ stacklevel=2,
1224
+ )
1225
+
1226
+
1227
+ def _reset_branch_to_parent(branch: NeonBranch, api_key: str) -> None:
1228
+ """Reset a branch to its parent's state using the Neon API.
1229
+
1230
+ Uses exponential backoff retry logic with jitter to handle rate limit (429)
1231
+ errors. After initiating the restore, polls the operation status until it
1232
+ completes.
1233
+
1234
+ See: https://api-docs.neon.tech/reference/api-rate-limiting
1235
+
1236
+ Args:
1237
+ branch: The branch to reset
1238
+ api_key: Neon API key
1239
+ """
1240
+ if not branch.parent_id:
1241
+ raise RuntimeError(f"Branch {branch.branch_id} has no parent - cannot reset")
1242
+
1243
+ base_url = "https://console.neon.tech/api/v2"
1244
+ project_id = branch.project_id
1245
+ branch_id = branch.branch_id
1246
+ restore_url = f"{base_url}/projects/{project_id}/branches/{branch_id}/restore"
1247
+ headers = {
1248
+ "Authorization": f"Bearer {api_key}",
1249
+ "Content-Type": "application/json",
1250
+ }
1251
+
1252
+ def do_restore() -> dict[str, Any]:
1253
+ response = requests.post(
1254
+ restore_url,
1255
+ headers=headers,
1256
+ json={"source_branch_id": branch.parent_id},
1257
+ timeout=30,
1258
+ )
1259
+ response.raise_for_status()
1260
+ return response.json()
1261
+
1262
+ # Wrap in retry logic to handle rate limits
1263
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
1264
+ data = _retry_on_rate_limit(do_restore, operation_name="branch_restore")
1265
+ operations = data.get("operations", [])
1266
+
1267
+ # The restore API returns operations that run asynchronously.
1268
+ # We must wait for operations to complete before the next test
1269
+ # starts, otherwise connections may fail during the restore.
1270
+ if operations:
1271
+ _wait_for_operations(
1272
+ project_id=branch.project_id,
1273
+ operations=operations,
1274
+ headers=headers,
1275
+ base_url=base_url,
1276
+ )
1277
+
1278
+
1279
+ def _wait_for_operations(
1280
+ project_id: str,
1281
+ operations: list[dict[str, Any]],
1282
+ headers: dict[str, str],
1283
+ base_url: str,
1284
+ max_wait_seconds: float = 60,
1285
+ poll_interval: float = 0.5,
1286
+ ) -> None:
1287
+ """Wait for Neon operations to complete.
1288
+
1289
+ Handles rate limit (429) errors with exponential backoff retry.
1290
+ See: https://api-docs.neon.tech/reference/api-rate-limiting
1291
+
1292
+ Args:
1293
+ project_id: The Neon project ID
1294
+ operations: List of operation dicts from the API response
1295
+ headers: HTTP headers including auth
1296
+ base_url: Base URL for Neon API
1297
+ max_wait_seconds: Maximum time to wait (default: 60s)
1298
+ poll_interval: Time between polls (default: 0.5s)
1299
+ """
1300
+ # Get operation IDs that aren't already finished
1301
+ pending_op_ids = [
1302
+ op["id"] for op in operations if op.get("status") not in ("finished", "skipped")
1303
+ ]
1304
+
1305
+ if not pending_op_ids:
1306
+ return # All operations already complete
1307
+
1308
+ waited = 0.0
1309
+ first_poll = True
1310
+ while pending_op_ids and waited < max_wait_seconds:
1311
+ # Poll immediately first time (operation usually completes instantly),
1312
+ # then wait between subsequent polls
1313
+ if first_poll:
1314
+ time.sleep(0.1) # Tiny delay to let operation start
1315
+ waited += 0.1
1316
+ first_poll = False
1317
+ else:
1318
+ time.sleep(poll_interval)
1319
+ waited += poll_interval
1320
+
1321
+ # Check status of each pending operation
1322
+ still_pending = []
1323
+ for op_id in pending_op_ids:
1324
+ op_url = f"{base_url}/projects/{project_id}/operations/{op_id}"
1325
+
1326
+ def get_operation_status(url: str = op_url) -> dict[str, Any]:
1327
+ """Fetch operation status. Default arg captures url by value."""
1328
+ response = requests.get(url, headers=headers, timeout=10)
1329
+ response.raise_for_status()
1330
+ return response.json()
1331
+
1332
+ try:
1333
+ # Wrap in retry logic to handle rate limits
1334
+ # See: https://api-docs.neon.tech/reference/api-rate-limiting
1335
+ result = _retry_on_rate_limit(
1336
+ get_operation_status,
1337
+ operation_name=f"operation_status({op_id})",
1338
+ )
1339
+ op_data = result.get("operation", {})
1340
+ status = op_data.get("status")
1341
+
1342
+ if status == "failed":
1343
+ err = op_data.get("error", "unknown error")
1344
+ raise RuntimeError(f"Operation {op_id} failed: {err}")
1345
+ if status not in ("finished", "skipped", "cancelled"):
1346
+ still_pending.append(op_id)
1347
+ except requests.RequestException:
1348
+ # On network error (non-429), assume still pending and retry
1349
+ still_pending.append(op_id)
1350
+
1351
+ pending_op_ids = still_pending
1352
+
1353
+ if pending_op_ids:
1354
+ raise RuntimeError(
1355
+ f"Timeout waiting for operations to complete: {pending_op_ids}"
1356
+ )
1357
+
1358
+
1359
+ def _branch_to_dict(branch: NeonBranch) -> dict[str, Any]:
1360
+ """Convert NeonBranch to a JSON-serializable dict."""
1361
+ return asdict(branch)
1362
+
1363
+
1364
+ def _dict_to_branch(data: dict[str, Any]) -> NeonBranch:
1365
+ """Convert a dict back to NeonBranch."""
1366
+ return NeonBranch(**data)
1367
+
1368
+
1369
+ # Timeout for waiting for migrations to complete (seconds)
1370
+ _MIGRATION_WAIT_TIMEOUT = 300 # 5 minutes
1371
+
1372
+
1373
+ @pytest.fixture(scope="session")
1374
+ def _neon_config(request: pytest.FixtureRequest) -> NeonConfig:
1375
+ """
1376
+ Session-scoped Neon configuration extracted from pytest config.
1377
+
1378
+ Skips tests if required configuration (api_key, project_id) is missing.
1379
+ """
1380
+ config = NeonConfig.from_pytest_config(request.config)
1381
+ if config is None:
1382
+ pytest.skip(
1383
+ "Neon configuration missing. Set NEON_API_KEY and NEON_PROJECT_ID "
1384
+ "environment variables or use --neon-api-key and --neon-project-id."
1385
+ )
1386
+ return config
1387
+
1388
+
1389
+ @pytest.fixture(scope="session")
1390
+ def _neon_branch_manager(_neon_config: NeonConfig) -> NeonBranchManager:
1391
+ """Session-scoped branch manager for Neon operations."""
1392
+ return NeonBranchManager(_neon_config)
1393
+
1394
+
1395
+ @pytest.fixture(scope="session")
1396
+ def _neon_xdist_coordinator(
1397
+ tmp_path_factory: pytest.TempPathFactory,
1398
+ ) -> XdistCoordinator:
1399
+ """Session-scoped coordinator for xdist worker synchronization."""
1400
+ return XdistCoordinator(tmp_path_factory)
1401
+
1402
+
1403
+ @pytest.fixture(scope="session")
1404
+ def _neon_migration_branch(
1405
+ request: pytest.FixtureRequest,
1406
+ _neon_config: NeonConfig,
1407
+ _neon_branch_manager: NeonBranchManager,
1408
+ _neon_xdist_coordinator: XdistCoordinator,
1409
+ ) -> Generator[NeonBranch, None, None]:
1410
+ """
1411
+ Session-scoped branch where migrations are applied.
1412
+
1413
+ This branch is ALWAYS created from the configured parent and serves as
1414
+ the parent for all test branches (dirty, isolated, readonly endpoint).
1415
+ Migrations run once per session on this branch.
1416
+
1417
+ pytest-xdist Support:
1418
+ When running with pytest-xdist, the first worker to acquire the lock
1419
+ creates the migration branch. Other workers wait for migrations to
1420
+ complete, then reuse the same branch. This avoids redundant API calls
1421
+ and ensures migrations only run once. Only the creator cleans up the
1422
+ branch at session end.
1423
+
1424
+ Note: The migration branch cannot have an expiry because Neon doesn't
1425
+ allow creating child branches from branches with expiration dates.
1426
+ Cleanup relies on the fixture teardown at session end.
1427
+ """
1428
+ env_manager = EnvironmentManager(_neon_config.env_var_name)
1429
+ branch: NeonBranch
1430
+ is_creator: bool
1431
+
1432
+ def create_migration_branch() -> dict[str, Any]:
1433
+ b = _neon_branch_manager.create_branch(
1434
+ name_suffix="-migration",
1435
+ expiry_seconds=0, # No expiry - child branches need this
1436
+ )
1437
+ return {"branch": _branch_to_dict(b)}
1438
+
1439
+ # Coordinate branch creation across xdist workers
1440
+ data, is_creator = _neon_xdist_coordinator.coordinate_resource(
1441
+ "migration_branch", create_migration_branch
1442
+ )
1443
+ branch = _dict_to_branch(data["branch"])
1444
+
1445
+ # Store creator status for other fixtures
1446
+ request.config._neon_is_migration_creator = is_creator # type: ignore[attr-defined]
1447
+
1448
+ # Set DATABASE_URL
1449
+ env_manager.set(branch.connection_string)
1450
+
1451
+ # Non-creators wait for migrations to complete
1452
+ if not is_creator:
1453
+ _neon_xdist_coordinator.wait_for_signal(
1454
+ "migrations_done", timeout=_MIGRATION_WAIT_TIMEOUT
1455
+ )
1456
+
1457
+ try:
1458
+ yield branch
1459
+ finally:
1460
+ env_manager.restore()
1461
+ # Only creator cleans up
1462
+ if is_creator:
1463
+ _neon_branch_manager.delete_branch(branch.branch_id)
1464
+
1465
+
1466
+ @pytest.fixture(scope="session")
1467
+ def neon_apply_migrations(_neon_migration_branch: NeonBranch) -> Any:
1468
+ """
1469
+ Override this fixture to run migrations on the test database.
1470
+
1471
+ The migration branch is already created and DATABASE_URL is set.
1472
+ Migrations run once per test session, before any tests execute.
1473
+
1474
+ pytest-xdist Support:
1475
+ When running with pytest-xdist, migrations only run on the first
1476
+ worker (the one that created the migration branch). Other workers
1477
+ wait for migrations to complete before proceeding. This ensures
1478
+ migrations run exactly once, even with parallel workers.
1479
+
1480
+ Smart Migration Detection:
1481
+ The plugin automatically detects whether migrations actually modified
1482
+ the database schema. If no schema changes occurred (or this fixture
1483
+ isn't overridden), the plugin skips creating a separate migration
1484
+ branch, saving Neon costs and branch slots.
1485
+
1486
+ Example in conftest.py:
1487
+
1488
+ @pytest.fixture(scope="session")
1489
+ def neon_apply_migrations(_neon_migration_branch):
1490
+ import subprocess
1491
+ subprocess.run(["alembic", "upgrade", "head"], check=True)
1492
+
1493
+ Or with Django:
1494
+
1495
+ @pytest.fixture(scope="session")
1496
+ def neon_apply_migrations(_neon_migration_branch):
1497
+ from django.core.management import call_command
1498
+ call_command("migrate", "--noinput")
1499
+
1500
+ Or with raw SQL:
1501
+
1502
+ @pytest.fixture(scope="session")
1503
+ def neon_apply_migrations(_neon_migration_branch):
1504
+ import psycopg
1505
+ with psycopg.connect(_neon_migration_branch.connection_string) as conn:
1506
+ with open("schema.sql") as f:
1507
+ conn.execute(f.read())
1508
+ conn.commit()
1509
+
1510
+ Args:
1511
+ _neon_migration_branch: The migration branch with connection details.
1512
+ Use _neon_migration_branch.connection_string to connect directly,
1513
+ or rely on DATABASE_URL which is already set.
1514
+
1515
+ Returns:
1516
+ Any value (ignored). The default returns a sentinel to indicate
1517
+ the fixture was not overridden.
1518
+ """
1519
+ return _MIGRATIONS_NOT_DEFINED
1520
+
1521
+
1522
+ @pytest.fixture(scope="session")
1523
+ def _neon_migrations_synchronized(
1524
+ request: pytest.FixtureRequest,
1525
+ _neon_migration_branch: NeonBranch,
1526
+ _neon_xdist_coordinator: XdistCoordinator,
1527
+ neon_apply_migrations: Any,
1528
+ ) -> Any:
1529
+ """
1530
+ Internal fixture that synchronizes migrations across xdist workers.
1531
+
1532
+ This fixture ensures that:
1533
+ 1. Only the creator worker runs migrations (non-creators wait in
1534
+ _neon_migration_branch BEFORE neon_apply_migrations runs)
1535
+ 2. Creator signals completion after migrations finish
1536
+ 3. The return value from neon_apply_migrations is preserved for detection
1537
+
1538
+ Without xdist, this is a simple passthrough.
1539
+ """
1540
+ is_creator = getattr(request.config, "_neon_is_migration_creator", True)
1541
+
1542
+ if is_creator:
1543
+ # Creator: migrations just ran via neon_apply_migrations dependency
1544
+ # Signal completion to other workers
1545
+ _neon_xdist_coordinator.send_signal("migrations_done")
1546
+
1547
+ return neon_apply_migrations
1548
+
1549
+
1550
+ @pytest.fixture(scope="session")
1551
+ def _neon_dirty_branch(
1552
+ _neon_config: NeonConfig,
1553
+ _neon_branch_manager: NeonBranchManager,
1554
+ _neon_xdist_coordinator: XdistCoordinator,
1555
+ _neon_migration_branch: NeonBranch,
1556
+ _neon_migrations_synchronized: Any, # Ensures migrations complete first
1557
+ ) -> Generator[NeonBranch, None, None]:
1558
+ """
1559
+ Session-scoped dirty branch shared across ALL xdist workers.
1560
+
1561
+ This branch is a child of the migration branch. All tests using
1562
+ neon_branch_dirty share this single branch - writes persist and
1563
+ are visible to all tests (even across workers).
1564
+
1565
+ This is the "dirty" branch because:
1566
+ - No reset between tests
1567
+ - Shared across all workers (concurrent writes possible)
1568
+ - Fast because no per-test overhead
1569
+ """
1570
+ env_manager = EnvironmentManager(_neon_config.env_var_name)
1571
+ branch: NeonBranch
1572
+ is_creator: bool
1573
+
1574
+ def create_dirty_branch() -> dict[str, Any]:
1575
+ b = _neon_branch_manager.create_branch(
1576
+ name_suffix="-dirty",
1577
+ parent_branch_id=_neon_migration_branch.branch_id,
1578
+ expiry_seconds=_neon_config.branch_expiry,
1579
+ )
1580
+ return {"branch": _branch_to_dict(b)}
1581
+
1582
+ # Coordinate dirty branch creation - shared across ALL workers
1583
+ data, is_creator = _neon_xdist_coordinator.coordinate_resource(
1584
+ "dirty_branch", create_dirty_branch
1585
+ )
1586
+ branch = _dict_to_branch(data["branch"])
1587
+
1588
+ # Set DATABASE_URL
1589
+ env_manager.set(branch.connection_string)
1590
+
1591
+ try:
1592
+ yield branch
1593
+ finally:
1594
+ env_manager.restore()
1595
+ # Only creator cleans up
1596
+ if is_creator:
1597
+ _neon_branch_manager.delete_branch(branch.branch_id)
1598
+
1599
+
1600
+ @pytest.fixture(scope="session")
1601
+ def _neon_readonly_endpoint(
1602
+ _neon_config: NeonConfig,
1603
+ _neon_branch_manager: NeonBranchManager,
1604
+ _neon_xdist_coordinator: XdistCoordinator,
1605
+ _neon_migration_branch: NeonBranch,
1606
+ _neon_migrations_synchronized: Any, # Ensures migrations complete first
1607
+ ) -> Generator[NeonBranch, None, None]:
1608
+ """
1609
+ Session-scoped read_only endpoint on the migration branch.
1610
+
1611
+ This is a true read-only endpoint - writes are blocked at the database
1612
+ level. All workers share this endpoint since it's read-only anyway.
1613
+ """
1614
+ env_manager = EnvironmentManager(_neon_config.env_var_name)
1615
+ branch: NeonBranch
1616
+ is_creator: bool
1617
+
1618
+ def create_readonly_endpoint() -> dict[str, Any]:
1619
+ b = _neon_branch_manager.create_readonly_endpoint(_neon_migration_branch)
1620
+ return {"branch": _branch_to_dict(b)}
1621
+
1622
+ # Coordinate endpoint creation - shared across ALL workers
1623
+ data, is_creator = _neon_xdist_coordinator.coordinate_resource(
1624
+ "readonly_endpoint", create_readonly_endpoint
1625
+ )
1626
+ branch = _dict_to_branch(data["branch"])
1627
+
1628
+ # Set DATABASE_URL
1629
+ env_manager.set(branch.connection_string)
1630
+
1631
+ try:
1632
+ yield branch
1633
+ finally:
1634
+ env_manager.restore()
1635
+ # Only creator cleans up the endpoint
1636
+ if is_creator and branch.endpoint_id:
1637
+ _neon_branch_manager.delete_endpoint(branch.endpoint_id)
1638
+
1639
+
1640
+ @pytest.fixture(scope="session")
1641
+ def _neon_isolated_branch(
1642
+ request: pytest.FixtureRequest,
1643
+ _neon_config: NeonConfig,
1644
+ _neon_branch_manager: NeonBranchManager,
1645
+ _neon_xdist_coordinator: XdistCoordinator,
1646
+ _neon_migration_branch: NeonBranch,
1647
+ _neon_migrations_synchronized: Any, # Ensures migrations complete first
1648
+ ) -> Generator[NeonBranch, None, None]:
1649
+ """
1650
+ Session-scoped isolated branch, one per xdist worker.
1651
+
1652
+ Each worker gets its own branch. Unlike the dirty branch, this is
1653
+ per-worker to allow reset operations without affecting other workers.
1654
+
1655
+ The branch is reset after each test that uses neon_branch_isolated.
1656
+ """
1657
+ env_manager = EnvironmentManager(_neon_config.env_var_name)
1658
+ worker_id = _neon_xdist_coordinator.worker_id
1659
+
1660
+ # Each worker creates its own isolated branch - no coordination needed
1661
+ # because each worker has a unique ID
1662
+ branch = _neon_branch_manager.create_branch(
1663
+ name_suffix=f"-isolated-{worker_id}",
1664
+ parent_branch_id=_neon_migration_branch.branch_id,
1665
+ expiry_seconds=_neon_config.branch_expiry,
1666
+ )
1667
+
1668
+ # Store branch manager on config for reset operations
1669
+ request.config._neon_isolated_branch_manager = _neon_branch_manager # type: ignore[attr-defined]
1670
+
1671
+ # Set DATABASE_URL
1672
+ env_manager.set(branch.connection_string)
1673
+
1674
+ try:
1675
+ yield branch
1676
+ finally:
1677
+ env_manager.restore()
1678
+ _neon_branch_manager.delete_branch(branch.branch_id)
1679
+
1680
+
1681
+ @pytest.fixture(scope="session")
1682
+ def neon_branch_readonly(
1683
+ _neon_config: NeonConfig,
1684
+ _neon_readonly_endpoint: NeonBranch,
1685
+ ) -> NeonBranch:
1686
+ """
1687
+ Provide a true read-only Neon database connection.
1688
+
1689
+ This fixture uses a read_only endpoint on the migration branch, which
1690
+ enforces read-only access at the database level. Any attempt to write
1691
+ will result in a database error.
1692
+
1693
+ This is the recommended fixture for tests that only read data (SELECT queries).
1694
+ It's session-scoped and shared across all tests and workers since it's read-only.
1695
+
1696
+ Use this fixture when your tests only perform SELECT queries.
1697
+ For tests that INSERT, UPDATE, or DELETE data, use ``neon_branch_dirty``
1698
+ (for shared state) or ``neon_branch_isolated`` (for test isolation).
1699
+
1700
+ The connection string is automatically set in the DATABASE_URL environment
1701
+ variable (configurable via --neon-env-var).
1702
+
1703
+ Requires either:
1704
+ - NEON_API_KEY and NEON_PROJECT_ID environment variables, or
1705
+ - --neon-api-key and --neon-project-id command line options
1706
+
1707
+ Returns:
1708
+ NeonBranch: Object with branch_id, project_id, connection_string, and host.
1709
+
1710
+ Example::
1711
+
1712
+ def test_query_users(neon_branch_readonly):
1713
+ # DATABASE_URL is automatically set
1714
+ conn_string = os.environ["DATABASE_URL"]
1715
+
1716
+ # Read-only query
1717
+ with psycopg.connect(conn_string) as conn:
1718
+ result = conn.execute("SELECT * FROM users").fetchall()
1719
+ assert len(result) > 0
1720
+
1721
+ # This would fail with a database error:
1722
+ # conn.execute("INSERT INTO users (name) VALUES ('test')")
1723
+ """
1724
+ # DATABASE_URL is already set by _neon_readonly_endpoint
1725
+ return _neon_readonly_endpoint
1726
+
1727
+
1728
+ @pytest.fixture(scope="session")
1729
+ def neon_branch_dirty(
1730
+ _neon_config: NeonConfig,
1731
+ _neon_dirty_branch: NeonBranch,
1732
+ ) -> NeonBranch:
1733
+ """
1734
+ Provide a session-scoped Neon database branch for read-write access.
1735
+
1736
+ All tests share the same branch and writes persist across tests (no cleanup
1737
+ between tests). This is faster than neon_branch_isolated because there's no
1738
+ reset overhead.
1739
+
1740
+ Use this fixture when:
1741
+ - Most tests can share database state without interference
1742
+ - You want maximum performance with minimal API calls
1743
+ - You manually manage test data cleanup if needed
1744
+ - You're using it alongside ``neon_branch_isolated`` for specific tests
1745
+ that need guaranteed clean state
1746
+
1747
+ The connection string is automatically set in the DATABASE_URL environment
1748
+ variable (configurable via --neon-env-var).
1749
+
1750
+ Warning:
1751
+ Data written by one test WILL be visible to subsequent tests AND to
1752
+ other xdist workers. This is truly shared - use ``neon_branch_isolated``
1753
+ for tests that require guaranteed clean state.
1754
+
1755
+ pytest-xdist:
1756
+ ALL workers share the same dirty branch. Concurrent writes from different
1757
+ workers may conflict. This is "dirty" by design - for isolation, use
1758
+ ``neon_branch_isolated``.
1759
+
1760
+ Requires either:
1761
+ - NEON_API_KEY and NEON_PROJECT_ID environment variables, or
1762
+ - --neon-api-key and --neon-project-id command line options
1763
+
1764
+ Returns:
1765
+ NeonBranch: Object with branch_id, project_id, connection_string, and host.
1766
+
1767
+ Example::
1768
+
1769
+ def test_insert_user(neon_branch_dirty):
1770
+ # DATABASE_URL is automatically set
1771
+ import psycopg
1772
+ with psycopg.connect(neon_branch_dirty.connection_string) as conn:
1773
+ conn.execute("INSERT INTO users (name) VALUES ('test')")
1774
+ conn.commit()
1775
+ # Data persists - next test will see this user
1776
+
1777
+ def test_count_users(neon_branch_dirty):
1778
+ # This test sees data from previous tests
1779
+ import psycopg
1780
+ with psycopg.connect(neon_branch_dirty.connection_string) as conn:
1781
+ result = conn.execute("SELECT COUNT(*) FROM users").fetchone()
1782
+ # Count includes users from previous tests
1783
+ """
1784
+ # DATABASE_URL is already set by _neon_dirty_branch
1785
+ return _neon_dirty_branch
1786
+
1787
+
1788
+ @pytest.fixture(scope="function")
1789
+ def neon_branch_isolated(
1790
+ request: pytest.FixtureRequest,
1791
+ _neon_config: NeonConfig,
1792
+ _neon_isolated_branch: NeonBranch,
1793
+ ) -> Generator[NeonBranch, None, None]:
1794
+ """
1795
+ Provide an isolated Neon database branch with reset after each test.
1796
+
1797
+ This is the recommended fixture for tests that modify database state and
1798
+ need isolation. Each xdist worker has its own branch, and the branch is
1799
+ reset to the migration state after each test.
1800
+
1801
+ Use this fixture when:
1802
+ - Tests modify database state (INSERT, UPDATE, DELETE)
1803
+ - You need test isolation (each test starts with clean state)
1804
+ - You're using it alongside ``neon_branch_dirty`` for specific tests
1805
+
1806
+ The connection string is automatically set in the DATABASE_URL environment
1807
+ variable (configurable via --neon-env-var).
1808
+
1809
+ SQLAlchemy Users:
1810
+ If you create your own engine (not using the neon_engine fixture),
1811
+ you MUST use pool_pre_ping=True::
1812
+
1813
+ engine = create_engine(DATABASE_URL, pool_pre_ping=True)
1814
+
1815
+ Branch resets terminate server-side connections. Without pool_pre_ping,
1816
+ SQLAlchemy may reuse dead pooled connections, causing SSL errors.
1817
+
1818
+ pytest-xdist:
1819
+ Each worker has its own isolated branch. Resets only affect that worker's
1820
+ branch, so workers don't interfere with each other.
1821
+
1822
+ Requires either:
1823
+ - NEON_API_KEY and NEON_PROJECT_ID environment variables, or
1824
+ - --neon-api-key and --neon-project-id command line options
1825
+
1826
+ Yields:
1827
+ NeonBranch: Object with branch_id, project_id, connection_string, and host.
1828
+
1829
+ Example::
1830
+
1831
+ def test_insert_user(neon_branch_isolated):
1832
+ # DATABASE_URL is automatically set
1833
+ conn_string = os.environ["DATABASE_URL"]
1834
+
1835
+ # Insert data - branch will reset after this test
1836
+ with psycopg.connect(conn_string) as conn:
1837
+ conn.execute("INSERT INTO users (name) VALUES ('test')")
1838
+ conn.commit()
1839
+ # Next test starts with clean state
1840
+ """
1841
+ # DATABASE_URL is already set by _neon_isolated_branch
1842
+ yield _neon_isolated_branch
1843
+
1844
+ # Reset branch to migration state after each test
1845
+ branch_manager = getattr(request.config, "_neon_isolated_branch_manager", None)
1846
+ if branch_manager is not None:
1847
+ try:
1848
+ branch_manager.reset_branch(_neon_isolated_branch)
1849
+ except Exception as e:
1850
+ pytest.fail(
1851
+ f"\n\nFailed to reset branch {_neon_isolated_branch.branch_id} "
1852
+ f"after test. Subsequent tests may see dirty state.\n\n"
1853
+ f"Error: {e}\n\n"
1854
+ f"To keep the branch for debugging, use --neon-keep-branches"
1855
+ )
1856
+
1857
+
1858
+ @pytest.fixture(scope="function")
1859
+ def neon_branch_readwrite(
1860
+ neon_branch_isolated: NeonBranch,
1861
+ ) -> Generator[NeonBranch, None, None]:
1862
+ """
1863
+ Deprecated: Use ``neon_branch_isolated`` instead.
1864
+
1865
+ This fixture is now an alias for ``neon_branch_isolated``.
1866
+
1867
+ .. deprecated:: 2.3.0
1868
+ Use ``neon_branch_isolated`` for tests that modify data with reset,
1869
+ ``neon_branch_dirty`` for shared state, or ``neon_branch_readonly``
1870
+ for read-only access.
1871
+ """
1872
+ warnings.warn(
1873
+ "neon_branch_readwrite is deprecated. Use neon_branch_isolated (for tests "
1874
+ "that modify data with isolation) or neon_branch_dirty (for shared state).",
1875
+ DeprecationWarning,
1876
+ stacklevel=2,
1877
+ )
1878
+ yield neon_branch_isolated
1879
+
1880
+
1881
+ @pytest.fixture(scope="function")
1882
+ def neon_branch(
1883
+ neon_branch_isolated: NeonBranch,
1884
+ ) -> Generator[NeonBranch, None, None]:
1885
+ """
1886
+ Deprecated: Use ``neon_branch_isolated``, ``neon_branch_dirty``, or
1887
+ ``neon_branch_readonly`` instead.
1888
+
1889
+ This fixture is now an alias for ``neon_branch_isolated``.
1890
+
1891
+ .. deprecated:: 1.1.0
1892
+ Use ``neon_branch_isolated`` for tests that modify data with reset,
1893
+ ``neon_branch_dirty`` for shared state, or ``neon_branch_readonly``
1894
+ for read-only access.
1895
+ """
1896
+ warnings.warn(
1897
+ "neon_branch is deprecated. Use neon_branch_isolated (for tests that "
1898
+ "modify data), neon_branch_dirty (for shared state), or "
1899
+ "neon_branch_readonly (for read-only tests).",
1900
+ DeprecationWarning,
1901
+ stacklevel=2,
1902
+ )
1903
+ yield neon_branch_isolated
1904
+
1905
+
1906
+ @pytest.fixture(scope="module")
1907
+ def neon_branch_shared(
1908
+ request: pytest.FixtureRequest,
1909
+ _neon_migration_branch: NeonBranch,
1910
+ _neon_migrations_synchronized: Any, # Ensures migrations complete first
1911
+ ) -> Generator[NeonBranch, None, None]:
1912
+ """
1913
+ Provide a shared Neon database branch for all tests in a module.
1914
+
1915
+ This fixture creates one branch per test module and shares it across all
1916
+ tests without resetting. This is the fastest option but tests can see
1917
+ each other's data modifications.
1918
+
1919
+ If you override the `neon_apply_migrations` fixture, migrations will run
1920
+ once before the first test, and this branch will include the migrated schema.
1921
+
1922
+ Use this when:
1923
+ - Tests are read-only or don't interfere with each other
1924
+ - You manually clean up test data within each test
1925
+ - Maximum speed is more important than isolation
1926
+
1927
+ Warning: Tests in the same module will share database state. Data created
1928
+ by one test will be visible to subsequent tests. Use `neon_branch` instead
1929
+ if you need isolation between tests.
1930
+
1931
+ Yields:
1932
+ NeonBranch: Object with branch_id, project_id, connection_string, and host.
1933
+
1934
+ Example:
1935
+ def test_read_only_query(neon_branch_shared):
1936
+ # Fast: no reset between tests, but be careful about data leakage
1937
+ conn_string = neon_branch_shared.connection_string
1938
+ """
1939
+ yield from _create_neon_branch(
1940
+ request,
1941
+ parent_branch_id_override=_neon_migration_branch.branch_id,
1942
+ branch_name_suffix="-shared",
1943
+ )
1944
+
1945
+
1946
+ @pytest.fixture
1947
+ def neon_connection(neon_branch_isolated: NeonBranch):
1948
+ """
1949
+ Provide a psycopg2 connection to the test branch.
1950
+
1951
+ Requires the psycopg2 optional dependency:
1952
+ pip install pytest-neon[psycopg2]
1953
+
1954
+ The connection is rolled back and closed after each test.
1955
+ Uses neon_branch_isolated for test isolation.
1956
+
1957
+ Yields:
1958
+ psycopg2 connection object
1959
+
1960
+ Example:
1961
+ def test_insert(neon_connection):
1962
+ cur = neon_connection.cursor()
1963
+ cur.execute("INSERT INTO users (name) VALUES ('test')")
1964
+ neon_connection.commit()
1965
+ """
1966
+ try:
1967
+ import psycopg2
1968
+ except ImportError:
1969
+ pytest.fail(
1970
+ "\n\n"
1971
+ "═══════════════════════════════════════════════════════════════════\n"
1972
+ " MISSING DEPENDENCY: psycopg2\n"
1973
+ "═══════════════════════════════════════════════════════════════════\n\n"
1974
+ " The 'neon_connection' fixture requires psycopg2.\n\n"
1975
+ " To fix this, install the psycopg2 extra:\n\n"
1976
+ " pip install pytest-neon[psycopg2]\n\n"
1977
+ " Or use the 'neon_branch_isolated' fixture with your own driver:\n\n"
1978
+ " def test_example(neon_branch_isolated):\n"
1979
+ " import your_driver\n"
1980
+ " conn = your_driver.connect(\n"
1981
+ " neon_branch_isolated.connection_string)\n\n"
1982
+ "═══════════════════════════════════════════════════════════════════\n"
1983
+ )
1984
+
1985
+ conn = psycopg2.connect(neon_branch_isolated.connection_string)
1986
+ yield conn
1987
+ conn.rollback()
1988
+ conn.close()
1989
+
1990
+
1991
+ @pytest.fixture
1992
+ def neon_connection_psycopg(neon_branch_isolated: NeonBranch):
1993
+ """
1994
+ Provide a psycopg (v3) connection to the test branch.
1995
+
1996
+ Requires the psycopg optional dependency:
1997
+ pip install pytest-neon[psycopg]
1998
+
1999
+ The connection is rolled back and closed after each test.
2000
+ Uses neon_branch_isolated for test isolation.
2001
+
2002
+ Yields:
2003
+ psycopg connection object
2004
+
2005
+ Example:
2006
+ def test_insert(neon_connection_psycopg):
2007
+ with neon_connection_psycopg.cursor() as cur:
2008
+ cur.execute("INSERT INTO users (name) VALUES ('test')")
2009
+ neon_connection_psycopg.commit()
2010
+ """
2011
+ try:
2012
+ import psycopg
2013
+ except ImportError:
2014
+ pytest.fail(
2015
+ "\n\n"
2016
+ "═══════════════════════════════════════════════════════════════════\n"
2017
+ " MISSING DEPENDENCY: psycopg (v3)\n"
2018
+ "═══════════════════════════════════════════════════════════════════\n\n"
2019
+ " The 'neon_connection_psycopg' fixture requires psycopg v3.\n\n"
2020
+ " To fix this, install the psycopg extra:\n\n"
2021
+ " pip install pytest-neon[psycopg]\n\n"
2022
+ " Or use the 'neon_branch_isolated' fixture with your own driver:\n\n"
2023
+ " def test_example(neon_branch_isolated):\n"
2024
+ " import your_driver\n"
2025
+ " conn = your_driver.connect(\n"
2026
+ " neon_branch_isolated.connection_string)\n\n"
2027
+ "═══════════════════════════════════════════════════════════════════\n"
2028
+ )
2029
+
2030
+ conn = psycopg.connect(neon_branch_isolated.connection_string)
2031
+ yield conn
2032
+ conn.rollback()
2033
+ conn.close()
2034
+
2035
+
2036
+ @pytest.fixture
2037
+ def neon_engine(neon_branch_isolated: NeonBranch):
2038
+ """
2039
+ Provide a SQLAlchemy engine connected to the test branch.
2040
+
2041
+ Requires the sqlalchemy optional dependency:
2042
+ pip install pytest-neon[sqlalchemy]
2043
+
2044
+ The engine is disposed after each test, which handles stale connections
2045
+ after branch resets automatically. Uses neon_branch_isolated for test isolation.
2046
+
2047
+ Note:
2048
+ If you create your own module-level engine instead of using this
2049
+ fixture, you MUST use pool_pre_ping=True::
2050
+
2051
+ engine = create_engine(DATABASE_URL, pool_pre_ping=True)
2052
+
2053
+ This is required because branch resets terminate server-side
2054
+ connections, and without pool_pre_ping SQLAlchemy may reuse dead
2055
+ pooled connections.
2056
+
2057
+ Yields:
2058
+ SQLAlchemy Engine object
2059
+
2060
+ Example::
2061
+
2062
+ def test_query(neon_engine):
2063
+ with neon_engine.connect() as conn:
2064
+ result = conn.execute(text("SELECT 1"))
2065
+ """
2066
+ try:
2067
+ from sqlalchemy import create_engine
2068
+ except ImportError:
2069
+ pytest.fail(
2070
+ "\n\n"
2071
+ "═══════════════════════════════════════════════════════════════════\n"
2072
+ " MISSING DEPENDENCY: SQLAlchemy\n"
2073
+ "═══════════════════════════════════════════════════════════════════\n\n"
2074
+ " The 'neon_engine' fixture requires SQLAlchemy.\n\n"
2075
+ " To fix this, install the sqlalchemy extra:\n\n"
2076
+ " pip install pytest-neon[sqlalchemy]\n\n"
2077
+ " Or use the 'neon_branch_isolated' fixture with your own driver:\n\n"
2078
+ " def test_example(neon_branch_isolated):\n"
2079
+ " from sqlalchemy import create_engine\n"
2080
+ " engine = create_engine(\n"
2081
+ " neon_branch_isolated.connection_string)\n\n"
2082
+ "═══════════════════════════════════════════════════════════════════\n"
2083
+ )
2084
+
2085
+ engine = create_engine(neon_branch_isolated.connection_string)
2086
+ yield engine
2087
+ engine.dispose()