nn.TransformerEncoderLayer中的src_mask,src_key_padding_mask解析
创始人
2024-06-03 08:08:33
0

注意,不同版本的pytorch,对nn.TransformerEncdoerLayer部分代码差别很大,比如1.8.0版本中没有batch_first参数,而1.10.1版本中就增加了这个参数,笔者这里使用pytorch1.10.1版本实验。

attention mask

要搞清楚src_mask和src_key_padding_mask的区别,关键在于搞清楚在self-attention中attention mask的作用是啥。
attetnionscore=softmax(QKTdk)Vattetnion \ score = softmax({QK^{T} \over \sqrt d_{k} })V attetnion score=softmax(d​k​QKT​)V
上式中,并没有体现出pad的token,认为所有token都是有用的,但是实际写代码时使用batch进行训练,所以要将所有token序列pad到相同的长度。
attention mask的作用就是,在计算注意力分数的时候,告诉模型,哪些token是pad的,不应该分配注意力分数。

针对一条长度为LLL的token序列,其attention mask的矩阵应该是L∗LL*LL∗L,下图是一个attention mask,蓝色的表示不是pad的token,灰色的表示pad的token。

在这里插入图片描述
但是针对attention mask中蓝色位置和灰色位置中的值,目前有两种做法:

  • 在huggingface的transformers中实现是,将蓝色位置填1 ,灰色位置填0,也就是1表示真实序列,不需要被mask,而0表示pad序列,需要被mask。但是为了用户操作,huggingface并没有要求用户输入一个B∗L∗LB*L*LB∗L∗L的mask矩阵,而是输入B∗LB*LB∗L的矩阵即可,然后在forward函数中使用get_extended_attention_mask方法将其扩展为B∗L∗LB*L*LB∗L∗L的mask矩阵。
  • 在pytorch的transformers中的实现是,蓝色的位置填0,灰色的位置填float(“-inf”),但是在实现时,又分为了src_mask和src_key_padding_mask,而最终的attention mask矩阵,是通过这个两个矩阵得到的。
    其中:

src_mask: 必须是2D或者3D的矩阵,形状为[L,S][L,S][L,S]或者[B∗num_heads,L,S][B*num\_heads, L, S][B∗num_heads,L,S],LLL是目标序列长度,SSS是源序列长度(只有涉及到机器翻译这种encoder-decoder框架目标序列和源序列才有意义,如果只是用transformer encoder做编码,则L=SL=SL=S),BBB是batch size,numheadnum\ headnum head表示头数。另外src_mask的取值有三种,

  1. 可以是binary mask,True的位置表示需要被mask,
  2. 可以是byte mask,非零的位置表示需要被mask,
  3. 可以float mask,这时float(“-inf”)的位置需要被mask。

src_key_padding_mask:是一个2D的矩阵,形状为[B,S][B, S][B,S],取值有两种,

  1. 可以是binary mask,True的位置表示key矩阵需要被mask,
  2. 可以是byte mask,非零的位置表示key矩阵需要被mask,

这里的key矩阵应该也是为了涵盖encoder-decoder这样的情况,对于只用transformer encoder的情况,src_key_padding_mask则更像是huggingface 中的attention mask。

其实在pytorch官方代码中,是通过src_mask和src_key_padding_mask二者综合得到最终的attention_mask。对于绝大多数情况,我们只需要使用src_key_padding_mask即可。

相关内容

热门资讯

保存时出现了1个错误,导致这篇... 当保存文章时出现错误时,可以通过以下步骤解决问题:查看错误信息:查看错误提示信息可以帮助我们了解具体...
汇川伺服电机位置控制模式参数配... 1. 基本控制参数设置 1)设置位置控制模式   2)绝对值位置线性模...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
表格中数据未显示 当表格中的数据未显示时,可能是由于以下几个原因导致的:HTML代码问题:检查表格的HTML代码是否正...
本地主机上的图像未显示 问题描述:在本地主机上显示图像时,图像未能正常显示。解决方法:以下是一些可能的解决方法,具体取决于问...
不一致的条件格式 要解决不一致的条件格式问题,可以按照以下步骤进行:确定条件格式的规则:首先,需要明确条件格式的规则是...
表格列调整大小出现问题 问题描述:表格列调整大小出现问题,无法正常调整列宽。解决方法:检查表格的布局方式是否正确。确保表格使...
Android|无法访问或保存... 这个问题可能是由于权限设置不正确导致的。您需要在应用程序清单文件中添加以下代码来请求适当的权限:此外...
【NI Multisim 14...   目录 序言 一、工具栏 🍊1.“标准”工具栏 🍊 2.视图工具...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...