變分自編碼器是近些年較火的一個(gè)生成模型,我個(gè)人認(rèn)為其本質(zhì)上仍然是一個(gè)概率圖模型,只是在此基礎(chǔ)上引入了神經(jīng)網(wǎng)絡(luò)。本文將就變分自編碼器(VAE)進(jìn)行簡單的原理講解和數(shù)學(xué)推導(dǎo)。
論文:
視頻:
引入高斯混合模型
生成模型,可以簡單的理解為生成數(shù)據(jù)(不止,但我們暫且就這么理解它)。假如現(xiàn)在我們有樣本數(shù)據(jù),而我們發(fā)現(xiàn)這些樣本符合正態(tài)分布且樣本具有充分的代表性,因此我們計(jì)算出樣本的均值和方差,就能得到樣本的概率分布。然后從正態(tài)分布中抽樣,就能得到樣本。這種生成樣本的過程就是生成過程。
可是,假如我們的數(shù)據(jù)長這樣
很顯然,它的數(shù)據(jù)是由兩個(gè)不同的正態(tài)分布構(gòu)成。我們可以計(jì)算出這些樣本的概率分布。但是一種更為常見的方法就是將其當(dāng)作是兩個(gè)正態(tài)分布。我們引入一個(gè)隱變量z。
假設(shè) z的取值為0,1,如果z為0,我們就從藍(lán)色的概率分布中抽樣;否則為1,則從橙色的概率分布中抽樣。這就是生成過程。
但是這個(gè)隱變量z是什么?它其實(shí)就是隱藏特征訓(xùn)練數(shù)據(jù)x的抽象出來的特征,比如,如果x偏小,我們則認(rèn)為它數(shù)據(jù)藍(lán)色正太分布,否則為橙色。這個(gè)"偏小"就是特征,我們把它的取值為0,1(0代表偏小,1代表偏大)。
那這種模型我們?nèi)绾稳∮?xùn)練它呢?如何去找出這個(gè)z呢?一種很直觀的方法就是重構(gòu)代價(jià)最小,我們希望,給一個(gè)訓(xùn)練數(shù)據(jù)x,由x去預(yù)測隱變量z,再由隱變量z預(yù)測回x,得到的誤差最小。比如假如我們是藍(lán)色正態(tài)分布,去提取特征z,得到的z再返回來預(yù)測x,結(jié)果得到的卻是橙色的正態(tài)分布,這是不可取的。其模型圖如下
這個(gè)模型被稱為GMM高斯混合模型
變分自編碼器(VAE)
那它和VAE有什么關(guān)聯(lián)呢?其實(shí)VAE的模型圖跟這個(gè)原理差不多。只是有些許改變,隱變量Z的維度是連續(xù)且高維的,不再是簡單的離散分布,因?yàn)榧偃缥覀兩傻氖菆D片,我們需要提取出來的特征明顯是要很多的,傳統(tǒng)的GMM無法做到。
在VAE中,(也就是用樣本x預(yù)測變量z),其服從高斯分布,接下來,我們來看模型圖
也就是將訓(xùn)練樣本x給神經(jīng)網(wǎng)絡(luò),讓神經(jīng)網(wǎng)絡(luò)計(jì)算出均值和協(xié)方差矩陣.
取log的原因是傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)輸出值總是有正有負(fù)。有了這兩個(gè)值就可以在對應(yīng)的高斯分布中采樣,得到隱變量z。再讓z經(jīng)過神經(jīng)網(wǎng)絡(luò)重構(gòu)回樣本,得到新樣本。這就是整個(gè)VAE的大致過程了。
再次強(qiáng)調(diào),訓(xùn)練過程我們希望每次重構(gòu)的時(shí)候,新樣本和訓(xùn)練樣本盡可能的相似
如果我們這樣直接去訓(xùn)練的化,可以嗎?可以!但是會(huì)有個(gè)問題,神經(jīng)網(wǎng)絡(luò)會(huì)趨向于將協(xié)方差變0,而讓概率分布 將不再具備隨機(jī)性,概率分布空間會(huì)坍縮成一個(gè)點(diǎn),每次采樣都是均值,這種情況我們俗稱過擬合。
為什么協(xié)方差會(huì)變成0?因?yàn)椴蓸泳哂须S機(jī)性,也就是存在噪聲,噪聲是肯定會(huì)增加重構(gòu)的誤差的。神經(jīng)網(wǎng)絡(luò)為了讓誤差最小,是肯定讓這個(gè)隨機(jī)性越小越好,因?yàn)橹挥羞@樣,才能重構(gòu)誤差最小
但是我們肯定是希望有隨機(jī)性的,為什么?因?yàn)橛须S機(jī)性,我們才可以生成不同的樣本啊!
所以,對于概率分布,我們不希望它的協(xié)方差為0,所以我們需要對其進(jìn)行約束。在論文中,它對其進(jìn)行約束,要求它盡量的往N(0,I)靠近(其實(shí)與先驗(yàn)分布P(z)有關(guān),后續(xù)數(shù)學(xué)推導(dǎo)中可見假設(shè)~ )
所以,有KL散度去衡量兩個(gè)概率分布的相似性
KL散度是大于等于0的值,越小則證明越相似
所以,我們就是兩個(gè)優(yōu)化目標(biāo)①最小化重構(gòu)代價(jià)②最小化上述的散度
依照這兩個(gè)條件,建立目標(biāo)函數(shù),直接梯度下降其實(shí)還需要重參數(shù)化,后面會(huì)講到,刷刷刷地往下降,最終收斂。
下面,我們就對其進(jìn)行簡單的數(shù)學(xué)推導(dǎo),并以此推導(dǎo)出目標(biāo)函數(shù)
原理推導(dǎo)引入目標(biāo)函數(shù)
以VAE的簡略圖為例
設(shè)我們有N個(gè)樣本
定義隱變量先驗(yàn)分布。很自然的想法,我們直接對x求log極大似然,假設(shè)我們有N的樣本,記作X。設(shè)所需求解的參數(shù)為 ,似然函數(shù)記為,為了簡便,以下省略 ,第 i個(gè)樣本記為,某個(gè)樣本的第 j個(gè)維度記作
現(xiàn)在,我們先單獨(dú)看看里面某一個(gè)樣本的似然,某個(gè)樣本記為x
引入一個(gè)分布,是它的參數(shù),為了簡便,后續(xù)省略掉,直接記為
等式左右分別對求積分
所以左邊等于右邊
(式a)到(式b)用到了log 的性質(zhì)。
那么,現(xiàn)在,我們開始極大似然,似然函數(shù)的參數(shù)為
因?yàn)槲覀兯悴怀鰜恚ㄔ蛘埧矗帜傅姆e分計(jì)算不了)
故而,使用去逼近,所以,更新q的參數(shù) ,以最小化第②項(xiàng)。
在給定x跟 的情況下,的值是確定的,所以最小化第②項(xiàng),就等于最大化第①項(xiàng)
舉個(gè)例子,在VAE中,里面的參數(shù),其實(shí)都是用神經(jīng)網(wǎng)絡(luò)去逼近的。所以,如果按照剛剛提到的,步驟就長這樣
按照上面提到的,我們可以把第一步改成
更一般地,我們把它們寫成一起
由于KL散度是大于等于0的,所以第①項(xiàng),就被稱為變分下界。
好,現(xiàn)在我們只需要最大化其變分下界(以下省略掉參數(shù))
發(fā)現(xiàn)了嗎,最大化里面的第一項(xiàng)就期望,不就是從采樣z,再讓概率最大。這不就是重構(gòu)代價(jià)最小嗎;而對于第二項(xiàng),最大化 -KL散度,就相當(dāng)于最小化KL散度。這和我們上面提到的兩個(gè)優(yōu)化目標(biāo)是一樣的。
細(xì)化目標(biāo)函數(shù)
既然得到了目標(biāo)函數(shù),那么我們就對 似然和KL散度都求出具體的表達(dá)。
先來看 KL散度
需要逼近,而 ~ N(0,I)的多維高斯分布,并且各個(gè)維度之間相互獨(dú)立,所以也是如此設(shè)定,那么最小化其KL散度,只需要對每一個(gè)維度求KL最小即可,單獨(dú)看某一個(gè)維度,設(shè)某一個(gè)維度為,設(shè) ~ (后續(xù)為了簡便,也同樣將隱去)
可以分為三部分
如果你熟悉高斯分布的高階矩的話,式A和式C完全就是二階原點(diǎn)矩和中心距,是直接可以的得出答案的。
值得注意的是,我看很多文章中都說此處就直接采用均方差來計(jì)算。這種說法是不準(zhǔn)確的,在論文中提到是服從一個(gè)概率分布的,而不是無端的就計(jì)算其差值。它不像是GAN一樣,對于其隱藏在內(nèi)部的概率分布不作約束,VAE是仍然對進(jìn)行約束。
當(dāng)然了,其實(shí)我們也可以不對其概率分布進(jìn)行約束,歸根究底,其讓然是最小重構(gòu)代價(jià),那么我們的目標(biāo)函數(shù)如果可以充分表達(dá)出“最小重構(gòu)代價(jià)”,那么是什么又有何關(guān)系呢?
在論文中,其假設(shè) ~ ,其中和都是需要使用神經(jīng)網(wǎng)絡(luò)去逼近。
但是,一般地,我們假設(shè) ~ ,也就是其均值用神經(jīng)網(wǎng)絡(luò)去逼近,對于其協(xié)方差矩陣,我們設(shè)定為常數(shù)c和相乘,所以依然是各個(gè)維度之間相互獨(dú)立。我們來看看它的極大似然估計(jì)得什么(假設(shè)采樣n個(gè)樣本)
可以看到,這就是一個(gè)均方差
重參數(shù)化技巧
有了目標(biāo)函數(shù),理論上我們直接梯度下降就可以了。然而,別忘了,我們是從中采樣出z來。可是我們卻是用的神經(jīng)網(wǎng)絡(luò)去計(jì)算的均值和方差,得到的高斯分布再去采樣,這種情況是不可導(dǎo)的。中間都已經(jīng)出現(xiàn)了一個(gè)斷層了。神經(jīng)網(wǎng)絡(luò)是一層套一層的計(jì)算。而采樣計(jì)算了一層之后,從這一層中去采樣新的值,再計(jì)算下一層。因此,采樣本身是不可導(dǎo)的。
所以要引入重參數(shù)化技巧,假定 ~ 。那么可以構(gòu)造一個(gè)概率分布 ~ 。有
我們從采樣,然后利用上述公式,就相當(dāng)于得到了從 采樣的采樣值。
代碼實(shí)現(xiàn)
效果一般,不曉得論文里面用了什么手段,效果看起來比這個(gè)好。(這個(gè)結(jié)果甚至還是我加了一層隱藏層的)
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
class VAE(nn.Module):
def __init__(self,input_dim,hidden_dim,gaussian_dim):
super().__init__()
#編碼器
#隱藏層
self.fc1=nn.Sequential(
nn.Linear(in_features=input_dim,out_features=hidden_dim),
nn.Tanh(),
nn.Linear(in_features=hidden_dim, out_features=256),
nn.Tanh(),
)
#μ和logσ^2
self.mu=nn.Linear(in_features=256,out_features=gaussian_dim)
self.log_sigma=nn.Linear(in_features=256,out_features=gaussian_dim)
#解碼(重構(gòu))
self.fc2=nn.Sequential(
nn.Linear(in_features=gaussian_dim,out_features=256),
nn.Tanh(),
nn.Linear(in_features=256, out_features=512),
nn.Tanh(),
nn.Linear(in_features=512,out_features=input_dim),
nn.Sigmoid() #圖片被轉(zhuǎn)為為0,1的值了,故用此函數(shù)
)
def forward(self,x):
#隱藏層
h=self.fc1(x)
#計(jì)算期望和log方差
mu=self.mu(h)
log_sigma=self.log_sigma(h)
#重參數(shù)化
h_sample=self.reparameterization(mu,log_sigma)
#重構(gòu)
reconsitution=self.fc2(h_sample)
return reconsitution,mu,log_sigma
def reparameterization(self,mu,log_sigma):
#重參數(shù)化
sigma=torch.exp(log_sigma*0.5) #計(jì)算σ
e=torch.randn_like(input=sigma,device=device)
result=mu+e*sigma #依據(jù)重參數(shù)化技巧可得
return result
def predict(self,new_x): #預(yù)測
reconsitution=self.fc2(new_x)
return reconsitution
def train():
transformer = transforms.Compose([
transforms.ToTensor(),
]) #歸一化
data = MNIST("./data", transform=transformer,download=True) #載入數(shù)據(jù)
dataloader = DataLoader(data, batch_size=128, shuffle=True) #寫入加載器
model = VAE(784, 512, 20).to(device) #初始化模型
optimer = torch.optim.Adam(model.parameters(), lr=1e-3) #初始化優(yōu)化器
loss_fn = nn.MSELoss(reduction="sum") #均方差損失
epochs = 100 #訓(xùn)練100輪
for epoch in torch.arange(epochs):
all_loss = 0
dataloader_len = len(dataloader.dataset)
for data in tqdm(dataloader, desc="第{}輪梯度下降".format(epoch)):
sample, label = data
sample = sample.to(device)
sample = sample.reshape(-1, 784) #重塑
result, mu, log_sigma = model(sample) #預(yù)測
loss_likelihood = loss_fn(sample, result) #計(jì)算似然損失
#計(jì)算KL損失
loss_KL = torch.pow(mu, 2) + torch.exp(log_sigma) - log_sigma - 1
#總損失
loss = loss_likelihood + 0.5 * torch.sum(loss_KL)
#梯度歸0并反向傳播和更新
optimer.zero_grad()
loss.backward()
optimer.step()
with torch.no_grad():
all_loss += loss.item()
print("函數(shù)損失為:{}".format(all_loss / dataloader_len))
torch.save(model, "./model/VAE.pth")
if __name__ == '__main__':
#是否有閑置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#訓(xùn)練
train()
#載入模型,預(yù)測
model=torch.load("./model/VAE (1).pth",map_location="cpu")
#預(yù)測20個(gè)樣本
x=torch.randn(size=(20,20))
result=model.predict(x).detach().numpy()
result=result.reshape(-1,28,28)
#繪圖
for i in range(20):
plt.subplot(4,5,i+1)
plt.imshow(result[i])
plt.gray()
plt.show()
VAE有很多的變種優(yōu)化,感興趣的讀者自行查閱。
結(jié)束
以上,就是VAE的原理和推導(dǎo)過程了。能力有限,過程并不嚴(yán)謹(jǐn),如有問題,還望指出。阿里嘎多
請加小助理加入AIGC技術(shù)交流群
備注公司/學(xué)校+昵稱+研究方向
往期推薦