深度學(xué)習(xí)算子優(yōu)化-FFT

作者:嚴(yán)健文 | 曠視 MegEngine 架構(gòu)師

背景

在數(shù)字信號(hào)和數(shù)字圖像領(lǐng)域, 對(duì)頻域的研究是一個(gè)重要分支。
我們?nèi)粘!凹庸ぁ钡膱D像都是像素級(jí),被稱為是圖像的空域數(shù)據(jù)??沼驍?shù)據(jù)表征我們“可讀”的細(xì)節(jié)。如果我們將同一張圖像視為信號(hào),進(jìn)行頻譜分析,可以得到圖像的頻域數(shù)據(jù)。 觀察下面這組圖 (來源),頻域圖中的亮點(diǎn)為低頻信號(hào),代表圖像的大部分能量,也就是圖像的主體信息。暗點(diǎn)為高頻信號(hào),代表圖像的邊緣和噪聲。從組圖可以看出,Degraded Goofy 與 Goofy 相比,近似的低頻信號(hào)保留住了 Goofy 的“輪廓”,而其高頻信號(hào)的增加使得背景噪點(diǎn)更加明顯。頻域分析使我們可以了解圖像的組成,進(jìn)而做更多的抽象分析和細(xì)節(jié)處理。

Goofy and Degraded Goofy

實(shí)現(xiàn)圖像空域和頻域轉(zhuǎn)換的工具,就是傅立葉變換。由于圖像數(shù)據(jù)在空間上是離散的,我們使用傅立葉變換的離散形式 DFT(Discrete Fourier Transform)及其逆變換 IDFT(Inverse Discrete Fourier Transform)。Cooley-Tuckey 在 DFT 的基礎(chǔ)上,開發(fā)了更快的算法 FFT(Fast Fourier Transform)。


DFT/FFT 在數(shù)字圖像領(lǐng)域還有一些延伸應(yīng)用。比如基于 DFT 的 DCT(Discrete Cosine Transform, 離散余弦變換)就用在了圖像壓縮 JPEG 算法 (來源) 和圖像水印算法(來源)。

JPEG 編碼是通過色彩空間轉(zhuǎn)換、抽樣分塊、DCT 變換、量化編碼實(shí)現(xiàn)的。其中 DCT 變換的使用將圖像低頻信息和高頻信息區(qū)分開,在量化編碼過程中壓縮了少量低頻信息、大量高頻信息從而獲得尺寸上壓縮。從貓臉圖上可看出隨著壓縮比增大畫質(zhì)會(huì)變差,但是主體信息還是得以保留。

貓臉不同 jpeg 畫質(zhì)(壓縮比)

圖像水印算法通過 DCT 將原圖轉(zhuǎn)換至頻域,選取合適的位置嵌入水印圖像信息,并通過 IDCT 轉(zhuǎn)換回原圖。這樣對(duì)原圖像的改變較小不易察覺,且水印通過操作可以被提取。


DFT/FFT 在深度學(xué)習(xí)領(lǐng)域也有延伸應(yīng)用。 比如利用 FFT 可以降低卷積計(jì)算量的特點(diǎn),F(xiàn)FT_Conv 算法也成為常見的深度學(xué)習(xí)卷積算法。本文我們就來探究一下頻域算法的原理和優(yōu)化策略。

DFT 的原理及優(yōu)化

公式

無論是多維的 DFT 運(yùn)算,還是有基于 DFT 的 DCT/FFT_Conv, 底層的計(jì)算單元都是 DFT_1D。 因此,DFT_1D 的優(yōu)化是整個(gè) FFT 類算子優(yōu)化的基礎(chǔ)。
DFT_1D 的計(jì)算公式:
X_{k}=\sum_{n=0}^{\mathrm{N}-1} x_{n} e^{-j 2 \pi k \frac{n}{N}} \quad k=0, \ldots, N-1

其中 x_{n}為長(zhǎng)度為 N 的輸入信號(hào),e^{-j 2 \pi k \frac{n}{N}}是 1 的 N 次根, X_{k}為長(zhǎng)度為 N 的輸出信號(hào)。
該公式的矩陣形式為:

