Strassen矩陣演算法分析及其實現

2021-08-30 09:13:58 字數 3320 閱讀 3685

對於矩陣乘法c =  a × b,通常的做法是將矩陣進行分塊相乘,如下圖所示:

從上圖可以看出這種分塊相乘總共用了8次乘法,當然對於子矩陣相乘(如a0×b0),還可以繼續遞迴使用分塊相乘。對於中小矩陣來說,很適合使用這種分塊乘法,但是對於大矩陣來說,遞迴的次數較多,如果能減少每次分塊乘法的次數,那麼效能將可以得到很好的提高。

strassen矩陣乘法就是採用了乙個簡單的運算技巧,將上面的8次矩陣相乘變成了7次乘法,看別小看這減少的1次乘法,因為每遞迴1次,效能就提高了1/8,比如對於1024*1024的矩陣,第1次先分解成7次512*512的矩陣相乘,對於512*512的矩陣,又可以繼續遞迴分解成256*256的矩陣相乘,…,一直遞迴下去,假設分解到64*64的矩陣大小後就不再遞迴,那麼所花的時間將是分塊矩陣乘法的(7/8) * (7/8) * (7/8) * (7/8) = 0.586倍,提高了快接近一倍。當然這是理論上的值,因為實際上strassen乘法增加了其他運算開銷,實際效能會略低一點。

下面就是strassen矩陣乘法的實現方法,

m1 = (a0 + a3) × (b0 + b3)

m2 = (a2 + a3) × b0

m3 = a0 × (b1 - b3)

m4 = a3 × (b2 - b0)

m5 = (a0 + a1) × b3

m6 = (a2 - a0) × (b0 + b1)

m7 = (a1 - a3) × (b2 + b3)

c0 = m1 + m4 - m5 + m7

c1 = m3 + m5

c2 = m2 + m4

c3 = m1 - m2 + m3 + m6

在求解m1,m2,m3,m4,m5,m6,m7時需要使用7次矩陣乘法,其他都是矩陣加法和減法。

下面看看strassen矩陣乘法的序列實現偽**:

serial_strassenmultiply(a, b, c)

由上可見,strassen矩陣乘法是通過遞迴實現的,它將一般情況下二階矩陣乘法(可擴充套件到n階,但strassen矩陣乘法要求n是2的冪)所需的8次乘法降低為7次,其c++實現**如下:

#include using namespace std;

const int n = 6; //define the size of the matrix

templatevoid strassen(int n, t a[n], t b[n], t c[n]);

templatevoid input(int n, t p[n]);

templatevoid output(int n, t c[n]);

int main()

}}/**the output fanction of matrix*/

templatevoid output(int n, t c[n])

}

}}/**matrix addition*/

template void matrix_add(int n, t x[n], t y[n], t z[n]) else

} //calculate m1 = (a0 + a3) × (b0 + b3)

matrix_add(n/2, a11, a22, aa);

matrix_add(n/2, b11, b22, bb);

strassen(n/2, aa, bb, m1);

//calculate m2 = (a2 + a3) × b0

matrix_add(n/2, a21, a22, aa);

strassen(n/2, aa, b11, m2);

//calculate m3 = a0 × (b1 - b3)

matrix_sub(n/2, b12, b22, bb);

strassen(n/2, a11, bb, m3);

//calculate m4 = a3 × (b2 - b0)

matrix_sub(n/2, b21, b11, bb);

strassen(n/2, a22, bb, m4);

//calculate m5 = (a0 + a1) × b3

matrix_add(n/2, a11, a12, aa);

strassen(n/2, aa, b22, m5);

//calculate m6 = (a2 - a0) × (b0 + b1)

matrix_sub(n/2, a21, a11, aa);

matrix_add(n/2, b11, b12, bb);

strassen(n/2, aa, bb, m6);

//calculate m7 = (a1 - a3) × (b2 + b3)

matrix_sub(n/2, a12, a22, aa);

matrix_add(n/2, b21, b22, bb);

strassen(n/2, aa, bb, m7);

//calculate c0 = m1 + m4 - m5 + m7

matrix_add(n/2, m1, m4, aa);

matrix_sub(n/2, m7, m5, bb);

matrix_add(n/2, aa, bb, c11);

//calculate c1 = m3 + m5

matrix_add(n/2, m3, m5, c12);

//calculate c2 = m2 + m4

matrix_add(n/2, m2, m4, c21);

//calculate c3 = m1 - m2 + m3 + m6

matrix_sub(n/2, m1, m2, aa);

matrix_add(n/2, m3, m6, bb);

matrix_add(n/2, aa, bb, c22);

//set the result to c[n]

for(int i=0; i}}

}

Strassen矩陣相乘演算法

strassen的矩陣相乘方法是一種典型的分治演算法。目前為止,我們已經見過一些分治策略的演算法了,例如歸併排序和karatsuba大數快速乘法。現在,讓我再來看看分治策略的背後是什麼。同動態規劃不同,在動態規劃中,為了得到最終的解決方案,我們經常需要把乙個大的問題 展開 為幾個子問題,但是這裡,我...

strassen矩陣乘法 Strassen矩陣乘法

求矩陣a,b相乘的結果c 直接根據矩陣乘法的定義來遍歷計算。c 語言 void matrixmul int a,int b,int c,int m,int b,int n void test3 int b 3 2 int c 2 2 matrixmul int a,int b,int c,2,3,2...

矩陣乘法 之 strassen 演算法

一般情況下矩陣乘法需要三個for迴圈,時間複雜度為o n 3 現在我們將矩陣分塊如圖 來自mit演算法導論 一般演算法需要八次乘法 r a e b g s a f b h t c e d g u c f d h strassen將其變成7次乘法,因為大家都知道乘法比加減法消耗更多,所有時間複雜更高!...