Skip to content

aimbat.utils

Miscellaneous helpers for AIMBAT.

Covers four areas:

  • JSON — render JSON data as Rich tables (json_to_table).
  • Sample data — download and delete the bundled sample dataset (download_sampledata, delete_sampledata).
  • Styling — shared Rich/table style helpers (make_table).
  • UUIDs — look up model records by short UUID prefix (get_by_uuid).

Classes:

Name Description
FormatResult

Container for a formatted value and its display metadata.

Functions:

Name Description
delete_sampledata

Delete sample data.

download_sampledata

Download sample data.

get_title_map

Creates a mapping from field names to their 'title' metadata.

json_to_table

Print a JSON dict or list of dicts as a rich table.

mean_and_sem

Return the mean and standard error of the mean (SEM) for a list of numeric values, ignoring None values.

mean_and_sem_timedelta

Return (mean, sem) for a list of pd.Timedelta values.

string_to_uuid

Determine a UUID from a string containing the first few characters.

uuid_shortener

Calculates the shortest unique prefix for a UUID, returning with dashes.

FormatResult

Bases: NamedTuple

Container for a formatted value and its display metadata.

Parameters:

Name Type Description Default
text str
None
justify str
'left'
style str | None
None
Source code in src/aimbat/utils/_table.py
class FormatResult(NamedTuple):
    """Container for a formatted value and its display metadata."""

    text: str
    justify: str = "left"
    style: str | None = None

delete_sampledata

delete_sampledata() -> None

Delete sample data.

Source code in src/aimbat/utils/_sampledata.py
def delete_sampledata() -> None:
    """Delete sample data."""

    logger.info(f"Deleting sample data in {settings.sampledata_dir}.")

    shutil.rmtree(settings.sampledata_dir)

download_sampledata

download_sampledata(force: bool = False) -> None

Download sample data.

Source code in src/aimbat/utils/_sampledata.py
def download_sampledata(force: bool = False) -> None:
    """Download sample data."""

    logger.info(
        f"Downloading sample data from {settings.sampledata_src} to {settings.sampledata_dir}."
    )

    if (
        settings.sampledata_dir.exists()
        and len(os.listdir(settings.sampledata_dir)) != 0
    ):
        if force is True:
            delete_sampledata()
        else:
            raise FileExistsError(
                f"The directory {settings.sampledata_dir} already exists and is non-empty."
            )

    with urlopen(settings.sampledata_src) as zipresp:
        logger.debug(f"Extracting sample data to {settings.sampledata_dir}.")
        with ZipFile(BytesIO(zipresp.read())) as zfile:
            zfile.extractall(settings.sampledata_dir)

    logger.info("Sample data downloaded and extracted successfully.")

get_title_map cached

get_title_map(
    model_class: type[BaseModel],
) -> dict[str, str]

Creates a mapping from field names to their 'title' metadata.

Source code in src/aimbat/utils/_pydantic.py
@lru_cache(maxsize=None)
def get_title_map(model_class: type[BaseModel]) -> dict[str, str]:
    """Creates a mapping from field names to their 'title' metadata."""
    mapping: dict[str, str] = {}

    for name, info in model_class.model_fields.items():
        mapping[name] = info.title or name.replace("_", " ")

    computed_fields: dict[str, Any] = getattr(
        model_class, "__pydantic_computed_fields__", {}
    )
    for name, info in computed_fields.items():
        title = getattr(info, "title", None)
        mapping[name] = title or name.replace("_", " ")

    return mapping

json_to_table

json_to_table(
    data: dict[str, Any] | list[dict[str, Any]],
    title: str | None = None,
    formatters: Mapping[
        str, Callable[[Any], str | FormatResult]
    ]
    | None = None,
    skip_keys: list[str] | None = None,
    column_order: list[str] | None = None,
    column_kwargs: Mapping[str, Mapping[str, Any]]
    | None = None,
    common_column_kwargs: Mapping[str, Any] | None = None,
    short: bool = True,
) -> None