\left[\begin{array}{c}X(0) \\ X(1) \\ \vdots \\ X(N-1)\end{array}\right]=\left[W_{N}^{n k}\right]\left[\begin{array}{c} \left.x(0\right) \\ x(1) \\ \vdots \\ x(N-1)\end{array}\right]

單位復(fù)根的性質(zhì)

DFT_1D 中的W_{N}^{nk} = e^{-j 2 \pi k \frac{n}{N}}是 1 的單位復(fù)根。直觀地看,就是將復(fù)平面劃分為 N 份,根據(jù) k * n 的值逆時(shí)針掃過復(fù)平面的圓周。

單位復(fù)根有著周期性和對(duì)稱性,我們依據(jù)這兩個(gè)性質(zhì)可以對(duì) W 矩陣做大量的簡(jiǎn)化,構(gòu)成 DFT_1D 的快速算法的基礎(chǔ)。
周期性:W_{N}^{k +N}=W_{N}^{k}
對(duì)稱性:W_{N}^{k+N / 2}=-W_{N}^{k}

Cooley-Tuckey FFT 算法

DFT_1D 的多種快速算法中,使用最頻繁的是 Cooley-Tuckey FFT 算法。算法采用用分治的思想,將輸入尺寸為 N 的序列,按照不同的基 radix,分解為 N/radix 個(gè)子序列,并對(duì)每個(gè)子序列再劃分,直到不能再被劃分為止。每一次劃分都可以得到一級(jí) stage,將所有的級(jí)自下而上組合在一起,計(jì)算得到最后的輸出序列。
這里以 N = 8, radix=2 為例展示推理過程。
其中x(k)為 N=8 的序列, X^{F}(k)為 DFT 輸出序列。
根據(jù) DFT 的計(jì)算公式
X^{F}(k)=W_{8}^{0} x_{0}+W_{8}^{k} x_{1}+W_{8}^{2 k} x_{2}+W_{8}^{3k} x_{3}+W_{8}^{4k} x_{4} + W_{8}^{5k} x_{5}+W_{8}^{6k} x_{6} +W_{8}^{7k} x_{7}

根據(jù)奇偶項(xiàng)拆開,分成兩個(gè)長(zhǎng)度為 4 的序列G(k), H(k)。

X^{F}(k)= W_{8}^{0} x_{0}+W_{8}^{2 k} x_{2}+W_{8}^{4 k} x_{4}+W_{8}^{6 k} x_{6}

+W_{8}^{k}\left(W_{8}^{0} x_{1}+W_{8}^{2 k} x_{3}+W_{8}^{4 k} x_{5}+W_{8}^{6 k} x_{7}\right)
=G^{F}(k)+W_{8}^{k} H^{F}(k)

X^{F}(k+4)=W_{8}^{0} x_{0}+W_{8}^{2(k+4)} x_{2}+W_{8}^{4(k+4)} x_{4}+W_{8}^{6(k+4)} x_{6}
+W_{8}^{(k+4)}\left(W_{8}^{0} x_{1}+W_{8}^{2(k+4)} x_{3}+W_{8}^{4(k+4)} x_{5}+W_{8}^{6(k+4)} x_{7}\right)
=G^{F}(k)+W_{8}^{k+4} H^{F}(k)
=G^{F}(k)-W_{8}^{k} H^{F}(k)

G^{F}(k)H^{F}(k)G(k)H(k)的 DFT 結(jié)果。G^{F}(k)H^{F}(k)乘以對(duì)應(yīng)的旋轉(zhuǎn)因子W_{8}^{k},進(jìn)行簡(jiǎn)單的加減運(yùn)算可以得到輸出X^{F}(k)
同理,對(duì)G(k)H(k)也做一樣的迭代,A(k),B(k), C(k), D(k) 都是 N=2 的序列,用他們的 DFT 結(jié)果進(jìn)行組合運(yùn)算可以得到G^{F}(k)H^{F}(k)

\begin{aligned} &G^{F}(k)=A^{F}(k) + W_{4}^{k}B^{F}(k)\\ \end{aligned}
\begin{aligned} &G^{F}(k+2)=A^{F}(k)-W_{4}^{k}B^{F}(k)\\ \end{aligned}
\begin{aligned} &H^{F}(k)=C^{F}(k)+W_{4}^{k}D^{F}(k)\\ \end{aligned}
\begin{aligned} &H^{F}(k+2)=C^{F}(k)-W_{4}^{k}D^{F}(k)\\ \end{aligned}

