斯坦福完全可解釋深度神經網絡:你需要用決策樹搞點事

 2018-01-10 13:06:00.0

原標題:斯坦福完全可解釋深度神經網絡:你需要用決策樹搞點事

選自Stanford

近日,斯坦福大學計算機科學博士生 Mike Wu 發表博客介紹了他對深度神經網絡可解釋性的探索,主要提到了樹正則化。其論文《Beyond Sparsity: Tree Regularization of Deep Models for Interpretability》已被 AAAI 2018 接收。

論文地址:https://arxiv.org/abs/1711.06178

近年來,深度學習迅速成爲業界、學界的重要工具。神經網絡再次成爲解決圖像識別、語音識別、文本翻譯以及其他困難問題的先進技術。去年十月,Deepmind 發佈了 AlphaGo 的更強版本,從頭開始訓練即可打敗最優秀的人類選手和機器人,表明 AI 的未來大有可期。在業界,Facebook、谷歌等公司將深度網絡集成在計算 pipeline 中,從而依賴算法處理每天數十億比特的數據。創業公司,如 Spring、Babylon Health 正在使用類似的方法來顛覆醫療領域。深度學習正在影響我們的日常生活。

圖 1:GradCam - 利用目標概念的梯度突出重要像素,從而創建決策的視覺解釋。

但是深度學習是一個黑箱。我第一次聽說它時,就對其工作原理非常費解。幾年過去了,我仍然在探索合理的答案。嘗試解釋現代神經網絡很難,但是至關重要。如果我們打算依賴深度學習製造新的 AI、處理敏感的用戶數據,或者開藥,那麼我們必須理解這些模型的工作原理。

很幸運,學界人士也提出了很多對深度學習的理解。以下是幾個近期論文示例:

  • Grad-Cam(Selvaraju et. al. 2017):使用最後卷積層的梯度生成熱力圖,突出顯示輸入圖像中的重要像素用於分類。

  • LIME(Ribeiro et. al. 2016):使用稀疏線性模型(可輕鬆識別重要特徵)逼近 DNN 的預測。

  • 特徵可視化(Olah 2017):對於帶有隨機噪聲的圖像,優化像素來激活訓練的 DNN 中的特定神經元,進而可視化神經元學到的內容。

  • Loss Landscape(Li et. al. 2017):可視化 DNN 嘗試最小化的非凸損失函數,查看架構/參數如何影響損失情況。

圖 2:特徵可視化:通過優化激活特定神經元或一組神經元,從而生成圖像(Olah 2017)。

從上述示例中可見,學界對如何解釋 DNN 存在不同見解。隔離單個神經元的影響?可視化損失情況?特徵稀疏性?

什麼是可解釋性?

我們應該把可解釋性看作人類模仿性(human simulatability)。如果人類可以在合適時間內採用輸入數據和模型參數,經過每個計算步,作出預測,則該模型具備模仿性(Lipton 2016)。

這是一個嚴格但權威的定義。以醫院生態系統爲例:給定一個模仿性模型,醫生可以輕鬆檢查模型的每一步是否違背其專業知識,甚至推斷數據中的公平性和系統偏差等。這可以幫助從業者利用正向反饋循環改進模型。

決策樹具備模仿性

我們可以很輕鬆地看到決策樹具備模仿性。例如,如果我想預測病人心臟病發作的風險,我可以沿着決策樹的每個節點走下去,理解哪些特徵可用於作出預測。

圖 3:訓練用於分類心臟病發作風險的決策樹。這棵樹最大路徑長度爲 3。

如果我們可以使用決策樹代替 DNN,那麼已經完成了。但是使用 DNN 儘管缺乏可解釋性,但是它的能力遠超過決策樹。所以我們是否可以將決策樹和 DNN 結合起來,構架具備模仿性的強大模型?

我們可以試着做一個類似 LIME 的東西,構建一個模擬決策樹來逼近訓練後的 DNN 的預測結果。但是訓練深度神經網絡時會出現很多局部極小值,其中只有部分極小值容易模仿。因此,用這種方法可能最後會陷於一個難以模仿的極小值(生成一個巨型決策樹,無法在合理時間內走完)。

表 1:決策樹和 RNN 在不同數據集上的性能。我們注意到 RNN 的預測能力比決策樹優秀許多。

直接優化提高模仿性

如果我們想在優化過程中提高模仿性,則可以嘗試找到更具可解釋性的極小值。完美情況是,我們訓練一個行爲非常像(但並不是)決策樹的 DNN,因爲我們仍然想利用神經網絡的非線性。

另一種方式是使用簡單決策樹正則化深度神經網絡。我們稱之爲樹正則化。

樹正則化

