網(wǎng)站首頁 編程語言 正文
本文實(shí)例為大家分享了基于numpy實(shí)現(xiàn)邏輯回歸的具體代碼,供大家參考,具體內(nèi)容如下
交叉熵?fù)p失函數(shù);sigmoid激勵(lì)函數(shù)
基于numpy的邏輯回歸的程序如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets.samples_generator import make_classification
class logistic_regression():
? ? def __init__(self):
? ? ? ? pass
? ? def sigmoid(self, x):
? ? ? ? z = 1 /(1 + np.exp(-x))
? ? ? ? return z
? ? def initialize_params(self, dims):
? ? ? ? W = np.zeros((dims, 1))
? ? ? ? b = 0
? ? ? ? return W, b
? ? def logistic(self, X, y, W, b):
? ? ? ? num_train = X.shape[0]
? ? ? ? num_feature = X.shape[1]
? ? ? ? a = self.sigmoid(np.dot(X, W) + b)
? ? ? ? cost = -1 / num_train * np.sum(y * np.log(a) + (1 - y) * np.log(1 - a))
? ? ? ? dW = np.dot(X.T, (a - y)) / num_train
? ? ? ? db = np.sum(a - y) / num_train
? ? ? ? cost = np.squeeze(cost)#[]列向量,易于plot
? ? ? ? return a, cost, dW, db
? ? def logistic_train(self, X, y, learning_rate, epochs):
? ? ? ? W, b = self.initialize_params(X.shape[1])
? ? ? ? cost_list = []
? ? ? ? for i in range(epochs):
? ? ? ? ? ? a, cost, dW, db = self.logistic(X, y, W, b)
? ? ? ? ? ? W = W - learning_rate * dW
? ? ? ? ? ? b = b - learning_rate * db
? ? ? ? ? ? if i % 100 == 0:
? ? ? ? ? ? ? ? cost_list.append(cost)
? ? ? ? ? ? if i % 100 == 0:
? ? ? ? ? ? ? ? print('epoch %d cost %f' % (i, cost))
? ? ? ? params = {
? ? ? ? ? ? 'W': W,
? ? ? ? ? ? 'b': b
? ? ? ? }
? ? ? ? grads = {
? ? ? ? ? ? 'dW': dW,
? ? ? ? ? ? 'db': db
? ? ? ? }
? ? ? ? return cost_list, params, grads
? ? def predict(self, X, params):
? ? ? ? y_prediction = self.sigmoid(np.dot(X, params['W']) + params['b'])
? ? ? ? #二分類
? ? ? ? for i in range(len(y_prediction)):
? ? ? ? ? ? if y_prediction[i] > 0.5:
? ? ? ? ? ? ? ? y_prediction[i] = 1
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? y_prediction[i] = 0
? ? ? ? return y_prediction
? ? #精確度計(jì)算
? ? def accuracy(self, y_test, y_pred):
? ? ? ? correct_count = 0
? ? ? ? for i in range(len(y_test)):
? ? ? ? ? ? for j in range(len(y_pred)):
? ? ? ? ? ? ? ? if y_test[i] == y_pred[j] and i == j:
? ? ? ? ? ? ? ? ? ? correct_count += 1
? ? ? ? accuracy_score = correct_count / len(y_test)
? ? ? ? return accuracy_score
? ? #創(chuàng)建數(shù)據(jù)
? ? def create_data(self):
? ? ? ? X, labels = make_classification(n_samples=100, n_features=2, n_redundant=0, n_informative=2)
? ? ? ? labels = labels.reshape((-1, 1))
? ? ? ? offset = int(X.shape[0] * 0.9)
? ? ? ? #訓(xùn)練集與測(cè)試集的劃分
? ? ? ? X_train, y_train = X[:offset], labels[:offset]
? ? ? ? X_test, y_test = X[offset:], labels[offset:]
? ? ? ? return X_train, y_train, X_test, y_test
? ? #畫圖函數(shù)
? ? def plot_logistic(self, X_train, y_train, params):
? ? ? ? n = X_train.shape[0]
? ? ? ? xcord1 = []
? ? ? ? ycord1 = []
? ? ? ? xcord2 = []
? ? ? ? ycord2 = []
? ? ? ? for i in range(n):
? ? ? ? ? ? if y_train[i] == 1:#1類
? ? ? ? ? ? ? ? xcord1.append(X_train[i][0])
? ? ? ? ? ? ? ? ycord1.append(X_train[i][1])
? ? ? ? ? ? else:#0類
? ? ? ? ? ? ? ? xcord2.append(X_train[i][0])
? ? ? ? ? ? ? ? ycord2.append(X_train[i][1])
? ? ? ? fig = plt.figure()
? ? ? ? ax = fig.add_subplot(111)
? ? ? ? ax.scatter(xcord1, ycord1, s=32, c='red')
? ? ? ? ax.scatter(xcord2, ycord2, s=32, c='green')#畫點(diǎn)
? ? ? ? x = np.arange(-1.5, 3, 0.1)
? ? ? ? y = (-params['b'] - params['W'][0] * x) / params['W'][1]#畫二分類直線
? ? ? ? ax.plot(x, y)
? ? ? ? plt.xlabel('X1')
? ? ? ? plt.ylabel('X2')
? ? ? ? plt.show()
if __name__ == "__main__":
? ? model = logistic_regression()
? ? X_train, y_train, X_test, y_test = model.create_data()
? ? print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
? ? # (90, 2)(90, 1)(10, 2)(10, 1)
? ? #訓(xùn)練模型
? ? cost_list, params, grads = model.logistic_train(X_train, y_train, 0.01, 1000)
? ? print(params)
? ? #計(jì)算精確度
? ? y_train_pred = model.predict(X_train, params)
? ? accuracy_score_train = model.accuracy(y_train, y_train_pred)
? ? print('train accuracy is:', accuracy_score_train)
? ? y_test_pred = model.predict(X_test, params)
? ? accuracy_score_test = model.accuracy(y_test, y_test_pred)
? ? print('test accuracy is:', accuracy_score_test)
? ? model.plot_logistic(X_train, y_train, params)
結(jié)果如下所示:
原文鏈接:https://blog.csdn.net/exsolar_521/article/details/108206644
相關(guān)推薦
- 2022-01-28 寶塔的定時(shí)任務(wù),如何設(shè)置秒數(shù)級(jí)別執(zhí)行?
- 2022-08-15 Spring之基于注解裝配Bean
- 2022-10-28 React中使用react-file-viewer問題_React
- 2022-04-25 .Net?Core?Aop之IResourceFilter的具體使用_實(shí)用技巧
- 2022-05-10 原生ajax 在服務(wù)器響應(yīng)前撤銷請(qǐng)求
- 2023-03-03 PostgreSQL死鎖了怎么辦及處理方法_PostgreSQL
- 2022-04-18 python?request?post?列表的方法詳解_python
- 2022-06-09 FreeRTOS實(shí)時(shí)操作系統(tǒng)的任務(wù)通知方法_操作系統(tǒng)
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支