計(jì)算 N=2 的序列A^{F}(k), B^{F}(k), C^{F}(k), D^{F}(k), 因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=k%3D0" alt="k=0" mathimg="1">,旋轉(zhuǎn)因子W_{2}^{0}= 1。只要進(jìn)行加減運(yùn)算得到結(jié)果。
\left[\begin{array}{l} A^{F}(0) \\ A^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{0} \\ x_{4} \\ \end{array}\right]

\left[\begin{array}{l} B^{F}(0) \\ B^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{2} \\ x_{6} \\ \end{array}\right]

\left[\begin{array}{l} C^{F}(0) \\ C^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{1} \\ x_{5} \\ \end{array}\right]

\left[\begin{array}{l} D^{F}(0) \\ D^{F}(1) \end{array}\right]=\left[\begin{array}{ll} 1 & 1 \\ 1 & -1 \end{array}\right]\left[\begin{array}{l} x_{3} \\ x_{7} \\ \end{array}\right]

用算法圖形表示,每一層的計(jì)算會(huì)產(chǎn)生多個(gè)蝶形,因此該算法又被稱為蝶形算法。
這里我們要介紹碟形網(wǎng)絡(luò)的基本組成,對(duì)下文的分析有所幫助。


N=8 碟形算法圖

N=8 的計(jì)算序列被分成了 3 級(jí),每一級(jí) (stage) 有一個(gè)或多個(gè)塊 (section),每個(gè)塊中包含了一個(gè)或者多個(gè)蝶形(butterfly), 蝶形的計(jì)算就是 DFT 運(yùn)算的 kernel。
每一個(gè) stage 的計(jì)算順序:

  • 取輸入
  • 乘以轉(zhuǎn)換因子
  • for section_num, for butterfly_num,執(zhí)行 radixN_kernel
  • 寫入輸出。

看 N=8 的蝶形算法圖,stage = 1 時(shí),運(yùn)算被分成了 4 個(gè) section,每個(gè) section 的 butterfly_num = 1。stage = 2 時(shí),section_num = 2,butterfly_num = 2。 stage = 3 時(shí),section_num = 1,butterfly_num = 4。
可以觀察到,從左到右過程中 section_num 不斷減少,butterfly_num 不斷增加,蝶形群在“變大變密”,然而每一級(jí)總的碟形次數(shù)是不變的。
實(shí)際上,對(duì)于長(zhǎng)度為 N,radix = r 的算法,我們可以推得到:

\text { Sec_num }=N / r^{S}
\text { Butterfly_num }= r^{S-1}
\text { Sec_stride }=r^{S}
\text { Butterfly_stride }=1

S 為當(dāng)前的 stage,sec/butterfly_stride 是每個(gè) section/butterfly 的間隔。

這個(gè)算法可以將復(fù)雜度從 O(n^2) 下降到 O(nlogn),顯得高效而優(yōu)雅。我們基于蝶形算法,對(duì)于不同的 radix 進(jìn)行算法的進(jìn)一步劃分和優(yōu)化,主要分為 radix - 2 的冪次的和 radix – 非 2 的冪次兩類。

radix-2 的冪次優(yōu)化

DFT_1D 的 kernel 即為矩陣形式中的W_{N}^{nk}矩陣,我們對(duì) radix_2^n 的 kernel 進(jìn)行分析。

背景里提到, DFT 公式的矩陣形式為:
\left[\begin{array}{c}X(0) \\ X(1) \\ \vdots \\ X(N-1)\end{array}\right]=\left[W_{N}^{n k}\right]\left[\begin{array}{c} \left.x(0\right) \\ x(1) \\ \vdots \\ x(N-1)\end{array}\right]
其中x(0) ~x(N-1)為乘以旋轉(zhuǎn)因子W_{N}^{kn}后的輸入

當(dāng) radix = 2 時(shí),由于W_{2}^1 = -1, W_{2}^2 = 1, radix_2 的 DFT 矩陣形式可以寫為:

