Strassen s 矩陣乘法 分治法實現

2021-08-13 02:57:22 字數 3687 閱讀 9529

內容會持續更新,有錯誤的地方歡迎指正,謝謝!

題目:

1.比較數學定義的矩陣乘法演算法和strassen』s 矩陣乘法演算法的效率;

2.自主生成兩個16*16的矩陣,輸出strassen』s 矩陣乘法演算法結果。

數學定義的矩陣乘法演算法:利用三個for迴圈來解決,時間複雜度為o(n^3)。

數學定義的矩陣乘法演算法的核心**如下:

//公理:兩個矩陣相乘a*b,a的列數必等於b的行數。

int a[2][3] = ;

int b[3][1] = ;

for (int i = 0; i < 2; ++i)

}

一般演算法需要八次乘法:

試試strassen』s 矩陣乘法演算法:

我們可以推出:

上面只有7次乘法和多次加減法,strassen』s 矩陣乘法演算法將其變成7次乘法。大家都知道乘法比加減法消耗更多的效能!所以,該演算法能將時間複雜度降低到o( n^lg7 ) = o( n^2.81 )。

**實現如下:(其中n必須為2的冪,這裡n=16)

#include using namespace std;

#define n

16//矩陣相加

void plus(int a[n / 2][n / 2], int b[n / 2][n / 2], int c[n / 2][n / 2])

}}//矩陣相減

void minus(int a[n / 2][n / 2], int b[n / 2][n / 2], int c[n / 2][n / 2])

}}//矩陣相乘

void multiply(int a[n / 2][n / 2], int b[n / 2][n / 2], int c[n / 2][n / 2])}}

}int main()

}int i[n / 2][n / 2], j[n / 2][n / 2], k[n / 2][n / 2], l[n / 2][n / 2];

int a[n / 2][n / 2], b[n / 2][n / 2], c[n / 2][n / 2], d[n / 2][n / 2];

int e[n / 2][n / 2], f[n / 2][n / 2], g[n / 2][n / 2], h[n / 2][n / 2];

int s1[n / 2][n / 2], s2[n / 2][n / 2], s3[n / 2][n / 2], s4[n / 2][n / 2];

int s5[n / 2][n / 2], s6[n / 2][n / 2], s7[n / 2][n / 2];

int t1[n / 2][n / 2], t2[n / 2][n / 2];

//將原矩陣m1、m2拆分為a b c

de f g h矩陣

for (i = 0; i < n / 2; i++)

}//s1

minus(i, f, h);

multiply(s1, a, i);

//s2

plus(i, a, b);

multiply(s2, i, h);

//s3

plus(i, c, d);

multiply(s3, i, e);

//s4

minus(i, g, e);

multiply(s4, d, i);

//s5

plus(i, a, d);

plus(j, e, f);

multiply(s5, i, j);

//s6

minus(i, b, d);

plus(j, g, h);

multiply(s6, i, j);

//s7

minus(i, a, c);

plus(j, e, f);

multiply(s7, i, j);

//計算i j k l矩陣

//i = s5 + s4 - s2 + s6

plus(t1, s5, s4);

minus(t2, t1, s2);

plus(i, t2, s6);

//j = s1 + s2

plus(j, s1, s2);

//k = s3 + s4

plus(k, s3, s4);

//l = s5 + s1 - s3 - s7 = s5 + s1 - ( s3 + s7 )

plus(t1, s5, s1);

plus(t2, s3, s7);

minus(l, t1, t2);

//將得到的i j k l矩陣合併到最終結果result矩陣中

int result[n][n] = ;

for (int i = 0; i < n / 2; i++)

}//輸出最終的矩陣

for (i = 0; i < n; ++i)

}cout << endl;

getchar();

return 0;

}

備註:由於博主時間問題,本**並未實現遞迴,也就是並未利用分治法拆分到最小單元再計算再合併,只是闡述了分治法解決該問題的思路,若要實現完整版**,我指明方法:

新建乙個遞迴函式,需將main()裡的部分**移到遞迴函式裡,並需修改遞迴函式裡的所有二維陣列的定義,例如:

int matrixa[n / 2][n / 2];

int matrixb[n / 2][n / 2];

int matrixc[n / 2][n / 2];

應該被修改為如下形式:

//n為遞迴函式傳入的引數

int** matrixa = new

int*[n];

int** matrixb = new

int*[n];

int** matrixc = new

int*[n];

for (int i = 0; i < n; i++)

用完new的二維陣列之後還要記得釋放記憶體,不然,在遞迴中,很容易產生記憶體洩漏:

for (int i = 0; i < n; i++)

delete a;

delete b;

遞迴函式的引數有n,matrixa,matrixb,matrixc

n用於傳遞矩陣維數。

matrixa矩陣就是上方**的m1矩陣。該題是求m1乘以m2矩陣,你就知道m1是什麼了。

matrixb矩陣就是上方**的m2矩陣。該題是求m1乘以m2矩陣,你就知道m2是什麼了。

matrixc矩陣用於記錄結果,最後輸出matrixc即是最終結果。

分治法實現的完整**,能輸出最終結果和每一次遞迴的s1~s7:

分治法 矩陣乘法

問題 給定兩個n階方陣相乘,對求解演算法進行優化。首先,根據傳統演算法,兩個n階矩陣相乘,對於n2個元素,每個元素想要被計算出來,至少要進行n次乘法和n 1次加法,演算法複雜度達到o n3 考慮將矩陣分塊,分為4個n 2的方陣,那麼每個兩個小方陣相乘的複雜度為o n3 8,要想得出最終結果,一共需要...

矩陣乘法 矩陣乘法的基本實現

求解關於兩個矩陣的乘積 參考線性代數裡面的兩個矩陣相乘的規則,我這裡不再贅述,詳情附上了乙個鏈結,我的程式設計也是用了裡面的例子 這裡寫鏈結內容 具體的過程我會在 片裡面加上注釋 矩陣乘法 author seen 2015 09 18 include using namespace std int ...

mysql 矩陣乘法 矩陣乘法高階操作

對於矩陣乘法的一些操作 我們 其實 大部分是 多追加乙個係數 或者和 其他演算法連在一起。至於核心無非就是 先列出dp 方程再優化 或者 直接 對題目進行建模 構建矩陣。至於矩陣乘法的正確性 形狀的正確性 是可以證明的 但是內部最真實的正確性我還無法證明。這道題是 字串型別的題目 求方案數 很煩 大...