关于pytorch里DataLoader的理解
创始人
2024-01-31 22:47:57
0

目录

一、python迭代器生成器基础讲解

1.1可迭代对象Iterable

1.2迭代器Iterator

1.3for in 的本质流程

1.4 getitem

1.5 yield 生成器

二、DataLoader的基础实现

三、整体框架的讲解


一、python迭代器生成器基础讲解

1.1可迭代对象Iterable

表示该对象可迭代,并不一定是一个数据类型,如字典,字符串,列表等,它也可以是一个实现了__iter__方法的类。

from collections.abc import Iterable, Iteratorclass A(object):def __init__(self):self.a = [1, 2, 3]def __iter__(self):# 此处返回啥无所谓return self.acls_a = A()
#  True
print(isinstance(cls_a, Iterable))

如果对象是Iterable,依然无法用for循环遍历,因为Iterable仅仅是提供了一种抽象规范接口。

1.2迭代器Iterator

如果一个对象是迭代器,那么它肯定是可迭代的,但是如果一个对象是可迭代的,它不一定是迭代器。实现了 __next__ 和 __iter__ 方法的类才能称为迭代器,就可以被 for 遍历了。

class A(object):def __init__(self):self.index = -1self.a = [1, 2, 3]# 必须要返回一个实现了 __next__ 方法的对象,否则后面无法 for 遍历# 因为本类自身实现了 __next__,所以通常都是返回 self 对象即可def __iter__(self):return selfdef __next__(self):self.index += 1if self.index < len(self.a):return self.a[self.index]else:# 抛异常,for 内部会自动捕获,表示迭代完成raise StopIteration("遍历完了")cls_a = A()
print(isinstance(cls_a, Iterable)) # True
print(isinstance(cls_a, Iterator)) # True
print(isinstance(iter(cls_a), Iterator)) # Truefor a in cls_a:print(a)
# 打印 1 2 3

1.3for in 的本质流程

for.....in...被python编译器编译后,如下

# 实际调用了 __iter__ 方法返回自身,包括了 __next__ 方法的对象
cls_a = iter(cls_a)
while True:try:# 然后调用对象的 __next__ 方法,不断返回元素value = next(cls_a)print(value)# 如果迭代完成,则捕获异常即可except StopIteration:break

可见,任何一个对象要能被for遍历,必须实现__iter__和__next__两个方法。

list是可迭代对象,但是没next方法,为什么可以实现for循环遍历。list内部的iter方法的内部实现了next方法。

所以得到:一个对象要能够被 for .. in .. 迭代,那么不管你是直接实现 __iter__ 和 __next__ 方法(对象必然是 Iterator),还是只实现 __iter__(不是 Iterator),但是内部间接返回了具备 __next__ 对象的类,都是可行的

1.4 getitem

上面说过for in本质就是调用__iter__和__next__方法,实际上还有一种更简单的方法,__getitem__方法就可以让对象实现迭代功能。实际上任何一个类,只要实现了__getitem__方法,那么当调用iter(类实例)时候会自动具备__iter__和__next__方法。__getitem__ 实际上是属于 iternext方法的高级封装,也就是我们常说的语法糖,只不过这个转化是通过编译器完成,内部自动转化,非常方便。

class A(object):def __init__(self):self.a = [1, 2, 3]def __getitem__(self, item):return self.a[item]cls_a = A()
print(isinstance(cls_a, Iterable))  # False
print(isinstance(cls_a, Iterator))  # False
print(dir(cls_a))  # 仅仅具备 __getitem__ 方法cls_a = iter(cls_a)
print(dir(cls_a))  # 具备 __iter__ 和 __next__ 方法print(isinstance(cls_a, Iterable))  # True
print(isinstance(cls_a, Iterator))  # True# 等价于 for .. in ..
while True:try:# 然后调用对象的 __next__ 方法,不断返回元素value = next(cls_a)print(value)# 如果迭代完成,则捕获异常即可except StopIteration:break# 输出: 1 2 3

如果你想该对象具备 list 等对象一样的长度属性,则只需要实现 __len__ 方法即可。

此时我们已经知道了第一种高级语法糖实现迭代器功能,下面分析另一个更简单的可以直接作用于函数的语法糖。

1.5 yield 生成器

生成器是一个在行为上和迭代器非常类似的对象,两者功能差不多,但生成器更优雅,只需要用关键字yield来返回。作用于函数上叫生成器函数,调用函数返回一个生成器。

def func():for a in [1, 2, 3]:yield acls_g = func()
print(isinstance(cls_g, Iterator))  # True
print(dir(cls_g))  # 自动具备 __iter__ 和 __next__ 方法for a in cls_g:print(a)# 输出: 1 2 3# 一种更简单的写法是用 ()
cls_g = (i for i in [1,2,3])