\left[\begin{array}{c}\mathrm{X}_{\mathrm{k}} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 2}\end{array}\right] =\left[\begin{array}{cc}1 & 1 \\ 1 & -1\end{array}\right]\left[\begin{array}{l}\mathrm{W}_{\mathrm{N}}^{0} \mathrm{A}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{\mathrm{k}} \mathrm{B}_{\mathrm{k}}\end{array}\right]

當(dāng) radix = 4 時(shí),由于W_{4}^1 = -j, W_{4}^2 = -1, W_{4}^3 = j, W_{4}^4= 1,radix_4 的 DFT 矩陣形式可以寫為:

\left[\begin{array}{c}\mathrm{X}_{\mathrm{k}} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 4} \\ \mathrm{X}_{\mathrm{k}+\mathrm{N} / 2} \\ \mathrm{X}_{\mathrm{k}+3 \mathrm{~N} / 4}\end{array}\right]=\left[\begin{array}{cccc}1 & 1 & 1 & 1 \\ 1 & -\mathrm{j} & -1 & \mathrm{j} \\ 1 & -1 & 1 & -1 \\ 1 & \mathrm{j} & -1 & -\mathrm{j}\end{array}\right]\left[\begin{array}{c}\mathrm{W}_{\mathrm{N}}^{0} \mathrm{A}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{\mathrm{k}} \mathrm{B}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{2 \mathrm{k}} \mathrm{C}_{\mathrm{k}} \\ \mathrm{W}_{\mathrm{N}}^{3 \mathrm{k}} \mathrm{D}_{\mathrm{k}}\end{array}\right]

同理推得到 radix_8 的 kernel 為:

\left[\begin{array}{cccccccc}1 & 1 & 1 & 1 & 1 & 1 & 1 & 1 \\ 1 & \mathrm{~W}_{8}^{1} & -j & \mathrm{~W}_{8}^{3} & -1 & -\mathrm{W}_{8}^{1} & j & -\mathrm{W}_{8}^{3} \\ 1 & -j & -1 & j & 1 & -j & -1 & j \\ 1 & \mathrm{~W}_{8}^{3} & j & \mathrm{~W}_{8}^{1} & -1 & -\mathrm{W}_{8}^{3} & -j & -\mathrm{W}_{8}^{1} \\ 1 & -1 & 1 & -1 & 1 & -1 & 1 & -1 \\ 1 & -\mathrm{W}_{8}^{1} & -j & -\mathrm{W}_{8}^{3} & -1 & \mathrm{~W}_{8}^{1} & j & \mathrm{~W}_{8}^{3} \\ 1 & j & -1 & -j & 1 & j & -1 & -j \\ 1 & -\mathrm{W}_{8}^{3} & j & -\mathrm{W}_{8}^{1} & -1 & \mathrm{~W}_{8}^{3} & -j & \mathrm{~W}_{8}^{1}\end{array}\right]

我們先來看訪存,現(xiàn)代處理器對(duì)于計(jì)算性能的優(yōu)化要優(yōu)于對(duì)于訪存的優(yōu)化,在計(jì)算和訪存相近的場(chǎng)景下, 訪存通常是性能瓶頸。

DFT1D 中,對(duì)于不同基底的算法 r-2/r-4/r-8, 每一個(gè) stage 有著相等的存取量:2 * butterfly_num * radix = 2N, 而不同的基底對(duì)應(yīng)的 stage 數(shù)有著明顯差異(\log_2N vs \log_4N vs \log_8N)。

