理解EM算法


EM算法是1977年发明的,但是到今天,想要透彻理解并不容易。到底什么是隐变量?Q函数到底是怎么回事?作者怎么想到Jessen不等式来证明的?有没有一种简单直观的理解方式?本文试图解答。

EM中,隐变量是一个比较核心又难以理解的概念。

隐变量

考虑我们有一些身高数据,170,161,172,173,164,175,176,167,,179,180。 此时不能简单假设这些身高数据来自某个正态分布。因为男生和女生身高数据,应该来自两个不同的正态分布。 所以我们这些数据实际上是来自两个正态分布的混合。

请仔细体会,我们现在从一个身高的数据分布中进行抽样,得到了上面的数据。我们知道每条具体抽样的身高数据,要么来自男生,要么来自女生。 这个“身高的数据分布”背后实际上对应着一个男生女生分布。怎么理解这里的对应?我们可以理解,每次抽样一条身高数据,同时也是从男生女生分布中抽样了一个样本。

一条身高数据样本<—>一个男生女生分布中的样本

唯一的不同,作为身高数据样本,我们是可以看到样本的具体值的,也就是上面的身高数字,但是当将该抽样作为男生女生分布中的样本时,我们是看不到样本的具体值的,即不知道该样本到底是来自男生还是女生,这就是“隐”的含义。

我们记身高数据分布为P(x),男生女生分布为P(z)。 将x称为观测变量,z称为隐变量。

观测变量可以由隐变量生成,即:

(1)P(x)=zP(x,z)=zP(z)P(x|z)

理解这个公式,我们可以使用上面男生女生身高的例子。假设男生身高分布为P(x|z=boy),女生身高分布为P(x|z=girl)。 那么,P(x)就是男生女生身高分布的混合。 即:

(2)P(x)=P(z=boy)P(x|z=boy)+P(z=girl)P(x|z=girl)

一般化,P(x)可以表示为:

(3)P(x)=zP(z)P(x|z)

对于连续变量,公式3可以写成:

(4)P(x)=zP(z)P(x|z)dz

上面的公式可以类比全概率公式来理解。为便于理解,下面都使用z的离散形式。

极大似然

有了上面含隐变量公式,我们可以用log-likelihood来求解模型参数。 我们知道log-likelihood的公式是:

(5)θ=argmaxθi=1nlogP(xi|θ)=argmaxθi=1nlogzP(z)P(xi|z)=argmaxθi=1nlogzP(z|θ)P(xi|z,θ)

此时,需要注意一下,θ是模型参数。具体是哪些参数呢?这点有必要先澄清一下。 以上面的男生女生例子,θ包含了三部分参数:

  1. 男生身高分布的参数,μboyσboy
  2. 女生身高分布的参数,μgirlσgirl
  3. 男生女生分布的参数,P(z=boy)P(z=girl)

符号P(z|θ)中的θ特指的是属于z的那部分参数。笔者最初学习的时候,在这一点上困惑过。 log-likelihood外面的对x的求和,表示的是对所有样本的求和。将似然最大化,就是将每个样本的似然最大化。 为了简化推导,后面省略这个求和符号。

对上面的式子进行求梯度,得到:

(6)θ=θlogzP(z|θ)P(x|z,θ)=zP(z|θ)P(x|z,θ)zP(z|θ)P(x|z,θ)=z(P(z|θ)P(x|z,θ)+P(z|θ)P(x|z,θ))zP(z|θ)P(x|z,θ)

有了梯度,就可以使用梯度上升法来优化模型参数。万事大吉? 其实没那么容易。上面基于梯度求解,有个难以忽视的问题:

通过梯度进行更新,难以满足非负约束。

上面的公式中,有两个分布:

P(z|θ)

以及

P(x|z,θ)

因为是概率分布,需要是正数。仅通过梯度更新,很容易破坏这个约束。

至此,我们理解了隐变量,也尝试基于隐变量方案的最大似然求解,由于遇到了计算上的困难,导致无法求解。 EM算法就是来帮我们绕过这个困难的。

EM算法

EM算法的起手式,就是对上面的logP(x|θ)进行新的分解。 如下所示:

(7)logP(x|θ)=logP(x,z|θ)logP(z|x,θ)

其中:

(8)P(x,z|θ)=P(z|θ)P(x|z,θ) (9)P(z|x,θ)=P(x,z|θ)P(x|θ)=P(z|θ)P(x|z,θ)zP(z|θ)P(x|z,θ)

这里我们细细体会一下这个分解和之前的分解的差别。之前的分解是:

(10)logP(x|θ)=logzP(z|θ)P(x|z,θ)

现在我们继续推导公式7。 假设,我们在第n次迭代中,已经有了一个θ(n)。我们可以对公式7进行变形:

(11)zP(z|x,θ(n))logP(x|θ)=zP(z|x,θ(n))logP(x,z|θ)P(z|x,θ(n))zP(z|x,θ(n))logP(z|x,θ)P(z|x,θ(n))

这步变化,是最费解的,就是为什么要这么变形。我们暂且不管,继续推导一步,答案就会揭晓。

(12)logP(x|θ)=EzP(z|x,θ(n))logP(x,z|θ)P(z|x,θ(n))ELBO+KL(P(z|x,θ(n)),P(z|x,θ))

我们知道KL是一个非负的值。 所以上式中的ELBO,实际上是logP(x|θ)的一个下界。这也是ELBO名称的来源,Evidence Lower Bound。

那么,我们只要想办法最大化ELBO,就可以间接最大化logP(x|θ)ELBOlog里面的分母与θ无关,可以直接去掉。于是可以得到:

(13)Q(θ,θ(n))=EzP(z|x,θ(n))logP(x,z|θ)

这个,就是千呼万唤的Q函数了。 EME,就是上式中的求期望,M就是求期望的最大化。 这就是EM算法的全部。

让我们回顾一下,刚才那个看似很奇怪的变形,其本质是通过构造一个KL散度,将第二项处理成非负的,让我们可以清晰得到ELBO这个下界。

等等,刚才说梯度下降难以满足非负约束,这里难道就满足了吗?答案是满足了。 只要我们保证初始化时P(x|z,θ)P(z|θ)是正数,那么在Q函数最大化的过程中,一定是使得P(x|z,θ)P(z|θ)越来越大的。

还有,文章开头提到了Jessen不等式,为什么整个推导过程中,没有看到Jessen不等式呢?

如果用Jessen不等式,推导起来会更加直接。

(14)logP(x|θ)=logzP(x,z|θ)=logzP(z|x,θ(n))P(x,z|θ)P(z|x,θ(n))=logEzP(z|x,θ(n))P(x,z|θ)P(z|x,θ(n))EzP(z|x,θ(n))logP(x,z|θ)+const

最后一步就是使用的Jessen不等式。Jessen不等式,可以参考Jessen不等式

不过,使用Jessen不等式后,整体的启发性就不足了。


原创文章,转载请注明出处,否则拒绝转载!
本文链接:抬头看浏览器地址栏

上篇: MinerU核心代码阅读笔记
下篇: 理解VAE算法

Gitalk 加载中 ...