網(wǎng)站首頁 編程語言 正文
本文實(shí)例為大家分享了python基于numpy的線性回歸的具體代碼,供大家參考,具體內(nèi)容如下
class類中包含:
創(chuàng)建數(shù)據(jù)
參數(shù)初始化
計(jì)算輸出值,損失值,dw,db
預(yù)測函數(shù)
交叉驗(yàn)證函數(shù)
其中用到的數(shù)據(jù)集為sklearn中的糖尿病數(shù)據(jù)集
具體代碼如下:
import numpy as np
from sklearn.utils import shuffle
from sklearn.datasets import load_diabetes
import matplotlib.pyplot as plt
#基于numpy實(shí)現(xiàn)一個(gè)簡單的線性回歸模型
#用class進(jìn)行簡單封裝
class lr_model():
? ? def __init__(self):
? ? ? ? pass
? ? # diabetes 是一個(gè)關(guān)于糖尿病的數(shù)據(jù)集, 該數(shù)據(jù)集包括442個(gè)病人的生理數(shù)據(jù)及一年以后的病情發(fā)展情況。
? ? # 數(shù)據(jù)集中的特征值總共10項(xiàng), 如下:
? ? # 年齡
? ? # 性別
? ? # 體質(zhì)指數(shù)
? ? # 血壓
? ? # s1,s2,s3,s4,s4,s6 ?(六種血清的化驗(yàn)數(shù)據(jù))
? ? # 但請注意,以上的數(shù)據(jù)是經(jīng)過特殊處理, 10個(gè)數(shù)據(jù)中的每個(gè)都做了均值中心化處理,然后又用標(biāo)準(zhǔn)差乘以個(gè)體數(shù)量調(diào)整了數(shù)值范圍。驗(yàn)證就會發(fā)現(xiàn)任何一列的所有數(shù)值平方和為1.
? ? def prepare_data(self):
? ? ? ? data = load_diabetes().data
? ? ? ? target = load_diabetes().target
? ? ? ? #數(shù)據(jù)打亂
? ? ? ? X, y = shuffle(data, target, random_state=42)
? ? ? ? X = X.astype(np.float32)
? ? ? ? y = y.reshape((-1, 1))#標(biāo)簽變成列向量形式
? ? ? ? data = np.concatenate((X, y), axis=1)#橫向變?yōu)閿?shù)據(jù)標(biāo)簽的行向量
? ? ? ? return data
? ? ?#初始化參數(shù),權(quán)值與偏執(zhí)初始化
? ? def initialize_params(self, dims):
? ? ? ? w = np.zeros((dims, 1))
? ? ? ? b = 0
? ? ? ? return w, b
? ? def linear_loss(self, X, y, w, b):
? ? ? ? num_train = X.shape[0]#行數(shù)訓(xùn)練數(shù)目
? ? ? ? num_feature = X.shape[1]#列數(shù)表示特征值數(shù)目
? ? ? ? y_hat = np.dot(X, w) + b#y=w*x+b
? ? ? ? loss = np.sum((y_hat - y) ** 2) / num_train#計(jì)算損失函數(shù)
? ? ? ? dw = np.dot(X.T, (y_hat - y)) / num_train#計(jì)算梯度
? ? ? ? db = np.sum((y_hat - y)) / num_train
? ? ? ? return y_hat, loss, dw, db
? ? def linear_train(self, X, y, learning_rate, epochs):
? ? ? ? w, b = self.initialize_params(X.shape[1])#參數(shù)初始化
? ? ? ? loss_list = []
? ? ? ? for i in range(1, epochs):
? ? ? ? ? ? y_hat, loss, dw, db = self.linear_loss(X, y, w, b)
? ? ? ? ? ? w += -learning_rate * dw
? ? ? ? ? ? b += -learning_rate * db#參數(shù)更新
? ? ? ? ? ? loss_list.append(loss)
? ? ? ? if i % 10000 == 0:#每到一定輪數(shù)進(jìn)行打印輸出
? ? ? ? ? ? print('epoch %d loss %f' % (i, loss))
? ? ? ? #參數(shù)保存
? ? ? ? params = {
? ? ? ? ? ? 'w': w,
? ? ? ? ? ? 'b': b
? ? ? ? }
? ? ? ? grads = {
? ? ? ? ? ? 'dw': dw,
? ? ? ? ? ? 'db': db
? ? ? ? }
? ? ? ? return loss, params, grads,loss_list
? ? #預(yù)測函數(shù)
? ? def predict(self, X, params):
? ? ? ? w = params['w']
? ? ? ? b = params['b']
? ? ? ? y_pred = np.dot(X, w) + b
? ? ? ? return y_pred
? ?#隨機(jī)交叉驗(yàn)證函數(shù),如何選測試集、訓(xùn)練集
? ? def linear_cross_validation(self, data, k, randomize=True):
? ? ? ? if randomize:
? ? ? ? ? ? data = list(data)
? ? ? ? ? ? shuffle(data)
? ? ? ? slices = [data[i::k] for i in range(k)]#k為step
? ? ? ? for i in range(k):
? ? ? ? ? ? validation = slices[i]
? ? ? ? ? ? train = [data for s in slices if s is not validation for data in s]#將不為測試集的數(shù)據(jù)作為訓(xùn)練集
? ? ? ? ? ? train = np.array(train)
? ? ? ? ? ? validation = np.array(validation)
? ? ? ? ? ? yield train, validation#yield 變?yōu)榭傻?每次返回
if __name__ == '__main__':
? ? lr = lr_model()
? ? data = lr.prepare_data()
? ? for train, validation in lr.linear_cross_validation(data, 5):
? ? ? ? X_train = train[:, :10]
? ? ? ? y_train = train[:, -1].reshape((-1, 1))
? ? ? ? X_valid = validation[:, :10]
? ? ? ? y_valid = validation[:, -1].reshape((-1, 1))
? ? ? ? loss5 = []
? ? ? ? loss, params, grads,loss_list = lr.linear_train(X_train, y_train, 0.001, 100000)
? ? ? ? plt.plot(loss_list, color='blue')
? ? ? ? plt.xlabel('epochs')
? ? ? ? plt.ylabel('loss')
? ? ? ? plt.show()
? ? ? ? loss5.append(loss)
? ? ? ? score = np.mean(loss5)
? ? ? ? print('five kold cross validation score is', score)#5類數(shù)據(jù)的測試分?jǐn)?shù)
? ? ? ? y_pred = lr.predict(X_valid, params)
? ? ? ? plt.scatter(range(X_valid.shape[0]),y_valid)
? ? ? ? plt.scatter(range(X_valid.shape[0]),y_pred,color='red')
? ? ? ? plt.xlabel('x')
? ? ? ? plt.ylabel('y')
? ? ? ? plt.show()
? ? ? ? valid_score = np.sum(((y_pred - y_valid) ** 2)) / len(X_valid)
? ? ? ? print('valid score is', valid_score)
結(jié)果如下:
原文鏈接:https://blog.csdn.net/exsolar_521/article/details/108201269
相關(guān)推薦
- 2022-10-07 ASP.NET?MVC使用Knockout獲取數(shù)組元素索引的2種方法_實(shí)用技巧
- 2022-08-15 oracle數(shù)據(jù)庫表實(shí)現(xiàn)自增主鍵的方法實(shí)例_oracle
- 2022-07-22 SQL?Server使用CROSS?APPLY與OUTER?APPLY實(shí)現(xiàn)連接查詢_MsSql
- 2022-11-07 PostgreSQL長事務(wù)概念解析_PostgreSQL
- 2022-07-26 在Pycharm set ops_config=local之后,直接echo %ops_config
- 2022-11-21 Python?Flask實(shí)現(xiàn)圖片驗(yàn)證碼與郵箱驗(yàn)證碼流程詳細(xì)講解_python
- 2022-04-18 html2canvas 畫圖出現(xiàn)空白的情況,引出圖片跨域的相關(guān)問題
- 2022-07-30 os.path模塊下的顯示路徑方法
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡單動態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支