網站首頁 編程語言 正文
本文實例為大家分享了基于numpy實現邏輯回歸的具體代碼,供大家參考,具體內容如下
交叉熵損失函數;sigmoid激勵函數
基于numpy的邏輯回歸的程序如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets.samples_generator import make_classification
class logistic_regression():
? ? def __init__(self):
? ? ? ? pass
? ? def sigmoid(self, x):
? ? ? ? z = 1 /(1 + np.exp(-x))
? ? ? ? return z
? ? def initialize_params(self, dims):
? ? ? ? W = np.zeros((dims, 1))
? ? ? ? b = 0
? ? ? ? return W, b
? ? def logistic(self, X, y, W, b):
? ? ? ? num_train = X.shape[0]
? ? ? ? num_feature = X.shape[1]
? ? ? ? a = self.sigmoid(np.dot(X, W) + b)
? ? ? ? cost = -1 / num_train * np.sum(y * np.log(a) + (1 - y) * np.log(1 - a))
? ? ? ? dW = np.dot(X.T, (a - y)) / num_train
? ? ? ? db = np.sum(a - y) / num_train
? ? ? ? cost = np.squeeze(cost)#[]列向量,易于plot
? ? ? ? return a, cost, dW, db
? ? def logistic_train(self, X, y, learning_rate, epochs):
? ? ? ? W, b = self.initialize_params(X.shape[1])
? ? ? ? cost_list = []
? ? ? ? for i in range(epochs):
? ? ? ? ? ? a, cost, dW, db = self.logistic(X, y, W, b)
? ? ? ? ? ? W = W - learning_rate * dW
? ? ? ? ? ? b = b - learning_rate * db
? ? ? ? ? ? if i % 100 == 0:
? ? ? ? ? ? ? ? cost_list.append(cost)
? ? ? ? ? ? if i % 100 == 0:
? ? ? ? ? ? ? ? print('epoch %d cost %f' % (i, cost))
? ? ? ? params = {
? ? ? ? ? ? 'W': W,
? ? ? ? ? ? 'b': b
? ? ? ? }
? ? ? ? grads = {
? ? ? ? ? ? 'dW': dW,
? ? ? ? ? ? 'db': db
? ? ? ? }
? ? ? ? return cost_list, params, grads
? ? def predict(self, X, params):
? ? ? ? y_prediction = self.sigmoid(np.dot(X, params['W']) + params['b'])
? ? ? ? #二分類
? ? ? ? for i in range(len(y_prediction)):
? ? ? ? ? ? if y_prediction[i] > 0.5:
? ? ? ? ? ? ? ? y_prediction[i] = 1
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? y_prediction[i] = 0
? ? ? ? return y_prediction
? ? #精確度計算
? ? def accuracy(self, y_test, y_pred):
? ? ? ? correct_count = 0
? ? ? ? for i in range(len(y_test)):
? ? ? ? ? ? for j in range(len(y_pred)):
? ? ? ? ? ? ? ? if y_test[i] == y_pred[j] and i == j:
? ? ? ? ? ? ? ? ? ? correct_count += 1
? ? ? ? accuracy_score = correct_count / len(y_test)
? ? ? ? return accuracy_score
? ? #創建數據
? ? def create_data(self):
? ? ? ? X, labels = make_classification(n_samples=100, n_features=2, n_redundant=0, n_informative=2)
? ? ? ? labels = labels.reshape((-1, 1))
? ? ? ? offset = int(X.shape[0] * 0.9)
? ? ? ? #訓練集與測試集的劃分
? ? ? ? X_train, y_train = X[:offset], labels[:offset]
? ? ? ? X_test, y_test = X[offset:], labels[offset:]
? ? ? ? return X_train, y_train, X_test, y_test
? ? #畫圖函數
? ? def plot_logistic(self, X_train, y_train, params):
? ? ? ? n = X_train.shape[0]
? ? ? ? xcord1 = []
? ? ? ? ycord1 = []
? ? ? ? xcord2 = []
? ? ? ? ycord2 = []
? ? ? ? for i in range(n):
? ? ? ? ? ? if y_train[i] == 1:#1類
? ? ? ? ? ? ? ? xcord1.append(X_train[i][0])
? ? ? ? ? ? ? ? ycord1.append(X_train[i][1])
? ? ? ? ? ? else:#0類
? ? ? ? ? ? ? ? xcord2.append(X_train[i][0])
? ? ? ? ? ? ? ? ycord2.append(X_train[i][1])
? ? ? ? fig = plt.figure()
? ? ? ? ax = fig.add_subplot(111)
? ? ? ? ax.scatter(xcord1, ycord1, s=32, c='red')
? ? ? ? ax.scatter(xcord2, ycord2, s=32, c='green')#畫點
? ? ? ? x = np.arange(-1.5, 3, 0.1)
? ? ? ? y = (-params['b'] - params['W'][0] * x) / params['W'][1]#畫二分類直線
? ? ? ? ax.plot(x, y)
? ? ? ? plt.xlabel('X1')
? ? ? ? plt.ylabel('X2')
? ? ? ? plt.show()
if __name__ == "__main__":
? ? model = logistic_regression()
? ? X_train, y_train, X_test, y_test = model.create_data()
? ? print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
? ? # (90, 2)(90, 1)(10, 2)(10, 1)
? ? #訓練模型
? ? cost_list, params, grads = model.logistic_train(X_train, y_train, 0.01, 1000)
? ? print(params)
? ? #計算精確度
? ? y_train_pred = model.predict(X_train, params)
? ? accuracy_score_train = model.accuracy(y_train, y_train_pred)
? ? print('train accuracy is:', accuracy_score_train)
? ? y_test_pred = model.predict(X_test, params)
? ? accuracy_score_test = model.accuracy(y_test, y_test_pred)
? ? print('test accuracy is:', accuracy_score_test)
? ? model.plot_logistic(X_train, y_train, params)
結果如下所示:
原文鏈接:https://blog.csdn.net/exsolar_521/article/details/108206644
相關推薦
- 2022-10-06 Python?Numpy中數組的集合操作詳解_python
- 2022-05-02 構建及部署jenkins?pipeline實現持續集成持續交付腳本_服務器其它
- 2022-06-01 詳解Pandas中stack()和unstack()的使用技巧_python
- 2023-03-26 數據結構TypeScript之棧和隊列詳解_其它
- 2023-03-25 React錯誤邊界Error?Boundaries_React
- 2022-10-01 Go語言異步API設計的扇入扇出模式詳解_Golang
- 2022-07-29 linux目錄管理方法介紹_linux shell
- 2022-08-03 Django框架中表單的用法_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支