import logging
import numpy as np
import sympy as sp
[docs]class CausalEstimator:
    """Base class for an estimator of causal effect.
    Subclasses implement different estimation methods. All estimation methods are in the package "dowhy.causal_estimators"
    """
    def __init__(self, data, identified_estimand, treatment, outcome,
                 control_value=0, treatment_value=1,
                 test_significance=False, evaluate_effect_strength=False,
                 confidence_intervals = False,
                 target_units=None, effect_modifiers=None,
                 params=None):
        """Initializes an estimator with data and names of relevant variables.
        More description.
        :param data: data frame containing the data
        :param identified_estimand: probability expression
            representing the target identified estimand to estimate.
        :param treatment: name of the treatment variable
        :param outcome: name of the outcome variable
        :param control_value: Value of the treatment in the control group, for effect estimation.  If treatment is multi-variate, this can be a list.
        :param treatment_value: Value of the treatment in the treated group, for effect estimation. If treatment is multi-variate, this can be a list.
        :param test_significance: whether to test significance
        :param evaluate_effect_strength: (Experimental) whether to evaluate the strength of effect
        :param confidence_intervals: (Experimental) Binary flag indicating whether confidence intervals should be computed.
        :param target_units: (Experimental) The units for which the treatment effect should be estimated. This can be a string for common specifications of target units (namely, "ate", "att" and "atc"). It can also be a lambda function that can be used as an index for the data (pandas DataFrame). Alternatively, it can be a new DataFrame that contains values of the effect_modifiers and effect will be estimated only for this new data.
        :param effect_modifiers: variables on which to compute separate effects, or return a heterogeneous effect function. Not all methods support this currently.
        :param params: (optional) additional method parameters
        :returns: an instance of the estimator class.
        """
        self._data = data
        self._target_estimand = identified_estimand
        # Currently estimation methods only support univariate treatment and outcome
        self._treatment_name = treatment
        self._outcome_name = outcome[0] # assuming one-dimensional outcome
        self._control_value = control_value
        self._treatment_value = treatment_value
        self._significance_test = test_significance
        self._effect_strength_eval = evaluate_effect_strength
        self._target_units = target_units
        self._effect_modifier_names = effect_modifiers
        self._confidence_intervals = confidence_intervals
        self._estimate = None
        self._effect_modifiers = None
        self.method_params = params
        if params is not None:
            for key, value in params.items():
                setattr(self, key, value)
        # Checking if some parameters were set, otherwise setting to default values
        if not hasattr(self, 'num_simulations'):
            self.num_simulations = 1000
        self.logger = logging.getLogger(__name__)
        # Setting more values
        if self._data is not None:
            self._treatment = self._data[self._treatment_name]
            self._outcome = self._data[self._outcome_name]
        # Now saving the effect modifiers
        if self._effect_modifier_names:
            self._effect_modifiers = self._data[self._effect_modifier_names]
            self.logger.debug("Effect modifiers: " +
                          ",".join(self._effect_modifier_names))
    def _estimate_effect(self):
        raise NotImplementedError
[docs]    def estimate_effect(self):
        """Base estimation method that calls the estimate_effect method of its calling subclass.
        Can optionally also test significance and estimate effect strength for any returned estimate.
        TODO: Enable methods to return a confidence interval in addition to the point estimate.
        :param self: object instance of class Estimator
        :returns: point estimate of causal effect
        """
        est = self._estimate_effect()
        self._estimate = est
        if self._significance_test:
            signif_dict = self.test_significance(est)
            est.add_significance_test_results(signif_dict)
        if self._effect_strength_eval:
            effect_strength_dict = self.evaluate_effect_strength(est)
            est.add_effect_strength(effect_strength_dict)
        return est 
[docs]    def estimate_effect_naive(self):
        #TODO Only works for binary treatment
        df_withtreatment = self._data.loc[self._data[self._treatment_name] == 1]
        df_notreatment = self._data.loc[self._data[self._treatment_name]== 0]
        est = np.mean(df_withtreatment[self._outcome_name]) - np.mean(df_notreatment[self._outcome_name])
        return CausalEstimate(est, None, None) 
    def _do(self, x):
        raise NotImplementedError
[docs]    def do(self, x):
        """Method that implements the do-operator.
        Given a value x for the treatment, returns the expected value of the outcome when the treatment is intervened to a value x.
        :param x: Value of the treatment
        :returns: Value of the outcome when treatment is intervened/set to x.
        """
        est = self._do(x)
        return est 
[docs]    def construct_symbolic_estimator(self, estimand):
        raise NotImplementedError 
