import h5py
import numpy as np
import xml.etree.ElementTree as ET
from pathlib import Path

# ---------------------------------------------------------------
# Parameters - edit these to match your simulation
# ---------------------------------------------------------------
rho = 1.225   # kg/m^3 - fluid density
U_H = 1.6   # m/s    - mean velocity at the building height of interest

# File paths - adjust to the location of your downloaded files
base = Path(".")
body_h5   = base / "bodies.body_Cp body.h5"
ref_h5    = base / "points.point_Cp ref.h5"
body_xdmf = base / "bodies.body_Cp body.xdmf"
# ---------------------------------------------------------------

q = 0.5 * rho * U_H**2
print(f"Dynamic pressure q = {q:.4f} Pa  (rho={rho}, U_H={U_H})")

# Step 1 - load reference pressure time series
print("Reading reference pressure ...")
with h5py.File(ref_h5, "r") as f:
    ref_keys = sorted(f["pressure"].keys(), key=lambda k: float(k[1:]))
    p_inf = {k: float(f[f"pressure/{k}"][0]) for k in ref_keys}

# Step 2 - compute Cp and write to body h5
print("Computing Cp and writing to body h5 ...")
with h5py.File(body_h5, "a") as f:
    body_keys = sorted(f["pressure"].keys(), key=lambda k: float(k[1:]))

    if set(body_keys) != set(ref_keys):
        raise ValueError(
            "Body and reference files have different timestep keys. "
            "Ensure both exports were configured with the same time range and interval."
        )

    if "cp" in f:
        del f["cp"]
    cp_grp = f.create_group("cp")
    cp_grp.attrs["rho"] = rho
    cp_grp.attrs["U_H"] = U_H
    cp_grp.attrs["q"] = q

    n = len(body_keys)
    for i, key in enumerate(body_keys):
        p   = f[f"pressure/{key}"][:]
        cp  = (p - p_inf[key]) / q
        cp_grp.create_dataset(key, data=cp.astype(np.float32))
        if (i + 1) % 100 == 0 or (i + 1) == n:
            print(f"  {i + 1}/{n} timesteps", end="\r")

print(f"\nCp written for {n} timesteps -> {body_h5}")

# Step 3 - patch xdmf to expose the Cp field
print("Patching xdmf ...")
ET.register_namespace("", "")
tree = ET.parse(body_xdmf)
root = tree.getroot()

domain     = root.find("Domain")
collection = domain.find("Grid")   # temporal collection
h5_name    = body_h5.name

for grid in collection.findall("Grid"):
    # Skip if Cp attribute already present
    if grid.find("Attribute[@Name='Cp']") is not None:
        continue

    # Extract the timestep key from the existing pressure Attribute
    pressure_attr = grid.find("Attribute[@Name='pressure']")
    if pressure_attr is None:
        continue
    ref_text = pressure_attr.find("DataItem").text.strip()
    time_key = ref_text.split("/pressure/")[1]   # e.g. "t1999.783813"

    n_triangles = grid.find("Topology").attrib["NumberOfElements"]

    cp_attr = ET.SubElement(grid, "Attribute",
        Name="Cp", AttributeType="Scalar", Center="Cell")
    cp_item = ET.SubElement(cp_attr, "DataItem",
        Format="HDF", DataType="Float", Dimensions=n_triangles)
    cp_item.text = f"{h5_name}:/cp/{time_key}"

tree.write(str(body_xdmf), xml_declaration=True, encoding="UTF-8")
print(f"Patched {body_xdmf}")
print("Done. Open the xdmf file in ParaView. The 'Cp' field will be available alongside 'pressure'.")