因此對(duì)于 DFT, 在不顯著增加計(jì)算量的條件下, 選用較大的 kernel 會(huì)在訪存上取得明顯的優(yōu)勢(shì)。觀察推導(dǎo)的 kernel 圖, r-2 的 kernel 每個(gè)蝶形對(duì)應(yīng) 4 次訪存操作和,2 次復(fù)數(shù)浮點(diǎn)加減運(yùn)算。r-4 的 kernel 每個(gè)蝶形算法對(duì)應(yīng) 8 次 load/store、8 次復(fù)數(shù)浮點(diǎn)加減操作(合并相同的運(yùn)算),在計(jì)算量略增加的同時(shí) stage 由 \log_2N 下降到 \log_4N , 降低了總訪存的次數(shù), 因此會(huì)有性能的提升。r-8 的 kernel 每個(gè)蝶形對(duì)應(yīng) 16 次 load/store、24 次復(fù)數(shù)浮點(diǎn)加法和 8 次浮點(diǎn)乘法。浮點(diǎn)乘法的存在使得計(jì)算代價(jià)有所上升, stage 由 \log_4N 進(jìn)一步下降到 \log_8N ,但由于 N 日常并不會(huì)太大, r-4 到 r-8 的 stage 減少不算明顯,所以優(yōu)化有限

我們?cè)賮砜从?jì)算的開銷。減少計(jì)算的開銷通常有兩種辦法:減少多余的運(yùn)算、并行化。

以 r-4 算法為例,kernel 部分的計(jì)算為:

  • radix_4_first_stage(src, dst, sec_num, butterfly_num)
  • radix_4_other_stage(src, dst, sec_num, butterfly_num)
    • for Sec_num
      • for butterfly_num
        • raidx_4_kernel

radix4_first_stage 的數(shù)據(jù)由于 k=0, 旋轉(zhuǎn)因子都為 1,可以省去這部分復(fù)數(shù)乘法運(yùn)算,單獨(dú)優(yōu)化。 radix4_other_stage 部分, 從第 2 個(gè) stage 往后, butterfly_num = 4^(s-1) 都為 4 的倍數(shù),而每個(gè) butterfly 數(shù)組讀取/存儲(chǔ)都是間隔的??梢詫?duì)最里層的循環(huán)做循環(huán)展開加向量化,實(shí)現(xiàn) 4 個(gè)或更多 butterfly 并行運(yùn)算。循環(huán)展開和 SIMD 指令的使用不僅可以提高并行性, 也可以提升 cacheline 利用的效率,可以帶來較大的性能提升。 以 SM8150(armv8) 為例,r-4 的并行優(yōu)化可以達(dá)到 r2 的 1.6x 的性能。

尺寸:1 * 2048(r2c) 環(huán)境:SM8150 大核

總之,對(duì)于 radix-2^n 的優(yōu)化,選用合適的 radix 以減少多 stage 帶來的訪存開銷,并且利用單位復(fù)根性質(zhì)以及并行化降低計(jì)算的開銷,可以帶來較大的性能提升。

radix-非 2 的冪次優(yōu)化

當(dāng)輸入長(zhǎng)度 N = radix1^m1 * radix2^m2... 且 radix 都不為 2 的冪次時(shí),如果使用 naive 的 O(n^2) 算法, 性能就會(huì)急劇下降。 常見的解決辦法對(duì)原長(zhǎng)補(bǔ) 0、使用 radix_N 算法、特殊的 radix_N 算法 (chirp-z transform)。補(bǔ) 0 至 2 的冪次方法對(duì)于大尺寸的輸入要增加很多運(yùn)算量和存儲(chǔ)量, 而 chirp-z transform 是用卷積計(jì)算 DFT, 算法過于復(fù)雜。因此對(duì)非 2 的冪次 radix-N 的優(yōu)化也是必要的。

radix-N 計(jì)算流程和 radix-2 冪次一樣,我們同樣可以利用單位復(fù)根的周期性和對(duì)稱性,對(duì) kernel 進(jìn)行計(jì)算的簡(jiǎn)化。 以 radix-5 為例,radix-5 的 DFT_kernel 為:
\left[\begin{array}{cccc} 1&1&1&1&1\\ 1 &\mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{-1} \\ 1 &\mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{-2} \\ 1 &\mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{1} & \mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{2} \\ 1 &\mathrm{W}_{\mathrm{5}}^{-1} & \mathrm{W}_{\mathrm{5}}^{-2} & \mathrm{W}_{\mathrm{5}}^{2} & \mathrm{W}_{\mathrm{5}}^{1} \\ \end{array}\right]

