numpy高維陣列獲取top K

2021-10-14 14:04:09 字數 2127 閱讀 4080

理論知識請自行翻閱numpy的argpartition和partition方法的實現原理,該文章僅僅包含使用和效率驗證。此外,numpy版本需要》=1.8.0。

不廢話了,直接放**,一看就懂,看不懂再說,自己跑一下就知道。

import numpy as np

defget_sorted_top_k

(array, top_k=

1, axis=-1

, reverse=

false):

""" 多維陣列排序

args:

array: 多維陣列

top_k: 取數

axis: 軸維度

reverse: 是否倒序

returns:

top_sorted_scores: 值

top_sorted_indexes: 位置

"""if reverse:

# argpartition分割槽排序,在給定軸上找到最小的值對應的idx,partition同理找對應的值

# kth表示在前的較小值的個數,帶來的問題是排序後的結果兩個分區間是仍然是無序的

# kth絕對值越小,分割槽排序效果越明顯

axis_length = array.shape[axis]

partition_index = np.take(np.argpartition(array, kth=

-top_k, axis=axis)

,range

(axis_length - top_k, axis_length)

, axis)

else

: partition_index = np.take(np.argpartition(array, kth=top_k, axis=axis)

,range(0

, top_k)

, axis)

top_scores = np.take_along_axis(array, partition_index, axis)

# 分割槽後重新排序

sorted_index = np.argsort(top_scores, axis=axis)

if reverse:

sorted_index = np.flip(sorted_index, axis=axis)

top_sorted_scores = np.take_along_axis(top_scores, sorted_index, axis)

top_sorted_indexes = np.take_along_axis(partition_index, sorted_index, axis)

return top_sorted_scores, top_sorted_indexes

if __name__ ==

"__main__"

:import time

from sklearn.metrics.pairwise import cosine_similarity

x = np.random.rand(10,

128)

y = np.random.rand(

1000000

,128

) z = cosine_similarity(x, y)

start_time = time.time(

) sorted_index_1 = get_sorted_top_k(z, top_k=

3, axis=

1, reverse=

true)[

1]print

(time.time(

)- start_time)

start_time = time.time(

) sorted_index_2 = np.flip(np.argsort(z, axis=1)

[:,-

3:], axis=1)

print

(time.time(

)- start_time)

print

((sorted_index_1 == sorted_index_2)

.all()

)

不吹比的說一句,這段**看著perfect好吧,效率提公升不少。

numpy 求TOP k 找出陣列都相同的數

numpy使用的方法 快速從一組資料中找到最大的值 掌握在矩陣中索引的使用方法 import numpy as np z np.arange 10000 b np.random.shuffle z n 5print z np.argpartition z,n n z np.random.randin...

C new delete 高維陣列小結

借鑑 高維陣列的動態申請和釋放與二維陣列的類似,所以這裡只演示的是二維陣列的動態申請和釋放。先來個大眾版的 1 include 2 3using namespace std 45 int main void 6 1819 for int i 0 i 3 i 20delete p i 2122 del...

numpy建立陣列元素的獲取

import numpy as np arr np.array np.arange 12 reshape 3,4 print arr print arr 0 獲取二維陣列的第一行 print arr 1 獲取二維陣列的第二行 print arr 3 獲取二維陣列的前三行 print arr 0,2 ...