Causality

Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms

밤 편지 2025. 2. 25. 07:29

https://arxiv.org/abs/2101.10943

https://github.com/AliciaCurth/CATENets

(AISTATS 2021)

 

https://www.youtube.com/watch?v=CkQCwh50SEk

★ 동영상을 꼭 봐야한다!! 논문에서 설명하지 않아서 이해하지 못했던 개념들을 명확하게 설명해주신다. ★

e.g. pseudo-outcomes, etc.


Abstract

The need to evaluate treatment effectiveness is ubiquitous in most of empirical science, and interest in flexibly investigating effect heterogeneity is growing rapidly. To do so, a multitude of model-agnostic, nonparametric meta-learners have been proposed in recent years. Such learners decompose the treatment effect estimation problem into separate sub-problems, each solvable using standard supervised learning methods. Choosing between different meta-learners in a data-driven manner is difficult, as it requires access to counterfactual information. Therefore, with the ultimate goal of building better understanding of the conditions under which some learners can be expected to perform better than others a priori, we theoretically analyze four broad meta-learning strategies which rely on plug-in estimation and pseudo-outcome regression. We highlight how this theoretical reasoning can be used to guide principled algorithm design and translate our analyses into practice by considering a variety of neural network architectures as base-learners for the discussed meta-learning strategies. In a simulation study, we showcase the relative strengths of the learners under different data-generating processes.


1. Introduction

Many empirical scientists ultimately aim to assess the causal effects of interventions, policies and treatments by analyzing experimental or observational data using tools from applied statistics. Due to the impressive performance of machine learning (ML) methods on prediction tasks, recent years have seen exciting developments incorporating ML into the estimation of average treatment effects (van der Laan and Rose, 2011; Chernozhukov et al., 2017). While average treatment effects (ATEs) have been the main estimand of interest thus far, data-adaptive, ML-based estimators have even more potential to shape our ability to flexibly investigate heterogeneity of effects across populations. As interest moves towards personalized policy and treatment design in fields such as econometrics and medicine, the need to accurately estimate the full conditional average treatment effect (CATE) function becomes ubiquitous.

 

The causal inference communities across disciplines have produced a rapidly growing number of algorithms for CATE estimation in recent years (see e.g. Bica et al. (2020) for an overview). In practice, this leads to the need to select the best model – which is notoriously difficult in treatment effect studies because the ground truth is unobserved. While recent literature has presented promising solutions using data-driven strategies (Rolling and Yang, 2014; Alaa and Van Der Schaar, 2019), we believe that it is equally important to reduce the complexity of the selection task a priori by building greater systematic understanding of the strengths and weaknesses of different algorithms from a theoretical viewpoint.

 

Here, we put our focus on comparing different so-called meta-learners for binary treatment effect estimation, which are model-agnostic algorithms that decompose the task of estimating CATE into multiple sub-problems, each solvable using any supervised learning / regression method (K¨unzel et al., 2019). In the theoretical part of this paper (Sections 3 and 4), we consider estimation within a generic nonparametric regression framework, i.e. we assume no known parametric structure, and derive theoretical arguments why one learner may outperform others. In the more practical part (Sections 5 and 6), we compare the empirical performance of the different learners using the same underlying machine learning method, and consider a variety of neural network (NN) architectures for CATE estimation. Throughout, instead of arguing that one learning algorithm is superior to all others, we aim to highlight how expert knowledge on the underlying data-generating process (DGP) can narrow the choice of algorithms a priori and guide model design.

Contributions

Our contributions are three-fold: First, we provide theoretical insights into nonparametric CATE estimation using meta-learners. We propose a new classification of meta-learners inspired by the ATE estimator taxonomy, categorizing algorithms into four broad classes: one-step plug-in learners and three types of two-step learners, which use unbiased pseudo-outcomes based on regression adjustment (RA), propensity weighting (PW) or both (DR), as illustrated in Figure 1. We present an analysis of the theoretical properties of the learners and discuss resulting theoretical criteria for choosing between them. While both plug-in and DR-learner have been previously analyzed, our analysis and discussion of RA- and PW-learner are – to the best of our knowledge – new.

 

Second, we compare four existing model architectures for CATE estimation using NNs and propose a new architecture which generalizes existing approaches. These architectures allow for different degrees of information sharing between nuisance parameter estimators, and we highlight the (dis-)advantages of different architectures. We also provide a suite of sklearn-style implementations for all architectures and meta-learners we consider.

 

