AI 從頭學(三七):Weight Decay
2017/06/22
前言:
你用一小撮資料來建立一個模型,以便用於預測大量資料。模型建的太好(符合這一小撮資料),在大量資料的表現,難免變的較差。錯一點,但不要錯太多,這就是 weight decay 的精神。
-----
Summary:
Regularization 是一些避免 overfitting 的方法的總稱,weight decay 是最簡單的方法之一 [1]。 本文蒐集了一些相關資料用來輔助理解 [2]-[5],最後是 noise 及一些相關的討論 [6], [7]。
-----
本文藉著下列問題來說明 weight decay。
Q1:Overfitting 是什麼?
Q2:Regularization 有那些方法?
Q3:多項式的 fitting 與類神經網路的 fitting 有何不同?
Q4:Weight decay 裡的懲罰項是什麼意思?
Q5:梯度不出現是什麼意思?
Q6:Weight decay 的微分方程版?
Q7:Weight decay 的差分方程版?
Q8:Noise 的影響?
-----
Q1:Overfitting 是什麼?
機器學習有很大的比例是用一小群資料建立模型,再用模型來處理大規模的資料。在訓練模型時,如果考慮的太細,會使訓練結果很好,但測試結果不理想,這個就叫 overfitting。以圖1a為例。訓練不足變成直線,訓練過度是高次方的多項式,理想的結果是一個二次的拋物線。
Fig. 1a. A high-degree polynomial regression model, p. 119 [1].
-----
Q2:Regularization 有那些方法?
本文要介紹的 weight decay 是一種,early stopping 是一種,dropout 也是一種,參考圖1b。這些在 [4], [5] 裡有清楚的簡介。
Fig. 1b. Regularization for deep learning [1].
-----
Q3:多項式的 fitting 與類神經網路的 fitting 有何不同?
其實是很類似的,不過類神經網路更複雜一些。
圖2a是一個多項式,要注意的是,如果你的 weights 很大,就會產生一堆曲線跟複雜的結構,因為你把不重要的 noise 資料都考慮到了,參考圖2b。
Fig. 2a. Polynomial curve fitting, p. 9 [2].
Fig. 2b. A polynomial with an excess of free coefficients, p. 338 [2].
-----
Q4:Weight decay 裡的懲罰項是什麼意思?
要做 weight decay,首先在原來的 loss function 加入一個懲罰項 Ω ,另外還有調節的係數 ν 。參考圖2c。
理論上,loss function 的值越小越好,因為你希望模型越準確越好,可惜在訓練時,你的資料只是一小撮。error 非常小,你就 overfit 、無法處理未來的一大堆測試資料集。Loss function 加入懲罰項,你就沒那麼準。懲罰項的係數是重點。係數太小,甚至是0,你等於麼都沒做,如同圖1a的右邊。係數太大,懲罰過度,如圖1a的左邊。適合的係數,經驗上譬如0.0001,會像是圖1a的中間。
懲罰項一般取權重的平方和,如圖2d。係數1/2是方便微分後跟二次方相乘變成1。懲罰項為 w^2,會避免 w 太大。為什麼?因為 Loss function 的值要越來越小,w 當然也要傾向於越來越小。
最後的結果是 fitting the data 與 minimizing Ω的 妥協,為什麼?因為你希望 Ω 越小越好,可是太小又會 overfitting 了,參考圖2e。
Fig. 2c. Adding a penalty to the error function, p. 338 [2].
Fig. 2d. Sum of the squares, p. 338 [2].
Fig. 2e. Compromise, p. 338 [2].
-----
Q5:梯度不出現是什麼意思?
圖3a講解了 weight decay 的原理。這個解釋不大清楚,但可以先看一下。
最後它提到:如果梯度不出現,則這個方程式會導致權重成指數式的下降,參考圖3b。
太不清楚了,但是這個問題可以好好想想。
Fig. 3a. Adding a term to the cost function that penalizes large weights, p. 2 [3].
Fig. 3b. If the gradient were not present, p. 2 [3].
-----
Q6:Weight decay 的微分方程版?
圖4a又介紹了一次 regularization 跟 weight decay,重點在圖4b的運算。這邊把式子列出,解微分方程後,得到 w 在連續時間 τ 上是一個指數函數,所以如果 τ 越來越大,w 就呈指數下降,在梯度不見(也就是原來的 loss function 無法藉著梯度更新時)時,還是可以持續更新 w。 這個也就是 weight decay。
Fig. 4a. Regularization, p.353 [2].
Fig. 4b. Weight decay, p. 354 [2].
-----
Q7:Weight decay 的差分方程版?
這邊又把 weight decay 整個講過一遍,參考圖4c。
做完後,在原來的梯度更新之前,w 會先乘上一個小於 1 的係數。這個就是 weight decay。
梯度等於0時,也就是無法藉由更新梯度而有更好的結果時,權重就指數式變小,以致於消失,類似人的神經系統,參考圖4d。
Fig. 4c. Closer to zero, p. 34 [4].
Fig. 4d. Pruning out useless links, p.36 [4].
-----
Q7:Goodfellow 的版本?
圖5是 Goodfellow 的版本,相信已經不用再講一遍了!
Fig. 5a. Mean squared error, p. 108 [1].
Fig. 5b. Adding a regularization term, p. 231 [1].
Fig. 5c. Total objective function, p. 119 [1].
Fig. 5d. Weight decay, p. 231 [1].
-----
Q8:Noise 的影響?
圖6a列出四種 noise。有特別的權重表示可以讓小的 weights 非常快變小,對比比較大的 weights,這也就相等於 reinforcing large weights and weakening small weights,參考圖6b。
圖7則是一些關於 weight decay 較經典的研究。
Fig. 6a. Four types of noise, p. 1127 [6].
Fig. 6b. Reinforcing large weights and weakening small weights, p. 1128 [6].
Fig. 6c. Decaying small weights more rapidly than large weights, p. 1128 [6].
Fig. 7. Historical references of weight decay, p. 92 [7].
-----
結論:
Weight decay 與 early stopping、dropout,都是避免 overfitting 常用的方法,而 weight decay 算是最基本的。
-----
References
[1] 2016_Deep Learning
[2] 1995_Neural Networks for Pattern Recognition
[3] 1992_A simple weight decay can improve generalization
[4] DNN tip
http://speech.ee.ntu.edu.tw/~tlkagk/courses/ML_2016/Lecture/DNN%20tip.pdf
[5] ML Lecture 9 Tips for Training DNN - YouTube
https://www.youtube.com/watch?v=xki61j7z-30
[6] 1998_Weight decay backpropagation for noisy data
[7] 2015_Deep learning in neural networks, An overview