Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 62 additions & 40 deletions taskiq/receiver/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ async def listen(self, finish_event: asyncio.Event) -> None: # pragma: no cover
if self.on_exit is not None:
self.on_exit(self)

async def prefetcher(
async def prefetcher( # noqa: C901
self,
queue: "asyncio.Queue[bytes | AckableMessage]",
finish_event: asyncio.Event,
Expand All @@ -396,48 +396,70 @@ async def prefetcher(
"""
fetched_tasks: int = 0
iterator = self.broker.listen()
current_message: asyncio.Task[bytes | AckableMessage] = asyncio.create_task(
iterator.__anext__(), # type: ignore
)
current_message: asyncio.Task[bytes | AckableMessage] | None = None

while True:
if finish_event.is_set():
break
try:
await self.sem_prefetch.acquire()
if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
):
logger.info("Max number of tasks executed.")
try:
while not finish_event.is_set():
try:
await self.sem_prefetch.acquire()
if (
self.max_tasks_to_execute
and fetched_tasks >= self.max_tasks_to_execute
):
logger.info("Max number of tasks executed.")
break
if current_message is None:
current_message = asyncio.create_task(
iterator.__anext__(), # type: ignore
)
# Here we wait for the message to be fetched,
# but we make it with timeout so it can be interrupted
done, _ = await asyncio.wait({current_message}, timeout=0.3)
# If the message is not fetched, we release the semaphore
# and continue the loop. So it will check if finished event was set.
if not done:
self.sem_prefetch.release()
continue
# We're done, so now we need to check
# whether task has returned an error.
message = current_message.result()
current_message = None
fetched_tasks += 1
await queue.put(message)
# Custom hooks for OTel and any future instrumentations
for middleware in reversed(self.broker.middlewares):
if hasattr(middleware, "on_prefetch_queue_add"):
await maybe_awaitable(
middleware.on_prefetch_queue_add(), # type: ignore
)
except (asyncio.CancelledError, StopAsyncIteration):
break
# Here we wait for the message to be fetched,
# but we make it with timeout so it can be interrupted
done, _ = await asyncio.wait({current_message}, timeout=0.3)
# If the message is not fetched, we release the semaphore
# and continue the loop. So it will check if finished event was set.
if not done:
self.sem_prefetch.release()
except Exception:
logger.exception("Error while prefetching.")
# current_message set => fetch failed before enqueue, so we
# still own the permit and a (possibly broken) iterator.
# Otherwise it's queued and the runner owns the permit;
# releasing here would leak a prefetch slot.
if current_message is not None:
current_message = None
iterator = self.broker.listen()
self.sem_prefetch.release()
continue
# We're done, so now we need to check
# whether task has returned an error.
message = current_message.result()
current_message = asyncio.create_task(iterator.__anext__()) # type: ignore
fetched_tasks += 1
await queue.put(message)
# Custom hooks for OTel and any future instrumentations
for middleware in reversed(self.broker.middlewares):
if hasattr(middleware, "on_prefetch_queue_add"):
await maybe_awaitable(
middleware.on_prefetch_queue_add(), # type: ignore
)
except (asyncio.CancelledError, StopAsyncIteration):
break
# We don't want to fetch new messages if we are shutting down.
logger.info("Stopping prefetching messages...")
current_message.cancel()
await queue.put(QUEUE_DONE)
self.sem_prefetch.release()
finally:
# We don't want to fetch new messages if we are shutting down.
logger.info("Stopping prefetching messages...")
# Short window to deliver, then forward or cancel.
if current_message is not None:
await asyncio.wait({current_message}, timeout=0.3)
if not current_message.done():
current_message.cancel()
elif (
not current_message.cancelled()
and current_message.exception() is None
):
await queue.put(current_message.result())
await queue.put(QUEUE_DONE)
self.sem_prefetch.release()

async def runner(
self,
Expand Down
67 changes: 66 additions & 1 deletion tests/receiver/test_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random
import time
import unittest.mock
from collections.abc import Generator
from collections.abc import AsyncGenerator, Generator
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Any, ClassVar
Expand Down Expand Up @@ -600,3 +600,68 @@ async def test_no_semaphore_without_max_async_tasks() -> None:
"""Test that semaphore is None when max_async_tasks is not set."""
receiver = get_receiver(max_async_tasks=None)
assert receiver.sem is None


async def test_prefetcher_does_not_pop_message_past_max_tasks() -> None:
"""Test not pulling a message without the intention of running it."""
broker = AsyncQueueBroker()

@broker.task
async def noop() -> None:
return None

for _ in range(6):
await noop.kiq()

assert broker.queue.qsize() == 6

receiver = Receiver(
broker,
executor=ThreadPoolExecutor(max_workers=1),
max_async_tasks=1,
max_tasks_to_execute=5,
)

await receiver.listen(asyncio.Event())

assert broker.queue.qsize() == 1


async def test_prefetcher_recovers_from_transient_listen_error() -> None:
"""A transient error mid-prefetch must not kill the prefetcher."""

class FlakyBroker(AsyncQueueBroker):
def __init__(self) -> None:
super().__init__()
self.fail_once = True

async def listen(self) -> AsyncGenerator[AckableMessage, None]:
while True:
data = await self.queue.get()
if self.fail_once:
self.fail_once = False
self.queue.task_done()
raise RuntimeError("transient broker hiccup")
yield AckableMessage(data=data, ack=self.queue.task_done)

broker = FlakyBroker()
ran = 0

@broker.task
async def collector() -> None:
nonlocal ran
ran += 1

await collector.kiq() # consumed by the transient error
await collector.kiq() # prefetcher recovering

receiver = Receiver(
broker,
executor=ThreadPoolExecutor(max_workers=1),
max_async_tasks=1,
max_tasks_to_execute=1,
)

await asyncio.wait_for(receiver.listen(asyncio.Event()), timeout=5)

assert ran == 1