Modules¶
- class dte_adj.AdjustedDistributionEstimator(base_model, folds=3, is_multi_task=False)[source]¶
Bases:
DistributionEstimatorBase
A class is for estimating the adjusted distribution function and computing the Distributional parameters based on the trained conditional estimator.
- fit(confoundings: ndarray, treatment_arms: ndarray, outcomes: ndarray) DistributionEstimatorBase ¶
Train the DistributionEstimatorBase.
- Parameters:
confoundings (np.ndarray) – Pre-treatment covariates.
treatment_arms (np.ndarray) – The index of the treatment arm.
outcomes (np.ndarray) – Scalar-valued observed outcome.
- Returns:
The fitted estimator.
- Return type:
DistributionEstimatorBase
- predict(treatment_arm: int, locations: ndarray) ndarray ¶
Compute cumulative distribution values.
- Parameters:
treatment_arm (int) – The index of the treatment arm.
outcomes (np.ndarray) – Scalar values to be used for computing the cumulative distribution.
- Returns:
Estimated cumulative distribution values for the input.
- Return type:
np.ndarray
- predict_dte(target_treatment_arm: int, control_treatment_arm: int, locations: ndarray, alpha: float = 0.05, variance_type='moment', n_bootstrap=500) Tuple[ndarray, ndarray, ndarray] ¶
Compute DTE based on the estimator for the distribution function.
- Parameters:
target_treatment_arm (int) – The index of the treatment arm of the treatment group.
control_treatment_arm (int) – The index of the treatment arm of the control group.
locations (np.ndarray) – Scalar values to be used for computing the cumulative distribution.
alpha (float, optional) – Significance level of the confidence bound. Defaults to 0.05.
variance_type (str, optional) – Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
n_bootstrap (int, optional) – Number of bootstrap samples. Defaults to 500.
- Returns:
- A tuple containing:
Expected DTEs
Upper bounds
Lower bounds
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- predict_pte(target_treatment_arm: int, control_treatment_arm: int, width: float, locations: ndarray, alpha: float = 0.05, variance_type='moment') Tuple[ndarray, ndarray, ndarray] ¶
Compute PTE based on the estimator for the distribution function.
- Parameters:
target_treatment_arm (int) – The index of the treatment arm of the treatment group.
control_treatment_arm (int) – The index of the treatment arm of the control group.
locations (np.ndarray) – Scalar values to be used for computing the cumulative distribution.
width (float) – The width of each outcome interval.
alpha (float, optional) – Significance level of the confidence bound. Defaults to 0.05.
variance_type (str, optional) – Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
- Returns:
- A tuple containing:
Expected PTEs
Upper bounds
Lower bounds
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- predict_qte(target_treatment_arm: int, control_treatment_arm: int, quantiles: ndarray = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], dtype=float32), alpha: float = 0.05, n_bootstrap=500) Tuple[ndarray, ndarray, ndarray] ¶
Compute QTE based on the estimator for the distribution function.
- Parameters:
target_treatment_arm (int) – The index of the treatment arm of the treatment group.
control_treatment_arm (int) – The index of the treatment arm of the control group.
quantiles (np.ndarray, optional) – Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10)].
alpha (float, optional) – Significance level of the confidence bound. Defaults to 0.05.
n_bootstrap (int, optional) – Number of bootstrap samples. Defaults to 500.
- Returns:
- A tuple containing:
Expected QTEs
Upper bounds
Lower bounds
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- class dte_adj.SimpleDistributionEstimator[source]¶
Bases:
DistributionEstimatorBase
A class for computing the empirical distribution function and the distributional parameters based on the distribution function.
- fit(confoundings: ndarray, treatment_arms: ndarray, outcomes: ndarray) DistributionEstimatorBase ¶
Train the DistributionEstimatorBase.
- Parameters:
confoundings (np.ndarray) – Pre-treatment covariates.
treatment_arms (np.ndarray) – The index of the treatment arm.
outcomes (np.ndarray) – Scalar-valued observed outcome.
- Returns:
The fitted estimator.
- Return type:
DistributionEstimatorBase
- predict(treatment_arm: int, locations: ndarray) ndarray ¶
Compute cumulative distribution values.
- Parameters:
treatment_arm (int) – The index of the treatment arm.
outcomes (np.ndarray) – Scalar values to be used for computing the cumulative distribution.
- Returns:
Estimated cumulative distribution values for the input.
- Return type:
np.ndarray
- predict_dte(target_treatment_arm: int, control_treatment_arm: int, locations: ndarray, alpha: float = 0.05, variance_type='moment', n_bootstrap=500) Tuple[ndarray, ndarray, ndarray] ¶
Compute DTE based on the estimator for the distribution function.
- Parameters:
target_treatment_arm (int) – The index of the treatment arm of the treatment group.
control_treatment_arm (int) – The index of the treatment arm of the control group.
locations (np.ndarray) – Scalar values to be used for computing the cumulative distribution.
alpha (float, optional) – Significance level of the confidence bound. Defaults to 0.05.
variance_type (str, optional) – Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
n_bootstrap (int, optional) – Number of bootstrap samples. Defaults to 500.
- Returns:
- A tuple containing:
Expected DTEs
Upper bounds
Lower bounds
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- predict_pte(target_treatment_arm: int, control_treatment_arm: int, width: float, locations: ndarray, alpha: float = 0.05, variance_type='moment') Tuple[ndarray, ndarray, ndarray] ¶
Compute PTE based on the estimator for the distribution function.
- Parameters:
target_treatment_arm (int) – The index of the treatment arm of the treatment group.
control_treatment_arm (int) – The index of the treatment arm of the control group.
locations (np.ndarray) – Scalar values to be used for computing the cumulative distribution.
width (float) – The width of each outcome interval.
alpha (float, optional) – Significance level of the confidence bound. Defaults to 0.05.
variance_type (str, optional) – Variance type to be used to compute confidence intervals. Available values are moment, simple, and uniform.
- Returns:
- A tuple containing:
Expected PTEs
Upper bounds
Lower bounds
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- predict_qte(target_treatment_arm: int, control_treatment_arm: int, quantiles: ndarray = array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], dtype=float32), alpha: float = 0.05, n_bootstrap=500) Tuple[ndarray, ndarray, ndarray] ¶
Compute QTE based on the estimator for the distribution function.
- Parameters:
target_treatment_arm (int) – The index of the treatment arm of the treatment group.
control_treatment_arm (int) – The index of the treatment arm of the control group.
quantiles (np.ndarray, optional) – Quantiles used for QTE. Defaults to [0.1 * i for i in range(1, 10)].
alpha (float, optional) – Significance level of the confidence bound. Defaults to 0.05.
n_bootstrap (int, optional) – Number of bootstrap samples. Defaults to 500.
- Returns:
- A tuple containing:
Expected QTEs
Upper bounds
Lower bounds
- Return type:
Tuple[np.ndarray, np.ndarray, np.ndarray]
- dte_adj.plot.plot(X: ndarray, means: ndarray, lower_bounds: ndarray, upper_bounds: ndarray, chart_type='line', ax: Axis | None = None, title: str | None = None, xlabel: str | None = None, ylabel: str | None = None)[source]¶
Visualize distributional parameters and their confidence intervals.
- Parameters:
X (np.Array) – values to be used for x axis.
means (np.Array) – Expected distributional parameters.
lower_bounds (np.Array) – Lower bound for the distributional parameters.
upper_bounds (np.Array) – Upper bound for the distributional parameters.
chart_type (str) – Chart type of the plotting. Available values are line or bar.
ax (matplotlib.axes.Axes, optional) – Target axes instance. If None, a new figure and axes will be created.
title (str, optional) – Axes title.
xlabel (str, optional) – X-axis title label.
ylabel (str, optional) – Y-axis title label.
- Returns:
The axes with the plot.
- Return type:
matplotlib.axes.Axes