asyncstdlib 3.13.1__tar.gz → 3.13.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/PKG-INFO +2 -1
  2. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/__init__.py +1 -1
  3. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/_typing.py +9 -2
  4. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/asynctools.py +2 -3
  5. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/builtins.py +3 -5
  6. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/builtins.pyi +39 -6
  7. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/functools.py +1 -1
  8. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/functools.pyi +8 -0
  9. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/heapq.py +2 -2
  10. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/itertools.py +95 -72
  11. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/itertools.pyi +6 -2
  12. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/pyproject.toml +4 -0
  13. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_asynctools.py +0 -1
  14. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_functools_lru.py +5 -2
  15. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_heapq.py +0 -1
  16. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_itertools.py +95 -1
  17. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/utility.py +0 -1
  18. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/LICENSE +0 -0
  19. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/README.rst +0 -0
  20. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/_core.py +0 -0
  21. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/_lrucache.py +0 -0
  22. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/_lrucache.pyi +0 -0
  23. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/_utility.py +0 -0
  24. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/contextlib.py +0 -0
  25. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/contextlib.pyi +0 -0
  26. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/heapq.pyi +0 -0
  27. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/asyncstdlib/py.typed +0 -0
  28. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/__init__.py +0 -0
  29. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_builtins.py +0 -0
  30. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_contextlib.py +0 -0
  31. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_functools.py +0 -0
  32. {asyncstdlib-3.13.1 → asyncstdlib-3.13.3}/unittests/test_helpers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: asyncstdlib
3
- Version: 3.13.1
3
+ Version: 3.13.3
4
4
  Summary: The missing async toolbox
5
5
  Keywords: async,enumerate,itertools,builtins,functools,contextlib
6
6
  Author-email: Max Kühn <maxfischer2781@gmail.com>
@@ -17,6 +17,7 @@ Classifier: Programming Language :: Python :: 3.10
17
17
  Classifier: Programming Language :: Python :: 3.11
18
18
  Classifier: Programming Language :: Python :: 3.12
19
19
  Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Programming Language :: Python :: 3.14
20
21
  License-File: LICENSE
21
22
  Requires-Dist: sphinx ; extra == "doc"
22
23
  Requires-Dist: sphinxcontrib-trio ; extra == "doc"
