Source code for fastapi_injectable.decorator

import inspect
import types
from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Generator
from functools import wraps
from typing import TYPE_CHECKING, Annotated, Any, ParamSpec, TypeVar, Union, cast, get_args, get_origin, overload

import fastapi
import fastapi.params

from .concurrency import run_coroutine_sync
from .main import resolve_dependencies

T = TypeVar("T")
P = ParamSpec("P")

if TYPE_CHECKING:

    def set_original_func(wrapper: Any, target: Any) -> None:  # noqa: ANN401
        pass
else:

    def set_original_func(wrapper: Any, target: Any) -> None:  # noqa: ANN401
        wrapper.__original_func__ = target


def _override_func_dependency_signature(func: Callable[P, T] | Callable[P, Awaitable[T]]) -> None:  # pragma: no cover
    """Override the function signature to make dependency-injected parameters optional."""
    signature = inspect.signature(func)
    new_parameters = []
    for param in signature.parameters.values():
        using_annotated_and_default_is_empty = (
            get_origin(param.annotation) is Annotated
            and param.annotation.__metadata__
            and param.default is inspect.Parameter.empty
        )
        parameter = param
        if using_annotated_and_default_is_empty:
            fastapi_default = None
            for metadata in param.annotation.__metadata__:
                if type(metadata) is fastapi.params.Depends:
                    fastapi_default = metadata
                    break
            if fastapi_default:
                actual_type = get_args(param.annotation)[0]
                origin = get_origin(actual_type)

                base_for_class = actual_type
                if origin is Union or origin is types.UnionType:
                    union_args = get_args(actual_type)
                    base_for_class = next(
                        (t for t in union_args if t is not type(None)),
                        union_args[0] if union_args else object,
                    )

                base_class = get_origin(base_for_class) or base_for_class

                dynamic_default = types.new_class(
                    "Injected_" + getattr(base_class, "__name__", "Injected"),
                    (base_class,),
                    {},
                    lambda ns: ns.update(
                        {
                            "__init__": lambda self, *args, **kwargs: None,  # noqa: ARG005
                            **{
                                method: lambda *args, **kwargs: None  # noqa: ARG005
                                for method in getattr(base_class, "__abstractmethods__", [])  # noqa: B023
                            },
                        }
                    ),
                )
                parameter = inspect.Parameter.replace(param, default=dynamic_default())
        new_parameters.append(parameter)
    func.__signature__ = signature.replace(parameters=new_parameters)  # type: ignore[union-attr]


@overload
def injectable(
    func: Callable[P, Generator[T, Any, Any]],
    *,
    use_cache: bool = True,
) -> Callable[P, Generator[T, Any, Any]]: ...


@overload
def injectable(
    func: Callable[P, AsyncGenerator[T, Any]],
    *,
    use_cache: bool = True,
) -> Callable[P, AsyncGenerator[T, Any]]: ...


@overload
def injectable(
    func: Callable[P, T],
    *,
    use_cache: bool = True,
) -> Callable[P, T]: ...


@overload
def injectable(
    *,
    use_cache: bool = True,
) -> Callable[[Callable[P, T]], Callable[P, T]]: ...


# overloads above provide actual type-hints
# Callable[..., Any] is just a place holder
[docs] def injectable( func: Callable[..., Any] | None = None, *, use_cache: bool = True, ) -> Callable[..., Any]: """Decorator to inject dependencies into any callable, sync or async.""" def decorator( target: Callable[P, T] | Callable[P, Awaitable[T]] | Callable[P, AsyncGenerator[T, Any]], ) -> Callable[P, T] | Callable[P, Awaitable[T]] | Callable[P, AsyncGenerator[T, Any]]: # Override the function signature to make dependency-injected parameters optional for packages like typer, cyclopt, etc. # noqa: E501 _override_func_dependency_signature(target) # Note, these are mutually exclusive is_async = inspect.iscoroutinefunction(target) is_async_generator = inspect.isasyncgenfunction(target) @wraps(target) async def async_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: dependencies = await resolve_dependencies(func=target, use_cache=use_cache, provided_kwargs=kwargs) return await cast("Callable[..., Coroutine[Any, Any, T]]", target)(*args, **{**dependencies, **kwargs}) @wraps(target) async def async_gen_wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[T, Any]: dependencies = await resolve_dependencies(func=target, use_cache=use_cache, provided_kwargs=kwargs) async for x in cast("Callable[..., AsyncGenerator[T, Any]]", target)(*args, **{**dependencies, **kwargs}): yield x @wraps(target) def sync_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: dependencies = run_coroutine_sync( resolve_dependencies(func=target, use_cache=use_cache, provided_kwargs=kwargs) ) return cast("Callable[..., T]", target)(*args, **{**dependencies, **kwargs}) if is_async_generator: set_original_func(async_gen_wrapper, target) return async_gen_wrapper if is_async: set_original_func(async_wrapper, target) return async_wrapper set_original_func(sync_wrapper, target) return sync_wrapper if func is None: return decorator decorated_func = decorator(func) set_original_func(decorated_func, func) return decorated_func