PyTorch 使用中注意事項

2022-03-18 17:25:19 字數 1991 閱讀 8998

1. 把label要轉成longtensor格式

self.y = torch.longtensor(y)
完整使用**如下:

1

class

imgdataset(dataset):

2def

__init__(self, x, y=none, transform=none):

3 self.x =x4#

label is required to be a longtensor

5 self.y =y

6if y is

notnone:

7 self.y =torch.longtensor(y)

8 self.transform =transform

9def

__len__

(self):

10return

len(self.x)

11def

__getitem__

(self, index):

12 x =self.x[index]

13if self.transform is

notnone:

14 x =self.transform(x)

15if self.y is

notnone:

16 y =self.y[index]

17return

x, y

18else:19

return x

view code

需要保證target型別為torch.cuda.longtensor,需要在資料讀取的迭代其中把target的型別轉換為int64位的:target = target.astype(np.int64),這樣,輸出的target型別為torch.cuda.longtensor。(或者在使用前使用tensor.type(torch.longtensor)進行轉換)。

*longtensor其實就是int64,有符號整型

2. 做**時,沒有y值,從dataloader中傳入給model的直接是data,而不再是data[0]了

model_best.eval()

prediction =

with torch.no_grad():

for i, data in

enumerate(test_loader):

#print(data[0].size())

#特別要注意的是,這裡直接傳入data,因為已經沒有y值了,所以無需data[0]。

#如果傳了data[0]反而導致沒有傳入整個batch,計算錯誤

test_pred =model_best(data.cuda())

test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)

for y in

test_label:

3. 訓練時,要設成model.train(),這樣optimizer就可以更新model的引數

驗證時,要設成model.val(),以此來固定model的引數。例如 去掉dropout、bn引數不變 等。

4. modulenotfounderror:no module named 」classifier「

訓練完儲存模型後,再load模型去做**時,仍然需要原來訓練時的classifier,即整個網路結構。。

有點匪夷所思呀。。那存模型和存引數比有啥區別呢?存了個寂寞?

還需要查一下存模型和存引數的區別

5.pytorch中nn.crossentropyloss 自帶softmax,無需將輸出經softmax層再計算交叉熵損失

其原始碼實現時,將 input 經過 softmax 啟用函式之後,再計算其與 target 的交叉熵損失

未完待續。。。

JS中注意事項

一 判斷中注意事項 一 所有的相對路徑都別拿來做判斷 1.img src 2.href 1.css href html index.html 3.img src 二 顏色值不能拿來做判斷 color red f00 rgb 250,0,0 三 innerhtml 值不能拿來做判斷 解決 設定開關變數...

php foreach中 注意事項

以前用foreach,總喜歡在第二次遍歷時改變value的拼寫,比如 x array a b c foreach x as value echo foreach x as value2 得到結果12 a b c a b c 並沒有什麼不妥.今天寫的時候沒有留神,發現出錯了,示例如下 x array ...

python中注意事項(更新)

1.print a,b print列印多個變數,預設中間會有個空格作為分隔符,預設以換行符結尾 docstring print value,sep end n file sys.stdout,flush false prints the values to a stream,or to sys.st...