Third, we illustrate our theoretical arguments in simulation experiments, demonstrating how differences in DGPs and sample size influence the relative performance of different learners empirically. By considering how to best combine model architectures and meta-learner strategies, we also attempt to bridge the gap between the relatively disjoint literatures on meta-learners and end-to-end CATE estimation using NNs.


1.1. Related Work

We restrict our attention to so-called ‘meta-learners’ for CATEmodel-agnostic algorithms that can be implemented using any arbitrary ML method. K¨unzel et al. (2019) appear to be the first paper explicitly discussing meta-learning strategies for CATE estimation, of which they consider and named three in detail: The S-learner (single learner), in which the treatment indicator is simply included as an additional feature in otherwise standard regression, the T-learner (two learners) which fits separate regression functions for each treatment group and then takes differences, and the X-learner, a two-step regression estimator that uses each observation twice (see discussion in Section 3). Nie and Wager (2020) propose a two-step algorithm that estimates CATE using orthogonalization with respect to the nuisance functions, which they dub R-learner as it is based on Robinson (1988). Finally, Kennedy (2020) proposes the DR-learner (doubly robust learner), a two-step algorithm that uses the expression for the doubly robust augmented inverse propensity weighted (AIPW) estimator (Robins and Rotnitzky, 1995) as a pseudo-outcome in a two-step regression set-up. With the exception of the DR-learner, the current naming strategy of meta-learners is surprisingly disjoint from the naming of ATE estimators, resulting in names that do not necessarily reflect the statistical concepts the learners are based on. To build better intuition and to facilitate principled theoretical analyses, we re-categorize meta-learning strategies in Section 3.

 

While our theoretical analyses of meta-learners apply to generic estimators, in the practical part of this paper we focus on instantiations using standard feed-forward neural networks (NNs) and NN-based representation learning


2. Problem Definition


3. Categorizing CATE Meta-Learners

To improve our ability to analyze the meta-learner‘s theoretical performance in a structured manner in the following section, we propose a high-level classification of CATE meta-learners that follows the well-known ATE taxonomy of estimators (see e.g. Imbens (2004)). We prefer this naming strategy because it builds on existing intuition and classifies learners by the characteristics that reflect their statistical properties. Therefore, we suggest to divide meta-learners into one-step plug-in learners – learners that output two regression functions which can then be differenced – and two-step learners, based on a regression adjustment (RA), propensity weighting (PW) or doubly robust (DR) strategy, outputting a CATE function directly. These strategies are illustrated in Figure 1.


5. CATE Estimation Using Neural Networks

In the practical part of this paper, we use feed-forward NNs as nuisance estimators for each meta-learner. The simplest NN-based implementation of each learner consists of using a separate network for each regression task (a TNet), which would be a good choice asymptotically, allowing for arbitrarily different regression surfaces.

 

However, as we alluded to in Section 3, it can be useful to share information between nuisance estimation tasks in finite samples. That is, it may be more efficient to share data between the two regression tasks if µ1(x) and µ0(x) are similar, which would be the case if they were supported on similar covariates and τ(x) is not too complex. Therefore, next to a simple T-learner (TNet), we consider a class of model architectures we refer to as SNets because they are based on sharing information between nuisance estimation tasks using representation learning. The resulting one-step architectures for nuisance estimation are visualized in Fig. 2, and we discuss implementation details and loss-functions in the supplement.

Existing SNet Architectures

Building on the success of representation learning on a variety of learning tasks (Bengio et al., 2013), Shalit et al. (2017) introduce the idea of learning a shared input representation for the two potential outcome regressions. Formally, this entails jointly learning a map Φ : X → R, representing all data in a new space, and two regression heads µ_w : R → Y, fit using only the data of the corresponding treatment group. This results in TARNet (Shalit et al., 2017), which we will also refer to as SNet-1 because it results in the simplest way of sharing information between tasks. Shi et al. (2019)’s DragonNet (SNet-2) takes this idea one step further and learns a representation space from which both π(x) and µ_w(x) can be learned, ensuring that confounders are sufficiently controlled for. While Shi et al. (2019) used this architecture, combined with ideas from the targeted maximum likelihood estimation framework, to estimate average treatment effects, we use only their model architecture for CATE estimation. Finally, Hassanpour and Greiner (2020)’s DR-CFR (SNet-3) learns three representations which are used to model either propensity score, potential outcome regressions or both, respectively.

Underlying Assumptions and a New Architecture

