sklearn DecisionTree 原始碼分析

2021-10-06 15:52:32 字數 4089 閱讀 1440

sklearn.tree._classes.basedecisiontree#fit

y至少為1維(意思是可以處理multilabels資料)

y = np.atleast_1d(y)
if is_classifier(self)

: self.tree_ = tree(self.n_features_,

self.n_classes_, self.n_outputs_)

else

: self.tree_ = tree(self.n_features_,

# todo: tree should't need this in this case

np.array([1

]* self.n_outputs_, dtype=np.intp)

, self.n_outputs_)

self.n_outputs_ = y.shape[

1]

self.n_classes_ = self.n_classes_[

0]

self.n_classes_ =

for k in

range

(self.n_outputs_)

: classes_k, y_encoded[

:, k]

= np.unique(y[

:, k]

, return_inverse=

true)0

])

np.unique([3

,2,2

,3,3

,4], return_inverse=

true

)out[4]

:(array([2

,3,4

]), array([1

,0,0

,1,1

,2])

)

return_inverse類似於labelencode

sklearn.tree._tree.tree

def

__cinit__

(self,

int n_features, np.ndarray[size_t, ndim=

1] n_classes,

int n_outputs)

:

特徵數

類別數label維度

# use bestfirst if max_leaf_nodes given; use depthfirst otherwise

if max_leaf_nodes <0:

builder = depthfirsttreebuilder(splitter, min_samples_split,

min_samples_leaf,

min_weight_leaf,

max_depth,

self.min_impurity_decrease,

min_impurity_split)

else

: builder = bestfirsttreebuilder(splitter, min_samples_split,

min_samples_leaf,

min_weight_leaf,

max_depth,

max_leaf_nodes,

self.min_impurity_decrease,

min_impurity_split)

scikit-learn決策樹演算法類庫介紹

最大葉子節點數max_leaf_nodes

通過限制最大葉子節點數,可以防止過擬合,預設是"none」,即不限制最大的葉子節點數。如果加了限制,演算法會建立在最大葉子節點數內最優的決策樹。如果特徵不多,可以不考慮這個值,但是如果特徵分成多的話,可以加以限制,具體的值可以通過交叉驗證得到。

sklearn.tree._tree.depthfirsttreebuilder#build

builder.build(self.tree_, x, y, sample_weight, x_idx_sorted)
cpdef build(self, tree tree,

object x, np.ndarray y,

np.ndarray sample_weight=

none

, np.ndarray x_idx_sorted=

none

):

注意到乙個現象,這裡該有的引數都有,但是class_weight去哪了呢?懷疑是轉化了sample_weight

if self.class_weight is

notnone

: expanded_class_weight = compute_sample_weight(

self.class_weight, y_original)

if expanded_class_weight is

notnone

:if sample_weight is

notnone

: sample_weight = sample_weight * expanded_class_weight

else

: sample_weight = expanded_class_weight

sklearn/tree/_tree.pyx:203

splitter.init(x, y, sample_weight_ptr, x_idx_sorted)
cdef size_t n_node_samples = splitter.n_samples
rc = stack.push(

0, n_node_samples,

0, _tree_undefined,

0, infinity,

0)

rc是根節點,在**前含有所有的樣本

stackstackrecord都是sklearn自己寫的資料結構

is_leaf =

(depth >= max_depth or

n_node_samples < min_samples_split or

n_node_samples <

2* min_samples_leaf or

weighted_n_node_samples <

2* min_weight_leaf)

is_leaf =

(is_leaf or

(impurity <= min_impurity_split)

)

滿足以上條件直接停止**

sklearn.tree._splitter.bestsplitter

sklearn.tree._splitter.bestsplitter#node_split

scikit-learn uses an optimised version of the cart algorithm; however, scikit-learn implementation does not support categorical variables for now.

Cartographer原始碼篇 原始碼分析 1

在安裝編譯cartographer 1.0.0的時候,我們可以看到 主要包括cartorgarpher ros cartographer ceres sover三個部分。其中,ceres solver用於非線性優化,求解最小二乘問題 cartographer ros為ros平台的封裝,獲取感測器資料...

AbstractListView原始碼分析3

normal list that does not indicate choices public static final int choice mode none 0 the list allows up to one choice public static final int choice ...

Android AsyncTask原始碼分析

android中只能在主線程中進行ui操作,如果是其它子執行緒,需要借助非同步訊息處理機制handler。除此之外,還有個非常方便的asynctask類,這個類內部封裝了handler和執行緒池。本文先簡要介紹asynctask的用法,然後分析具體實現。asynctask是乙個抽象類,我們需要建立子...