变分推断推导
简介
变分推断(Variational Inference)是一种概率图模型推断方法,用于估计复杂概率分布的后验分布。它的目标是通过一个简单的参数化分布(通常称为变分分布或近似分布)来近似表示复杂的后验分布,从而使推断问题变得更加可行。
在概率图模型中,我们通常希望找到给定一些观察数据的情况下,未观察到的随机变量的后验分布。这个后验分布通常是难以直接计算的,特别是对于高维复杂模型。变分推断的基本思想是将后验推断问题转化为一个优化问题,其中我们试图找到一个变分分布,使其在某种度量下最接近真实的后验分布。
问题定义
我们以贝叶斯逆向强化学习为例去介绍变分推断的过程。逆向强化学习是一种从观察到的代理智能行为中推断出代理智能目标函数(通常称为奖励函数)的方法。贝叶斯逆向强化学习扩展了这一思想,引入了贝叶斯方法,以处理不确定性,并提供了一种概率框架来估计奖励函数。假设我们一系列观测变量\(x\)(专家轨迹),想要推断出隐变量\(z\)(状态对应的奖励值)的分布,我们可以将这一过程写为后验概率的形式\(p(z|x)\)。
但由于\(p(z|x)\)难以求解,因此我们希望找到一个最优的替代分布\(q(z)\),其是距离\(p(z|x)\)最近的分布。假设\(L\)是空间\(Q\)上的距离函数,那么要寻找的概率分布可以定义为下式:
\[q^{*}\left( z \right) = \underset{q\left( z \right) \in Q}{\text{argmin}}L(q\left( z \right),p(z|x))\]
如果能找到这样的分布,就可以替代原来不好计算的\(p(z|x)\)。
ELBO
当我们取距离函数\(L\)为KL散度时,这一问题就变成了变分推断问题,最优化公式可以写为:
\[{q^{*}\left( z \right) = \underset{q\left( z \right) \in Q}{\text{argmin}}\ KL\left( q\left( z \right),p\left( z \middle| x \right) \right) }{= \underset{q\left( z \right) \in Q}{\text{argmin}} - \int_{}^{}{q(z)log\lbrack\frac{p(z|x)}{q(z)}\rbrack dz}}\]
展开上式
\[\ KL\left( q\left( z \right),p\left( z \middle| x \right) \right) = \int_{}^{}{q\left( z \right)\text{logq}\left( z \right)\text{dz}} - \int_{}^{}{q(z)logp(z|x)dz}\]
这里关于\(q(z)\)对\(z\)积分,其实就是关于\(q(z)\)的期望,即\(\int_{}^{}{q\left( z \right)f(z, \cdot )\text{dz}} = \mathbb{E}_{q}\lbrack f(z, \cdot )\rbrack\),那么上式能写成期望的形式:
\[{KL\left( q\left( z \right),p\left( z \middle| x \right) \right) = \mathbb{E}_{q}\left\lbrack \log q\left( z \right) \right\rbrack - \mathbb{E}_{q}\left\lbrack \text{logp}\left( z \middle| x \right) \right\rbrack }{= \mathbb{E}_{q}\left\lbrack \log q\left( z \right) \right\rbrack - \mathbb{E}_{q}\left\lbrack \log\left\lbrack \frac{p\left( x,z \right)}{p\left( x \right)} \right\rbrack \right\rbrack }{= \mathbb{E}_{q}\left\lbrack \log q\left( z \right) \right\rbrack - \mathbb{E}_{q}\left\lbrack \text{logp}\left( x,z \right) \right\rbrack + \mathbb{E}_{q}\lbrack logp(x)\rbrack}\]
其中最后一项与期望对象\(q(z)\)无关,因此期望符号可以去掉。此时前两项能够被称为\(- ELBO\)。那么关于\(q(z)\)的\(\text{ELBO}\)可以被表示成以下形式计算:
\[{\text{ELBO}\left( q \right) = \mathbb{E}_{q}\left\lbrack \text{logp}\left( x,z \right) \right\rbrack - \mathbb{E}_{q}\left\lbrack \text{logq}\left( z \right) \right\rbrack }{= \mathbb{E}_{q}\left\lbrack \text{logp}\left( x|z \right)p\left( z \right) \right\rbrack - \mathbb{E}_{q}\left\lbrack \text{logq}\left( z \right) \right\rbrack }{= \mathbb{E}_{q}\left\lbrack \text{logp}\left( x \middle| z \right) \right\rbrack + \mathbb{E}_{q}\left\lbrack \frac{\text{logp}\left( z \right)}{\text{logq}\left( z \right)} \right\rbrack }{= \mathbb{E}_{q}\left\lbrack \text{logp}\left( x \middle| z \right) \right\rbrack + \int_{}^{}{q\left( z \right)\frac{\text{logp}\left( z \right)}{\text{logq}\left( z \right)}\text{dz}} }{= \mathbb{E}_{q}\left\lbrack \text{logp}\left( x \middle| z \right) \right\rbrack - KL(q\left( z \right),p(z))}\]
把\(\text{ELBO}\)移到等式左边,此时:
\[KL\left( q\left( z \right),p\left( z \middle| x \right) \right) + E\text{LBO}\left( q \right) = logp(x)\]
因为\(logp(x)\)是常数,目标是最小化\(\text{KL}\)项,那么我们的目的就是最大化\(ELBO(q)\)。由于KL散度永远大于0,因此:
\[\log p\left( x \right) = ELBO\left( q \right) + KL\left( q\left( z \right),p\left( z \middle| x \right) \right) \geq ELBO(q)\]
因此这一项被称为Evidence Lower Bound。
变分推断
我们对\(\text{ELBO}\)的表达式进行进一步分析
\[\text{ELBO}\left( q \right) = \mathbb{E}_{q}\left\lbrack \text{logp}\left( x \middle| z \right) \right\rbrack - KL(q\left( z \right),p(z))\]
第一项是重建loss,第二项是对隐变量空间施加的约束,当假设\(p\left( z \right)\mathcal{= N(}0,1)\),\(q_{\theta}\left( z \right)\mathcal{= N(}z;\mu_{\theta}\left( x \right),\sigma_{\theta}^{2}(x))\)时,上述第二项可以写作
\[\text{KL}\left( q_{\theta}\left( z \right),p\left( z \right) \right) = \frac{1}{2}(\mu_{\theta}^{2}\left( x \right) + \sigma_{\theta}^{2}\left( x \right) - ln\sigma_{\theta}^{2}\left( x \right) - 1)\]
第一项\(p\left( x \middle| z \right)\)的形式并不确定,我们可以假设他的分布,既可以是高斯分布,也可以是伯努利分布。当属于高斯分布,且方差为1时,\(\text{ELBO}\)就简化为了一般的MSE Loss。同样如果是伯努利分布,那么最后将得到一个交叉熵