比DGL快14倍:PyTorch圖神經網絡庫PyG上線了

 2019-03-09 13:00:19.0

圖神經網絡是最近 AI 領域最熱門的方向之一,很多圖神經網絡框架如 graph_netsDGL 已經上線。但看起來這些工具還有很多可以改進的空間。近日,來自德國多特蒙德工業大學的研究者們提出了 PyTorch Geometric,該項目一經上線便在 GitHub 上獲得 1500 多個 star,並得到了 Yann LeCun 的點贊。現在,創建新的 GNN 層更加容易了。

項目鏈接:https://github.com/rusty1s/pytorch_geometric
LeCun評價:一個快速且漂亮的幾何深度學習庫(適用於圖神經網絡與其他不規則結構)

PyTorch Geometric 主要是現有模型快速重新實現的集合(具有定製化的稀疏操作),如果你想嘗試一下特定的已發佈模型,它會是一個很不錯的選擇;如果想實現更爲複雜的結構,其自定義稀疏/分散操作也非常方便。

PyTorch Geometric是基於PyTorch構建的深度學習庫,用於處理不規則結構化輸入數據(如圖、點雲、流形)。除了一般的圖形數據結構和處理方法外,它還包含從關係學習到3D數據處理等領域中最新發布的多種方法。通過利用稀疏 GPU 加速、提供專用的 CUDA 內核以及爲不同大小的輸入樣本引入高效的小批量處理,PyTorch Geometric 實現了很高的數據吞吐量。

PyTorch Geometric(PyG)庫包含易用的小批量加載器(mini-batch loader)、多GPU支持、大量常見基準數據集和有用的變換,適用於任意圖像、三維網格(3D mesh)和點雲。

其作者Matthias Fey 和 Jan E. Lenssen 來自德國多特蒙德工業大學,他們通過實驗證明了該庫已實現方法在圖分類、點雲分類、半監督節點分類等任務上的性能。此外,PyG 速度奇快,大大超過其它深度圖神經網絡庫,如 DGL。

簡單易用

PyTorch Geometric大大簡化了實現圖卷積網絡的過程。比如,它可以用以下幾行代碼實現一個層(如edge convolution layer):

速度快

PyTorch Geometric 速度非常快。下圖展示了這一工具和其它圖神經網絡庫的訓練速度對比情況:
最高比 DGL 快 14 倍!

已實現方法多

PyTorch Geometric 目前已實現以下方法,所有實現方法均支持 CPU 和 GPU 計算:

PyG 概覽

圖神經網絡(GNN)最近成爲在圖、點雲和流形上進行表徵學習的強大方法。與規則域中常用的卷積層和池化層概念類似,GNN通過傳遞、變換和聚合信息來 (層級化地)提取局部嵌入。

但是,實現GNN並不容易,因爲它需要在不同大小的高度稀疏與不規則數據上實現較高的GPU吞吐量。PyTorch Geometric (PyG) 是基於Pytorch構建的幾何深度學習擴展庫。它可以利用專門的CUDA內核實現高性能。在簡單的消息傳遞API之後,它將大多數近期提出的卷積層和池化層捆綁成一個統一的框架。所有的實現方法都支持 CPU 和 GPU 計算,並遵循不變的數據流範式,這種範式可以隨着時間的推移動態改變圖結構。PyG已在MIT許可證下開源,它具備完備的文檔,且提供了附帶教程和示例。

PyG 用節點特徵矩陣 X ∈ ℝ^(N×F) 和稀疏鄰接元組(I,E)來表示圖 G = (X, (I, E)),其中 I ∈ ℕ^(2×E) 以座標(COO)格式編碼邊索引,E ∈ ℝ^(E×D)(可選地)保留D維邊特徵。所有面向用戶的API(如數據加載路徑、多GPU支持、數據增強或模型實例化)都受到了PyTorch的極大啓發,以讓用戶儘可能地熟悉它們。

鄰域聚合(neighborhood aggregation)。將卷積運算泛化至不規則域通常表示爲鄰域聚合或消息傳遞方案(Gilmer et al., 2017)

