caffe 分類原始碼解讀

2021-08-11 04:17:32 字數 3646 閱讀 9988

首先, 新建乙個classifier的c++類,其中標頭檔案classifier.h如下:

其中,classifier函式:根據模型的配置檔案.prototxt,訓練好的模型檔案.caffemodel,建立模型,得到net_;處理均值檔案,得到mean_;讀入labels檔案,得到labels_。classify函式:呼叫predict函式對影象img進行分類,返回std::pair< std::string, float >形式的**結果。私有函式:僅供classifier函式和classify函式使用,包括

#include 

#ifdef use_opencv

#include

#include

#include

#endif // use_opencv

#include

#include

#include

#include

#include

#include

#ifdef use_opencv

#ifdef use_opencv

using

namespace caffe; // nolint(build/namespaces)

using

std::string;

//為std::pair建立乙個名為「prediction」的型別別名

typedef

std::pair prediction;

class classifier ;

c++檔案為:

#include "stdafx.h"

#include "classifier.h"

//在classifier類外定義classifier類的建構函式

classifier::classifier(const

string& model_file,

const

string& trained_file,

const

string& mean_file,

const

string& label_file)

// partial_sort 排序用到的自定義比較函式 => 前者比後者大就返回true

static

bool paircompare(const

std::pair& lhs,

const

std::pair& rhs)

// 函式用於返回向量v的前n個最大值的索引,也就是返回概率最大的五個類別的標籤

// 如果你是二分類問題,那麼這個n直接選擇1 (n要小於等於類別數)

static

std::vector

argmax(const

std::vector

& v, int n)

// classifier類的classify函式的定義,裡面呼叫了classifier類的私有函式predict函式和上面實現的argmax函式

// **函式,輸入一張img,希望**的前n種概率最大的,我們一般取n等於1

// 輸入**結果為std::make_pair,每個對包含這個物體的名字,及其相對於的概率

std::vector

classifier::classify(const cv::mat& img, int n)

return predictions;

}// 載入均值檔案函式的定義

void classifier::setmean(const

string& mean_file)

// 重新合成一張

cv::mat mean;

cv::merge(channels, mean);

// 計算每個通道的均值,得到乙個三維的向量channel_mean,然後把三維的向量擴充套件成一張新的均值

// 這種的每個通道的畫素值是相等的,這張均值的大小將和網路的輸入要求一樣

// 注意: 這裡的去均值,是指對需要處理的影象減去均值影象的平均亮度

cv::scalar channel_mean = cv::mean(mean);

mean_ = cv::mat(input_geometry_, mean.type(), channel_mean);

}//classifier類中predict函式的定義,輸入形參為單張影象

std::vector

classifier::predict(const cv::mat& img) );

return ip2_out;

#endif

}// 這個其實是為了獲得net_網路的輸入層資料的指標,然後後面我們直接把輸入資料拷貝到這個指標裡面

void classifier::wrapinputlayer(std::vector

* input_channels)

}// 預處理函式,包括縮放、歸一化、3通道分開儲存

// 對於三通道輸入cnn,經過該函式返回的是std::vector因為是三通道資料,所以用了vector

void classifier::preprocess(const cv::mat& img, std::vector

* input_channels)

classifier::~classifier()

呼叫:

int main(int argc, char** argv) 

::google::initgooglelogging(argv[0]); // 可以不需要日誌

string model_file = argv[1];

string trained_file = argv[2];

string mean_file = argv[3];

string label_file = argv[4];

// 建立物件並初始化網路、模型、均值、標籤各類物件

classifier classifier(model_file, trained_file, mean_file, label_file);

string file = argv[5];//輸入的待測

// 列印資訊

std::cout

<< "---------- prediction for "

<< file << " ----------"

<< std::endl;

cv::mat img = cv::imread(file, -1);

check(!img.empty()) << "unable to decode image "

<< file;

std::vector

predictions = classifier.classify(img);

// 將測試結果列印 std::pair型別的p變數,p.second代表概率值,p.first代表類別標籤

for (size_t i = 0; i < predictions.size(); ++i)

}#else

int main(int argc, char** argv)

#endif // use_opencv

參考博文:caffe 中classification.cpp的原始碼詳解、改寫

Caffe原始碼解讀 syncedmem類

記憶體同步 syncedmem 類的作用在於管理主機 cpu 和裝置 gpu 之間的記憶體分配和資料同步,封裝了二者之間的互動操作。這個類沒有對應的protobuffer描述,所以直接看.include caffe syncedmem.cpp檔案 ifndef caffe syncedmem hpp...

Caffe原始碼 math functions 解析

math function 定義了caffe 中用到的一些矩陣操作和數值計算的一些函式,這裡以float型別為例做簡單的分析 template void caffe cpu gemm const cblas transpose transa,const cblas transpose transb,...

caffe原始碼解析

目錄目錄 簡單介紹 主要函式readprotofromtextfile 函式 writeprotototextfile 函式 readprotofrombinaryfile 函式 writeprototobinaryfile 函式 readimagetocvmat 函式 matchext 函式 cv...