refactor: middleware and refresh
This commit is contained in:
		@@ -73,6 +73,7 @@ if __name__ == "__main__":
 | 
			
		||||
            log_level="info",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
app.add_middleware(MiddlewareAccessTokenValidadtion)
 | 
			
		||||
app.add_middleware(
 | 
			
		||||
    CORSMiddleware,
 | 
			
		||||
    allow_origins=origins,
 | 
			
		||||
@@ -80,5 +81,3 @@ app.add_middleware(
 | 
			
		||||
    allow_methods=["GET", "POST", "OPTIONS", "DELETE", "PUT"],
 | 
			
		||||
    allow_headers=["*"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
app.add_middleware(MiddlewareAccessTokenValidadtion)
 | 
			
		||||
 
 | 
			
		||||
@@ -50,13 +50,12 @@ async def get_user(connection: AsyncConnection, login: str) -> Optional[User]:
 | 
			
		||||
    return user, password
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def upgrade_old_refresh_token(connection: AsyncConnection, user, refresh_token) -> Optional[User]:
 | 
			
		||||
async def upgrade_old_refresh_token(connection: AsyncConnection, refresh_token) -> Optional[User]:
 | 
			
		||||
    new_status = KeyStatus.EXPIRED
 | 
			
		||||
 | 
			
		||||
    update_query = (
 | 
			
		||||
        update(account_keyring_table)
 | 
			
		||||
        .where(
 | 
			
		||||
            account_table.c.id == user.id,
 | 
			
		||||
            account_keyring_table.c.status == KeyStatus.ACTIVE,
 | 
			
		||||
            account_keyring_table.c.key_type == KeyType.REFRESH_TOKEN,
 | 
			
		||||
            account_keyring_table.c.key_value == refresh_token,
 | 
			
		||||
 
 | 
			
		||||
@@ -91,25 +91,21 @@ async def refresh(
 | 
			
		||||
    request: Request, connection: AsyncConnection = Depends(get_connection_dep), Authorize: AuthJWT = Depends()
 | 
			
		||||
):
 | 
			
		||||
    refresh_token = request.cookies.get("refresh_token_cookie")
 | 
			
		||||
    # print("Refresh Token:", refresh_token)
 | 
			
		||||
 | 
			
		||||
    if not refresh_token:
 | 
			
		||||
        raise HTTPException(status_code=401, detail="Refresh token is missing")
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        Authorize.jwt_refresh_token_required()
 | 
			
		||||
        current_user = Authorize.get_jwt_subject()
 | 
			
		||||
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        await upgrade_old_refresh_token(connection, current_user, refresh_token)
 | 
			
		||||
 | 
			
		||||
        Authorize.jwt_refresh_token_required(refresh_token)
 | 
			
		||||
        current_user = Authorize._verified_token(refresh_token).get("sub")
 | 
			
		||||
    except Exception:
 | 
			
		||||
        await upgrade_old_refresh_token(connection, refresh_token)
 | 
			
		||||
        raise HTTPException(
 | 
			
		||||
            status_code=status.HTTP_401_UNAUTHORIZED,
 | 
			
		||||
            detail="Invalid refresh token",
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
 | 
			
		||||
 | 
			
		||||
    new_access_token = Authorize.create_access_token(subject=current_user, expires_time=access_token_expires)
 | 
			
		||||
 | 
			
		||||
    return Access(access_token=new_access_token)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,4 @@
 | 
			
		||||
from fastapi_jwt_auth import AuthJWT
 | 
			
		||||
from starlette.middleware.base import BaseHTTPMiddleware
 | 
			
		||||
from fastapi import (
 | 
			
		||||
    Request,
 | 
			
		||||
@@ -11,9 +12,6 @@ import re
 | 
			
		||||
from re import escape
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from fastapi_jwt_auth import AuthJWT
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MiddlewareAccessTokenValidadtion(BaseHTTPMiddleware):
 | 
			
		||||
    def __init__(self, app):
 | 
			
		||||
        super().__init__(app)
 | 
			
		||||
@@ -22,40 +20,34 @@ class MiddlewareAccessTokenValidadtion(BaseHTTPMiddleware):
 | 
			
		||||
        self.excluded_routes = [
 | 
			
		||||
            re.compile(r"^" + re.escape(self.prefix) + r"/auth/refresh/?$"),
 | 
			
		||||
            re.compile(r"^" + re.escape(self.prefix) + r"/auth/?$"),
 | 
			
		||||
            re.compile(r"^" + r"/swagger"),
 | 
			
		||||
            re.compile(r"^" + r"/openapi"),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    async def dispatch(self, request: Request, call_next):
 | 
			
		||||
        if request.method in ["GET", "POST", "PUT", "DELETE"]:
 | 
			
		||||
            if any(pattern.match(request.url.path) for pattern in self.excluded_routes):
 | 
			
		||||
                return await call_next(request)
 | 
			
		||||
            else:
 | 
			
		||||
                auth_header = request.headers.get("Authorization")
 | 
			
		||||
                if not auth_header:
 | 
			
		||||
                    return JSONResponse(
 | 
			
		||||
                        status_code=status.HTTP_401_UNAUTHORIZED,
 | 
			
		||||
                        content={"detail": "Missing authorization header."},
 | 
			
		||||
                        headers={"WWW-Authenticate": "Bearer"},
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                token = auth_header.split(" ")[1]
 | 
			
		||||
                Authorize = AuthJWT(request)
 | 
			
		||||
 | 
			
		||||
                try:
 | 
			
		||||
                    current_user = Authorize.get_jwt_subject()
 | 
			
		||||
                    request.state.current_user = current_user
 | 
			
		||||
                    return await call_next(request)
 | 
			
		||||
 | 
			
		||||
                except Exception:
 | 
			
		||||
                    return JSONResponse(
 | 
			
		||||
                        status_code=status.HTTP_401_UNAUTHORIZED,
 | 
			
		||||
                        content={"detail": "The access token is invalid or expired."},
 | 
			
		||||
                        headers={"WWW-Authenticate": "Bearer"},
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                # async with get_connection() as connection:
 | 
			
		||||
                #     authorize_user = await get_user_login(connection, current_user)
 | 
			
		||||
                # print(authorize_user)
 | 
			
		||||
                # if authorize_user is None :
 | 
			
		||||
                #     return JSONResponse(
 | 
			
		||||
                #         status_code=status.HTTP_404_NOT_FOUND ,
 | 
			
		||||
                #         detail="User not found.")
 | 
			
		||||
        if request.method not in ["GET", "POST", "PUT", "DELETE"]:
 | 
			
		||||
            return JSONResponse(
 | 
			
		||||
                status_code=status.HTTP_405_METHOD_NOT_ALLOWED,
 | 
			
		||||
                content={"detail": "Method not allowed"},
 | 
			
		||||
            )
 | 
			
		||||
        if any(pattern.match(request.url.path) for pattern in self.excluded_routes):
 | 
			
		||||
            return await call_next(request)
 | 
			
		||||
        auth_header = request.headers.get("Authorization")
 | 
			
		||||
        if not auth_header:
 | 
			
		||||
            return JSONResponse(
 | 
			
		||||
                status_code=status.HTTP_401_UNAUTHORIZED,
 | 
			
		||||
                content={"detail": "Missing authorization header."},
 | 
			
		||||
                headers={"WWW-Authenticate": "Bearer"},
 | 
			
		||||
            )
 | 
			
		||||
        try:
 | 
			
		||||
            token = auth_header.split(" ")[1]
 | 
			
		||||
            Authorize = AuthJWT(request)
 | 
			
		||||
            current_user = Authorize.get_jwt_subject()
 | 
			
		||||
            request.state.current_user = current_user
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return JSONResponse(
 | 
			
		||||
                status_code=status.HTTP_401_UNAUTHORIZED,
 | 
			
		||||
                content={"detail": "The access token is invalid or expired."},
 | 
			
		||||
                headers={"WWW-Authenticate": "Bearer"},
 | 
			
		||||
            )
 | 
			
		||||
        return await call_next(request)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user