54 lines
1.9 KiB
Python
54 lines
1.9 KiB
Python
import re
|
|
from re import escape
|
|
|
|
from fastapi import (
|
|
Request,
|
|
status,
|
|
)
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi_jwt_auth import AuthJWT
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from api.config import get_settings
|
|
|
|
|
|
class MiddlewareAccessTokenValidadtion(BaseHTTPMiddleware):
|
|
def __init__(self, app):
|
|
super().__init__(app)
|
|
|
|
self.prefix = escape(get_settings().PATH_PREFIX)
|
|
self.excluded_routes = [
|
|
re.compile(r"^" + re.escape(self.prefix) + r"/auth/refresh/?$"),
|
|
re.compile(r"^" + re.escape(self.prefix) + r"/auth/?$"),
|
|
re.compile(r"^" + r"/swagger"),
|
|
re.compile(r"^" + r"/openapi"),
|
|
]
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
if request.method not in ["GET", "POST", "PUT", "DELETE"]:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_405_METHOD_NOT_ALLOWED,
|
|
content={"detail": "Method not allowed"},
|
|
)
|
|
if any(pattern.match(request.url.path) for pattern in self.excluded_routes):
|
|
return await call_next(request)
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "Missing authorization header."},
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
try:
|
|
token = auth_header.split(" ")[1]
|
|
Authorize = AuthJWT(request)
|
|
current_user = Authorize.get_jwt_subject()
|
|
request.state.current_user = current_user
|
|
except Exception:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "The access token is invalid or expired."},
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
return await call_next(request)
|