nucliadb-utils 2.44.2.post353__py3-none-any.whl → 2.44.2.post354__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
nucliadb_utils/const.py CHANGED
@@ -78,3 +78,4 @@ class Features:
78
78
  VERSIONED_PRIVATE_PREDICT = "nucliadb_versioned_private_predict"
79
79
  BACK_PRESSURE = "nucliadb_back_pressure"
80
80
  REBALANCE_KB = "nucliadb_rebalance_kb"
81
+ CORS_MIDDLEWARE = "nucliadb_cors_middleware_enabled"
@@ -61,6 +61,10 @@ DEFAULT_FLAG_DATA: dict[str, Any] = {
61
61
  "rollout": 0,
62
62
  "variants": {"environment": ["local"]},
63
63
  },
64
+ const.Features.CORS_MIDDLEWARE: {
65
+ "rollout": 0,
66
+ "variants": {"environment": ["local"]},
67
+ },
64
68
  }
65
69
 
66
70
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nucliadb-utils
3
- Version: 2.44.2.post353
3
+ Version: 2.44.2.post354
4
4
  Home-page: https://nuclia.com
5
5
  License: BSD
6
6
  Classifier: Development Status :: 4 - Beta
@@ -21,8 +21,8 @@ Requires-Dist: nats-py[nkeys] >=2.6.0
21
21
  Requires-Dist: pyjwt >=2.4.0
22
22
  Requires-Dist: memorylru >=1.1.2
23
23
  Requires-Dist: mrflagly
24
- Requires-Dist: nucliadb-protos >=2.44.2.post353
25
- Requires-Dist: nucliadb-telemetry >=2.44.2.post353
24
+ Requires-Dist: nucliadb-protos >=2.44.2.post354
25
+ Requires-Dist: nucliadb-telemetry >=2.44.2.post354
26
26
  Provides-Extra: cache
27
27
  Requires-Dist: redis >=4.3.4 ; extra == 'cache'
28
28
  Requires-Dist: orjson >=3.6.7 ; extra == 'cache'
@@ -1,11 +1,10 @@
1
1
  nucliadb_utils/__init__.py,sha256=EvBCH1iTODe-AgXm48aj4kVUt_Std3PeL8QnwimR5wI,895
2
2
  nucliadb_utils/asyncio_utils.py,sha256=i4u0NP32zx0GWB7L5bc5eSWc-zGIxDMP7rQZhoYxRxo,2794
3
3
  nucliadb_utils/authentication.py,sha256=HTj6kFFrfdN4-DW_MNqldAPqNPdPwGHnSWl-SigiMvk,6071
4
- nucliadb_utils/const.py,sha256=FYACRPdKhjcNgWt5tCMpjnEYQq_4ZYbxrEAmIPIFpME,2355
5
- nucliadb_utils/cors.py,sha256=s4PQZ5rZOOYxiHA2oBi8LKfmuHhZTwlZiyEfWM3c5ag,8406
4
+ nucliadb_utils/const.py,sha256=yd7We31W7-B7zI-Wt2x5QDTXb78gBNeNXpoyK1IOKb0,2412
6
5
  nucliadb_utils/debug.py,sha256=saSfh_CDQoQl-35KCyqef5hdB_OVdrIEnlmWnZU18vg,2470
7
6
  nucliadb_utils/exceptions.py,sha256=y_3wk77WLVUtdo-5FtbBsdSkCtK_DsJkdWb5BoPn3qo,1094
8
- nucliadb_utils/featureflagging.py,sha256=hc7sestzm9_usfZN6_CF7RSs4UOx3CufrQ93ZxnIb3c,2711
7
+ nucliadb_utils/featureflagging.py,sha256=Rxr999pbYCA5QTzyubxUft2MSA4Tlc_gTYd35YDIBkQ,2826
9
8
  nucliadb_utils/grpc.py,sha256=USXwPRuCJiSyLf0JW4isIKZne6od8gM_GGLWaHHjknk,3336
10
9
  nucliadb_utils/helpers.py,sha256=fOL6eImdvKO3NV39ymmo2UOCT-GAK1dfXKoMKdzdmFo,1599
11
10
  nucliadb_utils/indexing.py,sha256=ELJ9bTnrxcb6tmO77HzzO1SDLPmKXTc_nBiLnJocz5I,3462
@@ -52,7 +51,6 @@ nucliadb_utils/tests/s3.py,sha256=YB8QqDaBXxyhHonEHmeBbRRDmvB7sTOaKBSi8KBGokg,23
52
51
  nucliadb_utils/tests/unit/__init__.py,sha256=itSI7dtTwFP55YMX4iK7JzdMHS5CQVUiB1XzQu4UBh8,833
