[docs]classPropensityBalanceInterpreter(VisualInterpreter):SUPPORTED_ESTIMATORS=[PropensityScoreStratificationEstimator,]def__init__(self,estimate,**kwargs):super().__init__(estimate,**kwargs)ifnotisinstance(estimate,CausalEstimate):error_msg="The interpreter method expects a CausalEstimate object."self.logger.error(error_msg)raiseValueError(error_msg)self.estimator=self.estimate.estimatorifnotany(isinstance(self.estimator,est_class)forest_classinPropensityBalanceInterpreter.SUPPORTED_ESTIMATORS):error_msg="The interpreter method only supports propensity score stratification estimator."self.logger.error(error_msg)raiseValueError(error_msg)
[docs]definterpret(self):"""Balance plot that shows the change in standardized mean differences for each covariate after propensity score stratification. """cols=self.estimator._observed_common_causes_names+self.estimator._treatment_name+["strata","propensity_score"]df=self.estimator._data[cols]df_long=pd.wide_to_long(df.reset_index(),stubnames=["W"],i='index',j="common_cause_id").reset_index().astype({'W':'float64'})# First, calculating mean differences by stratamean_diff=df_long.groupby(self.estimator._treatment_name+["common_cause_id","strata"]).agg(mean_w=("W",np.mean))mean_diff=mean_diff.groupby(["common_cause_id","strata"]).transform(lambdax:x.max()-x.min()).reset_index()mean_diff=mean_diff.query("v0==True")size_by_w_strata=df_long.groupby(["common_cause_id","strata"]).agg(size=("propensity_score",np.size)).reset_index()size_by_strata=df_long.groupby(["common_cause_id"]).agg(size=("propensity_score",np.size)).reset_index()size_by_strata=pd.merge(size_by_w_strata,size_by_strata,on="common_cause_id")mean_diff_strata=pd.merge(mean_diff,size_by_strata,on=("common_cause_id","strata"))stddev_by_w_strata=df_long.groupby(["common_cause_id","strata"]).agg(stddev=("W",np.std)).reset_index()mean_diff_strata=pd.merge(mean_diff_strata,stddev_by_w_strata,on=["common_cause_id","strata"])mean_diff_strata["scaled_mean"]=(mean_diff_strata["mean_w"]/mean_diff_strata["stddev"])*(mean_diff_strata["size_x"]/mean_diff_strata["size_y"])mean_diff_strata=mean_diff_strata.groupby("common_cause_id").agg(std_mean_diff=("scaled_mean",np.sum)).reset_index()# Second, without stratamean_diff_overall=df_long.groupby(self.estimator._treatment_name+["common_cause_id"]).agg(mean_w=("W",np.mean))mean_diff_overall=mean_diff_overall.groupby("common_cause_id").transform(lambdax:x.max()-x.min()).reset_index()mean_diff_overall=mean_diff_overall[mean_diff_overall[self.estimator._treatment_name[0]]==True]#TODOstddev_overall=df_long.groupby(["common_cause_id"]).agg(stddev=("W",np.std)).reset_index()mean_diff_overall=pd.merge(mean_diff_overall,stddev_overall,on=["common_cause_id"])mean_diff_overall["std_mean_diff"]=mean_diff_overall["mean_w"]/mean_diff_overall["stddev"]# Third, concatenating them and plottingmean_diff_overall=mean_diff_overall[["common_cause_id","std_mean_diff"]]mean_diff_strata["sample"]="PropensityAdjusted"mean_diff_overall["sample"]="Unadjusted"plot_df=pd.concat([mean_diff_overall,mean_diff_strata])importmatplotlib.pyplotaspltplt.style.use("seaborn-white")fig,ax=plt.subplots(1,1)forlabel,subdfinplot_df.groupby('common_cause_id'):subdf.plot(kind="line",x="sample",y="std_mean_diff",ax=ax,label=label)plt.legend(title="Common causes")plt.ylabel("Standardized mean difference between treatment and control")plt.xlabel("")plt.xticks(rotation=45)returnplot_df