from __future__ import annotations
import functools
import itertools
from collections import deque
from collections.abc import Callable, Iterable, Iterator
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
Generic,
Literal,
ParamSpec,
SupportsIndex,
TypeAlias,
TypeVar,
overload,
)
from sloths._utils import UNSET, batch, window
if TYPE_CHECKING:
from sloths.ext.asyncio import AsyncStream
T = TypeVar("T")
U = TypeVar("U")
P = ParamSpec("P")
Transform: TypeAlias = Callable[Concatenate[Iterable[T], P], Iterable[U]]
[docs]
class Stream(Generic[T], Iterable[T]):
"""
Typed interface to build lazy generator/coroutines pipelines.
This technically works with any iterable but is primarily built to compose
lazy-generator pipelines into a single iterator. When used with generators
this provides good memory and throughput controls.
None of this can't be achieved either by colocating everything in a single
loop or composing generators outside-in by hand. This is a fairly light
abstraction with almost no runtime cost and is provided mostly for
ergonomics. The core benefits are:
- flat-definition of the pipeline
- stages defined in reading order instead of reverse order
- type erasure and safety
- composability
The simplest stream just wraps and consumes an iterable:
>>> s = Stream.range(10)
>>> list(s) # This will consume the iterator
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
But it becomes really useful when composing transformations. Taking a
trivial example of outside-in composition:
>>> def add_2(gen):
... for x in gen:
... yield x + 2
...
>>> def drop_multiples_of_3(gen):
... for x in gen:
... if x % 3 > 0:
... yield x
...
>>> gen = drop_multiples_of_3(
... add_2(
... range(10),
... ),
... )
>>> list(gen)
[2, 4, 5, 7, 8, 10, 11]
The equivalent form with :class:`Stream` is:
>>> stream = (
... Stream.range(10)
... .pipe(add_2)
... .pipe(drop_multiples_of_3)
... )
>>> list(stream)
[2, 4, 5, 7, 8, 10, 11]
Streams also provide a chainable API and convenience methods (largely
inspired by Rust's iterator trait) to make it easy to compose readable
pipelines without nesting.
Streams are also lazy as long as the transforms are well implemented (i.e.
they don't consume the entire source iterable in memory) and the pipeline
will run from the last transform, polling up the stack as needed.
For a simple example:
>>> source = iter(range(100_000_000_000)) # Problematically large
>>> (
... Stream(source)
... .pipe(add_2)
... .batch(10)
... .flatten()
... .pipe(drop_multiples_of_3)
... .inspect(print)
... .take(20)
... .fold(lambda x,y: x+y, 0)
... )
2
4
5
7
8
10
11
13
14
16
17
19
20
22
23
25
26
28
29
31
330
We can see that we haven't consumed too far into the source iterable:
>>> next(source)
30
The print calls in the last example also illustrate the laziness of the
streams. The final iterators polls from the last step which essentially
polls up the stack until any iterable yields data. So in the example above
there's only ever 10 integers passing through the pipeline at any given
time. This is primarily useful with lazy generators in order to control
peak memory usage.
.. warning::
Streams are *just* chained generators and don't provide any concurrency
primitives (threads or async). Everything is executing linearly and
behind the GIL. However nothing prevents a transform from using threads,
processes or asyncio behind the scene.
"""
def __init__(self, source: Iterable[T]) -> None:
self._source = source
@functools.cached_property
def _iter(self) -> Iterator[T]:
return iter(self._source)
def __iter__(self) -> Iterator[T]:
yield from self._iter
def __next__(self) -> T:
return next(self._iter)
def __repr__(self) -> str:
return f"Stream<{self._source!r}>"
[docs]
@classmethod
def range(cls: type[Stream[int]], *args: SupportsIndex) -> Stream[int]:
"""
Create a simple stream over ``range()``.
"""
return Stream(range(*args))
[docs]
def chain(self, *others: Iterable[T]) -> Stream[T]:
"""
Chain one or more iterables to the current ones.
Works with other streams:
>>> Stream.range(10).chain(
... Stream.range(5).map(lambda x: x + 20)
... ).collect()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 20, 21, 22, 23, 24]
And simple iterables:
>>> Stream.range(2).chain(range(3), range(2)).collect()
[0, 1, 0, 1, 2, 0, 1]
"""
return Stream(itertools.chain(self, *others))
[docs]
def pipe(
self,
fn: Transform[T, P, U],
*args: P.args,
**kwargs: P.kwargs,
) -> Pipe[T, P, U]:
"""
Chain a transform to a stream and return the resulting stream.
Transforms are the core composability primitive and are simply callables
which take an iterable and return another iterable. Usually these are
lazy generators.
>>> def to_str(iterable: Iterable[int]) -> Iterable[str]:
... for x in iterable:
... yield str(x)
...
>>> list(Stream.range(10).pipe(to_str))
['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
.. note::
Type information of the source stream is preserved, so in the
example above the first layer (``Stream.range(10)``) is a
``Stream[int, int]`` while the final stream is ``Stream[int, str]``
which is also an ``Iterable[str]``.
Transforms can also decide to short-circuit or selectively yield for
control-flow:
>>> def to_str_if_odd(iterable: Iterable[int]) -> Iterable[str]:
... for x in iterable:
... if x % 2:
... yield str(x)
...
>>> list(Stream.range(10).pipe(to_str_if_odd))
['1', '3', '5', '7', '9']
As transforms are just generator-factories they can hold state:
>>> def track_bounds(gen: Iterable[int]) -> Iterable[int]:
... m, M = 0, 0
... for x in gen:
... m, M = min(m, x), max(M, x)
... yield x
... print(f'Min {m}, Max {M}')
>>> s = Stream.range(10).pipe(track_bounds)
>>> list(s)
Min 0, Max 9
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
The flip-side of this being that streams are generally not safe to reuse
once iterated upon.
.. warning::
When writing transforms be careful not to accidentally consume the
iterable as this would negate much of the benefit of chaining
generators in the first place.
"""
return Pipe(self, fn, fn.__name__, *args, **kwargs)
# Chained operations
[docs]
def inspect(self: Stream[T], cb: Callable[[T], Any]) -> Stream[T]:
"""
Execute a function on each element without modifying it.
This is mostly useful for debugging but could be used as the base for
monitoring and metrics or any other side-effects.
>>> Stream.range(4).inspect(print).collect()
0
1
2
3
[0, 1, 2, 3]
"""
def inspect(gen: Iterable[T]) -> Iterable[T]:
for x in gen:
cb(x)
yield x
return Pipe(
self,
inspect,
name=f"inspect<{getattr(cb, '__name__', None) or repr(cb)}>",
)
[docs]
def enumerate(self: Stream[T]) -> Stream[tuple[int, T]]:
"""
Python's ``enumerate`` as a transform.
>>> Stream.range(5, 11).enumerate().collect()
[(0, 5), (1, 6), (2, 7), (3, 8), (4, 9), (5, 10)]
"""
return self.pipe(enumerate)
[docs]
def map(self, fn: Callable[[T], U]) -> Stream[U]:
"""
Run an element-wise transform over the stream.
>>> Stream.range(10).map(lambda x: x * 2).collect()
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
"""
def _map(gen: Iterable[T]) -> Iterable[U]:
for x in gen:
yield fn(x)
return Pipe(self, _map, name=f"map({fn})")
[docs]
def try_map(
self,
fn: Callable[[T], U],
exc_cls: tuple[type[Exception], ...] = (Exception,),
*,
cb: Callable[[Exception, T], None] | None = None,
) -> Stream[U]:
"""
Run an element-wise transform over the stream and discard errors.
>>> def no_2(x):
... if x == 2:
... raise ValueError(2)
... return x
>>> list(Stream.range(10).map(no_2))
Traceback (most recent call last):
...
ValueError: 2
>>> list(Stream.range(10).try_map(no_2, (ValueError,)))
[0, 1, 3, 4, 5, 6, 7, 8, 9]
Optionally you can pass in a callback to handle errors out of band:
>>> list(Stream.range(10).try_map(no_2, (ValueError,), cb=print))
2 2
[0, 1, 3, 4, 5, 6, 7, 8, 9]
"""
def _map_except(gen: Iterable[T]) -> Iterable[U]:
for x in gen:
try:
y = fn(x)
except exc_cls as e:
if cb:
cb(e, x)
continue
yield y
return Pipe(
self,
_map_except,
name=f"try_map({fn}, {exc_cls}, {cb})",
)
[docs]
def try_(
self,
exc_cls: tuple[type[Exception], ...] = (Exception,),
*,
cb: Callable[[Exception], None] | None = None,
) -> Stream[T]:
"""
Stop on the first exception and discard it.
This is more generic than :meth:`try_map` and will catch error that
happened when calling ``next()`` on the upstream transform but will stop
iteration on the first exception.
>>> def no_2(x):
... if x == 2:
... raise ValueError(2)
... return x
>>> Stream.range(10).map(no_2).collect()
Traceback (most recent call last):
...
ValueError: 2
>>> Stream.range(10).map(no_2).try_((ValueError,)).collect()
[0, 1]
Optionally you can pass in a callback to handle errors out of band:
>>> list(Stream.range(10).map(no_2).try_((ValueError,), cb=print))
2
[0, 1]
If there are no errors it flows to the end normally:
>>> Stream.range(10).map(lambda x: x + 2).try_((ValueError,), cb=print)\
.collect()
[2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
"""
def _try(gen: Iterable[T]) -> Iterable[T]:
it = iter(gen)
while True:
try:
yield next(it)
except StopIteration: # noqa: PERF203
return
except exc_cls as e:
if cb:
cb(e)
return
return Pipe(
self,
_try,
name=f"stop_on_exception({exc_cls}, {cb})",
)
[docs]
def batch(self, by: int) -> Stream[Iterable[T]]:
"""
Buffer the stream and provide groups to downstream consumers.
.. warning::
This partially unwinds the stream and will increase memory usage.
Only buffer to amounts you're comfortable holding in memory at once.
>>> Stream.range(11).batch(by=2).collect()
[(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10,)]
To simply buffer without exposing groups simply chain this with
:meth:`flatten()` which will ensure at least `by` elements are ready
before forwarding them downstream one by one:
>>> list(Stream.range(11).batch(by=2).flatten())
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Batches may not have the number of elements if the end of the stream
doesn't have enough to fill a batch:
>>> list(Stream.range(11).batch(by=3))
[(0, 1, 2), (3, 4, 5), (6, 7, 8), (9, 10)]
"""
return Pipe(self, batch, name=f"batch(by={by})", by=by)
[docs]
def flatten(self: Stream[Iterable[U]]) -> Stream[U]:
"""
Flatten iterators into their elements.
This is usually most useful after a buffered operation.
>>> Stream.range(11).batch(by=2).flatten().collect()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
.. seealso:: :meth:`~Stream.flat_map`.
"""
return Pipe(self, itertools.chain.from_iterable, name="flatten")
[docs]
def flat_map(self, fn: Callable[[T], Iterable[U]]) -> Stream[U]:
"""
Run an element-wise transform over the stream and flatten results.
>>> Stream.range(10).flat_map(lambda x: [x] * 2).collect()
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9]
"""
def _flat_map(gen: Iterable[T]) -> Iterable[U]:
for x in gen:
yield from fn(x)
return Pipe(self, _flat_map, name=f"flat_map({fn})")
[docs]
def filter(self, predicate: Callable[[T], bool] | None = None) -> Stream[T]:
"""
Filter elements by running them through a predicate function.
This supports passing no predicate ion which case it checks for truthy
values:
>>> Stream([1, 2, None, 0, 4]).filter().collect()
[1, 2, 4]
>>> Stream.range(10).filter(lambda x: bool(x % 2)).collect()
[1, 3, 5, 7, 9]
"""
def _filter(gen: Iterable[T]) -> Iterable[T]:
if predicate:
for x in gen:
if predicate(x):
yield x
else:
for x in gen:
if x:
yield x
return Pipe(self, _filter, name=f"filter({predicate})")
[docs]
def take(self, count: int) -> Stream[T]:
"""
Take up to ``count`` element from the stream and interrupt.
Upstream generators will not be polled once we've reached the requested
number of elements so the source can be consumed to its end separately.
>>> it = iter(range(10))
>>> Stream(it).take(4).collect()
[0, 1, 2, 3]
>>> list(it)
[4, 5, 6, 7, 8, 9]
Taking more than the size in the iterator has no effect:
>>> Stream.range(5).take(10).collect()
[0, 1, 2, 3, 4]
"""
return Pipe(
self,
lambda x: itertools.islice(x, count),
name=f"take({count})",
)
[docs]
def skip(self, count: int) -> Stream[T]:
"""
Skip over ``count`` element from the iterator.
>>> list(Stream.range(10).skip(4))
[4, 5, 6, 7, 8, 9]
"""
return Pipe(
self,
lambda x: itertools.islice(x, count, None),
name=f"skip({count})",
)
[docs]
def take_while(
self,
predicate: Callable[[T], bool] | None = None,
) -> Stream[T]:
"""
Consume element from the stream until the predicate returns ``False``.
>>> it = iter(range(10))
>>> list(Stream(it).take_while(lambda x: x == 0 or x % 3 != 0))
[0, 1, 2]
Note that the first failing element of the iterator is consumed:
>>> list(it)
[4, 5, 6, 7, 8, 9]
Passing no predicate is also supported:
>>> list(Stream([1, 2, 0, 3]).take_while())
[1, 2]
"""
return Pipe(
self,
lambda x: itertools.takewhile(predicate or bool, x),
name=f"take_while({predicate})",
)
[docs]
def skip_while(
self,
predicate: Callable[[T], bool] | None = None,
) -> Stream[T]:
"""
Skip elements until the predicate returns ``True``.
>>> Stream.range(10).skip_while(lambda x: x == 0 or x % 3 != 0)\
.collect()
[3, 4, 5, 6, 7, 8, 9]
Passing no predicate is also supported:
>>> list(Stream([1, 2, 0, None, 1, 2, 0, 3]).skip_while())
[0, None, 1, 2, 0, 3]
"""
return Pipe(
self,
lambda x: itertools.dropwhile(predicate or bool, x),
name=f"take_while({predicate})",
)
[docs]
def step_by(self, step: int) -> Stream[T]:
"""
Consume iterators by a given step size each iteration.
This consumes elements after their predecessor has been consumed.
>>> Stream.range(10).step_by(2).collect()
[0, 2, 4, 6, 8]
"""
return Pipe(
self,
lambda x: itertools.islice(x, None, None, step),
name=f"step_by({step})",
)
[docs]
def window(self, size: int) -> Stream[tuple[T, ...]]:
"""
Transform the stream into a stream of sliding windows.
Each window is a tuple containing ``size`` consecutive elements from the
stream. The windows overlap, with each window shifted one element
forward
from the previous window.
If the stream contains fewer elements than the window size, an empty
stream is returned.
>>> Stream.range(5).window(3).collect()
[(0, 1, 2), (1, 2, 3), (2, 3, 4)]
>>> Stream([1, 2]).window(3).collect()
[]
>>> Stream([]).window(2).collect()
[]
"""
return Pipe(self, window, name=f"window(size={size})", size=size)
# Adapters
[docs]
def peekable(self) -> Peekable[T]:
"""
Return a :class:`Peekable` version of the current stream.
>>> s = Stream.range(100).peekable()
>>> s.peek()
0
"""
return Peekable(self)
[docs]
def to_async(self) -> AsyncStream[T]:
"""
Return a :class:`sloths.ext.asyncio.AsyncStream` version.
"""
from sloths.ext.asyncio import AsyncStream, make_async
return AsyncStream(make_async(self))
# Reducer / consuming methods
[docs]
def consume(self):
"""
Consume the stream but discard the results.
This is useful for infinite pipelines or processing pipelines where the
results are not important.
"""
for _ in self:
pass
@overload
def collect(self) -> list[T]: ...
@overload
def collect(self, collector: Callable[[Iterable[T]], U]) -> U: ...
[docs]
def collect(
self,
collector: Callable[[Iterable[T]], U] | None = None,
) -> U | list[T]:
"""
Collect the iterator.
By default this collects into a list, so this:
>>> list(Stream.range(10))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Is equivalent to:
>>> Stream.range(10).collect()
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Custom collectors are also supported:
>>> Stream.range(10).map(lambda x: x // 2).collect(set)
{0, 1, 2, 3, 4}
"""
if collector is None:
return list(self)
return collector(self)
[docs]
def count(self) -> int:
"""
Return the length of the stream after consuming it.
``__len__`` would implicitly consume the stream in various places so is
unsafe to add.
>>> Stream.range(100).count()
100
"""
return sum(1 for _ in self)
@overload
def nth(self, nth: int) -> T: ...
@overload
def nth(self, nth: int, *, default: T) -> T: ...
@overload
def nth(self, nth: int, *, default: U) -> T | U: ...
[docs]
def nth(
self,
nth: int,
*,
default: U | Literal[UNSET.U] = UNSET.U,
) -> T | U:
"""
Return the ``nth`` value.
>>> Stream.range(10).nth(0)
0
>>> Stream.range(10).nth(6)
6
Raises ``IndexError`` if the stream is too short:
>>> Stream.range(10).nth(12)
Traceback (most recent call last):
...
IndexError: 12
A ``default`` can be provided as a fallback:
>>> Stream.range(10).nth(12, default=42)
42
This short-cirtcuits so it won't consume the source iterator past the
target element:
>>> source = iter(range(10))
>>> Stream(source).nth(3)
3
>>> list(source)
[4, 5, 6, 7, 8, 9]
"""
self.take(nth).consume()
if default is not UNSET.U:
return next(self, default)
try:
return next(self)
except StopIteration:
raise IndexError(nth) from None
[docs]
def find(self, predicate: Callable[[T], bool] | None = None) -> T | None:
"""
Find the first elements that satisfies a predicate.
>>> Stream.range(10).find(lambda x: x == 3)
3
This short-cirtcuits so it won't consume the source iterator past the
target element:
>>> source = iter(range(10))
>>> Stream(source).find(lambda x: x == 3)
3
>>> list(source)
[4, 5, 6, 7, 8, 9]
Returns ``None`` if the item is not found:
>>> source = iter(range(10))
>>> Stream(source).find(lambda x: x == 102)
>>> list(source)
[]
"""
return next(self.filter(predicate), None)
[docs]
def fold(self, fn: Callable[[U, T], U], acc: U) -> U:
"""
Fold every element into an accumulator function.
>>> Stream.range(10).fold(lambda x,y: x + y, 0)
45
>>> Stream.range(10).fold(lambda y, x: [x, *y], [])
[9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
"""
return functools.reduce(fn, self, acc)
class Pipe(Generic[T, P, U], Stream[U]):
"""
A stream representing a source stream passed through a transform function.
This should not be interacted with directly.
.. seealso::
:meth:`Stream.pipe`
"""
def __init__(
self,
source: Iterable[T],
fn: Transform[T, P, U],
name: str | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
self._inner = source
self._transform = fn
self._name = name
self._iterator: Iterator[U] | None = None
self._args = args
self._kwargs = kwargs
@functools.cached_property
def _iter(self) -> Iterator[U]:
return iter(
self._transform(
iter(self._inner),
*self._args,
**self._kwargs,
),
)
def __str__(self) -> str:
name = self._name or repr(self._transform)
return f"{self._inner} | {name}({self._args}, {self._kwargs})"
def __repr__(self) -> str:
return f"Pipe<{self!s}>"
[docs]
class Peekable(Stream[T]):
"""
A :class:`Stream` with a :meth:`peek()` method.
.. warning::
This may have a memory impact as it will buffer elements up to the
furthest index peeked at.
"""
def __init__(self, source: Iterator[T]) -> None:
self._source = source
self._buffer: deque[T] = deque()
@functools.cached_property
def _iter(self) -> Iterator[T]:
return iter(self._source)
def __iter__(self) -> Iterator[T]:
return self
def __next__(self) -> T:
if self._buffer:
return self._buffer.popleft()
return next(self._iter)
def __repr__(self) -> str:
return f"Peekable<{self._source!s}>"
@overload
def peek(self, n: int) -> T: ...
@overload
def peek(self, n: int, *, default: T) -> T: ...
[docs]
def peek(
self,
n: int = 1,
*,
default: U | Literal[UNSET.U] = UNSET.U,
) -> T | U:
"""
Return the element n positions ahead without consuming the stream.
>>> s = Stream.range(10).peekable()
>>> s.peek()
0
>>> next(s)
0
>>> s.peek(4)
4
>>> s.peek(2)
2
>>> next(s)
1
The ``Peekable`` instance is a regular stream so you can chain calls:
>>> s.take(5).collect()
[2, 3, 4, 5, 6]
Peeking past the stream raises ``IndexError``:
>>> s.peek(20)
Traceback (most recent call last):
...
IndexError: 20
Which can be avoided with a default value:
>>> s.peek(20, default=None) is None
True
"""
if len(self._buffer) < (n + 1):
self._buffer.extend(
itertools.islice(self._iter, n - len(self._buffer) + 1),
)
try:
return self._buffer[n - 1]
except IndexError:
if default is not UNSET.U:
return default
raise IndexError(n) from None