首页 新闻 会员 周边

PyTorch使用Cifar-10数据集时报错:TypeError: default_collate

0
[已解决问题] 解决于 2024-04-09 09:58

报错

TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>

相关代码

from torch.utils.data import DataLoader
import torchvision
import torch
from sampler import *


torch.manual_seed(0)
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=ImbalancedDatasetSampler(train_dataset))


for i, (data, target) in enumerate(train_dataloader):
    print(target)
    if i == 5:
        break

运行环境

pytorch                   2.2.2           py3.12_cuda12.1_cudnn8_0    pytorch
torchvision               0.17.2                   pypi_0    pypi
zh-jp的主页 zh-jp | 菜鸟二级 | 园豆:226
提问于:2024-04-09 09:56
< >
分享
最佳答案
0

cifar10数据集读入的图片没有转为张量导致的,添加将图片转为张量的模块即可:

from torch.utils.data import DataLoader
import torchvision
from sampler import *
from torchvision import transforms
transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

torch.manual_seed(0)
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_dataloader = DataLoader(train_dataset, batch_size=32, sampler=ImbalancedDatasetSampler(train_dataset))


for i, (data, target) in enumerate(train_dataloader):
    print(target)
    if i == 5:
        break
zh-jp | 菜鸟二级 |园豆:226 | 2024-04-09 09:58
清除回答草稿
   您需要登录以后才能回答,未注册用户请先注册