dowhy.causal_prediction.datasets package#
子模块#
dowhy.causal_prediction.datasets.base_dataset module#
- MultipleDomainDataset 类在此文件中借用自 DomainBed:facebookresearch/DomainBed
- @inproceedings{gulrajani2021in,
title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},
}
dowhy.causal_prediction.datasets.mnist module#
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]#
-
MNISTCausalAttribute 数据集的类。
- 参数:
root – 数据所在的目录(如果不存在,则为下载数据的目录)。
download – 指示是否下载数据的二进制标志
- 返回:
MultipleDomainDataset 类的一个实例
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']#
- INPUT_SHAPE = (2, 14, 14)#
- N_STEPS = 5001#
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]#
-
MNISTIndAttribute 数据集的类。
- 参数:
root – 数据所在的目录(如果不存在,则为下载数据的目录)。
download – 指示是否下载数据的二进制标志
- 返回:
MultipleDomainDataset 类的一个实例
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']#
- INPUT_SHAPE = (2, 14, 14)#
- N_STEPS = 5001#
- color_dataset(images, labels, environment)[source]#
转换 MNIST 数据集以引入属性(颜色)和标签之间的相关性。标签 Y 和颜色之间存在直接的因果关系。
- 参数:
images – 旋转后的 MNIST 图像
labels – 原始 MNIST 标签
environment – 颜色和标签之间相关性的值
- 返回:
转换后的图像、标签和属性(颜色)
- color_rot_dataset(images, labels, environment, env_id, angle)[source]#
通过 (i) 对图像应用旋转,然后 (ii) 在属性(颜色)和标签之间引入相关性来转换 MNIST 数据集。属性(旋转角度)独立于标签 Y;标签 Y 和颜色之间存在直接的因果关系。
- 参数:
images – 原始 MNIST 图像
labels – 原始 MNIST 标签
environment – 颜色和标签之间相关性的值
angle – 用于转换图像的旋转角度值
- 返回:
包含转换后图像、标签和属性(颜色、角度)的 TensorDataset
- class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]#
-
MNISTIndAttribute 数据集的类。
- 参数:
root – 数据所在的目录(如果不存在,则为下载数据的目录)。
download – 指示是否下载数据的二进制标志
- 返回:
MultipleDomainDataset 类的一个实例
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['15', '60', '90', '90']#
- INPUT_SHAPE = (1, 14, 14)#
- N_STEPS = 5001#