網站首頁 編程語言 正文
邏輯斯蒂回歸模型多分類任務
上節中,我們使用邏輯斯蒂回歸完成了二分類任務,針對多分類任務,我們可以采用以下措施,進行分類。
我們以三分類任務為例,類別分別為a,b,c。
1.ovr策略
我們可以訓練a類別,非a類別的分類器,確認未來的樣本是否為a類; 同理,可以訓練b類別,非b類別的分類器,確認未來的樣本是否為b類; 同理,可以訓練c類別,非c類別的分類器,確認未來的樣本是否為c類;這樣我們通過增加分類器的數量,K類訓練K個分類器,完成多分類任務。
2.one vs one策略
我們將樣本根據類別進行劃分,分別訓練a與b、a與c、b與c之間的分類器,通過多個分類器判斷結果的匯總打分,判斷未來樣本的類別。 同樣使用了增加分類的數量的方法,需要注意訓練樣本的使用方法不同,K類訓練K(K-1)/2個分類器,完成多分類任務
3.softmax策略
通過計算各個類別的概率,比較最高概率后,確定最終的類別。
對于類別互斥的情況,建議使用softmax,而不同類別之間關聯性較強時,建議使用增加多個分類器的策略。
邏輯斯蒂回歸模型多分類案例實現
本例我們使用sklearn數據集,鳶尾花數據。
1.加載數據
- 樣本總量:150組
- 預測類別:山鳶尾,雜色鳶尾,弗吉尼亞鳶尾三類,各50組。
- 樣本特征4種:花萼長度sepal length (cm) 、花萼寬度sepal width (cm)、花瓣長度petal length (cm)、花瓣寬度petal width (cm)。
2.使用seaborn提供的pairplot方法,可視化展示特征與標簽
3.訓練模型
from sklearn.datasets import load_iris
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
#加載數據
data = load_iris()
iris_target = data.target #
iris_df = pd.DataFrame(data=data.data, columns=data.feature_names) #利用Pandas轉化為DataFrame格式
iris_df['target'] = iris_target
## 特征與標簽組合的散點可視化
sns.pairplot(data=iris_df,diag_kind='hist', hue= 'target')
plt.show()
#劃分數據集
X=iris_df.iloc[:,:-1]
y=iris_df.iloc[:,-1]
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)
## 創建邏輯回歸模型
clf = LogisticRegression(random_state=0, solver='lbfgs')
''' 優化算法選擇參數:solver\
solver參數決定了我們對邏輯回歸損失函數的優化方法,有4種算法可以選擇,分別是:
a) liblinear:使用了開源的liblinear庫實現,內部使用了坐標軸下降法來迭代優化損失函數。
b) lbfgs:擬牛頓法的一種,利用損失函數二階導數矩陣即海森矩陣來迭代優化損失函數。
c) newton-cg:也是牛頓法家族的一種,利用損失函數二階導數矩陣即海森矩陣來迭代優化損失函數。
d) sag:即隨機平均梯度下降,是梯度下降法的變種,和普通梯度下降法的區別是每次迭代僅僅用一部分的樣本來計算梯度,適合于樣本數據多的時候。
從上面的描述可以看出,newton-cg, lbfgs和sag這三種優化算法時都需要損失函數的一階或者二階連續導數,因此不能用于沒有連續導數的L1正則化,只能用于L2正則化。而liblinear通吃L1正則化和L2正則化。
同時,sag每次僅僅使用了部分樣本進行梯度迭代,所以當樣本量少的時候不要選擇它,而如果樣本量非常大,比如大于10萬,sag是第一選擇。但是sag不能用于L1正則化,所以當你有大量的樣本,又需要L1正則化的話就要自己做取舍了。要么通過對樣本采樣來降低樣本量,要么回到L2正則化。
從上面的描述,大家可能覺得,既然newton-cg, lbfgs和sag這么多限制,如果不是大樣本,我們選擇liblinear不就行了嘛!錯,因為liblinear也有自己的弱點!我們知道,邏輯回歸有二元邏輯回歸和多元邏輯回歸。對于多元邏輯回歸常見的有one-vs-rest(OvR)和many-vs-many(MvM)兩種。而MvM一般比OvR分類相對準確一些。郁悶的是liblinear只支持OvR,不支持MvM,這樣如果我們需要相對精確的多元邏輯回歸時,就不能選擇liblinear了。也意味著如果我們需要相對精確的多元邏輯回歸不能使用L1正則化了。
'''
clf.fit(x_train, y_train)
## 查看自變量對應的系數w
print('the weight of Logistic Regression:\n',clf.coef_)
## 查看常數項對應的系數w0
print('the intercept(w0) of Logistic Regression:\n',clf.intercept_)
#模型1的變量重要性排序
coef_c1 = pd.DataFrame({'var' : pd.Series(x_test.columns),
'coef_abs' : abs(pd.Series(clf.coef_[0].flatten()))
})
coef_c1 = coef_c1.sort_values(by = 'coef_abs',ascending=False)
print(coef_c1)
#模型2的變量重要性排序
coef_c2 = pd.DataFrame({'var' : pd.Series(x_test.columns),
'coef_abs' : abs(pd.Series(clf.coef_[1].flatten()))
})
coef_c2 = coef_c2.sort_values(by = 'coef_abs',ascending=False)
print(coef_c2)
#模型3的變量重要性排序
coef_c3 = pd.DataFrame({'var' : pd.Series(x_test.columns),
'coef_abs' : abs(pd.Series(clf.coef_[2].flatten()))
})
coef_c3 = coef_c3.sort_values(by = 'coef_abs',ascending=False)
print(coef_c3)
4.對模型進行評價:模型得分、交叉驗證得分、混淆矩陣
from sklearn.metrics import accuracy_score,recall_score
## 模型評價
score = clf.score(x_train,y_train)#Return the mean accuracy on the given test data and labels.
print(score)#0.628125
#模型在訓練集上的得分
train_score = accuracy_score(y_train,clf.predict(x_train))
print(train_score)#0.628125
#模型在測試集上的得分
test_score = clf.score(x_test,y_test)
print(test_score)#0.6
#預測
y_predict = clf.predict(x_test)
#訓練集的召回率
train_recall = recall_score(y_train, clf.predict(x_train), average='macro')
print("訓練集召回率",train_recall)#0.47934382086167804
#測試集的召回率
test_recall = recall_score(y_test, clf.predict(x_test), average='macro')
print("測試集召回率",test_recall)#0.5002736726874658
from sklearn.metrics import classification_report
print('測試數據指標:\n',classification_report(y_test,y_predict,digits=4))
#k-fold交叉驗證得分
from sklearn.model_selection import cross_val_score
scores = cross_val_score(clf,x_train,y_train,cv=10,scoring='accuracy')
print('十折交叉驗證:每一次的得分',scores)
#結果:每一次的得分 [0.59375 0.59375 0.6875 0.59375 0.53125 0.5625 0.65625 0.625 0.71875 0.625 ]
print('十折交叉驗證:平均得分', scores.mean())
#結果:平均得分 0.61875
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
import pandas as pd
labelEncoder = LabelEncoder()
labelEncoder.fit(y)##對變量y進行硬編碼,將標簽變為數字
cm = confusion_matrix(y_test, y_predict)
cm_pd = pd.DataFrame(data = cm,columns=labelEncoder.classes_, index=labelEncoder.classes_)
print("混淆矩陣")
print(cm_pd)
import matplotlib.pyplot as plt
plt.matshow(confusion_matrix(y_test, y_predict))
plt.title('Confusion matrix')
plt.colorbar()
plt.ylabel('Actual type') #實際類型
plt.xlabel('Forecast type') #預測類型
原文鏈接:https://juejin.cn/post/7140666880540278814
相關推薦
- 2022-08-26 C++類模板實戰之vector容器的實現_C 語言
- 2023-03-01 Shell?$[]對整數進行數學運算實現_linux shell
- 2022-03-16 Linux下安裝軟件包報依賴等相關問題的解決方法_Linux
- 2022-08-27 C#8.0中的模式匹配_C#教程
- 2022-10-24 iOS開發之Objective-c的Runtime理解指南_IOS
- 2022-11-21 詳解Go語言中的內存對齊_Golang
- 2022-06-29 C語言超詳細講解指針與結構體_C 語言
- 2022-11-06 Go+Redis實現延遲隊列實操_Golang
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支