Each existing architecture reflects implicit assumptions on the underlying structure of the CATE estimation problem. Therefore, explicitly characterizing assumptions can help choosing between architectures in practice, and guide improvements by identifying shortcomings. SNet-1 builds on the assumption that there exists a common feature space underlying both {µ_w(x)}w∈{0,1}, while SNet-2 assumes that π(x) can also be represented in the same space. SNet-3 is built on the assumption that there exist three separate sets of features, determining µ_w(x) and/or π(x).

 

Reflecting on these assumptions, we note that there is one important case missing: We wish to explicitly allow for the existence of features that affect only one of the potential outcome functions. This is driven by medical applications where one distinguishes between markers that are prognostic (of outcome) regardless of treatment status or predictive (of treatment effectiveness) (Ballman, 2015). Therefore, we propose a final model architecture that learns five representations3, also allowing potential-outcome regressions to depend on only a subset of shared features. We will simply refer to this architecture as SNet because it encompasses all existing SNet-architectures and even the TNet as special cases. That is, if we increase the width of the µ_w(x)-specific representations while reducing the shared representations, SNet will approach TNet. If we do the reverse, the proposed architecture approaches SNet-3, which can in turn become either SNet-1 or SNet-2 by changing what is shared between µw(x) and π(x). We expect that such a general architecture should perform best on average in absence of knowledge on the underlying problem structure, but may be more difficult to fit than simpler models.

Further Practical Considerations

Two-step learners are most commonly implemented using independently trained “vanilla” nuisance estimators, yet, as we investigate in the experiments, all SNet architectures could also be used to estimate nuisance parameters in the first step. Further, while our theoretical analyses rely on sample splitting for two-step learners, we observed that using all data for both steps can work better in practice (particularly in small samples). If estimation with theoretical guarantees is desired yet data is scarce, it can be useful to rely not on sample splitting, but on cross-fitting (Chernozhukov et al., 2018) to obtain valid nuisance function estimates for all observations in the sample, which can then all be used for a second stage regression. Which strategy to use involves a trade-off between precision in estimation and computational complexity. While we implement all strategies in our code base, we do not use any form of sample splitting in our experiments. Finally, a convenient by-product of splitting the causal inference task into a series of multiple supervised learning tasks is that hyper-parameters can be tuned using factual hold-out data only – which can be done in both regression stages, if data is split appropriately.


6. Experiments

In this section we supplement our theoretical analyses with experimental evidence to demonstrate the empirical performance of all learners under different DGPs, using both fully synthetic data and the well-known semi-synthetic IHDP benchmark. In addition to verifying the theoretical properties of the different learners empirically, we consider it of particular interest to gain insight into how to best use the different NN architectures as nuisance estimators for the two-step learners.

 

Throughout the experiments, we fixed equivalent hyper-parameters across all model architectures (based on those used in Shalit et al. (2017)), ensuring that every estimator (output head) has access to the same total amount of hidden layers and units, and effectively used each learner ‘off-the-shelf’. To ensure fair comparison across learners, we implemented every architecture in our own python code base; for implementation details, refer to the supplement. Throughout, we consider performance in terms of the Root Mean Squared Error (RMSE) of estimation of τ (x), also sometimes referred to as the precision of estimating heterogeneous effects (PEHE) criterion (Hill, 2011).


6,1. Synthetic Experiments

We simulate data to investigate the relative performance of the different learners across different underlying DGPs and sample sizes. Throughout, we consider d = 25 multivariate normal covariates, of which we let subsets determine µ_w(x) and π(x). To highlight scenarios under which different learners can be expected to perform well, we consider three (highly stylized) settings: (i)τ(x) = 0, and µ0(x) depends on 5 covariates influencing outcome and 5 confounders (which influence also π(x)), (ii) the same set-up as (i), but with τ(x) nonzero and supported on 5 additional covariates and (iii) µ1(x) and µ0(x) depend on disjoint covariate sets, making τ(x) the most difficult function to estimate (no confounders). All DGPs are discussed in detail in the supplement. In all simulations, we evaluate performance on 500 independently generated test-observations, and average across 10 runs.

 

When comparing the 5 plug-in architectures (Fig. 3), we note that S-architectures always improve upon the TNet in small samples when {µ_w}w∈{0,1} share some structure. Even when µ1(x) and µ0(x) are very different, shared layers can help by filtering out noise covariates. Further, the flexibility of the respective SNet architecture seems to indeed drive performance. As expected, the general SNet performs best on average, and is the only architecture to outperform TNet when µ1(x) and µ0(x) are very different. SNet performs well also in the absence of any treatment effect, which we attribute to the architecture’s ability to learn that there are no predictive features to represent using the µ_w(x)-specific representations. Further, as expected, the strength of SNet relative to all variants becomes most apparent in relatively larger sample sizes.

 

