dowhy.causal_prediction.algorithms package#
子模块#
dowhy.causal_prediction.algorithms.base_algorithm module#
- class dowhy.causal_prediction.algorithms.base_algorithm.PredictionAlgorithm(model, optimizer, lr, weight_decay, betas, momentum)[源代码]#
基类:
LightningModule
此类实现了 Pytorch lightning 模块 pl.LightningModule 的默认方法。当调用 fit() 方法时,会调用其方法。要了解更多关于这些方法的信息,请参考 https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html。
- 参数:
model – 用于训练的神经网络模块
optimizer – 用于训练的优化算法。当前支持“Adam”和“SGD”。
lr – 学习率的值
weight_decay – 优化器的权重衰减值
betas – Adam 配置参数 (beta1, beta2),分别用于第一动量和第二动量估计的指数衰减率。
momentum – SGD 优化器的动量值
dowhy.causal_prediction.algorithms.cacm module#
- class dowhy.causal_prediction.algorithms.cacm.CACM(model, optimizer='Adam', lr=0.001, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9, kernel_type='gaussian', ci_test='mmd', attr_types=[], E_conditioned=True, E_eq_A=[], gamma=1e-06, lambda_causal=1.0, lambda_conf=1.0, lambda_ind=1.0, lambda_sel=1.0)[源代码]#
-
- 因果自适应约束最小化 (CACM) 算法类。
- @article{Kaur2022ModelingTD,
title={Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization}, author={Jivat Neet Kaur and Emre Kıcıman and Amit Sharma}, journal={ArXiv}, year={2022}, volume={abs/2206.07837}, url={https://arxiv.org/abs/2206.07837}
}
- 参数:
model – 用于训练的网络。model 类型期望是 torch.nn.Sequential(featurizer, classifier),其中 featurizer 和 classifier 的类型是 torch.nn.Module。
optimizer – 用于训练的优化算法。当前支持“Adam”和“SGD”。
lr – CACM 的学习率
weight_decay – 优化器的权重衰减值
betas – Adam 配置参数 (beta1, beta2),分别用于第一动量和第二动量估计的指数衰减率。
momentum – SGD 优化器的动量值
kernel_type – MMD 惩罚项的核类型。目前支持“gaussian” (RBF)。如果为 None,则使用均值和二阶统计量(协方差)之间的距离。
ci_test – 用于正则化惩罚项的条件独立性度量。目前支持 MMD。
attr_types – 属性类型列表(基于与标签 Y 的关系);应按加载数据集中属性的顺序排序。目前支持‘causal’(因果)、‘conf’(混淆)、‘ind’(独立)和‘sel’(选择)。对于单移位数据集,使用:[‘causal’]、[‘ind’]。对于多移位数据集,使用:[‘causal’, ‘ind’]
E_conditioned – 布尔标志,指示是否应用 E 条件正则化
E_eq_A – 属性 (A) 与环境 (E) 定义一致的属性索引列表;默认为空。
gamma – MMD 的核带宽(由于实现原因,实际的核带宽将是 gamma 的倒数,即 gamma=1e-6 意味着核带宽=1e6。参见 utils.py 中的 mmd_compute)
lambda_causal – 因果移位的 MMD 惩罚项超参数
lambda_conf – 混淆移位的 MMD 惩罚项超参数
lambda_ind – 独立移位的 MMD 惩罚项超参数
lambda_sel – 选择移位的 MMD 惩罚项超参数
- 返回:
PredictionAlgorithm 类的一个实例
dowhy.causal_prediction.algorithms.erm module#
- class dowhy.causal_prediction.algorithms.erm.ERM(model, optimizer='Adam', lr=0.001, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9)[源代码]#
-
此类实现了 Pytorch lightning 模块 pl.LightningModule 的默认方法。当调用 fit() 方法时,会调用其方法。要了解更多关于这些方法的信息,请参考 https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html。
- 参数:
model – 用于训练的神经网络模块
optimizer – 用于训练的优化算法。当前支持“Adam”和“SGD”。
lr – 学习率的值
weight_decay – 优化器的权重衰减值
betas – Adam 配置参数 (beta1, beta2),分别用于第一动量和第二动量估计的指数衰减率。
momentum – SGD 优化器的动量值
dowhy.causal_prediction.algorithms.regularization module#
- class dowhy.causal_prediction.algorithms.regularization.Regularizer(E_conditioned, ci_test, kernel_type, gamma)[源代码]#
基类:
object
实现应用无条件和条件正则化的方法。
- 参数:
E_conditioned – 布尔标志,指示是否应用 E 条件正则化
ci_test – 用于正则化惩罚项的条件独立性度量。目前支持 MMD。
kernel_type – MMD 惩罚项的核类型。目前支持“gaussian” (RBF)。如果为 None,则使用均值和二阶统计量(协方差)之间的距离。
gamma – MMD 的核带宽(由于实现原因,实际的核带宽将是 gamma 的倒数,即 gamma=1e-6 意味着核带宽=1e6。参见 utils.py 中的 mmd_compute)
- conditional_reg(classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False)[源代码]#
实现条件正则化 φ(x) ⊥⊥ A_i | A_s
- 参数:
classifs – 分类器层输出的特征表示 (gφ(x))
attribute_labels – 数据集加载的属性 A_i 的属性标签
conditioning_subset – 观测变量子集 A_s(属性 + 目标)的列表,使得 (X_c, A_i) 在此子集上是 d 分离的
num_envs – 环境/域的数量
E_eq_A – 布尔标志,指示属性 (A_i) 是否与环境 (E) 定义一致
根据条件子集找到条件正则化的组索引,方法是获取所有可能的组合,例如:conditioning_subset = [A1, Y],其中 A1 在 {0, 1} 中,Y 在 {0, 1, 2} 中,我们按如下方式分配组:
A1 = 0, Y = 0 -> 组 0 A1 = 1, Y = 0 -> 组 1 A1 = 0, Y = 1 -> 组 2 A1 = 1, Y = 1 -> 组 3 A1 = 0, Y = 2 -> 组 4 A1 = 1, Y = 2 -> 组 5
- 计算组索引的代码片段改编自 WILDS: p-lambda/wilds
- @inproceedings{wilds2021,
title = {{WILDS}: A Benchmark of in-the-Wild Distribution Shifts}, author = {Pang Wei Koh and Shiori Sagawa and Henrik Marklund and Sang Michael Xie and Marvin Zhang and Akshay Balsubramani and Weihua Hu and Michihiro Yasunaga and Richard Lanas Phillips and Irena Gao and Tony Lee and Etienne David and Ian Stavness and Wei Guo and Berton A. Earnshaw and Imran S. Haque and Sara Beery and Jure Leskovec and Anshul Kundaje and Emma Pierson and Sergey Levine and Chelsea Finn and Percy Liang}, booktitle = {International Conference on Machine Learning (ICML)}, year = {2021}
}`
dowhy.causal_prediction.algorithms.utils module#
- 此文件中的函数借用自 DomainBed:facebookresearch/DomainBed
- @inproceedings{gulrajani2021in,
title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},
}