"""On-the-fly, file-based caching of function return values."""
import datetime
import json
import pickle
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from typing import Any, ParamSpec, TypeVar, cast, overload
import numpy as np
import pandas as pd
import yaml
from bs4 import BeautifulSoup, Tag
from py_research.data import gen_id
from py_research.files import ensure_dir_exists
from py_research.reflect.runtime import (
get_calling_module,
get_full_args_dict,
get_return_type,
)
from py_research.telemetry import get_logger
log = get_logger()
default_root_path = Path.cwd() / ".cache"
P = ParamSpec("P")
R = TypeVar("R")
[docs]
@dataclass
class FileCache:
"""Local, directory-based cache for storing function results."""
path: Path
"""Root directory for storing cached results."""
max_cache_time: datetime.timedelta = datetime.timedelta(days=7)
"""After how long to invalidate cached objects and recompute."""
def __post_init__(self): # noqa: D105
now = datetime.datetime.now()
self.__earliest_date = now - self.max_cache_time
self.__now_str = now.strftime("%Y-%m-%d")
@staticmethod
def __get_date_from_filename(f: Path):
return datetime.datetime.strptime(f.name.split(".")[0], "%Y-%m-%d")
def __filter_outdated(self, f: Path):
if self.__get_date_from_filename(f) > self.__earliest_date:
return True
else:
f.unlink()
return False
@overload
def function(self, func: Callable[P, R]) -> Callable[P, R]: ...
@overload
def function(
self,
*,
id_arg_subset: list[int] | list[str] | None = None,
use_raw_arg: bool = False,
id_callback: Callable[P, dict[str, Any] | None] | None = None,
use_json: bool = True,
) -> Callable[[Callable[P, R]], Callable[P, R]]: ...
[docs]
def function( # noqa: C901
self,
func: Callable[P, R] | None = None,
*,
id_arg_subset: list[int] | list[str] | None = None,
use_raw_arg: bool = False,
id_callback: Callable[P, dict[str, Any] | None] | None = None,
use_json: bool = True,
) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]:
"""Decorator to cache wrapped function.
Args:
func: Function to cache.
id_arg_subset:
Number or name of the arguments to base hash id of result on.
use_raw_arg:
If ``True``, use the unhashed, string-formatted value of the id arg
as filename. Only works for single id arg.
id_callback:
Callback function to use for retrieving a unique id from the arguments.
use_json:
Whether to use JSON as the format for caching dicts (instead of YAML).
"""
def inner(func: Callable[P, R]) -> Callable[P, R]: # noqa: C901
path = ensure_dir_exists(self.path / func.__name__)
@wraps(func)
def inner_inner(*args: P.args, **kwargs: P.kwargs) -> R: # noqa: C901
id_args = None
if id_callback is not None:
id_args = id_callback(*args, **kwargs)
if id_args is None:
id_args = get_full_args_dict(func, args, kwargs)
if id_arg_subset is not None:
id_args = {
k: v for k, v in id_args.items() if k in id_arg_subset
}
id_value = gen_id(
id_args if len(id_args) > 1 else list(id_args.values())[0],
raw_str=use_raw_arg,
)
filename_pattern = f"[0-9]*-[0-9]*-[0-9]*.{id_value}.*"
result: R | None = None
all_cached = [
f
for f in path.iterdir()
if f.match(filename_pattern) and self.__filter_outdated(f)
]
if len(all_cached):
log.debug(f"💾 Taking cached result for: '{id_value}'")
last_cached = sorted(all_cached, key=self.__get_date_from_filename)[
-1
]
extension = last_cached.name.split(".")[-1]
return_type = get_return_type(func) or type
cached_result = None
if issubclass(return_type, str) and extension == "txt":
with open(
last_cached,
encoding="utf-8",
) as f:
cached_result = f.read()
elif issubclass(return_type, dict | list) and extension in (
"yaml",
"yml",
):
cached_result = yaml.load(
open(last_cached, encoding="utf-8"), Loader=yaml.CLoader
)
elif issubclass(return_type, dict | list) and extension == "json":
cached_result = json.load(open(last_cached, encoding="utf-8"))
elif (
issubclass(return_type, pd.DataFrame | pd.Series)
and extension == "xlsx"
):
cached_result = pd.read_excel(
last_cached, header=0, index_col=0
)
elif (
issubclass(return_type, pd.DataFrame | pd.Series)
and extension == "csv"
):
cached_result = pd.read_csv(last_cached, header=0, index_col=0)
elif issubclass(return_type, np.ndarray) and extension == "npy":
cached_result = np.load(last_cached)
elif (
issubclass(return_type, BeautifulSoup | Tag)
and extension == "html"
):
cached_result = BeautifulSoup(open(last_cached))
elif extension == "pkl":
cached_result = pickle.load(open(last_cached, mode="rb"))
result = cast(R | None, cached_result)
if result is None:
log.debug(
f"⬇ Performing operation / fetching resource: '{id_value}'"
)
result = func(*args, **kwargs)
match (result):
case str():
with open(
path / f"{self.__now_str}.{id_value}.txt",
"w",
encoding="utf-8",
) as f:
f.write(result)
case dict() | list():
if use_json:
json.dump(
result,
open(
path / f"{self.__now_str}.{id_value}.json",
"w",
encoding="utf-8",
),
indent=2,
)
else:
yaml.dump(
result,
open(
path / f"{self.__now_str}.{id_value}.yaml",
"w",
encoding="utf-8",
),
allow_unicode=True,
)
case pd.DataFrame():
result.to_excel(
path / f"{self.__now_str}.{id_value}.xlsx", index=True
)
case np.ndarray():
np.save(path / f"{self.__now_str}.{id_value}.npy", result)
case BeautifulSoup() | Tag():
with open(
path / f"{self.__now_str}.{id_value}.html", mode="w"
) as file:
file.write(str(result))
case _:
pickle.dump(
result,
open(
path / f"{self.__now_str}.{id_value}.pkl", mode="wb"
),
)
return result
return inner_inner
return inner if func is None else inner(func)
[docs]
def get_cache(
name: str | None = None,
root_path: Path | None = None,
max_cache_time: datetime.timedelta = datetime.timedelta(days=7),
):
"""Return a named cache instance private to the calling module.
Args:
name: Name of the cache (directory) to create.
root_path: Root directory, where to store cache.
max_cache_time: After how many days to invalidate cached objects and recompute.
Returns:
A cache instance.
"""
root_path = root_path or default_root_path
calling_module = get_calling_module()
module_name = calling_module.__name__ if calling_module is not None else None
return FileCache(
ensure_dir_exists(root_path / (module_name or "root") / (name or "")),
max_cache_time,
)