網站首頁 編程語言 正文
前言
metrics用于判斷模型性能。度量函數類似于損失函數,只是度量的結果不用于訓練模型。可以使用任何損失函數作為度量(如logloss等)。在訓練期間監控metrics的最佳方式是通過Tensorboard。
官方提供的metrics最重要的概念就是有狀態(stateful)變量,通過更新狀態變量,可以不斷累積統計數據,并可以隨時輸出狀態變量的計算結果。這是區別于losses的重要特性,losses是無狀態的(stateless)。
本文部分內容參考了:
Keras-Metrics官方文檔
代碼運行環境為:tf.__version__==2.6.2 。
metrics原理解析(以metrics.Mean為例)
metrics是有狀態的(stateful),即Metric 實例會存儲、記錄和返回已經累積的結果,有助于未來事務的信息。下面以tf.keras.metrics.Mean()
為例進行解釋:
創建tf.keras.metrics.Mean
的實例:
m = tf.keras.metrics.Mean()
通過help(m)
可以看到MRO為:
Mean
Reduce
Metric
keras.engine.base_layer.Layer
...
可見Metric和Mean是 keras.layers.Layer
的子類。相比于類Layer,其子類Mean多出了幾個方法:
- result: 計算并返回標量度量值(tensor形式)或標量字典,即狀態變量簡單地計算度量值。例如,
m.result()
,就是計算均值并返回。 - total: 狀態變量
m
目前累積的數字總和 - count: 狀態變量
m
目前累積的數字個數(m.total/m.count
就是m.result()
的返回值) - update_state: 累積統計數字用于計算指標。每次調用
m.update_state
都會更新m.total
和m.count
; - reset_state: 將狀態變量重置到初始化狀態;
- reset_states: 等價于reset_state,參見keras源代碼metrics.py L355
- reduction: 目前來看,沒什么用。
這也決定了Mean的特殊性質。其使用參見如下代碼:
# 創建狀態變量m,由于m未剛初始化, # 所以total,count和result()均為0 m = tf.keras.metrics.Mean() print("m.total:",m.total) print("m.count:",m.count) print("m.result():",m.result())
"""
# 輸出:
m.total: <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0>
m.count: <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0>
m.result(): tf.Tensor(0.0, shape=(), dtype=float32)
"""
# 更新狀態變量,可以看到total累加了總和, # count累積了個數,result()返回total/count m.update_state([1,2,3]) print("m.total:",m.total) print("m.count:",m.count) print("m.result():",m.result())
"""
# 輸出:
m.total: <tf.Variable 'total:0' shape=() dtype=float32, numpy=6.0>
m.count: <tf.Variable 'count:0' shape=() dtype=float32, numpy=3.0>
m.result(): tf.Tensor(2.0, shape=(), dtype=float32)
"""
# 重置狀態變量, 重置到初始化狀態 m.reset_state() print("m.total:",m.total) print("m.count:",m.count) print("m.result():",m.result())
"""
# 輸出:
m.total: <tf.Variable 'total:0' shape=() dtype=float32, numpy=0.0>
m.count: <tf.Variable 'count:0' shape=() dtype=float32, numpy=0.0>
m.result(): tf.Tensor(0.0, shape=(), dtype=float32)
"""
創建自定義metrics
創建無狀態 metrics
與損失函數類似,任何帶有類似于metric_fn(y_true, y_pred)
、返回損失數組(如輸入一個batch的數據,會返回一個batch的損失標量)的函數,都可以作為metric傳遞給compile()
:
import tensorflow as tf import numpy as np inputs = tf.keras.Input(shape=(3,)) x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) outputs = tf.keras.layers.Dense(1, activation=tf.nn.softmax)(x) model1 = tf.keras.Model(inputs=inputs, outputs=outputs) def my_metric_fn(y_true, y_pred): squared_difference = tf.square(y_true - y_pred) return tf.reduce_mean(squared_difference, axis=-1) # shape=(None,) model1.compile(optimizer='adam', loss='mse', metrics=[my_metric_fn]) x = np.random.random((100, 3)) y = np.random.random((100, 1)) model1.fit(x, y, epochs=3)
輸出:
Epoch 1/3
4/4 [==============================] - 0s 667us/step - loss: 0.0971 - my_metric_fn: 0.0971
Epoch 2/3
4/4 [==============================] - 0s 667us/step - loss: 0.0958 - my_metric_fn: 0.0958
Epoch 3/3
4/4 [==============================] - 0s 1ms/step - loss: 0.0946 - my_metric_fn: 0.0946
注意,因為本例創建的是無狀態的度量,所以上面跟蹤的度量值(my_metric_fn后面的值)是每個batch的平均度量值,并不是一個epoch(完整數據集)的累積值。(這一點需要理解,這也是為什么要使用有狀態度量的原因!)
值得一提的是,如果上述代碼使用
model1.compile(optimizer='adam', loss='mse', metrics=["mse"])
進行compile,則輸出的結果是累積的,在每個epoch結束時的結果就是整個數據集的結果,因為metrics=["mse"]
是直接調用了標準庫的有狀態度量。
通過繼承Metric創建有狀態metrics
如果想查看整個數據集的指標,就需要傳入有狀態的metrics,這樣就會在一個epoch內累加,并在epoch結束時輸出整個數據集的度量值。
創建有狀態度量指標,需要創建Metric的子類,它可以跨batch維護狀態,步驟如下:
- 在
__init__
中創建狀態變量(state variables) - 更新
update_state()
中y_true
和y_pred
的變量 - 在
result()
中返回標量度量結果 - 在
reset_states()
中清除狀態
class BinaryTruePositives(tf.keras.metrics.Metric): def __init__(self, name='binary_true_positives', **kwargs): super(BinaryTruePositives, self).__init__(name=name, **kwargs) self.true_positives = self.add_weight(name='tp', initializer='zeros') def update_state(self, y_true, y_pred, sample_weight=None): y_true = tf.cast(y_true, tf.bool) y_pred = tf.cast(y_pred, tf.bool) values = tf.logical_and(tf.equal(y_true, True), tf.equal(y_pred, True)) values = tf.cast(values, self.dtype) if sample_weight is not None: sample_weight = tf.cast(sample_weight, self.dtype) values = tf.multiply(values, sample_weight) self.true_positives.assign_add(tf.reduce_sum(values)) def result(self): return self.true_positives def reset_states(self): self.true_positives.assign(0) m = BinaryTruePositives() m.update_state([0, 1, 1, 1], [0, 1, 0, 0]) print('Intermediate result:', float(m.result())) m.update_state([1, 1, 1, 1], [0, 1, 1, 0]) print('Final result:', float(m.result()))
add_metric()方法
add_metric
方法是 tf.keras.layers.Layer
類添加的方法,Layer的父類tf.Module
并沒有這個方法,因此在編寫Layer子類如包括自定義層、官方提供的層(Dense)或模型(tf.keras.Model也是Layer的子類)時,可以使用add_metric()
來與層相關的統計量。比如,將類似Dense的自定義層的激活平均值記錄為metric。可以執行以下操作:
class DenseLike(Layer): """y = w.x + b""" ... def call(self, inputs): output = tf.matmul(inputs, self.w) + self.b self.add_metric(tf.reduce_mean(output), aggregation='mean', name='activation_mean') return output
將在名稱為activation_mean的度量下跟蹤output,跟蹤的值為每個批次度量值的平均值。
更詳細的信息,參閱官方文檔The base Layer class - add_metric method。
參考
Keras-Metrics官方文檔
原文鏈接:https://blog.csdn.net/u012762410/article/details/127609836
相關推薦
- 2022-06-14 jquery實現點擊按鈕顯示與隱藏效果_jquery
- 2023-02-12 Pytorch建模過程中的DataLoader與Dataset示例詳解_python
- 2022-01-17 類組件與函數組件的區別 react中class創建的組件與function創建的組件有什么區別
- 2022-07-25 .Net行為型設計模式之訪問者模式(Visitor)_基礎應用
- 2022-04-19 C#多線程系列之線程的創建和生命周期_C#教程
- 2022-06-22 Git用戶簽名的修改取消及優先級拓展教程_其它綜合
- 2024-07-14 Guava自加載緩存LoadingCache
- 2022-12-26 Python標準庫os常用函數和屬性詳解_python
- 最近更新
-
- window11 系統安裝 yarn
- 超詳細win安裝深度學習環境2025年最新版(
- Linux 中運行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎操作-- 運算符,流程控制 Flo
- 1. Int 和Integer 的區別,Jav
- spring @retryable不生效的一種
- Spring Security之認證信息的處理
- Spring Security之認證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權
- redisson分布式鎖中waittime的設
- maven:解決release錯誤:Artif
- restTemplate使用總結
- Spring Security之安全異常處理
- MybatisPlus優雅實現加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務發現-Nac
- Spring Security之基于HttpR
- Redis 底層數據結構-簡單動態字符串(SD
- arthas操作spring被代理目標對象命令
- Spring中的單例模式應用詳解
- 聊聊消息隊列,發送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠程分支