64 lines
1.9 KiB
Python
64 lines
1.9 KiB
Python
import contextlib
|
|
import json
|
|
import os
|
|
from typing import Any, AsyncGenerator
|
|
|
|
import asyncio
|
|
|
|
import sqlalchemy
|
|
from loguru import logger
|
|
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine
|
|
from sqlalchemy import URL,create_engine, text
|
|
|
|
|
|
from api.config import get_settings
|
|
from api.config.default import DbCredentialsSchema
|
|
|
|
|
|
|
|
class SessionManager:
|
|
engines: Any
|
|
|
|
|
|
def __init__(self, database_uri=get_settings().database_uri) -> None:
|
|
self.database_uri = database_uri
|
|
self.refresh(database_uri)
|
|
# self.reflect()
|
|
|
|
def __new__(cls, database_uri=get_settings().database_uri):
|
|
if not hasattr(cls, "instance"):
|
|
cls.instance = super(SessionManager, cls).__new__(cls)
|
|
cls.instance.engines = {}
|
|
return cls.instance
|
|
|
|
def refresh(self, database_uri) -> None:
|
|
# if not self.engines:
|
|
# self.engines = {}
|
|
if database_uri not in self.engines:
|
|
self.engines[database_uri] = create_async_engine(
|
|
database_uri,
|
|
echo=True,
|
|
future=True,
|
|
# json_serializer=serializer,
|
|
pool_recycle=1800,
|
|
pool_size=get_settings().CONNECTION_POOL_SIZE,
|
|
max_overflow=get_settings().CONNECTION_OVERFLOW,
|
|
)
|
|
def get_engine_by_db_uri(self, database_uri) -> AsyncEngine:
|
|
return self.engines[database_uri]
|
|
|
|
@contextlib.asynccontextmanager
|
|
async def get_connection(
|
|
database_uri=None,
|
|
) -> AsyncGenerator[AsyncConnection, None]:
|
|
if not database_uri:
|
|
database_uri = get_settings().database_uri
|
|
engine = SessionManager(database_uri).get_engine_by_db_uri(database_uri)
|
|
logger.debug(f"engine {engine} {SessionManager(database_uri).engines}")
|
|
async with engine.connect() as conn:
|
|
yield conn
|
|
|
|
async def get_connection_dep() -> AsyncConnection:
|
|
async with get_connection() as conn:
|
|
yield conn
|