Print a JSON dict or list of dicts as a rich table. Headers are kept as-is from keys unless renamed via column_kwargs.

Source code in src/aimbat/utils/_table.py
def json_to_table(
    data: dict[str, Any] | list[dict[str, Any]],
    title: str | None = None,
    formatters: Mapping[str, Callable[[Any], str | FormatResult]] | None = None,
    skip_keys: list[str] | None = None,
    column_order: list[str] | None = None,
    column_kwargs: Mapping[str, Mapping[str, Any]] | None = None,
    common_column_kwargs: Mapping[str, Any] | None = None,
    short: bool = True,
) -> None:
    """
    Print a JSON dict or list of dicts as a rich table.
    Headers are kept as-is from keys unless renamed via column_kwargs.
    """
    formatters = formatters or {}
    skip = set(skip_keys or [])
    column_kwargs = column_kwargs or {}
    common_column_kwargs = common_column_kwargs or {}
    console = Console()
    table = Table(title=title)

    styling_keys = {f.name for f in fields(TableStyling)}

    def _sorted_keys(keys: Iterable[str]) -> list[str]:
        keys_list = list(keys)
        if not column_order:
            return keys_list
        ordered = [k for k in column_order if k in keys_list]
        rest = [k for k in keys_list if k not in set(column_order)]
        return ordered + rest

    def _get_formatted(key: str, val: Any) -> FormatResult:
        """Resolves formatting and ensures a FormatResult is always returned."""
        raw_result: str | FormatResult | None = None

        if key in formatters:
            raw_result = formatters[key](val)

        if raw_result is None:
            low_key = key.lower()

            if isinstance(val, bool):
                raw_result = TABLE_STYLING.bool_formatter(val)
            elif isinstance(val, float):
                raw_result = TABLE_STYLING.float_formatter(val, short=short)
            elif isinstance(val, (Timestamp, datetime)):
                raw_result = TABLE_STYLING.timestamp_formatter(val, short=short)
            elif (
                isinstance(val, str)
                and val.strip()
                and any(k in low_key for k in ("time", "date", "modified"))
            ):
                raw_result = TABLE_STYLING.timestamp_formatter(val, short=short)
            else:
                raw_result = TABLE_STYLING.default_formatter(val)

        return (
            raw_result
            if isinstance(raw_result, FormatResult)
            else FormatResult(str(raw_result))
        )

    def _add_column_to_table(key: str, default_header: str) -> None:
        specific_kwargs = dict(column_kwargs.get(key, {}))
        kwargs = {**common_column_kwargs, **specific_kwargs}
        header = kwargs.pop("header", default_header)

        low_key = key.lower()
        is_short_id = "short" in low_key and ("_id" in low_key or " id" in low_key)
        is_exact_id = low_key == "id"

        if is_short_id or is_exact_id:
            for sep in ("_", " "):
                search = f"short{sep}id"
                if low_key == search:
                    header = "ID"
                    kwargs.setdefault("style", "yellow")
                    break
                sep_idx = low_key.find(f"short{sep}")
                id_idx = low_key.find("id")
                if sep_idx >= 0 and id_idx > sep_idx + 6:
                    middle = key[sep_idx + 6 : id_idx].strip("_").strip()
                    if middle:
                        header = f"{middle.title()} ID"
                        kwargs.setdefault("style", "magenta")
                    else:
                        header = "ID"
                        kwargs.setdefault("style", "yellow")
                    break
            else:
                header = "ID"
                kwargs.setdefault("style", "yellow")
            kwargs.setdefault("highlight", False)
            kwargs.setdefault("no_wrap", True)
        elif low_key.endswith("_id") or low_key.endswith(" id"):
            middle = ""
            for suffix in ("_id", " id"):
                if low_key.endswith(suffix):
                    middle = key[: -len(suffix)].strip("_").strip()
                    break
            if middle:
                header = f"{middle.title()} ID"
                kwargs.setdefault("style", "magenta")
            else:
                header = "ID"
                kwargs.setdefault("style", "yellow")
            kwargs.setdefault("highlight", False)
            kwargs.setdefault("no_wrap", True)

        hints = FormatResult("")

        if isinstance(data, dict):
            hints = _get_formatted(key, data.get(key))
        elif data:
            for item in data:
                val = item.get(key)
                if (
                    val is not None
                    and val is not NaT
                    and (not isinstance(val, str) or val.strip() != "")
                ):
                    hints = _get_formatted(key, val)
                    break

        if "style" not in kwargs and key in styling_keys:
            kwargs["style"] = getattr(TABLE_STYLING, key)

        if "highlight" not in kwargs:
            kwargs["highlight"] = not is_short_id

        if "justify" not in kwargs:
            kwargs["justify"] = hints.justify
        if "style" not in kwargs and hints.style:
            kwargs["style"] = hints.style

        table.add_column(header, **kwargs)

    if isinstance(data, dict):
        _add_column_to_table("Key", "Key")
        _add_column_to_table("Value", "Value")
        keys = _sorted_keys([k for k in data if k not in skip])
        for key in keys:
            res = _get_formatted(key, data[key])
            table.add_row(str(key), res.text)
    else:
        if not data:
            console.print(table)
            return
        columns = _sorted_keys([k for k in data[0].keys() if k not in skip])
        for col in columns:
            _add_column_to_table(col, str(col))
        for item in data:
            row_cells = [_get_formatted(col, item.get(col)).text for col in columns]
            table.add_row(*row_cells)

    console.print(table)