使用 yield 函数与使用 return 函数,在执行时差别在于:包含 yield 的方法一般用于迭代,每次执行时遇到 yield 就返回 yield 后的结果,但内部会保留上次执行的状态,下次继续迭代时,会继续执行 yield 之后的代码,直到再次遇到 yield 后返回。生成器是懒加载模式,特别适合解决内存占用大的集合问题。

总结:在迭代对象基础上,如果实现了 __next__ 方法则是迭代器对象,该对象在调用 next() 的时             候返回下一个值,如果容器中没有更多元素了,则抛出 StopIteration 异常。

           对于采用语法糖 __getitem__ 实现的迭代器对象,其本身实例既不是可迭代对象,更不是               迭代器,但是其可以被 for in 迭代,原因是对该对象采用 iter(类实例) 操作后就会自动变成             迭代器。

          生成器是一种特殊迭代器,但是不需要像迭代器一样实现__iter____next__方法,只需要            使用关键字 yield 就可以,生成器的构造可以通过生成器表达式 (),或者对函数返回值加入            yield 关键字实现。

          对于在类的 __iter__ 方法中采用语法糖 yield 实现的迭代器对象,其本身实例是可迭代对              象,但不是迭代器,但是其可以被 for .. in .. 迭代,原因是对该对象采用 iter(类实例) 操作后            就会自动变成迭代器。

二、DataLoader的基础实现

首先介绍5个基本的对象:

Dataset提供整个数据集的随机访问功能,每次访问都返回单个对象,例如一个对象和一个target。

Sampler提供整个数据集随机访问的索引列表,每次调用都返回所有列表中的单个索引。常用的子类是SequentialSampler 用于提供顺序输出的索引 和 RandomSampler 用于提供随机输出的索引

BatchSampler内部调用Sampler实列,输出指定batch_size个索引,然后将索引作用于Dataset上从而输出batch_size个数据对象,例如batch_size个数据和索引。

Collate_fn用于将batch_size个数据对象在batch维度进行聚合,生成(batch,.....)格式的数据输出。如果待聚合对象是numpy,则自动转化为tensor,此时就可以输入到网络中了。

迭代一次伪代码如下(非迭代器版本)

class DataLoader(object):def __init__(self):#假设数据长度为100,batch_size是4self.dataset=[[img0,target0],[img1,target1],.....[img99,target99]]self.sampler=[0,1,2,.....,99]self.batch_size=4self.index=0def collate_fn(self,data):#在batch维度聚合数据batch_img=torch.Stack(data[0],0)batch_target=torch.stack(data[1],0)return batch_img,batch_targetdef __next__(self):i=0batch_index=[]while i

以上就是最抽象的 DataLoader 运行流程以及和 Dataset、Sampler、BatchSampler、collate_fn 的关系。

首先需要强调的是 Dataset、Sampler、BatchSampler 和 DataLoader 都直接或间接实现了迭代器。

Dataset通过__getitem__方法使其可迭代

Sample对象是一个可迭代的基类对象,其常用子类 SequentialSampler 在 __iter__ 内部返回迭代器,RandomSampler 在 __iter__ 内部通过 yield 关键字返回迭代器

Batchsampler也是在__iter__内部通过yield关键字返回迭代器

DataLoader通过__iter__和__next__直接实现迭代器

除了DataLoader本身是迭代器外,其余对象本身都不是迭代器,但可以for in迭代

由于 DataLoader 类写的非常通用,故 Dataset、Sampler、BatchSampler 都可以外部传入,除了 Dataset 必须输入外,其余两个类都有默认实现,最典型的 Sampler 就是 SequentialSampler 和 RandomSampler。

需要注意的是 Sampler 对象其实在大部分时候都不需要传入 Dataset 实例对象,因为其功能仅仅是返回索引而已,并没有直接接触数据。

三、整体框架的讲解

核心运行逻辑:

def __next__(self):#返回batch个索引index=next(self.batch_sampler)#利用索引去取数据data=[self.dataset[idx] for idx in index]#batch维度聚合data=self.collate_fn(data)return data

整体流程:

1.self.batch_sampler=iter(batch_sampler)。在DataLoader的类初始化,需要得到BatchSampler的迭代器对象。

2.index=next(self.batch_sampler)。对于每次迭代,DataLoader对象首先会调用BatchSampler的迭代器进行下一次迭代,具体是调用BatchSampler对象的__iter__方法

3.而BatchSampler对象的__iter__方法实际上是需要依靠Sampler对象进行迭代输出索引,Sampler对象也是一个迭代器,当迭代batch_size次后就可以得到batch_size个数据索引。

4.data=[self.dataset[idx] for idx in index]。有了batch个索引就可以通过不断调用dataset的__getitem__方法返回数据对象,此时data就包含了batch个对象。

5.data=self.collate_fn(data)。将batch个对象输入给聚合函数,在第0个维度也就是batch维度进行聚合,得到类似(batch,....)的对象。

