"""Utilities for importing different data representations into relational format."""
from collections.abc import Callable, Hashable, Mapping
from dataclasses import dataclass
from datetime import datetime
from itertools import chain, groupby
from typing import Any, Literal, TypeAlias, cast, overload
from uuid import uuid4
import pandas as pd
from py_research.hashing import gen_str_hash
from .base import DB
from .conflicts import DataConflictError, DataConflictPolicy, DataConflicts
Scalar = str | int | float | datetime
_AttrMap = Mapping[str, "str | bool | _AttrMap"]
"""Mapping of hierarchical attributes to table columns"""
_RelationalMap = Mapping[str, "str | bool | _RelationalMap | TableMap | list[TableMap]"]
"""Mapping of hierarchical attributes to table columns or other tables."""
[docs]
@dataclass
class TableMap:
"""Configuration for how to map (nested) dictionary items to relational tables."""
table: str
"""Name of the table to map to attributes in ``map`` to."""
map: (
_RelationalMap
| set[str]
| str
| Callable[[dict | str], _RelationalMap | set[str] | str]
)
"""Mapping of hierarchical attributes to table columns or other tables."""
ext_maps: "list[TableMap] | None" = None
"""Map attributes on the same level to different tables"""
link_attr: str | None = None
"""Override attribute to use when referencing this table from a parent table."""
link_type: Literal["1-n", "n-1", "n-m"] = "n-m"
"""Type of reference between this table and the parent table.
Only ``n-m`` will create a join table.
"""
join_table_name: str | None = None
"""Name of the join table to use for referencing this table to its parent, if any"""
join_table_map: _AttrMap | None = None
"""Mapping of hierarchical attributes to join table rows."""
id_type: Literal["hash", "attr", "uuid"] = "hash"
"""Type of id to use for this table.
- ``hash``: sha256 hash of a subset of all mapped attributes
- ``attr``: use given attr as id directly, no hashing
- ``uuid``: generate random uuid for each row
"""
id_attr: str | list[str] | None = None
"""Name of unique, mapped attribute to use as id directly or
list of mapped attributes to use for auto-generating row ids via sha256 hashing.
If None, all mapped attributes will be used for hashing.
"""
hash_id_with_path: bool = False
"""If True, the id will be generated based on the data and the full tree path.
"""
match_by_attr: bool | str = False
"""Try to match this mapped data to target table (by given attr)
before creating a new row.
"""
conflict_policy: DataConflictPolicy = "raise"
"""Which policy to use if import conflicts occur for this table."""
_SubMap = tuple[dict | list, TableMap | list[TableMap]]
"""Combination of data and mapping for a subtree."""
def _map_sublayer(
node: dict, mapping: _RelationalMap
) -> tuple[pd.Series, dict[str, _SubMap]]:
"""Extract hierarchical data into set of scalar attributes + ref'd data objects."""
# Split the current mapping level into groups based on type.
target_groups: dict[type, dict] = {
t: dict(g) # type: ignore
for t, g in groupby(
sorted(
mapping.items(),
key=lambda item: str(type(item[1])),
),
key=lambda item: type(item[1]),
)
} # type: ignore
# First list and handle all scalars, hence data attributes on the current level,
# which are to be mapped to table columns.
scalars = dict(
chain(
(target_groups.get(str) or {}).items(),
(target_groups.get(bool) or {}).items(),
)
)
cols = {
(col if isinstance(col, str) else attr): data
for attr, col in scalars.items()
if isinstance(data := node.get(attr), Scalar)
}
refs = {
attr: (data, cast(TableMap | list[TableMap], sub_map))
for attr, sub_map in {
**(target_groups.get(TableMap) or {}),
**(target_groups.get(list) or {}),
}.items()
if isinstance(data := node.get(attr), dict | list)
}
# Handle nested data attributes (which come as dict types).
for attr, sub_map in (target_groups.get(dict) or {}).items():
sub_node = node.get(attr)
if isinstance(sub_node, dict):
sub_row, sub_refs = _map_sublayer(sub_node, cast(dict, sub_map))
cols = {**cols, **sub_row}
refs = {**refs, **sub_refs}
return pd.Series(cols, dtype=object), refs
def _gen_row_hash(
row: pd.Series,
context_path: list[str | int] | None = None,
hash_subset: list[str] | None = None,
) -> str:
"""Generate hash for given row."""
hash_row = (
row[list(set(hash_subset) & set(row.index))] if hash_subset is not None else row
)
row_id = gen_str_hash(
hash_row.to_dict()
if context_path is not None
else (hash_row.to_dict(), context_path)
)
return row_id
_DictDB: TypeAlias = dict[str, dict[Hashable, dict[str, Any]]]
"""All-dictionary representation of a relational database for performant extending."""
_Rels: TypeAlias = dict[tuple[str, str], tuple[str, str]]
"""Relations between tables, indexed by the referencing table and column."""
_JoinTables: TypeAlias = set[str]
"""Names of all join tables in the database."""
RelDB = tuple[_DictDB, _Rels, _JoinTables]
"""Full relational database representation."""
def _handle_refs( # noqa: C901
mapping: TableMap,
db: RelDB,
row: pd.Series,
refs: list[tuple[str | None, _SubMap]],
collect_conflicts: bool = False,
_all_conflicts: DataConflicts | None = None,
_path: list[str | int] | None = None,
) -> DataConflicts:
"""Handle references to other tables."""
_all_conflicts = _all_conflicts or {}
_path = _path or []
database, relations, join_tables = db
# Handle nested data, which is to be extracted into separate tables and referenced.
for attr, (sub_data, sub_maps) in refs:
# Get info about the ref table to use from mapping
# (or generate a new for the ref table).
if not isinstance(sub_maps, list):
sub_maps = [sub_maps]
for sub_map in sub_maps:
join_table_name = sub_map.join_table_name
join_table_exists = False
alt_join_table_names = [
f"{mapping.table}_{sub_map.table}",
f"{sub_map.table}_{mapping.table}",
]
if join_table_name is None:
join_table_name = alt_join_table_names[0]
for name in alt_join_table_names:
if name in database:
join_table_name = name
join_table_exists = True
break
else:
join_table_exists = join_table_name in database
if not isinstance(sub_data, list):
sub_data = [sub_data]
for i, sub_data_item in enumerate(sub_data):
_sub_path = [_path, attr, i]
rel_row, _all_conflicts = _tree_to_db(
sub_data_item,
sub_map,
db,
collect_conflicts,
_all_conflicts,
_sub_path,
)
if sub_map.link_type == "n-m":
# Map via join table.
if not join_table_exists:
database[join_table_name] = {}
join_table_exists = True
ref_row, _ = (
_map_sublayer(sub_data_item, sub_map.join_table_map)
if isinstance(sub_data_item, dict)
and sub_map.join_table_map is not None
else (pd.Series(dtype=object), None)
)
left_col = f"{attr}_of" if attr is not None else mapping.table
relations[(join_table_name, left_col)] = (
mapping.table,
"_id",
)
ref_row[left_col] = row.name
right_col = attr if attr is not None else sub_map.table
relations[(join_table_name, right_col)] = (
sub_map.table,
"_id",
)
ref_row[right_col] = rel_row.name
ref_row.name = _gen_row_hash(ref_row, _sub_path)
database[join_table_name][ref_row.name] = ref_row.to_dict()
join_tables |= {join_table_name}
elif sub_map.link_type == "1-n":
# Map via direct reference from children to parent.
col = f"{attr}_of" if attr is not None else mapping.table
if col in rel_row.index:
assert rel_row[col] is None or rel_row[col] == row.name
relations[(sub_map.table, col)] = (
mapping.table,
"_id",
)
database[sub_map.table][rel_row.name] = {
**rel_row.to_dict(),
col: row.name,
}
elif sub_map.link_type == "n-1":
# Map via direct reference from parents to child.
col = attr if attr is not None else sub_map.table
if col in row.index:
assert row[col] is None or row[col] == row.name
relations[(mapping.table, col)] = (
sub_map.table,
"_id",
)
database[mapping.table][row.name] = {
**row.to_dict(),
col: rel_row.name,
}
return _all_conflicts
def _tree_to_db( # noqa: C901
data: dict | str,
mapping: TableMap,
db: RelDB,
collect_conflicts: bool = False,
_all_conflicts: DataConflicts | None = None,
_path: list[str | int] | None = None,
) -> tuple[pd.Series, DataConflicts]:
"""Transform recursive dictionary data into relational format."""
_all_conflicts = _all_conflicts or {}
_path = _path or []
database, _, _ = db
# Initialize new table as defined by mapping, if it doesn't exist yet.
if mapping.table not in database:
database[mapping.table] = {}
resolved_map = (
mapping.map(data) if isinstance(mapping.map, Callable) else mapping.map
)
# Extract row of data attributes and refs to other objects.
row = None
refs: list[tuple[str | None, _SubMap]] = []
# If mapping is only a string, extract the target attr directly.
if isinstance(data, str):
assert isinstance(resolved_map, str)
row = pd.Series({resolved_map: data}, dtype=object)
else:
if isinstance(resolved_map, set):
row = pd.Series(
{k: v for k, v in data.items() if k in resolved_map}, dtype=object
)
elif isinstance(resolved_map, dict):
row, ref_dict = _map_sublayer(data, resolved_map)
refs = [*refs, *ref_dict.items()]
else:
raise TypeError(
f"Unsupported mapping type {type(resolved_map)}"
f" for data of type {type(data)}"
)
row.name = (
row[mapping.id_attr]
if mapping.id_type == "attr" and isinstance(mapping.id_attr, str)
else (
_gen_row_hash(
row,
_path,
(
[mapping.id_attr]
if isinstance(mapping.id_attr, str)
else mapping.id_attr
),
)
if mapping.id_type == "hash"
else str(uuid4())[-10:]
)
)
if not isinstance(row.name, str | int):
raise ValueError(
f"Value of `'{mapping.id_attr}'` (`TableMap.id_attr`) "
f"must be a string or int for all objects, but received {row.name}"
)
if mapping.ext_maps is not None:
assert isinstance(data, dict)
refs = [*refs, *((m.link_attr, (data, m)) for m in mapping.ext_maps)]
_all_conflicts = _handle_refs(
mapping, db, row, refs, collect_conflicts, _all_conflicts
)
if not row.empty:
existing_row: dict[str, Any] | None = None
if mapping.match_by_attr:
# Make sure any existing data in database is consistent with new data.
match_by = cast(str, mapping.match_by_attr)
if mapping.match_by_attr is True:
assert isinstance(resolved_map, str)
match_by = resolved_map
match_to = row[match_by]
# Match to existing row or create new one.
existing_rows: list[tuple[str, dict[str, Any]]] = list(
filter(
lambda i: i[1][match_by] == match_to,
database[mapping.table].items(),
)
)
if len(existing_rows) > 0:
existing_row_id, existing_row = existing_rows[0]
# Take over the id of the existing row.
row.name = existing_row_id
else:
existing_row = database[mapping.table].get(row.name)
if existing_row is not None:
existing_attrs = set(
str(k) for k, v in existing_row.items() if k and pd.notna(v)
)
new_attrs = set(str(k) for k, v in row.items() if k and pd.notna(v))
# Assert that overlapping attributes are equal.
intersect = existing_attrs & new_attrs
if mapping.conflict_policy == "raise":
conflicts = {
(mapping.table, row.name, c): (existing_row[c], row[c])
for c in intersect
if existing_row[c] != row[c]
}
if len(conflicts) > 0:
if not collect_conflicts:
raise DataConflictError(conflicts)
_all_conflicts = {**_all_conflicts, **conflicts}
if mapping.conflict_policy == "ignore":
row = pd.Series(
{**row.loc[list(new_attrs)], **existing_row}, name=row.name
)
else:
row = pd.Series(
{**existing_row, **row.loc[list(new_attrs)]}, name=row.name
)
# Add row to database table or update it.
database[mapping.table][row.name] = row.to_dict()
# Return row (used for recursion).
return row, _all_conflicts
@overload
def tree_to_db(
data: dict | str,
mapping: TableMap,
collect_conflicts: Literal[True] = ...,
) -> tuple[DB, DataConflicts]: ...
@overload
def tree_to_db(
data: dict | str,
mapping: TableMap,
collect_conflicts: Literal[False] = ...,
) -> DB: ...
[docs]
def tree_to_db( # noqa: C901
data: dict | str,
mapping: TableMap,
collect_conflicts: bool = False,
) -> DB | tuple[DB, DataConflicts]:
"""Transform recursive dictionary data into relational format.
Args:
data: The data to be transformed
mapping:
Configuration for how to performm the mapping.
collect_conflicts:
Collect all conflicts and return them, rather than stopping right away.
Returns:
The relational database representation of the data.
If ``collect_conflicts`` is ``True``, a tuple of the database and the conflicts
is returned.
"""
df_dict = {}
rels = {}
join_tables = set()
_, conflicts = _tree_to_db(
data, mapping, (df_dict, rels, join_tables), collect_conflicts
)
db = DB(
table_dfs={
name: pd.DataFrame.from_dict(df, orient="index").rename_axis(
"_id", axis="index"
)
for name, df in df_dict.items()
},
relations=rels,
join_tables=join_tables,
)
return (db, conflicts) if collect_conflicts else db