wau 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,88 @@
1
+ Metadata-Version: 2.4
2
+ Name: wau
3
+ Version: 0.1.0
4
+ Summary: Web API Utils
5
+ Author: Marco Schmalz
6
+ License-Expression: LGPL-3.0-or-later
7
+ Keywords: api,json,werkzeug,education
8
+ Classifier: Development Status :: 3 - Alpha
9
+ Classifier: Intended Audience :: Education
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Programming Language :: Python :: 3.14
12
+ Classifier: Topic :: Internet :: WWW/HTTP
13
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
14
+ Requires-Python: >=3.14
15
+ Description-Content-Type: text/markdown
16
+ License-File: LICENSE
17
+ Requires-Dist: dataset>=2.0.0
18
+ Requires-Dist: pyjwt>=2.13.0
19
+ Requires-Dist: werkzeug>=3.1.8
20
+ Dynamic: license-file
21
+
22
+ # `wau` — Web API Utils
23
+
24
+ Web API Utils, or short `wau`, is a thin layer on top of Werkzeug to provide a simple and consistent interface for writing APIs in Python. `wau` is built for educational purposes and is not intended for production use. It is opinionated, as it only supports JSON as data format. It uses simple type annotations to define the expected input and output of the API endpoints. Common tasks as authentication, CORS and server-sent events are supported by default.
25
+
26
+ ## Installation
27
+
28
+ Install from PyPI:
29
+
30
+ ```powershell
31
+ pip install wau
32
+ ```
33
+
34
+ or with uv:
35
+
36
+ ```powershell
37
+ uv add wau
38
+ ```
39
+
40
+ ## Testing
41
+
42
+ Test dependencies are separated from runtime dependencies in `pyproject.toml`
43
+ using the `test` dependency group.
44
+
45
+ Run the test suite:
46
+
47
+ ```powershell
48
+ uv run --group test python -m pytest -q
49
+ ```
50
+
51
+ Run doctests:
52
+
53
+ ```powershell
54
+ uv run --group test python -m doctest .\wau.py
55
+ ```
56
+
57
+ ## Publishing
58
+
59
+ Build package artifacts:
60
+
61
+ ```powershell
62
+ uv build
63
+ ```
64
+
65
+ Validate metadata and README rendering:
66
+
67
+ ```powershell
68
+ uvx twine check dist/*
69
+ ```
70
+
71
+ Upload to TestPyPI first:
72
+
73
+ ```powershell
74
+ uv publish --publish-url https://test.pypi.org/legacy/
75
+ ```
76
+
77
+ Then publish to PyPI:
78
+
79
+ ```powershell
80
+ uv publish
81
+ ```
82
+
83
+ ## License
84
+
85
+ This project is licensed under GNU LGPL v3 or later (`LGPL-3.0-or-later`).
86
+
87
+ If you distribute modified versions of this library, those library
88
+ modifications must be published under the same license terms.
@@ -0,0 +1,6 @@
1
+ wau.py,sha256=wQsoNvzvH6HFRBj8ZKIpKjT87_9aB-WfOi0kEt8pdCU,37026
2
+ wau-0.1.0.dist-info/licenses/LICENSE,sha256=5lfLO8VaZisSCb3xtnOUdjX26TYX52CYXYRVUJsSzUk,500
3
+ wau-0.1.0.dist-info/METADATA,sha256=bm0_c9LKeic0-u1efpcsK2YvduO8yVgHtNKXq3Ak_dY,2172
4
+ wau-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
5
+ wau-0.1.0.dist-info/top_level.txt,sha256=xH8uF0IWDusYlJBXD4LSHfJLMSsa-Yyj0_2I26TAJQQ,4
6
+ wau-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,14 @@
1
+ GNU LESSER GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ This project is licensed under the GNU Lesser General Public License,
9
+ version 3 or (at your option) any later version.
10
+
11
+ For the full license text, see:
12
+ https://www.gnu.org/licenses/lgpl-3.0.txt
13
+
14
+ SPDX-License-Identifier: LGPL-3.0-or-later
@@ -0,0 +1 @@
1
+ wau
wau.py ADDED
@@ -0,0 +1,1095 @@
1
+ import collections
2
+ import datetime
3
+ import functools
4
+ import inspect
5
+ import itertools
6
+ import json
7
+ import queue
8
+ import re
9
+ import sys
10
+ import threading
11
+ import traceback
12
+ import urllib.parse
13
+
14
+ import werkzeug
15
+ from werkzeug.exceptions import (
16
+ HTTPException,
17
+ NotFound,
18
+ Unauthorized,
19
+ UnprocessableEntity,
20
+ UnsupportedMediaType,
21
+ )
22
+ from werkzeug.middleware.dispatcher import DispatcherMiddleware
23
+ from werkzeug.routing import Map, Rule
24
+ from werkzeug.security import check_password_hash
25
+
26
+ try:
27
+ import jwt
28
+ except ImportError:
29
+ # Do not complain now, but only when auth classes get instantiated
30
+ jwt = None
31
+
32
+
33
+ class API:
34
+ """An API speaking in JSON with the outside world.
35
+
36
+ Here is a very simple first API:
37
+ >>> app = API()
38
+ >>> @app.GET("/hello")
39
+ ... def root(request):
40
+ ... return "Hello World"
41
+ ...
42
+ >>> from werkzeug.test import Client
43
+ >>> client = Client(app)
44
+ >>> response = client.get('/hello')
45
+ >>> response.get_json()
46
+ 'Hello World'
47
+
48
+ The `request` parameter must be called `request` and must be the first
49
+ parameter of the handler function, but it can be omitted if not needed.
50
+
51
+ So, here is a even simpler version of the above code:
52
+ >>> @app.GET("/hello_again")
53
+ ... def root():
54
+ ... return "Hello again!"
55
+ ...
56
+ >>> response = client.get('/hello_again')
57
+ >>> response.get_json()
58
+ 'Hello again!'
59
+
60
+ URL paths can be parametrized:
61
+ >>> @app.register("GET", "/user/{id}")
62
+ ... def home(id):
63
+ ... return f"Welcome home {id}!"
64
+ ...
65
+ >>> response = client.get("/user/007")
66
+ >>> response.status
67
+ '200 OK'
68
+ >>> response.get_json()
69
+ 'Welcome home 007!'
70
+
71
+ Path parameters can be typed:
72
+ >>> @app.register("GET", "/agent/{id:int}")
73
+ ... def home(id):
74
+ ... return f"Welcome home {id}! You're more than {id - 1}."
75
+ ...
76
+ >>> response = client.get("/agent/Bond")
77
+ >>> response.status
78
+ '404 NOT FOUND'
79
+ >>> response = client.get("/agent/007")
80
+ >>> response.status
81
+ '200 OK'
82
+ >>> response.get_json()
83
+ "Welcome home 7! You're more than 6."
84
+
85
+ Now with generic POST data: add a data parameter and optionally specify
86
+ it's type (default is dict)
87
+ >>> @app.POST("/")
88
+ ... def create(request, data:list):
89
+ ... print(data)
90
+ ...
91
+ >>> response = client.post("/", json=[1, 2, 3])
92
+ [1, 2, 3]
93
+
94
+ >>> response = client.post("/", json={})
95
+ >>> response.status
96
+ '422 UNPROCESSABLE ENTITY'
97
+
98
+ And finally requesting a dict in the POST data with specified fields
99
+ (and types):
100
+
101
+ >>> @app.PUT("/")
102
+ ... def update(request, name, age:int, superhuman:bool=False):
103
+ ... print(f"{name} is {age} years old.")
104
+ ... print(f"{name} is {'' if superhuman else 'not '}superhuman.")
105
+ ...
106
+ >>> data = {"name": "Betsy", "age": 34}
107
+ >>> response = client.put("/", json=data)
108
+ Betsy is 34 years old.
109
+ Betsy is not superhuman.
110
+ >>> response.status
111
+ '200 OK'
112
+
113
+ >>> data = {"name": "Betsy"}
114
+ >>> response = client.put("/", json=data)
115
+ >>> response.status
116
+ '422 UNPROCESSABLE ENTITY'
117
+
118
+ >>> data = {"name": "Betsy", "age": "34"}
119
+ >>> response = client.put("/", json=data)
120
+ >>> response.status
121
+ '422 UNPROCESSABLE ENTITY'
122
+
123
+ >>> data = {"name": "Betsy", "age": 34, "superhuman": True}
124
+ >>> response = client.put("/", json=data)
125
+ Betsy is 34 years old.
126
+ Betsy is superhuman.
127
+ >>> response.status
128
+ '200 OK'
129
+
130
+ >>> data = {"name": "Betsy", "age": 34, "verysmart": True}
131
+ >>> response = client.put("/", json=data)
132
+ >>> response.status
133
+ '422 UNPROCESSABLE ENTITY'
134
+
135
+ >>> data = "This is not valid JSON"
136
+ >>> response = client.put("/", data=data)
137
+ >>> response.status
138
+ '415 UNSUPPORTED MEDIA TYPE'
139
+ """
140
+
141
+ def __init__(self):
142
+ self._url_map = Map()
143
+
144
+ def register(self, method, url, func=None):
145
+ """Register a route with a callback.
146
+
147
+ This function can be used either directly:
148
+
149
+ >>> api = API()
150
+ >>> api.register("GET", "/", func=lambda request: "Hello!") # doctest: +ELLIPSIS
151
+ <function <lambda> at 0x...>
152
+
153
+ or as a decorator
154
+ >>> @api.register("GET", "/user/{id}")
155
+ ... def home(request, id):
156
+ ... return f"Welcome home {id}!"
157
+ ...
158
+
159
+ To test it use the Client class.
160
+ >>> from werkzeug.test import Client
161
+ >>> client = Client(api)
162
+ >>> response = client.get("/")
163
+ >>> response.status
164
+ '200 OK'
165
+ >>> response.get_json()
166
+ 'Hello!'
167
+ >>> response = client.get("/user/007")
168
+ >>> response.status
169
+ '200 OK'
170
+ >>> response.get_json()
171
+ 'Welcome home 007!'
172
+
173
+ `url` accepts parametrized and optionally typed placeholders
174
+ (`{id}`, `{id:int}`).
175
+ """
176
+ if func is None:
177
+ return functools.partial(self.register, method, url)
178
+
179
+ url = _normalize_url_placeholders(url)
180
+
181
+ rule = Rule(url, methods=(method,))
182
+ Map([rule]) # Bind rule temporarily
183
+ url_params = rule.arguments
184
+
185
+ sig = inspect.signature(func)
186
+ params = sig.parameters
187
+ param_keys = list(sig.parameters.keys())
188
+
189
+ # The first argument can optionally be request, after that route params.
190
+ # Parameter order is ignored.
191
+ if param_keys and param_keys[0] == "request":
192
+ func_url_params = set(param_keys[1 : len(url_params) + 1])
193
+ body_params = param_keys[len(url_params) + 1 :]
194
+ else:
195
+ func_url_params = set(param_keys[: len(url_params)])
196
+ body_params = param_keys[len(url_params) :]
197
+
198
+ missmatch = url_params ^ func_url_params
199
+ if missmatch:
200
+ raise TypeError(
201
+ f"{func.__name__}() arguments and route parameter missmatch "
202
+ f"({func_url_params} != {url_params})"
203
+ )
204
+
205
+ body_type = None
206
+ if len(body_params) == 1 and body_params[0] == "data":
207
+ body_type = (
208
+ params["data"].annotation
209
+ if params["data"].annotation is not inspect.Parameter.empty
210
+ else dict
211
+ )
212
+ content_types = {}
213
+ elif body_params:
214
+ body_type = dict
215
+ content_types = {
216
+ key: params[key].annotation
217
+ for key in body_params
218
+ if params[key].annotation is not inspect.Parameter.empty
219
+ }
220
+
221
+ if body_type:
222
+ wrapped_func = _parse_json_body(func, body_type, content_types)
223
+ else:
224
+ wrapped_func = _simple_wrapper(func)
225
+
226
+ self._url_map.add(Rule(url, methods=(method,), endpoint=wrapped_func))
227
+ return func
228
+
229
+ def GET(self, string):
230
+ """Shorthand for registering GET requests.
231
+
232
+ Use as a decorator:
233
+ >>> api = API()
234
+ >>> @api.GET("/admin")
235
+ ... def admin_home(request):
236
+ ... return "Nothing here"
237
+ ...
238
+ >>> from werkzeug.test import Client
239
+ >>> client = Client(api)
240
+ >>> client.get("/admin")
241
+ <TestResponse streamed [200 OK]>
242
+ """
243
+ return self.register("GET", string)
244
+
245
+ def POST(self, string):
246
+ """Shorthand for registering POST requests."""
247
+ return self.register("POST", string)
248
+
249
+ def PUT(self, string):
250
+ """Shorthand for registering PUT requests."""
251
+ return self.register("PUT", string)
252
+
253
+ def PATCH(self, string):
254
+ """Shorthand for registering PATCH requests."""
255
+ return self.register("PATCH", string)
256
+
257
+ def DELETE(self, string):
258
+ """Shorthand for registering DELETE requests."""
259
+ return self.register("DELETE", string)
260
+
261
+ def __call__(self, environ, start_response):
262
+ try:
263
+ request = werkzeug.Request(environ)
264
+ adapter = self._url_map.bind_to_environ(environ)
265
+ endpoint, values = adapter.match()
266
+
267
+ # Dispatch request
268
+ response = endpoint(request, **values)
269
+ if not callable(response):
270
+ response = _json_response(response)
271
+ return response(environ, start_response)
272
+ except HTTPException as e:
273
+ response = _json_response(
274
+ {"code": e.code, "name": e.name, "description": e.description},
275
+ status=e.code,
276
+ )
277
+ except Exception as e:
278
+ response = _json_response(
279
+ {"code": 500, "name": "Internal Server Error"}, status=500
280
+ )
281
+ err = environ["wsgi.errors"]
282
+ print(f"ERROR {e.__class__.__name__}: {str(e)}", file=err)
283
+ traceback.print_exc(file=err)
284
+ return response(environ, start_response)
285
+
286
+
287
+ def _json_response(data, status=200):
288
+ if data is None:
289
+ return werkzeug.Response(status=status)
290
+ elif isinstance(data, werkzeug.Response):
291
+ # Already a response object — pass through as-is
292
+ return data
293
+ else:
294
+ data = json.dumps(data, indent=2, default=str) + "\n"
295
+ return werkzeug.Response(data, status=status, mimetype="application/json")
296
+
297
+
298
+ def _simple_wrapper(func):
299
+ sig = inspect.signature(func)
300
+
301
+ @functools.wraps(func)
302
+ def wrapper(request, *args, **kwargs):
303
+ if sig.parameters and list(sig.parameters.keys())[0] == "request":
304
+ args = (request,) + args
305
+ elif request.data:
306
+ raise UnsupportedMediaType("No request body allowed")
307
+ return func(*args, **kwargs)
308
+
309
+ return wrapper
310
+
311
+
312
+ def _normalize_url_placeholders(url):
313
+ """Convert readable URL placeholders to Werkzeug form.
314
+
315
+ Examples:
316
+ `{id}` -> `<id>`
317
+ `{id:int}` -> `<int:id>`
318
+ """
319
+ for orig, contents in re.findall(r"(\{([^\}\{]+)\})", url):
320
+ if ":" in contents:
321
+ contents = ":".join(reversed(contents.split(":")))
322
+ url = url.replace(orig, f"<{contents}>")
323
+ return url
324
+
325
+
326
+ def _check_value(key, value, value_type):
327
+ if value_type is bool and (value is True or value is False):
328
+ return value
329
+ elif value_type == float and isinstance(value, int):
330
+ return float(value)
331
+ if isinstance(value_type, type) and isinstance(value, value_type):
332
+ return value
333
+ elif (
334
+ isinstance(value, str)
335
+ and callable(value_type)
336
+ and value_type not in (bool, int, float, str, list, dict)
337
+ ):
338
+ func = value_type
339
+ try:
340
+ return func(value)
341
+ except ValueError:
342
+ raise UnprocessableEntity(
343
+ f"Invalid format: '{key}' cannot be converted to {func.__name__}."
344
+ )
345
+ else:
346
+ raise UnprocessableEntity(
347
+ f"Invalid format: '{key}' must be of type {value_type.__name__}."
348
+ )
349
+
350
+
351
+ def _parse_json_body(func=None, body_type=dict, content_types={}): # noqa: C901
352
+ if func is None:
353
+ return functools.partial(
354
+ _parse_json_body, body_type=body_type, content_types=content_types
355
+ )
356
+
357
+ sig = inspect.signature(func)
358
+
359
+ @functools.wraps(func)
360
+ def wrapper(request, *args, **kwargs):
361
+ try:
362
+ data = str(request.data, "utf-8").strip()
363
+ except UnicodeDecodeError:
364
+ raise UnsupportedMediaType("Cannot parse request body: invalid UTF-8 data")
365
+
366
+ if not data:
367
+ raise UnsupportedMediaType("Cannot parse request body: no data supplied")
368
+
369
+ try:
370
+ data = json.loads(data)
371
+ except json.decoder.JSONDecodeError:
372
+ raise UnsupportedMediaType("Cannot parse request body: invalid JSON")
373
+
374
+ if body_type is not None and not isinstance(data, body_type):
375
+ raise UnprocessableEntity(
376
+ f"Invalid data format: {body_type.__name__} expected"
377
+ )
378
+
379
+ if sig.parameters and list(sig.parameters.keys())[0] == "request":
380
+ args = (request,) + args
381
+
382
+ if body_type == dict and content_types:
383
+ too_many = data.keys() - (sig.parameters.keys() - kwargs.keys())
384
+ if too_many:
385
+ raise UnprocessableEntity(f"Key not allowed: {', '.join(too_many)}")
386
+
387
+ kwargs.update(data)
388
+ bound = sig.bind_partial(*args, **kwargs)
389
+
390
+ for key, value in bound.arguments.items():
391
+ if key in content_types:
392
+ bound.arguments[key] = _check_value(key, value, content_types[key])
393
+
394
+ bound.apply_defaults()
395
+
396
+ missing = sig.parameters.keys() - bound.arguments.keys()
397
+ if missing:
398
+ raise UnprocessableEntity(f"Key missing: {', '.join(missing)}")
399
+
400
+ else:
401
+ kwargs["data"] = data
402
+
403
+ return func(*args, **kwargs)
404
+
405
+ return wrapper
406
+
407
+
408
+ def timestamp(string=None):
409
+ """Parse JS date strings to datetime objects.
410
+
411
+ Returns the current datetime (now), if called with no argument.
412
+
413
+ `timestamp` returns UTC timestamps.
414
+
415
+ Example usage:
416
+
417
+ To convert a JS timestamp, create one in the browser or in node:
418
+ > let now = new Date()
419
+ > JSON.stringify(now)
420
+ '"2020-12-09T23:44:53.782Z"'
421
+
422
+ Convert the value to a native Python datetime object:
423
+ >>> timestamp(json.loads('"2020-12-09T23:44:53.782Z"'))
424
+ datetime.datetime(2020, 12, 9, 23, 44, 53, 782000, tzinfo=datetime.timezone.utc)
425
+
426
+ To get the current time:
427
+ >>> timestamp() # doctest: +ELLIPSIS
428
+ datetime.datetime(2..., tzinfo=datetime.timezone.utc)
429
+
430
+ This function can be used as an annotation in request handlers:
431
+ >>> api = API()
432
+ >>> @api.POST("/reminder")
433
+ ... def reminder(request, date:timestamp, text:str):
434
+ ... pass
435
+ ...
436
+ """
437
+ if string is not None:
438
+ return datetime.datetime.fromisoformat(string.replace("Z", "+00:00"))
439
+ else:
440
+ return datetime.datetime.now().astimezone(datetime.timezone.utc)
441
+
442
+
443
+ def _cors_same_host_middleware(app, allowed_host):
444
+ allowed_host = allowed_host.lower().strip()
445
+
446
+ def _origin_for_host(origin):
447
+ try:
448
+ parsed = urllib.parse.urlparse(origin)
449
+ except ValueError:
450
+ return None
451
+
452
+ if not parsed.scheme or not parsed.hostname:
453
+ return None
454
+
455
+ if parsed.hostname.lower() == allowed_host:
456
+ return origin
457
+ return None
458
+
459
+ def _append_vary_origin(headers):
460
+ for index, (key, value) in enumerate(headers):
461
+ if key.lower() == "vary":
462
+ vary_values = {v.strip().lower() for v in value.split(",") if v.strip()}
463
+ if "origin" not in vary_values:
464
+ headers[index] = (key, f"{value}, Origin")
465
+ return
466
+ headers.append(("Vary", "Origin"))
467
+
468
+ def wrapped(environ, start_response):
469
+ origin = environ.get("HTTP_ORIGIN", "")
470
+ allowed_origin = _origin_for_host(origin)
471
+ is_preflight = (
472
+ environ.get("REQUEST_METHOD") == "OPTIONS"
473
+ and "HTTP_ACCESS_CONTROL_REQUEST_METHOD" in environ
474
+ )
475
+
476
+ if is_preflight and allowed_origin:
477
+ requested_headers = environ.get("HTTP_ACCESS_CONTROL_REQUEST_HEADERS", "")
478
+ headers = [
479
+ ("Access-Control-Allow-Origin", allowed_origin),
480
+ (
481
+ "Access-Control-Allow-Methods",
482
+ "GET, POST, PUT, PATCH, DELETE, OPTIONS",
483
+ ),
484
+ (
485
+ "Access-Control-Allow-Headers",
486
+ (
487
+ requested_headers
488
+ if requested_headers
489
+ else "Content-Type, Authorization"
490
+ ),
491
+ ),
492
+ ("Access-Control-Max-Age", "86400"),
493
+ ]
494
+ _append_vary_origin(headers)
495
+ response = werkzeug.Response(status=204, headers=headers)
496
+ return response(environ, start_response)
497
+
498
+ def cors_start_response(status, headers, exc_info=None):
499
+ if allowed_origin:
500
+ headers = list(headers)
501
+ headers.append(("Access-Control-Allow-Origin", allowed_origin))
502
+ _append_vary_origin(headers)
503
+ return start_response(status, headers, exc_info)
504
+
505
+ return app(environ, cors_start_response)
506
+
507
+ return wrapped
508
+
509
+
510
+ class BaseJWTAuthMiddleware:
511
+ """Middleware authorizing access to chained application using JWT.
512
+
513
+ Attention: Authentication must be provided.
514
+
515
+ This middleware exposes two endpoints:
516
+ - /auth/login for generating new tokens.
517
+ - /auth/renew for renewing an existing token
518
+
519
+ Tokens are short-lived and are valid for only 15 minutes, but expired tokens
520
+ can be renewed during one week starting from their initial issuing date.
521
+
522
+ Upon successful authentication the username is stored in the WSGI environment,
523
+ and can be retrieved from Werkzeug's Request object: `request.remote_user`
524
+ """
525
+
526
+ def __init__(
527
+ self,
528
+ app,
529
+ secret,
530
+ *,
531
+ exempt=[],
532
+ prefix="/auth",
533
+ login_methods=("POST",),
534
+ ):
535
+ if jwt is None:
536
+ print("WARNING: No module named 'jwt'", file=sys.stderr)
537
+ print("Cannot perform authentication without PyJWT", file=sys.stderr)
538
+ print("Run `pip install PyJWT` to fix this", file=sys.stderr)
539
+ raise ModuleNotFoundError("No module named 'jwt'")
540
+
541
+ auth_api = API()
542
+ for method in login_methods:
543
+ auth_api.register(method, "/login", func=functools.partial(self._login))
544
+ auth_api.register("POST", "/renew", func=functools.partial(self._renew))
545
+
546
+ self.app = DispatcherMiddleware(app, {prefix: auth_api})
547
+
548
+ prefix = prefix.rstrip("/")
549
+ exempt_map = Map()
550
+ for method, path, *_ in exempt:
551
+ exempt_map.add(
552
+ Rule(
553
+ _normalize_url_placeholders(path),
554
+ methods=(method.upper(),),
555
+ endpoint=True,
556
+ )
557
+ )
558
+ for method in login_methods:
559
+ exempt_map.add(
560
+ Rule(prefix + "/login", methods=(method.upper(),), endpoint=True)
561
+ )
562
+ exempt_map.add(Rule(prefix + "/renew", methods=("POST",), endpoint=True))
563
+ self._exempt_map = exempt_map
564
+
565
+ self.secret = secret
566
+
567
+ def __call__(self, environ, start_response):
568
+ try:
569
+ if not self._is_exempt(environ):
570
+ # Check authorization (throws an exception if it fails)
571
+ username = self._check_authorization(environ)
572
+ assert isinstance(username, str)
573
+ assert username != ""
574
+ del environ["HTTP_AUTHORIZATION"]
575
+ environ["REMOTE_USER"] = username
576
+ return self.app(environ, start_response)
577
+ except HTTPException as e:
578
+ response = _json_response(
579
+ {"code": e.code, "name": e.name, "description": e.description},
580
+ status=e.code,
581
+ )
582
+ except Exception as e:
583
+ response = _json_response(
584
+ {"code": 500, "name": "Internal Server Error"}, status=500
585
+ )
586
+ err = environ["wsgi.errors"]
587
+ print(f"ERROR {e.__class__.__name__}: {str(e)}", file=err)
588
+ traceback.print_exc(file=err)
589
+
590
+ return response(environ, start_response)
591
+
592
+ def _is_exempt(self, environ):
593
+ """Return whether the request matches a registered exempt rule."""
594
+ adapter = self._exempt_map.bind_to_environ(environ)
595
+ try:
596
+ adapter.match()
597
+ return True
598
+ except HTTPException:
599
+ return False
600
+
601
+ def _check_authorization(self, environ):
602
+ """Verify request authorization header.
603
+
604
+ Returns username if authorization passed.
605
+
606
+ Raises a 401 Unauthorized exception if authorization failed.
607
+ """
608
+ if "HTTP_AUTHORIZATION" not in environ:
609
+ raise Unauthorized("No authorization header supplied")
610
+
611
+ auth = environ["HTTP_AUTHORIZATION"]
612
+
613
+ if not auth.startswith("Bearer "):
614
+ raise Unauthorized("Invalid authorization header")
615
+
616
+ token = auth[len("Bearer ") :]
617
+ try:
618
+ claims = jwt.decode(
619
+ token,
620
+ self.secret,
621
+ algorithms=["HS256"],
622
+ options={"require_exp": True, "require_iat": True},
623
+ )
624
+ except jwt.ExpiredSignatureError:
625
+ raise Unauthorized("Expired token")
626
+ except jwt.InvalidTokenError:
627
+ raise Unauthorized("Invalid token")
628
+ return claims["username"]
629
+
630
+ def _login(self, request):
631
+ username = self.authenticate(request)
632
+ if username is None:
633
+ raise Unauthorized("User authentication failed")
634
+ now = datetime.datetime.now(datetime.UTC)
635
+ claims = {
636
+ "username": username,
637
+ "iat": now,
638
+ "exp": now + datetime.timedelta(minutes=15),
639
+ }
640
+ token = jwt.encode(claims, self.secret, algorithm="HS256")
641
+
642
+ return {"token": token}
643
+
644
+ def _renew(self, request, token: str):
645
+ now = datetime.datetime.now(datetime.UTC)
646
+ try:
647
+ # Valid tokens can always be renewed within their short lifetime,
648
+ # independent of the issuing date
649
+ claims = jwt.decode(
650
+ token,
651
+ self.secret,
652
+ algorithms=["HS256"],
653
+ options={"require_exp": True, "require_iat": True},
654
+ )
655
+ except jwt.ExpiredSignatureError:
656
+ # Expired tokens can be renewed for at most one week after the
657
+ # first issuing date
658
+ claims = jwt.decode(
659
+ token,
660
+ self.secret,
661
+ algorithms=["HS256"],
662
+ options={
663
+ "require_exp": True,
664
+ "require_iat": True,
665
+ "verify_exp": False,
666
+ },
667
+ )
668
+ issued_at = datetime.datetime.fromtimestamp(claims["iat"], tz=datetime.UTC)
669
+ if issued_at + datetime.timedelta(days=7) < now:
670
+ raise Unauthorized("Unrenewable expired token")
671
+ except jwt.InvalidTokenError:
672
+ raise Unauthorized("Invalid token")
673
+
674
+ claims["exp"] = now + datetime.timedelta(minutes=15)
675
+
676
+ token = jwt.encode(claims, self.secret, algorithm="HS256")
677
+
678
+ return {"token": token}
679
+
680
+ def authenticate(self, request):
681
+ """Authenticate user.
682
+
683
+ Returns a user identification string (usually the username) if
684
+ authentication passed, None otherwise. This method must be
685
+ overwritten in an implementing subclass.
686
+ """
687
+ raise NotImplementedError()
688
+
689
+
690
+ class ExternalAuth(BaseJWTAuthMiddleware):
691
+ """Rely on external authentication.
692
+
693
+ The username of an authenticated user must be passed with the
694
+ `REMOTE_USER` key in the wsgi environment.
695
+ """
696
+
697
+ def __init__(self, *args, login_methods=("GET", "POST"), **kwargs):
698
+ super().__init__(*args, login_methods=login_methods, **kwargs)
699
+
700
+ def authenticate(self, request):
701
+ return request.remote_user
702
+
703
+
704
+ class DummyAuth(BaseJWTAuthMiddleware):
705
+ """Dummy authenticator for testing and development.
706
+
707
+ Login always passes and always returns "dummyuser" as username.
708
+
709
+ Here is an example session:
710
+
711
+ Create an API:
712
+ >>> api = API()
713
+ >>> @api.GET("/")
714
+ ... def root(request):
715
+ ... return "Hello World"
716
+ ...
717
+
718
+ Wrap it with an authentication/authorization layer:
719
+ >>> app = DummyAuth(api, "not a secret but still quite long")
720
+
721
+ >>> from werkzeug.test import Client
722
+ >>> client = Client(app)
723
+
724
+ By default, access is denied:
725
+ >>> client.get("/")
726
+ <TestResponse streamed [401 UNAUTHORIZED]>
727
+
728
+ Login to get a token:
729
+ >>> response = client.post("/auth/login")
730
+ >>> response.status
731
+ '200 OK'
732
+ >>> token = response.get_json()["token"]
733
+ >>> token # doctest: +ELLIPSIS
734
+ 'eyJ...'
735
+
736
+ Use the token to gain access:
737
+ >>> headers = {"Authorization": f"Bearer {token}"}
738
+ >>> client.get("/", headers=headers)
739
+ <TestResponse streamed [200 OK]>
740
+ """
741
+
742
+ def authenticate(self, request):
743
+ return "dummyuser"
744
+
745
+
746
+ class UsernamePasswordAuth(BaseJWTAuthMiddleware):
747
+ """Authenticate with a username and password combination.
748
+
749
+ `dataset` user tables are supported natively. It must contain a unique
750
+ and identifiable `username` and a `password` column. Alternatively, a
751
+ custom password hash retrieval function may be specified.
752
+
753
+ Passwords are expected to be hashed using the PBKDF2 algorithm. Use
754
+ werkzeug's `werkzeug.security.generate_password_hash` function to generate
755
+ compatible password hashes.
756
+
757
+ Example usage:
758
+
759
+ Let's create a in-memory database with a user table containing one entry:
760
+ >>> import dataset
761
+ >>> from werkzeug.security import generate_password_hash
762
+ >>> db = dataset.connect("sqlite:///:memory:")
763
+ >>> db['user'].insert(dict(username="paul", password=generate_password_hash("john")))
764
+ 1
765
+
766
+ Assemble a dummy application and client:
767
+ >>> app = UsernamePasswordAuth(API(), "not a secret but still quite long", user_table=db['user'])
768
+ >>> from werkzeug.test import Client
769
+ >>> client = Client(app)
770
+
771
+ Logging in with correct credentials is now possible:
772
+ >>> cred = {"username": "paul", "password": "john"}
773
+ >>> response = client.post("/auth/login", json=cred)
774
+ >>> response.status
775
+ '200 OK'
776
+ >>> response.get_json()["token"] # doctest: +ELLIPSIS
777
+ 'eyJ...'
778
+
779
+ Requests with invalid passwords fail:
780
+ >>> cred = {"username": "paul", "password": "george"}
781
+ >>> response = client.post("/auth/login", json=cred)
782
+ >>> response.status
783
+ '401 UNAUTHORIZED'
784
+
785
+ The same is true for non-existent users:
786
+ >>> cred = {"username": "yoko", "password": "john"}
787
+ >>> response = client.post("/auth/login", json=cred)
788
+ >>> response.status
789
+ '401 UNAUTHORIZED'
790
+ """
791
+
792
+ def __init__(
793
+ self,
794
+ app,
795
+ secret,
796
+ user_table=None,
797
+ find_password=None,
798
+ *,
799
+ exempt=[],
800
+ prefix="/auth",
801
+ login_methods=("POST",),
802
+ ):
803
+ """Initialize the authentication middleware.
804
+
805
+ The `user_table` argument is expected to be a dataset table containing
806
+ at least a unique identifying `username` and a `password` field.
807
+
808
+ Alternatively the `find_password` parameter expects a function taking
809
+ a username as argument, and returning the corresponding password hash.
810
+ The function must return None if the user or password cannot be found.
811
+
812
+ If both present, the `user_table` parameter takes precedence over
813
+ `find_password`.
814
+ """
815
+ super().__init__(
816
+ app, secret, exempt=exempt, prefix=prefix, login_methods=login_methods
817
+ )
818
+ if user_table is not None:
819
+
820
+ def user_table_find_password(username):
821
+ user = user_table.find_one(username=username)
822
+ if user:
823
+ return user["password"]
824
+
825
+ self.find_password = user_table_find_password
826
+ elif find_password is not None:
827
+ self.find_password = find_password
828
+ else:
829
+ raise ValueError("One of 'user_table' and 'find_password' must be supplied")
830
+
831
+ def authenticate(self, request):
832
+ @_parse_json_body(content_types=dict(username=str, password=str))
833
+ def handler(request, username, password):
834
+ pw_hash = self.find_password(username)
835
+ if pw_hash and check_password_hash(pw_hash, password):
836
+ return username
837
+
838
+ return handler(request)
839
+
840
+
841
+ def run(
842
+ app,
843
+ prefix=None,
844
+ port=3000,
845
+ hostname="localhost",
846
+ allow_cors_from_hostname=False,
847
+ use_reloader=True,
848
+ ):
849
+ """Run a wsgi application like an API.
850
+
851
+ Optionally specify a listening `port` (default: 3000) and a bind
852
+ `hostname` (default: localhost). Set the hostname to the empty string,
853
+ to listen on all interfaces.
854
+
855
+ CORS is disabled by default. If `allow_cors_from_hostname` is set to
856
+ `True`, requests from origins sharing the same hostname are allowed,
857
+ regardless of their port number.
858
+
859
+ `use_reloader` is forwarded to Werkzeug's development server.
860
+ """
861
+ if prefix is not None:
862
+ app = DispatcherMiddleware(NotFound, {prefix: app})
863
+
864
+ if allow_cors_from_hostname:
865
+ app = _cors_same_host_middleware(app, hostname)
866
+
867
+ werkzeug.run_simple(
868
+ hostname,
869
+ port,
870
+ app,
871
+ threaded=True,
872
+ use_reloader=use_reloader,
873
+ )
874
+
875
+
876
+ Event = collections.namedtuple("Event", ["id", "event_type", "data"])
877
+
878
+
879
+ class PubSub:
880
+ """Class implementing a publish/subscribe event passing scheme.
881
+
882
+ Basic example usage:
883
+ >>> chat = PubSub()
884
+ >>> subscription = chat.subscribe()
885
+ >>> chat.publish("message", "Hello")
886
+ >>> next(subscription)
887
+ Event(id=0, event_type='message', data='Hello')
888
+
889
+ Messages can be differentiated by topic:
890
+ >>> general_room = chat.subscribe(topic="general")
891
+ >>> nerd_room = chat.subscribe(topic="nerd")
892
+ >>> chat.publish("new_user", "guido", topic="nerd")
893
+ >>> chat.publish("message", "Hi geeks!", topic="nerd")
894
+ >>> chat.publish("message", "It is 12 am", topic="general")
895
+ >>> next(general_room)
896
+ Event(id=3, event_type='message', data='It is 12 am')
897
+ >>> next(nerd_room)
898
+ Event(id=1, event_type='new_user', data='guido')
899
+ >>> next(nerd_room)
900
+ Event(id=2, event_type='message', data='Hi geeks!')
901
+ """
902
+
903
+ def __init__(self):
904
+ self._main_lock = threading.Lock()
905
+ self._topic_locks = collections.defaultdict(threading.Lock)
906
+ self._queues = collections.defaultdict(set)
907
+ self._replay_log = collections.defaultdict(
908
+ lambda: collections.deque(maxlen=1_000)
909
+ )
910
+ self._current_id = itertools.count()
911
+
912
+ def publish(self, event_type, data, topic=None):
913
+ """Publish an event.
914
+
915
+ The event has an `event_type`, usually a string, and a `data` payload.
916
+ `data` can be free-formed, but should be JSON-serializable.
917
+
918
+ Optionally a topic can be specified. The message will be only forwarded
919
+ to subscribers interested in the specified topic.
920
+ """
921
+ with self._main_lock:
922
+ id = next(self._current_id)
923
+ queues = self._queues[topic]
924
+ replay_log = self._replay_log[topic]
925
+ topic_lock = self._topic_locks[topic]
926
+
927
+ to_remove = []
928
+ event = Event(id, event_type, data)
929
+
930
+ with topic_lock:
931
+ replay_log.append(event)
932
+
933
+ for q in queues:
934
+ try:
935
+ q.put_nowait(event)
936
+ except queue.Full: # Somebody fell asleep?!?
937
+ to_remove.append(q)
938
+
939
+ for q in to_remove:
940
+ try:
941
+ queues.remove(q)
942
+ except KeyError:
943
+ pass
944
+
945
+ def broadcast(self, event_type, data):
946
+ """Broadcast event to all subscribers."""
947
+ with self._main_lock:
948
+ topics = list(self._queues.keys())
949
+
950
+ for topic in topics:
951
+ self.publish(event_type, data, topic)
952
+
953
+ def subscribe(self, topic=None):
954
+ """Subscribe to published events.
955
+
956
+ Events are returned as triples, containing a unique event `id`, the
957
+ `event_type`, and the payload `data`.
958
+
959
+ Optionally a specific `topic` can be specified.
960
+ """
961
+ q = queue.Queue(100)
962
+ with self._main_lock:
963
+ queues = self._queues[topic]
964
+ topic_lock = self._topic_locks[topic]
965
+
966
+ with topic_lock:
967
+ queues.add(q)
968
+
969
+ def iterator():
970
+ try:
971
+ while q in queues:
972
+ try:
973
+ yield q.get(timeout=60)
974
+ except queue.Empty:
975
+ pass
976
+ except GeneratorExit:
977
+ try:
978
+ with topic_lock:
979
+ queues.remove(q)
980
+ except KeyError:
981
+ pass
982
+
983
+ return iterator()
984
+
985
+ def _event_stream(self, replay_events=(), topic=None):
986
+ subscription = self.subscribe(topic)
987
+ for event in itertools.chain(replay_events, subscription):
988
+ yield (
989
+ f"id: {event.id}\n"
990
+ f"event: {event.event_type}\n"
991
+ f"data: {json.dumps(event.data, default=str)}\n\n"
992
+ ).encode("utf-8")
993
+
994
+ def _replay_events(self, last_id, topic=None):
995
+ if last_id is None:
996
+ return ()
997
+
998
+ last_id = int(last_id)
999
+
1000
+ with self._main_lock:
1001
+ replay_log = self._replay_log[topic]
1002
+ topic_lock = self._topic_locks[topic]
1003
+
1004
+ with topic_lock:
1005
+ log_iter = iter(replay_log)
1006
+ for event in log_iter:
1007
+ if event.id == last_id:
1008
+ break
1009
+ else:
1010
+ raise ValueError(f"{last_id} is not in event log")
1011
+ return list(log_iter)
1012
+
1013
+ def streaming_response(self, request, topic=None):
1014
+ """Generate a streaming HTTP response with server-sent events.
1015
+
1016
+ See https://html.spec.whatwg.org/multipage/server-sent-events.html
1017
+ for more information about server-sent events.
1018
+
1019
+ When reconnecting after losing the connection for a while, browsers
1020
+ automatically set the `Last-Event-ID` header field to the value of
1021
+ the id of the last received event. The response will first replay
1022
+ missed events, before sending newly arriving events. When the event
1023
+ specified by `Last-Event-ID` is not found, a 404 Not Found response
1024
+ is sent, signalling to the browser, that a clean recovery is not
1025
+ possible.
1026
+
1027
+ Here is an example session. First, let us create an API:
1028
+ >>> api = API()
1029
+ >>> chat = PubSub()
1030
+ >>> @api.POST("/")
1031
+ ... def post_message(request, message:str):
1032
+ ... chat.publish("message", message)
1033
+ ...
1034
+ >>> @api.GET("/")
1035
+ ... def stream(request):
1036
+ ... return chat.streaming_response(request)
1037
+ ...
1038
+
1039
+ We can now post messages and see them appear in our subscription:
1040
+ >>> subscription = chat.subscribe()
1041
+ >>> from werkzeug.test import Client
1042
+ >>> client = Client(api)
1043
+ >>> resp = client.post("/", json={"message": "hello"})
1044
+ >>> resp = client.post("/", json={"message": "everybody"})
1045
+ >>> next(subscription)
1046
+ Event(id=0, event_type='message', data='hello')
1047
+ >>> next(subscription)
1048
+ Event(id=1, event_type='message', data='everybody')
1049
+
1050
+ Now, let's simulate a reconnecting browser that only got the first
1051
+ message:
1052
+ >>> response = client.get("/", headers={"Last-Event-ID": "0"})
1053
+ >>> response.status
1054
+ '200 OK'
1055
+ >>> body = iter(response.response)
1056
+
1057
+ The events are formatted according to the specification for server-sent
1058
+ events:
1059
+ >>> print(str(next(body), encoding="utf-8").strip())
1060
+ id: 1
1061
+ event: message
1062
+ data: "everybody"
1063
+
1064
+ Further incoming messages are sent to the listening client, without
1065
+ closing the connection:
1066
+ >>> resp = client.post("/", json={"message": "howdoyoudo?"})
1067
+ >>> print(str(next(body), encoding="utf-8").strip())
1068
+ id: 2
1069
+ event: message
1070
+ data: "howdoyoudo?"
1071
+ """
1072
+ last_id = request.headers.get("Last-Event-ID", None)
1073
+ try:
1074
+ replay_events = self._replay_events(last_id)
1075
+ except ValueError:
1076
+ raise NotFound()
1077
+
1078
+ return werkzeug.Response(
1079
+ self._event_stream(replay_events, topic), mimetype="text/event-stream"
1080
+ )
1081
+
1082
+
1083
+ __all__ = (
1084
+ "API",
1085
+ "NotFound",
1086
+ "Unauthorized",
1087
+ "UnprocessableEntity",
1088
+ "timestamp",
1089
+ "BaseJWTAuthMiddleware",
1090
+ "ExternalAuth",
1091
+ "DummyAuth",
1092
+ "UsernamePasswordAuth",
1093
+ "run",
1094
+ "PubSub",
1095
+ )