Source code for scripts.utils.step_cache

"""Pipeline step caching for fast incremental re-runs.

This module provides file-based caching so that each pipeline step can be
skipped when its outputs already exist **and** its inputs have not changed
since the last successful run.

How it works:

1. Before running a step, ``is_step_fresh()`` checks for a manifest file
   (``.<step_name>.manifest.json``) stored inside the step's output
   directory.  The manifest records:
   - SHA-256 content hashes of every input file
   - Artifact version strings that were current when the step ran
   - A timestamp for human convenience

2. If the manifest exists and every recorded input hash still matches the
   file on disk (and artifact versions haven't changed), the step is
   considered *fresh* and can be skipped.

3. After a step completes successfully, ``save_step_manifest()`` writes a
   new manifest capturing the current state of its inputs.

This gives deterministic, content-based cache invalidation with no need for
external databases or lock files.

Design rules:
- Pure-function hashing: only file contents matter, not timestamps.
- Manifests are hidden dotfiles so they don't pollute ``ls`` output.
- ``--force`` in the CLI always bypasses the cache.
- Missing output directories always mean "not fresh".
- If the manifest itself is corrupt or missing, the step runs.
"""

from __future__ import annotations

import contextlib
import json
import logging
import os
import tempfile
from datetime import UTC, datetime
from pathlib import Path
from typing import Any

from .integrity import hash_file

__all__ = [
    "MANIFEST_VERSION",
    "hash_directory",
    "hash_file",
    "is_step_fresh",
    "save_step_manifest",
]

logger = logging.getLogger(__name__)

MANIFEST_VERSION = "1.0.0"
"""Schema version of the manifest file itself."""

_BUF_SIZE = 1 << 16  # 64 KiB read buffer for hashing


# ---------------------------------------------------------------------------
# Hashing helpers
# ---------------------------------------------------------------------------


