演算法導論(三)

2021-07-22 16:59:10 字數 3566 閱讀 8627

對於矩陣乘法

c = a×b

,通常的做法是將矩陣進行分塊相乘,如下圖所示:

從上圖可以看出這種分塊相乘總共用了

8次乘法,當然對於子矩陣相乘(如a0×

b0),還可以繼續遞迴使用分塊相乘。對於中小矩陣來說,很適合使用這種分塊乘法,但是對於大矩陣來說,遞迴的次數較多,如果能減少每次分塊乘法的次數,那麼效能將可以得到很好的提高。

strassen

矩陣乘法就是採用了乙個簡單的運算技巧,將上面的

8次矩陣相乘變成了

7次乘法,看別小看這減少的

1次乘法,因為每遞迴

1次,效能就提高了

1/8,比如對於

1024*1024

的矩陣,第

1次先分解成7次

512*512

的矩陣相乘,對於

512*512

的矩陣,又可以繼續遞迴分解成

256*256

的矩陣相乘,

…,一直遞迴下去,假設分解到

64*64

的矩陣大小後就不再遞迴,那麼所花的時間將是分塊矩陣乘法的

(7/8) * (7/8) * (7/8) * (7/8) = 0.586

倍,提高了快接近一倍。當然這是理論上的值,因為實際上

strassen

乘法增加了其他運算開銷,實際效能會略低一點。

由上可見,strassen矩陣乘法是通過遞迴實現的,它將一般情況下二階矩陣乘法(可擴充套件到n階,但strassen矩陣乘法要求n是2的冪)所需的8次乘法降低為7次,其c++實現**如下:

下面就是

strassen

矩陣乘法的實現方法,

m1 = (a0 + a3) × (b0 + b3)

m2 = (a2 + a3) × b0

m3 = a0 × (b1 - b3)

m4 = a3 × (b2 - b0)

m5 = (a0 + a1) × b3

m6 = (a2 - a0) × (b0 + b1)

m7 = (a1 - a3) × (b2 + b3)

c0 = m1 + m4 - m5 + m7

c1 = m3 + m5

c2 = m2 + m4

c3 = m1 - m2 + m3 + m6

在求解m1,m2,m3,m4,m5,m6,m7

時需要使用

7次矩陣乘法,其他都是矩陣加法和減法。

下面看看

strassen

矩陣乘法的序列實現偽**:

serial_strassenmultiply(a, b, c)

#include 

using

namespace std;

const

int n = 6; //

define the size of the matrix

template

void strassen(int n, t a[n], t b[n], t c[n]);

template

void input(int n, t p[n]);

template

void output(int n, t c[n]);

int main()

template

void input(int n, t p[n])

}

}

template

void output(int n, t c[n])

}

}

}

template

void matrix_add(int n, t x[n], t y[n], t z[n]) else

}

//calculate m1 = (a0 + a3) × (b0 + b3)

matrix_add(n/2, a11, a22, aa);

matrix_add(n/2, b11, b22, bb);

strassen(n/2, aa, bb, m1);

//calculate m2 = (a2 + a3) × b0

matrix_add(n/2, a21, a22, aa);

strassen(n/2, aa, b11, m2);

//calculate m3 = a0 × (b1 - b3)

matrix_sub(n/2, b12, b22, bb);

strassen(n/2, a11, bb, m3);

//calculate m4 = a3 × (b2 - b0)

matrix_sub(n/2, b21, b11, bb);

strassen(n/2, a22, bb, m4);

//calculate m5 = (a0 + a1) × b3

matrix_add(n/2, a11, a12, aa);

strassen(n/2, aa, b22, m5);

//calculate m6 = (a2 - a0) × (b0 + b1)

matrix_sub(n/2, a21, a11, aa);

matrix_add(n/2, b11, b12, bb);

strassen(n/2, aa, bb, m6);

//calculate m7 = (a1 - a3) × (b2 + b3)

matrix_sub(n/2, a12, a22, aa);

matrix_add(n/2, b21, b22, bb);

strassen(n/2, aa, bb, m7);

//calculate c0 = m1 + m4 - m5 + m7

matrix_add(n/2, m1, m4, aa);

matrix_sub(n/2, m7, m5, bb);

matrix_add(n/2, aa, bb, c11);

//calculate c1 = m3 + m5

matrix_add(n/2, m3, m5, c12);

//calculate c2 = m2 + m4

matrix_add(n/2, m2, m4, c21);

//calculate c3 = m1 - m2 + m3 + m6

matrix_sub(n/2, m1, m2, aa);

matrix_add(n/2, m3, m6, bb);

matrix_add(n/2, aa, bb, c22);

//set the result to c[n]

for(int i=0; i2; i++)

}

}

}

演算法導論之演算法基礎(三)

插入排序 模擬 如果你會玩鬥地主,那麼摸牌後按從小到大插入,你這樣插入的過程就是插入排序 程式 在程式中的玩法就像有乙個人發牌,發齊了再拿牌,也就是一開始你就有17張牌,這17張牌對應17個元素的陣列。你從第二種牌開始進行調動,如果第二張牌比第一張牌小,那麼就把第二張牌抽出來,然後把第一張牌放入到第...

重溫演算法導論(三) 氣泡排序

氣泡排序原理簡單,從最後的元素與前面的元素比較,小於則交換,最後最小的在最左邊 偽 實現如下 for i 1 to length a 實際陣列的下標從0開始 do for j length a downto i 1 do if a j a j 1 then exchange a j a j 1 實際...

演算法導論 隨機演算法

一.概率分布 對於有些問題本身是屬於概率問題,如僱傭問題 對於此類問題,我們需要利用概率分析來得到演算法的執行時間,有時也用來分析其他的量。例如,僱傭問題中的費用問題也需要結合概率分析來計算得到。為了使用概率分析,我們必須使用或者假設已知關於輸入的概率分布,然後通過分析該演算法計算出平均情況下的執行...