"""Unit tests for src/my_deepagent/middleware/fallback.py.""" from __future__ import annotations from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx import openai import pytest from my_deepagent.middleware.fallback import FallbackModelMiddleware # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_request(has_model_attr: bool = True) -> MagicMock: request = MagicMock() if not has_model_attr: del request.model return request # --------------------------------------------------------------------------- # Fallback on RateLimitError # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_fallback_on_rate_limit_error_calls_handler_with_fallback() -> None: primary = MagicMock(name="primary-model") fallback = MagicMock(name="fallback-model") mw = FallbackModelMiddleware(primary=primary, fallback=fallback) call_count = 0 fallback_model_seen: Any = None async def handler(request: Any) -> str: nonlocal call_count, fallback_model_seen call_count += 1 if call_count == 1: raise openai.RateLimitError( "rate limit", response=MagicMock(status_code=429, headers={}), body={}, ) fallback_model_seen = getattr(request, "model", None) return "fallback-response" request = _make_request() result = await mw.awrap_model_call(request, handler) assert result == "fallback-response" assert call_count == 2 assert fallback_model_seen is fallback @pytest.mark.asyncio async def test_fallback_on_api_connection_error() -> None: primary = MagicMock() fallback = MagicMock() mw = FallbackModelMiddleware(primary=primary, fallback=fallback) call_count = 0 async def handler(request: Any) -> str: nonlocal call_count call_count += 1 if call_count == 1: raise openai.APIConnectionError(request=MagicMock()) return "connection-fallback" result = await mw.awrap_model_call(_make_request(), handler) assert result == "connection-fallback" assert call_count == 2 @pytest.mark.asyncio async def test_fallback_on_httpx_error() -> None: primary = MagicMock() fallback = MagicMock() mw = FallbackModelMiddleware(primary=primary, fallback=fallback) call_count = 0 async def handler(request: Any) -> str: nonlocal call_count call_count += 1 if call_count == 1: raise httpx.ConnectError("connect failed") return "httpx-fallback" result = await mw.awrap_model_call(_make_request(), handler) assert result == "httpx-fallback" assert call_count == 2 # --------------------------------------------------------------------------- # No fallback — exception propagates # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_no_fallback_raises_original_error() -> None: mw = FallbackModelMiddleware(primary=MagicMock(), fallback=None) handler = AsyncMock( side_effect=openai.RateLimitError( "rate limit", response=MagicMock(status_code=429, headers={}), body={}, ) ) with pytest.raises(openai.RateLimitError): await mw.awrap_model_call(_make_request(), handler) # --------------------------------------------------------------------------- # AuthenticationError — never retried # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_auth_error_is_not_retried() -> None: primary = MagicMock() fallback = MagicMock() mw = FallbackModelMiddleware(primary=primary, fallback=fallback) call_count = 0 async def handler(request: Any) -> str: nonlocal call_count call_count += 1 raise openai.AuthenticationError( "bad api key", response=MagicMock(status_code=401, headers={}), body={}, ) with pytest.raises(openai.AuthenticationError): await mw.awrap_model_call(_make_request(), handler) # Handler should only be called once (no retry for auth errors) assert call_count == 1 # --------------------------------------------------------------------------- # _with_fallback_model # --------------------------------------------------------------------------- def test_with_fallback_model_swaps_model_attribute() -> None: primary = MagicMock(name="primary") fallback = MagicMock(name="fallback") mw = FallbackModelMiddleware(primary=primary, fallback=fallback) request = MagicMock() request.model = primary patched = mw._with_fallback_model(request) assert patched.model is fallback def test_with_fallback_model_no_model_attr_does_not_crash() -> None: mw = FallbackModelMiddleware(primary=MagicMock(), fallback=MagicMock()) request = MagicMock(spec=[]) # no attributes # Should not raise patched = mw._with_fallback_model(request) assert patched is request