網(wǎng)站首頁 編程語言 正文
python 留一交叉驗證
基本原理
K折交叉驗證
簡單來說,K折交叉驗證就是:
- 把數(shù)據(jù)集劃分成K份,取出其中一份作為測試集,另外的K - 1份作為訓(xùn)練集。
- 通過訓(xùn)練集得到回歸方程,再把測試集帶入該回歸方程,得到預(yù)測值。
- 計算預(yù)測值與真實值的差值的平方,得到平方損失函數(shù)(或其他的損失函數(shù))。
- 重復(fù)以上過程,總共得到K個回歸方程和K個損失函數(shù),其中損失函數(shù)最小的回歸方程就是最優(yōu)解。
留一交叉驗證
留一交叉驗證是K折交叉驗證的特殊情況,即:將數(shù)據(jù)集劃分成N份,N為數(shù)據(jù)集總數(shù)。就是只留一個數(shù)據(jù)作為測試集,該特殊情況稱為“留一交叉驗證”。
代碼實現(xiàn)
'''留一交叉驗證'''
import numpy as np
# K折交叉驗證
data = [[12, 1896], [11, 1900], [11, 1904], [10.8, 1908], [10.8, 1912], [10.8, 1920], [10.6, 1924], [10.8, 1928],
[10.3, 1932], [10.3, 1936], [10.3, 1948], [10.4, 1952], [10.5, 1956], [10.2, 1960], [10.0, 1964], [9.95, 1968],
[10.14, 1972], [10.06, 1976], [10.25, 1980], [9.99, 1984], [9.92, 1988], [9.96, 1992], [9.84, 1996],
[9.87, 2000], [9.85, 2004], [9.69, 2008]]
length = len(data)
# 得到訓(xùn)練集和測試集
def Get_test_train(length, data, i):
test_data = data[i] # 測試集
train_data = data[:]
train_data.pop(i) # 訓(xùn)練集
return train_data, test_data
# 得到線性回歸直線
def Get_line(train_data):
time = []
year = []
average_year_time = 0
average_year_year = 0
for i in train_data:
time.append(i[0])
year.append(i[1])
time = np.array(time)
year = np.array(year)
average_year = sum(year) / length # year拔
average_time = sum(time) / length # time拔
for i in train_data:
average_year_time = average_year_time + i[0] * i[1]
average_year_year = average_year_year + i[1] ** 2
average_year_time = average_year_time / length # (year, time)拔
average_year_year = average_year_year / length # (year, year)拔
# 線性回歸:t = w0 + w1 * x
w1 = (average_year_time - average_year * average_time) / (average_year_year - average_year * average_year)
w0 = average_time - w1 * average_year
return w0, w1
# 得到損失函數(shù)
def Get_loss_func(w0, w1, test_data):
time_real = test_data[0]
time_predict = eval('{} + {} * {}'.format(w0, w1, test_data[1]))
loss = (time_predict - time_real) ** 2
dic['t = {} + {}x'.format(w0, w1)] = loss
return dic
if __name__ == '__main__':
dic = {} # 存放建為回歸直線,值為損失函數(shù)的字典
for i in range(length):
train_data, test_data = Get_test_train(length, data, i)
w0, w1 = Get_line(train_data)
Get_loss_func(w0, w1, test_data)
dic = Get_loss_func(w0, w1, test_data)
min_loss = min(dic.values())
best_line = [k for k, v in dic.items() if v == min_loss][0]
print('最佳回歸直線:', best_line)
print('最小損失函數(shù):', min_loss)
留一法交叉驗證 Leave-One-Out Cross Validation
交叉驗證法,就是把一個大的數(shù)據(jù)集分為 k 個小數(shù)據(jù)集,其中 k?1 個作為訓(xùn)練集,剩下的 1 11 個作為測試集,在訓(xùn)練和測試的時候依次選擇訓(xùn)練集和它對應(yīng)的測試集。這種方法也被叫做 k 折交叉驗證法(k-fold cross validation)。最終的結(jié)果是這 k 次驗證的均值。
此外,還有一種交叉驗證方法就是 留一法(Leave-One-Out,簡稱LOO),顧名思義,就是使 k kk 等于數(shù)據(jù)集中數(shù)據(jù)的個數(shù),每次只使用一個作為測試集,剩下的全部作為訓(xùn)練集,這種方法得出的結(jié)果與訓(xùn)練整個測試集的期望值最為接近,但是成本過于龐大。
我們用SKlearn庫來實現(xiàn)一下LOO
from sklearn.model_selection import LeaveOneOut
# 一維示例數(shù)據(jù)
data_dim1 = [1, 2, 3, 4, 5]
# 二維示例數(shù)據(jù)
data_dim2 = [[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5]]
loo = LeaveOneOut() # 實例化LOO對象
# 取LOO訓(xùn)練、測試集數(shù)據(jù)索引
for train_idx, test_idx in loo.split(data_dim1):
# train_idx 是指訓(xùn)練數(shù)據(jù)在總數(shù)據(jù)集上的索引位置
# test_idx 是指測試數(shù)據(jù)在總數(shù)據(jù)集上的索引位置
print("train_index: %s, test_index %s" % (train_idx, test_idx))
# 取LOO訓(xùn)練、測試集數(shù)據(jù)值
for train_idx, test_idx in loo.split(data_dim1):
# train_idx 是指訓(xùn)練數(shù)據(jù)在總數(shù)據(jù)集上的索引位置
# test_idx 是指測試數(shù)據(jù)在總數(shù)據(jù)集上的索引位置
train_data = [data_dim1[i] for i in train_idx]
test_data = [data_dim1[i] for i in test_idx]
print("train_data: %s, test_data %s" % (train_data, test_data))
data_dim1的輸出:
train_index: [1 2 3 4], test_index [0]
train_index: [0 2 3 4], test_index [1]
train_index: [0 1 3 4], test_index [2]
train_index: [0 1 2 4], test_index [3]
train_index: [0 1 2 3], test_index [4]train_data: [2, 3, 4, 5], test_data [1]
train_data: [1, 3, 4, 5], test_data [2]
train_data: [1, 2, 4, 5], test_data [3]
train_data: [1, 2, 3, 5], test_data [4]
train_data: [1, 2, 3, 4], test_data [5]
data_dim2的輸出:
train_index: [1 2 3 4], test_index [0]
train_index: [0 2 3 4], test_index [1]
train_index: [0 1 3 4], test_index [2]
train_index: [0 1 2 4], test_index [3]
train_index: [0 1 2 3], test_index [4]train_data: [[2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]], test_data [[1, 1, 1, 1]]
train_data: [[1, 1, 1, 1], [3, 3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5]], test_data [[2, 2, 2, 2]]
train_data: [[1, 1, 1, 1], [2, 2, 2, 2], [4, 4, 4, 4], [5, 5, 5, 5]], test_data [[3, 3, 3, 3]]
train_data: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [5, 5, 5, 5]], test_data [[4, 4, 4, 4]]
train_data: [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], test_data [[5, 5, 5, 5]]
原文鏈接:https://blog.csdn.net/qq_43650934/article/details/108672624
相關(guān)推薦
- 2021-12-08 服務(wù)器并發(fā)量估算公式和計算方法_服務(wù)器其它
- 2022-07-10 DHCP服務(wù)配置——CentOS/Windows2003
- 2022-01-16 Jquery+Css+Html實現(xiàn)返選、批量刪除、高亮顯示功能
- 2022-10-08 C#中Timer實現(xiàn)Tick使用精度的問題_C#教程
- 2022-10-10 python?pandas數(shù)據(jù)處理之刪除特定行與列_python
- 2022-04-16 Python數(shù)據(jù)結(jié)構(gòu)與算法之跳表詳解_python
- 2022-03-28 Android實現(xiàn)調(diào)取支付寶健康碼_Android
- 2022-02-27 報錯:Unable to find main class
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支