"""Connect to an SQL server, validate its schema and perform cached analysis queries."""
from collections.abc import Callable, Mapping
from contextvars import ContextVar
from dataclasses import InitVar, dataclass
from datetime import datetime
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
ParamSpec,
Protocol,
TypeAlias,
TypeVar,
cast,
overload,
)
import pandas as pd
import sqlalchemy as sqla
import sqlalchemy.orm as orm
import sqlalchemy.sql.roles as sqla_roles
from pandas.api.types import (
is_datetime64_dtype,
is_integer_dtype,
is_numeric_dtype,
is_string_dtype,
)
from pandas.util import hash_pandas_object
from typing_extensions import Self
from py_research.reflect.runtime import get_all_subclasses
def _hash_df(df: pd.DataFrame | pd.Series) -> str:
return hex(abs(sum(hash_pandas_object(df))))[2:12]
Params = ParamSpec("Params")
def _get_sql_ctx() -> "DB":
sql_con = active_con.get()
if sql_con is None:
raise (RuntimeError("`SQLContext` object not supplied via context."))
return sql_con
SchemaBase = orm.DeclarativeBase
default_type_map = {
str: sqla.types.String().with_variant(
sqla.types.String(50), "oracle"
), # Avoid oracle error when VARCHAR has no size parameter
}
V = TypeVar("V")
S = TypeVar("S", bound=SchemaBase)
S_cov = TypeVar("S_cov", bound=SchemaBase, covariant=True)
S_contra = TypeVar("S_contra", bound=SchemaBase, contravariant=True)
V_contra = TypeVar("V_contra", contravariant=True)
[docs]
class ColRef(orm.InstrumentedAttribute[V], Generic[V, S_contra]):
"""Reference a column by scheme type, name and value type.""" # type: ignore[override]
[docs]
class Col(orm.MappedColumn[V]):
"""Define table column within a table schema."""
if TYPE_CHECKING:
@overload
def __get__(self, instance: None, owner: type[S]) -> ColRef[V, S]: ...
@overload
def __get__(self, instance: object, owner: type) -> V: ...
def __get__( # noqa: D105
self, instance: object | None, owner: type[S]
) -> ColRef[V, S] | V: ...
[docs]
@classmethod
def cast_sqla_attr(cls, attr: orm.MappedColumn) -> Self:
"""Cast a sqlalchemy attr object to this type."""
return cast(Self, attr)
def _wrap_mapped_col(
func: Callable[Params, orm.MappedColumn[V]]
) -> Callable[Params, Col[V]]:
@wraps(func)
def inner(*args: Params.args, **kwargs: Params.kwargs) -> Col[V]:
return Col.cast_sqla_attr(func(*args, **kwargs))
return inner
col = _wrap_mapped_col(orm.mapped_column)
Relation = sqla.ForeignKey
[docs]
class ColumnClause(sqla_roles.ColumnsClauseRole, sqla_roles.ExpressionElementRole):
"""Mix sqlalchemy classes to mitigate typing mishabs."""
[docs]
class Data(Protocol[S_cov]):
"""SQL table-shaped data."""
@property
def columns(self) -> Mapping[str, sqla.ColumnElement]:
"""Return all columns within this data."""
...
def __clause_element__(self) -> ColumnClause: # noqa: D105
...
def __getitem__( # noqa: D105
self, ref: str | ColRef[V, S_cov]
) -> sqla.ColumnElement[V]: ...
@property
def name(self) -> str | None:
"""Data object's name, if any."""
...
[docs]
@dataclass
class Table(Data[S_cov]):
"""Reference to a manifested SQL table."""
sqla_table: sqla.Table
@property
def columns(self) -> dict[str, sqla.Column]: # noqa: D102
return dict(self.sqla_table.columns)
def __getitem__(self, ref: str | ColRef[V, S_cov]) -> sqla.Column[V]: # noqa: D105
return self.columns[ref if isinstance(ref, str) else ref.key]
def __clause_element__(self) -> ColumnClause: # noqa: D105
return cast(ColumnClause, self.sqla_table)
@property
def name(self) -> str: # noqa: D102
return self.sqla_table.name
def _map_df_dtype(c: pd.Series) -> sqla.types.TypeEngine:
if is_datetime64_dtype(c):
return sqla.types.DATETIME()
elif is_integer_dtype(c):
return sqla.types.INTEGER()
elif is_numeric_dtype(c):
return sqla.types.FLOAT()
elif is_string_dtype(c):
max_len = c.str.len().max()
if max_len < 16:
return sqla.types.CHAR(max_len)
elif max_len < 256:
return sqla.types.VARCHAR(max_len)
return sqla.types.BLOB()
def _cols_from_df(df: pd.DataFrame) -> dict[str, sqla.Column]:
if len(df.index.names) > 1:
raise NotImplementedError("Multi-index not supported yet.")
return {
**{
level: sqla.Column(
level,
_map_df_dtype(df.index.get_level_values(level).to_series()),
primary_key=True,
)
for level in df.index.names
},
**{
str(name): sqla.Column(str(col.name), _map_df_dtype(col))
for name, col in df.items()
},
}
[docs]
@dataclass
class Query(Data[S_cov]):
"""SQL data defined by a query."""
sel: sqla.Select
"""SQLAlchemy select statement."""
query_name: InitVar[str]
"""Name of this query."""
schema: type[S_cov] | None = None
"""Schema of this query's result."""
_subquery: sqla.Subquery | None = None
def __post_init__(self, query_name: str): # noqa: D105
self.__name = query_name
@property
def name(self) -> str:
"""Name of this query."""
return self.__name
@property
def columns(self) -> dict[str, sqla.ColumnElement]: # noqa: D102
return dict(self.sel.columns)
def __clause_element__(self) -> ColumnClause: # noqa: D105
if self._subquery is None:
self._subquery = self.sel.subquery()
return cast(ColumnClause, self._subquery)
def __getitem__( # noqa: D105
self, ref: str | ColRef[V, S_cov]
) -> sqla.ColumnElement[V]:
return self.columns[ref if isinstance(ref, str) else ref.key]
SelFunc: TypeAlias = Callable[Params, sqla.Select | sqla.Table]
BoundSelFunc: TypeAlias = Callable[[], sqla.Select | sqla.Table]
[docs]
@dataclass
class DeferredQuery(Data[S_cov]):
"""SQL data defined by a query or table-returning Python function."""
func: BoundSelFunc
"""Function returning a SQL query or table."""
schema: type[S_cov] | None = None
"""Schema of this query's result."""
_subquery: sqla.Subquery | sqla.Table | None = None
@property
def columns(self) -> dict[str, sqla.ColumnElement]: # noqa: D102
return dict(self.subquery.columns)
@property
def subquery(self) -> sqla.Subquery | sqla.Table:
"""Return subquery or table defined by this query."""
if self._subquery is None:
res = self.func()
self._subquery = res.subquery() if isinstance(res, sqla.Select) else res
return self._subquery
def __clause_element__(self) -> ColumnClause: # noqa: D105
return cast(ColumnClause, self.subquery)
def __getitem__( # noqa: D105
self, ref: str | ColRef[V, S_cov]
) -> sqla.ColumnElement[V]:
return self.columns[ref if isinstance(ref, str) else ref.key]
@property
def name(self) -> str | None: # noqa: D102
return self.subquery.name or (
self.func.func.__name__
if isinstance(self.func, partial)
else self.func.__name__ if isinstance(self.func, Callable) else None
)
DS = TypeVar("DS", bound="DBSchema")
def _map_foreignkey_schema(
table: sqla.Table,
to_schema: str | None,
constraint: sqla.ForeignKeyConstraint,
referred_schema: str | None,
schema_dict: "dict[str | None, Schema | None]",
) -> str | None:
assert to_schema in schema_dict
for schema_name, schema in schema_dict.items():
if schema is not None and schema.schema_def is not None:
for schema_class in get_all_subclasses(schema.schema_def):
if (
hasattr(schema_class, "__table__")
and schema_class.__table__ is constraint.referred_table
):
return schema_name
return referred_schema
[docs]
@dataclass
class Schema(Generic[S_cov, DS]):
"""SQL schema defining multiple related tables."""
schema_def: type[S_cov] | None = None
"""Schema base class for this schema's tables."""
db_schema: type[DS] | None = None
"""DBSchema instance this schema is defined in."""
default: bool = False
"""Whether this is the default schema for its schema type."""
name: str | None = None
"""Name of this schema."""
def __set_name__(self, _: type, name: str): # noqa: D105
self.name = name if name != "__default__" else None
def __get__( # noqa: D105
self, instance: Any, owner: type[DS]
) -> "Schema[S_cov, DS]":
self.db_schema = owner
return self
@property
def tables(self) -> dict[str, Table]:
"""Tables defined by this schema."""
assert self.db_schema is not None
return {
name: Table(table)
for name, table in self.db_schema.metadata().tables.items()
if table.schema == self.name
}
def __getitem__(self, table: type[S] | str) -> Table[S]: # noqa: D105
# TODO: Find a way to limit table to type[S_cov] without losing type info,
# maybe via decorating the schema classes.
assert self.db_schema is not None
return Table(
self.db_schema.metadata().tables[
".".join(
[
*([self.name] if self.name else []),
table if isinstance(table, str) else table.__tablename__,
]
)
]
)
[docs]
class DBSchema(Generic[S_cov]):
"""Schema for an entire SQL database."""
__default__: Schema[S_cov, Self] = Schema()
"""Default schema for this DB."""
[docs]
@classmethod
def schema_dict(cls) -> dict[str | None, Schema | None]:
"""Return dict with all sub-schemas."""
return {
None: cls.__default__, # type: ignore
**{
(name if name != "__default__" else None): s
for name, s in cls.__dict__.items()
if isinstance(s, Schema) and name != "__default__"
},
}
[docs]
@classmethod
def default(cls) -> Schema[S_cov, Self]:
"""Return default schema."""
return cls.__default__ # type: ignore
[docs]
@classmethod
def defaults(cls) -> dict[type[SchemaBase], Schema]:
"""Return mapping of schema types to default schema names."""
defaults = {}
for s in cls.schema_dict().values():
if (
s is not None
and s.schema_def is not None
and (s.schema_def not in defaults.keys() or s.default)
):
defaults[s.schema_def] = s
return defaults
S2 = TypeVar("S2", bound=SchemaBase)
def _remove_external_fk(table: sqla.Table):
"""Dirty vodoo to remove external FKs from existing table."""
for c in table.columns.values():
c.foreign_keys = set(
fk
for fk in c.foreign_keys
if fk.constraint and fk.constraint.referred_table.schema == table.schema
)
table.foreign_keys = set( # type: ignore
fk
for fk in table.foreign_keys
if fk.constraint and fk.constraint.referred_table.schema == table.schema
)
table.constraints = set(
c
for c in table.constraints
if not isinstance(c, sqla.ForeignKeyConstraint)
or c.referred_table.schema == table.schema
)
[docs]
@dataclass
class DB(Generic[S_cov, DS]):
"""Active connection to a SQL server."""
url: str | sqla.URL
"""Connection URL."""
schema: type[DBSchema[S_cov]] = DBSchema
"""Schema for this DB."""
tag: str = datetime.today().strftime("YYYY-MM-dd")
"""Tag for this DB session. Used to identify automatically created tables."""
validate: bool = True
"""Whether to validate the DB schema on connection."""
def __post_init__(self): # noqa: D105
self.__token = None
self.engine = sqla.create_engine(self.url)
if self.validate:
# Validation.
inspector = sqla.inspect(self.engine)
for table in self.schema.metadata().tables.values():
if table.schema is not None:
assert inspector.has_schema(table.schema)
assert inspector.has_table(table.name, table.schema)
db_columns = {
c["name"]: c
for c in inspector.get_columns(table.name, table.schema)
}
for column in table.columns:
assert column.name in db_columns
db_col = db_columns[column.name]
assert isinstance(db_col["type"], type(column.type))
assert (
db_col["nullable"] == column.nullable or column.nullable is None
)
db_pk = inspector.get_pk_constraint(table.name, table.schema)
if (
len(db_pk["constrained_columns"]) > 0
): # Allow source tbales without pk
assert set(db_pk["constrained_columns"]) == set(
table.primary_key.columns.keys()
)
db_fks = inspector.get_foreign_keys(table.name, table.schema)
for fk in table.foreign_key_constraints:
matches = [
(
set(db_fk["constrained_columns"]) == set(fk.column_keys),
(
db_fk["referred_table"].lower()
== fk.referred_table.name.lower()
),
set(db_fk["referred_columns"])
== set(f.column.name for f in fk.elements),
)
for db_fk in db_fks
]
assert any(all(m) for m in matches)
@overload
def __getitem__(self, key: Schema[S2, DS]) -> Schema[S2, DS]: ...
@overload
def __getitem__(self, key: None) -> Schema[S_cov, DS]: ...
def __getitem__( # noqa: D105
self, key: Schema[S2, DS] | None
) -> Schema[S2, DS] | Schema[S_cov, DS]:
return self.schema.default() if key is None else key # type: ignore
[docs]
def activate(self) -> None:
"""Set this instance as the current SQL context."""
self.__token = active_con.set(self)
def __enter__(self): # noqa: D105
self.activate()
return self
def __exit__(self, *_): # noqa: D105
if self.__token is not None:
active_con.reset(self.__token)
[docs]
def default_table(self, schema: type[S]) -> Table[S]:
"""Return default table for given table schema.
Args:
schema: Schema of the table to return.
Returns:
Default table for given schema.
"""
defaults = self.schema.defaults()
valid_keys = list(filter(lambda s: issubclass(schema, s), defaults.keys()))
if len(valid_keys) == 0:
raise KeyError(schema)
schema_name = defaults[valid_keys[0]].name
return Table(
self.schema.metadata().tables[
".".join(
[
*([schema_name] if schema_name else []),
schema.__tablename__,
]
)
]
)
def _table_from_df(
self,
df: pd.DataFrame,
schema: type[S_cov] | None = None,
name: str | None = None,
) -> Table[S_cov]:
table_name = f"{self.tag}_df_{name + '_' if name else ''}{_hash_df(df)}"
if table_name not in self.schema.metadata().tables and not sqla.inspect(
self.engine
).has_table(table_name, schema=None):
cols = (
{
c.key: sqla.Column(c.key, c.type, primary_key=c.primary_key)
for c in schema.__table__.columns
}
if schema is not None
else _cols_from_df(df)
)
table = Table(
sqla.Table(table_name, self.schema.metadata(), *cols.values())
)
table.sqla_table.create(self.engine)
df.reset_index()[list(cols.keys())].to_sql(
table_name, self.engine, if_exists="append", index=False
)
return (
Table(self.schema.metadata().tables[table_name])
if table_name in self.schema.metadata().tables
else Table(
sqla.Table(
table_name, self.schema.metadata(), autoload_with=self.engine
)
)
)
def _table_from_query(
self, q: Query[S] | DeferredQuery[S], with_external_fk: bool = False
) -> Table[S_cov]:
table_name = f"{self.tag}_query_{q.name}"
if table_name in self.schema.metadata().tables:
return Table(self.schema.metadata().tables[table_name])
elif sqla.inspect(self.engine).has_table(table_name, schema=None):
return Table(
sqla.Table(
table_name,
self.schema.metadata(),
autoload_with=self.engine,
)
)
else:
sel_res = cast(sqla.Subquery | sqla.Table, q.__clause_element__())
sqla_table = (
q.schema.__table__.to_metadata(
self.schema.metadata(),
schema=None, # type: ignore
referred_schema_fn=partial(
_map_foreignkey_schema, schema_dict=self.schema.schema_dict()
),
name=table_name,
)
if q.schema is not None and isinstance(q.schema.__table__, sqla.Table)
else sqla.Table(
table_name,
self.schema.metadata(),
*(
sqla.Column(name, col.type, primary_key=col.primary_key)
for name, col in sel_res.columns.items()
),
)
)
if not with_external_fk:
_remove_external_fk(sqla_table)
sqla_table.create(self.engine)
table = Table(sqla_table)
with self.engine.begin() as con:
con.execute(
sqla.insert(table).from_select(
sel_res.exported_columns.keys(), sel_res
)
)
return table
[docs]
def to_table(
self,
src: pd.DataFrame | Query[S_cov] | DeferredQuery[S_cov],
schema: type[S_cov] | None = None,
name: str | None = None,
with_external_fk: bool = False,
) -> Table[S_cov]:
"""Transfer dataframe or sql query results to manifested SQL table.
Args:
src: Source / definition of data.
schema: Schema of the table to create.
name: Name of the table to create.
with_external_fk: Whether to include foreign keys to external tables.
Returns:
Reference to manifested SQL table.
"""
match src:
case pd.DataFrame():
return self._table_from_df(src, schema, name)
case Query() | DeferredQuery():
return self._table_from_query(src, with_external_fk)
[docs]
def to_df(self, q: Data) -> pd.DataFrame:
"""Transfer manifested table or query results to local dataframe.
Args:
q: Source / definition of data.
Returns:
Dataframe containing the data.
"""
sel = sqla.select(q)
with self.engine.connect() as con:
return pd.read_sql(sel, con)
active_con: ContextVar[DB | None] = ContextVar("active_sql_con", default=None)
QueryFunc: TypeAlias = Callable[Params, Query[S_cov]]
DefQueryFunc: TypeAlias = Callable[Params, DeferredQuery[S_cov]]
@overload
def query(func: SelFunc[Params]) -> QueryFunc[Params, S]: ...
@overload
def query(
*, defer: Literal[True] = ..., schema: type[S] = ... # type: ignore
) -> Callable[[SelFunc[Params]], DefQueryFunc[Params, S]]: ...
@overload
def query(
*, defer: Literal[True] = ..., schema: None = ...
) -> Callable[[SelFunc[Params]], DefQueryFunc[Params, Any]]: ...
@overload
def query(
*, defer: Literal[False] = ..., schema: type[S] = ... # type: ignore
) -> Callable[[SelFunc[Params]], QueryFunc[Params, S]]: ...
@overload
def query(
*, defer: Literal[False] = ..., schema: None = ...
) -> Callable[[SelFunc[Params]], QueryFunc[Params, Any]]: ...
[docs]
def query(
func: SelFunc[Params] | None = None,
*,
defer: bool | None = False,
schema: type[S] | None = None,
) -> (
QueryFunc[Params, S]
| Callable[[SelFunc[Params]], QueryFunc[Params, S] | DefQueryFunc[Params, S]]
):
"""Decorator to transform query-returning function into a referencable query.
Args:
func: Function returning a :py:obj:`sqla.Selectable`.
defer: Whether to defer execution of the query.
schema: Schema of the query's result.
Returns:
Decorated function.
"""
def inner(
func: SelFunc[Params],
) -> QueryFunc[Params, S] | DefQueryFunc[Params, S]:
@wraps(func)
def inner_inner(
*args: Params.args, **kwargs: Params.kwargs
) -> Query[S] | DeferredQuery[S]:
res = func(*args, **kwargs)
return (
Query(
res if isinstance(res, sqla.Select) else sqla.select(res),
func.__name__,
schema,
)
if not defer
else DeferredQuery(partial(func, *args, **kwargs), schema)
)
return inner_inner # type: ignore
return inner(func) if func is not None else inner # type: ignore
[docs]
def default_table(schema: type[S]) -> DeferredQuery[S]:
"""Return default table for given table schema.
Args:
schema: Schema of the table to return.
Returns:
Reference to default table for given schema.
"""
def get_current_default_table() -> sqla.Table:
ctx = _get_sql_ctx()
return ctx.default_table(schema).sqla_table
return DeferredQuery(get_current_default_table, schema)