refactor(api): refactor refresh logic

This commit is contained in:
Vladislav Syrochkin 2025-06-16 12:46:14 +05:00
parent c87581c9e2
commit ee92428ec3
2 changed files with 16 additions and 24 deletions

View File

@ -4,9 +4,9 @@ from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
HTTPException, HTTPException,
Request,
Response, Response,
status, status,
Request,
) )
from loguru import logger from loguru import logger
@ -22,7 +22,7 @@ from api.services.auth import authenticate_user
from api.db.logic.auth import add_new_refresh_token, upgrade_old_refresh_token from api.db.logic.auth import add_new_refresh_token, upgrade_old_refresh_token
from api.schemas.endpoints.auth import Auth, Access from api.schemas.endpoints.auth import Auth, Tokens
api_router = APIRouter( api_router = APIRouter(
prefix="/auth", prefix="/auth",
@ -33,7 +33,7 @@ api_router = APIRouter(
class Settings(BaseModel): class Settings(BaseModel):
authjwt_secret_key: str = get_settings().SECRET_KEY authjwt_secret_key: str = get_settings().SECRET_KEY
# Configure application to store and get JWT from cookies # Configure application to store and get JWT from cookies
authjwt_token_location: set = {"headers", "cookies"} authjwt_token_location: set = {"headers"}
authjwt_cookie_domain: str = get_settings().DOMAIN authjwt_cookie_domain: str = get_settings().DOMAIN
# Only allow JWT cookies to be sent over https # Only allow JWT cookies to be sent over https
@ -48,7 +48,7 @@ def get_config():
return Settings() return Settings()
@api_router.post("", response_model=Access) @api_router.post("", response_model=Tokens)
async def login_for_access_token( async def login_for_access_token(
user: Auth, user: Auth,
response: Response, response: Response,
@ -69,7 +69,6 @@ async def login_for_access_token(
) )
access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
refresh_token_expires = timedelta(days=get_settings().REFRESH_TOKEN_EXPIRE_DAYS) refresh_token_expires = timedelta(days=get_settings().REFRESH_TOKEN_EXPIRE_DAYS)
logger.debug(f"refresh_token_expires {refresh_token_expires}") logger.debug(f"refresh_token_expires {refresh_token_expires}")
@ -81,24 +80,20 @@ async def login_for_access_token(
await add_new_refresh_token(connection, refresh_token, refresh_token_expires_time, user) await add_new_refresh_token(connection, refresh_token, refresh_token_expires_time, user)
Authorize.set_refresh_cookies(refresh_token) return Tokens(access_token=access_token, refresh_token=refresh_token)
return Access(access_token=access_token)
@api_router.post("/refresh", response_model=Access) @api_router.post("/refresh", response_model=Tokens)
async def refresh( async def refresh(
request: Request, connection: AsyncConnection = Depends(get_connection_dep), Authorize: AuthJWT = Depends() request: Request,
): connection: AsyncConnection = Depends(get_connection_dep),
refresh_token = request.cookies.get("refresh_token_cookie") Authorize: AuthJWT = Depends(),
) -> Tokens:
if not refresh_token:
raise HTTPException(status_code=401, detail="Refresh token is missing")
try: try:
Authorize.jwt_refresh_token_required(refresh_token) Authorize.jwt_refresh_token_required()
current_user = Authorize._verified_token(refresh_token).get("sub") current_user = Authorize.get_jwt_subject()
except Exception: except Exception:
refresh_token = request.headers.get("Authorization").split(" ")[1]
await upgrade_old_refresh_token(connection, refresh_token) await upgrade_old_refresh_token(connection, refresh_token)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -108,4 +103,4 @@ async def refresh(
access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES) access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
new_access_token = Authorize.create_access_token(subject=current_user, expires_time=access_token_expires) new_access_token = Authorize.create_access_token(subject=current_user, expires_time=access_token_expires)
return Access(access_token=new_access_token) return Tokens(access_token=new_access_token)

View File

@ -8,9 +8,6 @@ class Auth(Base):
password: str password: str
class Refresh(Base): class Tokens(Base):
refresh_token: str
class Access(Base):
access_token: str access_token: str
refresh_token: str | None = None