n階方陣乘法straseen

2022-03-24 12:02:52 字數 2569 閱讀 2768

原理:分塊矩陣乘法,進行8次矩陣乘法,時間複雜度為 $\theta(n^3) = \theta(n^}) $ , 改進後僅需要7次乘法, 時間複雜度為 $\theta(n^})$

具體推到見演算法導論中利用主定理推導時間複雜度

def matrix_divide(a):

rows = len(a)

mid = rows // 2

a11 = [[0]*mid for _ in range(mid)]

a12 = [[0]*mid for _ in range(mid)]

a21 = [[0]*mid for _ in range(mid)]

a22 = [[0]*mid for _ in range(mid)]

for i in range(mid):

for j in range(mid):

a11[i][j] = a[i][j]

a12[i][j] = a[i][mid+j]

a21[i][j] = a[mid+i][j]

a22[i][j] = a[mid+i][mid+j]

return a11, a12, a21, a22

def matrix_add(a, b):

rows = len(a)

c = [[0]*rows for _ in range(rows)]

for i in range(rows):

for j in range(rows):

c[i][j] = a[i][j] + b[i][j]

return c

def matrix_sub(a, b):

rows = len(a)

c = [[0]*rows for _ in range(rows)]

for i in range(rows):

for j in range(rows):

c[i][j] = a[i][j] - b[i][j]

return c

def matrix_merge(c11, c12, c21, c22):

rows = len(c11)

n = rows * 2

c = [[0]*n for _ in range(n)]

for i in range(rows):

for j in range(rows):

c[i][j] = c11[i][j]

c[i][rows+j] = c12[i][j]

c[rows+i][j] = c21[i][j]

c[rows+i][rows+j] = c22[i][j]

return c

def strassen(a, b):

n = len(a)

c = [[0] for _ in range(n)]

if n == 1:

c[0][0] = a[0][0]*b[0][0]

return c

a11, a12, a21, a22 = matrix_divide(a)

b11, b12, b21, b22 = matrix_divide(b)

s1 = matrix_sub(b12, b22)

s2 = matrix_add(a11, a12)

s3 = matrix_add(a21, a22)

s4 = matrix_sub(b21, b11)

s5 = matrix_add(a11, a22)

s6 = matrix_add(b11, b22)

s7 = matrix_sub(a12, a22)

s8 = matrix_add(b21, b22)

s9 = matrix_sub(a11, a21)

s10 = matrix_add(b11, b12)

p1 = strassen(a11, s1)

p2 = strassen(s2, b22)

p3 = strassen(s3, b11)

p4 = strassen(a22, s4)

p5 = strassen(s5, s6)

p6 = strassen(s7, s8)

p7 = strassen(s9, s10)

c11 = matrix_add(p5, matrix_sub(p4, matrix_sub(p2, p6)))

c12 = matrix_add(p1, p2)

c21 = matrix_add(p3, p4)

c22 = matrix_add(p5, matrix_sub(p1, matrix_add(p3, p7)))

return matrix_merge(c11, c12, c21, c22)

def main():

a = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]]

b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]]

c = strassen(a, b)

print(c)

if __name__ == '__main__':

main()

N階魔方陣

寫出程式填寫出n n 魔方陣 的數值。所謂魔方陣是指這樣的方陣,資料是正整數,從1開始,每個遞增1,每個資料不重複出現,它的每一行 每一列和對角線之和均相等 n是奇數 input 3 5 output 8 1 6 3 5 7 4 9 2 17 24 01 08 15 23 05 07 14 16 0...

n階魔方陣

魔方陣 計算規律 1.將1放在第一行中間一列 2.從2開始到nn按如下規律 每乙個數存放的行數比上乙個數的行數減1 每乙個數存放的列數比上乙個數的列數加1 3.當乙個數的行數為1,他的下乙個數行數為n 4.當乙個數的列數為n,他的下乙個數的列數為1,行數減1 5.若按上述規則確定的位置有數字或上乙個...

n階魔方陣

奇數階魔方陣就是指行列數都是吧n n 3 且 n 2 1 的魔方陣 奇數階魔方陣的數字規律 通過對奇數階魔方陣的分析,其中的數字排列有如下的規律 1 自然數1出現在第一行的正中間 2 若填入的數字在第一行 不在第n列 則下乙個數字在第n行 最後一行 且列數加1 列數右移一列 3 若填入的數字在該行的...