若我們有包含 N 個序列的時序數據集,每一個序列有 T_n 個時間步。當沒有限制時,我們可以假設它有二元輸出。一般傳統上,訓練循環神經網絡(RNN)可以使用以下損失函數:

其中ψ爲正則化器(即 L1 或 L2 正則化)、λ 爲正則化係數或強度、W 爲一組 RNN 的權重矩陣、y_nt 爲單個時間步上的標註真值、y_nt hat 爲單個時間步上的預測值。此外,損失函數一般可以選爲交叉熵損失函數。

添加樹正則化需要改變兩個地方。第一部分是給定一些帶權重 W 的 RNN,且權重 W 可以是部分已訓練的,我們將 N 個長度爲 T 的數據 X 傳遞到 RNN 中以執行預測。然後我們就能使用這 N 個數據對訓練決策樹算法,並嘗試匹配 RNN 的預測。

圖 4:在優化過程中的任意點,我們能通過一個簡單的決策樹逼近部分訓練的 DNN。

因此,我們現在有了模擬 DT,但我們可以選擇一個十分小或十分大的決策樹,因此我們需要量化樹的大小。

爲了完成量化過程,首先我們需要考慮樹的平均路徑長度(APL)。對於單個樣本,路徑長度就等於遊歷樹並作出預測的長度。例如,如圖 3 所示,若有一個用來預測心臟病的決策樹,那麼假設輸入 x 爲 age=70。該樣本下路徑長度因爲 70>62.5 而等於 2。因此平均路徑長度可以簡單地表示爲 ∑ pathlength(x_n, y_n hat)。

圖 5:給定一棵決策樹與數據集,我們能計算平均路徑長度以作爲模擬、解釋平均樣本的成本。通過把這一項加入到目標函數,我們就能鼓勵 DNN 生成簡單的 DT 樹並懲罰複雜而巨大的決策樹。

因此我們最後能將損失函數改寫爲以下形式:

現在只有一個問題:決策樹是不可微的。但我們可能真的比較希望能用 SGD 以實現更快速和便捷的最優化,因此我們也許可以考慮更具創造性的方法。

我們可以做的是添加一個代理模型,它可能是一個以 RNN 權重作爲輸入的多層感知機(MLP),並期望能輸出平均路徑長度的估計量,就好像我們在訓練一個決策樹一樣。

圖 6:通過使用代理模型,我們可以利用流行的梯度下降算法來訓練 DNN。爲了訓練一個代理模型,我們最小化標註真值和預測 APL 之間的 MSE。

當我們優化 RNN/DNN 時,每一個梯度下降步都會生成一組新的權重 W_i。對於每一個 W_i,我們能訓練一個決策樹並計算平均路徑長度。在訓練幾個 epoch 之後,我們能創建一個大型數據集並訓練代理 MLP。

訓練過程會給定一個固定的代理,我們能定義正則化目標函數,並優化 RNN。若給定一個固定的 RNN,我們將構建一個數據集並優化 MLP。

小測試數據集

檢查新技術有效性的一個好方法是在合成數據及上進行測試,在其中我們可以強調新技術提出的效益。

考慮以下的虛構數據集:給定單位二維座標系統內的點 (x_i,y_i),定義一個拋物線決策函數。

y=5∗(x−0.5)^2+0.4

我們在單位正方形 [0,1]×[0,1] 內均勻地隨機採樣 500 個點,所有在拋物線之上的點設爲正的,在拋物線之下的點設爲負的。我們通過隨機翻轉 10% 的邊界附近(圖 7 的兩條灰色拋物線之間)的點以添加一些噪聲。然後,隨機分離出 30% 的點用作測試集。

我們訓練了一個 3 層 MLP 用作分類器,其中第一層有 100 個節點,第二層有 100 個節點,第三層有 10 個節點。我們有意讓該模型過度表達,以使其過擬合,並強調正則化的作用。

圖 7:虛構的拋物線數據集。我們訓練了一個深度 MLP,結合不同級別的 L1、L2 正則化和樹正則化以測試最終決策邊界之間的視覺差異。這裏的關鍵之處在於,樹正則化生成了座標對齊的邊界。然後我們用改變的正則化(L1、L2、樹)和改變的強度λ訓練了一系列的 MLP。我們可以通過描述單位正方形內所有點的行爲並畫出等高線以評估模型,從而逼近已學習的決策函數。圖 7 展示了在不同參數設置下的已學習決策函數的並行對比。

正如預期,隨着正則化強度增加,得到的決策函數也更簡單(減少過擬合)。更重要的是,這三種正則化方法生成不同形狀的決策函數。L1 正則化傾向於生成凹凸不平的線,L2 正則化傾向於球狀的線,樹正則化傾向於生成座標對齊的決策函數。這爲決策樹的工作方式提供了更多的直覺理解。