[docs] def hash_directory( directory: Path, *, extensions: frozenset[str] | None = None, ) -> dict[str, str]: """Compute per-file SHA-256 hashes for every relevant file in *directory*. Files are discovered recursively. Hidden files, ``__pycache__`` dirs, and ``.pyc`` files are excluded. If *extensions* is provided, only files whose suffix is in the set are included (e.g. ``{".xlsx", ".csv"}``). The returned dict maps ``relative_path → hex_sha256``. Keys are sorted so that the overall dict is deterministic regardless of filesystem walk order. Args: directory: Root directory to hash. extensions: Optional allowlist of file suffixes to include. Returns: Sorted dict of ``{relative_path: sha256_hex}``. """ if not directory.is_dir(): return {} hashes: dict[str, str] = {} for fpath in directory.rglob("*"): if not fpath.is_file(): continue # Check path parts relative to ensure we skip hidden dirs and pycache rel = str(fpath.relative_to(directory)) parts = Path(rel).parts if any(p.startswith(".") or p == "__pycache__" for p in parts): continue if fpath.name.startswith(".") or fpath.name.endswith(".pyc"): continue if extensions and fpath.suffix not in extensions: continue hashes[rel] = hash_file(fpath) return dict(sorted(hashes.items()))
# --------------------------------------------------------------------------- # Manifest I/O # --------------------------------------------------------------------------- def _manifest_path(output_dir: Path, step_name: str) -> Path: """Return the hidden manifest file path for a given step.""" return output_dir / f".{step_name}.manifest.json"
[docs] def save_step_manifest( step_name: str, output_dir: Path, input_hashes: dict[str, str], *, artifact_versions: dict[str, str] | None = None, extra_metadata: dict[str, Any] | None = None, ) -> Path: """Persist a cache manifest after a successful step run. Args: step_name: Short identifier for the pipeline step (e.g. ``"dictionary"``). output_dir: Directory where the step wrote its outputs. input_hashes: ``{relative_path: sha256}`` of every input file. artifact_versions: Optional artifact version strings to record. extra_metadata: Optional extra data to store (e.g. counts, flags). Returns: Path to the written manifest file. """ output_dir.mkdir(parents=True, exist_ok=True) mpath = _manifest_path(output_dir, step_name) payload: dict[str, Any] = { "manifest_version": MANIFEST_VERSION, "step_name": step_name, "completed_at": datetime.now(UTC).isoformat(), "input_hashes": input_hashes, } if artifact_versions: payload["artifact_versions"] = artifact_versions if extra_metadata: payload["extra"] = extra_metadata # Atomic write to avoid partial manifest files on crash/interruption mpath.parent.mkdir(parents=True, exist_ok=True) with tempfile.NamedTemporaryFile("w", dir=mpath.parent, delete=False, encoding="utf-8") as tmp: json.dump(payload, tmp, indent=2, ensure_ascii=False, sort_keys=True) tmp.write("\n") tmp.flush() os.fsync(tmp.fileno()) tmp_name = tmp.name try: os.replace(tmp_name, mpath) except Exception: with contextlib.suppress(OSError): os.remove(tmp_name) raise logger.debug("Step cache manifest saved: %s", mpath) return mpath
[docs] def is_step_fresh( step_name: str, output_dir: Path, current_input_hashes: dict[str, str], *, artifact_versions: dict[str, str] | None = None, required_outputs: list[str] | None = None, ) -> bool: """Check whether a pipeline step can be skipped. A step is *fresh* when ALL of the following hold: 1. The output directory exists. 2. A valid manifest file exists inside it. 3. Every input file hash in the manifest matches the current hash. 4. No new input files have appeared that weren't in the manifest. 5. If *artifact_versions* is provided, every recorded version matches. 6. If *required_outputs* is provided, each named file exists under *output_dir*. Args: step_name: Pipeline step identifier. output_dir: Directory where the step writes outputs. current_input_hashes: Live hashes of current input files. artifact_versions: If provided, must match recorded versions. required_outputs: Optional list of filenames/globs that must exist under *output_dir* for the step to be considered complete. Returns: ``True`` if the step is fresh and can be safely skipped. """ mpath = _manifest_path(output_dir, step_name) # ── 1. Output dir and manifest must exist ── if not output_dir.is_dir(): logger.debug("Cache miss [%s]: output dir does not exist", step_name) return False if not mpath.is_file(): logger.debug("Cache miss [%s]: manifest not found", step_name) return False # ── 2. Parse manifest ── try: with open(mpath, encoding="utf-8") as fh: manifest = json.load(fh) except (json.JSONDecodeError, OSError) as exc: logger.debug("Cache miss [%s]: manifest corrupt: %s", step_name, exc) return False if not isinstance(manifest, dict): logger.debug("Cache miss [%s]: manifest is not a JSON object", step_name) return False recorded_hashes = manifest.get("input_hashes", {}) if not isinstance(recorded_hashes, dict): logger.debug("Cache miss [%s]: input_hashes not a dict", step_name) return False # ── 3. Compare input hashes (both directions) ── if set(recorded_hashes.keys()) != set(current_input_hashes.keys()): added = set(current_input_hashes.keys()) - set(recorded_hashes.keys()) removed = set(recorded_hashes.keys()) - set(current_input_hashes.keys()) logger.debug( "Cache miss [%s]: input file set changed (added=%s, removed=%s)", step_name, added or "∅", removed or "∅", ) return False for rel_path, current_hash in current_input_hashes.items(): if recorded_hashes.get(rel_path) != current_hash: logger.debug("Cache miss [%s]: hash mismatch for %s", step_name, rel_path) return False # ── 4. Artifact version check ── if artifact_versions: recorded_versions = manifest.get("artifact_versions", {}) for key, expected in artifact_versions.items(): if recorded_versions.get(key) != expected: logger.debug( "Cache miss [%s]: artifact version changed: %s (%s%s)", step_name, key, recorded_versions.get(key), expected, ) return False # ── 5. Required output files ── if required_outputs: for name in required_outputs: if not (output_dir / name).exists(): logger.debug("Cache miss [%s]: required output missing: %s", step_name, name) return False logger.debug("Cache hit [%s]: step is fresh, skipping", step_name) return True