dowhy.causal_prediction.models 包#

子模块#

dowhy.causal_prediction.models.networks 模块#

MNIST_MLP 架构借鉴自 OoD-Bench
@inproceedings{ye2022ood,

title={OoD-Bench: Quantifying and Understanding Two Dimensions of Out-of-Distribution Generalization}, author={Ye, Nanyang and Li, Kaican and Bai, Haoyue and Yu, Runpeng and Hong, Lanqing and Zhou, Fengwei and Li, Zhenguo and Zhu, Jun}, booktitle={CVPR}, year={2022}

}

dowhy.causal_prediction.models.networks.Classifier(in_features, out_features, is_nonlinear=False)[source]#
class dowhy.causal_prediction.models.networks.ContextNet(input_shape)[source]#

基类: Module

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

forward(x)[source]#

定义每次调用时执行的计算。

应被所有子类重写。

注意

尽管前向传播的逻辑需要在此函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会处理注册的钩子,而后者则会默默忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.Identity[source]#

基类: Module

一个恒等层

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

forward(x)[source]#

定义每次调用时执行的计算。

应被所有子类重写。

注意

尽管前向传播的逻辑需要在此函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会处理注册的钩子,而后者则会默默忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.MLP(n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout)[source]#

基类: Module

一个简单的 MLP

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

forward(x)[source]#

定义每次调用时执行的计算。

应被所有子类重写。

注意

尽管前向传播的逻辑需要在此函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会处理注册的钩子,而后者则会默默忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.MNIST_CNN(input_shape)[source]#

基类: Module

为 MNIST 手工调优的架构。到目前为止使用此架构注意到的一些奇怪之处: - 在特征的均值池化后添加线性层会严重损害

RotatedMNIST-100 的泛化能力。

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

forward(x)[source]#

定义每次调用时执行的计算。

应被所有子类重写。

注意

尽管前向传播的逻辑需要在此函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会处理注册的钩子,而后者则会默默忽略它们。

n_outputs = 128#
training: bool#
class dowhy.causal_prediction.models.networks.MNIST_MLP(input_shape)[source]#

基类: Module

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

forward(x)[source]#

定义每次调用时执行的计算。

应被所有子类重写。

注意

尽管前向传播的逻辑需要在此函数中定义,但之后应调用 Module 实例而不是直接调用此函数,因为前者会处理注册的钩子,而后者则会默默忽略它们。

training: bool#
class dowhy.causal_prediction.models.networks.ResNet(input_shape, resnet18=True, resnet_dropout=0.0)[source]#

基类: Module

ResNet 移除了 softmax 并冻结了 batchnorm

初始化内部 Module 状态,由 nn.Module 和 ScriptModule 共享。

forward(x)[source]#

将 x 编码为大小为 n_outputs 的特征向量。

freeze_bn()[source]#
train(mode=True)[source]#

覆盖默认的 train() 方法以冻结 BN 参数

training: bool#

模块内容#