53
52
  nucliadb_utils/tests/unit/test_asyncio_utils.py,sha256=Dfs-lEa6ZGp6m-dOmhtB8PGuQQ9CUOfVbX-YvUEfdrc,2339
54
53
  nucliadb_utils/tests/unit/test_authentication.py,sha256=clxj4mmmUnHbUh9dvvtoA2XNruZSxRPEXD68MXXzgzU,5361
55
- nucliadb_utils/tests/unit/test_cors.py,sha256=7Gifju-MXjO0x_MtfB_STd2nVhr3jdsl4kntgpTW6-8,22062
56
54
  nucliadb_utils/tests/unit/test_helpers.py,sha256=w0Yxo1NteBZi_iSRVZ3vv4J-o3Z7SJIz1q6yyXGzQLg,1574
57
55
  nucliadb_utils/tests/unit/test_nats.py,sha256=4lLQOucQGwL9lubTfNFqXhXtua9u4lTlWjF_lFCYORc,4436
58
56
  nucliadb_utils/tests/unit/test_run.py,sha256=VHeojVBj_AfV_uNch9eEi4zCT-O94VKjwb_O-UZgqpA,1910
@@ -65,8 +63,8 @@ nucliadb_utils/tests/unit/storages/test_aws.py,sha256=GCsB_jwCUNV3Ogt8TZZEmNKAHv
65
63
  nucliadb_utils/tests/unit/storages/test_gcs.py,sha256=2XzJwgNpfjVGjtE-QdZhu3ayuT1EMEXINdM-_SatPCY,3554
66
64
  nucliadb_utils/tests/unit/storages/test_pg.py,sha256=sJfUttMSzq8W1XYolAUcMxl_R5HcEzb5fpCklPeMJiY,17000
67
65
  nucliadb_utils/tests/unit/storages/test_storage.py,sha256=VFpRq6Q6BjnIrBQCumYzR8DQUacwhxt5CzTKSlqqD24,6892
