LLM / Robotics / Training
CausalMix: Data Mixture as Causal Inference for Language Model Training
** Zinan Tang, Yukun Zhang, Shaomian Zheng, Zhuoshi Pan, Qizhi Pei, Dingnan Jin, Jun Zhou, Yujun Wang, Biqing Huang
CausalMix: Data Mixture as Causal Inference for Language Model Training
Authors: Zinan Tang, Yukun Zhang, Shaomian Zheng, Zhuoshi Pan, Qizhi Pei, Dingnan Jin, Jun Zhou, Yujun Wang, Biqing Huang
arXiv ID: 2607.01104
Problem: Existing data mixture optimization methods assume static data distributions and require costly retraining from scratch when the data pool shifts, preventing scalable transfer across data pools and model sizes.
Key Methodology:
- Formulates data mixture optimization as a causal inference problem using Double Machine Learning (DML) with CausalForestDML to estimate state-conditioned marginal returns of domain proportions, treating data-state features (Normalized_Loss, Writing_Style, HES) as covariates and domain mixture as treatment.
- Fits a causal model on 512 proxy runs of Qwen2.5-0.5B (100K samples each) to estimate Conditional Average Treatment Effect (CATE), then extrapolates optimal mixture for an 800K data pool without requiring new proxy experiments.
- Uses a conservative policy extraction (analytical closed-form and search-based variants) constrained by a trust region to translate estimated marginal returns into feasible mixture weights on the simplex.
Key Results:
- CausalMix-A achieves AvgDev 33.94 (800K, Qwen2.5-0.5B) vs DMO at 32.04 and Equal at 31.78; CausalMix-S achieves AvgDev 62.28 (800K, Qwen2.5-7B) vs DMO at 60.35 and Equal at 60.02.
- On LongCoT data (Qwen3-4B), CausalMix achieves Avg 66.66 vs Grid 64.74 and RegMix 61.40, demonstrating strong transferability to unseen data pools and model architectures.
- Ablation shows removing DML orthogonalization drops AvgDev by ~1.3 points (0.5B) and ~2.6 points (7B); removing covariates (state-agnostic regression) also degrades performance.
Applied Context: Builders can use CausalMix to determine optimal SFT data mixtures for new datasets and larger models using only a one-time causal model fit on small proxy runs, eliminating the need to re-run expensive grid searches or retrain proxy models when data pools or model sizes change.