其中 ⬚ 表示可微置換不變函數(permutation invariant function),如求和、均值或最大值,r 和 表示可微函數 ,如MLP。實踐中,r 和 的逐元素計算可以通過收集和散射節點特徵、利用廣播來實現,如圖1所示。儘管該方案處理的是不規則結構化輸入,但它依然可以通過GPU實現大幅加速。

圖 1:GNN 層計算方法。利用基於邊索引 I 的收集和散射方法,從而在節點並行空間(parallel space)和邊並行空間之間進行交替。 

PyG爲用戶提供通用的MessagePassing接口,以便對新的研究想法進行快速乾淨的原型製作。此外,幾乎所有近期提出的鄰域聚合函數都適用於此接口,其中包括PyG已經集成的方法。

全局池化PyG提供多種readout函數(如global add、mean 或 max pooling),從而支持圖級別輸出,而非節點級別輸出。PyG還提供更加複雜的方法,如 set-to-set (Vinyals et al., 2016)、sort pooling (Zhang et al., 2018) 和全局軟注意力層 (Li et al., 2016)。

層級池化(Hierarchical Pooling)。爲進一步提取層級信息和使用更深層的GNN模型,需要以空間或數據依賴的方式使用多種池化方法。PyG目前提供Graclus、voxel grid pooling、迭代最遠點採樣算法(iterative farthest point sampling algorithm)的實現示例,以及可微池化機制(如DiffPool和top_k pooling)。

小批量處理。PyG 可自動創建單個(稀疏)分塊對角鄰接矩陣(block-diagonal adjacency matrix),並在節點維度中將特徵矩陣級聯起來,從而支持對多個(不同大小)圖實例的小批量處理。正因如此,PyG可在不經修改的情況下應用鄰域聚合方法,因爲不相連的圖之間不會出現信息交流。此外,自動生成的 assignment 向量可確保節點級信息不會跨圖聚合,比如當執行全局聚合運算時。

處理數據集。PyG提供統一的數據格式和易用的接口,方便使用者創建和處理數據集,大型數據集和訓練期間可保存在內存中的數據集皆可適用。要想創建新數據集,用戶只需讀取/下載數據,並轉換爲PyG數據格式即可。此外,用戶可以使用變換(transform,即訪問單獨的多個圖並對其進行變換)方法來修改數據集,比如數據增強、使用合成結構化圖屬性來增強節點特徵等,從而基於點雲自動生成圖,或者從網格中自動採樣點雲。

PyG目前支持大量常見基準數據集,它們均可在第一次初始化時自動下載和處理。具體來講,PyG提供60多個 graph kernel 基準數據集 (Kersting et al., 2016),如 PROTEINS 或 IMDB-BINARY、引用網絡數據集 Cora、CiteSeer、PubMed 和 Cora-Full (Sen et al., 2008; Bojchevski & Günnemann, 2018)、Coauthor CS/Physics 和 Amazon Computers/Photo 數據集 (Shchur et al. (2018)、分子數據集 QM7b (Montavon et al., 2013) 和 QM9 (Ramakrishnan et al., 2014),以及Hamilton 等人 (2017) 創建的蛋白質相互作用圖。此外,PyG還提供嵌入式數據集,如MNIST超像素 (Monti et al., 2017)、FAUST (Bogo et al., 2014)、ModelNet10/40 (Wu et al., 2015)、ShapeNet (Chang et al., 2015)、COMA (Ranjan et al., 2018),以及 PCPNet 數據集 (Guerrero et al., 2018)。

實證評估

半監督節點分類
表 1:多個模型使用固定分割和隨機分割的半監督節點分類結果。

圖分類
表 2:圖分類。

點雲分類
表3:點雲分類。


看起來,圖神經網絡框架的競爭正愈發激烈起來,PyTorch Geometric 也引起了 DGL 創作者的注意,來自 AWS 上海 AI 研究院的 Ye Zihao 對此評論道:「目前 DGL 的速度比 PyG 慢,這是因爲它 PyTorch spmm 的後端速度較慢(相比於 PyG 中的收集+散射)。在 DGL 的下一個版本(0.2)中,我們將報告新的模型訓練速度數據,並提供基準測試腳本。我們還將提供定製內核支持以加速 GAT,敬請期待!」

論文鏈接:https://arxiv.org/abs/1903.02428

文章來源:機器之心