Test/database/Base.py
2025-04-13 22:48:56 +02:00

84 lines
No EOL
2.6 KiB
Python

from datetime import datetime
from typing import Optional, Sequence, Self
from uuid import UUID
from sqlalchemy import TypeDecorator, event, select
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession, AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
import sqlalchemy as sa
from sqlalchemy.sql.base import ExecutableOption
from sqlalchemy.sql._typing import _ColumnExpressionArgument
from settings import settings
engine = create_async_engine(
settings.SQLALCHEMY_URL
)
AsyncSessionLocal = async_sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
class Base(AsyncAttrs, DeclarativeBase):
id: Mapped[int] = mapped_column(sa.Integer, primary_key=True, autoincrement=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now())
modified_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now())
@classmethod
async def find_all(
cls,
session: AsyncSession,
options: Optional[list[ExecutableOption]] = None,
filter: Optional[list[_ColumnExpressionArgument[bool]]] = None,
order_by: Optional[list[_ColumnExpressionArgument[Self]]] = None,
limit: Optional[int] = None,
**kwargs
) -> Sequence[Self]:
stmt = select(cls)
if options:
stmt = stmt.options(*options)
if filter:
stmt = stmt.filter(*filter)
if order_by:
stmt = stmt.order_by(*order_by)
if limit:
stmt = stmt.limit(limit)
stmt = stmt.filter_by(**kwargs)
return (await session.execute(stmt)).unique().scalars().all()
@classmethod
async def find_one(
cls,
session: AsyncSession,
options: Optional[list[ExecutableOption]] = None,
filter: Optional[list[_ColumnExpressionArgument[bool]]] = None,
order_by: Optional[list[_ColumnExpressionArgument[Self]]] = None,
limit: Optional[int] = None,
**kwargs
) -> Self:
stmt = select(cls)
if options:
stmt = stmt.options(*options)
if filter:
stmt = stmt.filter(*filter)
if order_by:
stmt = stmt.order_by(*order_by)
stmt = stmt.filter_by(**kwargs)
return (await session.execute(stmt)).unique().scalar_one()
@event.listens_for(Base, "before_update", propagate=True)
def update_modified_at(mapper, connection, target):
target.modified_at = datetime.now()