This repository contains the metric implementations and released CSV results used in Diffusing in the Right Space: A Systematic Study of Latent Diffusability.
import torch
from metrics.viv import velocity_irreducible_variance
latents = torch.randn(50000, 32, 16, 16) # (B, C, H, W)
labels = torch.randint(0, 1000, (50000,)) # class labels
viv = velocity_irreducible_variance(
latents,
labels,
output_folder="outputs/viv_params",
)If output_folder is provided, the per-class spectra
are saved to latents_params.pt.
import torch
from metrics.lnc import latent_neighbor_consistency
latents = torch.randn(5000, 32, 16, 16) # (N, C, h, w)
masks = torch.randint(0, 2, (5000, 256, 256)) # foreground masks
labels = torch.randint(0, 1000, (5000,)) # class labels
lnc = latent_neighbor_consistency(
latents,
masks,
labels,
num_images_per_class=50,
sim_metric="cosine",
balance_classes=False,
)import torch
from metrics.sec import spectral_energy_concentration
latents = torch.randn(16, 32, 64, 64) # (B, C, H, W)
sec = spectral_energy_concentration(
latents,
thresholds=[0.25, 0.5],
dist_type="Manhattan",
)All raw data we used in the paper are organized in data folder.
- Latent property:
data/metrics.csv - Generation quality:
data/generate.csv - ODE straightness:
data/straightness.csv
plot_correlation.py merges CSV files, plots one metric against another,
and prints/saves the correlation scatter plot.
Example: plot generation quality against VIV for the SiT-B, CFG = 1.0, and
n_samples = 50000 rows.
python plot_correlation.py \
--m1_csv data/generate.csv \
--m1_col FID --m1_name gFID --inv_m1 \
--m2_csv data/metrics.csv \
--m2_col VIV \
--id_col VAE \
--cls_col Cluster --cls_color_csv misc/cls_color.csv \
--where "Diffusion=SiT-B AND CFG = 1.0 AND n_samples = 50000" \
-o results/fid_vs_viv.png