W_5^kW_{5}^{-k}在復(fù)平面上根據(jù) x 軸對(duì)稱,有相同的實(shí)部和相反的虛部。根據(jù)這個(gè)性質(zhì)。如下圖所示,對(duì)于每一個(gè) stage,可以合并公共項(xiàng) A,B,C,D,再根據(jù)公共項(xiàng)計(jì)算出該 stage 的輸出。

\begin{array}{l} A=\left(x_{1}+x_{4}\right) * W_{5}^{1} \cdot r+\left(x_{2}+x_{3}\right) * W_{5}^{2} \cdot r\\\end{array}

B=(-j) * \left[\left(x_{1}-x_{4}\right) * W_{5}^{1} \cdot i+\left(x_{2}-x_{3}\right) * W_{5}^{2} \cdot i\right] \

C=\left(x_{1}+x_{4}\right) * W_{5}^{2} \cdot r+\left(x_{2}+x_{3}\right) * W_{5}^{1} \cdot r\

D=j * \left[\left(x_{1}-x_{4}\right) * W_{5}^{2} \cdot i-\left(x_{2}-x_{3}\right) * W_{5}^{1} \cdot i\right] \

\begin{array}{l} X(k)=x_{0}+\left(x_{1}+x_{4}\right)+\left(x_{2}+x_{3}\right)\\ \end{array}
\begin{array}{l} X(k+N/5)=x_{0}+\mathrm{A}-\mathrm{B}\\ X(k+2N/5)=x_{0}+\mathrm{C}+\mathrm{D}\\ X(k+3N/5)=x_{0}+C-D\\ X(k+4N/5)=x_{0}+\mathrm{A}+\mathrm{B}\\ \end{array}

這種算法減少了很多重復(fù)的運(yùn)算。同時(shí),在 stage>=2 的時(shí)候,同樣對(duì) butterfly 做循環(huán)展開加并行化,進(jìn)一步減少計(jì)算的開銷。
radix-5 的優(yōu)化思想可以外推至 radix-N。對(duì)于 radix_N 的每一個(gè) stage, 計(jì)算流程為:

  • 取輸入
  • 乘以對(duì)應(yīng)的轉(zhuǎn)換因子
  • 計(jì)算公共項(xiàng), radix_N 有 N-1 個(gè)公共項(xiàng)
  • 執(zhí)行并行化的 radix_N_kernel
  • 寫入輸出

其他優(yōu)化

上述兩個(gè)章節(jié)描述的是 DFT_1D 的通用優(yōu)化,在此基礎(chǔ)上還可以做更細(xì)致的優(yōu)化,可以參考本文引用的論文。

  • 對(duì)于全實(shí)數(shù)輸入的, 由于輸入的虛部為 0, 進(jìn)行旋轉(zhuǎn)因子以及 radix_N_kernel 的復(fù)數(shù)運(yùn)算時(shí)會(huì)有多余的運(yùn)算和多余的存儲(chǔ), 可以利用 split r2c 算法, 視為長(zhǎng)度為 N/2 的復(fù)數(shù)序列, 計(jì)算 DFT 結(jié)果并進(jìn)行 split 操作得到 N 長(zhǎng)實(shí)數(shù)序列的結(jié)果。
  • 對(duì)于 radix-2 的冪次算法, 重新計(jì)算每個(gè) stage 的輸入/輸出 stride 以取消第一級(jí)的位元翻轉(zhuǎn)可以進(jìn)一步減少訪存的開銷。
  • 對(duì)于 radix-N 算法, 在混合基框架下 N = radix1^m1 * radix2^m2, 合并較小的 radix 為大的 radix 以減少 stage。

DFT 延展算法的原理及優(yōu)化

DCT 和 FFT_conv 兩個(gè)典型的基于 DFT 延展的算法,DFT_1D/2D 的優(yōu)化可以很好的用在這類算法中。

DCT

DCT 算法(Discrete Cosine Transform, 離散余弦變換)可以看作是 DFT 取其正弦分量并經(jīng)過工業(yè)校正的算法。DFT_1D 的計(jì)算公式為:

\begin{aligned} X[k] &=C(k) \sum_{n=0}^{N-1} x[n] \cos \left(\frac{(2 n+1) \pi k}{2 N}\right) \\ &C(k)=\sqrt{\frac{1}{n}} \\&k=1 \\ &C(k)=\sqrt{\frac{2}{n}} \\&k!=1 \\ \end{aligned}