圖 8:正則化模型的 APL 性能對比。這裏,決策樹(黃線)是原始的決策樹(沒有 DNN)。我們注意到在 1.0 到 5.0 之間樹正則化 MLP 的性能高於(以及複雜度低於)所有其它的模型。

至少在這個虛構示例中,樹正則化在高度正則化區域(人類可模擬)能得到更好的性能。例如,樹正則化結合λ=9500.0 只需要 3 個分支就可以獲得類似拋物線的決策函數(有更高的 APL)。

真實數據集

現在我們對樹正則化有了一個直觀認識,下面就來看一下真實世界數據集(帶有二分類結果),以及樹正則化與 L1、L2 正則化的對比。以下是對數據集的簡短描述:

  • Sepsis(Johnson et. al. 2016):超過 1.1 萬敗血症 ICU 病人的時序數據。我們在每個時間步可以獲取 35 個生命體徵的數據向量、標籤結果(如含氧量或心率)和 5 個二分類結果的標籤(即是否使用呼吸機或是否死亡)。

  • EuResist(Zazzi et. al. 2012):5 萬 HIV 病人的時序數據。該結構非常類似於 Sepsis,不過它包括 40 個輸入特徵和 15 個輸出特徵。

  • TIMIT(Garofolo et. al. 1993):630 位英語說話人的錄音,每個語句包括 60 個音素。我們專注於區分阻塞音(如 b、g)和非阻塞音。輸入特徵是連續聲係數和導數。

我們對真實世界數據集進行虛擬數據集同樣的操作,除了這次我們訓練的是 GRU-RNN。我們再次用不同的正則化執行一系列實驗,現在還利用針對 GRU 的不同隱藏單元大小進行實驗。

圖 9:正則化模型在 Sepsis(5/5 輸出維度)、EuResist (5/15 輸出維度)和 TIMIT 的 APL 上的性能對比。可以看到在 APL 較小時,性能與圖 8 類似,樹正則化達到更高的性能。更多詳細結果和討論見論文 https://arxiv.org/pdf/1711.06178.pdf。

即使在帶有噪聲的真實世界數據集中,我們仍然可以看到樹正則化在小型 APL 區域中優於 L1 和 L2 正則化。我們尤其關注這些低複雜度的「甜蜜點」(sweet spot),因爲這就是深度學習模型模仿性所在,也是在醫療、法律等注重安全的環境中實際有用之處。

此外,我們已經訓練了一個樹正則化 DNN,還可以訓練一個模仿性決策樹查看最終的決策樹是什麼樣子。這是一次很好的完整性檢查,因爲我們期望模仿性決策樹具備模仿性,且與特定問題領域相關。

下圖展示了針對 Sepsis 中 2 個輸出維度的模仿性決策樹。由於我們不是醫生,因此我們請一位敗血症治療專家檢查這些樹。

圖 10:構建決策樹以仿真已訓練的樹正則化 DNN(包含 Sepsis 的 5 個維度中的兩個)。從視覺上,我們可以確認這些樹的 APL 值較小,並且是可模仿的。

考慮 mechanical ventilation 決策樹,臨牀醫生注意到樹節點上的特徵(FiO2、RR、CO2 和 paO2)以及中斷點上的值是醫學上有效的,這些特徵都是測量呼吸質量的。

對於 hospital mortality 決策樹,他注意到該決策樹上的一些明顯的矛盾:有些無器官衰竭的年輕病人被預測爲高死亡率,而其他的有器官衰竭的年輕病人卻被預測爲低死亡率。然後臨牀醫生開始思考,未捕獲的(潛在的)變量如何影響決策樹過程。而這種思考過程不可能通過對深度模型的簡單敏感度分析而進行。

圖 11:和圖 10 相同,但是是從 EuResist 數據集的其中一個輸出維度(服藥堅持性)。

爲了把事情做到底,我們可以看看一個嘗試解釋病人不能服從 HIV 藥物處方(EuResist)的原因的決策樹。我們再次諮詢了臨牀醫生,他確認出,基礎病毒量(baseline viral load)和事先治療線(prior treatment line)是決策樹中的重要屬性,是有用的決策變量。多項研究(Langford, Ananworanich, and Cooper 2007, Socas et. al. 2011)表明高基線的病毒量會導致更快的病情惡化,因此需要多種藥物雞尾酒療法,太多的處方使得病人更難遵從醫囑。

可解釋性優先

本文的重點是一種鼓勵複雜模型在不犧牲太多預測性能的前提下,逼近人類模仿性功能的技術。我認爲這種可解釋性非常強大,可以允許領域專家理解和近似計算黑箱模型正在做的事情。

