import logging
import numpy as np
import pandas as pd
from dowhy.utils.api import parse_state
[docs]class DoSampler:
"""Base class for a sampler from the interventional distribution.
"""
def __init__(self, data,
params=None, variable_types=None,
num_cores=1, causal_model=None, keep_original_treatment=False):
"""
Initializes a do sampler with data and names of relevant variables.
Do sampling implements the do() operation from Pearl (2000). This is an operation is defined on a causal
bayesian network, an explicit implementation of which is the basis for the MCMC sampling method.
We abstract the idea behind the three-step process to allow other methods, as well. The `disrupt_causes`
method is the means to make treatment assignment ignorable. In the Pearlian framework, this is where we cut the
edges pointing into the causal state. With other methods, this will typically be by using some approach which
assumes conditional ignorability (e.g. weighting, or explicit conditioning with Robins G-formula.)
Next, the `make_treatment_effective` method reflects the assumption that the intervention we impose is
"effective". Most simply, we fix the causal state to some specific value. We skip this step there is no value
specified for the causal state, and the original values are used instead.
Finally, we sample from the resulting distribution. This can be either from a `point_sample` method, in the case
that the inference method doesn't support batch sampling, or the `sample` method in the case that it does. For
convenience, the `point_sample` method parallelizes with `multiprocessing` using the `num_cores` kwargs to set
the number of cores to use for parallelization.
While different methods will have their own class attributes, the `_df` method should be common to all methods.
This is them temporary dataset which starts as a copy of the original data, and is modified to reflect the steps
of the do operation. Read through the existing methods (weighting is likely the most minimal) to get an idea of
how this works to implement one yourself.
:param data: pandas.DataFrame containing the data
:param identified_estimand: dowhy.causal_identifier.IdentifiedEstimand: and estimand using a backdoor method
for effect identification.
:param treatments: list or str: names of the treatment variables
:param outcomes: list or str: names of the outcome variables
:param variable_types: dict: A dictionary containing the variable's names and types. 'c' for continuous, 'o'
for ordered, 'd' for discrete, and 'u' for unordered discrete.
:param keep_original_treatment: bool: Whether to use `make_treatment_effective`, or to keep the original
treatment assignments.
:param params: (optional) additional method parameters
"""
self._data = data.copy()
self._causal_model = causal_model
self._target_estimand = self._causal_model.identify_effect()
self._treatment_names = parse_state(self._causal_model._treatment)
self._outcome_names = parse_state(self._causal_model._outcome)
self._estimate = None
self._variable_types = variable_types
self.num_cores = num_cores
self.point_sampler = True
self.sampler = None
self.keep_original_treatment = keep_original_treatment
if params is not None:
for key, value in params.items():
setattr(self, key, value)
self._df = self._data.copy()
if not self._variable_types:
self._infer_variable_types()
self.dep_type = [self._variable_types[var] for var in self._outcome_names]
self.indep_type = [self._variable_types[var] for var in
self._treatment_names + self._target_estimand.backdoor_variables]
self.density_types = [self._variable_types[var] for var in self._target_estimand.backdoor_variables]
self.outcome_lower_support = self._data[self._outcome_names].min().values
self.outcome_upper_support = self._data[self._outcome_names].max().values
self.logger = logging.getLogger(__name__)
def _sample_point(self, x_z):
"""
OVerride this if your sampling method only allows sampling a point at a time.
:param : numpy.array: x_z is a numpy array containing the values of x and z in the order of the list given by
self._treatment_names + self._target_estimand.backdoor_variables
:return: numpy.array: a sampled outcome point
"""
raise NotImplementedError
[docs] def reset(self):
"""
If your `DoSampler` has more attributes that the `_df` attribute, you should reset them all to their
initialization values by overriding this method.
:return:
"""
self._df = self._data.copy()
[docs] def make_treatment_effective(self, x):
"""
This is more likely the implementation you'd like to use, but some methods may require overriding this method
to make the treatment effective.
:param x:
:return:
"""
if not self.keep_original_treatment:
self._df[self._treatment_names] = x
[docs] def disrupt_causes(self):
"""
Override this method to render treatment assignment conditionally ignorable
:return:
"""
raise NotImplementedError
[docs] def point_sample(self):
if self.num_cores == 1:
sampled_outcomes = self._df[self._treatment_names +
self._target_estimand.backdoor_variables].apply(self._sample_point, axis=1)
else:
from multiprocessing import Pool
p = Pool(self.num_cores)
sampled_outcomes = np.array(p.map(self.sampler.sample_point,
self._df[self._treatment_names +
self._target_estimand.backdoor_variables].values))
sampled_outcomes = pd.DataFrame(sampled_outcomes, columns=self._outcome_names)
self._df[self._outcome_names] = sampled_outcomes
[docs] def sample(self):
"""
By default, this expects a sampler to be built on class initialization which contains a `sample` method.
Override this method if you want to use a different approach to sampling.
:return:
"""
sampled_outcomes = self.sampler.sample(self._df[self._treatment_names +
self._target_estimand.backdoor_variables].values)
sampled_outcomes = pd.DataFrame(sampled_outcomes, columns=self._outcome_names)
self._df[self._outcome_names] = sampled_outcomes
[docs] def do_sample(self, x):
self.reset()
self.disrupt_causes()
self.make_treatment_effective(x)
if self.point_sampler:
self.point_sample()
else:
self.sample()
return self._df
def _infer_variable_types(self):
raise NotImplementedError('Variable type inference not implemented. Use the variable_types kwarg.')