When comparing the four learning strategies (Fig. 4), all based on TNet in the first step, we observe that the DR-learner substantially outperforms the others when there is confounding but no treatment effect (setting (i)). Conversely, when CATE is more complex than either potential outcome function (setting (iii)) – making the task of learning CATE directly more difficult than learning the potential outcome functions separately – the plug-in learner performs best, as expected. The RA-learner performed best when there is both confounding and a non-trivial treatment effect (setting (ii)). The PW-learner performed poorly in general, which is caused by the very low signal-to-noise ratio and high variance in the associated pseudo-outcome. While all other learning strategies performed very well using equivalent architectures, optimizing the PW-learner would require much stronger regularization and less flexibility (smaller networks) than what we considered here.

 

To gain further insight to the relative performance of different learning strategies, we interpolate between settings (i) and (ii) by gradually increasing the number of predictive features (features determining τ (x)) at n = 2000 in Fig. 5. We observe that the performance of the different S-architectures degrades relative to TNet as the number of predictive features increases, which is to be expected as µ0(x) and µ1(x) become less similar. We also observe that while the performance gap between TNet and RA-learner remains virtually constant, the DR-learner loses its advantage as CATE becomes less sparse.

 

In Fig. 6 we reconsider setting (i) but with imbalance in addition to confounding, by gradually changing the proportion of treated individuals. Comparing the different plug-in architectures, we observe that information sharing has a much larger added value when samples are highly imbalanced. Comparing the different learners, we observe that the performance gap between T- and DR-learner is not impacted by the proportion of treated individuals, but that the RA-learner outperforms the T-learner only for moderate to no imbalance.

 

Finally, we consider how to best combine nuisance estimators and two-step learners (Fig. 7). We reconsider settings (i) and (ii) where DR- and RA-learner performed best, and use SNet-1, which does not estimate π(x), and the more general SNet, which does, as nuisance estimators. For the RA-learner, we observe that using SNet, which has better performance on its own, leads to the best RA-learner and conclude that in practice the RA-learner should be combined with an architecture that is expected to best capture the underlying DGP. For the DR-learner, we note that strong dependence between π^(x) and µ_w^(x) leads to a slower decaying remainder term theoretically, manifesting in poor empirical performance of using SNet with the DR-learner relative to SNet-1 in smaller samples. Hence, the DR-learner is best combined with a nuisance estimator that does not share information between estimation of π(x) and µw(x).


6.2. Semi-synthetic Benchmark: IHDP

Additionally, we deploy all algorithms on the well-known IHDP benchmark, based on real covariates with simulated outcomes. We use an adapted version of the 100 realizations provided by Shalit et al. (2017). The data-set is small (n = 747, of which 90% are used in training), imbalanced (19% treated), and there is only partial overlap (Hill, 2011). While the two potential outcome functions are supported on the same covariates, their functional forms are different, making their difference – CATE – the most difficult function to estimate. A more detailed description of the data-set can be found in the supplement.

 

In Table 1, we observe that information sharing in plug-in learners indeed significantly improves performance versus the simple TNet, and that more complex models (SNet-3 & SNet) underperform simpler models (SNet-1 & 2). Both of these observations are not surprising given the small sample size, the high treatment group imbalance, as well as the fact that there is no true separation of covariates into different adjustment sets in the DGP. Further, we observe that – with the exception of using the RA-Learner on top of the best-performing plug-in model – the twostep learners underperform the plug-in learners on the IHDP data-set, a direct consequence of the complexity of the simulated τ (x). Further, potentially because overlap is incomplete in this data-set, the RA-learner outperforms the DR-learner substantially.


7. Conclusion

In this paper we considered meta-learning strategies for nonparametric CATE estimation, theoretically analyzed their properties, and implemented them using a range of neural network architectures. We demonstrated that while the DR-learner is asymptotically optimal in theory, both the RA-learner and plug-in learners sharing information between nuisance estimation tasks can outperform it in finite samples. We also showed that using sophisticated architectures as nuisance estimators for two-step learners instead of vanilla NNs can boost their small sample performance. In addition, we highlighted that the relative performance of different learners and architectures depends both on the underlying DGP and the amount of data at hand, such that the choice of learner in practice should incorporate an expert’s assessment of the most likely DGP. While we considered only the choice between meta-learners using the same underlying method throughout, investigating the optimal choice of ML method – e.g. NNs versus random forests – would be an interesting next step.


C. Learning Algorithms and Implementation