The development of a high-quality, high-efficiency machine learning model with a standard widely adopted API will allow for an increased uptake of training data enabling more accurate causal prediction of outcomes influenced by covariates.

Adam Li (Columbia), and collaborator Joshua Vogelstein (Neurodata Lab, JHU)(pictured)

Causal and Generalized Random Forests for Scikit-Learn

PI Adam Li (Columbia) and Josh Vogelstein (Neurodata Lab, JHU)

Classical machine learning seeks to automatically infer a functional relationship between a vector of observable variables and a continuous or categorical outcome. Given enough training data, it can be enormously successful in solving problems that can be reliably cast in terms of prediction or classification across fixed populations. However, the standard paradigm is inadequate for problems involving the prediction of causal treatment effects on the level of the individual, yet such predictions have a wide range of unserved applications in personalization and targeting decisions such as: who should receive a medical treatment, who should get a price discount, who should be targeted for a political campaign, who should receive benefits from a government program, and so on. Since the standard methods make no distinction between subgroup trait covariates and the conditioning treatment variable whose effect we want to estimate, the treatment effect is effectively lost within the prediction of the broader covariate-response pattern. In effect, predictions using standard methods cannot distinguish between effects due to trait or treatment.

Recently a number of new decision tree and random forest-based machine learning models have been introduced that formalize this distinction and enable true individual causal treatment effect prediction. These models and their ongoing developments are expected to open a wide new vista of personalization applications that have so far remained unapproached by machine learning methods. A significant remaining barrier in the uptake of these methods by the wider community is the development of a high-quality, high-efficiency implementation following a standard, widely adopted API. This project closes that gap.