"""Study Knowledge Base — YAML-driven ground truth for variable mappings."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import yaml
from scripts.ai_assistant.file_access import validate_agent_read
_DEFAULT_YAML = Path(__file__).resolve().parents[2] / "config" / "study_knowledge.yaml"
[docs]
class StudyKnowledge:
"""Provides deterministic lookups for variable mappings, value encodings,
dataset relationships, and outcome definitions from study_knowledge.yaml."""
def __init__(self, yaml_path: Path | None = None) -> None:
path = yaml_path or _DEFAULT_YAML
if not path.is_file():
raise FileNotFoundError(f"Study knowledge YAML not found: {path}")
validated = validate_agent_read(path)
with validated.open() as fh:
self._data: dict[str, Any] = yaml.safe_load(fh)
self._cohorts: dict[str, Any] = self._data.get("cohorts", {})
self._datasets: dict[str, Any] = self._data.get("dataset_relationships", {})
self._study: dict[str, Any] = self._data.get("study", {})
# ── public API ──────────────────────────────────────────────────
@property
def study_name(self) -> str:
return str(self._study.get("name", ""))
@property
def study_description(self) -> str:
return str(self._study.get("description", ""))
[docs]
def list_cohorts(self) -> list[str]:
return list(self._cohorts.keys())
[docs]
def get_cohort(self, cohort_id: str) -> dict[str, Any]:
if cohort_id not in self._cohorts:
raise ValueError(f"Unknown cohort '{cohort_id}'. Available: {self.list_cohorts()}")
return dict(self._cohorts[cohort_id])
[docs]
def list_concepts(self, cohort_id: str) -> list[str]:
cohort = self.get_cohort(cohort_id)
concepts: list[str] = []
concepts.extend(cohort.get("demographics", {}).keys())
concepts.extend(cohort.get("predictors", {}).keys())
return concepts
[docs]
def resolve_concept(self, concept: str, cohort_id: str) -> dict[str, Any]:
cohort = self.get_cohort(cohort_id)
# Search demographics first, then predictors
for section_key in ("demographics", "predictors"):
section = cohort.get(section_key, {})
if concept in section:
result = dict(section[concept])
result["section"] = section_key
return result
raise KeyError(
f"Unknown concept '{concept}' in cohort '{cohort_id}'. "
f"Available: {self.list_concepts(cohort_id)}"
)
[docs]
def get_outcome(self, cohort_id: str, outcome_name: str) -> dict[str, Any]:
cohort = self.get_cohort(cohort_id)
outcomes = cohort.get("outcomes", {})
if outcome_name not in outcomes:
raise KeyError(
f"Unknown outcome '{outcome_name}' in cohort '{cohort_id}'. "
f"Available: {list(outcomes.keys())}"
)
return dict(outcomes[outcome_name])
[docs]
def get_value_encoding(self, column: str, cohort_id: str) -> dict[str, Any]:
cohort = self.get_cohort(cohort_id)
for section_key in ("demographics", "predictors"):
for info in cohort.get(section_key, {}).values():
if info.get("column") == column:
result: dict[str, Any] = {"column": column, "type": info.get("type")}
if "encoding" in info:
result["encoding"] = info["encoding"]
if "binary_map" in info:
result["binary_map"] = info["binary_map"]
if "valid_range" in info:
result["valid_range"] = info["valid_range"]
return result
raise KeyError(f"Column '{column}' not found in cohort '{cohort_id}'")
[docs]
def get_join_plan(self, cohort_id: str, concepts: list[str]) -> list[dict[str, Any]]:
cohort = self.get_cohort(cohort_id)
datasets_needed: dict[str, set[str]] = {}
# Always need screening dataset for demographics
screening = cohort.get("screening_dataset", "")
if screening:
datasets_needed[screening] = set()
for concept in concepts:
try:
info = self.resolve_concept(concept, cohort_id)
ds = info.get("dataset", "")
col = info.get("column", "")
if ds:
datasets_needed.setdefault(ds, set())
if col:
datasets_needed[ds].add(col)
except KeyError:
# Check derived variables
derived = cohort.get("derived_variables", {})
if concept in derived:
for source in derived[concept].get("sources", []):
try:
src_info = self.resolve_concept(source, cohort_id)
ds = src_info.get("dataset", "")
col = src_info.get("column", "")
if ds:
datasets_needed.setdefault(ds, set())
if col:
datasets_needed[ds].add(col)
except KeyError:
pass
join_key = self._datasets.get("join_key", "SUBJID")
plan: list[dict[str, Any]] = []
for ds, cols in datasets_needed.items():
ds_info = self._datasets.get("datasets", {}).get(ds, {})
plan.append(
{
"dataset": ds,
"columns": sorted(cols) if cols else ds_info.get("key_columns", []),
"join_key": join_key,
"form": ds_info.get("form", ""),
"description": ds_info.get("description", ""),
}
)
return plan
[docs]
def get_derived_variable(self, name: str, cohort_id: str) -> dict[str, Any]:
cohort = self.get_cohort(cohort_id)
derived = cohort.get("derived_variables", {})
if name not in derived:
raise KeyError(
f"Unknown derived variable '{name}' in cohort '{cohort_id}'. "
f"Available: {list(derived.keys())}"
)
return dict(derived[name])
[docs]
def get_default_outcome(self, cohort_id: str) -> tuple[str, dict[str, Any]]:
cohort = self.get_cohort(cohort_id)
outcomes = cohort.get("outcomes", {})
if not outcomes:
raise ValueError(f"No outcomes defined for cohort '{cohort_id}'")
name = next(iter(outcomes))
return name, dict(outcomes[name])