Strassen矩陣乘法

2021-05-22 12:06:51 字數 3008 閱讀 6868

strassen矩陣乘法

strassen矩陣乘法是通過遞迴實現的,它將一般情況下二階矩陣乘法(可擴充套件到n階,但strassen矩陣乘法要求n是2的冪)所需的8次乘法降低為7次,將計算時間從o(ne3)降低為o(ne2.81)。

矩陣c = ab,可寫為

c11 = a11b11 + a12b21

c12 = a11b12 + a12b22

c21 = a21b11 + a22b21

c22 = a21b12 + a22b22

如果a、b、c都是二階矩陣,則共需要8次乘法和4次加法。如果階大於2,可以將矩陣分塊進行計算。耗費的時間是o(ne3)。

要改進演算法計算時間的複雜度,必須減少乘法運算次數。按分治法的思想,strassen提出一種新的方法,用7次乘法完成2階矩陣的乘法,演算法如下:

m1 = a11(b12 - b12)

m2 = (a11 + a12)b22

m3 = (a21 + a22)b11

m4 = a22(b21 - b11)

m5 = (a11 + a22)(b11 + b22)

m6 = (a12 - a22)(b21 + b22)

m7 = (a11 - a21)(b11 + b12)

完成了7次乘法,再做如下加法:

c11 = m5 + m4 - m2 + m6

c12 = m1 + m2

c21 = m3 + m4

c22 = m5 + m1 - m3 - m7

全部計算使用了7次乘法和18次加減法,計算時間降低到o(ne2.81)。計算複雜性得到較大改進。

附strassen矩陣乘法**:

#include

const int n = 8; //常量n用來定義矩陣的大小

template

void strassen(int n, t a[n], t b[n], t c[n]);

template

void input(int n, t p[n]);

template

void output(int n, t c[n]); //函式宣告部分

void main()

} strassen(n,a,b,c); //呼叫strassen函式計算

output(n,c); //輸出計算結果

} template

void input(int n, t p[n]) //矩陣輸入函式}}

template

void output(int n, t c[n]) //矩陣輸出函式

} template

void matrix_multiply(t a[n], t b[n], t c[n]) //按通常的矩陣乘法計算c=ab的子演算法(僅做2階)}}

} template

void matrix_add(int n, t x[n], t y[n], t z[n]) //矩陣加法函式x+y—>z}}

template

void matrix_sub(int n, t x[n], t y[n], t z[n]) //矩陣減法函式x-y—>z}}

// fullfill c = a * b

template

void strassen(int n, t a[n], t b[n], t c[n]) //strassen函式(遞迴)

else

}matrix_sub (n / 2, b12, b22, bb);

strassen (n/ 2, a11, bb, m1); //m1=a11(b12-b22)

matrix_add (n / 2, a11, a12, aa);

strassen (n / 2, aa, b22, m2); //m2=(a11+a12)b22

matrix_add (n / 2, a21, a22, aa);

strassen (n / 2, aa, b11, m3); //m3=(a21+a22)b11

matrix_sub (n / 2, b21, b11, bb);

strassen (n / 2, a22, bb, m4); //m4=a22(b21-b11)

matrix_add (n / 2, a11, a22, aa);

matrix_add (n / 2, b11, b22, bb);

strassen (n / 2, aa, bb, m5); //m5=(a11+a22)(b11+b22)

matrix_sub (n / 2, a12, a22, aa);

matrix_sub (n / 2, b21, b22, bb);

strassen (n / 2, aa, bb, m6); //m6=(a12-a22)(b21+b22)

matrix_sub (n / 2, a11, a21, aa);

matrix_sub (n / 2, b11, b12, bb);

strassen (n / 2, aa, bb, m7); //m7=(a11-a21)(b11+b12)

//計算m1,m2,m3,m4,m5,m6,m7(遞迴部分)

matrix_add (n / 2, m5, m4, mm1);

matrix_sub (n / 2, m6, m2, mm2);

matrix_add (n / 2, mm1, mm2, c11); //c11=m5+m4-m2+m6

matrix_add (n / 2, m1, m2, c12); //c12=m1+m2

matrix_add (n / 2, m3, m4, c21); //c21=m3+m4

matrix_add (n / 2, m5, m1, mm1);

matrix_add (n / 2, m3, m7, mm2);

matrix_sub (n / 2, mm1, mm2, c22); //c22=m5+m1-m3-m7

for (i = 0; i < n / 2; i++)

//計算結果送回c[n][n]}}

}

strassen矩陣乘法 Strassen矩陣乘法

求矩陣a,b相乘的結果c 直接根據矩陣乘法的定義來遍歷計算。c 語言 void matrixmul int a,int b,int c,int m,int b,int n void test3 int b 3 2 int c 2 2 matrixmul int a,int b,int c,2,3,2...

Strassen矩陣乘法

矩陣乘法是線性代數中最常見的運算之一,它在數值計算中有廣泛的應用。若a和b是2個n n的矩陣,則它們的乘積c ab同樣是乙個n n的矩陣。a和b的乘積矩陣c中的元素c i,j 定義為 若依此定義來計算a和b的乘積矩陣c,則每計算c的乙個元素c i,j 需要做n個乘法和n 1次加法。因此,求出矩陣c的...

strassen矩陣乘法

出處 矩陣乘法是線性代數中最常見的運算之一,它在數值計算中有廣泛的應用。若a和b是2個n n的矩陣,則它們的乘積c ab同樣是乙個n n的矩陣。a和b的乘積矩陣c中的元素c i,j 定義為 若依此定義來計算a和b的乘積矩陣c,則每計算c的乙個元素c i,j 需要做n個乘法和n 1次加法。因此,求出矩...