85 lines
2.6 KiB
Python
85 lines
2.6 KiB
Python
from typing import Optional
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncConnection
|
|
from enum import Enum
|
|
|
|
from api.db.tables.account import account_table, account_keyring_table, KeyType, KeyStatus
|
|
|
|
from api.schemas.account.account import User
|
|
from api.schemas.account.account_keyring import AccountKeyring
|
|
|
|
from api.utils.key_id_gen import KeyIdGenerator
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
async def get_user(connection: AsyncConnection, login: str) -> Optional[User]:
|
|
|
|
query = (
|
|
select(account_table, account_keyring_table)
|
|
.join(account_keyring_table, account_table.c.id == account_keyring_table.c.owner_id)
|
|
.where(account_table.c.login == login,
|
|
account_keyring_table.c.key_type == KeyType.PASSWORD)
|
|
)
|
|
|
|
user_db_cursor = await connection.execute(query)
|
|
user_db = user_db_cursor.one_or_none()
|
|
|
|
if not user_db:
|
|
return None, None
|
|
|
|
|
|
user_data = {
|
|
column.name: (getattr(user_db, column.name).name if isinstance(
|
|
getattr(user_db, column.name), Enum) else getattr(user_db, column.name))
|
|
for column in account_table.columns
|
|
}
|
|
|
|
password_data = {
|
|
column.name: (getattr(user_db, column.name).name if isinstance(
|
|
getattr(user_db, column.name), Enum) else getattr(user_db, column.name))
|
|
for column in account_keyring_table.columns
|
|
}
|
|
|
|
user = User.model_validate(user_data)
|
|
password = AccountKeyring.model_validate(password_data)
|
|
return user, password
|
|
|
|
|
|
async def upgrade_old_refresh_token(connection: AsyncConnection, user,refresh_token) -> Optional[User]:
|
|
|
|
new_status = KeyStatus.EXPIRED
|
|
|
|
update_query = (
|
|
update(account_keyring_table)
|
|
.where(
|
|
account_table.c.id == user.id,
|
|
account_keyring_table.c.status == KeyStatus.ACTIVE,
|
|
account_keyring_table.c.key_type == KeyType.REFRESH_TOKEN,
|
|
account_keyring_table.c.key_value == refresh_token
|
|
)
|
|
.values(status=new_status)
|
|
)
|
|
|
|
await connection.execute(update_query)
|
|
|
|
await connection.commit()
|
|
|
|
|
|
async def add_new_refresh_token(connection: AsyncConnection, new_refresh_token, new_refresh_token_expires_time, user) -> Optional[User]:
|
|
|
|
new_refresh_token = account_keyring_table.insert().values(
|
|
owner_id=user.id,
|
|
key_type=KeyType.REFRESH_TOKEN,
|
|
key_id=KeyIdGenerator(),
|
|
key_value=new_refresh_token,
|
|
created_at=datetime.now(timezone.utc),
|
|
expiry=new_refresh_token_expires_time,
|
|
status=KeyStatus.ACTIVE,
|
|
)
|
|
|
|
await connection.execute(new_refresh_token)
|
|
|
|
await connection.commit()
|