mean_and_sem

mean_and_sem(
    data: Sequence[float | None],
) -> tuple[float | None, float | None]

Return the mean and standard error of the mean (SEM) for a list of numeric values, ignoring None values.

Parameters:

Name Type Description Default
data Sequence[float | None]

List of numeric values (float or int) or None.

required

Returns:

Type Description
tuple[float | None, float | None]

A tuple containing the mean and SEM of the input data, both as floats or None if not computable.

Source code in src/aimbat/utils/_maths.py
def mean_and_sem(
    data: Sequence[float | None],
) -> tuple[float | None, float | None]:
    """Return the mean and standard error of the mean (SEM) for a list of numeric values, ignoring None values.

    Args:
        data: List of numeric values (float or int) or None.

    Returns:
        A tuple containing the mean and SEM of the input data, both as floats or None if not computable.
    """
    return _mean_and_sem_tuple(tuple(data))

mean_and_sem_timedelta

mean_and_sem_timedelta(
    values: Sequence[Timedelta],
) -> tuple[Timedelta | None, Timedelta | None]

Return (mean, sem) for a list of pd.Timedelta values.

Parameters:

Name Type Description Default
values Sequence[Timedelta]

List of pd.Timedelta values.

required

Returns:

Type Description
tuple[Timedelta | None, Timedelta | None]

(None, None) when empty. SEM is None for fewer than two values.

Source code in src/aimbat/utils/_maths.py
def mean_and_sem_timedelta(
    values: Sequence[pd.Timedelta],
) -> tuple[pd.Timedelta | None, pd.Timedelta | None]:
    """Return (mean, sem) for a list of pd.Timedelta values.

    Args:
        values: List of pd.Timedelta values.

    Returns:
        `(None, None)` when empty. SEM is `None` for fewer than two values.
    """
    if not values:
        return None, None
    ns_vals = [float(td.value) for td in values]
    mean_ns, sem_ns = mean_and_sem(ns_vals)
    return (
        pd.Timedelta(int(mean_ns), unit="ns") if mean_ns is not None else None,
        pd.Timedelta(int(sem_ns), unit="ns") if sem_ns is not None else None,
    )

string_to_uuid

string_to_uuid(
    session: Session,
    id: str,
    aimbat_class: type[AimbatTypes],
    custom_error: str | None = None,
) -> UUID

Determine a UUID from a string containing the first few characters.

Parameters:

Name Type Description Default
session Session

Database session.

required
id str

Input string to find UUID for.

required
aimbat_class type[AimbatTypes]

