網站首頁 編程語言 正文
數據集劃分方法
train_test_split
train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)
參數包括:
- test_size:可選參數,表示測試集的大小。可以是一個表示比例的浮點數(例如0.2表示20%的數據作為測試集),或者是一個表示樣本數量的整數。默認為None。
- train_size:可選參數,表示訓練集的大小。可以是一個表示比例的浮點數(例如0.8表示80%的數據作為訓練集),或者是一個表示樣本數量的整數。默認為None,表示訓練集的大小由測試集大小決定。
- random_state:可選參數,表示隨機數生成器的種子,用于隨機劃分數據集。設置一個整數值可以保證每次劃分的結果一致。
- shuffle:可選參數,表示是否在劃分數據集之前對數據進行隨機打亂。默認為True,即進行隨機打亂。
- stratify:可選參數,表示根據指定的標簽數組進行分層劃分。標簽數組的長度必須與輸入數據集的第一個維度相同。適用于分類問題中的類別不平衡情況。
from sklearn.model_selection import train_test_split
X, y = load_data() # 加載特征數據 X 和標簽數據 y
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
交叉驗證方法
K折交叉驗證
K折交叉驗證將數據集劃分為K個互不重疊的子集,稱為折(Fold)。模型會進行K次訓練和驗證,每次使用K-1個折作為訓練集,剩下的1個折作為驗證集。K次訓練和驗證的結果會進行平均,得到最終的性能評估。K折交叉驗證可以通過KFold類實現,具體用法如下
from sklearn.model_selection import KFold
X = np.arange(10)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):
X_train, X_test = X[train_index], X[test_index]
print(X_train, X_test)
print("*"*20)
執行結果
留一交叉驗證LeaveOneOut
留一交叉驗證是一種特殊的K折交叉驗證,其中K等于數據集的樣本數量。每個樣本都作為單獨的驗證集,而其余樣本作為訓練集。這種方法適用于數據集較小的情況。留一交叉驗證可以通過LeaveOneOut類實現,具體用法如下
from sklearn.model_selection import LeaveOneOut
loo = LeaveOneOut()
X = np.arange(10)
for train_index, test_index in loo.split(X):
X_train, X_test = X[train_index], X[test_index]
print(X_train, X_test)
print("*"*20)
# 在訓練集上訓練模型,使用測試集進行評估
分組交叉驗證GroupKFold
分組交叉驗證是一種考慮數據集中樣本之間分組關系的交叉驗證方法。在某些任務中,樣本可能彼此相關或存在依賴關系,例如在自然語言處理中的句子分類任務中,同一篇文章中的句子可能相互影響。為了確保模型在訓練集和驗證集中都包含相同分組的樣本,可以使用GroupKFold類進行分組交叉驗證。具體用法如下
from sklearn.model_selection import GroupKFold
gkf = GroupKFold(n_splits=3)
for train_index, test_index in gkf.split(X, y, groups):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# 在訓練集上訓練模型,使用測試集進行評估
groups參數是一個表示樣本分組的數組,長度與數據集的樣本數相同。
隨機重復K折交叉驗證RepeatedKFold
隨機重復K折交叉驗證是K折交叉驗證的擴展,通過多次重復執行K折交叉驗證來更穩定地評估模型性能。可以使用RepeatedKFold類進行隨機重復K折交叉驗證。具體用法如下
from sklearn.model_selection import RepeatedKFold
rkf = RepeatedKFold(n_splits=5, n_repeats=10, random_state=42)
X = np.arange(10)
i = 0
for train_index, test_index in rkf.split(X):
X_train, X_test = X[train_index], X[test_index]
print(X_train, X_test)
print("*"*20)
i +=1
print(i)
層次化交叉驗證cross_val_score
層次化交叉驗證是一種嵌套的交叉驗證方法,用于在模型選擇和性能評估中進行雙重交叉驗證。外層交叉驗證用于評估不同的模型或模型參數,內層交叉驗證用于在每個外層驗證折上進行模型訓練和驗證。可以通過嵌套使用cross_val_score函數來實現層次化交叉驗證。具體用法如下
from sklearn.model_selection import cross_val_score
scores = cross_val_score(estimator, X, y, cv=5)
分層K折交叉驗證StratifiedKFold
分層K折交叉驗證是K折交叉驗證的一種變體,它在劃分數據集時保持了每個類別的樣本比例。這對于類別不平衡的分類問題非常重要。分層K折交叉驗證可以通過StratifiedKFold類實現,具體用法與K折交叉驗證類似。
StratifiedKFold的作用是確保每個折中的樣本比例與原始數據集中的樣本比例相同。這對于處理類別不平衡的分類問題非常重要,因為如果樣本比例不平衡,模型在某些折上可能無法學習到少數類別的有效模式
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = y[train_index], y[test_index]
# 在訓練集上訓練模型,使用測試集進行評估
參數搜索和模型選擇方法
網格搜索
網格搜索通過遍歷指定的參數組合來尋找最佳的模型參數配置。它通過窮舉搜索所有參數組合,并在交叉驗證中評估每個組合的性能。GridSearchCV類實現了網格搜索的功能。我們需要指定要搜索的參數和其取值范圍,并指定評估指標和交叉驗證的折數。示例代碼如下
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
from sklearn.datasets import load_iris
# 加載數據集
iris = load_iris()
X = iris.data
y = iris.target
# 定義模型和參數網格
model = SVC()
param_grid = {'C': [0.1, 1, 10], 'kernel': ['linear', 'rbf']}
# 執行網格搜索
grid_search = GridSearchCV(model, param_grid, cv=5)
grid_search.fit(X, y)
# 輸出最佳參數配置和得分
print("Best parameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)
執行結果
隨機搜索
隨機搜索通過隨機抽樣一組參數組合來尋找最佳的模型參數配置。與網格搜索不同,隨機搜索不遍歷所有參數組合,而是在指定的參數空間中進行隨機抽樣,并在交叉驗證中評估每個參數組合的性能。RandomizedSearchCV類實現了隨機搜索的功能。示例代碼如下:
from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
# 加載數據集
iris = load_iris()
X = iris.data
y = iris.target
# 定義模型和參數分布
model = RandomForestClassifier()
param_dist = {'n_estimators': [10, 50, 100], 'max_depth': [None, 5, 10]}
# 執行隨機搜索
random_search = RandomizedSearchCV(model, param_distributions=param_dist, cv=5)
random_search.fit(X, y)
# 輸出最佳參數配置和得分
print("Best parameters: ", random_search.best_params_)
print("Best score: ", random_search.best_score_)
執行結果
交叉驗證(Cross-Validation)
將數據集分成多個折(Fold),每次使用其中一部分作為驗證集,剩余部分作為訓練集進行模型訓練和評估。使用cross_val_score函數進行交叉驗證,并指定模型和評估指標。示例代碼:
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
# 定義模型和數據集
model = DecisionTreeClassifier()
X, y = load_iris(return_X_y=True)
# 執行交叉驗證
scores = cross_val_score(model, X, y, cv=5)
# 輸出每折的得分和平均得分
print("Cross-validation scores: ", scores)
print("Average score: ", scores.mean())
執行結果
學習曲線
通過繪制不同訓練集大小下模型的訓練和驗證得分曲線,評估模型的擬合能力和泛化能力。使用learning_curve函數生成學習曲線數據,并繪制曲線圖。示例代碼:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import learning_curve
from sklearn.linear_model import LogisticRegression
# 加載數據集
X, y = load_digits(return_X_y=True)
# 定義模型
model = LogisticRegression()
# 生成學習曲線數據
train_sizes, train_scores, test_scores = learning_curve(model, X, y, cv=5)
# 繪制學習曲線圖
plt.plot(train_sizes, np.mean(train_scores, axis=1), label='Training score')
plt.plot(train_sizes, np.mean(test_scores, axis=1), label='Validation score')
plt.xlabel('Training Set Size')
plt.ylabel('Score')
plt.title('Learning Curve')
plt.legend(loc='best')
plt.show()
執行結果
原文鏈接:https://blog.csdn.net/qq_29983883/article/details/131492781
- 上一篇:沒有了
- 下一篇:沒有了
相關推薦
- 2022-02-28 Chrome控制臺提示“Slow network is detected. Fallback fon
- 2022-04-20 教你python?中如何取出colomap部分的顏色范圍_python
- 2022-12-04 Python?棧實現的幾種方式及優劣詳解_python
- 2023-07-31 elementui中el-tree控件懶加載和局部刷新
- 2023-01-02 Kotlin?fun函數使用方法_Android
- 2022-04-11 socket連接關閉問題分析_python
- 2022-09-06 python實現plt?x軸坐標按1刻度顯示_python
- 2022-08-21 C#?Any()和AII()方法的區別_C#教程
- 欄目分類
-
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支