網站首頁 編程語言 正文
1.數據采集和標記
先采集數據,再對數據進行標記。其中采集數據要就有代表性,以確保最終訓練出來模型的準確性。
2.特征選擇
選擇特征的直觀方法:直接使用圖片的每個像素點作為一個特征。
數據保存為樣本個數×特征個數格式的array對象。scikit-learn使用Numpy的array對象來表示數據,所有的圖片數據保存在digits.images里,每個元素都為一個8×8尺寸的灰階圖片。
3.數據清洗
把采集到的、不合適用來做機器學習訓練的數據進行預處理,從而轉換為合適機器學習的數據。
目的:減少計算量,確保模型穩定性。
4.模型選擇
對于不同的數據集,選擇不同的模型有不同的效率。因此在選擇模型要考慮很多的因素,來提高最終選擇模型的契合度。
5.模型訓練
在進行模型訓練之前,要將數據集劃分為訓練數據集和測試數據集,再利用劃分好的數據集進行模型訓練,最后得到我們訓練出來的模型參數。
6.模型測試
模型測試的直觀方法:用訓練出來的模型預測測試數據集,然后將預測出來的結果與真正的結果進行比較,最后比較出來的結果即為模型的準確度。
scikit-learn提供的完成這項工作的方法:
clf . score ( Xtest , Ytest)
除此之外,還可以直接把測試數據集里的部分圖片顯示出來,并且在圖片的左下角顯示預測值,右下角顯示真實值。
7.模型保存與加載
當我們訓練出一個滿意的模型后即可將模型保存下來,這樣當下次需要預測時,可以直接利用此模型進行預測,不用再一次進行模型訓練。
8.實例
數據采集和標記
#導入庫
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
"""
sk-learn庫中自帶了一些數據集
此處使用的就是手寫數字識別圖片的數據
"""
# 導入sklearn庫中datasets模塊
from sklearn import datasets
# 利用datasets模塊中的函數load_digits()進行數據加載
digits = datasets.load_digits()
# 把數據所代表的圖片顯示出來
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
plt.subplot(2, 4, index + 1)
plt.axis('off')
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.title('Digit: %i' % label, fontsize=20);
特征選擇
# 將數據保存為 樣本個數x特征個數 格式的array對象 的數據格式進行輸出
# 數據已經保存在了digits.data文件中
print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))
模型訓練
# 把數據分成訓練數據集和測試數據集(此處將數據集的百分之二十作為測試數據集)
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2);
# 使用支持向量機來訓練模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100., probability=True)
# 使用訓練數據集Xtrain和Ytrain來訓練模型
clf.fit(Xtrain, Ytrain);
模型測試
"""
sklearn.metrics.accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)
normalize:默認值為True,返回正確分類的比例;如果為False,返回正確分類的樣本數
"""
# 評估模型的準確度(此處默認為true,直接返回正確的比例,也就是模型的準確度)
from sklearn.metrics import accuracy_score
# predict是訓練后返回預測結果,是標簽值。
Ypred = clf.predict(Xtest);
accuracy_score(Ytest, Ypred)
模型保存與加載
"""
將測試數據集里的部分圖片顯示出來
圖片的左下角顯示預測值,右下角顯示真實值
"""
# 查看預測的情況
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
transform=ax.transAxes,
color='green' if Ypred[i] == Ytest[i] else 'red')
ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
transform=ax.transAxes,
color='black')
ax.set_xticks([])
ax.set_yticks([])
# 保存模型參數
import joblib
joblib.dump(clf, 'digits_svm.pkl');
保存模型參數過程中出現如下錯誤:
原因:sklearn.externals.joblib函數是用在0.21及以前的版本中,在最新的版本,該函數應被棄用。
解決方法:將 from sklearn.externals import joblib改為 import joblib
# 導入模型參數,直接進行預測
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Xtest, Ytest)
原文鏈接:https://blog.csdn.net/m0_65187443/article/details/125923433
相關推薦
- 2022-04-01 Python 中 __name__ == '__main__' 的作用
- 2023-04-10 Android序列化接口Parcelable與Serializable接口對比_Android
- 2023-05-23 Numpy數組轉置的實現_python
- 2022-03-31 C語言值傳遞和地址傳遞詳解_C 語言
- 2022-10-25 Go語言實戰學習之流程控制詳解_Golang
- 2023-06-04 Pandas通過index選擇并獲取行和列_python
- 2023-12-25 Spring 之 @Cacheable 緩存使用教程
- 2022-07-23 Python代碼實現雙鏈表_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支