DoWhy 在 MNIST 上的因果预测演示#

本 Notebook 的目标是演示使用 因果自适应约束最小化 (CACM) (https://arxiv.org/abs/2206.07837) [1] 进行因果预测的示例。

多属性分布偏移数据集#

域泛化文献主要关注那些在一个属性上仅具有一种分布偏移的数据集。以 MNIST 为例,可以通过增加旋转等虚假属性的新值来创建域(例如,Rotated-MNIST 数据集 [2]),或者域在类别标签和颜色等虚假属性之间表现出不同的相关值(例如,Colored-MNIST [3])。然而,现实世界的数据通常在不同属性上具有多种分布偏移。例如,卫星图像数据既表现出随时间变化的分布偏移,也表现出随区域变化的分布偏移。

多属性 MNIST#

我们创建了一个 多属性 偏移的 MNIST 变体,其中数字的颜色和旋转角度都可以在数据分布中发生偏移。因此,我们创建了三种 MNIST 变体 – MNISTCausalAttribute (单属性偏移)、MNISTIndAttribute (单属性偏移)、MNISTCausalIndAttribute (多属性偏移)。为了更好地描述 CausalIndCausalInd 数据集,请考虑下面的数据生成过程的因果图

main_fig_mnist.drawio.png

分布偏移的特征基于虚假属性 A 与分类标签 Y 之间的关系。 1. Causal:属性与类别标签具有直接的 因果 关系,即 Y 导致属性(例如,这里的 颜色) 2. Ind:属性与类别标签 独立(例如,这里的 旋转) 3. CausalInd:与 Y 具有 因果独立 关系的不同属性共存于数据中

多属性 MNIST 中的域#

我们描述了我们的 多属性 偏移数据集 MNISTCausalIndAttribute 的域。每个域 Ei 都有一个特定的 旋转 角度 ri 和一个特定的 颜色 C 与标签 Y 之间的相关性 corri。我们的设置包含 3 个域:E1、E2 是训练域,E3 是测试域。我们在 Ei 中定义 corri = P(Y = 1|C = 1) = P(Y = 0|C = 0)。在我们的设置中,r1 = 15◦, r2 = 60◦, r3 = 90◦,corr1 = 0.9, corr2 = 0.8, corr3 = 0.1。所有环境都有 25% 的标签噪声,如 [3] 中所示

其他与数据集相关的详细信息可以在 dowhy.causal_prediction.datasets 中找到。

MNIST_visualize.drawio%20%283%29.png

[1]:
import torch
import pytorch_lightning as pl

初始化数据集#

[2]:
from dowhy.causal_prediction.datasets.mnist import MNISTCausalAttribute

# dataset class initialization requires mandatory param `data_dir`
# `download` is passed to torchvision.datasets.MNIST and downloads data if not present
data_dir = 'data'
dataset = MNISTCausalAttribute(data_dir, download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

初始化数据加载器#

get_loaders 返回用于训练、验证和测试的数据加载器。loaders 返回的是一个包含 train_loadersval_loaderstest_loaders 的字典。目前支持两种场景来初始化验证域

方法 1:当数据集中的一个或多个域被明确指定为验证域时 方法 2:当没有指定的验证域时,使用训练域的一个子集来创建验证集

根据需要运行方法 1 或方法 2 下方的单元格。

[3]:
from dowhy.causal_prediction.dataloaders.get_data_loader import get_loaders

方法 1:明确提供验证域#

将验证域的索引提供给 val_envstest_envs 是一个可选参数。

[4]:
loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64,
            val_envs=[2], test_envs=[3])
/github/home/.cache/pypoetry/virtualenvs/dowhy-oN2hW5jr-py3.8/lib/python3.8/site-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(

方法 2:使用训练数据子集作为验证集#

val_envstest_envs 是可选参数。如果未提供 val_envs,则使用训练数据的一个子集创建验证集。使用的训练数据比例由 holdout_fraction 决定。

[5]:
loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64,
            holdout_fraction=0.2, test_envs=[3])

如果存在多个验证或测试域,下面的代码会处理它们。无论使用方法 1 还是方法 2,都请运行下面的单元格。

[6]:
# handle multiple validation and test domains if present
from pytorch_lightning.trainer.supporters import CombinedLoader

if len(loaders['val_loaders']) > 1:
    val_loaders = loaders['val_loaders']
    loaders['val_loaders'] = CombinedLoader(val_loaders)

if len(loaders['test_loaders']) > 1:
    test_loaders = loaders['test_loaders']
    loaders['test_loaders'] = CombinedLoader(test_loaders)

初始化模型和算法#

[7]:
from dowhy.causal_prediction.models.networks import MNIST_MLP, Classifier

下面的 model 预期是 torch.nn.Sequential 类型,包含两个 torch.nn.Module 元素(特征提取器和分类器)。我们在 dowhy.causal_prediction.models.networks 中提供了示例网络(MLP, ResNet),但用户可以灵活使用任何模型。

[8]:
featurizer = MNIST_MLP(dataset.input_shape)
classifier = Classifier(
    featurizer.n_outputs,
    dataset.num_classes)

model = torch.nn.Sequential(featurizer, classifier)

初始化算法类:ERM#

我们在 dowhy.causal_prediction.algorithms 中实现了经验风险最小化 (ERM) 作为基准。

[9]:
from dowhy.causal_prediction.algorithms.erm import ERM
[10]:
algorithm = ERM(model, lr=1e-3)

拟合预测器并开始训练#

注意:MNISTCausalAttribute(以及其他引入的 MNIST 变体)的最佳准确率是 75%,因为我们按照先前的工作引入了 25% 的噪声。

[11]:
trainer = pl.Trainer(devices=1, max_epochs=5)

# val_loaders is optional param
trainer.fit(algorithm, loaders['train_loaders'], loaders['val_loaders'])
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/github/home/.cache/pypoetry/virtualenvs/dowhy-oN2hW5jr-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
Missing logger folder: /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 306 K
-------------------------------------
306 K     Trainable params
0         Non-trainable params
306 K     Total params
1.226     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5` reached.

在测试域上评估#

使用 trainer.test 在测试集上执行一个评估周期。ckpt_path 确定用于评估的模型——‘best’、‘last’ 或特定检查点的路径。如果未传入 ckpt_path,则加载上一次 trainer.fit 的最佳模型检查点 (https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.test)。

我们报告在测试域/测试集上的准确率 (test_acc) 和交叉熵损失 (test_loss)。

[12]:
if 'test_loaders' in loaders:
    trainer.test(dataloaders=loaders['test_loaders'], ckpt_path='best')
Restoring states from the checkpoint path at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_0/checkpoints/epoch=4-step=1560.ckpt
Loaded model weights from checkpoint at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_0/checkpoints/epoch=4-step=1560.ckpt
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.19169999659061432
        test_loss           1.6126253604888916
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

使用 CACM 进行预测#

我们现在使用 CACM 训练和评估上述数据集。我们通过作为 CACM 输入提供的列表 attr_types 指定存在的偏移类型。关于在多属性偏移中使用 CACM 的进一步说明将在下一节中提供。

[13]:
from dowhy.causal_prediction.algorithms.cacm import CACM
[14]:
# `attr_types` list contains type of attributes present (supports 'causal', 'conf', ind', and  'sel' currently)
algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal'], lambda_causal=100.)
[15]:
trainer = pl.Trainer(devices=1, max_epochs=5)

trainer.fit(algorithm, loaders['train_loaders'], loaders['val_loaders'])
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 306 K
-------------------------------------
306 K     Trainable params
0         Non-trainable params
306 K     Total params
1.226     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5` reached.
[16]:
if 'test_loaders' in loaders:
    trainer.test(dataloaders=loaders['test_loaders'], ckpt_path='best')
Restoring states from the checkpoint path at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_1/checkpoints/epoch=4-step=1560.ckpt
Loaded model weights from checkpoint at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_1/checkpoints/epoch=4-step=1560.ckpt
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.6495000123977661
        test_loss           0.6850733757019043
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

扩展到不同的数据集和算法#

MNIST Independent 和 Causal+Independent 数据集#

我们展示了如何对 MNISTIndAttributeMNISTCausalIndAttribute 数据集执行上述评估。需要向 CACM 算法提供额外的 attr_types 以处理多种偏移。我们目前支持数据中的 因果混杂独立选定 分布偏移。

MNISTIndAttribute:单属性 独立 偏移#

[17]:
from dowhy.causal_prediction.datasets.mnist import MNISTIndAttribute

data_dir = 'data'
dataset = MNISTIndAttribute(data_dir)
[18]:
algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind'], lambda_ind=10., E_eq_A=[0])

MNISTCausalIndAttribute:多属性 因果+独立 偏移#

[19]:
from dowhy.causal_prediction.datasets.mnist import MNISTCausalIndAttribute

data_dir = 'data'
dataset = MNISTCausalIndAttribute(data_dir)
[20]:
# `attr_types` should be ordered consistent with the attribute order in dataset class
algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal', 'ind'], lambda_causal=100., lambda_ind=10., E_eq_A=[1])

额外的数据集和算法#

我们提供了使用 ERM 和 CACM 算法在 MNIST 上的演示。可以将评估扩展到新的数据集和算法进行评估。

可以将新数据集添加到 dowhy.causal_prediction.datasets 并在此导入,就像我们对 MNIST 所做的那样。我们在 dowhy.causal_prediction.datasets.mnist 中提供了 MNIST 数据集(及其变体)的描述,这将有助于创建新的数据集类。我们目前支持数据中的 因果混杂独立选定 分布偏移。

我们在 dowhy.causal_prediction.algorithms 中实现了 ERM 作为基准。可以通过重写基类 PredictionAlgorithm 中的 training_step 函数来添加其他算法。

参考文献#

[1] Kaur, J.N., Kıcıman, E., & Sharma, A. (2022). Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization. ArXiv, abs/2206.07837.

[2] Ghifary, M., Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain Generalization for Object Recognition with Multi-task Autoencoders. 2015 IEEE International Conference on Computer Vision (ICCV), 2551-2559.

[3] Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization. ArXiv, abs/1907.02893.