0%

Word Embedding的参数推导和直观理解_DL体会总结_3

最近在看NLP,其中的基础就是word embedding,我也看了cs224n的课,也看了那个关于word2vec的论文,但我看到的就是优化这么一个目标函数

$$ \log{\sigma(v_{w_{O}}^{'T}v_{w_{I}})} + \sum_{i=1}^{n}E_{w_{i}\sim P_{n}(w)}[\log\sigma(-v_{w_{i}}^{'T}v_{w_{I}})] $$

这是什么?原论文基本没有写细节。我特别想知道,这样的函数怎么做back propagation。于是总算找到这个:《word2vec Parameter Learning Explained》。人家确认牛,不仅深入浅出的给出数学推导,还能给出一个直观的,说人话的解释,让人更加能够明白word2vec到底在干什么。我想尝试整理一下,看看都学到了哪些,所以总结如下:

1, Forward过程

训练word embedding的网络结构是这样的,他的隐藏层后面并没有任何非线性函数。

wordvec

为了说明计算过程,模型简化为1对1 的预测,类似bigram。

Input -> Hidden

模型的输入到hidden的计算是

$$h=W^{T}X = W_{(k,.)}^{T}$$

X是one-hot向量,是V*1的列向量,W是V*N的矩阵,就是word embed,每行代表词表中的一个词。h是N*1的列向量。

Hidden -> Output

h向量就是W的第k行,也就是词表中第k个词的向量。$$W^{'}$$ 是N*V的矩阵,可理解为另外一组word embed。 从h预测output时,相当于是输入词的embed和输出词embed做内积,得出一个score u

$$ u= W^{'T}h$$

u是V*1的向量,通过softmax,得出预测的每个词的概率y $$ y_{i} = \frac{exp(u_{i})}{\sum_{k=1}^{V} exp(u_{k})} $$

Loss Func

有了预测概率,有了true target,就可以通过交叉熵来计算损失函数了,经过基本变形就得到了 $$ \begin{align*} E &= - log \space y_{j^{}} \ &= -u_{j^{}} + \log \sum_{j’=1}^{V} exp(u_{j’})
\end{align*} $$

我真正的困惑是从下面开始,不知道怎么去做导数反向传递,好在《word2vec Parameter Learning Explained》给出的推导过程特别详细,我才能勉强看懂。

2, Backward过程

  1. $$h=W^{T}X = W_{(k,.)}^{T}$$
  2. $$ u= W^{'T}h$$
  3. $$ y_{i} = \frac{exp(u_{i})}{\sum_{k=1}^{V} exp(u_{k})} $$
  4. $$ E = -u_{j^{*}} + \log \sum_{j{'}=1}{V} exp(u_{j^{'}}) $$

就是这几个公式依次求导。可是反向求导为啥难理解呢,我觉得主要是因为,前向过程都是用矩阵或向量计算的,求导时需要很多变换,还需要考虑转置的问题,行列的问题,转换步骤一多,思维就乱掉了。

首先求关于$$u_{j}$$ 的导数 $$ \frac{\partial E}{\partial u_{j}} = y_{j} - t_{j} := e_{j} \qquad j\in [1,V] $$ 然后求关于$$W_{i,j}^{‘}$$ 的导数 $$ \begin{align*} \frac{\partial E}{\partial W_{ij}^{’}} &= \frac{\partial E}{\partial u_{j}} . \frac{\partial u_{j}}{\partial W_{ij}^{‘}} \ \&=e {j}.h{i} \qquad\qquad j\in [1,V]\quad i \in [1,N] \end{align*} $$ 这个要理解我觉得最好还是把矩阵画出来,然后一步步去推导比较容易理解。其实最后$$\frac{\partial E}{\partial W^{’}} $$ 会最终变为一个矩阵,参数更新也都是通过矩阵运算的方式。这个公式在原论文中给出了一个直观理解,就是对于输出参数矩阵的每个词,根据预测的概率误差,相应的远离输入词。相当于这次word vector在他们的高维空间,不停的移动,已获得最佳的位置。当训练样本足够多了,每个word vector也就基本稳定不会移动了,这时候就可以把参数矩阵拿出来直接当做word embedding使用了。这些word embedding中包含了很多语义特征。

其他的推导懒得写了,如果以后忘记了,就回看论文好了。