Source code for scripts.extraction.cleanup_propagation

"""Cleanup propagation — prune dictionary + PDF artifacts after dataset drops.

Runs against the staging workspace (``tmp/{STUDY_NAME}/{datasets,dictionary,pdfs}/``)
after :func:`scripts.extraction.dataset_cleanup.clean_trio_datasets` completes.

Dictionary and PDF legs carry no PHI and therefore emit no audit report — the
prune step is side-effect-only, keeping the dict and PDF schemas aligned with
the surviving dataset schema so the LLM sees no dangling references. The
dataset leg's own audit (``AUDIT_DATASET_REPORT_PATH``) remains the single
source of truth for what was removed.

Pruning rule
------------

A variable ``V`` is pruned from the dictionary and PDF legs **iff** it was
dropped from at least one dataset *and* never survives in any final
surviving dataset JSONL schema. Variables dropped from one dataset but kept
in another are NOT pruned.

Comparisons are case-folded; dataset provenance fields
(``source_file``, ``_provenance``, ``_metadata``) are excluded from
the surviving-vars set.
"""

from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Any

import config
from scripts.extraction._dict_keys import DICT_VAR_KEY as _DICT_VAR_KEY
from scripts.extraction.io import (
    JSONLParseError,
    atomic_write_json,
    atomic_write_jsonl,
    load_json_object_line,
)
from scripts.security.secure_env import assert_write_zone

logger = logging.getLogger(__name__)

__all__ = [
    "PROVENANCE_FIELDS",
    "compute_propagation_set",
    "prune_dictionary",
    "prune_pdfs",
    "run_propagation",
]


# ── Constants ───────────────────────────────────────────────────────────────

# Dataset-row metadata keys that are NOT "variables" — they should be excluded
# from the "surviving dataset vars" set so propagation doesn't treat them as
# schema members.
PROVENANCE_FIELDS: frozenset[str] = frozenset({"source_file", "_provenance", "_metadata"})


# ── Step 1: compute_propagation_set ─────────────────────────────────────────


