-
-
Notifications
You must be signed in to change notification settings - Fork 126
feat: add worker-side task batching #634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # broker.py | ||
| import asyncio | ||
|
|
||
| from taskiq import InMemoryBroker | ||
|
|
||
| broker = InMemoryBroker() | ||
|
|
||
|
|
||
| @broker.task(batch=True, batch_size=100, batch_timeout=3) | ||
| async def process_items(items: list[int]) -> int: | ||
| # The worker collects many `.kiq` calls and invokes this | ||
| # function once with the accumulated list of items. | ||
| print(f"Processing a batch of {len(items)} items.") | ||
| return sum(items) | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| await broker.startup() | ||
| # Each `.kiq` sends a single item. They are buffered and run | ||
| # together once the batch is flushed. | ||
| tasks = [await process_items.kiq(i) for i in range(10)] | ||
| # In tests, `wait_all` flushes any pending batches and waits | ||
| # for them to finish before we read the results. | ||
| await broker.wait_all() | ||
| for task in tasks: | ||
| result = await task.wait_result(timeout=5) | ||
| # Every item in the batch shares the same batch result. | ||
| print(f"Returned value: {result.return_value}") | ||
| await broker.shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| # broker.py | ||
| import asyncio | ||
|
|
||
| from taskiq_redis import RedisAsyncResultBackend, RedisStreamBroker | ||
|
|
||
| broker = RedisStreamBroker(url="redis://localhost:6379").with_result_backend( | ||
| RedisAsyncResultBackend(redis_url="redis://localhost:6379"), | ||
| ) | ||
|
|
||
|
|
||
| @broker.task(batch=True, batch_size=100, batch_timeout=3) | ||
| async def process_items(items: list[int]) -> int: | ||
| # The worker collects many `.kiq` calls and invokes this | ||
| # function once with the accumulated list of items. | ||
| print(f"Processing a batch of {len(items)} items.") | ||
| return sum(items) | ||
|
|
||
|
|
||
| async def main() -> None: | ||
| await broker.startup() | ||
| # Each `.kiq` sends a single item. The worker buffers them and | ||
| # runs `process_items` once with the whole batch. | ||
| tasks = [await process_items.kiq(i) for i in range(10)] | ||
| for task in tasks: | ||
| result = await task.wait_result(timeout=5) | ||
| # Every item in the batch shares the same batch result. | ||
| print(f"Returned value: {result.return_value}") | ||
| await broker.shutdown() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| asyncio.run(main()) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,92 @@ | ||
| --- | ||
| title: Batching tasks | ||
| order: 6 | ||
| --- | ||
|
|
||
| # Batching tasks | ||
|
|
||
| Some tasks have a high fixed cost per call but become much cheaper when processed together. Think of database writes, calls to an external API, ML inference, event publishing or search indexing — running them one by one wastes most of the time on overhead. If a single call takes ~1 second, then 10 separate calls take ~10 seconds, but processing all 10 at once might take only ~3 seconds. | ||
|
|
||
| Taskiq can collect many task invocations into a single batched call. Instead of running every message on its own, the worker buffers messages of the same task and executes the function once with the whole list. | ||
|
|
||
| ## Defining a batched task | ||
|
|
||
| Pass `batch=True` to the `task` decorator and declare the function with a single parameter that receives the list of items. | ||
|
|
||
| ```python | ||
| @broker.task(batch=True, batch_size=100, batch_timeout=3) | ||
| async def process_items(items: list[int]) -> int: | ||
| return sum(items) | ||
| ``` | ||
|
|
||
| Each `.kiq` call sends a single item, exactly like a normal task: | ||
|
|
||
| ```python | ||
| await process_items.kiq(1) | ||
| await process_items.kiq(2) | ||
| ``` | ||
|
|
||
| The worker accumulates these items and calls `process_items` once with the collected list (e.g. `[1, 2, ...]`). | ||
|
|
||
| ::: tip Typed by design | ||
|
|
||
| `.kiq` accepts a single element, while the function body receives `list[item]`. | ||
| Both sides are correctly typed: `process_items.kiq(1)` type-checks, but `process_items.kiq([1, 2])` is reported as a type error. | ||
|
|
||
| ::: | ||
|
|
||
| ## When a batch is flushed | ||
|
|
||
| A batch is sent for execution as soon as **either** condition is met: | ||
|
|
||
| - **`batch_size`** — the buffer reaches this number of items, or | ||
| - **`batch_timeout`** — this many seconds pass since the first item entered the buffer. | ||
|
|
||
| Whichever happens first wins. You must set at least one of the two; you can set both. The timer starts with the first item of a fresh buffer and resets after each flush. When a worker shuts down gracefully, any buffered items are flushed so nothing is lost. | ||
|
|
||
| Each worker buffers independently and keeps a separate buffer per task name. | ||
|
|
||
| ## Results and acknowledgement | ||
|
|
||
| A batch produces a single result. That same return value (or error) is stored for **every** task in the batch, so each `.kiq` call can still await its own result. If the batched function raises, every task in the batch receives that error. Every message is acknowledged according to the configured [acknowledgement type](./cli.md). | ||
|
|
||
| ::: caution Per-item granularity | ||
|
|
||
| Batching trades per-item isolation for throughput. The whole batch shares one result and one fate — there are no per-item results or per-item error handling. A batched task must take exactly one positional argument (the list); keyword arguments are not part of the batched call. | ||
|
|
||
| ::: | ||
|
|
||
| ## Trying it locally | ||
|
|
||
| Batching is a worker-side feature, but the `InMemoryBroker` supports it too, so you can try it without setting up a real broker. Call `wait_all` to flush any pending batches and wait for them to finish before reading results. | ||
|
|
||
| @[code python](../examples/batching/inmemory_batch.py) | ||
|
|
||
| Running this prints a single batch execution and the shared result: | ||
|
|
||
| ```bash:no-line-numbers | ||
| $ python broker.py | ||
| Processing a batch of 10 items. | ||
| Returned value: 45 | ||
| ... (10 times) | ||
| ``` | ||
|
|
||
| ::: warning InMemoryBroker behavior | ||
|
|
||
| The `InMemoryBroker` executes tasks inplace, so batches are flushed by `batch_size`, by `wait_all`, or — with `await_inplace=True` — immediately as one-item batches. This is convenient for tests, but to see real batching across processes you need a distributed broker and a worker. | ||
|
|
||
| ::: | ||
|
|
||
| ## Running with a worker | ||
|
|
||
| In production, batching happens inside the worker. Using [taskiq-redis](https://pypi.org/project/taskiq-redis/) as an example: | ||
|
|
||
| @[code python](../examples/batching/redis_batch.py) | ||
|
|
||
| Start one or more workers: | ||
|
|
||
| ```bash:no-line-numbers | ||
| taskiq worker broker:broker | ||
| ``` | ||
|
|
||
| Then run the script to send items. The worker collects them and runs `process_items` once per batch. With several workers, each one batches the messages it receives independently, so the load is spread across all of them. |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||||||||||||||||||||||||||||||||
| TYPE_CHECKING, | ||||||||||||||||||||||||||||||||||||
| Any, | ||||||||||||||||||||||||||||||||||||
| ClassVar, | ||||||||||||||||||||||||||||||||||||
| Literal, | ||||||||||||||||||||||||||||||||||||
| ParamSpec, | ||||||||||||||||||||||||||||||||||||
| TypeAlias, | ||||||||||||||||||||||||||||||||||||
| TypeVar, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -21,9 +22,9 @@ | |||||||||||||||||||||||||||||||||||
| from taskiq.abc.middleware import TaskiqMiddleware | ||||||||||||||||||||||||||||||||||||
| from taskiq.abc.serializer import TaskiqSerializer | ||||||||||||||||||||||||||||||||||||
| from taskiq.acks import AckableMessage | ||||||||||||||||||||||||||||||||||||
| from taskiq.decor import AsyncTaskiqDecoratedTask | ||||||||||||||||||||||||||||||||||||
| from taskiq.decor import AsyncBatchedTaskiqDecoratedTask, AsyncTaskiqDecoratedTask | ||||||||||||||||||||||||||||||||||||
| from taskiq.events import TaskiqEvents | ||||||||||||||||||||||||||||||||||||
| from taskiq.exceptions import TaskBrokerMismatchError | ||||||||||||||||||||||||||||||||||||
| from taskiq.exceptions import TaskBrokerMismatchError, TaskiqBatchConfigError | ||||||||||||||||||||||||||||||||||||
| from taskiq.formatters.proxy_formatter import ProxyFormatter | ||||||||||||||||||||||||||||||||||||
| from taskiq.message import BrokerMessage | ||||||||||||||||||||||||||||||||||||
| from taskiq.result_backends.dummy import DummyResultBackend | ||||||||||||||||||||||||||||||||||||
|
|
@@ -43,6 +44,7 @@ | |||||||||||||||||||||||||||||||||||
| from taskiq.abc.result_backend import AsyncResultBackend | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| _T = TypeVar("_T") | ||||||||||||||||||||||||||||||||||||
| _Item = TypeVar("_Item") | ||||||||||||||||||||||||||||||||||||
| _FuncParams = ParamSpec("_FuncParams") | ||||||||||||||||||||||||||||||||||||
| _ReturnType = TypeVar("_ReturnType") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
@@ -255,6 +257,40 @@ def listen(self) -> AsyncGenerator[bytes | AckableMessage, None]: | |||||||||||||||||||||||||||||||||||
| :return: nothing. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @staticmethod | ||||||||||||||||||||||||||||||||||||
| def _validate_batch_labels(labels: dict[str, Any]) -> None: | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| Validate batch related labels. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| :param labels: labels passed to the task decorator. | ||||||||||||||||||||||||||||||||||||
| :raises TaskiqBatchConfigError: if batch configuration is invalid. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| if not labels.get("batch"): | ||||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||||
| batch_size = labels.get("batch_size") | ||||||||||||||||||||||||||||||||||||
| batch_timeout = labels.get("batch_timeout") | ||||||||||||||||||||||||||||||||||||
| invalid = ( | ||||||||||||||||||||||||||||||||||||
| (batch_size is None and batch_timeout is None) | ||||||||||||||||||||||||||||||||||||
| or (batch_size is not None and batch_size < 1) | ||||||||||||||||||||||||||||||||||||
| or (batch_timeout is not None and batch_timeout <= 0) | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| if invalid: | ||||||||||||||||||||||||||||||||||||
| raise TaskiqBatchConfigError | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @overload | ||||||||||||||||||||||||||||||||||||
| def task( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
| *, | ||||||||||||||||||||||||||||||||||||
| batch: Literal[True], | ||||||||||||||||||||||||||||||||||||
| batch_size: int | None = None, | ||||||||||||||||||||||||||||||||||||
| batch_timeout: float | None = None, | ||||||||||||||||||||||||||||||||||||
| **labels: Any, | ||||||||||||||||||||||||||||||||||||
| ) -> Callable[ | ||||||||||||||||||||||||||||||||||||
| [Callable[[list[_Item]], Awaitable[_ReturnType]]], | ||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This definition ignores possible use of sync functions. Is it by design? |
||||||||||||||||||||||||||||||||||||
| AsyncBatchedTaskiqDecoratedTask[_Item, ..., _ReturnType], | ||||||||||||||||||||||||||||||||||||
| ]: # pragma: no cover | ||||||||||||||||||||||||||||||||||||
|
Comment on lines
+288
to
+291
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that we need this decorator class. Since it's used only for type annotations, I see that there is a more straightforward way for deconstructing ParamSpec. You can leverage the power of
Suggested change
This way, we ensure that first argument of the function is actually a list of items by deconstructing its type signature. It's kinda like pattern matching on types. However, this implementation allows for more than one argument. Which is might be not as preferable. Because it would complicate the batcher's logic a bit. But might be worth looking into. If you want to support your original single item type, it can be easily done with the following signature:
Suggested change
|
||||||||||||||||||||||||||||||||||||
| ... | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| @overload | ||||||||||||||||||||||||||||||||||||
| def task( | ||||||||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -301,6 +337,7 @@ def task( # type: ignore[misc] | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| :returns: decorator function or AsyncTaskiqDecoratedTask. | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| self._validate_batch_labels(labels) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| def make_decorated_task( | ||||||||||||||||||||||||||||||||||||
| inner_labels: dict[str, str | int], | ||||||||||||||||||||||||||||||||||||
|
|
@@ -332,8 +369,11 @@ def inner( | |||||||||||||||||||||||||||||||||||
| if "return" in sign: | ||||||||||||||||||||||||||||||||||||
| return_type = sign["return"] | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| decorator_cls = self.decorator_class | ||||||||||||||||||||||||||||||||||||
| if inner_labels.get("batch"): | ||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for custom decorator class as shown above. |
||||||||||||||||||||||||||||||||||||
| decorator_cls = AsyncBatchedTaskiqDecoratedTask | ||||||||||||||||||||||||||||||||||||
| decorated_task = wrapper( | ||||||||||||||||||||||||||||||||||||
| self.decorator_class( | ||||||||||||||||||||||||||||||||||||
| decorator_cls( | ||||||||||||||||||||||||||||||||||||
| broker=self, | ||||||||||||||||||||||||||||||||||||
| original_func=func, | ||||||||||||||||||||||||||||||||||||
| labels=inner_labels, | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| ParamSpec, | ||
| TypeVar, | ||
| Union, | ||
| cast, | ||
| overload, | ||
| ) | ||
|
|
||
|
|
@@ -23,6 +24,7 @@ | |
| from taskiq.scheduler.scheduled_task import CronSpec | ||
|
|
||
| _T = TypeVar("_T") | ||
| _Item = TypeVar("_Item") | ||
| _FuncParams = ParamSpec("_FuncParams") | ||
| _ReturnType = TypeVar("_ReturnType") | ||
|
|
||
|
|
@@ -232,3 +234,37 @@ def kicker(self) -> AsyncKicker[_FuncParams, _ReturnType]: | |
|
|
||
| def __repr__(self) -> str: | ||
| return f"AsyncTaskiqDecoratedTask({self.task_name})" | ||
|
|
||
|
|
||
| class AsyncBatchedTaskiqDecoratedTask( | ||
| AsyncTaskiqDecoratedTask[_FuncParams, _ReturnType], | ||
| Generic[_Item, _FuncParams, _ReturnType], | ||
| ): | ||
| """ | ||
| Task that is executed in batches. | ||
|
|
||
| The decorated function receives ``list[_Item]``, but each ``kiq`` call | ||
| sends a single ``_Item``. The worker accumulates these items and invokes | ||
| the function once with the collected list, flushing when ``batch_size`` is | ||
| reached or ``batch_timeout`` seconds elapse since the first buffered item. | ||
| """ | ||
|
|
||
| # The base `kiq` takes the function's full params (`list[_Item]`). | ||
| # A batched task is enqueued one element at a time, so we deliberately | ||
| # narrow the signature to a single `_Item`. mypy flags this as an | ||
| # incompatible override, which is intended here. | ||
| async def kiq( # type: ignore[override] | ||
| self, | ||
| item: _Item, | ||
| ) -> AsyncTaskiqTask[_ReturnType]: | ||
| """ | ||
| Send a single item to be processed as part of a batch. | ||
|
|
||
| :param item: one element that becomes a member of the batched list. | ||
| :returns: taskiq task for this individual message. | ||
| """ | ||
| kicker = cast( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would you need this cast? |
||
| "AsyncKicker[..., _ReturnType]", | ||
| self.kicker(), | ||
| ) | ||
| return await kicker.kiq(item) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You forgot
task_namehere.