test
This commit is contained in:
@@ -73,6 +73,7 @@ if __name__ == "__main__":
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
app.add_middleware(MiddlewareAccessTokenValidadtion)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
@@ -80,5 +81,3 @@ app.add_middleware(
|
||||
allow_methods=["GET", "POST", "OPTIONS", "DELETE", "PUT"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.add_middleware(MiddlewareAccessTokenValidadtion)
|
||||
|
@@ -18,7 +18,8 @@ class DbCredentialsSchema(BaseModel):
|
||||
class DefaultSettings(BaseSettings):
|
||||
ENV: str = environ.get("ENV", "local")
|
||||
PATH_PREFIX: str = environ.get("PATH_PREFIX", "/api/v1")
|
||||
APP_HOST: str = environ.get("APP_HOST", "http://127.0.0.1")
|
||||
# APP_HOST: str = environ.get("APP_HOST", "http://127.0.0.1")
|
||||
APP_HOST: str = environ.get("APP_HOST", "http://localhost")
|
||||
APP_PORT: int = int(environ.get("APP_PORT", 8000))
|
||||
APP_ID: uuid.UUID = environ.get("APP_ID", uuid.uuid4())
|
||||
LOGS_STORAGE_PATH: str = environ.get("LOGS_STORAGE_PATH", "storage/logs")
|
||||
|
@@ -50,13 +50,12 @@ async def get_user(connection: AsyncConnection, login: str) -> Optional[User]:
|
||||
return user, password
|
||||
|
||||
|
||||
async def upgrade_old_refresh_token(connection: AsyncConnection, user, refresh_token) -> Optional[User]:
|
||||
async def upgrade_old_refresh_token(connection: AsyncConnection, refresh_token) -> Optional[User]:
|
||||
new_status = KeyStatus.EXPIRED
|
||||
|
||||
update_query = (
|
||||
update(account_keyring_table)
|
||||
.where(
|
||||
account_table.c.id == user.id,
|
||||
account_keyring_table.c.status == KeyStatus.ACTIVE,
|
||||
account_keyring_table.c.key_type == KeyType.REFRESH_TOKEN,
|
||||
account_keyring_table.c.key_value == refresh_token,
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import jwt
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -8,7 +9,6 @@ from fastapi import (
|
||||
Response,
|
||||
status,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from fastapi_jwt_auth import AuthJWT
|
||||
|
||||
@@ -30,11 +30,21 @@ api_router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def get_login_from_jwt(token: str):
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
get_settings().SECRET_KEY,
|
||||
algorithms=[get_settings().ALGORITHM],
|
||||
)
|
||||
return payload.get("sub")
|
||||
|
||||
|
||||
class Settings(BaseModel):
|
||||
authjwt_secret_key: str = get_settings().SECRET_KEY
|
||||
# Configure application to store and get JWT from cookies
|
||||
authjwt_token_location: set = {"headers", "cookies"}
|
||||
authjwt_cookie_domain: str = get_settings().DOMAIN
|
||||
authjwt_refresh_cookie_name: str = "refresh_token_cookie"
|
||||
|
||||
# Only allow JWT cookies to be sent over https
|
||||
authjwt_cookie_secure: bool = get_settings().ENV == "prod"
|
||||
@@ -68,7 +78,8 @@ async def login_for_access_token(
|
||||
# headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
# access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token_expires = timedelta(seconds=5)
|
||||
|
||||
refresh_token_expires = timedelta(days=get_settings().REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
@@ -88,28 +99,19 @@ async def login_for_access_token(
|
||||
|
||||
@api_router.post("/refresh", response_model=Access)
|
||||
async def refresh(
|
||||
request: Request, connection: AsyncConnection = Depends(get_connection_dep), Authorize: AuthJWT = Depends()
|
||||
request: Request,
|
||||
connection: AsyncConnection = Depends(get_connection_dep),
|
||||
Authorize: AuthJWT = Depends(),
|
||||
):
|
||||
refresh_token = request.cookies.get("refresh_token_cookie")
|
||||
# print("Refresh Token:", refresh_token)
|
||||
|
||||
if not refresh_token:
|
||||
raise HTTPException(status_code=401, detail="Refresh token is missing")
|
||||
|
||||
try:
|
||||
Authorize.jwt_refresh_token_required()
|
||||
current_user = Authorize.get_jwt_subject()
|
||||
|
||||
except Exception as e:
|
||||
await upgrade_old_refresh_token(connection, current_user, refresh_token)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid refresh token",
|
||||
)
|
||||
|
||||
access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
Authorize.jwt_refresh_token_required(refresh_token)
|
||||
current_user = Authorize.get_jwt_subject()
|
||||
# try:
|
||||
# access_token_expires = timedelta(minutes=get_settings().ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token_expires = timedelta(seconds=5)
|
||||
new_access_token = Authorize.create_access_token(subject=current_user, expires_time=access_token_expires)
|
||||
|
||||
return Access(access_token=new_access_token)
|
||||
|
@@ -22,40 +22,38 @@ class MiddlewareAccessTokenValidadtion(BaseHTTPMiddleware):
|
||||
self.excluded_routes = [
|
||||
re.compile(r"^" + re.escape(self.prefix) + r"/auth/refresh/?$"),
|
||||
re.compile(r"^" + re.escape(self.prefix) + r"/auth/?$"),
|
||||
re.compile(r"^" + r"/swagger"),
|
||||
re.compile(r"^" + r"/openapi"),
|
||||
]
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if request.method in ["GET", "POST", "PUT", "DELETE"]:
|
||||
if any(pattern.match(request.url.path) for pattern in self.excluded_routes):
|
||||
return await call_next(request)
|
||||
else:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Missing authorization header."},
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
if request.method not in ["GET", "POST", "PUT", "DELETE"]:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_405_METHOD_NOT_ALLOWED,
|
||||
content={"detail": "Method not allowed"},
|
||||
)
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
Authorize = AuthJWT(request)
|
||||
if any(pattern.match(request.url.path) for pattern in self.excluded_routes):
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
current_user = Authorize.get_jwt_subject()
|
||||
request.state.current_user = current_user
|
||||
return await call_next(request)
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Missing authorization header."},
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "The access token is invalid or expired."},
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
token = auth_header.split(" ")[1]
|
||||
Authorize = AuthJWT(request)
|
||||
current_user = Authorize.get_jwt_subject()
|
||||
request.state.current_user = current_user
|
||||
except Exception:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "The access token is invalid or expired."},
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
# async with get_connection() as connection:
|
||||
# authorize_user = await get_user_login(connection, current_user)
|
||||
# print(authorize_user)
|
||||
# if authorize_user is None :
|
||||
# return JSONResponse(
|
||||
# status_code=status.HTTP_404_NOT_FOUND ,
|
||||
# detail="User not found.")
|
||||
return await call_next(request)
|
||||
|
Reference in New Issue
Block a user