@@ -45,7 +45,7 @@ from .itertools import (
45
45
  from .asynctools import borrow, scoped_iter, await_each, any_iter, apply, sync
46
46
  from .heapq import merge, nlargest, nsmallest
47
47
 
48
- __version__ = "3.13.1"
48
+ __version__ = "3.13.3"
49
49
 
50
50
  __all__ = [
51
51
  "anext",
@@ -55,12 +55,19 @@ AC = TypeVar("AC", bound=Callable[..., Awaitable[Any]])
55
55
  #: Hashable Key
56
56
  HK = TypeVar("HK", bound=Hashable)
57
57
 
58
+
59
+ # bool(...)
60
+ class SupportsBool(Protocol):
61
+ def __bool__(self) -> bool:
62
+ raise NotImplementedError
63
+
64
+
58
65
  # LT < LT
59
66
  LT = TypeVar("LT", bound="SupportsLT")
60
67
 
61
68
 
62
69
  class SupportsLT(Protocol):
63
- def __lt__(self: LT, other: LT) -> bool:
70
+ def __lt__(self, __other: Any) -> SupportsBool:
64
71
  raise NotImplementedError
65
72
 
66
73
 
@@ -69,7 +76,7 @@ ADD = TypeVar("ADD", bound="SupportsAdd")
69
76
 
70
77
 
71
78
  class SupportsAdd(Protocol):
72
- def __add__(self: ADD, other: ADD, /) -> ADD:
79
+ def __add__(self, __other: Any, /) -> Any:
73
80
  raise NotImplementedError
74
81
 
75
82
 
@@ -20,7 +20,6 @@ from ._typing import T, T1, T2, T3, T4, T5, AnyIterable
20
20
  from ._core import aiter
21
21
  from .contextlib import nullcontext
22
22
 
23
-
24
23
  S = TypeVar("S")
25
24
 
26
25
 
@@ -35,7 +34,7 @@ class _BorrowedAsyncIterator(AsyncGenerator[T, S]):
35
34
  __slots__ = "__wrapped__", "__anext__", "asend", "athrow", "_wrapper"
36
35
 
37
36
  # Type checker does not understand `__slot__` definitions
38
- __anext__: Callable[[Any], Coroutine[Any, Any, T]]
37
+ __anext__: Callable[..., Coroutine[Any, Any, T]]
39
38
  asend: Any
40
39
  athrow: Any
41
40
 
@@ -49,7 +48,7 @@ class _BorrowedAsyncIterator(AsyncGenerator[T, S]):
49
48
  # An async *iterator* (e.g. `async def: yield`) must return
50
49
  # itself from __aiter__. If we do not shadow this then
51
50
  # running aiter(self).aclose closes the underlying iterator.
52
- self.__anext__ = self._wrapper.__anext__ # type: ignore
51
+ self.__anext__ = self._wrapper.__anext__
53
52
  if hasattr(iterator, "asend"):
54
53
  self.asend = (
55
54
  iterator.asend # pyright: ignore[reportUnknownMemberType,reportAttributeAccessIssue]
@@ -22,7 +22,6 @@ from ._core import (
22
22
  Sentinel,
23
23
  )
24
24
 
25
-
26
25
  __ANEXT_DEFAULT = Sentinel("<no default>")
27
26
 
28
27
 
@@ -55,7 +54,7 @@ __ITER_DEFAULT = Sentinel("<no default>")
55
54
 
56
55
 
57
56
  def iter(
58
- subject: Union[AnyIterable[T], Callable[[], Awaitable[T]]],
57
+ subject: Union[AnyIterable[T], Callable[[], Awaitable[T]], Callable[[], T]],
59
58
  sentinel: Union[Sentinel, T] = __ITER_DEFAULT,
60
59
  ) -> AsyncIterator[T]:
61
60
  """
@@ -84,13 +83,12 @@ def iter(
84
83
  raise TypeError("iter(v, w): v must be callable")
85
84
  else:
86
85
  assert not isinstance(sentinel, Sentinel)
87
- return acallable_iterator(subject, sentinel)
86
+ return acallable_iterator(_awaitify(subject), sentinel)
88
87
 
89
88
 
90
89
  async def acallable_iterator(
91
90
  subject: Callable[[], Awaitable[T]], sentinel: T
92
91
  ) -> AsyncIterator[T]:
93
- subject = _awaitify(subject)
94
92
  value = await subject()
95
93
  while value != sentinel:
96
94
  yield value
@@ -306,7 +304,7 @@ async def _min_max(
306
304
  raise ValueError(f"{name}() arg is an empty sequence")
307
305
  elif key is None:
308
306
  async for item in item_iter:
309
- if invert ^ (item < best):
307
+ if invert ^ bool(item < best):
310
308
  best = item
311
309
  else:
312
310
  key = _awaitify(key)
@@ -2,7 +2,7 @@ from typing import Any, AsyncIterator, Awaitable, Callable, overload
2
2
  from typing_extensions import TypeGuard
3
3
  import builtins
4
4
 
5
- from ._typing import ADD, AnyIterable, HK, LT, R, T, T1, T2, T3, T4, T5
5
+ from ._typing import ADD, AnyIterable, HK, LT, R, T, T1, T2, T3, T4, T5, SupportsLT
6
6
 
7
7
  @overload
8
8
  async def anext(iterator: AsyncIterator[T]) -> T: ...
@@ -16,6 +16,10 @@ def iter(
16
16
  ) -> AsyncIterator[T]: ...
17
17
  @overload
18
18
  def iter(subject: Callable[[], Awaitable[T]], sentinel: T) -> AsyncIterator[T]: ...
19
+ @overload
20
+ def iter(subject: Callable[[], T | None], sentinel: None) -> AsyncIterator[T]: ...
21
+ @overload
22
+ def iter(subject: Callable[[], T], sentinel: T) -> AsyncIterator[T]: ...
19
23
  async def all(iterable: AnyIterable[Any]) -> bool: ...
20
24
  async def any(iterable: AnyIterable[Any]) -> bool: ...
21
25
  @overload
@@ -180,20 +184,42 @@ async def max(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
180
184
  @overload
181
185
  async def max(iterable: AnyIterable[LT], *, key: None = ..., default: T) -> LT | T: ...
182
186
  @overload
183
- async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...
187
+ async def max(
188
+ iterable: AnyIterable[T1], *, key: Callable[[T1], Awaitable[SupportsLT]]
189
+ ) -> T1: ...
190
+ @overload
191
+ async def max(
192
+ iterable: AnyIterable[T1],
193
+ *,
194
+ key: Callable[[T1], Awaitable[SupportsLT]],
195
+ default: T2,
196
+ ) -> T1 | T2: ...
197
+ @overload
198
+ async def max(iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT]) -> T1: ...
184
199
  @overload
185
200
  async def max(
186
- iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
201
+ iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT], default: T2
187
202
  ) -> T1 | T2: ...
188
203
  @overload
189
204
  async def min(iterable: AnyIterable[LT], *, key: None = ...) -> LT: ...
190
205
  @overload
191
206
  async def min(iterable: AnyIterable[LT], *, key: None = ..., default: T) -> LT | T: ...
192
207
  @overload
193
- async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ...) -> T1: ...
208
+ async def min(
209
+ iterable: AnyIterable[T1], *, key: Callable[[T1], Awaitable[SupportsLT]]
210
+ ) -> T1: ...
194
211
  @overload
195
212
  async def min(
196
- iterable: AnyIterable[T1], *, key: Callable[[T1], LT] = ..., default: T2
213
+ iterable: AnyIterable[T1],
214
+ *,
215
+ key: Callable[[T1], Awaitable[SupportsLT]],
216
+ default: T2,
217
+ ) -> T1 | T2: ...
218
+ @overload
219
+ async def min(iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT]) -> T1: ...
220
+ @overload
221
+ async def min(
222
+ iterable: AnyIterable[T1], *, key: Callable[[T1], SupportsLT], default: T2
197
223
  ) -> T1 | T2: ...
198
224
  @overload
199
225
  def filter(
@@ -247,5 +273,12 @@ async def sorted(
247
273
  ) -> builtins.list[LT]: ...
248
274
  @overload
249
275
  async def sorted(
250
- iterable: AnyIterable[T], *, key: Callable[[T], LT], reverse: bool = ...
276
+ iterable: AnyIterable[T],
277
+ *,
278
+ key: Callable[[T], Awaitable[SupportsLT]],
279
+ reverse: bool = ...,
280
+ ) -> builtins.list[T]: ...
281
+ @overload
282
+ async def sorted(
283
+ iterable: AnyIterable[T], *, key: Callable[[T], SupportsLT], reverse: bool = ...
251
284
  ) -> builtins.list[T]: ...
@@ -257,7 +257,7 @@ def cached_property(
257
257
  if iscoroutinefunction(type_or_getter):
258
258
  return CachedProperty(type_or_getter)
259
259
  elif isinstance(type_or_getter, type) and issubclass(
260
- type_or_getter, AsyncContextManager
260
+ type_or_getter, AsyncContextManager # pyright: ignore[reportGeneralTypeIssues]
261
261
  ):
262
262
 
263
263
  def decorator(
@@ -33,6 +33,14 @@ def cached_property(
33
33
  asynccontextmanager_type: type[AsyncContextManager[Any]], /
34
34
  ) -> Callable[[Callable[[T], Awaitable[R]]], CachedProperty[T, R]]: ...
35
35
  @overload
36
+ async def reduce(
37
+ function: Callable[[T1, T2], Awaitable[T1]], iterable: AnyIterable[T2], initial: T1
38
+ ) -> T1: ...
39
+ @overload
40
+ async def reduce(
41
+ function: Callable[[T, T], Awaitable[T]], iterable: AnyIterable[T]
42
+ ) -> T: ...
43
+ @overload
36
44
  async def reduce(
37
45
  function: Callable[[T1, T2], T1], iterable: AnyIterable[T2], initial: T1
38
46
  ) -> T1: ...
@@ -92,7 +92,7 @@ class _KeyIter(Generic[LT]):
92
92
  return True
93
93
 
94
94
  def __lt__(self, other: _KeyIter[LT]) -> bool:
95
- return self.reverse ^ (self.head_key < other.head_key)
95
+ return self.reverse ^ bool(self.head_key < other.head_key)
96
96
 
97
97
  def __eq__(self, other: _KeyIter[LT]) -> bool: # type: ignore[override]
98
98
  return not (self.head_key < other.head_key or other.head_key < self.head_key)
@@ -161,7 +161,7 @@ class ReverseLT(Generic[LT]):
161
161
  self.key = key
162
162
 
163
163
  def __lt__(self, other: ReverseLT[LT]) -> bool:
164
- return other.key < self.key
164
+ return bool(other.key < self.key)
165
165
 
166
166
 
167
167
  # Python's heapq provides a *min*-heap
@@ -8,7 +8,6 @@ from typing import (
8
8
  Union,
9
9
  Callable,
10
10
  Optional,
11
- Deque,
12
11
  Generic,
13
12
  Iterable,
14
13
  Iterator,
@@ -16,15 +15,16 @@ from typing import (
16
15
  cast,
17
16
  overload,
18
17
  AsyncGenerator,
18
+ TYPE_CHECKING,
19
19
  )
20
- from collections import deque
20
+ if TYPE_CHECKING:
21
+ from typing_extensions import TypeAlias
21
22
 
22
23
  from ._typing import ACloseable, R, T, AnyIterable, ADD
23
24
  from ._utility import public_module
24
25
  from ._core import (
25
26
  ScopedIter,
26
27
  awaitify as _awaitify,
27
- Sentinel,
28
28
  borrow as _borrow,
29
29
  )
30
30
  from .builtins import (
@@ -33,6 +33,7 @@ from .builtins import (
33
33
  enumerate as aenumerate,
34
34
  iter as aiter,
35
35
  )
36
+ from itertools import count as _counter
36
37
 
37
38
  S = TypeVar("S")
38
39
  T_co = TypeVar("T_co", covariant=True)
@@ -64,9 +65,6 @@ async def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]:
64
65
  yield item
65
66
 
66
67
 
67
- __ACCUMULATE_SENTINEL = Sentinel("<no default>")
68
-
69
-
70
68
  async def add(x: ADD, y: ADD) -> ADD:
71
69
  """The default reduction of :py:func:`~.accumulate`"""
72
70
  return x + y
@@ -78,7 +76,7 @@ async def accumulate(
78
76
  Callable[[Any, Any], Any], Callable[[Any, Any], Awaitable[Any]]
79
77
  ] = add,
80
78
  *,
81
- initial: Any = __ACCUMULATE_SENTINEL,
79
+ initial: Any = None,
82
80
  ) -> AsyncIterator[Any]:
83
81
  """
84
82
  An :term:`asynchronous iterator` on the running reduction of ``iterable``
@@ -105,11 +103,7 @@ async def accumulate(
105
103
  """
106
104
  async with ScopedIter(iterable) as item_iter:
107
105
  try:
108
- value = (
109
- initial
110
- if initial is not __ACCUMULATE_SENTINEL
111
- else await anext(item_iter)
112
- )
106
+ value = initial if initial is not None else await anext(item_iter)
113
107
  except StopAsyncIteration:
114
108
  raise TypeError(
115
109
  "accumulate() of empty sequence with no initial value"
@@ -354,57 +348,79 @@ class NoLock:
354
348
  return None
355
349
 
356
350
 
357
- async def tee_peer(
358
- iterator: AsyncIterator[T],
359
- # the buffer specific to this peer
360
- buffer: Deque[T],
361
- # the buffers of all peers, including our own
362
- peers: List[Deque[T]],
363
- lock: AsyncContextManager[Any],
364
- ) -> AsyncGenerator[T, None]:
365
- """An individual iterator of a :py:func:`~.tee`"""
366
- try:
367
- while True:
368
- if not buffer:
369
- async with lock:
370
- # Another peer produced an item while we were waiting for the lock.
371
- # Proceed with the next loop iteration to yield the item.
372
- if buffer:
373
- continue
374
- try:
375
- item = await iterator.__anext__()
376
- except StopAsyncIteration:
377
- break
378
- else:
379
- # Append to all buffers, including our own. We'll fetch our
380
- # item from the buffer again, instead of yielding it directly.
381
- # This ensures the proper item ordering if any of our peers
382
- # are fetching items concurrently. They may have buffered their
383
- # item already.
384
- for peer_buffer in peers:
385
- peer_buffer.append(item)
386
- yield buffer.popleft()
387
- finally:
388
- # this peer is done – remove its buffer
389
- for idx, peer_buffer in enumerate(peers): # pragma: no branch
390
- if peer_buffer is buffer:
391
- peers.pop(idx)
392
- break
393
- # if we are the last peer, try and close the iterator
394
- if not peers and isinstance(iterator, ACloseable):
395
- await iterator.aclose()
351
+ _get_tee_index = _counter().__next__
352
+
353
+
354
+ _TeeNode: "TypeAlias" = "list[T | _TeeNode[T]]"
355
+
356
+
357
+ class TeePeer(Generic[T]):
358
+ def __init__(
359
+ self,
360
+ iterator: AsyncIterator[T],
361
+ buffer: "_TeeNode[T]",
362
+ lock: AsyncContextManager[Any],
363
+ tee_peers: "set[int]",
364
+ ) -> None:
365
+ self._iterator = iterator
366
+ self._lock = lock
367
+ self._buffer: _TeeNode[T] = buffer
368
+ self._tee_peers = tee_peers
369
+ self._tee_idx = _get_tee_index()
370
+ self._tee_peers.add(self._tee_idx)
371
+
372
+ def __aiter__(self):
373
+ return self
374
+
375
+ async def __anext__(self) -> T:
376
+ # the buffer is a singly-linked list as [value, [value, [...]]] | []
377
+ next_node = self._buffer
378
+ value: T
379
+ # for any most advanced TeePeer, the node is just []
380
+ # fetch the next value so we can mutate the node to [value, [...]]
381
+ if not next_node:
382
+ async with self._lock:
383
+ # Check if another peer produced an item while we were waiting for the lock
384
+ if not next_node:
385
+ await self._extend_buffer(next_node)
386
+ # for any other TeePeer, the node is already some [value, [...]]
387
+ value, self._buffer = next_node # type: ignore
388
+ return value
389
+
390
+ async def _extend_buffer(self, next_node: "_TeeNode[T]") -> None:
391
+ """Extend the buffer by fetching a new item from the iterable"""
392
+ try:
393
+ # another peer may fill the buffer while we wait here
394
+ next_value = await self._iterator.__anext__()
395
+ except StopAsyncIteration:
396
+ # no one else managed to fetch a value either
397
+ if not next_node:
398
+ raise
399
+ else:
400
+ # skip nodes that were filled in the meantime
401
+ while next_node:
402
+ _, next_node = next_node # type: ignore
403
+ next_node[:] = next_value, []
404
+
405
+ async def aclose(self) -> None:
406
+ self._tee_peers.discard(self._tee_idx)
407
+ if not self._tee_peers and isinstance(self._iterator, ACloseable):
408
+ await self._iterator.aclose()
409
+
410
+ def __del__(self) -> None:
411
+ self._tee_peers.discard(self._tee_idx)
396
412
 
397
413
 
398
414
  @public_module(__name__, "tee")
399
415
  class Tee(Generic[T]):
400
- """
416
+ r"""
401
417
  Create ``n`` separate asynchronous iterators over ``iterable``
402
418
 
403
419
  This splits a single ``iterable`` into multiple iterators, each providing
404
420
  the same items in the same order.
405
421
  All child iterators may advance separately but share the same items
406
422
  from ``iterable`` -- when the most advanced iterator retrieves an item,
407
- it is buffered until the least advanced iterator has yielded it as well.
423
+ it is buffered until all other iterators have yielded it as well.
408
424
  A ``tee`` works lazily and can handle an infinite ``iterable``, provided
409
425
  that all iterators advance.
410
426
 
@@ -415,16 +431,9 @@ class Tee(Generic[T]):
415
431
  await a.anext(previous) # advance one iterator
416
432
  return a.map(operator.sub, previous, current)
417
433
 
418
- Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
419
- of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
420
- to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
421
- immediately closes all children, and it can be used in an ``async with`` context
422
- for the same effect.
423
-
424
- If ``iterable`` is an iterator and read elsewhere, ``tee`` will *not*
425
- provide these items. Also, ``tee`` must internally buffer each item until the
426
- last iterator has yielded it; if the most and least advanced iterator differ
427
- by most data, using a :py:class:`list` is more efficient (but not lazy).
434
+ If ``iterable`` is an iterator and read elsewhere, ``tee`` will generally *not*
435
+ provide these items. However, a ``tee`` of a ``tee`` shares its buffer with parent,
436
+ sibling and child ``tee``\ s so that each sees the same items.
428
437
 
429
438
  If the underlying iterable is concurrency safe (``anext`` may be awaited
430
439
  concurrently) the resulting iterators are concurrency safe as well. Otherwise,
@@ -432,9 +441,15 @@ class Tee(Generic[T]):
432
441
  To enforce sequential use of ``anext``, provide a ``lock``
433
442
  - e.g. an :py:class:`asyncio.Lock` instance in an :py:mod:`asyncio` application -
434
443
  and access is automatically synchronised.
444
+
445
+ Unlike :py:func:`itertools.tee`, :py:func:`~.tee` returns a custom type instead
446
+ of a :py:class:`tuple`. Like a tuple, it can be indexed, iterated and unpacked
447
+ to get the child iterators. In addition, its :py:meth:`~.tee.aclose` method
448
+ immediately closes all children, and it can be used in an ``async with`` context
449
+ for the same effect.
435
450
  """
436
451
 
437
- __slots__ = ("_iterator", "_buffers", "_children")
452
+ __slots__ = ("_children",)
438
453
 
439
454
  def __init__(
440
455
  self,
@@ -443,16 +458,24 @@ class Tee(Generic[T]):
443
458
  *,
444
459
  lock: Optional[AsyncContextManager[Any]] = None,
445
460
  ):
446
- self._iterator = aiter(iterable)
447
- self._buffers: List[Deque[T]] = [deque() for _ in range(n)]
461
+ buffer: _TeeNode[T]
462
+ peers: set[int]
463
+ if not isinstance(iterable, TeePeer):
464
+ iterator = aiter(iterable)
465
+ buffer = []
466
+ peers = set()
467
+ else:
468
+ iterator = iterable._iterator # pyright: ignore[reportPrivateUsage]
469
+ buffer = iterable._buffer # pyright: ignore[reportPrivateUsage]
470
+ peers = iterable._tee_peers # pyright: ignore[reportPrivateUsage]
448
471
  self._children = tuple(
449
- tee_peer(
450
- iterator=self._iterator,
451
- buffer=buffer,
452
- peers=self._buffers,
453
- lock=lock if lock is not None else NoLock(),
472
+ TeePeer(
473
+ iterator,
474
+ buffer,
475
+ lock if lock is not None else NoLock(),
476
+ peers,
454
477
  )
455
- for buffer in self._buffers
478
+ for _ in range(n)
456
479
  )
457
480
 
458
481
  def __len__(self) -> int:
@@ -16,13 +16,17 @@ from ._typing import AnyIterable, ADD, T, T1, T2, T3, T4, T5
16
16
 
17
17
  def cycle(iterable: AnyIterable[T]) -> AsyncIterator[T]: ...
18
18
  @overload
19
- def accumulate(iterable: AnyIterable[ADD]) -> AsyncIterator[ADD]: ...
19
+ def accumulate(
20
+ iterable: AnyIterable[ADD], *, initial: None = ...
21
+ ) -> AsyncIterator[ADD]: ...
20
22
  @overload
21
23
  def accumulate(iterable: AnyIterable[ADD], *, initial: ADD) -> AsyncIterator[ADD]: ...
22
24
  @overload
23
25
  def accumulate(
24
26
  iterable: AnyIterable[T],
25
27
  function: Callable[[T, T], T] | Callable[[T, T], Awaitable[T]],
28
+ *,
29
+ initial: None = ...,
26
30
  ) -> AsyncIterator[T]: ...
27
31
  @overload
28
32
  def accumulate(
@@ -76,7 +80,7 @@ def filterfalse(
76
80
  predicate: Callable[[T], Any] | None, iterable: AnyIterable[T]
77
81
  ) -> AsyncIterator[T]: ...
78
82
  @overload
79
- def islice(iterable: AnyIterable[T], start: int | None, /) -> AsyncIterator[T]: ...
83
+ def islice(iterable: AnyIterable[T], stop: int | None, /) -> AsyncIterator[T]: ...
80
84
  @overload
81
85
  def islice(
82
86
  iterable: AnyIterable[T],
@@ -21,6 +21,7 @@ classifiers = [
21
21
  "Programming Language :: Python :: 3.11",
22
22
  "Programming Language :: Python :: 3.12",
23
23
  "Programming Language :: Python :: 3.13",
24
+ "Programming Language :: Python :: 3.14",
24
25
  ]
25
26
  license = {"file" = "LICENSE"}
26
27
  keywords = ["async", "enumerate", "itertools", "builtins", "functools", "contextlib"]
@@ -80,3 +81,6 @@ verboseOutput = true
80
81
  testpaths = [
81
82
  "unittests",
82
83
  ]
84
+
85
+ [tool.black]
86
+ target-version = ["py38", "py39","py310", "py311", "py312", "py313", "py314"]
@@ -4,7 +4,6 @@ import asyncstdlib as a
4
4
 
5
5
  from .utility import sync, asyncify
6
6
 
7
-
8
7
  CLOSED = "closed"
9
8
 
10
9
 
@@ -2,6 +2,7 @@ from typing import Callable, Any
2
2
  import sys
3
3
 
4
4
  import pytest
5
+ from typing_extensions import get_annotations, Format
5
6
 
6
7
  import asyncstdlib as a
7
8
 
@@ -175,5 +176,7 @@ def test_wrapper_attributes(size: "int | None"):
175
176
  if name != "method":
176
177
  continue
177
178
  # test direct and literal annotation styles
178
- assert Bar.method.__annotations__["int_arg"] in {int, "int"}
179
- assert Bar().method.__annotations__["int_arg"] in {int, "int"}
179
+ assert get_annotations(Bar.method, format=Format.STRING)["int_arg"] == "int"
180
+ assert (
181
+ get_annotations(Bar().method, format=Format.STRING)["int_arg"] == "int"
182
+ )
@@ -7,7 +7,6 @@ import asyncstdlib as a
7
7
 
8
8
  from .utility import sync, asyncify, awaitify
9
9
 
10
-
11
10
  MERGE_SAMPLES = [
12
11
  [[1, 2], [3, 4]],
13
12
  [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
@@ -1,3 +1,4 @@
1
+ from typing import AsyncIterator
1
2
  import itertools
2
3
  import sys
3
4
  import platform
@@ -34,6 +35,7 @@ async def test_accumulate():
34
35
 
35
36
  @sync
36
37
  async def test_accumulate_default():
38
+ """Test the default function of accumulate"""
37
39
  for itertype in (asyncify, list):
38
40
  assert await a.list(a.accumulate(itertype([0, 1]))) == list(
39
41
  itertools.accumulate([0, 1])
@@ -53,10 +55,21 @@ async def test_accumulate_default():
53
55
 
54
56
  @sync
55
57
  async def test_accumulate_misuse():
58
+ """Test wrong arguments to accumulate"""
56
59
  with pytest.raises(TypeError):
57
60
  assert await a.list(a.accumulate([]))
58
61
 
59
62
 
63
+ @sync
64
+ async def test_accumulate_initial():
65
+ """Test the `initial` argument to accumulate"""
66
+ assert (
67
+ await a.list(a.accumulate(asyncify([1, 2, 3]), initial=None))
68
+ == await a.list(a.accumulate(asyncify([1, 2, 3])))
69
+ == list(itertools.accumulate([1, 2, 3], initial=None))
70
+ )
71
+
72
+
60
73
  batched_cases = [
61
74
  (range(10), 2, [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]),
62
75
  (range(10), 3, [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)]),
@@ -329,7 +342,7 @@ async def test_tee():
329
342
 
330
343
  @sync
331
344
  async def test_tee_concurrent_locked():
332
- """Test that properly uses a lock for synchronisation"""
345
+ """Test that tee properly uses a lock for synchronisation"""
333
346
  items = [1, 2, 3, -5, 12, 78, -1, 111]
334
347
 
335
348
  async def iter_values():
@@ -348,6 +361,52 @@ async def test_tee_concurrent_locked():
348
361
  assert results == items
349
362
 
350
363
 
364
+ @pytest.mark.parametrize("concurrency", (1, 2, 4, 7))
365
+ @sync
366
+ async def test_tee_share(concurrency: int) -> None:
367
+ """Test that related tees share their buffer and see all items"""
368
+ items = [1, 2, 3, -5, 12, 78, -1, 111]
369
+
370
+ async def tee_test(tee_state: AsyncIterator[int]) -> None:
371
+ """Asynchronously check that `tee_state` includes all `items`"""
372
+ for expected in items:
373
+ assert expected == await a.anext(tee_state)
374
+ await Switch(0, concurrency)
375
+
376
+ # create tees that are multiple times removed from an initial iterator
377
+ item_iter = a.iter(items)
378
+ for tee_peer in a.tee(item_iter, n=concurrency):
379
+ await Schedule(tee_test(a.tee(tee_peer)[0]))
380
+
381
+
382
+ @sync
383
+ async def test_tee_share_deep() -> None:
384
+ """Test that related tees share their buffer and see all items no matter when spawned"""
385
+ items = [1, 2, 3, -5, 12, 78, -1, 111]
386
+
387
+ async def tee_spawn_walker(
388
+ tee_state: AsyncIterator[int], start_idx: int = 0
389
+ ) -> None:
390
+ """Walk and check `tee_state` elements and spawn new walkers on every step"""
391
+ for idx in range(start_idx, len(items)):
392
+ await Switch(0, 3)
393
+ assert await a.anext(tee_state) == items[idx]
394
+ tee_state, *child_states = a.tee(tee_state, n=3)
395
+ await Schedule(
396
+ *(
397
+ tee_spawn_walker(child_state, idx + 1)
398
+ for child_state in child_states
399
+ )
400
+ )
401
+ await Switch()
402
+
403
+ head_peer, *child_peers = a.tee(items, n=3)
404
+ await Schedule(*(tee_spawn_walker(child, 0) for child in child_peers))
405
+ await Switch(len(items) // 2)
406
+ results = [item async for item in head_peer]
407
+ assert results == items
408
+
409
+
351
410
  # see https://github.com/python/cpython/issues/74956
352
411
  @pytest.mark.skipif(
353
412
  sys.version_info < (3, 8),
@@ -381,6 +440,41 @@ async def test_tee_concurrent_unlocked():
381
440
  await test_peer(this)
382
441
 
383
442
 
443
+ @pytest.mark.parametrize("size", [2, 3, 5, 9, 12])
444
+ @sync
445
+ async def test_tee_concurrent_ordering(size: int):
446
+ """Test that tee respects concurrent ordering for all peers"""
447
+
448
+ class ConcurrentInvertedIterable:
449
+ """Helper that concurrently iterates with earlier items taking longer"""
450
+
451
+ def __init__(self, count: int) -> None:
452
+ self.count = count
453
+ self._counter = itertools.count()
454
+
455
+ def __aiter__(self):
456
+ return self
457
+
458
+ async def __anext__(self):
459
+ value = next(self._counter)
460
+ if value >= self.count:
461
+ raise StopAsyncIteration()
462
+ await Switch(self.count - value)
463
+ return value
464
+
465
+ async def test_peer(peer_tee: AsyncIterator[int]):
466
+ # consume items from the tee with a delay so that slower items can arrive
467
+ seen_items: list[int] = []
468
+ async for item in peer_tee:
469
+ seen_items.append(item)
470
+ await Switch()
471
+ assert seen_items == expected_items
472
+
473
+ expected_items = list(range(size)[::-1])
474
+ peers = a.tee(ConcurrentInvertedIterable(size), n=size)
475
+ await Schedule(*map(test_peer, peers))
476
+
477
+
384
478
  @sync
385
479
  async def test_pairwise():
386
480
  assert await a.list(a.pairwise(range(5))) == [(0, 1), (1, 2), (2, 3), (3, 4)]
@@ -13,7 +13,6 @@ from functools import wraps
13
13
  from collections import deque
14
14
  from random import randint
15
15
 
16
-
17
16
  T = TypeVar("T")
18
17
 
19
18
 
File without changes
File without changes