Source code for queuebridge.hints

"""Type hints utilities for decoding task arguments and return values."""

from __future__ import annotations

import inspect
from dataclasses import dataclass
from typing import Any, Callable, get_type_hints

from queuebridge.codec import decode

_SKIP_PARAMS = frozenset({"self", "cls", "ctx"})

_signature_cache: dict[int, TaskSignature] = {}


[docs] @dataclass(frozen=True) class TaskSignature: """Cached type hints for a task function. Attributes: params: Mapping of parameter name to annotation (skips ``self``, ``cls``, ``ctx``). return_type: Return annotation, or ``Any`` if missing. """ params: dict[str, Any] return_type: Any
[docs] def get_task_signature(fn: Callable[..., Any]) -> TaskSignature: """Extract and cache parameter/return type hints from a callable. Args: fn: Task function (sync, async, or decorated). Returns: :class:`TaskSignature` with ``params`` and ``return_type``. """ cache_key = id(fn) if cache_key in _signature_cache: return _signature_cache[cache_key] hints = get_type_hints(fn, include_extras=True) sig = inspect.signature(fn) params: dict[str, Any] = {} for name, param in sig.parameters.items(): if name in _SKIP_PARAMS: continue if name in hints: params[name] = hints[name] return_type = hints.get("return", Any) task_sig = TaskSignature(params=params, return_type=return_type) _signature_cache[cache_key] = task_sig return task_sig
[docs] def decode_args( fn: Callable[..., Any], args: tuple[Any, ...], kwargs: dict[str, Any], ) -> tuple[tuple[Any, ...], dict[str, Any]]: """Decode wire ``args`` and ``kwargs`` using ``fn``'s type hints. Skips decoding for ``self``, ``cls``, and ``ctx`` parameters (Arq/Celery). Args: fn: Task function whose annotations guide decoding. args: Positional wire values. kwargs: Keyword wire values. Returns: Tuple of ``(decoded_args, decoded_kwargs)``. """ sig = get_task_signature(fn) inspect_sig = inspect.signature(fn) all_param_names = list(inspect_sig.parameters.keys()) new_args_list: list[Any] = [] for i, arg in enumerate(args): if i < len(all_param_names): name = all_param_names[i] if name in _SKIP_PARAMS: new_args_list.append(arg) else: hint = sig.params.get(name, Any) new_args_list.append(decode(arg, hint)) else: new_args_list.append(arg) new_kwargs = {k: decode(v, sig.params.get(k, Any)) for k, v in kwargs.items()} return tuple(new_args_list), new_kwargs
[docs] def decode_return(fn: Callable[..., Any], result: Any) -> Any: """Decode a task return value using ``fn``'s return type hint. Args: fn: Task function. result: Wire return value from the result backend. Returns: Decoded Python object. """ sig = get_task_signature(fn) return decode(result, sig.return_type)