Measure-Theoretic Anti-Causal Representation Learning

Conceptual Diagram of ACIA Framework

Abstract

Causal representation learning in the anti-causal setting—labels cause features rather than the reverse—presents unique challenges requiring specialized approaches. We propose Anti-Causal Invariant Abstractions (ACIA), a novel measure-theoretic framework for anti-causal representation learning. ACIA employs a two-level design: low-level representations capture how labels generate observations, while high-level representations learn stable causal patterns across environment-specific variations. ACIA addresses key limitations of existing approaches by: (1) accommodating perfect and imperfect interventions through interventional kernels, (2) eliminating dependency on explicit causal structures, (3) handling high-dimensional data effectively, and (4) providing theoretical guarantees for out-of-distribution generalization. Experiments on synthetic and real-world medical datasets demonstrate that ACIA consistently outperforms state-of-the-art methods in both accuracy and invariance metrics. Furthermore, our theoretical results establish tight bounds on performance gaps between training and unseen environments, confirming the efficacy of our approach for robust anti-causal learning.

ACIA Framework Overview

Our approach introduces a novel measure-theoretic framework for learning in anti-causal domains ($Y \rightarrow X \leftarrow E$). The method consists of three key algorithms: Causal Dynamics Algorithm, Causal Abstraction Algorithm, and OOD Optimization Algorithm. The framework decomposes the problem into two key representation spaces:
  • Low-Level Representation ($\mathcal{Z}_L$): Captures the causal latent dynamics of how the label $Y$ generates the features $X$ (as per Theorem 3 in the paper).
  • High-Level Representation ($\mathcal{Z}_H$): Learns the causal invariant abstraction (as per Theorem 4), ensuring the final prediction is independent of environment factors.
This two-level approach enables OOD generalization by learning from both perfect and imperfect interventions.

Results and Performance

Baseline Comparisons Across Datasets

ACIA achieves state-of-the-art performance across four benchmark datasets with superior invariance metrics.

Method CMNIST RMNIST Ball Agent Camelyon17
Acc↑ EI↓ LLI↓ IR↓ Acc↑ EI↓ LLI↓ IR↓ Acc↑ EI↓ LLI↓ IR↓ Acc↑ EI↓ LLI↓ IR↓
GDRO 92.00 1.85 0.80 0.91 63.00 16.03 4.10 1.53 66.00 1.04 0.69 0.75 58.00 1.32 0.87 0.83
MMD 94.00 1.22 1.73 1.13 92.00 6.88 15.62 0.69 68.50 1.13 0.87 1.82 60.00 4.43 2.16 1.12
CORAL 89.00 1.48 2.06 1.30 91.00 4.02 9.56 0.31 70.50 1.23 1.92 1.84 41.00 1.62 2.45 1.01
DANN 45.00 0.03 0.86 0.20 38.50 12.82 3.85 1.47 61.00 1.35 0.96 0.89 39.00 0.68 1.40 1.95
IRM 85.00 1.43 0.83 1.08 85.50 19.03 6.64 3.17 56.00 0.89 0.67 1.71 52.00 1.95 1.76 2.45
Rex 73.00 0.69 1.41 1.80 80.50 0.69 10.69 0.96 54.50 1.05 0.11 7.65 39.00 0.68 1.40 1.95
VREx 95.50 1.71 1.09 0.77 93.50 2.41 2.77 1.03 74.00 0.93 0.78 0.73 54.50 1.98 1.78 1.02
ACTIR 78.50 0.64 0.97 1.80 72.00 0.23 18.79 0.19 69.00 0.88 0.02 0.58 60.50 0.60 0.63 0.80
CausalDA 83.50 0.41 0.85 12.23 87.50 0.62 0.91 16.44 45.50 1.20 0.85 1.22 55.50 0.55 1.60 10.55
LECI 70.00 0.83 0.40 0.67 82.00 0.29 2.91 0.04 71.20 0.46 0.39 0.05 65.50 0.23 0.50 0.45
ACIA (Ours) 99.20 0.00 0.01 0.02 99.10 0.00 0.03 0.01 99.98 0.52 0.03 0.03 84.40 0.28 0.42 0.43

Metrics: Acc = Accuracy, EI = Environment Independence, LLI = Low-Level Invariance, IR = Invariance Ratio. ↑ indicates higher is better, ↓ indicates lower is better. Bold indicates best performance, italic indicates second-best.

Installation & Usage (ACIA Library)

The Anti-Causal Invariant Abstractions (ACIA) framework is available as a Python package authored by Arman Behnam, requiring Python >=3.8.

Installation

Install directly from GitHub:

# Clone the repository
                git clone https://github.com/ArmanBehnam/ACIA.git
                cd ACIA

                # Install the package
                pip install -e .

                # Or install with development dependencies
                pip install -e ".[dev]"

Quick Start

After installation, train on ColoredMNIST:

from acia import ColoredMNIST, CausalRepresentationNetwork, CausalOptimizer
                from torch.utils.data import DataLoader, ConcatDataset

                # Load datasets
                train_e1 = ColoredMNIST(env='e1', train=True)
                train_e2 = ColoredMNIST(env='e2', train=True)
                train_loader = DataLoader(ConcatDataset([train_e1, train_e2]), batch_size=128)

                # Train model
                model = CausalRepresentationNetwork()
                optimizer = CausalOptimizer(model, batch_size=128)

                for x, y, e in train_loader:
                    metrics = optimizer.train_step(x, y, e)

Or run the complete example:

python examples/example_colored_mnist.py

Acknowledgments

We thank the anonymous reviewers for their valuable and constructive feedback. This work was supported in part by the Cisco Research Award and by the National Science Foundation under Grant Nos. ECCS-2216926, CCF-2331302, CNS-2241713, and CNS-2339686.

Citation

@inproceedings{behnam2025anticausal,
   title={Measure-Theoretic Anti-Causal Representation Learning},
   author={Behnam, Arman and Wang, Binghui},
   booktitle={Advances in Neural Information Processing Systems},
   year={2025}
}

Website adapted for ArmanBehnam/anticausal-neurips2025. Code is available at https://github.com/ArmanBehnam/ACIA.