6.重复上面的操作,就可以不断输出一个一个的batch数据

class Dataset(object):#只要实现了__getitem__方法就可以变成迭代器def __getitem__(self,index):raise NotImplementedErrordef __len__(self):raise NotImplementedError
class Sampler(object):def __init__(self,data_source):passdef __iter__(self):raise NotImplementedErrordef __len__(self):raise NotImplementedError
#一般出现raise NotImplementedError这个错误,就是子类没有重写父类中的成员函数,然后子类对象调用此函数会报这个错误class SequentialSampler(sampler):def __init__(self,data_source):super(SequentialSampler,self).__init__(data_source)self.data_source=data_sourcedef __iter__(self):#返回迭代器,不然无法for  inreturn iter(range(len(self.data_source))def __len__(self):return len(self.data_source)class BatchSampler(Sampler):def __init__(self,sampler,batch_size,drop_last):self.sampler=samplerself.batch_size=batch_sizeself.dorp_last=drop_lastdef __iter__(self):batch=[]for idx in self.sampler:batch.append(idx)#如果得到了batch个索引,则可以通过yield关键字生成生成器返回,得到迭代器对象if len(batch)==self.batch_size:yield batchbatch=[]if len(batch)>0 and not self.drop_last:yield batchdef __len__(self):if self.drop_last:#如果最后的索引数不等于一个batch,抛弃return len(self.sampler)//self.batch_sizeelse:return (len(self.sampler)+self.batch_size-1)//self.batch_size
class DataLoader(object):def __init__(self,dataset,batch_size=1,shuffle=False,sample=None,batch_sampler=None,collate_fn=None,drop_last=False):self.dataset=dataset#因为这两个功能是冲突的if sampler is not None and shuffle:raise ValueError('sampler option is ..')if batch_sampler is not None:# 一旦设置了 batch_sampler,那么 batch_size、shuffle、sampler# 和 drop_last 四个参数就不能传入# 因为这4个参数功能和 batch_sampler 功能冲突了if batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and ''drop_last')batch_size = Nonedrop_last = Falseif sampler is None:if shuffle:sampler = RandomSampler(dataset)else:sampler = SequentialSampler(dataset)# 也就是说 batch_sampler 必须要存在,你如果没有设置,那么采用默认类if batch_sampler is None:batch_sampler = BatchSampler(sampler, batch_size, drop_last)self.batch_size = batch_sizeself.drop_last = drop_lastself.sampler = samplerself.batch_sampler = iter(batch_sampler)if collate_fn is None:collate_fn = default_collateself.collate_fn = collate_fn#核心代码def __next__(self):index=next(self.batch_sampler)data=[self.dataset[idx] for idx in index]data=self.collate_fn(data)return data#返回自身,因为自身实现了nextdef __iter__(self):return self
def default_collate(batch):elem=batch[0]elem_type=type(elem)if isinstance(elem,torch.Tensor):return torch.stack(batch,0)elif elem_type.__module__=='numpy':return default_collate([torch.as_tensor(b) for b in batch])else:raise NotImplementedError

完整调用例子

class Simplev1Dataset(Dataset):def __init__(self):#伪造数据self.imgs=np.arange(0,16).reshape(8,2)def __getitem__(self,index):return self.imgs[index]def __len__(self):return self.imgs.shape[0]from simplev1_dataset import Simplev1Dataset
simple_dataset=Simplev1Dataset()
dataloader=DataLoader(simple_dataset,batch_size=2,collate_fn=default_collate)
for data in dataloader:print(data)

四、Reference

https://zhuanlan.zhihu.com/p/340465632

相关内容

热门资讯

AWSECS:访问外部网络时出... 如果您在AWS ECS中部署了应用程序,并且该应用程序需要访问外部网络,但是无法正常访问,可能是因为...
AWSElasticBeans... 在Dockerfile中手动配置nginx反向代理。例如,在Dockerfile中添加以下代码:FR...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...
北信源内网安全管理卸载 北信源内网安全管理是一款网络安全管理软件,主要用于保护内网安全。在日常使用过程中,卸载该软件是一种常...
AWR报告解读 WORKLOAD REPOSITORY PDB report (PDB snapshots) AW...
AWS管理控制台菜单和权限 要在AWS管理控制台中创建菜单和权限,您可以使用AWS Identity and Access Ma...
​ToDesk 远程工具安装及... 目录 前言 ToDesk 优势 ToDesk 下载安装 ToDesk 功能展示 文件传输 设备链接 ...
群晖外网访问终极解决方法:IP... 写在前面的话 受够了群晖的quickconnet的小水管了,急需一个新的解决方法&#x...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
Azure构建流程(Power... 这可能是由于配置错误导致的问题。请检查构建流程任务中的“发布构建制品”步骤,确保正确配置了“Arti...