wau 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wau-0.1.0.dist-info/METADATA +88 -0
- wau-0.1.0.dist-info/RECORD +6 -0
- wau-0.1.0.dist-info/WHEEL +5 -0
- wau-0.1.0.dist-info/licenses/LICENSE +14 -0
- wau-0.1.0.dist-info/top_level.txt +1 -0
- wau.py +1095 -0
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: wau
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Web API Utils
|
|
5
|
+
Author: Marco Schmalz
|
|
6
|
+
License-Expression: LGPL-3.0-or-later
|
|
7
|
+
Keywords: api,json,werkzeug,education
|
|
8
|
+
Classifier: Development Status :: 3 - Alpha
|
|
9
|
+
Classifier: Intended Audience :: Education
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
12
|
+
Classifier: Topic :: Internet :: WWW/HTTP
|
|
13
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
14
|
+
Requires-Python: >=3.14
|
|
15
|
+
Description-Content-Type: text/markdown
|
|
16
|
+
License-File: LICENSE
|
|
17
|
+
Requires-Dist: dataset>=2.0.0
|
|
18
|
+
Requires-Dist: pyjwt>=2.13.0
|
|
19
|
+
Requires-Dist: werkzeug>=3.1.8
|
|
20
|
+
Dynamic: license-file
|
|
21
|
+
|
|
22
|
+
# `wau` — Web API Utils
|
|
23
|
+
|
|
24
|
+
Web API Utils, or short `wau`, is a thin layer on top of Werkzeug to provide a simple and consistent interface for writing APIs in Python. `wau` is built for educational purposes and is not intended for production use. It is opinionated, as it only supports JSON as data format. It uses simple type annotations to define the expected input and output of the API endpoints. Common tasks as authentication, CORS and server-sent events are supported by default.
|
|
25
|
+
|
|
26
|
+
## Installation
|
|
27
|
+
|
|
28
|
+
Install from PyPI:
|
|
29
|
+
|
|
30
|
+
```powershell
|
|
31
|
+
pip install wau
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
or with uv:
|
|
35
|
+
|
|
36
|
+
```powershell
|
|
37
|
+
uv add wau
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Testing
|
|
41
|
+
|
|
42
|
+
Test dependencies are separated from runtime dependencies in `pyproject.toml`
|
|
43
|
+
using the `test` dependency group.
|
|
44
|
+
|
|
45
|
+
Run the test suite:
|
|
46
|
+
|
|
47
|
+
```powershell
|
|
48
|
+
uv run --group test python -m pytest -q
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
Run doctests:
|
|
52
|
+
|
|
53
|
+
```powershell
|
|
54
|
+
uv run --group test python -m doctest .\wau.py
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
## Publishing
|
|
58
|
+
|
|
59
|
+
Build package artifacts:
|
|
60
|
+
|
|
61
|
+
```powershell
|
|
62
|
+
uv build
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
Validate metadata and README rendering:
|
|
66
|
+
|
|
67
|
+
```powershell
|
|
68
|
+
uvx twine check dist/*
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
Upload to TestPyPI first:
|
|
72
|
+
|
|
73
|
+
```powershell
|
|
74
|
+
uv publish --publish-url https://test.pypi.org/legacy/
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
Then publish to PyPI:
|
|
78
|
+
|
|
79
|
+
```powershell
|
|
80
|
+
uv publish
|
|
81
|
+
```
|
|
82
|
+
|
|
83
|
+
## License
|
|
84
|
+
|
|
85
|
+
This project is licensed under GNU LGPL v3 or later (`LGPL-3.0-or-later`).
|
|
86
|
+
|
|
87
|
+
If you distribute modified versions of this library, those library
|
|
88
|
+
modifications must be published under the same license terms.
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
wau.py,sha256=wQsoNvzvH6HFRBj8ZKIpKjT87_9aB-WfOi0kEt8pdCU,37026
|
|
2
|
+
wau-0.1.0.dist-info/licenses/LICENSE,sha256=5lfLO8VaZisSCb3xtnOUdjX26TYX52CYXYRVUJsSzUk,500
|
|
3
|
+
wau-0.1.0.dist-info/METADATA,sha256=bm0_c9LKeic0-u1efpcsK2YvduO8yVgHtNKXq3Ak_dY,2172
|
|
4
|
+
wau-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
5
|
+
wau-0.1.0.dist-info/top_level.txt,sha256=xH8uF0IWDusYlJBXD4LSHfJLMSsa-Yyj0_2I26TAJQQ,4
|
|
6
|
+
wau-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
GNU LESSER GENERAL PUBLIC LICENSE
|
|
2
|
+
Version 3, 29 June 2007
|
|
3
|
+
|
|
4
|
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
|
5
|
+
Everyone is permitted to copy and distribute verbatim copies
|
|
6
|
+
of this license document, but changing it is not allowed.
|
|
7
|
+
|
|
8
|
+
This project is licensed under the GNU Lesser General Public License,
|
|
9
|
+
version 3 or (at your option) any later version.
|
|
10
|
+
|
|
11
|
+
For the full license text, see:
|
|
12
|
+
https://www.gnu.org/licenses/lgpl-3.0.txt
|
|
13
|
+
|
|
14
|
+
SPDX-License-Identifier: LGPL-3.0-or-later
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
wau
|
wau.py
ADDED
|
@@ -0,0 +1,1095 @@
|
|
|
1
|
+
import collections
|
|
2
|
+
import datetime
|
|
3
|
+
import functools
|
|
4
|
+
import inspect
|
|
5
|
+
import itertools
|
|
6
|
+
import json
|
|
7
|
+
import queue
|
|
8
|
+
import re
|
|
9
|
+
import sys
|
|
10
|
+
import threading
|
|
11
|
+
import traceback
|
|
12
|
+
import urllib.parse
|
|
13
|
+
|
|
14
|
+
import werkzeug
|
|
15
|
+
from werkzeug.exceptions import (
|
|
16
|
+
HTTPException,
|
|
17
|
+
NotFound,
|
|
18
|
+
Unauthorized,
|
|
19
|
+
UnprocessableEntity,
|
|
20
|
+
UnsupportedMediaType,
|
|
21
|
+
)
|
|
22
|
+
from werkzeug.middleware.dispatcher import DispatcherMiddleware
|
|
23
|
+
from werkzeug.routing import Map, Rule
|
|
24
|
+
from werkzeug.security import check_password_hash
|
|
25
|
+
|
|
26
|
+
try:
|
|
27
|
+
import jwt
|
|
28
|
+
except ImportError:
|
|
29
|
+
# Do not complain now, but only when auth classes get instantiated
|
|
30
|
+
jwt = None
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class API:
|
|
34
|
+
"""An API speaking in JSON with the outside world.
|
|
35
|
+
|
|
36
|
+
Here is a very simple first API:
|
|
37
|
+
>>> app = API()
|
|
38
|
+
>>> @app.GET("/hello")
|
|
39
|
+
... def root(request):
|
|
40
|
+
... return "Hello World"
|
|
41
|
+
...
|
|
42
|
+
>>> from werkzeug.test import Client
|
|
43
|
+
>>> client = Client(app)
|
|
44
|
+
>>> response = client.get('/hello')
|
|
45
|
+
>>> response.get_json()
|
|
46
|
+
'Hello World'
|
|
47
|
+
|
|
48
|
+
The `request` parameter must be called `request` and must be the first
|
|
49
|
+
parameter of the handler function, but it can be omitted if not needed.
|
|
50
|
+
|
|
51
|
+
So, here is a even simpler version of the above code:
|
|
52
|
+
>>> @app.GET("/hello_again")
|
|
53
|
+
... def root():
|
|
54
|
+
... return "Hello again!"
|
|
55
|
+
...
|
|
56
|
+
>>> response = client.get('/hello_again')
|
|
57
|
+
>>> response.get_json()
|
|
58
|
+
'Hello again!'
|
|
59
|
+
|
|
60
|
+
URL paths can be parametrized:
|
|
61
|
+
>>> @app.register("GET", "/user/{id}")
|
|
62
|
+
... def home(id):
|
|
63
|
+
... return f"Welcome home {id}!"
|
|
64
|
+
...
|
|
65
|
+
>>> response = client.get("/user/007")
|
|
66
|
+
>>> response.status
|
|
67
|
+
'200 OK'
|
|
68
|
+
>>> response.get_json()
|
|
69
|
+
'Welcome home 007!'
|
|
70
|
+
|
|
71
|
+
Path parameters can be typed:
|
|
72
|
+
>>> @app.register("GET", "/agent/{id:int}")
|
|
73
|
+
... def home(id):
|
|
74
|
+
... return f"Welcome home {id}! You're more than {id - 1}."
|
|
75
|
+
...
|
|
76
|
+
>>> response = client.get("/agent/Bond")
|
|
77
|
+
>>> response.status
|
|
78
|
+
'404 NOT FOUND'
|
|
79
|
+
>>> response = client.get("/agent/007")
|
|
80
|
+
>>> response.status
|
|
81
|
+
'200 OK'
|
|
82
|
+
>>> response.get_json()
|
|
83
|
+
"Welcome home 7! You're more than 6."
|
|
84
|
+
|
|
85
|
+
Now with generic POST data: add a data parameter and optionally specify
|
|
86
|
+
it's type (default is dict)
|
|
87
|
+
>>> @app.POST("/")
|
|
88
|
+
... def create(request, data:list):
|
|
89
|
+
... print(data)
|
|
90
|
+
...
|
|
91
|
+
>>> response = client.post("/", json=[1, 2, 3])
|
|
92
|
+
[1, 2, 3]
|
|
93
|
+
|
|
94
|
+
>>> response = client.post("/", json={})
|
|
95
|
+
>>> response.status
|
|
96
|
+
'422 UNPROCESSABLE ENTITY'
|
|
97
|
+
|
|
98
|
+
And finally requesting a dict in the POST data with specified fields
|
|
99
|
+
(and types):
|
|
100
|
+
|
|
101
|
+
>>> @app.PUT("/")
|
|
102
|
+
... def update(request, name, age:int, superhuman:bool=False):
|
|
103
|
+
... print(f"{name} is {age} years old.")
|
|
104
|
+
... print(f"{name} is {'' if superhuman else 'not '}superhuman.")
|
|
105
|
+
...
|
|
106
|
+
>>> data = {"name": "Betsy", "age": 34}
|
|
107
|
+
>>> response = client.put("/", json=data)
|
|
108
|
+
Betsy is 34 years old.
|
|
109
|
+
Betsy is not superhuman.
|
|
110
|
+
>>> response.status
|
|
111
|
+
'200 OK'
|
|
112
|
+
|
|
113
|
+
>>> data = {"name": "Betsy"}
|
|
114
|
+
>>> response = client.put("/", json=data)
|
|
115
|
+
>>> response.status
|
|
116
|
+
'422 UNPROCESSABLE ENTITY'
|
|
117
|
+
|
|
118
|
+
>>> data = {"name": "Betsy", "age": "34"}
|
|
119
|
+
>>> response = client.put("/", json=data)
|
|
120
|
+
>>> response.status
|
|
121
|
+
'422 UNPROCESSABLE ENTITY'
|
|
122
|
+
|
|
123
|
+
>>> data = {"name": "Betsy", "age": 34, "superhuman": True}
|
|
124
|
+
>>> response = client.put("/", json=data)
|
|
125
|
+
Betsy is 34 years old.
|
|
126
|
+
Betsy is superhuman.
|
|
127
|
+
>>> response.status
|
|
128
|
+
'200 OK'
|
|
129
|
+
|
|
130
|
+
>>> data = {"name": "Betsy", "age": 34, "verysmart": True}
|
|
131
|
+
>>> response = client.put("/", json=data)
|
|
132
|
+
>>> response.status
|
|
133
|
+
'422 UNPROCESSABLE ENTITY'
|
|
134
|
+
|
|
135
|
+
>>> data = "This is not valid JSON"
|
|
136
|
+
>>> response = client.put("/", data=data)
|
|
137
|
+
>>> response.status
|
|
138
|
+
'415 UNSUPPORTED MEDIA TYPE'
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
def __init__(self):
|
|
142
|
+
self._url_map = Map()
|
|
143
|
+
|
|
144
|
+
def register(self, method, url, func=None):
|
|
145
|
+
"""Register a route with a callback.
|
|
146
|
+
|
|
147
|
+
This function can be used either directly:
|
|
148
|
+
|
|
149
|
+
>>> api = API()
|
|
150
|
+
>>> api.register("GET", "/", func=lambda request: "Hello!") # doctest: +ELLIPSIS
|
|
151
|
+
<function <lambda> at 0x...>
|
|
152
|
+
|
|
153
|
+
or as a decorator
|
|
154
|
+
>>> @api.register("GET", "/user/{id}")
|
|
155
|
+
... def home(request, id):
|
|
156
|
+
... return f"Welcome home {id}!"
|
|
157
|
+
...
|
|
158
|
+
|
|
159
|
+
To test it use the Client class.
|
|
160
|
+
>>> from werkzeug.test import Client
|
|
161
|
+
>>> client = Client(api)
|
|
162
|
+
>>> response = client.get("/")
|
|
163
|
+
>>> response.status
|
|
164
|
+
'200 OK'
|
|
165
|
+
>>> response.get_json()
|
|
166
|
+
'Hello!'
|
|
167
|
+
>>> response = client.get("/user/007")
|
|
168
|
+
>>> response.status
|
|
169
|
+
'200 OK'
|
|
170
|
+
>>> response.get_json()
|
|
171
|
+
'Welcome home 007!'
|
|
172
|
+
|
|
173
|
+
`url` accepts parametrized and optionally typed placeholders
|
|
174
|
+
(`{id}`, `{id:int}`).
|
|
175
|
+
"""
|
|
176
|
+
if func is None:
|
|
177
|
+
return functools.partial(self.register, method, url)
|
|
178
|
+
|
|
179
|
+
url = _normalize_url_placeholders(url)
|
|
180
|
+
|
|
181
|
+
rule = Rule(url, methods=(method,))
|
|
182
|
+
Map([rule]) # Bind rule temporarily
|
|
183
|
+
url_params = rule.arguments
|
|
184
|
+
|
|
185
|
+
sig = inspect.signature(func)
|
|
186
|
+
params = sig.parameters
|
|
187
|
+
param_keys = list(sig.parameters.keys())
|
|
188
|
+
|
|
189
|
+
# The first argument can optionally be request, after that route params.
|
|
190
|
+
# Parameter order is ignored.
|
|
191
|
+
if param_keys and param_keys[0] == "request":
|
|
192
|
+
func_url_params = set(param_keys[1 : len(url_params) + 1])
|
|
193
|
+
body_params = param_keys[len(url_params) + 1 :]
|
|
194
|
+
else:
|
|
195
|
+
func_url_params = set(param_keys[: len(url_params)])
|
|
196
|
+
body_params = param_keys[len(url_params) :]
|
|
197
|
+
|
|
198
|
+
missmatch = url_params ^ func_url_params
|
|
199
|
+
if missmatch:
|
|
200
|
+
raise TypeError(
|
|
201
|
+
f"{func.__name__}() arguments and route parameter missmatch "
|
|
202
|
+
f"({func_url_params} != {url_params})"
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
body_type = None
|
|
206
|
+
if len(body_params) == 1 and body_params[0] == "data":
|
|
207
|
+
body_type = (
|
|
208
|
+
params["data"].annotation
|
|
209
|
+
if params["data"].annotation is not inspect.Parameter.empty
|
|
210
|
+
else dict
|
|
211
|
+
)
|
|
212
|
+
content_types = {}
|
|
213
|
+
elif body_params:
|
|
214
|
+
body_type = dict
|
|
215
|
+
content_types = {
|
|
216
|
+
key: params[key].annotation
|
|
217
|
+
for key in body_params
|
|
218
|
+
if params[key].annotation is not inspect.Parameter.empty
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if body_type:
|
|
222
|
+
wrapped_func = _parse_json_body(func, body_type, content_types)
|
|
223
|
+
else:
|
|
224
|
+
wrapped_func = _simple_wrapper(func)
|
|
225
|
+
|
|
226
|
+
self._url_map.add(Rule(url, methods=(method,), endpoint=wrapped_func))
|
|
227
|
+
return func
|
|
228
|
+
|
|
229
|
+
def GET(self, string):
|
|
230
|
+
"""Shorthand for registering GET requests.
|
|
231
|
+
|
|
232
|
+
Use as a decorator:
|
|
233
|
+
>>> api = API()
|
|
234
|
+
>>> @api.GET("/admin")
|
|
235
|
+
... def admin_home(request):
|
|
236
|
+
... return "Nothing here"
|
|
237
|
+
...
|
|
238
|
+
>>> from werkzeug.test import Client
|
|
239
|
+
>>> client = Client(api)
|
|
240
|
+
>>> client.get("/admin")
|
|
241
|
+
<TestResponse streamed [200 OK]>
|
|
242
|
+
"""
|
|
243
|
+
return self.register("GET", string)
|
|
244
|
+
|
|
245
|
+
def POST(self, string):
|
|
246
|
+
"""Shorthand for registering POST requests."""
|
|
247
|
+
return self.register("POST", string)
|
|
248
|
+
|
|
249
|
+
def PUT(self, string):
|
|
250
|
+
"""Shorthand for registering PUT requests."""
|
|
251
|
+
return self.register("PUT", string)
|
|
252
|
+
|
|
253
|
+
def PATCH(self, string):
|
|
254
|
+
"""Shorthand for registering PATCH requests."""
|
|
255
|
+
return self.register("PATCH", string)
|
|
256
|
+
|
|
257
|
+
def DELETE(self, string):
|
|
258
|
+
"""Shorthand for registering DELETE requests."""
|
|
259
|
+
return self.register("DELETE", string)
|
|
260
|
+
|
|
261
|
+
def __call__(self, environ, start_response):
|
|
262
|
+
try:
|
|
263
|
+
request = werkzeug.Request(environ)
|
|
264
|
+
adapter = self._url_map.bind_to_environ(environ)
|
|
265
|
+
endpoint, values = adapter.match()
|
|
266
|
+
|
|
267
|
+
# Dispatch request
|
|
268
|
+
response = endpoint(request, **values)
|
|
269
|
+
if not callable(response):
|
|
270
|
+
response = _json_response(response)
|
|
271
|
+
return response(environ, start_response)
|
|
272
|
+
except HTTPException as e:
|
|
273
|
+
response = _json_response(
|
|
274
|
+
{"code": e.code, "name": e.name, "description": e.description},
|
|
275
|
+
status=e.code,
|
|
276
|
+
)
|
|
277
|
+
except Exception as e:
|
|
278
|
+
response = _json_response(
|
|
279
|
+
{"code": 500, "name": "Internal Server Error"}, status=500
|
|
280
|
+
)
|
|
281
|
+
err = environ["wsgi.errors"]
|
|
282
|
+
print(f"ERROR {e.__class__.__name__}: {str(e)}", file=err)
|
|
283
|
+
traceback.print_exc(file=err)
|
|
284
|
+
return response(environ, start_response)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def _json_response(data, status=200):
|
|
288
|
+
if data is None:
|
|
289
|
+
return werkzeug.Response(status=status)
|
|
290
|
+
elif isinstance(data, werkzeug.Response):
|
|
291
|
+
# Already a response object — pass through as-is
|
|
292
|
+
return data
|
|
293
|
+
else:
|
|
294
|
+
data = json.dumps(data, indent=2, default=str) + "\n"
|
|
295
|
+
return werkzeug.Response(data, status=status, mimetype="application/json")
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def _simple_wrapper(func):
|
|
299
|
+
sig = inspect.signature(func)
|
|
300
|
+
|
|
301
|
+
@functools.wraps(func)
|
|
302
|
+
def wrapper(request, *args, **kwargs):
|
|
303
|
+
if sig.parameters and list(sig.parameters.keys())[0] == "request":
|
|
304
|
+
args = (request,) + args
|
|
305
|
+
elif request.data:
|
|
306
|
+
raise UnsupportedMediaType("No request body allowed")
|
|
307
|
+
return func(*args, **kwargs)
|
|
308
|
+
|
|
309
|
+
return wrapper
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _normalize_url_placeholders(url):
|
|
313
|
+
"""Convert readable URL placeholders to Werkzeug form.
|
|
314
|
+
|
|
315
|
+
Examples:
|
|
316
|
+
`{id}` -> `<id>`
|
|
317
|
+
`{id:int}` -> `<int:id>`
|
|
318
|
+
"""
|
|
319
|
+
for orig, contents in re.findall(r"(\{([^\}\{]+)\})", url):
|
|
320
|
+
if ":" in contents:
|
|
321
|
+
contents = ":".join(reversed(contents.split(":")))
|
|
322
|
+
url = url.replace(orig, f"<{contents}>")
|
|
323
|
+
return url
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def _check_value(key, value, value_type):
|
|
327
|
+
if value_type is bool and (value is True or value is False):
|
|
328
|
+
return value
|
|
329
|
+
elif value_type == float and isinstance(value, int):
|
|
330
|
+
return float(value)
|
|
331
|
+
if isinstance(value_type, type) and isinstance(value, value_type):
|
|
332
|
+
return value
|
|
333
|
+
elif (
|
|
334
|
+
isinstance(value, str)
|
|
335
|
+
and callable(value_type)
|
|
336
|
+
and value_type not in (bool, int, float, str, list, dict)
|
|
337
|
+
):
|
|
338
|
+
func = value_type
|
|
339
|
+
try:
|
|
340
|
+
return func(value)
|
|
341
|
+
except ValueError:
|
|
342
|
+
raise UnprocessableEntity(
|
|
343
|
+
f"Invalid format: '{key}' cannot be converted to {func.__name__}."
|
|
344
|
+
)
|
|
345
|
+
else:
|
|
346
|
+
raise UnprocessableEntity(
|
|
347
|
+
f"Invalid format: '{key}' must be of type {value_type.__name__}."
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
def _parse_json_body(func=None, body_type=dict, content_types={}): # noqa: C901
|
|
352
|
+
if func is None:
|
|
353
|
+
return functools.partial(
|
|
354
|
+
_parse_json_body, body_type=body_type, content_types=content_types
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
sig = inspect.signature(func)
|
|
358
|
+
|
|
359
|
+
@functools.wraps(func)
|
|
360
|
+
def wrapper(request, *args, **kwargs):
|
|
361
|
+
try:
|
|
362
|
+
data = str(request.data, "utf-8").strip()
|
|
363
|
+
except UnicodeDecodeError:
|
|
364
|
+
raise UnsupportedMediaType("Cannot parse request body: invalid UTF-8 data")
|
|
365
|
+
|
|
366
|
+
if not data:
|
|
367
|
+
raise UnsupportedMediaType("Cannot parse request body: no data supplied")
|
|
368
|
+
|
|
369
|
+
try:
|
|
370
|
+
data = json.loads(data)
|
|
371
|
+
except json.decoder.JSONDecodeError:
|
|
372
|
+
raise UnsupportedMediaType("Cannot parse request body: invalid JSON")
|
|
373
|
+
|
|
374
|
+
if body_type is not None and not isinstance(data, body_type):
|
|
375
|
+
raise UnprocessableEntity(
|
|
376
|
+
f"Invalid data format: {body_type.__name__} expected"
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if sig.parameters and list(sig.parameters.keys())[0] == "request":
|
|
380
|
+
args = (request,) + args
|
|
381
|
+
|
|
382
|
+
if body_type == dict and content_types:
|
|
383
|
+
too_many = data.keys() - (sig.parameters.keys() - kwargs.keys())
|
|
384
|
+
if too_many:
|
|
385
|
+
raise UnprocessableEntity(f"Key not allowed: {', '.join(too_many)}")
|
|
386
|
+
|
|
387
|
+
kwargs.update(data)
|
|
388
|
+
bound = sig.bind_partial(*args, **kwargs)
|
|
389
|
+
|
|
390
|
+
for key, value in bound.arguments.items():
|
|
391
|
+
if key in content_types:
|
|
392
|
+
bound.arguments[key] = _check_value(key, value, content_types[key])
|
|
393
|
+
|
|
394
|
+
bound.apply_defaults()
|
|
395
|
+
|
|
396
|
+
missing = sig.parameters.keys() - bound.arguments.keys()
|
|
397
|
+
if missing:
|
|
398
|
+
raise UnprocessableEntity(f"Key missing: {', '.join(missing)}")
|
|
399
|
+
|
|
400
|
+
else:
|
|
401
|
+
kwargs["data"] = data
|
|
402
|
+
|
|
403
|
+
return func(*args, **kwargs)
|
|
404
|
+
|
|
405
|
+
return wrapper
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def timestamp(string=None):
|
|
409
|
+
"""Parse JS date strings to datetime objects.
|
|
410
|
+
|
|
411
|
+
Returns the current datetime (now), if called with no argument.
|
|
412
|
+
|
|
413
|
+
`timestamp` returns UTC timestamps.
|
|
414
|
+
|
|
415
|
+
Example usage:
|
|
416
|
+
|
|
417
|
+
To convert a JS timestamp, create one in the browser or in node:
|
|
418
|
+
> let now = new Date()
|
|
419
|
+
> JSON.stringify(now)
|
|
420
|
+
'"2020-12-09T23:44:53.782Z"'
|
|
421
|
+
|
|
422
|
+
Convert the value to a native Python datetime object:
|
|
423
|
+
>>> timestamp(json.loads('"2020-12-09T23:44:53.782Z"'))
|
|
424
|
+
datetime.datetime(2020, 12, 9, 23, 44, 53, 782000, tzinfo=datetime.timezone.utc)
|
|
425
|
+
|
|
426
|
+
To get the current time:
|
|
427
|
+
>>> timestamp() # doctest: +ELLIPSIS
|
|
428
|
+
datetime.datetime(2..., tzinfo=datetime.timezone.utc)
|
|
429
|
+
|
|
430
|
+
This function can be used as an annotation in request handlers:
|
|
431
|
+
>>> api = API()
|
|
432
|
+
>>> @api.POST("/reminder")
|
|
433
|
+
... def reminder(request, date:timestamp, text:str):
|
|
434
|
+
... pass
|
|
435
|
+
...
|
|
436
|
+
"""
|
|
437
|
+
if string is not None:
|
|
438
|
+
return datetime.datetime.fromisoformat(string.replace("Z", "+00:00"))
|
|
439
|
+
else:
|
|
440
|
+
return datetime.datetime.now().astimezone(datetime.timezone.utc)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def _cors_same_host_middleware(app, allowed_host):
|
|
444
|
+
allowed_host = allowed_host.lower().strip()
|
|
445
|
+
|
|
446
|
+
def _origin_for_host(origin):
|
|
447
|
+
try:
|
|
448
|
+
parsed = urllib.parse.urlparse(origin)
|
|
449
|
+
except ValueError:
|
|
450
|
+
return None
|
|
451
|
+
|
|
452
|
+
if not parsed.scheme or not parsed.hostname:
|
|
453
|
+
return None
|
|
454
|
+
|
|
455
|
+
if parsed.hostname.lower() == allowed_host:
|
|
456
|
+
return origin
|
|
457
|
+
return None
|
|
458
|
+
|
|
459
|
+
def _append_vary_origin(headers):
|
|
460
|
+
for index, (key, value) in enumerate(headers):
|
|
461
|
+
if key.lower() == "vary":
|
|
462
|
+
vary_values = {v.strip().lower() for v in value.split(",") if v.strip()}
|
|
463
|
+
if "origin" not in vary_values:
|
|
464
|
+
headers[index] = (key, f"{value}, Origin")
|
|
465
|
+
return
|
|
466
|
+
headers.append(("Vary", "Origin"))
|
|
467
|
+
|
|
468
|
+
def wrapped(environ, start_response):
|
|
469
|
+
origin = environ.get("HTTP_ORIGIN", "")
|
|
470
|
+
allowed_origin = _origin_for_host(origin)
|
|
471
|
+
is_preflight = (
|
|
472
|
+
environ.get("REQUEST_METHOD") == "OPTIONS"
|
|
473
|
+
and "HTTP_ACCESS_CONTROL_REQUEST_METHOD" in environ
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
if is_preflight and allowed_origin:
|
|
477
|
+
requested_headers = environ.get("HTTP_ACCESS_CONTROL_REQUEST_HEADERS", "")
|
|
478
|
+
headers = [
|
|
479
|
+
("Access-Control-Allow-Origin", allowed_origin),
|
|
480
|
+
(
|
|
481
|
+
"Access-Control-Allow-Methods",
|
|
482
|
+
"GET, POST, PUT, PATCH, DELETE, OPTIONS",
|
|
483
|
+
),
|
|
484
|
+
(
|
|
485
|
+
"Access-Control-Allow-Headers",
|
|
486
|
+
(
|
|
487
|
+
requested_headers
|
|
488
|
+
if requested_headers
|
|
489
|
+
else "Content-Type, Authorization"
|
|
490
|
+
),
|
|
491
|
+
),
|
|
492
|
+
("Access-Control-Max-Age", "86400"),
|
|
493
|
+
]
|
|
494
|
+
_append_vary_origin(headers)
|
|
495
|
+
response = werkzeug.Response(status=204, headers=headers)
|
|
496
|
+
return response(environ, start_response)
|
|
497
|
+
|
|
498
|
+
def cors_start_response(status, headers, exc_info=None):
|
|
499
|
+
if allowed_origin:
|
|
500
|
+
headers = list(headers)
|
|
501
|
+
headers.append(("Access-Control-Allow-Origin", allowed_origin))
|
|
502
|
+
_append_vary_origin(headers)
|
|
503
|
+
return start_response(status, headers, exc_info)
|
|
504
|
+
|
|
505
|
+
return app(environ, cors_start_response)
|
|
506
|
+
|
|
507
|
+
return wrapped
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
class BaseJWTAuthMiddleware:
|
|
511
|
+
"""Middleware authorizing access to chained application using JWT.
|
|
512
|
+
|
|
513
|
+
Attention: Authentication must be provided.
|
|
514
|
+
|
|
515
|
+
This middleware exposes two endpoints:
|
|
516
|
+
- /auth/login for generating new tokens.
|
|
517
|
+
- /auth/renew for renewing an existing token
|
|
518
|
+
|
|
519
|
+
Tokens are short-lived and are valid for only 15 minutes, but expired tokens
|
|
520
|
+
can be renewed during one week starting from their initial issuing date.
|
|
521
|
+
|
|
522
|
+
Upon successful authentication the username is stored in the WSGI environment,
|
|
523
|
+
and can be retrieved from Werkzeug's Request object: `request.remote_user`
|
|
524
|
+
"""
|
|
525
|
+
|
|
526
|
+
def __init__(
|
|
527
|
+
self,
|
|
528
|
+
app,
|
|
529
|
+
secret,
|
|
530
|
+
*,
|
|
531
|
+
exempt=[],
|
|
532
|
+
prefix="/auth",
|
|
533
|
+
login_methods=("POST",),
|
|
534
|
+
):
|
|
535
|
+
if jwt is None:
|
|
536
|
+
print("WARNING: No module named 'jwt'", file=sys.stderr)
|
|
537
|
+
print("Cannot perform authentication without PyJWT", file=sys.stderr)
|
|
538
|
+
print("Run `pip install PyJWT` to fix this", file=sys.stderr)
|
|
539
|
+
raise ModuleNotFoundError("No module named 'jwt'")
|
|
540
|
+
|
|
541
|
+
auth_api = API()
|
|
542
|
+
for method in login_methods:
|
|
543
|
+
auth_api.register(method, "/login", func=functools.partial(self._login))
|
|
544
|
+
auth_api.register("POST", "/renew", func=functools.partial(self._renew))
|
|
545
|
+
|
|
546
|
+
self.app = DispatcherMiddleware(app, {prefix: auth_api})
|
|
547
|
+
|
|
548
|
+
prefix = prefix.rstrip("/")
|
|
549
|
+
exempt_map = Map()
|
|
550
|
+
for method, path, *_ in exempt:
|
|
551
|
+
exempt_map.add(
|
|
552
|
+
Rule(
|
|
553
|
+
_normalize_url_placeholders(path),
|
|
554
|
+
methods=(method.upper(),),
|
|
555
|
+
endpoint=True,
|
|
556
|
+
)
|
|
557
|
+
)
|
|
558
|
+
for method in login_methods:
|
|
559
|
+
exempt_map.add(
|
|
560
|
+
Rule(prefix + "/login", methods=(method.upper(),), endpoint=True)
|
|
561
|
+
)
|
|
562
|
+
exempt_map.add(Rule(prefix + "/renew", methods=("POST",), endpoint=True))
|
|
563
|
+
self._exempt_map = exempt_map
|
|
564
|
+
|
|
565
|
+
self.secret = secret
|
|
566
|
+
|
|
567
|
+
def __call__(self, environ, start_response):
|
|
568
|
+
try:
|
|
569
|
+
if not self._is_exempt(environ):
|
|
570
|
+
# Check authorization (throws an exception if it fails)
|
|
571
|
+
username = self._check_authorization(environ)
|
|
572
|
+
assert isinstance(username, str)
|
|
573
|
+
assert username != ""
|
|
574
|
+
del environ["HTTP_AUTHORIZATION"]
|
|
575
|
+
environ["REMOTE_USER"] = username
|
|
576
|
+
return self.app(environ, start_response)
|
|
577
|
+
except HTTPException as e:
|
|
578
|
+
response = _json_response(
|
|
579
|
+
{"code": e.code, "name": e.name, "description": e.description},
|
|
580
|
+
status=e.code,
|
|
581
|
+
)
|
|
582
|
+
except Exception as e:
|
|
583
|
+
response = _json_response(
|
|
584
|
+
{"code": 500, "name": "Internal Server Error"}, status=500
|
|
585
|
+
)
|
|
586
|
+
err = environ["wsgi.errors"]
|
|
587
|
+
print(f"ERROR {e.__class__.__name__}: {str(e)}", file=err)
|
|
588
|
+
traceback.print_exc(file=err)
|
|
589
|
+
|
|
590
|
+
return response(environ, start_response)
|
|
591
|
+
|
|
592
|
+
def _is_exempt(self, environ):
|
|
593
|
+
"""Return whether the request matches a registered exempt rule."""
|
|
594
|
+
adapter = self._exempt_map.bind_to_environ(environ)
|
|
595
|
+
try:
|
|
596
|
+
adapter.match()
|
|
597
|
+
return True
|
|
598
|
+
except HTTPException:
|
|
599
|
+
return False
|
|
600
|
+
|
|
601
|
+
def _check_authorization(self, environ):
|
|
602
|
+
"""Verify request authorization header.
|
|
603
|
+
|
|
604
|
+
Returns username if authorization passed.
|
|
605
|
+
|
|
606
|
+
Raises a 401 Unauthorized exception if authorization failed.
|
|
607
|
+
"""
|
|
608
|
+
if "HTTP_AUTHORIZATION" not in environ:
|
|
609
|
+
raise Unauthorized("No authorization header supplied")
|
|
610
|
+
|
|
611
|
+
auth = environ["HTTP_AUTHORIZATION"]
|
|
612
|
+
|
|
613
|
+
if not auth.startswith("Bearer "):
|
|
614
|
+
raise Unauthorized("Invalid authorization header")
|
|
615
|
+
|
|
616
|
+
token = auth[len("Bearer ") :]
|
|
617
|
+
try:
|
|
618
|
+
claims = jwt.decode(
|
|
619
|
+
token,
|
|
620
|
+
self.secret,
|
|
621
|
+
algorithms=["HS256"],
|
|
622
|
+
options={"require_exp": True, "require_iat": True},
|
|
623
|
+
)
|
|
624
|
+
except jwt.ExpiredSignatureError:
|
|
625
|
+
raise Unauthorized("Expired token")
|
|
626
|
+
except jwt.InvalidTokenError:
|
|
627
|
+
raise Unauthorized("Invalid token")
|
|
628
|
+
return claims["username"]
|
|
629
|
+
|
|
630
|
+
def _login(self, request):
|
|
631
|
+
username = self.authenticate(request)
|
|
632
|
+
if username is None:
|
|
633
|
+
raise Unauthorized("User authentication failed")
|
|
634
|
+
now = datetime.datetime.now(datetime.UTC)
|
|
635
|
+
claims = {
|
|
636
|
+
"username": username,
|
|
637
|
+
"iat": now,
|
|
638
|
+
"exp": now + datetime.timedelta(minutes=15),
|
|
639
|
+
}
|
|
640
|
+
token = jwt.encode(claims, self.secret, algorithm="HS256")
|
|
641
|
+
|
|
642
|
+
return {"token": token}
|
|
643
|
+
|
|
644
|
+
def _renew(self, request, token: str):
|
|
645
|
+
now = datetime.datetime.now(datetime.UTC)
|
|
646
|
+
try:
|
|
647
|
+
# Valid tokens can always be renewed within their short lifetime,
|
|
648
|
+
# independent of the issuing date
|
|
649
|
+
claims = jwt.decode(
|
|
650
|
+
token,
|
|
651
|
+
self.secret,
|
|
652
|
+
algorithms=["HS256"],
|
|
653
|
+
options={"require_exp": True, "require_iat": True},
|
|
654
|
+
)
|
|
655
|
+
except jwt.ExpiredSignatureError:
|
|
656
|
+
# Expired tokens can be renewed for at most one week after the
|
|
657
|
+
# first issuing date
|
|
658
|
+
claims = jwt.decode(
|
|
659
|
+
token,
|
|
660
|
+
self.secret,
|
|
661
|
+
algorithms=["HS256"],
|
|
662
|
+
options={
|
|
663
|
+
"require_exp": True,
|
|
664
|
+
"require_iat": True,
|
|
665
|
+
"verify_exp": False,
|
|
666
|
+
},
|
|
667
|
+
)
|
|
668
|
+
issued_at = datetime.datetime.fromtimestamp(claims["iat"], tz=datetime.UTC)
|
|
669
|
+
if issued_at + datetime.timedelta(days=7) < now:
|
|
670
|
+
raise Unauthorized("Unrenewable expired token")
|
|
671
|
+
except jwt.InvalidTokenError:
|
|
672
|
+
raise Unauthorized("Invalid token")
|
|
673
|
+
|
|
674
|
+
claims["exp"] = now + datetime.timedelta(minutes=15)
|
|
675
|
+
|
|
676
|
+
token = jwt.encode(claims, self.secret, algorithm="HS256")
|
|
677
|
+
|
|
678
|
+
return {"token": token}
|
|
679
|
+
|
|
680
|
+
def authenticate(self, request):
|
|
681
|
+
"""Authenticate user.
|
|
682
|
+
|
|
683
|
+
Returns a user identification string (usually the username) if
|
|
684
|
+
authentication passed, None otherwise. This method must be
|
|
685
|
+
overwritten in an implementing subclass.
|
|
686
|
+
"""
|
|
687
|
+
raise NotImplementedError()
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
class ExternalAuth(BaseJWTAuthMiddleware):
|
|
691
|
+
"""Rely on external authentication.
|
|
692
|
+
|
|
693
|
+
The username of an authenticated user must be passed with the
|
|
694
|
+
`REMOTE_USER` key in the wsgi environment.
|
|
695
|
+
"""
|
|
696
|
+
|
|
697
|
+
def __init__(self, *args, login_methods=("GET", "POST"), **kwargs):
|
|
698
|
+
super().__init__(*args, login_methods=login_methods, **kwargs)
|
|
699
|
+
|
|
700
|
+
def authenticate(self, request):
|
|
701
|
+
return request.remote_user
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
class DummyAuth(BaseJWTAuthMiddleware):
|
|
705
|
+
"""Dummy authenticator for testing and development.
|
|
706
|
+
|
|
707
|
+
Login always passes and always returns "dummyuser" as username.
|
|
708
|
+
|
|
709
|
+
Here is an example session:
|
|
710
|
+
|
|
711
|
+
Create an API:
|
|
712
|
+
>>> api = API()
|
|
713
|
+
>>> @api.GET("/")
|
|
714
|
+
... def root(request):
|
|
715
|
+
... return "Hello World"
|
|
716
|
+
...
|
|
717
|
+
|
|
718
|
+
Wrap it with an authentication/authorization layer:
|
|
719
|
+
>>> app = DummyAuth(api, "not a secret but still quite long")
|
|
720
|
+
|
|
721
|
+
>>> from werkzeug.test import Client
|
|
722
|
+
>>> client = Client(app)
|
|
723
|
+
|
|
724
|
+
By default, access is denied:
|
|
725
|
+
>>> client.get("/")
|
|
726
|
+
<TestResponse streamed [401 UNAUTHORIZED]>
|
|
727
|
+
|
|
728
|
+
Login to get a token:
|
|
729
|
+
>>> response = client.post("/auth/login")
|
|
730
|
+
>>> response.status
|
|
731
|
+
'200 OK'
|
|
732
|
+
>>> token = response.get_json()["token"]
|
|
733
|
+
>>> token # doctest: +ELLIPSIS
|
|
734
|
+
'eyJ...'
|
|
735
|
+
|
|
736
|
+
Use the token to gain access:
|
|
737
|
+
>>> headers = {"Authorization": f"Bearer {token}"}
|
|
738
|
+
>>> client.get("/", headers=headers)
|
|
739
|
+
<TestResponse streamed [200 OK]>
|
|
740
|
+
"""
|
|
741
|
+
|
|
742
|
+
def authenticate(self, request):
|
|
743
|
+
return "dummyuser"
|
|
744
|
+
|
|
745
|
+
|
|
746
|
+
class UsernamePasswordAuth(BaseJWTAuthMiddleware):
|
|
747
|
+
"""Authenticate with a username and password combination.
|
|
748
|
+
|
|
749
|
+
`dataset` user tables are supported natively. It must contain a unique
|
|
750
|
+
and identifiable `username` and a `password` column. Alternatively, a
|
|
751
|
+
custom password hash retrieval function may be specified.
|
|
752
|
+
|
|
753
|
+
Passwords are expected to be hashed using the PBKDF2 algorithm. Use
|
|
754
|
+
werkzeug's `werkzeug.security.generate_password_hash` function to generate
|
|
755
|
+
compatible password hashes.
|
|
756
|
+
|
|
757
|
+
Example usage:
|
|
758
|
+
|
|
759
|
+
Let's create a in-memory database with a user table containing one entry:
|
|
760
|
+
>>> import dataset
|
|
761
|
+
>>> from werkzeug.security import generate_password_hash
|
|
762
|
+
>>> db = dataset.connect("sqlite:///:memory:")
|
|
763
|
+
>>> db['user'].insert(dict(username="paul", password=generate_password_hash("john")))
|
|
764
|
+
1
|
|
765
|
+
|
|
766
|
+
Assemble a dummy application and client:
|
|
767
|
+
>>> app = UsernamePasswordAuth(API(), "not a secret but still quite long", user_table=db['user'])
|
|
768
|
+
>>> from werkzeug.test import Client
|
|
769
|
+
>>> client = Client(app)
|
|
770
|
+
|
|
771
|
+
Logging in with correct credentials is now possible:
|
|
772
|
+
>>> cred = {"username": "paul", "password": "john"}
|
|
773
|
+
>>> response = client.post("/auth/login", json=cred)
|
|
774
|
+
>>> response.status
|
|
775
|
+
'200 OK'
|
|
776
|
+
>>> response.get_json()["token"] # doctest: +ELLIPSIS
|
|
777
|
+
'eyJ...'
|
|
778
|
+
|
|
779
|
+
Requests with invalid passwords fail:
|
|
780
|
+
>>> cred = {"username": "paul", "password": "george"}
|
|
781
|
+
>>> response = client.post("/auth/login", json=cred)
|
|
782
|
+
>>> response.status
|
|
783
|
+
'401 UNAUTHORIZED'
|
|
784
|
+
|
|
785
|
+
The same is true for non-existent users:
|
|
786
|
+
>>> cred = {"username": "yoko", "password": "john"}
|
|
787
|
+
>>> response = client.post("/auth/login", json=cred)
|
|
788
|
+
>>> response.status
|
|
789
|
+
'401 UNAUTHORIZED'
|
|
790
|
+
"""
|
|
791
|
+
|
|
792
|
+
def __init__(
|
|
793
|
+
self,
|
|
794
|
+
app,
|
|
795
|
+
secret,
|
|
796
|
+
user_table=None,
|
|
797
|
+
find_password=None,
|
|
798
|
+
*,
|
|
799
|
+
exempt=[],
|
|
800
|
+
prefix="/auth",
|
|
801
|
+
login_methods=("POST",),
|
|
802
|
+
):
|
|
803
|
+
"""Initialize the authentication middleware.
|
|
804
|
+
|
|
805
|
+
The `user_table` argument is expected to be a dataset table containing
|
|
806
|
+
at least a unique identifying `username` and a `password` field.
|
|
807
|
+
|
|
808
|
+
Alternatively the `find_password` parameter expects a function taking
|
|
809
|
+
a username as argument, and returning the corresponding password hash.
|
|
810
|
+
The function must return None if the user or password cannot be found.
|
|
811
|
+
|
|
812
|
+
If both present, the `user_table` parameter takes precedence over
|
|
813
|
+
`find_password`.
|
|
814
|
+
"""
|
|
815
|
+
super().__init__(
|
|
816
|
+
app, secret, exempt=exempt, prefix=prefix, login_methods=login_methods
|
|
817
|
+
)
|
|
818
|
+
if user_table is not None:
|
|
819
|
+
|
|
820
|
+
def user_table_find_password(username):
|
|
821
|
+
user = user_table.find_one(username=username)
|
|
822
|
+
if user:
|
|
823
|
+
return user["password"]
|
|
824
|
+
|
|
825
|
+
self.find_password = user_table_find_password
|
|
826
|
+
elif find_password is not None:
|
|
827
|
+
self.find_password = find_password
|
|
828
|
+
else:
|
|
829
|
+
raise ValueError("One of 'user_table' and 'find_password' must be supplied")
|
|
830
|
+
|
|
831
|
+
def authenticate(self, request):
|
|
832
|
+
@_parse_json_body(content_types=dict(username=str, password=str))
|
|
833
|
+
def handler(request, username, password):
|
|
834
|
+
pw_hash = self.find_password(username)
|
|
835
|
+
if pw_hash and check_password_hash(pw_hash, password):
|
|
836
|
+
return username
|
|
837
|
+
|
|
838
|
+
return handler(request)
|
|
839
|
+
|
|
840
|
+
|
|
841
|
+
def run(
|
|
842
|
+
app,
|
|
843
|
+
prefix=None,
|
|
844
|
+
port=3000,
|
|
845
|
+
hostname="localhost",
|
|
846
|
+
allow_cors_from_hostname=False,
|
|
847
|
+
use_reloader=True,
|
|
848
|
+
):
|
|
849
|
+
"""Run a wsgi application like an API.
|
|
850
|
+
|
|
851
|
+
Optionally specify a listening `port` (default: 3000) and a bind
|
|
852
|
+
`hostname` (default: localhost). Set the hostname to the empty string,
|
|
853
|
+
to listen on all interfaces.
|
|
854
|
+
|
|
855
|
+
CORS is disabled by default. If `allow_cors_from_hostname` is set to
|
|
856
|
+
`True`, requests from origins sharing the same hostname are allowed,
|
|
857
|
+
regardless of their port number.
|
|
858
|
+
|
|
859
|
+
`use_reloader` is forwarded to Werkzeug's development server.
|
|
860
|
+
"""
|
|
861
|
+
if prefix is not None:
|
|
862
|
+
app = DispatcherMiddleware(NotFound, {prefix: app})
|
|
863
|
+
|
|
864
|
+
if allow_cors_from_hostname:
|
|
865
|
+
app = _cors_same_host_middleware(app, hostname)
|
|
866
|
+
|
|
867
|
+
werkzeug.run_simple(
|
|
868
|
+
hostname,
|
|
869
|
+
port,
|
|
870
|
+
app,
|
|
871
|
+
threaded=True,
|
|
872
|
+
use_reloader=use_reloader,
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
|
|
876
|
+
Event = collections.namedtuple("Event", ["id", "event_type", "data"])
|
|
877
|
+
|
|
878
|
+
|
|
879
|
+
class PubSub:
|
|
880
|
+
"""Class implementing a publish/subscribe event passing scheme.
|
|
881
|
+
|
|
882
|
+
Basic example usage:
|
|
883
|
+
>>> chat = PubSub()
|
|
884
|
+
>>> subscription = chat.subscribe()
|
|
885
|
+
>>> chat.publish("message", "Hello")
|
|
886
|
+
>>> next(subscription)
|
|
887
|
+
Event(id=0, event_type='message', data='Hello')
|
|
888
|
+
|
|
889
|
+
Messages can be differentiated by topic:
|
|
890
|
+
>>> general_room = chat.subscribe(topic="general")
|
|
891
|
+
>>> nerd_room = chat.subscribe(topic="nerd")
|
|
892
|
+
>>> chat.publish("new_user", "guido", topic="nerd")
|
|
893
|
+
>>> chat.publish("message", "Hi geeks!", topic="nerd")
|
|
894
|
+
>>> chat.publish("message", "It is 12 am", topic="general")
|
|
895
|
+
>>> next(general_room)
|
|
896
|
+
Event(id=3, event_type='message', data='It is 12 am')
|
|
897
|
+
>>> next(nerd_room)
|
|
898
|
+
Event(id=1, event_type='new_user', data='guido')
|
|
899
|
+
>>> next(nerd_room)
|
|
900
|
+
Event(id=2, event_type='message', data='Hi geeks!')
|
|
901
|
+
"""
|
|
902
|
+
|
|
903
|
+
def __init__(self):
|
|
904
|
+
self._main_lock = threading.Lock()
|
|
905
|
+
self._topic_locks = collections.defaultdict(threading.Lock)
|
|
906
|
+
self._queues = collections.defaultdict(set)
|
|
907
|
+
self._replay_log = collections.defaultdict(
|
|
908
|
+
lambda: collections.deque(maxlen=1_000)
|
|
909
|
+
)
|
|
910
|
+
self._current_id = itertools.count()
|
|
911
|
+
|
|
912
|
+
def publish(self, event_type, data, topic=None):
|
|
913
|
+
"""Publish an event.
|
|
914
|
+
|
|
915
|
+
The event has an `event_type`, usually a string, and a `data` payload.
|
|
916
|
+
`data` can be free-formed, but should be JSON-serializable.
|
|
917
|
+
|
|
918
|
+
Optionally a topic can be specified. The message will be only forwarded
|
|
919
|
+
to subscribers interested in the specified topic.
|
|
920
|
+
"""
|
|
921
|
+
with self._main_lock:
|
|
922
|
+
id = next(self._current_id)
|
|
923
|
+
queues = self._queues[topic]
|
|
924
|
+
replay_log = self._replay_log[topic]
|
|
925
|
+
topic_lock = self._topic_locks[topic]
|
|
926
|
+
|
|
927
|
+
to_remove = []
|
|
928
|
+
event = Event(id, event_type, data)
|
|
929
|
+
|
|
930
|
+
with topic_lock:
|
|
931
|
+
replay_log.append(event)
|
|
932
|
+
|
|
933
|
+
for q in queues:
|
|
934
|
+
try:
|
|
935
|
+
q.put_nowait(event)
|
|
936
|
+
except queue.Full: # Somebody fell asleep?!?
|
|
937
|
+
to_remove.append(q)
|
|
938
|
+
|
|
939
|
+
for q in to_remove:
|
|
940
|
+
try:
|
|
941
|
+
queues.remove(q)
|
|
942
|
+
except KeyError:
|
|
943
|
+
pass
|
|
944
|
+
|
|
945
|
+
def broadcast(self, event_type, data):
|
|
946
|
+
"""Broadcast event to all subscribers."""
|
|
947
|
+
with self._main_lock:
|
|
948
|
+
topics = list(self._queues.keys())
|
|
949
|
+
|
|
950
|
+
for topic in topics:
|
|
951
|
+
self.publish(event_type, data, topic)
|
|
952
|
+
|
|
953
|
+
def subscribe(self, topic=None):
|
|
954
|
+
"""Subscribe to published events.
|
|
955
|
+
|
|
956
|
+
Events are returned as triples, containing a unique event `id`, the
|
|
957
|
+
`event_type`, and the payload `data`.
|
|
958
|
+
|
|
959
|
+
Optionally a specific `topic` can be specified.
|
|
960
|
+
"""
|
|
961
|
+
q = queue.Queue(100)
|
|
962
|
+
with self._main_lock:
|
|
963
|
+
queues = self._queues[topic]
|
|
964
|
+
topic_lock = self._topic_locks[topic]
|
|
965
|
+
|
|
966
|
+
with topic_lock:
|
|
967
|
+
queues.add(q)
|
|
968
|
+
|
|
969
|
+
def iterator():
|
|
970
|
+
try:
|
|
971
|
+
while q in queues:
|
|
972
|
+
try:
|
|
973
|
+
yield q.get(timeout=60)
|
|
974
|
+
except queue.Empty:
|
|
975
|
+
pass
|
|
976
|
+
except GeneratorExit:
|
|
977
|
+
try:
|
|
978
|
+
with topic_lock:
|
|
979
|
+
queues.remove(q)
|
|
980
|
+
except KeyError:
|
|
981
|
+
pass
|
|
982
|
+
|
|
983
|
+
return iterator()
|
|
984
|
+
|
|
985
|
+
def _event_stream(self, replay_events=(), topic=None):
|
|
986
|
+
subscription = self.subscribe(topic)
|
|
987
|
+
for event in itertools.chain(replay_events, subscription):
|
|
988
|
+
yield (
|
|
989
|
+
f"id: {event.id}\n"
|
|
990
|
+
f"event: {event.event_type}\n"
|
|
991
|
+
f"data: {json.dumps(event.data, default=str)}\n\n"
|
|
992
|
+
).encode("utf-8")
|
|
993
|
+
|
|
994
|
+
def _replay_events(self, last_id, topic=None):
|
|
995
|
+
if last_id is None:
|
|
996
|
+
return ()
|
|
997
|
+
|
|
998
|
+
last_id = int(last_id)
|
|
999
|
+
|
|
1000
|
+
with self._main_lock:
|
|
1001
|
+
replay_log = self._replay_log[topic]
|
|
1002
|
+
topic_lock = self._topic_locks[topic]
|
|
1003
|
+
|
|
1004
|
+
with topic_lock:
|
|
1005
|
+
log_iter = iter(replay_log)
|
|
1006
|
+
for event in log_iter:
|
|
1007
|
+
if event.id == last_id:
|
|
1008
|
+
break
|
|
1009
|
+
else:
|
|
1010
|
+
raise ValueError(f"{last_id} is not in event log")
|
|
1011
|
+
return list(log_iter)
|
|
1012
|
+
|
|
1013
|
+
def streaming_response(self, request, topic=None):
|
|
1014
|
+
"""Generate a streaming HTTP response with server-sent events.
|
|
1015
|
+
|
|
1016
|
+
See https://html.spec.whatwg.org/multipage/server-sent-events.html
|
|
1017
|
+
for more information about server-sent events.
|
|
1018
|
+
|
|
1019
|
+
When reconnecting after losing the connection for a while, browsers
|
|
1020
|
+
automatically set the `Last-Event-ID` header field to the value of
|
|
1021
|
+
the id of the last received event. The response will first replay
|
|
1022
|
+
missed events, before sending newly arriving events. When the event
|
|
1023
|
+
specified by `Last-Event-ID` is not found, a 404 Not Found response
|
|
1024
|
+
is sent, signalling to the browser, that a clean recovery is not
|
|
1025
|
+
possible.
|
|
1026
|
+
|
|
1027
|
+
Here is an example session. First, let us create an API:
|
|
1028
|
+
>>> api = API()
|
|
1029
|
+
>>> chat = PubSub()
|
|
1030
|
+
>>> @api.POST("/")
|
|
1031
|
+
... def post_message(request, message:str):
|
|
1032
|
+
... chat.publish("message", message)
|
|
1033
|
+
...
|
|
1034
|
+
>>> @api.GET("/")
|
|
1035
|
+
... def stream(request):
|
|
1036
|
+
... return chat.streaming_response(request)
|
|
1037
|
+
...
|
|
1038
|
+
|
|
1039
|
+
We can now post messages and see them appear in our subscription:
|
|
1040
|
+
>>> subscription = chat.subscribe()
|
|
1041
|
+
>>> from werkzeug.test import Client
|
|
1042
|
+
>>> client = Client(api)
|
|
1043
|
+
>>> resp = client.post("/", json={"message": "hello"})
|
|
1044
|
+
>>> resp = client.post("/", json={"message": "everybody"})
|
|
1045
|
+
>>> next(subscription)
|
|
1046
|
+
Event(id=0, event_type='message', data='hello')
|
|
1047
|
+
>>> next(subscription)
|
|
1048
|
+
Event(id=1, event_type='message', data='everybody')
|
|
1049
|
+
|
|
1050
|
+
Now, let's simulate a reconnecting browser that only got the first
|
|
1051
|
+
message:
|
|
1052
|
+
>>> response = client.get("/", headers={"Last-Event-ID": "0"})
|
|
1053
|
+
>>> response.status
|
|
1054
|
+
'200 OK'
|
|
1055
|
+
>>> body = iter(response.response)
|
|
1056
|
+
|
|
1057
|
+
The events are formatted according to the specification for server-sent
|
|
1058
|
+
events:
|
|
1059
|
+
>>> print(str(next(body), encoding="utf-8").strip())
|
|
1060
|
+
id: 1
|
|
1061
|
+
event: message
|
|
1062
|
+
data: "everybody"
|
|
1063
|
+
|
|
1064
|
+
Further incoming messages are sent to the listening client, without
|
|
1065
|
+
closing the connection:
|
|
1066
|
+
>>> resp = client.post("/", json={"message": "howdoyoudo?"})
|
|
1067
|
+
>>> print(str(next(body), encoding="utf-8").strip())
|
|
1068
|
+
id: 2
|
|
1069
|
+
event: message
|
|
1070
|
+
data: "howdoyoudo?"
|
|
1071
|
+
"""
|
|
1072
|
+
last_id = request.headers.get("Last-Event-ID", None)
|
|
1073
|
+
try:
|
|
1074
|
+
replay_events = self._replay_events(last_id)
|
|
1075
|
+
except ValueError:
|
|
1076
|
+
raise NotFound()
|
|
1077
|
+
|
|
1078
|
+
return werkzeug.Response(
|
|
1079
|
+
self._event_stream(replay_events, topic), mimetype="text/event-stream"
|
|
1080
|
+
)
|
|
1081
|
+
|
|
1082
|
+
|
|
1083
|
+
__all__ = (
|
|
1084
|
+
"API",
|
|
1085
|
+
"NotFound",
|
|
1086
|
+
"Unauthorized",
|
|
1087
|
+
"UnprocessableEntity",
|
|
1088
|
+
"timestamp",
|
|
1089
|
+
"BaseJWTAuthMiddleware",
|
|
1090
|
+
"ExternalAuth",
|
|
1091
|
+
"DummyAuth",
|
|
1092
|
+
"UsernamePasswordAuth",
|
|
1093
|
+
"run",
|
|
1094
|
+
"PubSub",
|
|
1095
|
+
)
|