dowhy.causal_prediction.dataloaders 包#

子模块#

dowhy.causal_prediction.dataloaders.fast_data_loader 模块#

class dowhy.causal_prediction.dataloaders.fast_data_loader.FastDataLoader(dataset, batch_size, num_workers)[源码]#

基类: object

DataLoader 包装器,通过不在每个 epoch 重新生成工作进程来稍微提高速度。

class dowhy.causal_prediction.dataloaders.fast_data_loader.InfiniteDataLoader(dataset, weights, batch_size, num_workers)[源码]#

基类: object

dowhy.causal_prediction.dataloaders.get_data_loader 模块#

dowhy.causal_prediction.dataloaders.get_data_loader.get_eval_loader(dataset, envs, batch_size, class_balanced=False)[源码]#

返回评估数据集加载器(测试/验证)。

参数:
  • dataset – 包含环境列表的数据集类

  • envs – 包含数据集中验证/测试域索引的列表

  • batch_size – 数据集加载器要使用的批大小的值

  • class_balanced – 布尔标志,指示是否在类别之间进行平衡采样

返回:

数据集加载器列表

dowhy.causal_prediction.dataloaders.get_data_loader.get_loaders(dataset, train_envs, batch_size, val_envs=None, test_envs=None, class_balanced=False, holdout_fraction=0.2, trial_seed=0)[源码]#

返回训练、验证和测试数据集加载器。

参数:
  • dataset – 包含环境列表的数据集类

  • train_envs – 包含数据集中训练域索引的列表

  • batch_size – 数据集加载器要使用的批大小的值

  • val_envs – 包含数据集中验证域索引的列表。如果为 None,则使用训练数据的一部分 (holdout_fraction) 来创建验证集。

  • test_envs – 包含数据集中测试域索引的列表

  • class_balanced – 布尔标志,指示是否在类别之间进行平衡采样

  • holdout_fraction – 用于创建验证域的训练数据比例。当 val_envs 为 None 时使用。

  • trial_seed – 从训练数据生成验证拆分时使用的种子。当 val_envs 为 None 时使用。

返回:

包含数据集加载器列表的字典,格式为 {'train_loaders': [train_dataloader_1, train_dataloader_2, ....],

'val_loaders': [val_dataloader_1, val_dataloader_2, ....], 'test_loaders': [test_dataloader_1, test_dataloader_2, ....]。

}

dowhy.causal_prediction.dataloaders.get_data_loader.get_train_eval_loader(dataset, envs, batch_size, class_balanced, holdout_fraction, trial_seed)[源码]#

返回训练和验证数据集加载器。

参数:
  • dataset – 包含环境列表的数据集类

  • envs – 包含数据集中训练域索引的列表

  • batch_size – 数据集加载器要使用的批大小的值

  • class_balanced – 布尔标志,指示是否在类别之间进行平衡采样

  • holdout_fraction – 用于创建验证域的训练数据比例

  • trial_seed – 从训练数据生成验证拆分时使用的种子

返回:

分别用于训练 (train_loaders) 和验证 (val_loaders) 的两个数据集加载器列表

dowhy.causal_prediction.dataloaders.get_data_loader.get_train_loader(dataset, envs, batch_size, class_balanced=False)[源码]#

返回训练数据集加载器。

参数:
  • dataset – 包含环境列表的数据集类

  • envs – 包含数据集中训练域索引的列表

  • batch_size – 数据集加载器要使用的批大小的值

  • class_balanced – 布尔标志,指示是否在类别之间进行平衡采样

返回:

数据集加载器列表

dowhy.causal_prediction.dataloaders.misc 模块#

杂项辅助函数

dowhy.causal_prediction.dataloaders.misc.make_weights_for_balanced_classes(dataset)[源码]#
dowhy.causal_prediction.dataloaders.misc.seed_hash(*args)[源码]#

从所有参数派生一个整数哈希值,用作随机种子。

dowhy.causal_prediction.dataloaders.misc.split_dataset(dataset, n, seed=0)[源码]#

返回与给定数据集随机拆分相对应的两对数据集,第一个数据集包含 n 个数据点,其余数据点在最后一个数据集中,使用给定的随机种子。

模块内容#