refactor(api): refactor refresh logic
This commit is contained in:
		@@ -4,9 +4,9 @@ from fastapi import (
 | 
			
		||||
    APIRouter,
 | 
			
		||||
    Depends,
 | 
			
		||||
    HTTPException,
 | 
			
		||||
    Request,
 | 
			
		||||
    Response,
 | 
			
		||||
    status,
 | 
			
		||||
    Request,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from loguru import logger
 | 
			
		||||
@@ -22,7 +22,7 @@ from api.services.auth import authenticate_user
 | 
			
		||||
 | 
			
		||||
from api.db.logic.auth import add_new_refresh_token, upgrade_old_refresh_token
 | 
			
		||||
 | 
			
		||||
from api.schemas.endpoints.auth import Auth, Access
 | 
			
		||||
from api.schemas.endpoints.auth import Auth, Tokens
 | 
			
		||||
 | 
			
		||||
api_router = APIRouter(
 | 
			
		||||
    prefix="/auth",
 | 
			
		||||
@@ -33,7 +33,7 @@ api_router = APIRouter(
 | 
			
		||||
class Settings(BaseModel):
 | 
			
		||||
    authjwt_secret_key: str = get_settings().SECRET_KEY
 | 
			
		||||
    # Configure application to store and get JWT from cookies
 | 
			
		||||
    authjwt_token_location: set = {"headers", "cookies"}
 | 
			
		||||
    authjwt_token_location: set = {"headers"}
 | 
			
		||||
    authjwt_cookie_domain: str = get_settings().DOMAIN
 | 
			
		||||
 | 
			
		||||
    # Only allow JWT cookies to be sent over https
 | 
			
		||||
@@ -48,7 +48,7 @@ def get_config():
 | 
			
		||||
    return Settings()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@api_router.post("", response_model=Access)
 | 
			
		||||
@api_router.post("", response_model=Tokens)
 | 
			
		||||
async def login_for_access_token(
 | 
			
		||||
    user: Auth,
 | 
			
		||||
    response: Response,
 | 
			
		||||
@@ -69,7 +69,6 @@ async def login_for_access_token(
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
			
		||||
 | 
			
		||||
    refresh_token_expires = timedelta(days=get_settings().REFRESH_TOKEN_EXPIRE_DAYS)
 | 
			
		||||
 | 
			
		||||
    logger.debug(f"refresh_token_expires {refresh_token_expires}")
 | 
			
		||||
@@ -81,24 +80,20 @@ async def login_for_access_token(
 | 
			
		||||
 | 
			
		||||
    await add_new_refresh_token(connection, refresh_token, refresh_token_expires_time, user)
 | 
			
		||||
 | 
			
		||||
    Authorize.set_refresh_cookies(refresh_token)
 | 
			
		||||
 | 
			
		||||
    return Access(access_token=access_token)
 | 
			
		||||
    return Tokens(access_token=access_token, refresh_token=refresh_token)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@api_router.post("/refresh", response_model=Access)
 | 
			
		||||
@api_router.post("/refresh", response_model=Tokens)
 | 
			
		||||
async def refresh(
 | 
			
		||||
    request: Request, connection: AsyncConnection = Depends(get_connection_dep), Authorize: AuthJWT = Depends()
 | 
			
		||||
):
 | 
			
		||||
    refresh_token = request.cookies.get("refresh_token_cookie")
 | 
			
		||||
 | 
			
		||||
    if not refresh_token:
 | 
			
		||||
        raise HTTPException(status_code=401, detail="Refresh token is missing")
 | 
			
		||||
 | 
			
		||||
    request: Request,
 | 
			
		||||
    connection: AsyncConnection = Depends(get_connection_dep),
 | 
			
		||||
    Authorize: AuthJWT = Depends(),
 | 
			
		||||
) -> Tokens:
 | 
			
		||||
    try:
 | 
			
		||||
        Authorize.jwt_refresh_token_required(refresh_token)
 | 
			
		||||
        current_user = Authorize._verified_token(refresh_token).get("sub")
 | 
			
		||||
        Authorize.jwt_refresh_token_required()
 | 
			
		||||
        current_user = Authorize.get_jwt_subject()
 | 
			
		||||
    except Exception:
 | 
			
		||||
        refresh_token = request.headers.get("Authorization").split(" ")[1]
 | 
			
		||||
        await upgrade_old_refresh_token(connection, refresh_token)
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_401_UNAUTHORIZED,
 | 
			
		||||
@@ -108,4 +103,4 @@ async def refresh(
 | 
			
		||||
    access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
			
		||||
    new_access_token = Authorize.create_access_token(subject=current_user, expires_time=access_token_expires)
 | 
			
		||||
 | 
			
		||||
    return Access(access_token=new_access_token)
 | 
			
		||||
    return Tokens(access_token=new_access_token)
 | 
			
		||||
 
 | 
			
		||||
@@ -8,9 +8,6 @@ class Auth(Base):
 | 
			
		||||
    password: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Refresh(Base):
 | 
			
		||||
    refresh_token: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Access(Base):
 | 
			
		||||
class Tokens(Base):
 | 
			
		||||
    access_token: str
 | 
			
		||||
    refresh_token: str | None = None
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user