該算法 naive 實(shí)現(xiàn)是 O(n^2) 的,而我們將其轉(zhuǎn)換成 DFT_1D 算法,可以將算法復(fù)雜度降至 O(nlogn)。
基于 DFT 的 DCT 算法流程為:

  • 對(duì)于 DCT 的輸入序列 x[n], 創(chuàng)建長(zhǎng)為 2N 的輸入序列 y[n] 滿足 y[n] = x[n] + x[2N-n-1], 即做一個(gè)鏡像對(duì)稱。
  • 對(duì)輸入序列 y[n] 進(jìn)行 DFT 運(yùn)算,得到輸出序列 Y[K]。
  • 由 Y[K] 計(jì)算得到原輸入序列的輸出 X[K] 。

我們嘗試推導(dǎo)一下這個(gè)算法:

{l} y[n]=x[n]+x [2 N-1-n] \

{l} Y[k]=\sum_{n=0}^{N-1} x[n]\cdot e^{-j \frac{2 \pi k n}{2 N}} +\sum_{n=N}^{2 N-1} x[2 N-1-n] \cdot e^{-j \frac{2 \pi k n}{2 N}}

=\sum_{n=0}^{N-1} x[n]\cdot e^{-j \frac{2 \pi k n}{2 N}} +\sum_{n=0}^{N-1} x[n] \cdot e^{-j \frac{2 \pi k (2N-1-n)}{2 N}}
=e^{-j \frac{2 \pi k }{2 N}} \cdot \sum_{n=0}^{N-1} x[n] (e^{-j \frac{2\pi}{2 N} kn} \cdot e^{-j \frac{\pi}{2 N}k}+e^{j \frac{2\pi}{2 N} kn} \cdot e^{j \frac{\pi}{2 N}k})
=e^{-j \frac{2 \pi k }{2 N}} \cdot \sum_{n=0}^{N-1} x[n] \cdot 2\cdot\cos(\frac{2n+1}{2N} k\pi)
=e^{-j \frac{2 \pi k }{2 N}} \cdot C(u) \cdot X[k]

對(duì) y[n] 依照 DFT 公式展開,整理展開的兩項(xiàng)并提取公共項(xiàng)e^{-j \frac{2 \pi k }{2 N}}, 根據(jù)歐拉公式和誘導(dǎo)函數(shù),整理非公共項(xiàng)(e^{-j \frac{2\pi}{2 N} kn} \cdot e^{-j \frac{\pi}{2 N}k}+e^{j \frac{2\pi}{2 N} kn} \cdot e^{j \frac{\pi}{2 N}k})??梢钥闯龅玫降慕Y(jié)果正是 x[k] 和與 k 有關(guān)的系數(shù)的乘積。這樣就可以通過先計(jì)算Y[k]得到 x[n] 的 DCT 輸出X[k] 。

在理解算法的基礎(chǔ)上,我們對(duì) DFT_1D 的優(yōu)化可以完整地應(yīng)用到 DCT 上。DCT_2D 的計(jì)算過程是依次對(duì)行、列做 DCT_1D, 我們用多線程對(duì) DCT_1D 進(jìn)行并行,可以進(jìn)一步優(yōu)化算法。

FFT_conv

Conv 是深度學(xué)習(xí)最常見的運(yùn)算,計(jì)算 conv 常用的方法有 IMG2COL+GEMM, Winograd, FFT_conv。三種算法都有各自的使用場(chǎng)景。

FFT_conv 的數(shù)學(xué)原理是時(shí)域中的循環(huán)卷積對(duì)應(yīng)于其離散傅里葉變換的乘積。如下圖所示, f 和 g 的卷積等同于將 f 和 g 各自做傅立葉變幻 F,進(jìn)行點(diǎn)乘并通過傅立葉逆變換計(jì)算后的結(jié)果。
f \underset{\text { Circ }}{*} g=\mathcal{F}^{-1}(\mathcal{F}(f) \cdot \mathcal{F}(g))

直觀的理論證明可下圖(來源)。

\mathcal{F}[f * g] \

