cafs-cache-cdn-client 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,535 @@
1
+ import asyncio
2
+ from asyncio import Queue, QueueShutDown
3
+ from collections.abc import (
4
+ AsyncIterator,
5
+ Callable,
6
+ Collection,
7
+ Coroutine,
8
+ MutableMapping,
9
+ )
10
+ from contextlib import asynccontextmanager
11
+ from enum import Enum
12
+ from logging import LoggerAdapter, getLogger
13
+ from pathlib import Path
14
+ from typing import Any, Self, TypeVar
15
+
16
+ import aiofiles
17
+ import aiofiles.os as aio_os
18
+
19
+ from .blob.hash_ import calc_hash_file
20
+ from .blob.package import CompressionT, Packer, Unpacker
21
+ from .blob.utils import choose_compression
22
+ from .exceptions import (
23
+ BlobNotFoundError,
24
+ CAFSClientError,
25
+ EmptyConnectionPoolError,
26
+ UnexpectedResponseError,
27
+ )
28
+ from .types import AsyncReader, AsyncWriter
29
+
30
+ __all__ = ('CAFSClient',)
31
+
32
+
33
+ DEFAULT_CONNECT_TIMEOUT = 5.0
34
+
35
+ module_logger = getLogger(__name__)
36
+
37
+
38
+ class CommandT(bytes, Enum):
39
+ VERSION = b'VERS'
40
+ CHECK = b'CHCK'
41
+ STREAM = b'STRM'
42
+ SIZE = b'SIZE'
43
+ PULL = b'PULL'
44
+
45
+ def __str__(self) -> str:
46
+ return self.value.decode('utf-8')
47
+
48
+
49
+ class ResponseT(bytes, Enum):
50
+ VERSION_OK_RESPONSE = b'NONE'
51
+ CHECK_FOUND_RESPONSE = b'HAVE'
52
+ CHECK_NOT_FOUND_RESPONSE = b'NONE'
53
+ STREAM_OK_RESPONSE = b'HAVE'
54
+ SIZE_NOT_FOUND_RESPONSE = b'NONE'
55
+ SIZE_OK_RESPONSE = b'SIZE'
56
+ PULL_FOUND_RESPONSE = b'TAKE'
57
+ PULL_NOT_FOUND_RESPONSE = b'NONE'
58
+
59
+
60
+ RESPONSE_LENGTH = 4
61
+ HELLO_HEADER_LENGTH = 20
62
+
63
+ CLIENT_VERSION = b'001'
64
+ HASH_TYPE = b'blake3'
65
+ CAFS_DEFAULT_PORT = 2403
66
+ STREAM_MAX_CHUNK_SIZE = 65534
67
+
68
+
69
+ class ConnectionLoggerAdapter(LoggerAdapter):
70
+ def process(
71
+ self, msg: Any, kwargs: MutableMapping[str, Any]
72
+ ) -> tuple[str, MutableMapping[str, Any]]:
73
+ return (
74
+ f'[{self.extra["host"]}:{self.extra["port"]}:{self.extra["connection_id"]}] {msg}', # type: ignore[index]
75
+ kwargs,
76
+ )
77
+
78
+
79
+ class CAFSConnection:
80
+ host: str
81
+ port: int
82
+ timeout: float
83
+ server_root: bytes
84
+ module_logger: ConnectionLoggerAdapter
85
+
86
+ _reader: asyncio.StreamReader | None = None
87
+ _writer: asyncio.StreamWriter | None = None
88
+
89
+ def __init__(
90
+ self,
91
+ server_root: str,
92
+ host: str,
93
+ port: int = CAFS_DEFAULT_PORT,
94
+ timeout: float = DEFAULT_CONNECT_TIMEOUT,
95
+ ) -> None:
96
+ self.server_root = server_root.encode('utf-8')
97
+ self.host = host
98
+ self.port = port
99
+ self.timeout = timeout
100
+ self.is_connected = False
101
+ self.logger = ConnectionLoggerAdapter(
102
+ module_logger, {'host': host, 'port': port, 'connection_id': id(self)}
103
+ )
104
+
105
+ async def connect(self) -> None:
106
+ if self.is_connected:
107
+ return
108
+
109
+ try:
110
+ self.logger.debug('Connecting')
111
+ self._reader, self._writer = await asyncio.wait_for(
112
+ asyncio.open_connection(self.host, self.port), timeout=self.timeout
113
+ )
114
+ await self._auth()
115
+ self.is_connected = True
116
+ self.logger.debug('Connected')
117
+ except asyncio.TimeoutError:
118
+ self.logger.error('Connection timed out')
119
+ raise CAFSClientError('Connection timed out')
120
+ except (ConnectionRefusedError, OSError) as err:
121
+ self.logger.error('Failed to connect: %s', str(err))
122
+ raise err
123
+
124
+ async def disconnect(self) -> None:
125
+ if self._writer and not self._writer.is_closing():
126
+ self._writer.close()
127
+ try:
128
+ await self._writer.wait_closed()
129
+ except Exception as e: # pylint: disable=broad-exception-caught
130
+ self.logger.error('Failed to close connection: %s', e)
131
+
132
+ self.is_connected = False
133
+ self._reader = None
134
+ self._writer = None
135
+ self.logger.debug('Disconnected')
136
+
137
+ async def _send(self, data: bytes) -> None:
138
+ if not self._writer:
139
+ raise CAFSClientError('Connection is not established')
140
+ self._writer.write(data)
141
+ await self._writer.drain()
142
+
143
+ async def _receive(self, n: int = -1) -> bytes:
144
+ if not self._reader:
145
+ raise CAFSClientError('Connection is not established')
146
+ try:
147
+ return await asyncio.wait_for(
148
+ self._reader.readexactly(n), timeout=self.timeout
149
+ )
150
+ except (asyncio.IncompleteReadError, asyncio.TimeoutError) as err:
151
+ self.logger.error('Failed to receive data: %s', err)
152
+ raise
153
+
154
+ async def _auth(self) -> None:
155
+ self.logger.debug('Authenticating')
156
+ await self._receive(HELLO_HEADER_LENGTH)
157
+ self.logger.debug('Got hello header')
158
+
159
+ self.logger.debug('Sending version')
160
+ await self._send(CommandT.VERSION + CLIENT_VERSION)
161
+
162
+ response = await self._receive(RESPONSE_LENGTH)
163
+
164
+ if response != ResponseT.VERSION_OK_RESPONSE:
165
+ self.logger.error('Authentication failed, received: %s', response)
166
+ raise UnexpectedResponseError(response)
167
+
168
+ buff = bytearray(1 + len(self.server_root))
169
+ buff[0] = len(self.server_root)
170
+ buff[1:] = self.server_root
171
+ await self._send(buff)
172
+ self.logger.debug('Authenticated')
173
+
174
+ async def check(self, blob_hash: str) -> bool:
175
+ if not self.is_connected:
176
+ await self.connect()
177
+
178
+ self.logger.debug('Checking for blob: %s', blob_hash)
179
+ blob_hash_bytes = bytes.fromhex(blob_hash)
180
+ await self._send(CommandT.CHECK + HASH_TYPE + blob_hash_bytes)
181
+ self.logger.debug('Sent %s command', CommandT.CHECK)
182
+
183
+ response = await self._receive(RESPONSE_LENGTH)
184
+ self.logger.debug('Got response: %s', response)
185
+ if response == ResponseT.CHECK_FOUND_RESPONSE:
186
+ return True
187
+ if response == ResponseT.CHECK_NOT_FOUND_RESPONSE:
188
+ return False
189
+
190
+ self.logger.error('Received unexpected response: %s', response)
191
+ raise UnexpectedResponseError(response)
192
+
193
+ async def size(self, blob_hash: str) -> int:
194
+ if not self.is_connected:
195
+ await self.connect()
196
+
197
+ self.logger.debug('Getting size for blob: %s', blob_hash)
198
+ blob_hash_bytes = bytes.fromhex(blob_hash)
199
+ await self._send(CommandT.SIZE + HASH_TYPE + blob_hash_bytes)
200
+ self.logger.debug('Sent %s command', CommandT.SIZE)
201
+
202
+ response = await self._receive(RESPONSE_LENGTH + 8)
203
+ self.logger.debug('Got response: %s', response)
204
+ if response[:RESPONSE_LENGTH] == ResponseT.SIZE_NOT_FOUND_RESPONSE:
205
+ self.logger.error('Blob not found: %s', blob_hash)
206
+ raise BlobNotFoundError(blob_hash)
207
+ if response[:RESPONSE_LENGTH] != ResponseT.SIZE_OK_RESPONSE:
208
+ self.logger.error('Received unexpected response: %s', response)
209
+ raise UnexpectedResponseError(response)
210
+ return int.from_bytes(response[RESPONSE_LENGTH:], 'little')
211
+
212
+ async def stream(self, blob_hash: str, reader: 'AsyncReader') -> None:
213
+ if not self.is_connected:
214
+ await self.connect()
215
+
216
+ self.logger.debug('Streaming blob: %s', blob_hash)
217
+ blob_hash_bytes = bytes.fromhex(blob_hash)
218
+ await self._send(CommandT.STREAM + HASH_TYPE + blob_hash_bytes)
219
+
220
+ chunk = await reader.read(STREAM_MAX_CHUNK_SIZE)
221
+ while chunk:
222
+ size_header = len(chunk).to_bytes(2, 'little')
223
+ self.logger.debug(
224
+ 'Streaming chunk of size: %d (%s)', len(chunk), size_header.hex()
225
+ )
226
+ await self._send(size_header)
227
+ await self._send(chunk)
228
+ chunk = await reader.read(STREAM_MAX_CHUNK_SIZE)
229
+
230
+ self.logger.debug('Ending stream')
231
+ await self._send(b'\x00\x00')
232
+
233
+ response = await self._receive(RESPONSE_LENGTH)
234
+ self.logger.debug('Got response: %s', response)
235
+ if response != ResponseT.STREAM_OK_RESPONSE:
236
+ self.logger.error('Received unexpected response: %s', response)
237
+ raise UnexpectedResponseError(response)
238
+
239
+ async def pull(self, blob_hash: str, writer: AsyncWriter) -> None:
240
+ if not self.is_connected:
241
+ await self.connect()
242
+
243
+ self.logger.debug('Pulling blob: %s', blob_hash)
244
+ blob_hash_bytes = bytes.fromhex(blob_hash)
245
+ await self._send(CommandT.PULL + HASH_TYPE + blob_hash_bytes + b'\x00' * 12)
246
+
247
+ response = await self._receive(RESPONSE_LENGTH)
248
+ if response == ResponseT.PULL_NOT_FOUND_RESPONSE:
249
+ self.logger.error('Blob not found: %s', blob_hash)
250
+ raise BlobNotFoundError(blob_hash)
251
+
252
+ if response != ResponseT.PULL_FOUND_RESPONSE:
253
+ self.logger.error('Received unexpected response: %s', response)
254
+ raise UnexpectedResponseError(response)
255
+
256
+ response = await self._receive(len(HASH_TYPE) + len(blob_hash_bytes) + 12)
257
+
258
+ blob_size = int.from_bytes(
259
+ response[len(HASH_TYPE) + len(blob_hash_bytes) : -4], 'little'
260
+ )
261
+ self.logger.debug('Blob size: %d', blob_size)
262
+ received = 0
263
+
264
+ while received < blob_size:
265
+ chunk_size = min(STREAM_MAX_CHUNK_SIZE, blob_size - received)
266
+ self.logger.debug('Pulling chunk of size: %d', chunk_size)
267
+ chunk = await self._receive(chunk_size)
268
+ received += chunk_size
269
+ self.logger.debug('Received %d bytes', len(chunk))
270
+ await writer.write(chunk)
271
+
272
+ await writer.flush()
273
+
274
+ self.logger.debug('Pulled %d for blob %s', received, blob_hash)
275
+
276
+
277
+ class ConnectionPool:
278
+ connect_timeout: float
279
+ server_root: str
280
+ servers: set[tuple[str, int]]
281
+ connection_per_server: int
282
+
283
+ _lock: asyncio.Lock
284
+ _connections: set[CAFSConnection]
285
+ _connection_queue: Queue[CAFSConnection]
286
+
287
+ def __init__(
288
+ self,
289
+ server_root: str,
290
+ servers: Collection[tuple[str, int]],
291
+ connection_per_server: int = 1,
292
+ connect_timeout: float = DEFAULT_CONNECT_TIMEOUT,
293
+ ) -> None:
294
+ self.server_root = server_root
295
+ self.connect_timeout = connect_timeout
296
+ self.servers = set(servers)
297
+ self.connection_per_server = connection_per_server
298
+
299
+ self._connections = set()
300
+ self._connection_queue = Queue()
301
+ self._lock = asyncio.Lock()
302
+
303
+ async def get_connection_count(self) -> int:
304
+ async with self._lock:
305
+ return len(self._connections)
306
+
307
+ async def initialize(self) -> None:
308
+ for server in self.servers:
309
+ host, port = server
310
+ for _ in range(self.connection_per_server):
311
+ conn = CAFSConnection(
312
+ self.server_root, host, port, timeout=self.connect_timeout
313
+ )
314
+ self._connections.add(conn)
315
+ await self._connection_queue.put(conn)
316
+
317
+ async def _get_connection(self) -> CAFSConnection:
318
+ try:
319
+ return await self._connection_queue.get()
320
+ except QueueShutDown as err:
321
+ raise EmptyConnectionPoolError() from err
322
+
323
+ async def _release_connection(self, conn: CAFSConnection) -> None:
324
+ await self._connection_queue.put(conn)
325
+
326
+ async def _delete_connection(self, conn: CAFSConnection) -> None:
327
+ await conn.disconnect()
328
+ async with self._lock:
329
+ self._connections.remove(conn)
330
+ if not self._connections:
331
+ self._connection_queue.shutdown(immediate=True)
332
+
333
+ async def close(self) -> None:
334
+ async with self._lock:
335
+ self._connection_queue.shutdown(immediate=True)
336
+ for conn in self._connections:
337
+ if conn.is_connected:
338
+ await conn.disconnect()
339
+
340
+ @asynccontextmanager
341
+ async def connection(self) -> AsyncIterator[CAFSConnection]:
342
+ conn = await self._get_connection()
343
+ try:
344
+ yield conn
345
+ except Exception:
346
+ await self._delete_connection(conn)
347
+ raise
348
+ await self._release_connection(conn)
349
+
350
+ async def __aenter__(self) -> Self:
351
+ await self.initialize()
352
+ return self
353
+
354
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
355
+ await self.close()
356
+
357
+
358
+ T = TypeVar('T')
359
+
360
+
361
+ async def _until_pool_empty_wrapper(
362
+ func: Callable[[], Coroutine[Any, Any, T]], retry: bool = True
363
+ ) -> T:
364
+ if not retry:
365
+ return await func()
366
+
367
+ while True:
368
+ try:
369
+ return await func()
370
+ except EmptyConnectionPoolError:
371
+ raise
372
+ except Exception: # pylint: disable=broad-exception-caught # nosec: try_except_pass
373
+ pass
374
+
375
+
376
+ class CAFSClient:
377
+ _connection_pool: ConnectionPool
378
+
379
+ def __init__(
380
+ self,
381
+ server_root: str,
382
+ servers: Collection[str],
383
+ connection_per_server: int = 1,
384
+ connect_timeout: float = DEFAULT_CONNECT_TIMEOUT,
385
+ ) -> None:
386
+ servers_ = {self.parse_server_uri(server) for server in servers}
387
+ self._connection_pool = ConnectionPool(
388
+ server_root, servers_, connection_per_server, connect_timeout
389
+ )
390
+
391
+ async def pull(self, blob_hash: str, path: Path, retry: bool = True) -> None:
392
+ await aio_os.makedirs(path.parent, exist_ok=True)
393
+
394
+ async def _pull() -> None:
395
+ async with aiofiles.open(path, 'wb') as file:
396
+ async with self._connection_pool.connection() as conn:
397
+ unpacker = Unpacker(file, logger=conn.logger)
398
+ await conn.pull(blob_hash, unpacker)
399
+
400
+ await _until_pool_empty_wrapper(_pull, retry=retry)
401
+
402
+ async def pull_batch(
403
+ self,
404
+ blobs: list[tuple[str, Path]],
405
+ retry: bool = True,
406
+ ) -> None:
407
+ if not blobs:
408
+ return
409
+
410
+ files_queue: asyncio.Queue[tuple[str, Path]] = asyncio.Queue()
411
+ for blob_hash, blob_path in blobs:
412
+ files_queue.put_nowait((blob_hash, blob_path))
413
+
414
+ async def worker(stop_event_: asyncio.Event) -> None:
415
+ while not stop_event_.is_set():
416
+ try:
417
+ f_hash, f_path = files_queue.get_nowait()
418
+ except asyncio.QueueEmpty:
419
+ break
420
+ try:
421
+ await self.pull(f_hash, f_path, retry=retry)
422
+ except EmptyConnectionPoolError:
423
+ stop_event_.set()
424
+ raise
425
+ finally:
426
+ files_queue.task_done()
427
+
428
+ stop_event = asyncio.Event()
429
+ workers = [asyncio.create_task(worker(stop_event)) for _ in range(len(blobs))]
430
+ errors = await asyncio.gather(*workers, return_exceptions=True)
431
+ files_queue.shutdown(immediate=True)
432
+
433
+ for err in errors:
434
+ if isinstance(err, Exception):
435
+ raise err
436
+
437
+ async def check(self, blob_hash: str, retry: bool = True) -> bool:
438
+ async def _check() -> bool:
439
+ async with self._connection_pool.connection() as conn:
440
+ return await conn.check(blob_hash)
441
+
442
+ return await _until_pool_empty_wrapper(_check, retry=retry)
443
+
444
+ async def size(self, blob_hash: str, retry: bool = True) -> int:
445
+ async def _size() -> int:
446
+ async with self._connection_pool.connection() as conn:
447
+ return await conn.size(blob_hash)
448
+
449
+ return await _until_pool_empty_wrapper(_size, retry=retry)
450
+
451
+ async def stream(
452
+ self,
453
+ path: Path,
454
+ compression: CompressionT = CompressionT.NONE,
455
+ retry: bool = True,
456
+ ) -> str:
457
+ blob_hash: str = await calc_hash_file(path)
458
+ compression = choose_compression(path, preferred_compression=compression)
459
+
460
+ async def _stream() -> str:
461
+ async with aiofiles.open(path, 'rb') as file:
462
+ async with self._connection_pool.connection() as conn:
463
+ packer = Packer(file, compression=compression, logger=conn.logger)
464
+ await conn.stream(blob_hash, packer)
465
+ return blob_hash
466
+
467
+ return await _until_pool_empty_wrapper(_stream, retry=retry)
468
+
469
+ async def stream_batch(
470
+ self,
471
+ paths: list[Path],
472
+ compression: CompressionT = CompressionT.NONE,
473
+ retry: bool = True,
474
+ max_concurrent: int | None = None,
475
+ ) -> list[str]:
476
+ if not paths:
477
+ return []
478
+
479
+ max_concurrent = min(
480
+ max_concurrent or await self._connection_pool.get_connection_count(),
481
+ len(paths),
482
+ )
483
+
484
+ files_queue: asyncio.Queue[tuple[int, Path]] = asyncio.Queue()
485
+ for idx, path in enumerate(paths):
486
+ files_queue.put_nowait((idx, path))
487
+ results: list[str | None] = [None] * len(paths)
488
+
489
+ async def worker(stop_event_: asyncio.Event) -> None:
490
+ while not stop_event_.is_set():
491
+ try:
492
+ f_idx, f_path = files_queue.get_nowait()
493
+ except asyncio.QueueEmpty:
494
+ break
495
+ try:
496
+ blob_hash = await self.stream(
497
+ f_path, compression=compression, retry=retry
498
+ )
499
+ results[f_idx] = blob_hash
500
+ except EmptyConnectionPoolError:
501
+ stop_event_.set()
502
+ raise
503
+ finally:
504
+ files_queue.task_done()
505
+
506
+ stop_event = asyncio.Event()
507
+ workers = [
508
+ asyncio.create_task(worker(stop_event)) for _ in range(max_concurrent)
509
+ ]
510
+ errors = await asyncio.gather(*workers, return_exceptions=True)
511
+ files_queue.shutdown(immediate=True)
512
+
513
+ for err in errors:
514
+ if isinstance(err, Exception):
515
+ raise err
516
+ if any(res is None for res in results):
517
+ raise CAFSClientError(
518
+ 'Unexpected error during streaming, some blobs are None'
519
+ )
520
+
521
+ return results
522
+
523
+ async def __aenter__(self) -> Self:
524
+ await self._connection_pool.__aenter__()
525
+ return self
526
+
527
+ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
528
+ await self._connection_pool.__aexit__(exc_type, exc_val, exc_tb)
529
+
530
+ @staticmethod
531
+ def parse_server_uri(uri: str) -> tuple[str, int]:
532
+ if ':' in uri:
533
+ host, port = uri.rsplit(':', maxsplit=1)
534
+ return host, int(port)
535
+ return uri, CAFS_DEFAULT_PORT
@@ -0,0 +1,30 @@
1
+ __all__ = (
2
+ 'CAFSClientError',
3
+ 'BlobNotFoundError',
4
+ 'UnexpectedResponseError',
5
+ 'EmptyConnectionPoolError',
6
+ )
7
+
8
+
9
+ class CAFSClientError(Exception):
10
+ pass
11
+
12
+
13
+ class BlobNotFoundError(CAFSClientError):
14
+ blob_hash: str
15
+
16
+ def __init__(self, blob_hash: str) -> None:
17
+ self.blob_hash = blob_hash
18
+ super().__init__(f'Blob not found: {blob_hash}')
19
+
20
+
21
+ class UnexpectedResponseError(CAFSClientError):
22
+ response: bytes
23
+
24
+ def __init__(self, response: bytes) -> None:
25
+ self.response = response
26
+ super().__init__('Unexpected response: %s', self.response)
27
+
28
+
29
+ class EmptyConnectionPoolError(CAFSClientError):
30
+ pass
@@ -0,0 +1,19 @@
1
+ from typing import Any, Protocol
2
+
3
+ __all__ = (
4
+ 'AsyncReader',
5
+ 'AsyncWriter',
6
+ )
7
+
8
+
9
+ class AsyncReader(Protocol):
10
+ async def read(self, size: int = -1) -> bytes:
11
+ pass
12
+
13
+
14
+ class AsyncWriter(Protocol):
15
+ async def write(self, data: bytes, /) -> Any:
16
+ pass
17
+
18
+ async def flush(self) -> None:
19
+ pass