Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 64 additions & 22 deletions tests/integration/_utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import asyncio
import inspect
import secrets
import string
import time
from collections.abc import AsyncIterator, Iterator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, overload
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast, overload

import pytest

if TYPE_CHECKING:
from collections.abc import Callable, Coroutine
from collections.abc import Awaitable, Callable

# Environment variable names for test configuration
TOKEN_ENV_VAR = 'APIFY_TEST_USER_API_TOKEN'
Expand Down Expand Up @@ -92,22 +93,14 @@ def get_random_resource_name(label: str) -> str:
return name_template.format(label, get_crypto_random_object_id(random_id_length))


@overload
async def maybe_await(value: Coroutine[Any, Any, T]) -> T: ...


@overload
async def maybe_await(value: T) -> T: ...


async def maybe_await(value: T | Coroutine[Any, Any, T]) -> T:
"""Await coroutines, pass through other values.
async def maybe_await(value: Awaitable[T] | T) -> T:
"""Await `value` if it is awaitable, otherwise return it unchanged.

Enables unified test code for both sync and async clients:
result = await maybe_await(client.datasets().list())
"""
if hasattr(value, '__await__'):
return await value # ty: ignore[invalid-await]
if inspect.isawaitable(value):
return await cast('Awaitable[T]', value)
return value


Expand All @@ -119,6 +112,49 @@ async def maybe_sleep(seconds: float, *, is_async: bool) -> None:
time.sleep(seconds) # noqa: ASYNC251


@overload
async def poll_until_condition(
fn: Callable[[], Awaitable[T]],
condition: Callable[[T], bool] = ...,
*,
timeout: float = ...,
poll_interval: float = ...,
) -> T: ...
@overload
async def poll_until_condition(
fn: Callable[[], T],
condition: Callable[[T], bool] = ...,
*,
timeout: float = ...,
poll_interval: float = ...,
) -> T: ...
async def poll_until_condition(
fn: Callable[[], Awaitable[T] | T],
condition: Callable[[T], bool] = bool,
*,
timeout: float = 5,
poll_interval: float = 1,
) -> T:
"""Poll `fn` until `condition(result)` is True or the timeout expires.

Polls `fn` at `poll_interval`-second intervals until `condition` is satisfied or `timeout` seconds have elapsed.
Returns the last polled result regardless of whether the condition was met, so the caller can run its own
assertion. The default condition checks for a truthy result.

Use this instead of a fixed `asyncio.sleep` when waiting for eventually-consistent state (e.g. a freshly
created resource appearing in a listing) that may take a variable amount of time to propagate.
"""
deadline = time.monotonic() + timeout
result = await maybe_await(fn())
while not condition(result):
remaining = deadline - time.monotonic()
if remaining <= 0:
break
await asyncio.sleep(min(poll_interval, remaining))
result = await maybe_await(fn())
return result


async def collect_iterate_until_present(
iterator_factory: Callable[[], Iterator[_HasIdT] | AsyncIterator[_HasIdT]],
expected_ids: set[str],
Expand All @@ -132,7 +168,7 @@ async def collect_iterate_until_present(

Handles eventual consistency on listing endpoints: under parallel load a freshly
created resource may not appear in the listing for a short window. Each attempt
builds a fresh iterator via `iterator_factory`, drains it, and breaks early once
builds a fresh iterator via `iterator_factory`, drains it, and stops early once
`expected_ids` is a subset of the collected items' `.id` values. The most recent
collection is returned regardless of whether the condition was met, so the caller
can run its own assertion with a helpful failure message.
Expand All @@ -142,18 +178,16 @@ async def collect_iterate_until_present(
expected_ids: IDs that must all appear in the collected items.
item_type: Asserted to match the runtime type of each yielded item.
is_async: Whether the iterator is async (and so are sleeps).
max_attempts: Maximum number of polling rounds.
interval: Seconds to sleep before each attempt.
max_attempts: Maximum number of polling rounds, guaranteed regardless of how long each drain takes.
interval: Seconds to sleep between attempts.

Returns:
The most recently collected items.
"""
collected: list[_HasIdT] = []
for attempt in range(max_attempts):
if attempt > 0:
await maybe_sleep(interval, is_async=is_async)

async def drain() -> list[_HasIdT]:
iterator = iterator_factory()
collected = []
collected: list[_HasIdT] = []
if is_async:
assert isinstance(iterator, AsyncIterator)
async for item in iterator:
Expand All @@ -164,8 +198,16 @@ async def collect_iterate_until_present(
for item in iterator:
assert isinstance(item, item_type)
collected.append(item)
return collected

# Loop on attempt count rather than a wall-clock deadline: drains take HTTP time, and charging it
# against a deadline would mean fewer retries under load — exactly when they are needed most.
collected = await drain()
for _ in range(max_attempts - 1):
if expected_ids.issubset(item.id for item in collected):
break
await maybe_sleep(interval, is_async=is_async)
collected = await drain()
return collected


Expand Down
12 changes: 12 additions & 0 deletions tests/integration/test_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_random_string,
maybe_await,
maybe_sleep,
poll_until_condition,
)
from apify_client._models import (
BatchAddResult,
Expand Down Expand Up @@ -560,6 +561,17 @@ async def test_request_queue_unlock_requests(client: ApifyClient | ApifyClientAs
assert isinstance(result, LockedRequestQueueHead)
lock_response = result
assert len(lock_response.items) == 3
locked_ids = {item.id for item in lock_response.items}

# Locks are acknowledged before they are visible to subsequent reads, so unlocking immediately can
# see fewer locks than were just acquired. Since locked requests are excluded from the queue head,
# poll `list_head` until the locked IDs disappear from it (best-effort mitigation of the race).
async def all_locks_visible() -> bool:
head = await maybe_await(rq_client.list_head(limit=5))
assert isinstance(head, RequestQueueHead)
return locked_ids.isdisjoint(item.id for item in head.items)

await poll_until_condition(all_locks_visible, timeout=30, poll_interval=1)

# Unlock all requests
unlock_response = await maybe_await(rq_client.unlock_requests())
Expand Down
Loading