diff --git a/api/api/__main__.py b/api/api/__main__.py index 6e57ec2..f00b342 100644 --- a/api/api/__main__.py +++ b/api/api/__main__.py @@ -73,6 +73,7 @@ if __name__ == "__main__": log_level="info", ) +app.add_middleware(MiddlewareAccessTokenValidadtion) app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -80,5 +81,3 @@ app.add_middleware( allow_methods=["GET", "POST", "OPTIONS", "DELETE", "PUT"], allow_headers=["*"], ) - -app.add_middleware(MiddlewareAccessTokenValidadtion) diff --git a/api/api/db/logic/auth.py b/api/api/db/logic/auth.py index 393598d..20e9573 100644 --- a/api/api/db/logic/auth.py +++ b/api/api/db/logic/auth.py @@ -50,13 +50,12 @@ async def get_user(connection: AsyncConnection, login: str) -> Optional[User]: return user, password -async def upgrade_old_refresh_token(connection: AsyncConnection, user, refresh_token) -> Optional[User]: +async def upgrade_old_refresh_token(connection: AsyncConnection, refresh_token) -> Optional[User]: new_status = KeyStatus.EXPIRED update_query = ( update(account_keyring_table) .where( - account_table.c.id == user.id, account_keyring_table.c.status == KeyStatus.ACTIVE, account_keyring_table.c.key_type == KeyType.REFRESH_TOKEN, account_keyring_table.c.key_value == refresh_token, diff --git a/api/api/endpoints/auth.py b/api/api/endpoints/auth.py index bd0cfe8..edad474 100644 --- a/api/api/endpoints/auth.py +++ b/api/api/endpoints/auth.py @@ -91,25 +91,21 @@ async def refresh( request: Request, connection: AsyncConnection = Depends(get_connection_dep), Authorize: AuthJWT = Depends() ): refresh_token = request.cookies.get("refresh_token_cookie") - # print("Refresh Token:", refresh_token) if not refresh_token: raise HTTPException(status_code=401, detail="Refresh token is missing") try: - Authorize.jwt_refresh_token_required() - current_user = Authorize.get_jwt_subject() - - except Exception as e: - await upgrade_old_refresh_token(connection, current_user, refresh_token) - + Authorize.jwt_refresh_token_required(refresh_token) + current_user = Authorize._verified_token(refresh_token).get("sub") + except Exception: + await upgrade_old_refresh_token(connection, refresh_token) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid refresh token", ) access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES) - new_access_token = Authorize.create_access_token(subject=current_user, expires_time=access_token_expires) return Access(access_token=new_access_token) diff --git a/api/api/services/middleware.py b/api/api/services/middleware.py index a44070d..68e511c 100644 --- a/api/api/services/middleware.py +++ b/api/api/services/middleware.py @@ -1,3 +1,4 @@ +from fastapi_jwt_auth import AuthJWT from starlette.middleware.base import BaseHTTPMiddleware from fastapi import ( Request, @@ -11,9 +12,6 @@ import re from re import escape -from fastapi_jwt_auth import AuthJWT - - class MiddlewareAccessTokenValidadtion(BaseHTTPMiddleware): def __init__(self, app): super().__init__(app) @@ -22,40 +20,34 @@ class MiddlewareAccessTokenValidadtion(BaseHTTPMiddleware): self.excluded_routes = [ re.compile(r"^" + re.escape(self.prefix) + r"/auth/refresh/?$"), re.compile(r"^" + re.escape(self.prefix) + r"/auth/?$"), + re.compile(r"^" + r"/swagger"), + re.compile(r"^" + r"/openapi"), ] async def dispatch(self, request: Request, call_next): - if request.method in ["GET", "POST", "PUT", "DELETE"]: - if any(pattern.match(request.url.path) for pattern in self.excluded_routes): - return await call_next(request) - else: - auth_header = request.headers.get("Authorization") - if not auth_header: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Missing authorization header."}, - headers={"WWW-Authenticate": "Bearer"}, - ) - - token = auth_header.split(" ")[1] - Authorize = AuthJWT(request) - - try: - current_user = Authorize.get_jwt_subject() - request.state.current_user = current_user - return await call_next(request) - - except Exception: - return JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "The access token is invalid or expired."}, - headers={"WWW-Authenticate": "Bearer"}, - ) - - # async with get_connection() as connection: - # authorize_user = await get_user_login(connection, current_user) - # print(authorize_user) - # if authorize_user is None : - # return JSONResponse( - # status_code=status.HTTP_404_NOT_FOUND , - # detail="User not found.") + if request.method not in ["GET", "POST", "PUT", "DELETE"]: + return JSONResponse( + status_code=status.HTTP_405_METHOD_NOT_ALLOWED, + content={"detail": "Method not allowed"}, + ) + if any(pattern.match(request.url.path) for pattern in self.excluded_routes): + return await call_next(request) + auth_header = request.headers.get("Authorization") + if not auth_header: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Missing authorization header."}, + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + token = auth_header.split(" ")[1] + Authorize = AuthJWT(request) + current_user = Authorize.get_jwt_subject() + request.state.current_user = current_user + except Exception: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "The access token is invalid or expired."}, + headers={"WWW-Authenticate": "Bearer"}, + ) + return await call_next(request)