$EM$算法是1977年发明的,但是到今天,想要透彻理解并不容易。到底什么是隐变量?$Q$函数到底是怎么回事?作者怎么想到$Jessen$不等式来证明的?有没有一种简单直观的理解方式?本文试图解答。
在$EM$中,隐变量是一个比较核心又难以理解的概念。
考虑我们有一些身高数据,${170, 161, 172, 173, 164, 175, 176, 167, \ldots, 179, 180}$。 此时不能简单假设这些身高数据来自某个正态分布。因为男生和女生身高数据,应该来自两个不同的正态分布。 所以我们这些数据实际上是来自两个正态分布的混合。
请仔细体会,我们现在从一个身高的数据分布中进行抽样,得到了上面的数据。我们知道每条具体抽样的身高数据,要么来自男生,要么来自女生。 这个“身高的数据分布”背后实际上对应着一个男生女生分布。怎么理解这里的对应?我们可以理解,每次抽样一条身高数据,同时也是从男生女生分布中抽样了一个样本。
一条身高数据样本<—>一个男生女生分布中的样本
唯一的不同,作为身高数据样本,我们是可以看到样本的具体值的,也就是上面的身高数字,但是当将该抽样作为男生女生分布中的样本时,我们是看不到样本的具体值的,即不知道该样本到底是来自男生还是女生,这就是“隐”的含义。
我们记身高数据分布为$P(x)$,男生女生分布为$P(z)$。 将$x$称为观测变量,$z$称为隐变量。
观测变量可以由隐变量生成,即:
\[P(x) = \sum_{z} P(x, z) = \sum_{z} P(z)P(x|z) \tag{1}\]理解这个公式,我们可以使用上面男生女生身高的例子。假设男生身高分布为$P(x|z=boy)$,女生身高分布为$P(x|z=girl)$。 那么,$P(x)$就是男生女生身高分布的混合。 即:
\[P(x) = P(z=boy) P(x|z=boy) + P(z=girl) P(x|z=girl) \tag{2}\]一般化,$P(x)$可以表示为:
\[P(x) = \sum_{z} P(z)P(x|z) \tag{3}\]对于连续变量,公式3可以写成:
\[P(x) = \int_{z} P(z)P(x|z) dz \tag{4}\]上面的公式可以类比全概率公式来理解。为便于理解,下面都使用$z$的离散形式。
有了上面含隐变量公式,我们可以用log-likelihood来求解模型参数。 我们知道log-likelihood的公式是:
\[\begin{aligned} \theta &= \arg \max_{\theta} \sum_{i=1}^n \log P(x_i|\theta) \\ &= \arg \max_{\theta} \sum_{i=1}^n \log \sum_{z} P(z)P(x_i|z) \\ &= \arg \max_{\theta} \sum_{i=1}^n \log \sum_{z} P(z|\theta)P(x_i|z, \theta) \end{aligned} \tag{5}\]此时,需要注意一下,$\theta$是模型参数。具体是哪些参数呢?这点有必要先澄清一下。 以上面的男生女生例子,$\theta$包含了三部分参数:
符号$P(z|\theta)$中的$\theta$特指的是属于$z$的那部分参数。笔者最初学习的时候,在这一点上困惑过。 log-likelihood外面的对$x$的求和,表示的是对所有样本的求和。将似然最大化,就是将每个样本的似然最大化。 为了简化推导,后面省略这个求和符号。
对上面的式子进行求梯度,得到:
\[\begin{aligned} \nabla \theta &= \frac{\partial}{\partial \theta} \log \sum_{z} P(z|\theta)P(x|z, \theta) \\ &= \frac{\partial \sum_{z} P(z|\theta)P(x|z, \theta)}{\sum_{z}P(z|\theta)P(x|z, \theta)} \\ &= \frac{\sum_{z} \biggl( \partial P(z|\theta)P(x|z, \theta) + P(z|\theta)\partial P(x|z, \theta) \biggr)}{\sum_{z}P(z|\theta)P(x|z, \theta)} \end{aligned} \tag{6}\]有了梯度,就可以使用梯度上升法来优化模型参数。万事大吉? 其实没那么容易。上面基于梯度求解,有个难以忽视的问题:
通过梯度进行更新,难以满足非负约束。
上面的公式中,有两个分布:
\[P(z|\theta)\]以及
\[P(x|z,\theta)\]因为是概率分布,需要是正数。仅通过梯度更新,很容易破坏这个约束。
至此,我们理解了隐变量,也尝试基于隐变量方案的最大似然求解,由于遇到了计算上的困难,导致无法求解。 EM算法就是来帮我们绕过这个困难的。
EM算法的起手式,就是对上面的$logP(x|\theta)$进行新的分解。 如下所示:
\[\log P(x|\theta) = \log P(x,z|\theta) - \log P(z|x, \theta) \tag{7}\]其中:
\[P(x, z|\theta) = P(z|\theta)P(x|z, \theta) \tag{8}\] \[\begin{aligned} P(z|x, \theta) &= \frac{P(x, z|\theta)}{P(x|\theta)} \\ &= \frac{P(z|\theta)P(x|z, \theta)}{\sum_{z} P(z|\theta)P(x|z, \theta)} \\ \end{aligned} \tag{9}\]这里我们细细体会一下这个分解和之前的分解的差别。之前的分解是:
\[\log P(x|\theta) =\log \sum_{z} P(z|\theta)P(x|z, \theta) \tag{10}\]现在我们继续推导公式7。 假设,我们在第n次迭代中,已经有了一个$\theta^{(n)}$。我们可以对公式7进行变形:
\[\begin{aligned} \sum_{z} P(z|x, \theta^{(n)})\log P(x|\theta) &= \sum_{z} P(z|x, \theta^{(n)}) \log \frac{P(x,z|\theta)}{P(z|x, \theta^{(n)})} \\ &- \sum_{z} P(z|x, \theta^{(n)}) \log \frac{P(z|x, \theta)}{P(z|x,\theta^{(n)})} \end{aligned} \tag{11}\]这步变化,是最费解的,就是为什么要这么变形。我们暂且不管,继续推导一步,答案就会揭晓。
\[\log P(x|\theta) = \underbrace{E_{z \in P(z|x, \theta^{(n)})} \log \frac{P(x,z|\theta)}{P(z|x, \theta^{(n)})}}_{ELBO} + KL(P(z|x, \theta^{(n)}), P(z|x, \theta)) \tag{12}\]我们知道$KL$是一个非负的值。 所以上式中的$ELBO$,实际上是$\log P(x|\theta)$的一个下界。这也是$ELBO$名称的来源,Evidence Lower Bound。
那么,我们只要想办法最大化$ELBO$,就可以间接最大化$\log P(x|\theta)$。 $ELBO$中$\log$里面的分母与$\theta$无关,可以直接去掉。于是可以得到:
\[Q(\theta, \theta^{(n)}) = E_{z \in P(z|x, \theta^{(n)})} \log P(x,z|\theta) \tag{13}\]这个,就是千呼万唤的$Q$函数了。 $EM$的$E$,就是上式中的求期望,$M$就是求期望的最大化。 这就是$EM$算法的全部。
让我们回顾一下,刚才那个看似很奇怪的变形,其本质是通过构造一个$KL$散度,将第二项处理成非负的,让我们可以清晰得到$ELBO$这个下界。
等等,刚才说梯度下降难以满足非负约束,这里难道就满足了吗?答案是满足了。 只要我们保证初始化时$P(x|z, \theta)$和$P(z|\theta)$是正数,那么在Q函数最大化的过程中,一定是使得$P(x|z, \theta)$和$P(z|\theta)$越来越大的。
还有,文章开头提到了$Jessen$不等式,为什么整个推导过程中,没有看到$Jessen$不等式呢?
如果用$Jessen$不等式,推导起来会更加直接。
\[\begin{aligned} \log P(x|\theta) &= \log \sum_{z}P(x,z|\theta) \\ &= \log \sum_{z} P(z|x, \theta^{(n)})\frac{P(x,z|\theta)}{P(z|x, \theta^{(n)})} \\ &= \log E_{z \in P(z|x, \theta^{(n)})} \frac{P(x,z|\theta)}{P(z|x, \theta^{(n)})} \\ &\geq E_{z \in P(z|x, \theta^{(n)})} \log P(x,z|\theta) + \text{const} \end{aligned} \tag{14}\]最后一步就是使用的$Jessen$不等式。$Jessen$不等式,可以参考Jessen不等式。
不过,使用$Jessen$不等式后,整体的启发性就不足了。