dowhy.causal_prediction.datasets package#

子模块#

dowhy.causal_prediction.datasets.base_dataset module#

MultipleDomainDataset 类在此文件中借用自 DomainBed:facebookresearch/DomainBed
@inproceedings{gulrajani2021in,

title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},

}

class dowhy.causal_prediction.datasets.base_dataset.MultipleDomainDataset[source]#

基类:object

CHECKPOINT_FREQ = 100#
ENVIRONMENTS = None#
INPUT_SHAPE = None#
N_STEPS = 5001#
N_WORKERS = 8#

dowhy.causal_prediction.datasets.mnist module#

class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]#

基类:MultipleDomainDataset

MNISTCausalAttribute 数据集的类。

参数:
  • root – 数据所在的目录(如果不存在,则为下载数据的目录)。

  • download – 指示是否下载数据的二进制标志

返回:

MultipleDomainDataset 类的一个实例

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']#
INPUT_SHAPE = (2, 14, 14)#
N_STEPS = 5001#
color_dataset(images, labels, environment)[source]#

转换 MNIST 数据集以引入属性(颜色)和标签之间的相关性。标签 Y 和颜色之间存在直接的因果关系。

参数:
  • images – 原始 MNIST 图像

  • labels – 原始 MNIST 标签

  • environment – 颜色和标签之间相关性的值

返回:

包含转换后图像、标签和属性(颜色)的 TensorDataset

torch_bernoulli_(p, size)[source]#
torch_xor_(a, b)[source]#
class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]#

基类:MultipleDomainDataset

MNISTIndAttribute 数据集的类。

参数:
  • root – 数据所在的目录(如果不存在,则为下载数据的目录)。

  • download – 指示是否下载数据的二进制标志

返回:

MultipleDomainDataset 类的一个实例

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']#
INPUT_SHAPE = (2, 14, 14)#
N_STEPS = 5001#
color_dataset(images, labels, environment)[source]#

转换 MNIST 数据集以引入属性(颜色)和标签之间的相关性。标签 Y 和颜色之间存在直接的因果关系。

参数:
  • images – 旋转后的 MNIST 图像

  • labels – 原始 MNIST 标签

  • environment – 颜色和标签之间相关性的值

返回:

转换后的图像、标签和属性(颜色)

color_rot_dataset(images, labels, environment, env_id, angle)[source]#

通过 (i) 对图像应用旋转,然后 (ii) 在属性(颜色)和标签之间引入相关性来转换 MNIST 数据集。属性(旋转角度)独立于标签 Y;标签 Y 和颜色之间存在直接的因果关系。

参数:
  • images – 原始 MNIST 图像

  • labels – 原始 MNIST 标签

  • environment – 颜色和标签之间相关性的值

  • angle – 用于转换图像的旋转角度值

返回:

包含转换后图像、标签和属性(颜色、角度)的 TensorDataset

rotate_dataset(images, angle)[source]#

通过对图像应用旋转来转换 MNIST 数据集。属性(旋转角度)独立于标签 Y。

参数:
  • images – 原始 MNIST 图像

  • angle – 用于转换图像的旋转角度值

返回:

转换后的图像

torch_bernoulli_(p, size)[source]#
torch_xor_(a, b)[source]#
class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]#

基类:MultipleDomainDataset

MNISTIndAttribute 数据集的类。

参数:
  • root – 数据所在的目录(如果不存在,则为下载数据的目录)。

  • download – 指示是否下载数据的二进制标志

返回:

MultipleDomainDataset 类的一个实例

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['15', '60', '90', '90']#
INPUT_SHAPE = (1, 14, 14)#
N_STEPS = 5001#
rotate_dataset(images, labels, env_id, angle)[source]#

通过对图像应用旋转来转换 MNIST 数据集。属性(旋转角度)独立于标签 Y。

参数:
  • images – 原始 MNIST 图像

  • labels – 原始 MNIST 标签

  • angle – 用于转换图像的旋转角度值

返回:

包含转换后图像、标签和属性(角度)的 TensorDataset

torch_bernoulli_(p, size)[source]#
torch_xor_(a, b)[source]#

模块内容#