Aimbat class to use to find UUID.

required
custom_error str | None

Overrides the default error message.

None

Returns:

Type Description
UUID

The full UUID.

Raises:

Type Description
ValueError

If the UUID could not be determined.

Source code in src/aimbat/utils/_uuid.py
def string_to_uuid(
    session: Session,
    id: str,
    aimbat_class: type[AimbatTypes],
    custom_error: str | None = None,
) -> UUID:
    """Determine a UUID from a string containing the first few characters.

    Args:
        session: Database session.
        id: Input string to find UUID for.
        aimbat_class: Aimbat class to use to find UUID.
        custom_error: Overrides the default error message.

    Returns:
        The full UUID.

    Raises:
        ValueError: If the UUID could not be determined.
    """
    statement = select(aimbat_class.id).where(
        func.replace(cast(aimbat_class.id, String), "-", "").like(
            f"{id.replace('-', '')}%"
        )
    )
    uuid_set = set(session.exec(statement).all())
    if len(uuid_set) == 1:
        resolved = uuid_set.pop()
        logger.debug(f"Resolved {id} to UUID: {resolved}")
        return resolved
    if len(uuid_set) == 0:
        raise ValueError(
            custom_error or f"Unable to find {aimbat_class.__name__} using id: {id}."
        )
    raise ValueError(f"Found more than one {aimbat_class.__name__} using id: {id}")

uuid_shortener

uuid_shortener(
    session: Session,
    aimbat_obj: T | type[T],
    min_length: int = 2,
    str_uuid: str | None = None,
) -> str

Calculates the shortest unique prefix for a UUID, returning with dashes.

Parameters:

Name Type Description Default
session Session

An active SQLModel/SQLAlchemy session.

required
aimbat_obj T | type[T]

Either an instance of a SQLModel or the SQLModel class itself.

required
min_length int

The starting character length for the shortened ID.

2
str_uuid str | None

The full UUID string. Required only if aimbat_obj is a class.

None

Returns:

Name Type Description
str str

The shortest unique prefix string, including hyphens where applicable.

Source code in src/aimbat/utils/_uuid.py
def uuid_shortener[T: AimbatTypes](
    session: Session,
    aimbat_obj: T | type[T],
    min_length: int = 2,
    str_uuid: str | None = None,
) -> str:
    """Calculates the shortest unique prefix for a UUID, returning with dashes.

    Args:
        session: An active SQLModel/SQLAlchemy session.
        aimbat_obj: Either an instance of a SQLModel or the SQLModel class itself.
        min_length: The starting character length for the shortened ID.
        str_uuid: The full UUID string. Required only if `aimbat_obj` is a class.

    Returns:
        str: The shortest unique prefix string, including hyphens where applicable.
    """

    if isinstance(aimbat_obj, type):
        model_class = aimbat_obj
        if str_uuid is None:
            raise ValueError("str_uuid must be provided when aimbat_obj is a class.")
        target_full = str(UUID(str_uuid))
    else:
        model_class = type(aimbat_obj)
        target_full = str(aimbat_obj.id)

    prefix_clean = target_full.replace("-", "")[:min_length]

    # select with a WHERE clause that removes dashes and compares the cleaned prefix
    statement = select(model_class.id).where(
        func.replace(cast(model_class.id, String), "-", "").like(f"{prefix_clean}%")
    )

    # Store results as standard hyphenated strings
    results = session.exec(statement).all()
    relevant_pool = [str(uid) for uid in results]

    if target_full not in relevant_pool:
        raise ValueError(f"ID {target_full} not found in table {model_class.__name__}")

    current_length = min_length
    while current_length < len(target_full):
        candidate = target_full[:current_length]
        if candidate.endswith("-"):
            current_length += 1
            candidate = target_full[:current_length]

        matches = [u for u in relevant_pool if u.startswith(candidate)]
        if len(matches) == 1:
            logger.debug(f"Shortened {target_full} to: {candidate}")
            return candidate
        current_length += 1

    return target_full