68
- nucliadb_utils-2.44.2.post353.dist-info/METADATA,sha256=aRrRVotMIrJNyXT12mr6MjRVriJ7mtEYtBRl8MCRLro,1978
69
- nucliadb_utils-2.44.2.post353.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
70
- nucliadb_utils-2.44.2.post353.dist-info/top_level.txt,sha256=fE3vJtALTfgh7bcAWcNhcfXkNPp_eVVpbKK-2IYua3E,15
71
- nucliadb_utils-2.44.2.post353.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
72
- nucliadb_utils-2.44.2.post353.dist-info/RECORD,,
66
+ nucliadb_utils-2.44.2.post354.dist-info/METADATA,sha256=ry3N2eeoqt0rxNMZIUgkfPZf7OJMKm0-EyDGp53Pmcw,1978
67
+ nucliadb_utils-2.44.2.post354.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
68
+ nucliadb_utils-2.44.2.post354.dist-info/top_level.txt,sha256=fE3vJtALTfgh7bcAWcNhcfXkNPp_eVVpbKK-2IYua3E,15
69
+ nucliadb_utils-2.44.2.post354.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
70
+ nucliadb_utils-2.44.2.post354.dist-info/RECORD,,
nucliadb_utils/cors.py DELETED
@@ -1,210 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import functools
21
- import logging
22
- import re
23
- import typing
24
-
25
- from starlette.datastructures import Headers, MutableHeaders
26
- from starlette.responses import PlainTextResponse, Response
27
- from starlette.types import ASGIApp, Message, Receive, Scope, Send
28
-
29
- ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT")
30
- SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"}
31
-
32
- logger = logging.getLogger("cors_debug")
33
-
34
-
35
- class CORSMiddleware:
36
- def __init__(
37
- self,
38
- app: ASGIApp,
39
- allow_origins: typing.Sequence[str] = (),
40
- allow_methods: typing.Sequence[str] = ("GET",),
41
- allow_headers: typing.Sequence[str] = (),
42
- allow_credentials: bool = False,
43
- allow_origin_regex: typing.Optional[str] = None,
44
- expose_headers: typing.Sequence[str] = (),
45
- max_age: int = 600,
46
- ) -> None:
47
- if "*" in allow_methods:
48
- allow_methods = ALL_METHODS
49
-
50
- compiled_allow_origin_regex = None
51
- if allow_origin_regex is not None:
52
- compiled_allow_origin_regex = re.compile(allow_origin_regex)
53
-
54
- allow_all_origins = "*" in allow_origins
55
- allow_all_headers = "*" in allow_headers
56
- preflight_explicit_allow_origin = not allow_all_origins or allow_credentials
57
-
58
- simple_headers = {}
59
- if allow_all_origins:
60
- simple_headers["Access-Control-Allow-Origin"] = "*"
61
- if allow_credentials:
62
- simple_headers["Access-Control-Allow-Credentials"] = "true"
63
- if expose_headers:
64
- simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
65
-
66
- preflight_headers = {}
67
- if preflight_explicit_allow_origin:
68
- # The origin value will be set in preflight_response() if it is allowed.
69
- preflight_headers["Vary"] = "Origin"
70
- else:
71
- preflight_headers["Access-Control-Allow-Origin"] = "*"
72
- preflight_headers.update(
73
- {
74
- "Access-Control-Allow-Methods": ", ".join(allow_methods),
75
- "Access-Control-Max-Age": str(max_age),
76
- }
77
- )
78
- allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
79
- if allow_headers and not allow_all_headers:
80
- preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
81
- if allow_credentials:
82
- preflight_headers["Access-Control-Allow-Credentials"] = "true"
83
-
84
- self.app = app
85
- self.allow_origins = allow_origins
86
- self.allow_methods = allow_methods
87
- self.allow_headers = [h.lower() for h in allow_headers]
88
- self.allow_all_origins = allow_all_origins
89
- self.allow_all_headers = allow_all_headers
90
- self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
91
- self.allow_origin_regex = compiled_allow_origin_regex
92
- self.simple_headers = simple_headers
93
- self.preflight_headers = preflight_headers
94
-
95
- async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
96
- if scope["type"] != "http": # pragma: no cover
97
- await self.app(scope, receive, send)
98
- return
99
-
100
- method = scope["method"]
101
- headers = Headers(scope=scope)
102
- origin = headers.get("origin")
103
-
104
- if origin is None:
105
- await self.app(scope, receive, send)
106
- return
107
-
108
- if method == "OPTIONS" and "access-control-request-method" in headers:
109
- response = self.preflight_response(request_headers=headers)
110
- await response(scope, receive, send)
111
- return
112
-
113
- await self.simple_response(scope, receive, send, request_headers=headers)
114
-
115
- def is_allowed_origin(
116
- self, origin: str, allowed_origins: typing.Optional[str]
117
- ) -> bool:
118
- if allowed_origins:
119
- return origin in allowed_origins.split(",")
120
-
121
- if self.allow_all_origins:
122
- return True
123
-
124
- if self.allow_origin_regex is not None and self.allow_origin_regex.fullmatch(
125
- origin
126
- ):
127
- return True
128
-
129
- return origin in self.allow_origins
130
-
131
- def preflight_response(self, request_headers: Headers) -> Response:
132
- logger.info(f"CORS DEBUG REQUEST: {str(request_headers)}")
133
- requested_origin = request_headers["origin"]
134
- requested_method = request_headers["access-control-request-method"]
135
- requested_headers = request_headers.get("access-control-request-headers")
136
-
137
- headers = dict(self.preflight_headers)
138
- failures = []
139
-
140
- allowed_origins = request_headers.get("x-nucliadb-cors-allowed-origins")
141
- if self.is_allowed_origin(
142
- origin=requested_origin, allowed_origins=allowed_origins
143
- ):
144
- if self.preflight_explicit_allow_origin:
145
- # The "else" case is already accounted for in self.preflight_headers
146
- # and the value would be "*".
147
- headers["Access-Control-Allow-Origin"] = requested_origin
148
- else:
149
- failures.append("origin")
150
-
151
- if requested_method not in self.allow_methods:
152
- failures.append("method")
153
-
154
- # If we allow all headers, then we have to mirror back any requested
155
- # headers in the response.
156
- if self.allow_all_headers and requested_headers is not None:
157
- headers["Access-Control-Allow-Headers"] = requested_headers
158
- elif requested_headers is not None:
159
- for header in [h.lower() for h in requested_headers.split(",")]:
160
- if header.strip() not in self.allow_headers:
161
- failures.append("headers")
162
- break
163
-
164
- # We don't strictly need to use 400 responses here, since its up to
165
- # the browser to enforce the CORS policy, but its more informative
166
- # if we do.
167
- if failures:
168
- failure_text = "Disallowed CORS " + ", ".join(failures)
169
- return PlainTextResponse(failure_text, status_code=400, headers=headers)
170
-
171
- return PlainTextResponse("OK", status_code=200, headers=headers)
172
-
173
- async def simple_response(
174
- self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
175
- ) -> None:
176
- send = functools.partial(self.send, send=send, request_headers=request_headers)
177
- await self.app(scope, receive, send)
178
-
179
- async def send(
180
- self, message: Message, send: Send, request_headers: Headers
181
- ) -> None:
182
- if message["type"] != "http.response.start":
183
- await send(message)
184
- return
185
-
186
- message.setdefault("headers", [])
187
- headers = MutableHeaders(scope=message)
188
- headers.update(self.simple_headers)
189
- origin = request_headers["Origin"]
190
- has_cookie = "cookie" in request_headers
191
-
192
- # If request includes any cookie headers, then we must respond
193
- # with the specific origin instead of '*'.
194
- if self.allow_all_origins and has_cookie:
195
- self.allow_explicit_origin(headers, origin)
196
-
197
- # If we only allow specific origins, then we have to mirror back
198
- # the Origin header in the response.
199
- elif not self.allow_all_origins and self.is_allowed_origin(
200
- origin=origin,
201
- allowed_origins=headers.get("x-nucliadb-cors-allowed-origins"),
202
- ):
203
- self.allow_explicit_origin(headers, origin)
204
-
205
- await send(message)
206
-
207
- @staticmethod
208
- def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None:
209
- headers["Access-Control-Allow-Origin"] = origin
210
- headers.add_vary_header("Origin")
@@ -1,633 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
-
20
- import functools
21
- from typing import Any, Callable, Literal
22
-
23
- import pytest
24
- from starlette.applications import Starlette
25
- from starlette.middleware import Middleware
26
- from starlette.requests import Request
27
- from starlette.responses import PlainTextResponse
28
- from starlette.routing import Route
29
- from starlette.testclient import TestClient
30
- from starlette.types import ASGIApp
31
-
32
- from nucliadb_utils.cors import CORSMiddleware
33
-
34
- TestClientFactory = Callable[[ASGIApp], TestClient]
35
-
36
-
37
- @pytest.fixture
38
- def anyio_backend():
39
- return "asyncio"
40
-
41
-
42
- @pytest.fixture
43
- def test_client_factory(
44
- anyio_backend_name: Literal["asyncio"],
45
- anyio_backend_options: dict[str, Any],
46
- ) -> TestClientFactory:
47
- # anyio_backend_name defined by:
48
- # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
49
- return functools.partial(
50
- TestClient,
51
- backend=anyio_backend_name,
52
- backend_options=anyio_backend_options,
53
- )
54
-
55
-
56
- def test_cors_allow_all(
57
- test_client_factory: TestClientFactory,
58
- ) -> None:
59
- def homepage(request: Request) -> PlainTextResponse:
60
- return PlainTextResponse("Homepage", status_code=200)
61
-
62
- app = Starlette(
63
- routes=[Route("/", endpoint=homepage)],
64
- middleware=[
65
- Middleware(
66
- CORSMiddleware,
67
- allow_origins=["*"],
68
- allow_headers=["*"],
69
- allow_methods=["*"],
70
- expose_headers=["X-Status"],
71
- allow_credentials=True,
72
- )
73
- ],
74
- )
75
-
76
- client = test_client_factory(app)
77
-
78
- # Test pre-flight response
79
- headers = {
80
- "Origin": "https://example.org",
81
- "Access-Control-Request-Method": "GET",
82
- "Access-Control-Request-Headers": "X-Example",
83
- }
84
- response = client.options("/", headers=headers)
85
- assert response.status_code == 200
86
- assert response.text == "OK"
87
- assert response.headers["access-control-allow-origin"] == "https://example.org"
88
- assert response.headers["access-control-allow-headers"] == "X-Example"
89
- assert response.headers["access-control-allow-credentials"] == "true"
90
- assert response.headers["vary"] == "Origin"
91
-
92
- # Test standard response
93
- headers = {"Origin": "https://example.org"}
94
- response = client.get("/", headers=headers)
95
- assert response.status_code == 200
96
- assert response.text == "Homepage"
97
- assert response.headers["access-control-allow-origin"] == "*"
98
- assert response.headers["access-control-expose-headers"] == "X-Status"
99
- assert response.headers["access-control-allow-credentials"] == "true"
100
-
101
- # Test standard credentialed response
102
- headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
103
- response = client.get("/", headers=headers)
104
- assert response.status_code == 200
105
- assert response.text == "Homepage"
106
- assert response.headers["access-control-allow-origin"] == "https://example.org"
107
- assert response.headers["access-control-expose-headers"] == "X-Status"
108
- assert response.headers["access-control-allow-credentials"] == "true"
109
-
110
- # Test non-CORS response
111
- response = client.get("/")
112
- assert response.status_code == 200
113
- assert response.text == "Homepage"
114
- assert "access-control-allow-origin" not in response.headers
115
-
116
-
117
- def test_cors_allow_all_except_credentials(
118
- test_client_factory: TestClientFactory,
119
- ) -> None:
120
- def homepage(request: Request) -> PlainTextResponse:
121
- return PlainTextResponse("Homepage", status_code=200)
122
-
123
- app = Starlette(
124
- routes=[Route("/", endpoint=homepage)],
125
- middleware=[
126
- Middleware(
127
- CORSMiddleware,
128
- allow_origins=["*"],
129
- allow_headers=["*"],
130
- allow_methods=["*"],
131
- expose_headers=["X-Status"],
132
- )
133
- ],
134
- )
135
-
136
- client = test_client_factory(app)
137
-
138
- # Test pre-flight response
139
- headers = {
140
- "Origin": "https://example.org",
141
- "Access-Control-Request-Method": "GET",
142
- "Access-Control-Request-Headers": "X-Example",
143
- }
144
- response = client.options("/", headers=headers)
145
- assert response.status_code == 200
146
- assert response.text == "OK"
147
- assert response.headers["access-control-allow-origin"] == "*"
148
- assert response.headers["access-control-allow-headers"] == "X-Example"
149
- assert "access-control-allow-credentials" not in response.headers
150
- assert "vary" not in response.headers
151
-
152
- # Test standard response
153
- headers = {"Origin": "https://example.org"}
154
- response = client.get("/", headers=headers)
155
- assert response.status_code == 200
156
- assert response.text == "Homepage"
157
- assert response.headers["access-control-allow-origin"] == "*"
158
- assert response.headers["access-control-expose-headers"] == "X-Status"
159
- assert "access-control-allow-credentials" not in response.headers
160
-
161
- # Test non-CORS response
162
- response = client.get("/")
163
- assert response.status_code == 200
164
- assert response.text == "Homepage"
165
- assert "access-control-allow-origin" not in response.headers
166
-
167
-
168
- def test_cors_allow_specific_origin(
169
- test_client_factory: TestClientFactory,
170
- ) -> None:
171
- def homepage(request: Request) -> PlainTextResponse:
172
- return PlainTextResponse("Homepage", status_code=200)
173
-
174
- app = Starlette(
175
- routes=[Route("/", endpoint=homepage)],
176
- middleware=[
177
- Middleware(
178
- CORSMiddleware,
179
- allow_origins=["https://example.org"],
180
- allow_headers=["X-Example", "Content-Type"],
181
- )
182
- ],
183
- )
184
-
185
- client = test_client_factory(app)
186
-
187
- # Test pre-flight response
188
- headers = {
189
- "Origin": "https://example.org",
190
- "Access-Control-Request-Method": "GET",
191
- "Access-Control-Request-Headers": "X-Example, Content-Type",
192
- }
193
- response = client.options("/", headers=headers)
194
- assert response.status_code == 200
195
- assert response.text == "OK"
196
- assert response.headers["access-control-allow-origin"] == "https://example.org"
197
- assert response.headers["access-control-allow-headers"] == (
198
- "Accept, Accept-Language, Content-Language, Content-Type, X-Example"
199
- )
200
- assert "access-control-allow-credentials" not in response.headers
201
-
202
- # Test standard response
203
- headers = {"Origin": "https://example.org"}
204
- response = client.get("/", headers=headers)
205
- assert response.status_code == 200
206
- assert response.text == "Homepage"
207
- assert response.headers["access-control-allow-origin"] == "https://example.org"
208
- assert "access-control-allow-credentials" not in response.headers
209
-
210
- # Test non-CORS response
211
- response = client.get("/")
212
- assert response.status_code == 200
213
- assert response.text == "Homepage"
214
- assert "access-control-allow-origin" not in response.headers
215
-
216
-
217
- def test_cors_disallowed_preflight(
218
- test_client_factory: TestClientFactory,
219
- ) -> None:
220
- def homepage(request: Request) -> None:
221
- pass # pragma: no cover
222
-
223
- app = Starlette(
224
- routes=[Route("/", endpoint=homepage)],
225
- middleware=[
226
- Middleware(
227
- CORSMiddleware,
228
- allow_origins=["https://example.org"],
229
- allow_headers=["X-Example"],
230
- )
231
- ],
232
- )
233
-
234
- client = test_client_factory(app)
235
-
236
- # Test pre-flight response
237
- headers = {
238
- "Origin": "https://another.org",
239
- "Access-Control-Request-Method": "POST",
240
- "Access-Control-Request-Headers": "X-Nope",
241
- }
242
- response = client.options("/", headers=headers)
243
- assert response.status_code == 400
244
- assert response.text == "Disallowed CORS origin, method, headers"
245
- assert "access-control-allow-origin" not in response.headers
246
-
247
- # Bug specific test, https://github.com/encode/starlette/pull/1199
248
- # Test preflight response text with multiple disallowed headers
249
- headers = {
250
- "Origin": "https://example.org",
251
- "Access-Control-Request-Method": "GET",
252
- "Access-Control-Request-Headers": "X-Nope-1, X-Nope-2",
253
- }
254
- response = client.options("/", headers=headers)
255
- assert response.text == "Disallowed CORS headers"
256
-
257
-
258
- def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(
259
- test_client_factory: TestClientFactory,
260
- ) -> None:
261
- def homepage(request: Request) -> None:
262
- return # pragma: no cover
263
-
264
- app = Starlette(
265
- routes=[Route("/", endpoint=homepage)],
266
- middleware=[
267
- Middleware(
268
- CORSMiddleware,
269
- allow_origins=["*"],
270
- allow_methods=["POST"],
271
- allow_credentials=True,
272
- )
273
- ],
274
- )
275
-
276
- client = test_client_factory(app)
277
-
278
- # Test pre-flight response
279
- headers = {
280
- "Origin": "https://example.org",
281
- "Access-Control-Request-Method": "POST",
282
- }
283
- response = client.options(
284
- "/",
285
- headers=headers,
286
- )
287
- assert response.status_code == 200
288
- assert response.headers["access-control-allow-origin"] == "https://example.org"
289
- assert response.headers["access-control-allow-credentials"] == "true"
290
- assert response.headers["vary"] == "Origin"
291
-
292
-
293
- def test_cors_preflight_allow_all_methods(
294
- test_client_factory: TestClientFactory,
295
- ) -> None:
296
- def homepage(request: Request) -> None:
297
- pass # pragma: no cover
298
-
299
- app = Starlette(
300
- routes=[Route("/", endpoint=homepage)],
301
- middleware=[
302
- Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
303
- ],
304
- )
305
-
306
- client = test_client_factory(app)
307
-
308
- headers = {
309
- "Origin": "https://example.org",
310
- "Access-Control-Request-Method": "POST",
311
- }
312
-
313
- for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"):
314
- response = client.options("/", headers=headers)
315
- assert response.status_code == 200
316
- assert method in response.headers["access-control-allow-methods"]
317
-
318
-
319
- def test_cors_allow_all_methods(
320
- test_client_factory: TestClientFactory,
321
- ) -> None:
322
- def homepage(request: Request) -> PlainTextResponse:
323
- return PlainTextResponse("Homepage", status_code=200)
324
-
325
- app = Starlette(
326
- routes=[
327
- Route(
328
- "/",
329
- endpoint=homepage,
330
- methods=["delete", "get", "head", "options", "patch", "post", "put"],
331
- )
332
- ],
333
- middleware=[
334
- Middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"])
335
- ],
336
- )
337
-
338
- client = test_client_factory(app)
339
-
340
- headers = {"Origin": "https://example.org"}
341
-
342
- for method in ("patch", "post", "put"):
343
- response = getattr(client, method)("/", headers=headers, json={})
344
- assert response.status_code == 200
345
- for method in ("delete", "get", "head", "options"):
346
- response = getattr(client, method)("/", headers=headers)
347
- assert response.status_code == 200
348
-
349
-
350
- def test_cors_allow_origin_regex(
351
- test_client_factory: TestClientFactory,
352
- ) -> None:
353
- def homepage(request: Request) -> PlainTextResponse:
354
- return PlainTextResponse("Homepage", status_code=200)
355
-
356
- app = Starlette(
357
- routes=[Route("/", endpoint=homepage)],
358
- middleware=[
359
- Middleware(
360
- CORSMiddleware,
361
- allow_headers=["X-Example", "Content-Type"],
362
- allow_origin_regex="https://.*",
363
- allow_credentials=True,
364
- )
365
- ],
366
- )
367
-
368
- client = test_client_factory(app)
369
-
370
- # Test standard response
371
- headers = {"Origin": "https://example.org"}
372
- response = client.get("/", headers=headers)
373
- assert response.status_code == 200
374
- assert response.text == "Homepage"
375
- assert response.headers["access-control-allow-origin"] == "https://example.org"
376
- assert response.headers["access-control-allow-credentials"] == "true"
377
-
378
- # Test standard credentialed response
379
- headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
380
- response = client.get("/", headers=headers)
381
- assert response.status_code == 200
382
- assert response.text == "Homepage"
383
- assert response.headers["access-control-allow-origin"] == "https://example.org"
384
- assert response.headers["access-control-allow-credentials"] == "true"
385
-
386
- # Test disallowed standard response
387
- # Note that enforcement is a browser concern. The disallowed-ness is reflected
388
- # in the lack of an "access-control-allow-origin" header in the response.
389
- headers = {"Origin": "http://example.org"}
390
- response = client.get("/", headers=headers)
391
- assert response.status_code == 200
392
- assert response.text == "Homepage"
393
- assert "access-control-allow-origin" not in response.headers
394
-
395
- # Test pre-flight response
396
- headers = {
397
- "Origin": "https://another.com",
398
- "Access-Control-Request-Method": "GET",
399
- "Access-Control-Request-Headers": "X-Example, content-type",
400
- }
401
- response = client.options("/", headers=headers)
402
- assert response.status_code == 200
403
- assert response.text == "OK"
404
- assert response.headers["access-control-allow-origin"] == "https://another.com"
405
- assert response.headers["access-control-allow-headers"] == (
406
- "Accept, Accept-Language, Content-Language, Content-Type, X-Example"
407
- )
408
- assert response.headers["access-control-allow-credentials"] == "true"
409
-
410
- # Test disallowed pre-flight response
411
- headers = {
412
- "Origin": "http://another.com",
413
- "Access-Control-Request-Method": "GET",
414
- "Access-Control-Request-Headers": "X-Example",
415
- }
416
- response = client.options("/", headers=headers)
417
- assert response.status_code == 400
418
- assert response.text == "Disallowed CORS origin"
419
- assert "access-control-allow-origin" not in response.headers
420
-
421
-
422
- def test_cors_allow_origin_regex_fullmatch(
423
- test_client_factory: TestClientFactory,
424
- ) -> None:
425
- def homepage(request: Request) -> PlainTextResponse:
426
- return PlainTextResponse("Homepage", status_code=200)
427
-
428
- app = Starlette(
429
- routes=[Route("/", endpoint=homepage)],
430
- middleware=[
431
- Middleware(
432
- CORSMiddleware,
433
- allow_headers=["X-Example", "Content-Type"],
434
- allow_origin_regex=r"https://.*\.example.org",
435
- )
436
- ],
437
- )
438
-
439
- client = test_client_factory(app)
440
-
441
- # Test standard response
442
- headers = {"Origin": "https://subdomain.example.org"}
443
- response = client.get("/", headers=headers)
444
- assert response.status_code == 200
445
- assert response.text == "Homepage"
446
- assert (
447
- response.headers["access-control-allow-origin"]
448
- == "https://subdomain.example.org"
449
- )
450
- assert "access-control-allow-credentials" not in response.headers
451
-
452
- # Test diallowed standard response
453
- headers = {"Origin": "https://subdomain.example.org.hacker.com"}
454
- response = client.get("/", headers=headers)
455
- assert response.status_code == 200
456
- assert response.text == "Homepage"
457
- assert "access-control-allow-origin" not in response.headers
458
-
459
-
460
- def test_cors_credentialed_requests_return_specific_origin(
461
- test_client_factory: TestClientFactory,
462
- ) -> None:
463
- def homepage(request: Request) -> PlainTextResponse:
464
- return PlainTextResponse("Homepage", status_code=200)
465
-
466
- app = Starlette(
467
- routes=[Route("/", endpoint=homepage)],
468
- middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
469
- )
470
- client = test_client_factory(app)
471
-
472
- # Test credentialed request
473
- headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
474
- response = client.get("/", headers=headers)
475
- assert response.status_code == 200
476
- assert response.text == "Homepage"
477
- assert response.headers["access-control-allow-origin"] == "https://example.org"
478
- assert "access-control-allow-credentials" not in response.headers
479
-
480
-
481
- def test_cors_vary_header_defaults_to_origin(
482
- test_client_factory: TestClientFactory,
483
- ) -> None:
484
- def homepage(request: Request) -> PlainTextResponse:
485
- return PlainTextResponse("Homepage", status_code=200)
486
-
487
- app = Starlette(
488
- routes=[Route("/", endpoint=homepage)],
489
- middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
490
- )
491
-
492
- headers = {"Origin": "https://example.org"}
493
-
494
- client = test_client_factory(app)
495
-
496
- response = client.get("/", headers=headers)
497
- assert response.status_code == 200
498
- assert response.headers["vary"] == "Origin"
499
-
500
-
501
- def test_cors_vary_header_is_not_set_for_non_credentialed_request(
502
- test_client_factory: TestClientFactory,
503
- ) -> None:
504
- def homepage(request: Request) -> PlainTextResponse:
505
- return PlainTextResponse(
506
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
507
- )
508
-
509
- app = Starlette(
510
- routes=[Route("/", endpoint=homepage)],
511
- middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
512
- )
513
- client = test_client_factory(app)
514
-
515
- response = client.get("/", headers={"Origin": "https://someplace.org"})
516
- assert response.status_code == 200
517
- assert response.headers["vary"] == "Accept-Encoding"
518
-
519
-
520
- def test_cors_vary_header_is_properly_set_for_credentialed_request(
521
- test_client_factory: TestClientFactory,
522
- ) -> None:
523
- def homepage(request: Request) -> PlainTextResponse:
524
- return PlainTextResponse(
525
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
526
- )
527
-
528
- app = Starlette(
529
- routes=[Route("/", endpoint=homepage)],
530
- middleware=[Middleware(CORSMiddleware, allow_origins=["*"])],
531
- )
532
- client = test_client_factory(app)
533
-
534
- response = client.get(
535
- "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
536
- )
537
- assert response.status_code == 200
538
- assert response.headers["vary"] == "Accept-Encoding, Origin"
539
-
540
-
541
- def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(
542
- test_client_factory: TestClientFactory,
543
- ) -> None:
544
- def homepage(request: Request) -> PlainTextResponse:
545
- return PlainTextResponse(
546
- "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"}
547
- )
548
-
549
- app = Starlette(
550
- routes=[
551
- Route("/", endpoint=homepage),
552
- ],
553
- middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
554
- )
555
- client = test_client_factory(app)
556
-
557
- response = client.get("/", headers={"Origin": "https://example.org"})
558
- assert response.status_code == 200
559
- assert response.headers["vary"] == "Accept-Encoding, Origin"
560
-
561
-
562
- def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(
563
- test_client_factory: TestClientFactory,
564
- ) -> None:
565
- def homepage(request: Request) -> PlainTextResponse:
566
- return PlainTextResponse("Homepage", status_code=200)
567
-
568
- app = Starlette(
569
- routes=[
570
- Route("/", endpoint=homepage),
571
- ],
572
- middleware=[
573
- Middleware(
574
- CORSMiddleware,
575
- allow_origins=["*"],
576
- allow_headers=["*"],
577
- allow_methods=["*"],
578
- )
579
- ],
580
- )
581
-
582
- client = test_client_factory(app)
583
- response = client.get("/", headers={"Origin": "https://someplace.org"})
584
- assert response.headers["access-control-allow-origin"] == "*"
585
- assert "access-control-allow-credentials" not in response.headers
586
-
587
- response = client.get(
588
- "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
589
- )
590
- assert response.headers["access-control-allow-origin"] == "https://someplace.org"
591
- assert "access-control-allow-credentials" not in response.headers
592
-
593
- response = client.get("/", headers={"Origin": "https://someplace.org"})
594
- assert response.headers["access-control-allow-origin"] == "*"
595
- assert "access-control-allow-credentials" not in response.headers
596
-
597
-
598
- def test_cors_x_nucliadb_cors_allowed_origins(
599
- test_client_factory: TestClientFactory,
600
- ) -> None:
601
- def homepage(request: Request) -> PlainTextResponse:
602
- return PlainTextResponse(
603
- "Homepage",
604
- status_code=200,
605
- headers={
606
- "Vary": "Accept-Encoding",
607
- "x-nucliadb-cors-allowed-origins": "https://example-a.org,https://example-b.org",
608
- },
609
- )
610
-
611
- app = Starlette(
612
- routes=[
613
- Route("/", endpoint=homepage),
614
- ],
615
- middleware=[Middleware(CORSMiddleware, allow_origins=["https://example.org"])],
616
- )
617
- client = test_client_factory(app)
618
-
619
- response = client.get("/", headers={"Origin": "https://example.org"})
620
-
621
- assert response.status_code == 200
622
- assert response.headers["vary"] == "Accept-Encoding"
623
- assert "access-control-allow-origin" not in response.headers
624
-
625
- response = client.get("/", headers={"Origin": "https://example-a.org"})
626
- assert response.status_code == 200
627
- assert response.headers["vary"] == "Accept-Encoding, Origin"
628
- assert response.headers["access-control-allow-origin"] == "https://example-a.org"
629
-
630
- response = client.get("/", headers={"Origin": "https://example-b.org"})
631
- assert response.status_code == 200
632
- assert response.headers["vary"] == "Accept-Encoding, Origin"
633
- assert response.headers["access-control-allow-origin"] == "https://example-b.org"