Source code for sloths.ext.asyncio

"""
asyncio native stream class.
"""

from __future__ import annotations

import functools
import inspect
from collections.abc import (
    AsyncIterable,
    AsyncIterator,
    Awaitable,
    Callable,
    Iterable,
)
from typing import (
    TYPE_CHECKING,
    Any,
    Concatenate,
    Generic,
    Literal,
    ParamSpec,
    SupportsIndex,
    TypeAlias,
    TypeVar,
    cast,
    overload,
)

from sloths._utils import UNSET

if TYPE_CHECKING:
    from collections.abc import Awaitable, Iterable

T = TypeVar("T")
U = TypeVar("U")

P = ParamSpec("P")


AsyncTransform: TypeAlias = Callable[
    Concatenate[AsyncIterable[T], P],
    AsyncIterable[U],
]


[docs] class AsyncStream(Generic[T], AsyncIterable[T]): """ Async version of :class:`sloths.Stream` but async iterators. It works essentially the same and expose the same interface but in an async/await compatible manner. Some functions which take callbacks such as :meth:`map` also have prefixed async equivalent :meth:`amap` which take an async callback instead. """ def __init__(self, source: AsyncIterable[T]) -> None: self._source = source @classmethod
[docs] def range( cls: type[AsyncStream[int]], *args: SupportsIndex, ) -> AsyncStream[int]: """ Create a simple async stream over ``range()``. """ return AsyncStream(make_async(range(*args)))
@functools.cached_property def _iter(self) -> AsyncIterator[T]: return aiter(self._source) async def __aiter__(self) -> AsyncIterator[T]: async for x in self._iter: yield x async def __anext__(self) -> T: return await anext(self._iter) def __repr__(self) -> str: return f"AsyncStream<{self._source!r}>"
[docs] def chain(self, *others: AsyncIterable[T]) -> AsyncStream[T]: """ Chain one or more async iterables to the current ones. .. seealso:: :meth:`sloths.Stream.chain` """ async def _chained(): for it in (self, *others): async for x in it: yield x return AsyncStream(_chained())
[docs] def pipe( self, fn: AsyncTransform[T, P, U], *args: P.args, **kwargs: P.kwargs, ) -> AsyncPipe[T, P, U]: """ Chain a transform to a stream and return the resulting stream. .. seealso:: :meth:`sloths.Stream.pipe` Transforms are the core composability primitive and are simply callables which take an iterable and return another iterable. Usually these are lazy generators. >>> import asyncio >>> async def to_str(iterable: AsyncIterable[int]) -> \ AsyncIterable[str]: ... async for x in iterable: ... yield str(x) ... >>> asyncio.run(AsyncStream.range(10).pipe(to_str).collect()) ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] Transforms can also decide to short-circuit or selectively yield for control-flow: >>> async def to_str_if_odd(iterable: AsyncIterable[int]) -> \ AsyncIterable[str]: ... async for x in iterable: ... if x % 2: ... yield str(x) ... >>> asyncio.run(AsyncStream.range(10).pipe(to_str_if_odd).collect()) ['1', '3', '5', '7', '9'] And all the same properties as the sync version. .. warning:: When writing transforms be careful not to accidentally consume the iterable as this would negate much of the benefit of chaining generators in the first place. """ return AsyncPipe(self, fn, fn.__name__, *args, **kwargs)
# Chained operations
[docs] def inspect(self: AsyncStream[T], cb: Callable[[T], Any]) -> AsyncStream[T]: """ Execute a function on each element without modifying it. .. seealso:: :meth:`sloths.Stream.inspect` This is mostly useful for debugging but could be used as the base for monitoring and metrics or any other side-effects. """ async def inspect(gen: AsyncIterable[T]) -> AsyncIterable[T]: async for x in gen: cb(x) yield x return AsyncPipe( self, inspect, name=f"inspect<{getattr(cb, '__name__', None) or repr(cb)}>", )
[docs] def enumerate(self: AsyncStream[T]) -> AsyncStream[tuple[int, T]]: """ Python's ``enumerate`` as a transform. .. seealso:: :meth:`sloths.Stream.enumerate` """ async def _enumerate( gen: AsyncIterable[T], ) -> AsyncIterable[tuple[int, T]]: i = 0 async for x in gen: yield i, x i += 1 return self.pipe(_enumerate)
[docs] def map(self, fn: Callable[[T], U]) -> AsyncStream[U]: """ Run a synchronous element-wise transform over the stream. .. seealso:: :meth:`sloths.Stream.map` """ async def _map(gen: AsyncIterable[T]) -> AsyncIterable[U]: async for x in gen: yield fn(x) return AsyncPipe(self, _map, name=f"map({fn})")
[docs] def try_map( self, fn: Callable[[T], U], exc_cls: tuple[type[Exception], ...] = (Exception,), *, cb: Callable[[Exception, T], None] | None = None, ) -> AsyncStream[U]: """ Run a synchronous element-wise transform over the stream and discard errors. .. seealso:: :meth:`sloths.Stream.try_map` """ # noqa: E501 async def _map_except(gen: AsyncIterable[T]) -> AsyncIterable[U]: async for x in gen: try: y = fn(x) except exc_cls as e: if cb: cb(e, x) continue yield y return AsyncPipe( self, _map_except, name=f"try_map({fn}, {exc_cls}, {cb})", )
[docs] def amap(self, fn: Callable[[T], Awaitable[U]]) -> AsyncStream[U]: """ Run an asynchronous element-wise transform over the stream. This is equivalent to ``AsyncStream(...).map(...).flatten()``. """ async def _map(gen: AsyncIterable[T]) -> AsyncIterable[U]: async for x in gen: yield await fn(x) return AsyncPipe(self, _map, name=f"map({fn})")
[docs] def atry_map( self, fn: Callable[[T], Awaitable[U]], exc_cls: tuple[type[Exception], ...] = (Exception,), *, cb: Callable[[Exception, T], None] | None = None, ) -> AsyncStream[U]: """ Run an asynchronous element-wise transform over the stream discard errors. """ # noqa: E501 async def _map_except(gen: AsyncIterable[T]) -> AsyncIterable[U]: async for x in gen: try: y = await fn(x) except exc_cls as e: if cb: cb(e, x) continue yield y return AsyncPipe( self, _map_except, name=f"try_map({fn}, {exc_cls}, {cb})", )
[docs] def try_( self, exc_cls: tuple[type[Exception], ...] = (Exception,), *, cb: Callable[[Exception], None] | None = None, ) -> AsyncStream[T]: """ Stop on the first exception and discard it. .. seealso:: :meth:`sloths.Stream.try_` This is more generic than :meth:`try_map` and will catch error that happened when calling ``next()`` on the upstream transform but will stop iteration on the first exception. """ async def _try(gen: AsyncIterable[T]) -> AsyncIterable[T]: it = aiter(gen) while True: try: yield await anext(it) except StopAsyncIteration: # noqa: PERF203 return except exc_cls as e: if cb: cb(e) return return AsyncPipe( self, _try, name=f"stop_on_exception({exc_cls}, {cb})", )
[docs] def batch(self, by: int) -> AsyncStream[tuple[T, ...]]: """ Buffer the stream and provide groups to downstream consumers. .. seealso:: :meth:`sloths.Stream.batch` .. warning:: This partially unwinds the stream and will increase memory usage. Only buffer to amounts you're comfortable holding in memory at once. """ return AsyncPipe(self, _batch, name=f"batch(by={by})", by=by)
[docs] def flatten( self: AsyncStream[AsyncIterable[U]] | AsyncStream[Iterable[U]] | AsyncStream[Awaitable[U]], ) -> AsyncStream[U]: """ Flatten iterators into their elements. This will flatten iterables, async iterables and awaitable, so it has the same utility as :meth:`sloths.Stream.flatten`: >>> import asyncio >>> asyncio.run(AsyncStream.range(11).batch(by=2).flatten().collect()) [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] >>> asyncio.run(AsyncStream.range(0).batch(by=2).flatten().collect()) [] But can also be used to flatten async results (this trivial case is equivalent to calling :meth:`amap`): >>> async def aadd_2(x): ... await asyncio.sleep(0.001) ... return x + 2 >>> asyncio.run(AsyncStream.range(11).map(aadd_2).flatten().collect()) [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] Or async iterables: >>> async def apair(x): ... await asyncio.sleep(0.001) ... for _ in range(2): ... yield x >>> asyncio.run(AsyncStream.range(5).map(apair).flatten().collect()) [0, 0, 1, 1, 2, 2, 3, 3, 4, 4] """ # I am pretty sure this is correct in practice but the type inference is # unhappy. Need to revisit. return AsyncPipe(self, _flatten, name="flatten") # type: ignore
[docs] def filter( self, predicate: Callable[[T], bool] | None = None, ) -> AsyncStream[T]: """ Filter elements by running them through a predicate function. .. seealso:: :meth:`sloths.Stream.filter` """ async def _filter(gen: AsyncIterable[T]) -> AsyncIterable[T]: if predicate: async for x in gen: if predicate(x): yield x else: async for x in gen: if x: yield x return AsyncPipe(self, _filter, name=f"filter({predicate})")
[docs] def afilter( self, predicate: Callable[[T], Awaitable[bool]], ) -> AsyncStream[T]: """ Filter elements by running them through an asynchronous predicate. """ async def _filter(gen: AsyncIterable[T]) -> AsyncIterable[T]: async for x in gen: if await predicate(x): yield x return AsyncPipe(self, _filter, name=f"filter({predicate})")
[docs] def take(self, count: int) -> AsyncStream[T]: """ Take up to ``count`` element from the stream and interrupt. .. seealso:: :meth:`sloths.Stream.take` Upstream generators will not be polled once we've reached the requested number of elements so the source can be consumed to its end separately. """ return AsyncPipe(self, _take, name=f"take({count})", n=count)
[docs] def skip(self, count: int) -> AsyncStream[T]: """ Skip over ``count`` element from the iterator. .. seealso:: :meth:`sloths.Stream.skip` """ return AsyncPipe(self, _skip, name=f"skip({count})", n=count)
[docs] def take_while( self, predicate: Callable[[T], bool] | None = None, ) -> AsyncStream[T]: """ Consume element from the stream until the predicate returns ``False``. .. seealso:: :meth:`sloths.Stream.take_while` """ return AsyncPipe( self, _take_while, name=f"take_while({predicate})", predicate=predicate, )
[docs] def skip_while( self, predicate: Callable[[T], bool] | None = None, ) -> AsyncStream[T]: """ Skip elements until the predicate returns ``True``. .. seealso:: :meth:`sloths.Stream.skip_while` """ return AsyncPipe( self, _skip_while, name=f"drop_while({predicate})", predicate=predicate, )
# Reducer / consuming methods
[docs] async def consume(self): """ Consume the stream but discard the results. This is useful for infinite pipelines or processing pipelines where the results are not important. """ async for _ in self: pass
@overload
[docs] async def collect(self) -> list[T]: ...
@overload async def collect( self, collector: Callable[[AsyncIterable[T]], Awaitable[U]], ) -> U: ... async def collect( self, collector: Callable[[AsyncIterable[T]], Awaitable[U]] | None = None, ) -> U | list[T]: """ Collect the iterator. By default this collects into a list but custom collectors are also supported as long as they accept async iterables as input. """ if collector is None: return [x async for x in self] return await collector(self)
[docs] async def count(self) -> int: """ Return the length of the stream after consuming it. ``__alen__`` would implicitly consume the stream in various places so is unsafe to add. """ s = 0 async for _ in self: s += 1 return s
@overload
[docs] async def nth(self, nth: int) -> T: ...
@overload async def nth(self, nth: int, *, default: T) -> T: ... @overload async def nth(self, nth: int, *, default: U) -> T | U: ... async def nth( self, nth: int, *, default: U | Literal[UNSET.U] = UNSET.U, ) -> T | U: """ Return the ``nth`` value. .. seealso:: :meth:`sloths.Stream.nth` Raises ``IndexError`` if the stream isn't long enough and a default value is not provided. """ await self.take(nth).consume() if default is not UNSET.U: return await anext(self, default) try: return await anext(self) except StopAsyncIteration: raise IndexError(nth) from None
[docs] async def find( self, predicate: Callable[[T], bool] | None = None, ) -> T | None: """ Find the first elements that satisfies a predicate. .. seealso:: :meth:`sloths.Stream.find` This short-cirtcuits so it won't consume the source iterator past the target element: """ return await anext(self.filter(predicate), None)
[docs] async def afind( self, predicate: Callable[[T], Awaitable[bool]], ) -> T | None: """ Find the first elements that satisfies an asynchronous predicate. This short-cirtcuits so it won't consume the source iterator past the target element: """ return await anext(self.afilter(predicate), None)
[docs] async def fold(self, fn: Callable[[U, T], U], acc: U) -> U: """ Fold every element into an accumulator function. .. seealso:: :meth:`sloths.Stream.fold` """ cur = acc async for x in self: cur = fn(cur, x) return cur
[docs] async def afold(self, fn: Callable[[U, T], Awaitable[U]], acc: U) -> U: """ Fold every element into an asynchronous accumulator function. """ cur = acc async for x in self: cur = await fn(cur, x) return cur
[docs] class AsyncPipe(Generic[T, P, U], AsyncStream[U]): """ A stream representing a source stream passed through a transform function. This should not be interacted with directly. .. seealso:: :meth:`Stream.pipe` """ def __init__( self, source: AsyncIterable[T], fn: AsyncTransform[T, P, U], name: str | None = None, *args: P.args, **kwargs: P.kwargs, ) -> None: self._inner = source self._transform = fn self._name = name self._iterator: AsyncIterator[U] | None = None self._args = args self._kwargs = kwargs @functools.cached_property def _iter(self) -> AsyncIterator[U]: return aiter( self._transform( aiter(self._inner), *self._args, **self._kwargs, ), ) def __str__(self) -> str: name = self._name or repr(self._transform) return f"{self._inner} | {name}({self._args}, {self._kwargs})" def __repr__(self) -> str: return f"AsyncPipe<{self!s}>"
# Async iterator utils
[docs] async def make_async(it: Iterable[T]) -> AsyncIterable[T]: """ Wrap a synchronous iterator in an asynchronous one. """ for x in iter(it): yield x
async def _batch(it: AsyncIterable[T], by: int) -> AsyncIterable[tuple[T, ...]]: """ Chunk an iterable into tuples of a given size. """ iterator = aiter(it) _batch: list[T] = [] try: while True: for _ in range(by): _batch.append(await anext(iterator)) # noqa: PERF401 yield tuple(_batch) del _batch[:] except StopAsyncIteration: if _batch: yield tuple(_batch) async def _flatten( gen: AsyncIterable[AsyncIterable[U]] | AsyncIterable[Iterable[U]] | AsyncIterable[Awaitable[U]], ) -> AsyncIterable[U]: """ Flatten an async iterator of awaitables, iterables, or async iterables. """ it = aiter(gen) try: first = await anext(it) except StopAsyncIteration: return if inspect.isawaitable(first): yield await first async for x in cast("AsyncIterator[Awaitable[U]]", it): yield await x elif isinstance(first, (AsyncIterable, AsyncIterator)): async for y in first: yield y async for x in cast("AsyncIterator[AsyncIterable[U]]", it): async for y in x: yield y else: for y in first: yield y async for x in cast("AsyncIterator[Iterable[U]]", it): for y in x: yield y async def _take(gen: AsyncIterable[T], *, n: int) -> AsyncIterable[T]: it = aiter(gen) try: for _ in range(n): yield await anext(it) except StopAsyncIteration: return async def _skip(gen: AsyncIterable[T], *, n: int) -> AsyncIterable[T]: it = aiter(gen) try: for _ in range(n): await anext(it) except StopAsyncIteration: return async for x in it: yield x async def _take_while( gen: AsyncIterable[T], *, predicate: Callable[[T], bool] | None = None, ) -> AsyncIterable[T]: it = aiter(gen) if predicate: async for x in it: if predicate(x): yield x else: return else: async for x in it: if x: yield x else: return async def _skip_while( gen: AsyncIterable[T], *, predicate: Callable[[T], bool] | None = None, ) -> AsyncIterable[T]: it = aiter(gen) if predicate: async for x in it: if not predicate(x): yield x break else: async for x in it: if not x: yield x break async for x in it: yield x