=\int_{-\infty}^{\infty}\left[\left(\int_{-\infty}^{\infty}g(z)f(x-z)dz\right)e^{-i k x}\right]dx

=\int_{-\infty}^{\infty} g(z)\left[\int_{-\infty}^{\infty} f(x-z) e^{-i k x} d x\right] d z
=\int_{-\infty}^{\infty} g(z)\left[\int_{-\infty}^{\infty} f(y) e^{-i k(y+z)} d y\right] d z
=\left[\int_{-\infty}^{\infty} g(z) e^{-i k z} d z\right]\left[\int_{-\infty}^{\infty} f(y) e^{-i k y} d y\right]
=\mathcal{F}[f] \cdot \mathcal{F}[g]

將卷積公式和離散傅立葉變換展開, 改變積分的順序并且替換變量, 可以證明結(jié)論。
注意這里的卷積是循環(huán)卷積, 和我們深度學(xué)習(xí)中常用的線性卷積是有區(qū)別的。 利用循環(huán)卷積計(jì)算線性卷積的條件為循環(huán)卷積長(zhǎng)度 L?| f |+| g |?1。 因此我們要對(duì) Feature Map 和 Kernel 做 zero-padding,并從最終結(jié)果中取有效的線性計(jì)算結(jié)果。

FFT_conv 算法的流程:

  • 將 Feature Map 和 Kernel 都 zero-pad 到同一個(gè)尺寸,進(jìn)行 DFT 轉(zhuǎn)換。
  • 矩陣點(diǎn)乘
  • 將計(jì)算結(jié)果通過 IDFT 計(jì)算出結(jié)果。

該算法將卷積轉(zhuǎn)換成點(diǎn)乘, 算法復(fù)雜度是 O(nlogn), 小于卷積的 O(n^2), 在輸入的尺寸比較大時(shí)可以減少運(yùn)算量,適用于大 kernel 的 conv 算法。

深度學(xué)習(xí)計(jì)算中, Kernel 的尺寸要遠(yuǎn)小于 Feature Map, 因此 FFT_conv 第一步的 zero-padding 會(huì)有很大的開銷,參考論文 2 里提到可以通過對(duì) Feature map 進(jìn)行分塊, 分塊后的 Feature Map 和 Kernel 需要 padding 到的尺寸較小,可以大幅減小這一部分的開銷。 優(yōu)化后 fft_conv 的計(jì)算流程為:

  • 合理安排緩存計(jì)算出合適的 tile 尺寸,對(duì)原圖進(jìn)行分塊
  • 分塊后的小圖和 kernel 進(jìn)行 zero-padding, 并進(jìn)行 DFT 運(yùn)算
  • 小圖矩陣點(diǎn)乘
  • 進(jìn)行逆運(yùn)算并組合成大圖。

同時(shí)我們可以觀察到,F(xiàn)FT_conv 的核心計(jì)算模塊還是針對(duì)小圖的 DFT 運(yùn)算, 因此我們可以將前一章節(jié)對(duì) DFT 的優(yōu)化代入此處,輔以多線程,進(jìn)一步提升 FFT_Conv 的計(jì)算效率。

參考資料

  1. 陳暾,李志豪,賈海鵬,張?jiān)迫?a target="_blank">基于 ARMV8 平臺(tái)的多維 FFT 實(shí)現(xiàn)與優(yōu)化研究
  2. Qinglin Wang,Dongsheng Li. Optimizing FFT-Based Convolutionon ARMv8 Multi-core CPUs
  3. Aleksandar Zlateski, Zhen Jia, Kai Li, Fredo Durand. FFT Convolutions are Faster than Winograd onModern CPUs, Here’s Why

附:

GitHub:MegEngine 天元

官網(wǎng):MegEngine-深度學(xué)習(xí),簡(jiǎn)單開發(fā)

歡迎加入 MegEngine 技術(shù)交流 QQ 群:1029741705

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
【社區(qū)內(nèi)容提示】社區(qū)部分內(nèi)容疑似由AI輔助生成,瀏覽時(shí)請(qǐng)結(jié)合常識(shí)與多方信息審慎甄別。
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

相關(guān)閱讀更多精彩內(nèi)容

友情鏈接更多精彩內(nèi)容