網站首頁 編程語言 正文
一、Introduction
線性判別模型(LDA)在模式識別領域(比如人臉識別等圖形圖像識別領域)中有非常廣泛的應用。LDA是一種監督學習的降維技術,也就是說它的數據集的每個樣本是有類別輸出的。這點和PCA不同。PCA是不考慮樣本類別輸出的無監督降維技術。 LDA的思想可以用一句話概括,就是“投影后類內方差最小,類間方差最大”。我們要將數據在低維度上進行投影,投影后希望每一種類別數據的投影點盡可能的接近,而不同類別的數據的類別中心之間的距離盡可能的大。即:將數據投影到維度更低的空間中,使得投影后的點,會形成按類別區分,一簇一簇的情況,相同類別的點,將會在投影后的空間中更接近方法。
1 LDA的優點
- 在降維過程中可以使用類別的先驗知識經驗,而像PCA這樣的無監督學習則無法使用類別先驗知識;
- LDA在樣本分類信息依賴均值而不是方差的時候,比PCA之類的算法較優
2 LDA的缺點
- LDA不適合對非高斯分布樣本進行降維,PCA也有這個問題
- LDA降維最多降到類別數 k-1 的維數,如果我們降維的維度大于 k-1,則不能使用 LDA。當然目前有一些LDA的進化版算法可以繞過這個問題
- LDA在樣本分類信息依賴方差而不是均值的時候,降維效果不好
- LDA可能過度擬合數據
3 LDA在模式識別領域與自然語言處理領域的區別
在自然語言處理領域,LDA是隱含狄利克雷分布,它是一種處理文檔的主題模型。本文討論的是線性判別分析 LDA除了可以用于降維以外,還可以用于分類。一個常見的LDA分類基本思想是假設各個類別的樣本數據符合高斯分布,這樣利用LDA進行投影后,可以利用極大似然估計計算各個類別投影數據的均值和方差,進而得到該類別高斯分布的概率密度函數。當一個新的樣本到來后,我們可以將它投影,然后將投影后的樣本特征分別帶入各個類別的高斯分布概率密度函數,計算它屬于這個類別的概率,最大的概率對應的類別即為預測類別
二、Demo
#%%導入基本庫 # 基礎數組運算庫導入 import numpy as np # 畫圖庫導入 import matplotlib.pyplot as plt # 導入三維顯示工具 from mpl_toolkits.mplot3d import Axes3D # 導入LDA模型 from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # 導入demo數據制作方法 from sklearn.datasets import make_classification #%%模型訓練 # 制作四個類別的數據,每個類別100個樣本 X, y = make_classification(n_samples=1000, n_features=3, n_redundant=0, n_classes=4, n_informative=2, n_clusters_per_class=1, class_sep=3, random_state=10) # 將四個類別的數據進行三維顯示 fig = plt.figure() ax = Axes3D(fig, rect=[0, 0, 1, 1], elev=20, azim=20) ax.scatter(X[:, 0], X[:, 1], X[:, 2], marker='o', c=y) plt.show()
#%%建立 LDA 模型 lda = LinearDiscriminantAnalysis() # 進行模型訓練 lda.fit(X, y) #%%查看lda的參數 print(lda.get_params())
#%%數據可視化 #模型預測 X_new = lda.transform(X) # 可視化預測數據 plt.scatter(X_new[:, 0], X_new[:, 1], marker='o', c=y) plt.show()
#%%使用新的數據進行測試 a = np.array([[-1, 0.1, 0.1]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a)) a = np.array([[-12, -100, -91]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a)) a = np.array([[-12, -0.1, -0.1]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a)) a = np.array([[0.1, 90.1, 9.1]]) print(f"{a} 類別是: ", lda.predict(a)) print(f"{a} 類別概率分別是: ", lda.predict_proba(a))
三、基于LDA 手寫數字的分類
#%%導入庫函數 # 導入手寫數據集 MNIST from sklearn.datasets import load_digits # 導入訓練集分割方法 from sklearn.model_selection import train_test_split # 導入LDA模型 from sklearn.discriminant_analysis import LinearDiscriminantAnalysis # 導入預測指標計算函數和混淆矩陣計算函數 from sklearn.metrics import classification_report, confusion_matrix # 導入繪圖包 import seaborn as sns import matplotlib import matplotlib.pyplot as plt #%% 導入MNIST數據集 mnist = load_digits() # 查看數據集信息 print('The Mnist dataeset:\n',mnist) # 分割數據為訓練集和測試集 x, test_x, y, test_y = train_test_split(mnist.data, mnist.target, test_size=0.1, random_state=2)
#%%## 輸出示例圖像 images = range(0,9) plt.figure(dpi=100) for i in images: plt.subplot(330 + 1 + i) plt.imshow(x[i].reshape(8, 8), cmap = matplotlib.cm.binary,interpolation="nearest") # show the plot plt.show()
#%%利用LDA對手寫數字進行訓練與預測 m_lda = LinearDiscriminantAnalysis()# 建立 LDA 模型 # 進行模型訓練 m_lda.fit(x, y) # 進行模型預測 x_new = m_lda.transform(x) # 可視化預測數據 plt.scatter(x_new[:, 0], x_new[:, 1], marker='o', c=y) plt.title('MNIST with LDA Model') plt.show()
#%% 進行測試集數據的類別預測 y_test_pred = m_lda.predict(test_x) print("測試集的真實標簽:\n", test_y) print("測試集的預測標簽:\n", y_test_pred) #%% 進行預測結果指標統計 統計每一類別的預測準確率、召回率、F1分數 print(classification_report(test_y, y_test_pred)) # 計算混淆矩陣 C2 = confusion_matrix(test_y, y_test_pred) # 打混淆矩陣 print(C2) # 將混淆矩陣以熱力圖的防線顯示 sns.set() f, ax = plt.subplots() # 畫熱力圖 sns.heatmap(C2, cmap="YlGnBu_r", annot=True, ax=ax) # 標題 ax.set_title('confusion matrix') # x軸為預測類別 ax.set_xlabel('predict') # y軸實際類別 ax.set_ylabel('true') plt.show()
四、小結
LDA適用于線性可分數據,在非線性數據上要謹慎使用。 886~~~
原文鏈接:https://blog.csdn.net/qq_43368987/article/details/122472515
相關推薦
- 2023-12-13 Excel統計某個關鍵字出現的次數
- 2022-10-22 PyTorch中的CUDA的操作方法_python
- 2022-08-05 C#實現圖形界面的時鐘_C#教程
- 2023-04-10 詳解Go語言中的數據庫操作_Golang
- 2022-06-12 C語言經典順序表真題演練講解_C 語言
- 2022-10-28 Go保證并發安全底層實現詳解_Golang
- 2022-03-16 C++中的Lambda函數詳解_C 語言
- 2023-01-19 flask?post獲取前端請求參數的三種方式總結_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支