NLP学习之:BERT 模型复现(4)模型实现
创始人
2024-01-13 04:52:26
0

文章目录

  • 原理
    • 数据集输入
    • Embedding
    • Attention 机制
    • 模型结构
      • 主体结构
      • 激活函数
      • BERT 网络代码

原理

数据集输入

  • 每次两个句子,用 [CLS], [SEP] 分隔,
  • 这两个句子中的每个词都有 15% 的概率被 [MASK]
    • 被选中 MASK 的词中 80% 真的用 [MASK] 符号替换原词
    • 10% 换成其他随机词(引入噪声)
    • 剩余 10% 啥也不干,虚晃一枪

Embedding

  • 需要 embedding 的有:
    • token : 将原本的 tokenembedding 规定的空间维度中表示(例如 embedding 的空间是 768 维空间)
    • position:将位置信息在 embedding 的空间维度中表示
    • segment:将当前词属于第一句话还是第二句话的信息在 embedding 的空间中表示

Attention 机制

  • 单个 attention 是通过 Q,K,VQ, K, VQ,K,V 三个矩阵计算出对于一个单词,其他单词与他的相关程度,并把这些相关关系编码到最终的 QKTVQ K^{T}VQKTV 输出的矩阵中
  • 多头注意力机制就是将上述的 attention 重复了 nheadn_{head}nhead​ 次,这样映射到多个空间中编码词之间的相关关系会更加充分地利用输入信息

模型结构

主体结构

  • 就是 Transformer 网络的编码端;所以借鉴了 Transformer 的基础 block 结构

class LayerNorm(nn.Module):"""Construct a layernorm module (See citation for details).Layer 标准化"""def __init__(self, features, eps=1e-6):super(LayerNorm, self).__init__()self.a_2 = nn.Parameter(torch.ones(features))self.b_2 = nn.Parameter(torch.zeros(features))self.eps = epsdef forward(self, x):mean = x.mean(-1, keepdim=True)std = x.std(-1, keepdim=True)return self.a_2 * (x - mean) / (std + self.eps) + self.b_2class SublayerConnection(nn.Module):"""A residual connection followed by a layer norm.Note for code simplicity the norm is first as opposed to last."""def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):"Apply residual connection to any sublayer with the same size."return x + self.dropout(sublayer(self.norm(x)))class PositionwiseFeedForward(nn.Module):"Implements FFN equation."def __init__(self, d_model, d_ff, dropout=0.1):""":param d_model: 词向量的维度:param d_ff::param dropout:"""super(PositionwiseFeedForward, self).__init__()self.w_1 = nn.Linear(d_model, d_ff)self.w_2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)self.activation = GELU()def forward(self, x):return self.w_2(self.dropout(self.activation(self.w_1(x))))class TransformerBlock(nn.Module):"""Bidirectional Encoder = Transformer (self-attention)Transformer = MultiHead_Attention + Feed_Forward with sublayer connection"""def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):""":param hidden: hidden size of transformer:param attn_heads: head sizes of multi-head attention:param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size:param dropout: dropout rate"""super(TransformerBlock, self).__init__()self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)self.dropout = nn.Dropout(p=dropout)def forward(self, x, mask):x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, mask=mask))x = self.output_sublayer(x, self.feed_forward)return self.dropout(x)

激活函数

  • 使用了作者提出的 GELU 激活函数

class GELU(nn.Module):"""Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU在论文的 3.4 节中,作者重写设计了 GELU 激活函数来代替 RELU"""def forward(self, x):return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

BERT 网络代码


class BERT(nn.Module):"""BERT model : Bidirectional Encoder Representations from Transformers."""def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):""":param vocab_size: vocab_size of total words:param hidden: BERT model hidden size:param n_layers: numbers of Transformer blocks(layers):param attn_heads: number of attention heads:param dropout: dropout rate"""super(BERT, self).__init__()self.hidden = hiddenself.n_layers = n_layersself.attn_heads = attn_heads# paper noted they used 4*hidden_size for ff_network_hidden_sizeself.feed_forward_hidden = hidden * 4# embedding for BERT, sum of positional, segment, token embeddingsself.embedding = BERTEmbedding(vocab_size=vocab_size, d_model=hidden)# multi-layers transformer blocks, deep networkself.transformer_blocks = nn.ModuleList([TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])def forward(self, x, segment_info):# attention masking for padded token# torch.ByteTensor([batch_size, 1, seq_len, seq_len)mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)# embedding the indexed sequence to sequence of vectorsx = self.embedding(x, segment_info)# running over multiple transformer blocksfor transformer in self.transformer_blocks:x = transformer.forward(x, mask)return x

相关内容

热门资讯

保存时出现了1个错误,导致这篇... 当保存文章时出现错误时,可以通过以下步骤解决问题:查看错误信息:查看错误提示信息可以帮助我们了解具体...
汇川伺服电机位置控制模式参数配... 1. 基本控制参数设置 1)设置位置控制模式   2)绝对值位置线性模...
不能访问光猫的的管理页面 光猫是现代家庭宽带网络的重要组成部分,它可以提供高速稳定的网络连接。但是,有时候我们会遇到不能访问光...
不一致的条件格式 要解决不一致的条件格式问题,可以按照以下步骤进行:确定条件格式的规则:首先,需要明确条件格式的规则是...
本地主机上的图像未显示 问题描述:在本地主机上显示图像时,图像未能正常显示。解决方法:以下是一些可能的解决方法,具体取决于问...
表格列调整大小出现问题 问题描述:表格列调整大小出现问题,无法正常调整列宽。解决方法:检查表格的布局方式是否正确。确保表格使...
表格中数据未显示 当表格中的数据未显示时,可能是由于以下几个原因导致的:HTML代码问题:检查表格的HTML代码是否正...
Android|无法访问或保存... 这个问题可能是由于权限设置不正确导致的。您需要在应用程序清单文件中添加以下代码来请求适当的权限:此外...
【NI Multisim 14...   目录 序言 一、工具栏 🍊1.“标准”工具栏 🍊 2.视图工具...
银河麒麟V10SP1高级服务器... 银河麒麟高级服务器操作系统简介: 银河麒麟高级服务器操作系统V10是针对企业级关键业务...