postkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
postkit/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """
2
+ postkit - PostgreSQL-native authentication and authorization SDK.
3
+
4
+ Usage:
5
+ from postkit.authz import AuthzClient
6
+ from postkit.authn import AuthnClient
7
+ """
8
+
9
+ __version__ = "0.1.0"
@@ -0,0 +1,13 @@
1
+ """postkit.authn - Authentication client for PostgreSQL-native auth."""
2
+
3
+ from postkit.authn.client import (
4
+ AuthnClient,
5
+ AuthnError,
6
+ AuthnValidationError,
7
+ )
8
+
9
+ __all__ = [
10
+ "AuthnClient",
11
+ "AuthnError",
12
+ "AuthnValidationError",
13
+ ]
@@ -0,0 +1,567 @@
1
+ """
2
+ postkit.authn - Authentication client for PostgreSQL-native auth.
3
+
4
+ This module provides:
5
+ - AuthnClient: SDK-style interface for authentication operations
6
+ - Exception classes: AuthnError, AuthnValidationError
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from datetime import datetime, timedelta
12
+ from uuid import UUID
13
+
14
+ import psycopg
15
+
16
+
17
+ __all__ = [
18
+ "AuthnClient",
19
+ "AuthnError",
20
+ "AuthnValidationError",
21
+ ]
22
+
23
+
24
+ # =============================================================================
25
+ # EXCEPTIONS
26
+ # =============================================================================
27
+
28
+
29
+ class AuthnError(Exception):
30
+ """Base exception for authn operations."""
31
+
32
+ pass
33
+
34
+
35
+ class AuthnValidationError(AuthnError):
36
+ """Raised when input validation fails."""
37
+
38
+ pass
39
+
40
+
41
+ # =============================================================================
42
+ # CLIENT
43
+ # =============================================================================
44
+
45
+
46
+ class AuthnClient:
47
+ """
48
+ SDK-style client for postkit/authn.
49
+
50
+ This wraps the SQL functions with a Pythonic API.
51
+
52
+ Example:
53
+ authn = AuthnClient(cursor, namespace="production")
54
+
55
+ # Create user
56
+ user_id = authn.create_user("alice@example.com", "argon2_hash")
57
+
58
+ # Create session
59
+ session_id = authn.create_session(user_id, "sha256_token_hash")
60
+
61
+ # Validate session
62
+ user = authn.validate_session("sha256_token_hash")
63
+ if user:
64
+ print(f"Logged in as {user['email']}")
65
+ """
66
+
67
+ def __init__(self, cursor, namespace: str):
68
+ self.cursor = cursor
69
+ self.namespace = namespace
70
+ # Set tenant context for RLS
71
+ self.cursor.execute("SELECT authn.set_tenant(%s)", (namespace,))
72
+ # Actor context stored as instance state (applied per-operation in _write_scalar)
73
+ self._actor_id: str | None = None
74
+ self._request_id: str | None = None
75
+ self._ip_address: str | None = None
76
+ self._user_agent: str | None = None
77
+
78
+ def _handle_error(self, e: psycopg.Error) -> None:
79
+ """Convert psycopg errors to SDK exceptions."""
80
+ raise AuthnError(str(e)) from e
81
+
82
+ def _normalize_row(self, row: dict) -> dict:
83
+ """Normalize types in result row (UUIDs to strings)."""
84
+ return {k: str(v) if isinstance(v, UUID) else v for k, v in row.items()}
85
+
86
+ def _scalar(self, sql: str, params: tuple):
87
+ """Execute SQL and return single scalar value."""
88
+ try:
89
+ self.cursor.execute(sql, params)
90
+ result = self.cursor.fetchone()
91
+ return result[0] if result else None
92
+ except psycopg.Error as e:
93
+ self._handle_error(e)
94
+
95
+ def _row(self, sql: str, params: tuple) -> dict | None:
96
+ """Execute SQL and return single row as dict with normalized types."""
97
+ try:
98
+ self.cursor.execute(sql, params)
99
+ result = self.cursor.fetchone()
100
+ if result is None:
101
+ return None
102
+ columns = [desc[0] for desc in self.cursor.description]
103
+ return self._normalize_row(dict(zip(columns, result)))
104
+ except psycopg.Error as e:
105
+ self._handle_error(e)
106
+
107
+ def _fetchall(self, sql: str, params: tuple) -> list[dict]:
108
+ """Execute SQL and return all rows as list of dicts with normalized types."""
109
+ self.cursor.execute(sql, params)
110
+ columns = [desc[0] for desc in self.cursor.description]
111
+ return [
112
+ self._normalize_row(dict(zip(columns, row)))
113
+ for row in self.cursor.fetchall()
114
+ ]
115
+
116
+ def _write_scalar(self, sql: str, params: tuple):
117
+ """Execute a write operation with actor context for audit logging."""
118
+ if self._actor_id is None:
119
+ return self._scalar(sql, params)
120
+
121
+ in_transaction = self.cursor.connection.info.transaction_status != 0
122
+
123
+ if in_transaction:
124
+ self.cursor.execute(
125
+ "SELECT authn.set_actor(%s, %s, %s, %s)",
126
+ (self._actor_id, self._request_id, self._ip_address, self._user_agent),
127
+ )
128
+ return self._scalar(sql, params)
129
+
130
+ try:
131
+ self.cursor.execute("BEGIN")
132
+ self.cursor.execute(
133
+ "SELECT authn.set_actor(%s, %s, %s, %s)",
134
+ (self._actor_id, self._request_id, self._ip_address, self._user_agent),
135
+ )
136
+ result = self._scalar(sql, params)
137
+ self.cursor.execute("COMMIT")
138
+ return result
139
+ except Exception:
140
+ self.cursor.execute("ROLLBACK")
141
+ raise
142
+
143
+ # =========================================================================
144
+ # User Management
145
+ # =========================================================================
146
+
147
+ def create_user(
148
+ self,
149
+ email: str,
150
+ password_hash: str | None = None,
151
+ ) -> str:
152
+ """
153
+ Create a new user.
154
+
155
+ Args:
156
+ email: User's email address (will be normalized to lowercase)
157
+ password_hash: Pre-hashed password (None for SSO-only users)
158
+
159
+ Returns:
160
+ User ID (UUID string)
161
+ """
162
+ result = self._write_scalar(
163
+ "SELECT authn.create_user(%s, %s, %s)",
164
+ (email, password_hash, self.namespace),
165
+ )
166
+ return str(result) if result else None
167
+
168
+ def get_user(self, user_id: str) -> dict | None:
169
+ """Get user by ID. Does not return password_hash."""
170
+ return self._row(
171
+ "SELECT * FROM authn.get_user(%s::uuid, %s)",
172
+ (user_id, self.namespace),
173
+ )
174
+
175
+ def get_user_by_email(self, email: str) -> dict | None:
176
+ """Get user by email. Does not return password_hash."""
177
+ return self._row(
178
+ "SELECT * FROM authn.get_user_by_email(%s, %s)",
179
+ (email, self.namespace),
180
+ )
181
+
182
+ def update_email(self, user_id: str, new_email: str) -> bool:
183
+ """Update user's email. Clears email_verified_at."""
184
+ return self._write_scalar(
185
+ "SELECT authn.update_email(%s::uuid, %s, %s)",
186
+ (user_id, new_email, self.namespace),
187
+ )
188
+
189
+ def disable_user(self, user_id: str) -> bool:
190
+ """Disable user and revoke all their sessions."""
191
+ return self._write_scalar(
192
+ "SELECT authn.disable_user(%s::uuid, %s)",
193
+ (user_id, self.namespace),
194
+ )
195
+
196
+ def enable_user(self, user_id: str) -> bool:
197
+ """Re-enable a disabled user."""
198
+ return self._write_scalar(
199
+ "SELECT authn.enable_user(%s::uuid, %s)",
200
+ (user_id, self.namespace),
201
+ )
202
+
203
+ def delete_user(self, user_id: str) -> bool:
204
+ """Permanently delete a user and all associated data."""
205
+ return self._write_scalar(
206
+ "SELECT authn.delete_user(%s::uuid, %s)",
207
+ (user_id, self.namespace),
208
+ )
209
+
210
+ def list_users(self, limit: int = 100, cursor: str | None = None) -> list[dict]:
211
+ """List users with pagination."""
212
+ return self._fetchall(
213
+ "SELECT * FROM authn.list_users(%s, %s, %s)",
214
+ (self.namespace, limit, cursor),
215
+ )
216
+
217
+ # =========================================================================
218
+ # Credentials
219
+ # =========================================================================
220
+
221
+ def get_credentials(self, email: str) -> dict | None:
222
+ """
223
+ Get credentials for login verification.
224
+
225
+ Returns user_id, password_hash, and disabled_at for caller to verify.
226
+ This is the ONLY method that returns password_hash.
227
+ """
228
+ return self._row(
229
+ "SELECT * FROM authn.get_credentials(%s, %s)",
230
+ (email, self.namespace),
231
+ )
232
+
233
+ def update_password(self, user_id: str, new_password_hash: str) -> bool:
234
+ """Update user's password hash."""
235
+ return self._write_scalar(
236
+ "SELECT authn.update_password(%s::uuid, %s, %s)",
237
+ (user_id, new_password_hash, self.namespace),
238
+ )
239
+
240
+ # =========================================================================
241
+ # Sessions
242
+ # =========================================================================
243
+
244
+ def create_session(
245
+ self,
246
+ user_id: str,
247
+ token_hash: str,
248
+ expires_in: timedelta | None = None,
249
+ ip_address: str | None = None,
250
+ user_agent: str | None = None,
251
+ ) -> str:
252
+ """
253
+ Create a new session.
254
+
255
+ Args:
256
+ user_id: User ID
257
+ token_hash: Pre-hashed session token (SHA-256)
258
+ expires_in: Session duration (default: 7 days)
259
+ ip_address: Client IP
260
+ user_agent: Client user agent
261
+
262
+ Returns:
263
+ Session ID (UUID string)
264
+ """
265
+ result = self._write_scalar(
266
+ "SELECT authn.create_session(%s::uuid, %s, %s, %s::inet, %s, %s)",
267
+ (user_id, token_hash, expires_in, ip_address, user_agent, self.namespace),
268
+ )
269
+ return str(result) if result else None
270
+
271
+ def validate_session(self, token_hash: str) -> dict | None:
272
+ """
273
+ Validate a session token.
274
+
275
+ Returns user info if valid, None otherwise.
276
+ Does not log to audit (hot path).
277
+ """
278
+ return self._row(
279
+ "SELECT * FROM authn.validate_session(%s, %s)",
280
+ (token_hash, self.namespace),
281
+ )
282
+
283
+ def extend_session(
284
+ self,
285
+ token_hash: str,
286
+ extend_by: timedelta | None = None,
287
+ ) -> datetime | None:
288
+ """Extend session expiration. Returns new expires_at."""
289
+ return self._scalar(
290
+ "SELECT authn.extend_session(%s, %s, %s)",
291
+ (token_hash, extend_by, self.namespace),
292
+ )
293
+
294
+ def revoke_session(self, token_hash: str) -> bool:
295
+ """Revoke a session."""
296
+ return self._write_scalar(
297
+ "SELECT authn.revoke_session(%s, %s)",
298
+ (token_hash, self.namespace),
299
+ )
300
+
301
+ def revoke_all_sessions(self, user_id: str) -> int:
302
+ """Revoke all sessions for a user. Returns count revoked."""
303
+ return self._write_scalar(
304
+ "SELECT authn.revoke_all_sessions(%s::uuid, %s)",
305
+ (user_id, self.namespace),
306
+ )
307
+
308
+ def list_sessions(self, user_id: str) -> list[dict]:
309
+ """List active sessions for a user. Does not return token_hash."""
310
+ return self._fetchall(
311
+ "SELECT * FROM authn.list_sessions(%s::uuid, %s)",
312
+ (user_id, self.namespace),
313
+ )
314
+
315
+ # =========================================================================
316
+ # Tokens (password reset, email verification, magic links)
317
+ # =========================================================================
318
+
319
+ def create_token(
320
+ self,
321
+ user_id: str,
322
+ token_hash: str,
323
+ token_type: str,
324
+ expires_in: timedelta | None = None,
325
+ ) -> str:
326
+ """
327
+ Create a one-time use token.
328
+
329
+ Args:
330
+ user_id: User ID
331
+ token_hash: Pre-hashed token (SHA-256)
332
+ token_type: 'password_reset', 'email_verify', or 'magic_link'
333
+ expires_in: Token lifetime (defaults vary by type)
334
+
335
+ Returns:
336
+ Token ID (UUID string)
337
+ """
338
+ result = self._write_scalar(
339
+ "SELECT authn.create_token(%s::uuid, %s, %s, %s, %s)",
340
+ (user_id, token_hash, token_type, expires_in, self.namespace),
341
+ )
342
+ return str(result) if result else None
343
+
344
+ def consume_token(self, token_hash: str, token_type: str) -> dict | None:
345
+ """
346
+ Consume a one-time token.
347
+
348
+ Returns user info if valid, None otherwise.
349
+ Token is marked as used after this call.
350
+ """
351
+ return self._row(
352
+ "SELECT * FROM authn.consume_token(%s, %s, %s)",
353
+ (token_hash, token_type, self.namespace),
354
+ )
355
+
356
+ def verify_email(self, token_hash: str) -> dict | None:
357
+ """
358
+ Verify email using a token.
359
+
360
+ Convenience method that consumes email_verify token and sets email_verified_at.
361
+ """
362
+ return self._row(
363
+ "SELECT * FROM authn.verify_email(%s, %s)",
364
+ (token_hash, self.namespace),
365
+ )
366
+
367
+ def invalidate_tokens(self, user_id: str, token_type: str) -> int:
368
+ """Invalidate all unused tokens of a type for a user."""
369
+ return self._write_scalar(
370
+ "SELECT authn.invalidate_tokens(%s::uuid, %s, %s)",
371
+ (user_id, token_type, self.namespace),
372
+ )
373
+
374
+ # =========================================================================
375
+ # MFA
376
+ # =========================================================================
377
+
378
+ def add_mfa(
379
+ self,
380
+ user_id: str,
381
+ mfa_type: str,
382
+ secret: str,
383
+ name: str | None = None,
384
+ ) -> str:
385
+ """
386
+ Add an MFA method for a user.
387
+
388
+ Args:
389
+ user_id: User ID
390
+ mfa_type: 'totp', 'webauthn', or 'recovery_codes'
391
+ secret: The MFA secret (caller stores this securely)
392
+ name: Optional friendly name
393
+
394
+ Returns:
395
+ MFA ID (UUID string)
396
+ """
397
+ result = self._write_scalar(
398
+ "SELECT authn.add_mfa(%s::uuid, %s, %s, %s, %s)",
399
+ (user_id, mfa_type, secret, name, self.namespace),
400
+ )
401
+ return str(result) if result else None
402
+
403
+ def get_mfa(self, user_id: str, mfa_type: str) -> list[dict]:
404
+ """Get MFA secrets for verification. Returns secrets!"""
405
+ return self._fetchall(
406
+ "SELECT * FROM authn.get_mfa(%s::uuid, %s, %s)",
407
+ (user_id, mfa_type, self.namespace),
408
+ )
409
+
410
+ def list_mfa(self, user_id: str) -> list[dict]:
411
+ """List MFA methods. Does NOT return secrets."""
412
+ return self._fetchall(
413
+ "SELECT * FROM authn.list_mfa(%s::uuid, %s)",
414
+ (user_id, self.namespace),
415
+ )
416
+
417
+ def remove_mfa(self, mfa_id: str) -> bool:
418
+ """Remove an MFA method."""
419
+ return self._write_scalar(
420
+ "SELECT authn.remove_mfa(%s::uuid, %s)",
421
+ (mfa_id, self.namespace),
422
+ )
423
+
424
+ def record_mfa_use(self, mfa_id: str) -> bool:
425
+ """Record that an MFA method was used."""
426
+ return self._write_scalar(
427
+ "SELECT authn.record_mfa_use(%s::uuid, %s)",
428
+ (mfa_id, self.namespace),
429
+ )
430
+
431
+ def has_mfa(self, user_id: str) -> bool:
432
+ """Check if user has any MFA method enabled."""
433
+ return self._scalar(
434
+ "SELECT authn.has_mfa(%s::uuid, %s)",
435
+ (user_id, self.namespace),
436
+ )
437
+
438
+ # =========================================================================
439
+ # Lockout
440
+ # =========================================================================
441
+
442
+ def record_login_attempt(
443
+ self,
444
+ email: str,
445
+ success: bool,
446
+ ip_address: str | None = None,
447
+ ) -> None:
448
+ """Record a login attempt."""
449
+ self._scalar(
450
+ "SELECT authn.record_login_attempt(%s, %s, %s::inet, %s)",
451
+ (email, success, ip_address, self.namespace),
452
+ )
453
+
454
+ def is_locked_out(
455
+ self,
456
+ email: str,
457
+ window: timedelta | None = None,
458
+ max_attempts: int | None = None,
459
+ ) -> bool:
460
+ """Check if an email is locked out due to too many failed attempts."""
461
+ return self._scalar(
462
+ "SELECT authn.is_locked_out(%s, %s, %s, %s)",
463
+ (email, self.namespace, window, max_attempts),
464
+ )
465
+
466
+ def get_recent_attempts(self, email: str, limit: int = 10) -> list[dict]:
467
+ """Get recent login attempts for an email."""
468
+ return self._fetchall(
469
+ "SELECT * FROM authn.get_recent_attempts(%s, %s, %s)",
470
+ (email, self.namespace, limit),
471
+ )
472
+
473
+ def clear_attempts(self, email: str) -> int:
474
+ """Clear login attempts for an email. Returns count deleted."""
475
+ return self._write_scalar(
476
+ "SELECT authn.clear_attempts(%s, %s)",
477
+ (email, self.namespace),
478
+ )
479
+
480
+ # =========================================================================
481
+ # Maintenance
482
+ # =========================================================================
483
+
484
+ def cleanup_expired(self) -> dict:
485
+ """Clean up expired sessions, tokens, and old login attempts."""
486
+ result = self._row(
487
+ "SELECT * FROM authn.cleanup_expired(%s)",
488
+ (self.namespace,),
489
+ )
490
+ return result or {}
491
+
492
+ def get_stats(self) -> dict:
493
+ """Get namespace statistics."""
494
+ result = self._row(
495
+ "SELECT * FROM authn.get_stats(%s)",
496
+ (self.namespace,),
497
+ )
498
+ return result or {}
499
+
500
+ # =========================================================================
501
+ # Audit Context
502
+ # =========================================================================
503
+
504
+ def set_actor(
505
+ self,
506
+ actor_id: str,
507
+ request_id: str | None = None,
508
+ ip_address: str | None = None,
509
+ user_agent: str | None = None,
510
+ ) -> None:
511
+ """
512
+ Set actor context for audit logging.
513
+
514
+ Note: Unlike set_tenant() which applies immediately via SQL, actor context
515
+ is stored as instance state and applied per-operation in _write_scalar.
516
+ This ensures actor context is set within the same transaction as the
517
+ audited operation (required because PostgreSQL's set_config with is_local=true
518
+ only persists within the current transaction).
519
+ """
520
+ self._actor_id = actor_id
521
+ self._request_id = request_id
522
+ self._ip_address = ip_address
523
+ self._user_agent = user_agent
524
+
525
+ def clear_actor(self) -> None:
526
+ """Clear actor context."""
527
+ self._actor_id = None
528
+ self._request_id = None
529
+ self._ip_address = None
530
+ self._user_agent = None
531
+
532
+ def get_audit_events(
533
+ self,
534
+ limit: int = 100,
535
+ event_type: str | None = None,
536
+ resource_type: str | None = None,
537
+ resource_id: str | None = None,
538
+ ) -> list[dict]:
539
+ """Query audit events."""
540
+ conditions = ["namespace = %s"]
541
+ params: list = [self.namespace]
542
+
543
+ if event_type is not None:
544
+ conditions.append("event_type = %s")
545
+ params.append(event_type)
546
+
547
+ if resource_type is not None:
548
+ conditions.append("resource_type = %s")
549
+ params.append(resource_type)
550
+
551
+ if resource_id is not None:
552
+ conditions.append("resource_id = %s")
553
+ params.append(resource_id)
554
+
555
+ params.append(limit)
556
+
557
+ sql = f"""
558
+ SELECT *
559
+ FROM authn.audit_events
560
+ WHERE {' AND '.join(conditions)}
561
+ ORDER BY event_time DESC, id DESC
562
+ LIMIT %s
563
+ """
564
+
565
+ self.cursor.execute(sql, tuple(params))
566
+ columns = [desc[0] for desc in self.cursor.description]
567
+ return [dict(zip(columns, row)) for row in self.cursor.fetchall()]
@@ -0,0 +1,17 @@
1
+ """postkit.authz - Authorization client for PostgreSQL-native ReBAC."""
2
+
3
+ from postkit.authz.client import (
4
+ AuthzClient,
5
+ AuthzCycleError,
6
+ AuthzError,
7
+ AuthzValidationError,
8
+ Entity,
9
+ )
10
+
11
+ __all__ = [
12
+ "AuthzClient",
13
+ "AuthzError",
14
+ "AuthzValidationError",
15
+ "AuthzCycleError",
16
+ "Entity",
17
+ ]