from datetime import datetime from typing import Optional, Sequence, Self from uuid import UUID from sqlalchemy import TypeDecorator, event, select from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession, AsyncAttrs from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column import sqlalchemy as sa from sqlalchemy.sql.base import ExecutableOption from sqlalchemy.sql._typing import _ColumnExpressionArgument from settings import settings engine = create_async_engine( settings.SQLALCHEMY_URL ) AsyncSessionLocal = async_sessionmaker( autocommit=False, autoflush=False, bind=engine, expire_on_commit=False ) class Base(AsyncAttrs, DeclarativeBase): id: Mapped[int] = mapped_column(sa.Integer, primary_key=True, autoincrement=True, nullable=False) created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now()) modified_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, default=datetime.now()) @classmethod async def find_all( cls, session: AsyncSession, options: Optional[list[ExecutableOption]] = None, filter: Optional[list[_ColumnExpressionArgument[bool]]] = None, order_by: Optional[list[_ColumnExpressionArgument[Self]]] = None, limit: Optional[int] = None, **kwargs ) -> Sequence[Self]: stmt = select(cls) if options: stmt = stmt.options(*options) if filter: stmt = stmt.filter(*filter) if order_by: stmt = stmt.order_by(*order_by) if limit: stmt = stmt.limit(limit) stmt = stmt.filter_by(**kwargs) return (await session.execute(stmt)).unique().scalars().all() @classmethod async def find_one( cls, session: AsyncSession, options: Optional[list[ExecutableOption]] = None, filter: Optional[list[_ColumnExpressionArgument[bool]]] = None, order_by: Optional[list[_ColumnExpressionArgument[Self]]] = None, limit: Optional[int] = None, **kwargs ) -> Self: stmt = select(cls) if options: stmt = stmt.options(*options) if filter: stmt = stmt.filter(*filter) if order_by: stmt = stmt.order_by(*order_by) stmt = stmt.filter_by(**kwargs) return (await session.execute(stmt)).unique().scalar_one() @event.listens_for(Base, "before_update", propagate=True) def update_modified_at(mapper, connection, target): target.modified_at = datetime.now()