python-arango-async 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,139 @@
1
+ __all__ = [
2
+ "AcceptEncoding",
3
+ "ContentEncoding",
4
+ "CompressionManager",
5
+ "DefaultCompressionManager",
6
+ ]
7
+
8
+ import zlib
9
+ from abc import ABC, abstractmethod
10
+ from enum import Enum, auto
11
+ from typing import Optional
12
+
13
+
14
+ class AcceptEncoding(Enum):
15
+ """Valid accepted encodings for the Accept-Encoding header."""
16
+
17
+ DEFLATE = auto()
18
+ GZIP = auto()
19
+ IDENTITY = auto()
20
+
21
+
22
+ class ContentEncoding(Enum):
23
+ """Valid content encodings for the Content-Encoding header."""
24
+
25
+ DEFLATE = auto()
26
+ GZIP = auto()
27
+
28
+
29
+ class CompressionManager(ABC): # pragma: no cover
30
+ """Abstract base class for handling request/response compression."""
31
+
32
+ @abstractmethod
33
+ def needs_compression(self, data: str | bytes) -> bool:
34
+ """Determine if the data needs to be compressed
35
+
36
+ Args:
37
+ data (str | bytes): Data to check
38
+
39
+ Returns:
40
+ bool: True if the data needs to be compressed
41
+ """
42
+ raise NotImplementedError
43
+
44
+ @abstractmethod
45
+ def compress(self, data: str | bytes) -> bytes:
46
+ """Compress the data
47
+
48
+ Args:
49
+ data (str | bytes): Data to compress
50
+
51
+ Returns:
52
+ bytes: Compressed data
53
+ """
54
+ raise NotImplementedError
55
+
56
+ @property
57
+ @abstractmethod
58
+ def content_encoding(self) -> str:
59
+ """Return the content encoding.
60
+
61
+ This is the value of the Content-Encoding header in the HTTP request.
62
+ Must match the encoding used in the compress method.
63
+
64
+ Returns:
65
+ str: Content encoding
66
+ """
67
+ raise NotImplementedError
68
+
69
+ @property
70
+ @abstractmethod
71
+ def accept_encoding(self) -> str | None:
72
+ """Return the accept encoding.
73
+
74
+ This is the value of the Accept-Encoding header in the HTTP request.
75
+ Currently, only "deflate" and "gzip" are supported.
76
+
77
+ Returns:
78
+ str: Accept encoding
79
+ """
80
+ raise NotImplementedError
81
+
82
+
83
+ class DefaultCompressionManager(CompressionManager):
84
+ """Compress requests using the deflate algorithm.
85
+
86
+ Args:
87
+ threshold (int): Will compress requests to the server if
88
+ the size of the request body (in bytes) is at least the value of this option.
89
+ Setting it to -1 will disable request compression.
90
+ level (int): Compression level. Defaults to 6.
91
+ accept (str | None): Accepted encoding. Can be disabled by setting it to `None`.
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ threshold: int = 1024,
97
+ level: int = 6,
98
+ accept: Optional[AcceptEncoding] = AcceptEncoding.DEFLATE,
99
+ ) -> None:
100
+ self._threshold = threshold
101
+ self._level = level
102
+ self._content_encoding = ContentEncoding.DEFLATE.name.lower()
103
+ self._accept_encoding = accept.name.lower() if accept else None
104
+
105
+ @property
106
+ def threshold(self) -> int:
107
+ return self._threshold
108
+
109
+ @threshold.setter
110
+ def threshold(self, value: int) -> None:
111
+ self._threshold = value
112
+
113
+ @property
114
+ def level(self) -> int:
115
+ return self._level
116
+
117
+ @level.setter
118
+ def level(self, value: int) -> None:
119
+ self._level = value
120
+
121
+ @property
122
+ def accept_encoding(self) -> Optional[str]:
123
+ return self._accept_encoding
124
+
125
+ @accept_encoding.setter
126
+ def accept_encoding(self, value: Optional[AcceptEncoding]) -> None:
127
+ self._accept_encoding = value.name.lower() if value else None
128
+
129
+ @property
130
+ def content_encoding(self) -> str:
131
+ return self._content_encoding
132
+
133
+ def needs_compression(self, data: str | bytes) -> bool:
134
+ return len(data) >= self._threshold
135
+
136
+ def compress(self, data: str | bytes) -> bytes:
137
+ if isinstance(data, bytes):
138
+ return zlib.compress(data, self._level)
139
+ return zlib.compress(data.encode("utf-8"), self._level)
@@ -0,0 +1,515 @@
1
+ __all__ = [
2
+ "BaseConnection",
3
+ "BasicConnection",
4
+ "Connection",
5
+ "JwtConnection",
6
+ "JwtSuperuserConnection",
7
+ ]
8
+
9
+ from abc import ABC, abstractmethod
10
+ from typing import Any, List, Optional
11
+
12
+ from jwt import ExpiredSignatureError
13
+
14
+ from arangoasync.auth import Auth, JwtToken
15
+ from arangoasync.compression import CompressionManager
16
+ from arangoasync.errno import HTTP_UNAUTHORIZED
17
+ from arangoasync.exceptions import (
18
+ AuthHeaderError,
19
+ ClientConnectionAbortedError,
20
+ ClientConnectionError,
21
+ DeserializationError,
22
+ JWTRefreshError,
23
+ SerializationError,
24
+ ServerConnectionError,
25
+ )
26
+ from arangoasync.http import HTTPClient
27
+ from arangoasync.logger import logger
28
+ from arangoasync.request import Method, Request
29
+ from arangoasync.resolver import HostResolver
30
+ from arangoasync.response import Response
31
+ from arangoasync.serialization import (
32
+ DefaultDeserializer,
33
+ DefaultSerializer,
34
+ Deserializer,
35
+ Serializer,
36
+ )
37
+ from arangoasync.typings import Json, Jsons
38
+
39
+
40
+ class BaseConnection(ABC):
41
+ """Blueprint for connection to a specific ArangoDB database.
42
+
43
+ Args:
44
+ sessions (list): List of client sessions.
45
+ host_resolver (HostResolver): Host resolver.
46
+ http_client (HTTPClient): HTTP client.
47
+ db_name (str): Database name.
48
+ compression (CompressionManager | None): Compression manager.
49
+ serializer (Serializer | None): For overriding the default JSON serialization.
50
+ Leave `None` for default.
51
+ deserializer (Deserializer | None): For overriding the default JSON
52
+ deserialization. Leave `None` for default.
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ sessions: List[Any],
58
+ host_resolver: HostResolver,
59
+ http_client: HTTPClient,
60
+ db_name: str,
61
+ compression: Optional[CompressionManager] = None,
62
+ serializer: Optional[Serializer[Json]] = None,
63
+ deserializer: Optional[Deserializer[Json, Jsons]] = None,
64
+ ) -> None:
65
+ self._sessions = sessions
66
+ self._db_endpoint = f"/_db/{db_name}"
67
+ self._host_resolver = host_resolver
68
+ self._http_client = http_client
69
+ self._db_name = db_name
70
+ self._compression = compression
71
+ self._serializer: Serializer[Json] = serializer or DefaultSerializer()
72
+ self._deserializer: Deserializer[Json, Jsons] = (
73
+ deserializer or DefaultDeserializer()
74
+ )
75
+
76
+ @property
77
+ def db_name(self) -> str:
78
+ """Return the database name."""
79
+ return self._db_name
80
+
81
+ @property
82
+ def serializer(self) -> Serializer[Json]:
83
+ """Return the serializer."""
84
+ return self._serializer
85
+
86
+ @property
87
+ def deserializer(self) -> Deserializer[Json, Jsons]:
88
+ """Return the deserializer."""
89
+ return self._deserializer
90
+
91
+ @staticmethod
92
+ def raise_for_status(request: Request, resp: Response) -> None:
93
+ """Raise an exception based on the response.
94
+
95
+ Args:
96
+ request (Request): Request object.
97
+ resp (Response): Response object.
98
+
99
+ Raises:
100
+ ServerConnectionError: If the response status code is not successful.
101
+ """
102
+ if resp.status_code in {401, 403}:
103
+ raise ServerConnectionError(resp, request, "Authentication failed.")
104
+ if not resp.is_success:
105
+ raise ServerConnectionError(resp, request, "Bad server response.")
106
+
107
+ def prep_response(self, request: Request, resp: Response) -> Response:
108
+ """Prepare response for return.
109
+
110
+ Args:
111
+ request (Request): Request object.
112
+ resp (Response): Response object.
113
+
114
+ Returns:
115
+ Response: Response object
116
+ """
117
+ resp.is_success = 200 <= resp.status_code < 300
118
+ if not resp.is_success:
119
+ try:
120
+ body = self._deserializer.loads(resp.raw_body)
121
+ except DeserializationError as e:
122
+ logger.debug(
123
+ f"Failed to decode response body: {e} (from request {request})"
124
+ )
125
+ else:
126
+ if body.get("error") is True:
127
+ resp.error_code = body.get("errorNum")
128
+ resp.error_message = body.get("errorMessage")
129
+ return resp
130
+
131
+ def compress_request(self, request: Request) -> bool:
132
+ """Compress request if needed.
133
+
134
+ Additionally, the server may be instructed to compress the response.
135
+ The decision to compress the request is based on the compression strategy
136
+ passed during the connection initialization.
137
+ The request headers and may be modified as a result of this operation.
138
+
139
+ Args:
140
+ request (Request): Request to be compressed.
141
+
142
+ Returns:
143
+ bool: True if compression settings were applied.
144
+ """
145
+ if self._compression is None:
146
+ return False
147
+
148
+ result: bool = False
149
+ if request.data is not None and self._compression.needs_compression(
150
+ request.data
151
+ ):
152
+ request.data = self._compression.compress(request.data)
153
+ request.headers["content-encoding"] = self._compression.content_encoding
154
+ result = True
155
+
156
+ accept_encoding: str | None = self._compression.accept_encoding
157
+ if accept_encoding is not None:
158
+ request.headers["accept-encoding"] = accept_encoding
159
+ result = True
160
+
161
+ return result
162
+
163
+ async def process_request(self, request: Request) -> Response:
164
+ """Process request, potentially trying multiple hosts.
165
+
166
+ Args:
167
+ request (Request): Request object.
168
+
169
+ Returns:
170
+ Response: Response object.
171
+
172
+ Raises:
173
+ ConnectionAbortedError: If it can't connect to host(s) within limit.
174
+ """
175
+
176
+ request.endpoint = f"{self._db_endpoint}{request.endpoint}"
177
+ host_index = self._host_resolver.get_host_index()
178
+ for tries in range(self._host_resolver.max_tries):
179
+ try:
180
+ resp = await self._http_client.send_request(
181
+ self._sessions[host_index], request
182
+ )
183
+ return self.prep_response(request, resp)
184
+ except ClientConnectionError:
185
+ ex_host_index = host_index
186
+ host_index = self._host_resolver.get_host_index()
187
+ if ex_host_index == host_index:
188
+ # Force change host if the same host is selected
189
+ self._host_resolver.change_host()
190
+ host_index = self._host_resolver.get_host_index()
191
+
192
+ raise ClientConnectionAbortedError(
193
+ f"Can't connect to host(s) within limit ({self._host_resolver.max_tries})"
194
+ )
195
+
196
+ async def ping(self) -> int:
197
+ """Ping host to check if connection is established.
198
+
199
+ Returns:
200
+ int: Response status code.
201
+
202
+ Raises:
203
+ ServerConnectionError: If the response status code is not successful.
204
+ """
205
+ request = Request(method=Method.GET, endpoint="/_api/collection")
206
+ resp = await self.send_request(request)
207
+ self.raise_for_status(request, resp)
208
+ return resp.status_code
209
+
210
+ @abstractmethod
211
+ async def send_request(self, request: Request) -> Response: # pragma: no cover
212
+ """Send an HTTP request to the ArangoDB server.
213
+
214
+ Args:
215
+ request (Request): HTTP request.
216
+
217
+ Returns:
218
+ Response: HTTP response.
219
+ """
220
+ raise NotImplementedError
221
+
222
+
223
+ class BasicConnection(BaseConnection):
224
+ """Connection to a specific ArangoDB database.
225
+
226
+ Allows for basic authentication to be used (username and password).
227
+
228
+ Args:
229
+ sessions (list): List of client sessions.
230
+ host_resolver (HostResolver): Host resolver.
231
+ http_client (HTTPClient): HTTP client.
232
+ db_name (str): Database name.
233
+ compression (CompressionManager | None): Compression manager.
234
+ serializer (Serializer | None): Override default JSON serialization.
235
+ deserializer (Deserializer | None): Override default JSON deserialization.
236
+ auth (Auth | None): Authentication information.
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ sessions: List[Any],
242
+ host_resolver: HostResolver,
243
+ http_client: HTTPClient,
244
+ db_name: str,
245
+ compression: Optional[CompressionManager] = None,
246
+ serializer: Optional[Serializer[Json]] = None,
247
+ deserializer: Optional[Deserializer[Json, Jsons]] = None,
248
+ auth: Optional[Auth] = None,
249
+ ) -> None:
250
+ super().__init__(
251
+ sessions,
252
+ host_resolver,
253
+ http_client,
254
+ db_name,
255
+ compression,
256
+ serializer,
257
+ deserializer,
258
+ )
259
+ self._auth = auth
260
+
261
+ async def send_request(self, request: Request) -> Response:
262
+ """Send an HTTP request to the ArangoDB server.
263
+
264
+ Args:
265
+ request (Request): HTTP request.
266
+
267
+ Returns:
268
+ Response: HTTP response
269
+
270
+ Raises:
271
+ ArangoClientError: If an error occurred from the client side.
272
+ ArangoServerError: If an error occurred from the server side.
273
+ """
274
+ self.compress_request(request)
275
+
276
+ if self._auth:
277
+ request.auth = self._auth
278
+
279
+ return await self.process_request(request)
280
+
281
+
282
+ class JwtConnection(BaseConnection):
283
+ """Connection to a specific ArangoDB database, using JWT authentication.
284
+
285
+ Providing login information (username and password), allows to refresh the JWT.
286
+
287
+ Args:
288
+ sessions (list): List of client sessions.
289
+ host_resolver (HostResolver): Host resolver.
290
+ http_client (HTTPClient): HTTP client.
291
+ db_name (str): Database name.
292
+ compression (CompressionManager | None): Compression manager.
293
+ serializer (Serializer | None): For custom serialization.
294
+ deserializer (Deserializer | None): For custom deserialization.
295
+ auth (Auth | None): Authentication information.
296
+ token (JwtToken | None): JWT token.
297
+
298
+ Raises:
299
+ ValueError: If neither token nor auth is provided.
300
+ """
301
+
302
+ def __init__(
303
+ self,
304
+ sessions: List[Any],
305
+ host_resolver: HostResolver,
306
+ http_client: HTTPClient,
307
+ db_name: str,
308
+ compression: Optional[CompressionManager] = None,
309
+ serializer: Optional[Serializer[Json]] = None,
310
+ deserializer: Optional[Deserializer[Json, Jsons]] = None,
311
+ auth: Optional[Auth] = None,
312
+ token: Optional[JwtToken] = None,
313
+ ) -> None:
314
+ super().__init__(
315
+ sessions,
316
+ host_resolver,
317
+ http_client,
318
+ db_name,
319
+ compression,
320
+ serializer,
321
+ deserializer,
322
+ )
323
+ self._auth = auth
324
+ self._expire_leeway: int = 0
325
+ self._token: Optional[JwtToken] = token
326
+ self._auth_header: Optional[str] = None
327
+ self.token = self._token
328
+
329
+ if self._token is None and self._auth is None:
330
+ raise ValueError("Either token or auth must be provided.")
331
+
332
+ @property
333
+ def token(self) -> Optional[JwtToken]:
334
+ """Get the JWT token.
335
+
336
+ Returns:
337
+ JwtToken | None: JWT token.
338
+ """
339
+ return self._token
340
+
341
+ @token.setter
342
+ def token(self, token: Optional[JwtToken]) -> None:
343
+ """Set the JWT token.
344
+
345
+ Args:
346
+ token (JwtToken | None): JWT token.
347
+ Setting it to None will cause the token to be automatically
348
+ refreshed on the next request, if auth information is provided.
349
+ """
350
+ self._token = token
351
+ self._auth_header = f"bearer {self._token.token}" if self._token else None
352
+
353
+ async def refresh_token(self) -> None:
354
+ """Refresh the JWT token.
355
+
356
+ Raises:
357
+ JWTRefreshError: If the token can't be refreshed.
358
+ """
359
+ if self._auth is None:
360
+ raise JWTRefreshError("Auth must be provided to refresh the token.")
361
+
362
+ auth_data = dict(username=self._auth.username, password=self._auth.password)
363
+ try:
364
+ auth = self._serializer.dumps(auth_data)
365
+ except SerializationError as e:
366
+ logger.debug(f"Failed to serialize auth data: {auth_data}")
367
+ raise JWTRefreshError(str(e)) from e
368
+
369
+ request = Request(
370
+ method=Method.POST,
371
+ endpoint="/_open/auth",
372
+ data=auth.encode("utf-8"),
373
+ )
374
+
375
+ try:
376
+ resp = await self.process_request(request)
377
+ except ClientConnectionAbortedError as e:
378
+ raise JWTRefreshError(str(e)) from e
379
+ except ServerConnectionError as e:
380
+ raise JWTRefreshError(str(e)) from e
381
+
382
+ if not resp.is_success:
383
+ raise JWTRefreshError(
384
+ f"Failed to refresh the JWT token: "
385
+ f"{resp.status_code} {resp.status_text}"
386
+ )
387
+
388
+ token = self._deserializer.loads(resp.raw_body)
389
+ try:
390
+ self.token = JwtToken(token["jwt"])
391
+ except ExpiredSignatureError as e:
392
+ raise JWTRefreshError(
393
+ "Failed to refresh the JWT token: got an expired token"
394
+ ) from e
395
+
396
+ async def send_request(self, request: Request) -> Response:
397
+ """Send an HTTP request to the ArangoDB server.
398
+
399
+ Args:
400
+ request (Request): HTTP request.
401
+
402
+ Returns:
403
+ Response: HTTP response
404
+
405
+ Raises:
406
+ AuthHeaderError: If the authentication header could not be generated.
407
+ ArangoClientError: If an error occurred from the client side.
408
+ ArangoServerError: If an error occurred from the server side.
409
+ """
410
+ if self._auth_header is None:
411
+ await self.refresh_token()
412
+
413
+ if self._auth_header is None:
414
+ raise AuthHeaderError("Failed to generate authorization header.")
415
+
416
+ request.headers["authorization"] = self._auth_header
417
+ self.compress_request(request)
418
+
419
+ resp = await self.process_request(request)
420
+ if (
421
+ resp.status_code == HTTP_UNAUTHORIZED
422
+ and self._token is not None
423
+ and self._token.needs_refresh(self._expire_leeway)
424
+ ):
425
+ # If the token has expired, refresh it and retry the request
426
+ await self.refresh_token()
427
+ resp = await self.process_request(request)
428
+ return resp
429
+
430
+
431
+ class JwtSuperuserConnection(BaseConnection):
432
+ """Connection to a specific ArangoDB database, using superuser JWT.
433
+
434
+ The JWT token is not refreshed and (username and password) are not required.
435
+
436
+ Args:
437
+ sessions (list): List of client sessions.
438
+ host_resolver (HostResolver): Host resolver.
439
+ http_client (HTTPClient): HTTP client.
440
+ db_name (str): Database name.
441
+ compression (CompressionManager | None): Compression manager.
442
+ serializer (Serializer | None): For custom serialization.
443
+ deserializer (Deserializer | None): For custom deserialization.
444
+ token (JwtToken | None): JWT token.
445
+ """
446
+
447
+ def __init__(
448
+ self,
449
+ sessions: List[Any],
450
+ host_resolver: HostResolver,
451
+ http_client: HTTPClient,
452
+ db_name: str,
453
+ compression: Optional[CompressionManager] = None,
454
+ serializer: Optional[Serializer[Json]] = None,
455
+ deserializer: Optional[Deserializer[Json, Jsons]] = None,
456
+ token: Optional[JwtToken] = None,
457
+ ) -> None:
458
+ super().__init__(
459
+ sessions,
460
+ host_resolver,
461
+ http_client,
462
+ db_name,
463
+ compression,
464
+ serializer,
465
+ deserializer,
466
+ )
467
+ self._token: Optional[JwtToken] = token
468
+ self._auth_header: Optional[str] = None
469
+ self.token = self._token
470
+
471
+ @property
472
+ def token(self) -> Optional[JwtToken]:
473
+ """Get the JWT token.
474
+
475
+ Returns:
476
+ JwtToken | None: JWT token.
477
+ """
478
+ return self._token
479
+
480
+ @token.setter
481
+ def token(self, token: Optional[JwtToken]) -> None:
482
+ """Set the JWT token.
483
+
484
+ Args:
485
+ token (JwtToken | None): JWT token.
486
+ Setting it to None will cause the token to be automatically
487
+ refreshed on the next request, if auth information is provided.
488
+ """
489
+ self._token = token
490
+ self._auth_header = f"bearer {self._token.token}" if self._token else None
491
+
492
+ async def send_request(self, request: Request) -> Response:
493
+ """Send an HTTP request to the ArangoDB server.
494
+
495
+ Args:
496
+ request (Request): HTTP request.
497
+
498
+ Returns:
499
+ Response: HTTP response
500
+
501
+ Raises:
502
+ AuthHeaderError: If the authentication header could not be generated.
503
+ ArangoClientError: If an error occurred from the client side.
504
+ ArangoServerError: If an error occurred from the server side.
505
+ """
506
+ if self._auth_header is None:
507
+ raise AuthHeaderError("Failed to generate authorization header.")
508
+ request.headers["authorization"] = self._auth_header
509
+ self.compress_request(request)
510
+
511
+ resp = await self.process_request(request)
512
+ return resp
513
+
514
+
515
+ Connection = BasicConnection | JwtConnection | JwtSuperuserConnection