[docs] def compute_propagation_set( audit_path: Path, datasets_dir: Path, ) -> set[str]: """Return the case-folded set of variables that should propagate-prune. Algorithm: 1. Load ``audit_path`` (the dataset leg's unified audit). Union all ``scope == "dataset-column"`` events' ``name`` into ``dataset_dropped_vars`` (case-folded). 2. Scan every ``datasets_dir/*.jsonl``. Union all row keys (excluding :data:`PROVENANCE_FIELDS`) into ``surviving_dataset_vars`` (case-folded). 3. Return ``dataset_dropped_vars - surviving_dataset_vars``. Variables dropped from one dataset but kept in another → excluded from the returned set (they "survive" somewhere). Missing audit or empty datasets dir → empty set. """ if not audit_path.exists(): level = logging.INFO if datasets_dir.is_dir() and any(datasets_dir.glob("*.jsonl")): level = logging.WARNING logger.log( level, "Dataset audit not found at %s — propagation set is empty " "(caller may have invoked run_propagation before clean_trio_datasets)", audit_path, ) return set() audit_payload = json.loads(audit_path.read_text(encoding="utf-8")) dropped: set[str] = set() for entry in audit_payload.get("removed", []): if entry.get("scope") == "dataset-column": name = entry.get("name") if isinstance(name, str) and name: dropped.add(name.casefold()) if not dropped: return set() surviving: set[str] = set() if datasets_dir.is_dir(): for jsonl_file in sorted(datasets_dir.glob("*.jsonl")): with jsonl_file.open("r", encoding="utf-8") as fh: for line_no, raw in enumerate(fh, start=1): raw = raw.strip() if not raw: continue try: row = load_json_object_line( raw, source_path=jsonl_file, line_number=line_no ) except JSONLParseError: logger.debug("Skipping malformed line %d in %s", line_no, jsonl_file.name) continue for key in row: if isinstance(key, str) and key not in PROVENANCE_FIELDS: surviving.add(key.casefold()) return dropped - surviving
# ── Step 2: prune_dictionary ────────────────────────────────────────────────
[docs] def prune_dictionary(drop_set: set[str], dict_dir: Path) -> int: """Walk ``dict_dir/**/*.jsonl`` and drop rows in ``drop_set``. Each row's :data:`_DICT_VAR_KEY` value is compared case-folded against ``drop_set`` (which callers pass pre-folded — see :func:`compute_propagation_set`). Matching rows are removed and the file is rewritten atomically. Returns the total number of rows removed across all files (for logging). No audit artifact is written — the dictionary leg carries no PHI. """ total_removed = 0 if not dict_dir.is_dir(): return 0 assert_write_zone(dict_dir) for jsonl_file in sorted(dict_dir.rglob("*.jsonl")): kept_rows: list[dict[str, Any]] = [] removed_in_file = 0 with jsonl_file.open("r", encoding="utf-8") as fh: for line_no, raw in enumerate(fh, start=1): raw = raw.strip() if not raw: continue try: row = load_json_object_line(raw, source_path=jsonl_file, line_number=line_no) except JSONLParseError: logger.debug("Skipping malformed line %d in %s", line_no, jsonl_file.name) continue var_name = row.get(_DICT_VAR_KEY) if isinstance(var_name, str) and var_name.casefold() in drop_set: removed_in_file += 1 else: kept_rows.append(row) if removed_in_file: atomic_write_jsonl(jsonl_file, kept_rows) total_removed += removed_in_file logger.info("Pruned %d rows from %s", removed_in_file, jsonl_file.name) return total_removed
# ── Step 3: prune_pdfs ──────────────────────────────────────────────────────
[docs] def prune_pdfs(drop_set: set[str], pdf_dir: Path) -> int: """Walk ``pdf_dir/*_variables.json`` and drop matching entries. For each JSON file: 1. Remove keys from the top-level ``variables: dict`` whose key (case-folded) is in ``drop_set``. 2. For each section in ``sections: dict``, remove matching entries from ``sections[name]["variables"]: list``. Modified files are rewritten atomically; unmodified files are left alone. Returns the total number of entries removed (vars + section refs). No audit artifact is written — the PDF leg carries no PHI. """ total_removed = 0 if not pdf_dir.is_dir(): return 0 assert_write_zone(pdf_dir) for pdf_json in sorted(pdf_dir.glob("*_variables.json")): try: raw_text = pdf_json.read_text(encoding="utf-8") data = json.loads(raw_text) except (json.JSONDecodeError, OSError) as exc: logger.warning("Skipping malformed PDF JSON %s: %s", pdf_json.name, exc) continue if not isinstance(data, dict): logger.warning("Unexpected top-level type in %s — skipping", pdf_json.name) continue changed = False removed_in_file = 0 # Top-level variables dict variables_dict = data.get("variables") if isinstance(variables_dict, dict): for var_key in list(variables_dict.keys()): if isinstance(var_key, str) and var_key.casefold() in drop_set: del variables_dict[var_key] changed = True removed_in_file += 1 # Per-section variables list sections_dict = data.get("sections") if isinstance(sections_dict, dict): for section_body in sections_dict.values(): if not isinstance(section_body, dict): continue section_vars = section_body.get("variables") if not isinstance(section_vars, list): continue kept: list[Any] = [] for entry in section_vars: if isinstance(entry, str) and entry.casefold() in drop_set: changed = True removed_in_file += 1 else: kept.append(entry) section_body["variables"] = kept if changed: atomic_write_json(pdf_json, data) total_removed += removed_in_file logger.info("Pruned %d entries in %s", removed_in_file, pdf_json.name) return total_removed
# ── Step 4: run_propagation ─────────────────────────────────────────────────
[docs] def run_propagation() -> None: """Orchestrate the propagation: compute drop set, prune dict + PDF legs. All paths resolved from ``config.STAGING_*`` and ``config.AUDIT_*`` — never touches the promoted trio bundle directly. Dict + PDF legs emit no audit report (no PHI); only their prune counts are logged. """ drop_set = compute_propagation_set( config.AUDIT_DATASET_REPORT_PATH, config.STAGING_DATASETS_DIR, ) logger.info( "Propagation drop-set (%d vars): %s", len(drop_set), sorted(drop_set), ) dict_removed = prune_dictionary(drop_set, config.STAGING_DICTIONARY_DIR) pdf_removed = prune_pdfs(drop_set, config.STAGING_PDFS_DIR) logger.info( "Propagation complete: %d dictionary rows + %d PDF entries pruned", dict_removed, pdf_removed, )