AI 安全逐漸成爲主流。很多會議如 NIPS 開始更多關注現代機器學習中的公平性、可解釋性等重要問題。之前我們認真地將深度學習應用於消費者產品和服務(自動駕駛汽車),我們確實需要更好地瞭解這些模型的工作原理。這意味着我們需要開發更多可解釋性示例(人類專家參與其中)。

Notes:本文將會出現在 AAAI 2018 上(Beyond Sparsity: Tree Regularization of Deep Models for Interpretability),預印版可在 arXiv 上找到:https://arxiv.org/abs/1711.06178。類似的版本已經在 NIP 2017 上進行了 oral 解讀。

問答

代理 MLP 追蹤 APL 表現如何?

讓人吃驚地好。在所有實驗中,我們使用帶有 25 個隱藏節點的單層 MLP(這是相當小的一個網絡)。這必須有一個預測 APL 權重的低維表徵。

圖 12:真節點計數指的是真正訓練決策樹並計算 APL。已預測的節點計數指的是代理 MLP 的輸出。

與原決策樹相比,樹正則化模型的表現如何?

上述的每個對比圖展示了與正則 DNN 對比的決策樹 AUCs。爲了生成這些線,我們在不同決策樹超參數(即定義葉、基尼係數等的最小樣本數)上進行了網格搜索。我們注意到在所有案例中,DT 表現要比所有正則化方法更差。這表明樹正則化不能只複製 DT。

文獻中有與此相似的嗎?

除了在文章開頭提及的相關工作,模型提取/壓縮很可能是最相似的子領域。其主要思想是訓練一個更小模型以模擬一個更深網絡。這裏,我們主要在優化中使用 DT 執行提取。

樹正則化的運行時間如何?

讓我們看一下 TIMIT 數據集(最大的數據集)。L2 正則化 GRU 每 epoch 用時 2116 秒。帶有 10 個狀態的樹正則化 GRU 每個 epoch 用時 3977 秒,這其中包含訓練代理的時間。實際上,我們做的非常謹慎。例如,如果我們每 25 個 epoch 做一次,我們將獲得 2191 秒的一個平均的每 epoch 的成本。

在多個運行中,(最後的)模擬 DT 穩定嗎?

如果樹正則化強大(高λ),最終的 DT 在不同運行中是穩定的(頂多在一些節點上不同)。

DT 對深度模型的預測有多準確?

換言之,這一問題是在問如果訓練期間 DT 的預測與 DNN 預測是否密切匹配。如果沒有,那麼我們無法有效地真正正則化我們的模型。但是我們並不希望匹配很精確。

在上表中,我們測量了保真度(Craven and Shavlik 1996),這是 DT 預測與 DNN 一致的測試實例的百分比。因此 DT 是準確的。

殘差 GRU-HMM 模型

(本節討論一個專爲可解釋性設計的新模型。)

隱馬爾可夫模型(HMM)就像隨機 RNN,它建模潛在變量序列 [z1,…,zT],其中每個潛在變量是 K 離散狀態之一: z_t∈1,⋯,K。狀態序列通常用於生成數據 x_t,並在每個時間步上輸出觀察到的 y_t。值得注意的是,它包含轉化矩陣 A,其中 A_ij=Pr(z_t=i|z_t−1=j),以及一些產生數據的發射參數。HMMs 通常被認爲是一個更可闡釋的模型,因爲聚類數據的 K 潛在變量通常在語義上是有意義的。

當使用 HMM 潛在狀態(換言之,當 HMM 捕獲數據不足時,只使用 GRU)預測二值目標之時,我們把 GRU-HMM 定義爲一個可以建模殘差誤差的 GRU。根據殘差模型的性質,我們可以使用樹正則化只懲罰 GRU 輸出節點的複雜性,從而使得 HMM 不受限制。

圖 13:GRU-HMM 圖解。x_t 表徵時間步 t 上的輸入數據。s_t 表徵時間步 t 的潛在狀態;r_t,h_t,h_t tilde,z_t 表徵 GRU 的變量。最後的 sigmoid(緊挨着橘色三角形)投射在 HMM 狀態和 GRU 潛在狀態的總和之上。橘色三角形表示用於樹正則化的替代訓練的輸出。

總體而言,深度殘差模型比帶有大體相同參數的 GRU-only 模型的表現要好 1%。參見論文附錄獲得更多信息。

圖 14:就像從前,我們可以爲這些殘差模型繪圖並可視化模擬 DT。儘管我們看到相似的「sweet spot」行爲,我們注意到最後得到的樹有清晰的結構,這表明 GRU 在這一殘差設置中表現不同。

原文地址:http://www.shallowmind.co/jekyll/pixyll/2017/12/30/tree-regularization/


文章來源:機器之心