[docs]    def test_significance(self, estimate, num_simulations=None):
        """Test statistical significance of obtained estimate.
        Uses resampling to create a non-parametric significance test.
        A general procedure. Individual estimators can override this method.
        :param self: object instance of class Estimator
        :param estimate: obtained estimate
        :param num_simulations: (optional) number of simulations to run
        :returns:
        """
        if num_simulations is None:
            num_simulations = self.num_simulations
        null_estimates = np.zeros(num_simulations)
        for i in range(num_simulations):
            self._outcome = np.random.permutation(self._outcome)
            est = self._estimate_effect()
            null_estimates[i] = est.value
        sorted_null_estimates = np.sort(null_estimates)
        self.logger.debug("Null estimates: {0}".format(sorted_null_estimates))
        median_estimate = sorted_null_estimates[int(num_simulations / 2)]
        # Doing a two-sided test
        if estimate.value > median_estimate:
            # Being conservative with the p-value reported
            estimate_index = np.searchsorted(sorted_null_estimates, estimate.value, side="left")
            p_value = 1 - (estimate_index / num_simulations)
        if estimate.value <= median_estimate:
            # Being conservative with the p-value reported
            estimate_index = np.searchsorted(sorted_null_estimates, estimate.value, side="right")
            p_value = (estimate_index / num_simulations)
        signif_dict = {
            'p_value': p_value,
            'sorted_null_estimates': sorted_null_estimates
        }
        return signif_dict 
[docs]    def evaluate_effect_strength(self, estimate):
        fraction_effect_explained = self._evaluate_effect_strength(estimate, method="fraction-effect")
        # Need to test r-squared before supporting
        #effect_r_squared = self._evaluate_effect_strength(estimate, method="r-squared")
        strength_dict = {
                'fraction-effect': fraction_effect_explained
         #       'r-squared': effect_r_squared
                }
        return strength_dict 
    def _evaluate_effect_strength(self, estimate, method="fraction-effect"):
        supported_methods = ["fraction-effect"]
        if method not in supported_methods:
            raise NotImplementedError("This method is not supported for evaluating effect strength")
        if method == "fraction-effect":
            naive_obs_estimate = self.estimate_effect_naive()
            print(estimate.value, naive_obs_estimate.value)
            fraction_effect_explained = estimate.value/naive_obs_estimate.value
            return fraction_effect_explained
        #elif method == "r-squared":
        #    outcome_mean = np.mean(self._outcome)
        #    total_variance = np.sum(np.square(self._outcome - outcome_mean))
            # Assuming a linear model with one variable: the treatment
            # Currently only works for continuous y
        #    causal_model = outcome_mean + estimate.value*self._treatment
        #    squared_residual = np.sum(np.square(self._outcome - causal_model))
        #    r_squared = 1 - (squared_residual/total_variance)
        #    return r_squared
        else:
            return None 
[docs]class CausalEstimate:
    """Class for the estimate object that every causal estimator returns
    """
    def __init__(self, estimate, target_estimand, realized_estimand_expr, **kwargs):
        self.value = estimate
        self.target_estimand = target_estimand
        self.realized_estimand_expr = realized_estimand_expr
        self.params = kwargs
        if self.params is not None:
            for key, value in self.params.items():
                setattr(self, key, value)
        self.significance_test = None
        self.effect_strength = None
[docs]    def add_significance_test_results(self, test_results):
        self.significance_test = test_results 
[docs]    def add_effect_strength(self, strength_dict):
        self.effect_strength = strength_dict 
[docs]    def add_params(self, **kwargs):
        self.params.update(kwargs) 
    def __str__(self):
        s = "*** Causal Estimate ***\n"
        s += "\n## Target estimand\n{0}".format(self.target_estimand)
        s += "\n## Realized estimand\n{0}".format(self.realized_estimand_expr)
        s += "\n## Estimate\n"
        s += "Value: {0}\n".format(self.value)
        if self.significance_test is not None:
            s += "\n## Statistical Significance\n"
            if self.significance_test["p_value"]>0:
                s += "p-value: {0}\n".format(self.significance_test["p_value"])
            else:
                s+= "p-value: <{0}\n".format(1/len(self.significance_test["sorted_null_estimates"]))
        if self.effect_strength is not None:
            s += "\n## Effect Strength\n"
            s += "Change in outcome attributable to treatment: {}\n".format(self.effect_strength["fraction-effect"])
            #s += "Variance in outcome explained by treatment: {}\n".format(self.effect_strength["r-squared"])
        return s 
[docs]class RealizedEstimand(object):
    def __init__(self, identified_estimand, estimator_name):
        self.treatment_variable = identified_estimand.treatment_variable
        self.outcome_variable = identified_estimand.outcome_variable
        self.backdoor_variables = identified_estimand.backdoor_variables
        self.instrumental_variables = identified_estimand.instrumental_variables
        self.estimand_type = identified_estimand.estimand_type
        self.estimand_expression = None
        self.assumptions = None
        self.estimator_name = estimator_name
[docs]    def update_assumptions(self, estimator_assumptions):
        self.assumptions = estimator_assumptions 
[docs]    def update_estimand_expression(self, estimand_expression):
        self.estimand_expression = estimand_expression 
    def __str__(self):
        s = "Realized estimand: {0}\n".format(self.estimator_name)
        s += "Realized estimand type: {0}\n".format(self.estimand_type)
        s += "Estimand expression:\n{0}\n".format(sp.pretty(self.estimand_expression))
        j = 1
        for ass_name, ass_str in self.assumptions.items():
            s += "Estimand assumption {0}, {1}: {2}\n".format(j, ass_name, ass_str)
            j += 1
        return s