dowhy.causal_prediction.dataloaders 包#
子模块#
dowhy.causal_prediction.dataloaders.fast_data_loader 模块#
dowhy.causal_prediction.dataloaders.get_data_loader 模块#
- dowhy.causal_prediction.dataloaders.get_data_loader.get_eval_loader(dataset, envs, batch_size, class_balanced=False)[源码]#
返回评估数据集加载器(测试/验证)。
- 参数:
dataset – 包含环境列表的数据集类
envs – 包含数据集中验证/测试域索引的列表
batch_size – 数据集加载器要使用的批大小的值
class_balanced – 布尔标志,指示是否在类别之间进行平衡采样
- 返回:
数据集加载器列表
- dowhy.causal_prediction.dataloaders.get_data_loader.get_loaders(dataset, train_envs, batch_size, val_envs=None, test_envs=None, class_balanced=False, holdout_fraction=0.2, trial_seed=0)[源码]#
返回训练、验证和测试数据集加载器。
- 参数:
dataset – 包含环境列表的数据集类
train_envs – 包含数据集中训练域索引的列表
batch_size – 数据集加载器要使用的批大小的值
val_envs – 包含数据集中验证域索引的列表。如果为 None,则使用训练数据的一部分 (holdout_fraction) 来创建验证集。
test_envs – 包含数据集中测试域索引的列表
class_balanced – 布尔标志,指示是否在类别之间进行平衡采样
holdout_fraction – 用于创建验证域的训练数据比例。当 val_envs 为 None 时使用。
trial_seed – 从训练数据生成验证拆分时使用的种子。当 val_envs 为 None 时使用。
- 返回:
包含数据集加载器列表的字典,格式为 {'train_loaders': [train_dataloader_1, train_dataloader_2, ....],
'val_loaders': [val_dataloader_1, val_dataloader_2, ....], 'test_loaders': [test_dataloader_1, test_dataloader_2, ....]。
}
- dowhy.causal_prediction.dataloaders.get_data_loader.get_train_eval_loader(dataset, envs, batch_size, class_balanced, holdout_fraction, trial_seed)[源码]#
返回训练和验证数据集加载器。
- 参数:
dataset – 包含环境列表的数据集类
envs – 包含数据集中训练域索引的列表
batch_size – 数据集加载器要使用的批大小的值
class_balanced – 布尔标志,指示是否在类别之间进行平衡采样
holdout_fraction – 用于创建验证域的训练数据比例
trial_seed – 从训练数据生成验证拆分时使用的种子
- 返回:
分别用于训练 (train_loaders) 和验证 (val_loaders) 的两个数据集加载器列表
dowhy.causal_prediction.dataloaders.misc 模块#
杂项辅助函数