Learning Representations for Counterfactual Inference
Learning Representations for Counterfactual Inference
Johansson, Fredrik D. and Shalit, Uri and Sontag, David
2016

Paper summary
hlarochelle
This paper presents a method to train a neural network to make predictions for *counterfactual* questions. In short, such questions are questions about what the result of an intervention would have been, had a different choice for the intervention been made (e.g. *Would this patient have lower blood sugar had she received a different medication?*).
One approach to tackle this problem is to collect data of the form $(x_i, t_i, y_i^F)$ where $x_i$ describes a situation (e.g. a patient), $t_i$ describes the intervention made (in this paper $t_i$ is binary, e.g. $t_i = 1$ if a new treatment is used while $t_i = 0$ would correspond to using the current treatment) and $y_i^F$ is the factual outcome of the intervention $t_i$ for $x_i$. From this training data, a predictor $h(x,t)$ taking the pair $(x_i, t_i)$ as input and outputting a prediction for $y_i^F$ could be trained.
From this predictor, one could imagine answering counterfactual questions by feeding $(x_i, 1-t_i)$ (i.e. a description of the same situation $x_i$ but with the opposite intervention $1-t_i$) to our predictor and comparing the prediction $h(x_i, 1-t_i)$ with $y_i^F$. This would give us an estimate of the change in the outcome, had a different intervention been made, thus providing an answer to our counterfactual question.
The authors point out that this scenario is related to that of domain adaptation (more specifically to the special case of covariate shift) in which the input training distribution (here represented by inputs $(x_i,t_i)$) is different from the distribution of inputs that will be fed at test time to our predictor (corresponding to the inputs $(x_i, 1-t_i)$). If the choice of intervention $t_i$ is evenly spread and chosen independently from $x_i$, the distributions become the same. However, in observational studies, the choice of $t_i$ for some given $x_i$ is often not independent of $x_i$ and made according to some unknown policy. This is the situation of interest in this paper.
Thus, the authors propose an approach inspired by the domain adaptation literature. Specifically, they propose to have the predictor $h(x,t)$ learn a representation of $x$ that is indiscriminate of the intervention $t$ (see Figure 2 for the proposed neural network architecture). Indeed, this is a notion that is [well established][1] in the domain adaptation literature and has been exploited previously using regularization terms based on [adversarial learning][2] and [maximum mean discrepancy][3]. In this paper, the authors used instead a regularization (noted in the paper as $disc(\Phi_{t=0},\Phi_ {t=1})$) based on the so-called discrepancy distance of [Mansour et al.][4], adapting its use to the case of a neural network.
As an example, imagine that in our dataset, a new treatment ($t=1$) was much more frequently used than not ($t=0$) for men. Thus, for men, relatively insufficient evidence for counterfactual inference is expected to be found in our training dataset. Intuitively, we would thus want our predictor to not rely as much on that "feature" of patients when inferring the impact of the treatment.
In addition to this term, the authors also propose incorporating an additional regularizer where the prediction $h(x_i,1-t_i)$ on counterfactual inputs is pushed to be as close as possible to the target $y_{j}^F$ of the observation $x_j$ that is closest to $x_i$ **and** actually had the counterfactual intervention $t_j = 1-t_i$.
The paper first shows a bound relating the counterfactual generalization error to the discrepancy distance. Moreover, experiments simulating counterfactual inference tasks are presented, in which performance is measured by comparing the predicted treatment effects (as estimated by the difference between the observed effect $y_i^F$ for the observed treatment and the predicted effect $h(x_i, 1-t_i)$ for the opposite treatment) with the real effect (known here because the data is simulated). The paper shows that the proposed approach using neural networks outperforms several baselines on this task.
**My two cents**
The connection with domain adaptation presented here is really clever and enlightening. This sounds like a very compelling approach to counterfactual inference, which can exploit a lot of previous work on domain adaptation.
The paper mentions that selecting the hyper-parameters (such as the regularization terms weights) in this scenario is not a trivial task. Indeed, measuring performance here requires knowing the true difference in intervention outcomes, which in practice usually cannot be known (e.g. two treatments usually cannot be given to the same patient once). In the paper, they somewhat "cheat" by using the ground truth difference in outcomes to measure out-of-sample performance, which the authors admit is unrealistic. Thus, an interesting avenue for future work would be to design practical hyper-parameter selection procedures for this scenario. I wonder whether the *reverse cross-validation* approach we used in our work on our adversarial approach to domain adaptation (see [Section 5.1.2][5]) could successfully be used here.
Finally, I command the authors for presenting such a nicely written description of counterfactual inference problem setup in general, I really enjoyed it!
[1]: https://papers.nips.cc/paper/2983-analysis-of-representations-for-domain-adaptation.pdf
[2]: http://arxiv.org/abs/1505.07818
[3]: http://ijcai.org/Proceedings/09/Papers/200.pdf
[4]: http://www.cs.nyu.edu/~mohri/pub/nadap.pdf
[5]: http://arxiv.org/pdf/1505.07818v4.pdf#page=16
Learning Representations for Counterfactual Inference

Johansson, Fredrik D. and Shalit, Uri and Sontag, David

arXiv e-Print archive - 2016 via Local Bibsonomy

Keywords: dblp

Johansson, Fredrik D. and Shalit, Uri and Sontag, David

arXiv e-Print archive - 2016 via Local Bibsonomy

Keywords: dblp

You must log in before you can submit this summary! Your draft will not be saved!

Preview:

About