diff --git a/confluent_server/confluent/tasks.py b/confluent_server/confluent/tasks.py index 06f1bf41..85a66c75 100644 --- a/confluent_server/confluent/tasks.py +++ b/confluent_server/confluent/tasks.py @@ -46,11 +46,13 @@ class TaskHolder: class TaskPile: def __init__(self, pool): self.pool = pool - self._tasks = set() + self._tasks = {} + self._taskholders = set() def spawn(self, coro_func, *args): task = self.pool.schedule(coro_func, *args) - self._tasks.add(task) + self._taskholders.add(task) + self._tasks[task._task] = task return task def __aiter__(self): @@ -59,10 +61,12 @@ class TaskPile: async def __anext__(self): if not self._tasks: raise StopAsyncIteration - done, _ = await asyncio.wait(self._tasks, return_when=asyncio.FIRST_COMPLETED) + done, _ = await asyncio.wait(self._tasks.keys(), return_when=asyncio.FIRST_COMPLETED) for task in done: - self._tasks.discard(task) - return task + taskhlder = self._tasks[task] + self._taskholders.discard(taskhlder) + self._tasks.pop(task, None) + return task.result() class TaskPool: def __init__(self, max_concurrent=128):