python-in-underwear 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
piu/__init__.py ADDED
@@ -0,0 +1,36 @@
1
+ from .app import PIU
2
+ from .wrappers import Request, Response
3
+ from .routing import Router, Route, Blueprint
4
+ from .middleware import MiddlewareStack
5
+ from .templating import TemplateEngine
6
+ from .static import serve_static
7
+ from .helpers import status_text
8
+ from .config import Config
9
+ from .sessions import SessionMiddleware, Session
10
+ from .csrf import CSRFMiddleware, csrf_input
11
+ from .ratelimit import RateLimitMiddleware, rate_limit
12
+ from .auth import require_auth, login_user, logout_user, current_user, is_authenticated
13
+ from .testing import TestClient
14
+ from .plugins import Plugin
15
+ from .tasks import BackgroundTasks
16
+ from .websocket import WebSocket
17
+ from .openapi import generate_schema
18
+
19
+ __all__ = [
20
+ "PIU", "Request", "Response",
21
+ "Router", "Route", "Blueprint",
22
+ "MiddlewareStack", "TemplateEngine",
23
+ "serve_static", "status_text",
24
+ "Config",
25
+ "SessionMiddleware", "Session",
26
+ "CSRFMiddleware", "csrf_input",
27
+ "RateLimitMiddleware", "rate_limit",
28
+ "require_auth", "login_user", "logout_user",
29
+ "current_user", "is_authenticated",
30
+ "TestClient",
31
+ "Plugin", "BackgroundTasks",
32
+ "WebSocket", "generate_schema",
33
+ ]
34
+
35
+ __version__ = "0.5.0"
36
+ __author__ = "TodorW & n11kol11c"
piu/__main__.py ADDED
@@ -0,0 +1,3 @@
1
+ from piu.cli import main
2
+
3
+ main()
piu/app.py ADDED
@@ -0,0 +1,226 @@
1
+ import asyncio
2
+ import inspect
3
+ from typing import Callable, Optional
4
+ from urllib.parse import parse_qs, urlparse
5
+
6
+ from .auth import current_user, is_authenticated, login_user, logout_user, require_auth
7
+ from .config import Config
8
+ from .csrf import CSRFMiddleware, csrf_input
9
+ from .helpers import status_text
10
+ from .middleware import MiddlewareStack
11
+ from .openapi import SWAGGER_HTML, generate_schema
12
+ from .plugins import Plugin
13
+ from .ratelimit import RateLimitMiddleware, rate_limit
14
+ from .routing import Blueprint, Router
15
+ from .serving import run_dev_server
16
+ from .sessions import SessionMiddleware
17
+ from .static import serve_static
18
+ from .tasks import BackgroundTasks
19
+ from .templating import TemplateEngine
20
+ from .websocket import WebSocket, WebSocketRouter
21
+ from .wrappers import Request, Response
22
+
23
+
24
+ class PIU:
25
+ def __init__(self, template_dir: str = None,
26
+ static_dir: str = None,
27
+ static_url: str = None,
28
+ config: dict = None):
29
+ self.config = Config(config)
30
+ self.router = Router()
31
+ self.ws_router = WebSocketRouter()
32
+ self.middleware = MiddlewareStack()
33
+ self._template_dir = template_dir or self.config["TEMPLATE_DIR"]
34
+ self._static_dir = static_dir or self.config["STATIC_DIR"]
35
+ self._static_url = static_url or self.config["STATIC_URL"]
36
+ self._template_engine: Optional[TemplateEngine] = None
37
+ self._error_handlers: dict[int, Callable] = {}
38
+ self._plugins: list[Plugin] = []
39
+ self._docs_enabled = False
40
+ self._docs_title = "PIU API"
41
+
42
+ async def __call__(self, scope: dict, receive: Callable, send: Callable):
43
+ await self.asgi(scope, receive, send)
44
+
45
+ def route(self, path: str, methods: list[str] = ["GET"]):
46
+ def decorator(fn: Callable):
47
+ self.router.add_route(path, fn, methods)
48
+ return fn
49
+ return decorator
50
+
51
+ def get(self, path: str): return self.route(path, methods=["GET"])
52
+ def post(self, path: str): return self.route(path, methods=["POST"])
53
+ def put(self, path: str): return self.route(path, methods=["PUT"])
54
+ def patch(self, path: str): return self.route(path, methods=["PATCH"])
55
+ def delete(self, path: str): return self.route(path, methods=["DELETE"])
56
+
57
+ def ws(self, path: str):
58
+ def decorator(fn: Callable):
59
+ self.ws_router.add(path, fn)
60
+ return fn
61
+ return decorator
62
+
63
+ def enable_docs(self, title: str = "PIU API", path: str = "/docs"):
64
+ self._docs_enabled = True
65
+ self._docs_title = title
66
+ self._docs_path = path
67
+
68
+ @self.get(path)
69
+ def _swagger(request: Request):
70
+ html = SWAGGER_HTML.format(title=title)
71
+ return Response(body=html, content_type="text/html")
72
+
73
+ @self.get("/openapi.json")
74
+ def _schema(request: Request):
75
+ schema = generate_schema(
76
+ self.router,
77
+ title=title,
78
+ version=self.config.get("VERSION", "0.1.0"),
79
+ )
80
+ return Response.json(schema)
81
+
82
+ def register_plugin(self, plugin: Plugin):
83
+ plugin.setup(self)
84
+ self._plugins.append(plugin)
85
+
86
+ def register(self, blueprint: Blueprint, prefix: str = None):
87
+ bp_prefix = (prefix or blueprint.prefix).rstrip("/")
88
+ for path, handler, methods in blueprint._routes:
89
+ full_path = bp_prefix + ("" if path == "/" else path)
90
+ self.router.add_route(full_path, handler, methods)
91
+
92
+ def errorhandler(self, status_code: int):
93
+ def decorator(fn: Callable):
94
+ self._error_handlers[status_code] = fn
95
+ return fn
96
+ return decorator
97
+
98
+ async def _handle_error(self, request: Request, status: int, error: Exception = None) -> Response:
99
+ handler = self._error_handlers.get(status)
100
+ if handler:
101
+ result = handler(request, error) if not inspect.iscoroutinefunction(handler) \
102
+ else await handler(request, error)
103
+ return result if isinstance(result, Response) else Response(body=result, status=status)
104
+ return Response(body=f"{status} {status_text(status)}", status=status)
105
+
106
+ def render(self, template_name: str, **context) -> Response:
107
+ if self._template_engine is None:
108
+ self._template_engine = TemplateEngine(self._template_dir)
109
+ html = self._template_engine.render(template_name, **context)
110
+ return Response(body=html, content_type="text/html")
111
+
112
+ async def _dispatch(self, request: Request) -> Response:
113
+ static_resp = serve_static(request.path, self._static_dir, self._static_url)
114
+ if static_resp is not None:
115
+ return static_resp
116
+
117
+ handler, path_params = self.router.resolve(request.path, request.method)
118
+
119
+ if handler is None:
120
+ return await self._handle_error(request, 404)
121
+
122
+ request.background_tasks = BackgroundTasks()
123
+
124
+ async def call_handler(req: Request) -> Response:
125
+ try:
126
+ result = await handler(req, **path_params) \
127
+ if inspect.iscoroutinefunction(handler) \
128
+ else handler(req, **path_params)
129
+ resp = result if isinstance(result, Response) else Response(body=result)
130
+ await req.background_tasks.run_all()
131
+ return resp
132
+ except Exception as e:
133
+ return await self._handle_error(req, 500, e)
134
+
135
+ return await self.middleware.run(request, call_handler)
136
+
137
+ def _finalize(self, response: Response) -> Response:
138
+ for k, v in response._cookie_headers():
139
+ existing = response.headers.get(k)
140
+ if existing:
141
+ response.headers[k] = existing + "\n" + v
142
+ else:
143
+ response.headers[k] = v
144
+ return response
145
+
146
+ def wsgi(self, environ: dict, start_response: Callable):
147
+ parsed = urlparse(environ.get("PATH_INFO", "/"))
148
+ query = parse_qs(environ.get("QUERY_STRING", ""))
149
+ length = int(environ.get("CONTENT_LENGTH") or 0)
150
+ body = environ["wsgi.input"].read(length) if length else b""
151
+
152
+ headers = {
153
+ k[5:].replace("_", "-").title(): v
154
+ for k, v in environ.items() if k.startswith("HTTP_")
155
+ }
156
+
157
+ request = Request(
158
+ method=environ.get("REQUEST_METHOD", "GET"),
159
+ path=parsed.path, headers=headers,
160
+ body=body, query_params=query,
161
+ )
162
+
163
+ response = self._finalize(asyncio.run(self._dispatch(request)))
164
+ status_str = f"{response.status} {status_text(response.status)}"
165
+ resp_headers = [("Content-Type", response.content_type)]
166
+ for k, v in response.headers.items():
167
+ resp_headers.append((k, v))
168
+
169
+ start_response(status_str, resp_headers)
170
+ return [response.body]
171
+
172
+ async def asgi(self, scope: dict, receive: Callable, send: Callable):
173
+ if scope["type"] == "websocket":
174
+ handler, params = self.ws_router.resolve(scope.get("path", "/"))
175
+ if handler is None:
176
+ await send({"type": "websocket.close", "code": 4004})
177
+ return
178
+ ws = WebSocket(scope, receive, send)
179
+ await send({"type": "websocket.accept"})
180
+ if inspect.iscoroutinefunction(handler):
181
+ await handler(ws, **params)
182
+ else:
183
+ handler(ws, **params)
184
+ return
185
+
186
+ if scope["type"] != "http":
187
+ return
188
+
189
+ body = b""
190
+ while True:
191
+ event = await receive()
192
+ body += event.get("body", b"")
193
+ if not event.get("more_body", False):
194
+ break
195
+
196
+ query = parse_qs(scope.get("query_string", b"").decode())
197
+ headers = {k.decode(): v.decode() for k, v in scope.get("headers", [])}
198
+
199
+ request = Request(
200
+ method=scope.get("method", "GET"),
201
+ path=scope.get("path", "/"),
202
+ headers=headers, body=body, query_params=query,
203
+ )
204
+
205
+ response = self._finalize(await self._dispatch(request))
206
+
207
+ await send({
208
+ "type": "http.response.start",
209
+ "status": response.status,
210
+ "headers": [
211
+ [b"content-type", response.content_type.encode()],
212
+ *[[k.encode(), v.encode()] for k, v in response.headers.items()]
213
+ ],
214
+ })
215
+ await send({"type": "http.response.body", "body": response.body})
216
+
217
+ def run(self, host: str = None, port: int = None, reload: bool = None):
218
+ run_dev_server(
219
+ self,
220
+ host = host or self.config.get("HOST", "127.0.0.1"),
221
+ port = port or self.config.get("PORT", 5000),
222
+ reload = reload if reload is not None else self.config.get("DEBUG", False),
223
+ )
224
+
225
+ def __repr__(self):
226
+ return f"<PIU routes={len(self.router._routes)} middleware={len(self.middleware._middlewares)}>"
piu/auth.py ADDED
@@ -0,0 +1,114 @@
1
+ """
2
+ Auth utilities for PIU.
3
+
4
+ Provides:
5
+ - @require_auth — decorator that guards a route, redirects or 401s if not authed
6
+ - @require_auth(role=...) — additionally checks request.user["role"]
7
+ - login_user(request, user_data) — store user in session
8
+ - logout_user(request) — clear user from session
9
+ - current_user(request) — retrieve user dict from session (or None)
10
+
11
+ Requires SessionMiddleware to be registered.
12
+
13
+ Example::
14
+
15
+ @app.post("/login")
16
+ def login(request):
17
+ user = db.check_credentials(request.json())
18
+ if not user:
19
+ return Response(body="Bad credentials", status=401)
20
+ login_user(request, {"id": user.id, "role": "admin"})
21
+ return Response.redirect("/dashboard")
22
+
23
+ @app.get("/dashboard")
24
+ @require_auth
25
+ def dashboard(request):
26
+ return Response(body=f"Hello {current_user(request)['id']}")
27
+
28
+ @app.get("/admin")
29
+ @require_auth(role="admin")
30
+ def admin_panel(request):
31
+ ...
32
+ """
33
+
34
+ import functools
35
+ import inspect
36
+ from typing import Callable
37
+
38
+ from .wrappers import Request, Response
39
+
40
+ SESSION_USER_KEY = "_auth_user"
41
+
42
+ def login_user(request: Request, user_data: dict):
43
+ """Store user data in the session."""
44
+ _require_session(request)
45
+ request.session[SESSION_USER_KEY] = user_data
46
+
47
+
48
+ def logout_user(request: Request):
49
+ """Remove user data from the session."""
50
+ _require_session(request)
51
+ request.session.pop(SESSION_USER_KEY, None)
52
+
53
+
54
+ def current_user(request: Request) -> dict | None:
55
+ """Return the logged-in user dict, or None."""
56
+ _require_session(request)
57
+ return request.session.get(SESSION_USER_KEY)
58
+
59
+
60
+ def is_authenticated(request: Request) -> bool:
61
+ return current_user(request) is not None
62
+
63
+
64
+ def _require_session(request: Request):
65
+ if not hasattr(request, "session"):
66
+ raise RuntimeError(
67
+ "Auth helpers require SessionMiddleware. "
68
+ "Register it before using @require_auth."
69
+ )
70
+
71
+ def require_auth(fn: Callable = None, *, role: str = None,
72
+ redirect_to: str = None, status: int = 401):
73
+ """
74
+ Guard a route handler. Can be used with or without arguments:
75
+
76
+ @require_auth
77
+ def view(request): ...
78
+
79
+ @require_auth(role="admin", redirect_to="/login")
80
+ def admin(request): ...
81
+
82
+ Args:
83
+ role: If set, also checks request.session["_auth_user"]["role"].
84
+ redirect_to: If set, returns a redirect instead of a 401/403 response.
85
+ status: Status code for unauthenticated responses (default 401).
86
+ Automatically becomes 403 when role check fails.
87
+ """
88
+ def decorator(handler: Callable):
89
+ @functools.wraps(handler)
90
+ async def wrapper(request: Request, *args, **kwargs):
91
+ _require_session(request)
92
+ user = current_user(request)
93
+
94
+ if not user:
95
+ return _deny(redirect_to, status, "401 Unauthorized.")
96
+
97
+ if role and user.get("role") != role:
98
+ return _deny(redirect_to, 403, f"403 Forbidden — requires role '{role}'.")
99
+
100
+ if inspect.iscoroutinefunction(handler):
101
+ return await handler(request, *args, **kwargs)
102
+ return handler(request, *args, **kwargs)
103
+
104
+ return wrapper
105
+
106
+ if fn is not None:
107
+ return decorator(fn)
108
+ return decorator
109
+
110
+
111
+ def _deny(redirect_to: str, status: int, message: str) -> Response:
112
+ if redirect_to:
113
+ return Response.redirect(redirect_to)
114
+ return Response(body=message, status=status)
piu/cli.py ADDED
@@ -0,0 +1,161 @@
1
+ import argparse
2
+ import os
3
+ import sys
4
+
5
+
6
+ _APP_TEMPLATE = '''\
7
+ import piu
8
+ from piu import (
9
+ PIU, Request, Response,
10
+ SessionMiddleware, CSRFMiddleware,
11
+ RateLimitMiddleware,
12
+ )
13
+
14
+ app = PIU()
15
+
16
+ app.config.from_env_file(".env")
17
+ app.config.load_env()
18
+
19
+ app.middleware.use(RateLimitMiddleware(limit=100, window=60))
20
+ app.middleware.use(SessionMiddleware(
21
+ secret_key=app.config.get("SECRET_KEY", "dev-secret"),
22
+ max_age=3600,
23
+ ))
24
+ app.middleware.use(CSRFMiddleware())
25
+
26
+
27
+ @app.errorhandler(404)
28
+ def not_found(request, error):
29
+ return Response(body="<h1>404 — Not found</h1>", status=404)
30
+
31
+
32
+ @app.get("/")
33
+ def index(request: Request):
34
+ return Response(body="<h1>Hello from PIU 🩲</h1>")
35
+
36
+
37
+ if __name__ == "__main__":
38
+ app.run()
39
+ '''
40
+
41
+ _ENV_TEMPLATE = """\
42
+ DEBUG=true
43
+ HOST=127.0.0.1
44
+ PORT=5000
45
+ SECRET_KEY=change-me-in-production
46
+ """
47
+
48
+ _HTML_TEMPLATE = """\
49
+ <!DOCTYPE html>
50
+ <html lang="en">
51
+ <head>
52
+ <meta charset="UTF-8">
53
+ <title>{{ title }}</title>
54
+ </head>
55
+ <body>
56
+ <h1>🩲 {{ title }}</h1>
57
+ <p>Your PIU app is running.</p>
58
+ </body>
59
+ </html>
60
+ """
61
+
62
+ _GITIGNORE_TEMPLATE = """\
63
+ __pycache__/
64
+ *.pyc
65
+ .env
66
+ *.egg-info/
67
+ dist/
68
+ .venv/
69
+ """
70
+
71
+ _TEST_TEMPLATE = '''\
72
+ from piu.testing import TestClient
73
+ from app import app
74
+
75
+ client = TestClient(app)
76
+
77
+
78
+ def test_index():
79
+ resp = client.get("/")
80
+ assert resp.status == 200
81
+ '''
82
+
83
+
84
+ def cmd_new(args):
85
+ name = args.name
86
+ base = os.path.join(os.getcwd(), name)
87
+
88
+ if os.path.exists(base):
89
+ print(f"[PIU] Error: directory '{name}' already exists.")
90
+ sys.exit(1)
91
+
92
+ dirs = [base, os.path.join(base, "templates"),
93
+ os.path.join(base, "static"), os.path.join(base, "tests")]
94
+ for d in dirs:
95
+ os.makedirs(d)
96
+
97
+ files = {
98
+ os.path.join(base, "app.py"): _APP_TEMPLATE,
99
+ os.path.join(base, ".env"): _ENV_TEMPLATE,
100
+ os.path.join(base, "templates", "index.html"): _HTML_TEMPLATE,
101
+ os.path.join(base, ".gitignore"): _GITIGNORE_TEMPLATE,
102
+ os.path.join(base, "tests", "test_app.py"): _TEST_TEMPLATE,
103
+ }
104
+ for path, content in files.items():
105
+ with open(path, "w") as f:
106
+ f.write(content)
107
+
108
+ print(f"\n[PIU] 🩲 '{name}' created!\n")
109
+ print(f" cd {name}")
110
+ print(f" python app.py\n")
111
+
112
+
113
+ def cmd_run(args):
114
+ if not os.path.isfile("app.py"):
115
+ print("[PIU] Error: no app.py found in the current directory.")
116
+ sys.exit(1)
117
+
118
+ sys.path.insert(0, os.getcwd())
119
+ try:
120
+ import importlib
121
+ module = importlib.import_module("app")
122
+ except Exception as e:
123
+ print(f"[PIU] Failed to import app.py: {e}")
124
+ sys.exit(1)
125
+
126
+ app = getattr(module, "app", None)
127
+ if app is None:
128
+ print("[PIU] Error: app.py must define an 'app' variable.")
129
+ sys.exit(1)
130
+
131
+ if os.path.isfile(".env"):
132
+ app.config.from_env_file(".env")
133
+ app.config.load_env()
134
+
135
+ host = args.host or app.config.get("HOST", "127.0.0.1")
136
+ port = args.port or app.config.get("PORT", 5000)
137
+ reload = args.reload or app.config.get("DEBUG", False)
138
+
139
+ app.run(host=host, port=int(port), reload=reload)
140
+
141
+
142
+ def main():
143
+ parser = argparse.ArgumentParser(prog="piu", description="🩲 Python In Underwear CLI")
144
+ sub = parser.add_subparsers(dest="command", required=True)
145
+
146
+ p_new = sub.add_parser("new", help="Scaffold a new PIU project")
147
+ p_new.add_argument("name", help="Project name")
148
+ p_new.set_defaults(func=cmd_new)
149
+
150
+ p_run = sub.add_parser("run", help="Run the development server")
151
+ p_run.add_argument("--host", default=None)
152
+ p_run.add_argument("--port", default=None, type=int)
153
+ p_run.add_argument("--reload", action="store_true")
154
+ p_run.set_defaults(func=cmd_run)
155
+
156
+ args = parser.parse_args()
157
+ args.func(args)
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
piu/config.py ADDED
@@ -0,0 +1,85 @@
1
+ import os
2
+ from typing import Any
3
+
4
+
5
+ class Config:
6
+ def __init__(self, defaults: dict = None):
7
+ self._data: dict[str, Any] = {
8
+ "DEBUG": False,
9
+ "HOST": "127.0.0.1",
10
+ "PORT": 5000,
11
+ "SECRET_KEY": "",
12
+ "TEMPLATE_DIR": "templates",
13
+ "STATIC_DIR": "static",
14
+ "STATIC_URL": "/static",
15
+ }
16
+ if defaults:
17
+ self._data.update(defaults)
18
+
19
+ def __getitem__(self, key: str) -> Any:
20
+ return self._data[key.upper()]
21
+
22
+ def __setitem__(self, key: str, value: Any):
23
+ self._data[key.upper()] = value
24
+
25
+ def __contains__(self, key: str) -> bool:
26
+ return key.upper() in self._data
27
+
28
+ def get(self, key: str, default: Any = None) -> Any:
29
+ return self._data.get(key.upper(), default)
30
+
31
+ def set(self, key: str, value: Any):
32
+ self._data[key.upper()] = value
33
+
34
+ def all(self) -> dict:
35
+ return dict(self._data)
36
+
37
+ def from_dict(self, d: dict):
38
+ for k, v in d.items():
39
+ self._data[k.upper()] = v
40
+
41
+ def from_env_file(self, path: str = ".env"):
42
+ if not os.path.isfile(path):
43
+ return
44
+ with open(path) as f:
45
+ for line in f:
46
+ line = line.strip()
47
+ if not line or line.startswith("#") or "=" not in line:
48
+ continue
49
+ key, _, val = line.partition("=")
50
+ self._data[key.strip().upper()] = _cast(val.strip())
51
+
52
+ def from_yaml(self, path: str):
53
+ try:
54
+ import yaml
55
+ except ImportError:
56
+ raise RuntimeError("from_yaml() requires pyyaml. Run: pip install pyyaml")
57
+ with open(path) as f:
58
+ data = yaml.safe_load(f) or {}
59
+ for k, v in data.items():
60
+ self._data[k.upper()] = v
61
+
62
+ def load_env(self, prefix: str = "PIU_"):
63
+ for k, v in os.environ.items():
64
+ if k.startswith(prefix):
65
+ key = k[len(prefix):]
66
+ self._data[key.upper()] = _cast(v)
67
+
68
+ def __repr__(self):
69
+ return f"<Config keys={list(self._data.keys())}>"
70
+
71
+
72
+ def _cast(value: str) -> Any:
73
+ if value.lower() in ("true", "yes"):
74
+ return True
75
+ if value.lower() in ("false", "no"):
76
+ return False
77
+ try:
78
+ return int(value)
79
+ except ValueError:
80
+ pass
81
+ try:
82
+ return float(value)
83
+ except ValueError:
84
+ pass
85
+ return value
piu/csrf.py ADDED
@@ -0,0 +1,65 @@
1
+ import os
2
+ import hmac
3
+ import hashlib
4
+ from typing import Callable
5
+
6
+ from .wrappers import Request, Response
7
+
8
+ CSRF_SESSION_KEY = "_csrf_token"
9
+ UNSAFE_METHODS = {"POST", "PUT", "PATCH", "DELETE"}
10
+
11
+
12
+ def _generate_token() -> str:
13
+ return os.urandom(32).hex()
14
+
15
+
16
+ def _tokens_equal(a: str, b: str) -> bool:
17
+ return hmac.compare_digest(a.encode(), b.encode())
18
+
19
+
20
+ class CSRFMiddleware:
21
+ def __init__(self, exempt_paths: list[str] = None):
22
+ """
23
+ Args:
24
+ exempt_paths: List of URL path prefixes to skip CSRF checks on.
25
+ Useful for API routes using token auth instead of sessions.
26
+ e.g. ["/api/", "/webhook"]
27
+ """
28
+ self._exempt = exempt_paths or []
29
+
30
+ def _is_exempt(self, path: str) -> bool:
31
+ return any(path.startswith(p) for p in self._exempt)
32
+
33
+ async def __call__(self, request: Request, next: Callable) -> Response:
34
+ if not hasattr(request, "session"):
35
+ raise RuntimeError(
36
+ "CSRFMiddleware requires SessionMiddleware to run first. "
37
+ "Register SessionMiddleware before CSRFMiddleware."
38
+ )
39
+
40
+
41
+ if CSRF_SESSION_KEY not in request.session:
42
+ request.session[CSRF_SESSION_KEY] = _generate_token()
43
+
44
+ token = request.session[CSRF_SESSION_KEY]
45
+
46
+
47
+ request.csrf_token = token
48
+
49
+
50
+ if request.method in UNSAFE_METHODS and not self._is_exempt(request.path):
51
+
52
+ submitted = (
53
+ request.headers.get("X-Csrf-Token")
54
+ or request.headers.get("X-CSRF-Token")
55
+ or request.form().get("_csrf_token", [None])[0]
56
+ )
57
+ if not submitted or not _tokens_equal(token, submitted):
58
+ return Response(body="403 CSRF token invalid or missing.", status=403)
59
+
60
+ return await next(request)
61
+
62
+
63
+ def csrf_input(token: str) -> str:
64
+ """Return an HTML hidden input string for use in templates."""
65
+ return f'<input type="hidden" name="_csrf_token" value="{token}">'