2020年6月23日 星期二

Wasserstein GAN(生成對抗網路 Generative Adversarial Network 簡介)

這篇文章要介紹的是 Wasserstein GAN (WGAN) [1]。WGAN 以及其團隊的前作分析了原始 GAN 的根本問題,並且提出針對於 GAN 的小修改來讓訓練過程變得更穩定,想要複習 GAN 簡介的讀者可以先點生成對抗網路標籤閱讀前文。這篇文章的內容主要來自於以下的參考資料 [2][3]。

原始 GAN 出了什麼問題?

在前文中我們提到了當找出最佳的 discriminator 時,GAN 的 loss function 為以下式子: \[ V(G,D^*)=-2log2 + 2D_{JS}(p_{data} \parallel p_g) \] 這個 loss function 的問題是如果 \(p_{data}\) 與 \(p_g\) 沒有交集時,無論兩個機率分布距離多遠他們的 JS divergence 都一定是 log 2: \[ D_{JS}(p \parallel q) = \frac{1}{2} D_{KL}(p \parallel \frac{p+q}{2}) + \frac{1}{2} D_{KL}(q \parallel \frac{p+q}{2}) \]分別把 p 與 q 分別為零的情況代入上式就會得到 log 2。loss function 是個常數時,代表訓練過程中的梯度為 0,因此在這情形下完全無法訓練 generator。

Wasserstein GAN:用 Wasserstein 距離取代 JS divergence

Wasserstein 距離就是前文所提到的 Earth Mover's Distance 推土機距離。用這個距離當成 loss function 的好處是即使 \(p_{data}\) 與 \(p_g\) 沒有交集,Wasserstein 距離仍然能反映兩個機率分布的遠近,因此在訓練過程中能夠一步步地讓 generator 學得更好。(註:本文只簡述這個想法的概念,關於理論證明請讀 WGAN 的 paper)。

如何讓 loss function 成為 Wasserstein 距離的形式?

經過一連串的理論證明之後,作者給出以下結論: \[ V(G,D) = \underset{D\in 1-Lipschitz}{max} \{ E_{x \sim p_{data}}[D(x)] - E_{x \sim p_g}[D(x)] \} \] 在上式中提到了 Lipschitz 函數的概念。如果一個函數為 Lipschitz 的話,其必須滿足以下條件: \[ \parallel f(x_1) - f(x_2) \parallel \leq K \parallel x_1 - x_2 \parallel \] 而 1-Lipschitz 代表上式中的 K 等於 1。Lipschitz 函數的概念是代表一個函數是否很平滑,因為如果這個函數的變動很大的話那就不會滿足以上不等式的條件。

怎麼實作滿足 1-Lipschitz 條件的函數?

在 WGAN 原文中作者用了一個很直觀的方法:將訓練過程中的參數設定一個最大值 c 及最小值 -c,也就是說當一個參數 w 大於 c 時,就將它設為 c;小於 -c 時就設為 -c。這麼做就可以滿足 1-Lipschitz 的條件了。另外這篇文章還有提到一些實作 WGAN 的細節,有興趣的讀者可以閱讀以下的參考資料。

參考資料

[1] Wasserstein GAN, Martin Arjovsky, Soumith Chintala, and Leon Bottou

沒有留言:

張貼留言