spark mllib原始碼分析之OWLQN

2021-08-07 14:10:25 字數 2766 閱讀 4904

spark mllib原始碼分析之l-bfgs(一)

spark mllib原始碼分析之l-bfgs(二)

線搜尋spark mllib原始碼分析之邏輯回歸彈性網路elasticnet(一)

spark在elastic net中當使用l1正則化時,optimizer使用owlqn,我們結合理論介紹其在spark中的實現。

結合**《scalable training of l1-regularized log-linear models》中的內容,簡單介紹owlqn演算法,以與spark(breeze)中的實現相對照,首先定義目標函式 f(

x)=ℓ

(x)+

c∥x∥

1 其中c是大於0的常數。 定義π

函式,對於向量x⃗ 

與y⃗  ,有 π(

x;y)

=過載了iterations函式,引數f實際是f(

x+λd

) ,是個偏函式,x與d都已經給定,只需要計算不同的

λ 下的值,在owlqn演算法中,這個函式返回的是(loss, gradient*d),這裡的

α 對應演算法中的λ

def iterations(f: difffunction[double], init: double = 1.0): iterator[state] =  else

if (enforcewolfeconditions && (fderiv < cwolfe * df0)) else

if (enforcestrongwolfeconditions && (fderiv > -cwolfe * df0)) else

if(multiplier == 1.0) else else

if(newalpha < minalpha) else

if (newalpha > maxalpha)

//計算新alpha對應的(loss, g*d)

val (fvalnew, fderivnew) = f.calculate(newalpha)

//構造新state

(state(newalpha, fvalnew, fderivnew), false, iter+1)

} //演算法結束條件,multiplier==1.0或者迭代到了最大次數,返回state

}.takewhile(triple => !triple._2 && (triple._3 < maxiterations)).map(_._1)

}

這裡指的是breeze中實現的firstorderminimizer抽象類,結合之前的文章,我們給出其通用的優化步驟

owlqn繼承自l-bfgs,在**中作者甚至說在l-bfgs的基礎上實現owlqn是需要改動大概30行**。其涉及到的改動主要對應上面的演算法。

override

protected

def choosedescentdirection(state: state, fn: difffunction[t]) = )

correcteddir

}

結合演算法1.2節可以知道,correcteddir保證搜尋方向需要與負梯度方向滿足

π 函式,**中是判斷的是與梯度的關係,因此是反的。

override

protected

def determinestepsize(state: state, f: difffunction[t], dir: t) =

} //使用backtracking line search線搜尋計算最優步長,其初始值是前一輪的loss(state.value),f函式根據本輪計算的dir,計算當前的loss和g*d

val search = new backtrackinglinesearch(state.value, shrinkstep= if(iter < 1) 0.1

else

0.5)

val alpha = search.minimize(ff, if(iter < 1) .5/norm(state.grad) else

1.0)

alpha

}

結合演算法1.3.2可知,函式ff中返回值那一步的direction應該是xk

+1−x

k ,但是這裡卻是直接使用了全域性的搜尋方向dir(來自於第一步choosedescentdirection),原始碼注釋提到不確定這樣做對不對,但是在實際中工作的挺好。

主要是調整loss和梯度,loss的調整主要是加上正則化的部分,梯度在這裡則是使用偽梯度(pseudo gradient)

override protected def adjust(newx: t, newgrad: t, newval: double): (double, t) =  else  

case _ => v + math.signum(xv) * l1regvalue

} }

})adjvalue -> res

}

計算ξ

private def computeorthant(x: t, grad: t) = )

orth

}

計算xk

+1,注意對新值的約束

override

protected

def takestep(state: state, dir: t, stepsize: double) = )

}

Spark MLlib原始碼分析 TFIDF原始碼詳解

以下 是我依據sparkmllib 版本1.6 1 hashingtf 是使用雜湊表來儲存分詞,並計算分詞頻數 tf 生成hashmap表。在map中,k為分詞對應索引號,v為分詞的頻數。在宣告hashingtf 時,需要設定numfeatures,該屬性實為設定雜湊表的大小 如果設定numfeat...

spring原始碼分析 spring原始碼分析

1.spring 執行原理 spring 啟動時讀取應用程式提供的 bean 配置資訊,並在 spring 容器中生成乙份相應的 bean 配置登錄檔,然後根據這張登錄檔例項化 bean,裝配好 bean 之間的依賴關係,為上 層應用提供準備就緒的執行環境。二 spring 原始碼分析 1.1spr...

思科VPP原始碼分析(dpo機制原始碼分析)

vpp的dpo機制跟路由緊密結合在一起。路由表查詢 ip4 lookup 的最後結果是乙個load balance t結構。該結構可以看做是乙個hash表,裡面包含了很多dpo,指向為下一步處理動作。每個dpo都是新增路由時的乙個path的結果。dpo標準型別有 dpo drop,dpo ip nu...