Measure-Theoretic Anti-Causal Representation Learning
Abstract
ACIA Framework Overview
- Low-Level Representation ($\mathcal{Z}_L$): Captures the causal latent dynamics of how the label $Y$ generates the features $X$ (as per Theorem 3 in the paper).
- High-Level Representation ($\mathcal{Z}_H$): Learns the causal invariant abstraction (as per Theorem 4), ensuring the final prediction is independent of environment factors.
Results and Performance
Baseline Comparisons Across Datasets
ACIA achieves state-of-the-art performance across four benchmark datasets with superior invariance metrics.
| Method | CMNIST | RMNIST | Ball Agent | Camelyon17 | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Acc↑ | EI↓ | LLI↓ | IR↓ | Acc↑ | EI↓ | LLI↓ | IR↓ | Acc↑ | EI↓ | LLI↓ | IR↓ | Acc↑ | EI↓ | LLI↓ | IR↓ | |
| GDRO | 92.00 | 1.85 | 0.80 | 0.91 | 63.00 | 16.03 | 4.10 | 1.53 | 66.00 | 1.04 | 0.69 | 0.75 | 58.00 | 1.32 | 0.87 | 0.83 |
| MMD | 94.00 | 1.22 | 1.73 | 1.13 | 92.00 | 6.88 | 15.62 | 0.69 | 68.50 | 1.13 | 0.87 | 1.82 | 60.00 | 4.43 | 2.16 | 1.12 |
| CORAL | 89.00 | 1.48 | 2.06 | 1.30 | 91.00 | 4.02 | 9.56 | 0.31 | 70.50 | 1.23 | 1.92 | 1.84 | 41.00 | 1.62 | 2.45 | 1.01 |
| DANN | 45.00 | 0.03 | 0.86 | 0.20 | 38.50 | 12.82 | 3.85 | 1.47 | 61.00 | 1.35 | 0.96 | 0.89 | 39.00 | 0.68 | 1.40 | 1.95 |
| IRM | 85.00 | 1.43 | 0.83 | 1.08 | 85.50 | 19.03 | 6.64 | 3.17 | 56.00 | 0.89 | 0.67 | 1.71 | 52.00 | 1.95 | 1.76 | 2.45 |
| Rex | 73.00 | 0.69 | 1.41 | 1.80 | 80.50 | 0.69 | 10.69 | 0.96 | 54.50 | 1.05 | 0.11 | 7.65 | 39.00 | 0.68 | 1.40 | 1.95 |
| VREx | 95.50 | 1.71 | 1.09 | 0.77 | 93.50 | 2.41 | 2.77 | 1.03 | 74.00 | 0.93 | 0.78 | 0.73 | 54.50 | 1.98 | 1.78 | 1.02 |
| ACTIR | 78.50 | 0.64 | 0.97 | 1.80 | 72.00 | 0.23 | 18.79 | 0.19 | 69.00 | 0.88 | 0.02 | 0.58 | 60.50 | 0.60 | 0.63 | 0.80 |
| CausalDA | 83.50 | 0.41 | 0.85 | 12.23 | 87.50 | 0.62 | 0.91 | 16.44 | 45.50 | 1.20 | 0.85 | 1.22 | 55.50 | 0.55 | 1.60 | 10.55 |
| LECI | 70.00 | 0.83 | 0.40 | 0.67 | 82.00 | 0.29 | 2.91 | 0.04 | 71.20 | 0.46 | 0.39 | 0.05 | 65.50 | 0.23 | 0.50 | 0.45 |
| ACIA (Ours) | 99.20 | 0.00 | 0.01 | 0.02 | 99.10 | 0.00 | 0.03 | 0.01 | 99.98 | 0.52 | 0.03 | 0.03 | 84.40 | 0.28 | 0.42 | 0.43 |
Metrics: Acc = Accuracy, EI = Environment Independence, LLI = Low-Level Invariance, IR = Invariance Ratio. ↑ indicates higher is better, ↓ indicates lower is better. Bold indicates best performance, italic indicates second-best.
Installation & Usage (ACIA Library)
The Anti-Causal Invariant Abstractions (ACIA) framework is available as a Python package authored by Arman Behnam, requiring Python >=3.8.
Installation
Install directly from GitHub:
# Clone the repository
git clone https://github.com/ArmanBehnam/ACIA.git
cd ACIA
# Install the package
pip install -e .
# Or install with development dependencies
pip install -e ".[dev]"
Quick Start
After installation, train on ColoredMNIST:
from acia import ColoredMNIST, CausalRepresentationNetwork, CausalOptimizer
from torch.utils.data import DataLoader, ConcatDataset
# Load datasets
train_e1 = ColoredMNIST(env='e1', train=True)
train_e2 = ColoredMNIST(env='e2', train=True)
train_loader = DataLoader(ConcatDataset([train_e1, train_e2]), batch_size=128)
# Train model
model = CausalRepresentationNetwork()
optimizer = CausalOptimizer(model, batch_size=128)
for x, y, e in train_loader:
metrics = optimizer.train_step(x, y, e)
Or run the complete example:
python examples/example_colored_mnist.py
Acknowledgments
We thank the anonymous reviewers for their valuable and constructive feedback. This work was supported in part by the Cisco Research Award and by the National Science Foundation under Grant Nos. ECCS-2216926, CCF-2331302, CNS-2241713, and CNS-2339686.
Citation
title={Measure-Theoretic Anti-Causal Representation Learning},
author={Behnam, Arman and Wang, Binghui},
booktitle={Advances in Neural Information Processing Systems},
year={2025}
}
Website adapted for ArmanBehnam/anticausal-neurips2025. Code is available at https://github.com/ArmanBehnam/ACIA.