在tensorflow.keras中,callbacks能在fit、evaluate和predict過程中加入伴隨著模型的生命周期運行,目前tensorflow.keras已經(jīng)構(gòu)建了許多種callbacks供用戶使用,用于防止過擬合、可視化訓(xùn)練過程、糾錯、保存模型checkpoints和生成TensorBoard等。通過這篇文章,我們來了解一下如何使用tensorflow.keras里的各種callbacks,以及如何自定義callbacks。
使用callbacks的步驟很簡單,先定義callbacks,然后在model.fit、model.evaluate和model.predict中把定義好的callbacks傳到callbacks參數(shù)里即可。
以最常見的ModelCheckpoint為例,使用過程如下示例:
...
model_checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
filepath=filePath,
save_weights_only=True,
monitor='val_accuracy',
mode='max')
model.fit(x, y, callbacks=model_checkpoint_callback)
這樣在模型訓(xùn)練時,就會將模型checkpoints存儲在對應(yīng)的位置供后續(xù)使用。除了ModelCheckpoint,在Tensorflow 2.0中,還有許多其他類型的callbacks供使用,讓我們一探究竟。
這個callback能監(jiān)控設(shè)定的評價指標(biāo),在訓(xùn)練過程中,評價指標(biāo)不再上升時,訓(xùn)練將會提前結(jié)束,防止模型過擬合,其默認(rèn)參數(shù)如下:
tf.keras.callbacks.EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto',
baseline=None,
restore_best_weights=False)
其中各個參數(shù):
這個callback能在模型訓(xùn)練過程中調(diào)整學(xué)習(xí)率,通常而言,隨著訓(xùn)練次數(shù)的變多,適當(dāng)?shù)亟档蛯W(xué)習(xí)率有利于模型收斂在全局最優(yōu)點,因此這個callback需要搭配一個學(xué)習(xí)率調(diào)度器使用,在每個epoch開始時,schedule函數(shù)會獲取最新的學(xué)習(xí)率并用在當(dāng)前的epoch中:
tf.keras.callbacks.LearningRateScheduler(
schedule, verbose=0
)
# 調(diào)度函數(shù)在10個epoch前調(diào)用初始學(xué)習(xí)率,隨后學(xué)習(xí)率呈指數(shù)下降
def scheduler(epoch, lr):
if epoch < 10:
return lr
else:
return lr * tf.math.exp(-0.1)
model=tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss='mse')
callback=tf.keras.callbacks.LearningRateScheduler(scheduler)
history=model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
epochs=15, callbacks=[callback], verbose=0)
相比于LearningRateScheduler,ReduceLROnPlateau不是按照預(yù)先設(shè)定好的調(diào)度調(diào)整學(xué)習(xí)率,它會在評價指標(biāo)停止提升時降低學(xué)習(xí)率。
tf.keras.callbacks.ReduceLROnPlateau(
monitor='val_loss', factor=0.1, patience=10, verbose=0,
mode='auto', min_delta=0.0001, cooldown=0, min_lr=0, **kwargs
)
其中重要參數(shù):
TensorBoard能很方便地展示模型架構(gòu)、訓(xùn)練過程,這個callback能生成TensorBoard的日志,當(dāng)訓(xùn)練結(jié)束后可以在TensorBoard里查看可視化結(jié)果。
tf.keras.callbacks.TensorBoard(
log_dir='logs', histogram_freq=0, write_graph=True,
write_images=False, write_steps_per_second=False, update_freq='epoch',
profile_batch=2, embeddings_freq=0, embeddings_metadata=None, **kwargs
)
其中重要參數(shù):
顧名思義,這個callback能將訓(xùn)練過程寫入CSV文件。
tf.keras.callbacks.CSVLogger(
filename, separator=',', append=False
)
其中重要參數(shù):
在損失變?yōu)镹aN時停止訓(xùn)練。
tf.keras.callbacks.TerminateOnNaN()
除了上述callback外,還有一些callback可以查詢TensorFlow官網(wǎng)[1],在使用多個callbacks時,可以使用列表將多個callbacks傳入、或者使用tf.keras.callbacks.CallbackList[2]。除此之外,也可以自定義callback,需要繼承keras.callbacks.Callback,然后重寫在不同訓(xùn)練階段的方法。
training_finished=False
class MyCallback(tf.keras.callbacks.Callback):
def on_train_end(self, logs=None):
global training_finished
training_finished=True
model=tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
model.compile(loss='mean_squared_error')
model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
callbacks=[MyCallback()])
assert training_finished==True
本文總結(jié)了若干常用的tf.keras.callbacks,實際工作中,請按需使用,并且查看tf.keras.callbacks的官方文檔確認(rèn)參數(shù)取值。
希望這次的分享對你有幫助,歡迎在評論區(qū)留言討論!
[1] tf.keras.callbacks: 'https://www.tensorflow.org/api_docs/python/tf/keras/callbacks'
[2] tf.keras.callbacks.CallbackList: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/CallbackList
引言
圖像增強(qiáng)是我們在深度學(xué)習(xí)領(lǐng)域中繞不開的一個話題,本文我們將討論什么是圖像增強(qiáng),并在三個不同的 python 庫中實現(xiàn)它,即 Keras、Pytorch 和 augmentation(專門用于圖像增強(qiáng)的一個庫)。所以第一個問題就是什么是圖像增強(qiáng)以及常規(guī)的數(shù)據(jù)增強(qiáng)。
什么是圖像增強(qiáng)?
增強(qiáng)是使規(guī)模或數(shù)量增大的動作或過程。
在深度學(xué)習(xí)中,深度網(wǎng)絡(luò)需要大量的訓(xùn)練數(shù)據(jù)來很好地歸納和達(dá)到良好的準(zhǔn)確性。但在某些情況下,圖像數(shù)據(jù)不夠大。在這種情況下,我們使用一些技術(shù)來增加我們的訓(xùn)練數(shù)據(jù)。它人為地創(chuàng)建訓(xùn)練數(shù)據(jù),使用諸如隨機(jī)旋轉(zhuǎn)、位移、剪切和翻轉(zhuǎn)等技術(shù)處理給定的數(shù)據(jù)(我們將在后面討論其中的一些)。
圖像增強(qiáng)是為了訓(xùn)練我們的深度學(xué)習(xí)模型而產(chǎn)生新圖像的過程。這些新的圖像是使用現(xiàn)有的訓(xùn)練圖像生成的,因此我們不必手動收集它們。
不同的圖像增強(qiáng)技術(shù)
我們可以使用各種技術(shù)來增強(qiáng)圖像。例如:
空間增強(qiáng)
· 縮放
· 翻轉(zhuǎn)
· 旋轉(zhuǎn)
· 剪切
· 平移
像素增強(qiáng)
· 亮度
· 對比度
· 飽和度
· 色調(diào)
深度學(xué)習(xí)中的圖像增強(qiáng)
在深度學(xué)習(xí)中,數(shù)據(jù)增強(qiáng)是一種常見的做法。因此,每個深度學(xué)習(xí)框架都有自己的增強(qiáng)方法,甚至有一個完整的庫。例如,讓我們看看如何使用 Keras、 PyTorch 和 Albumentations 中的內(nèi)置方法應(yīng)用圖像增強(qiáng)。
1. Keras
Keras 的 ImageDataGenerator 類提供了一種快速簡便的方法來增強(qiáng)圖像。它提供了許多不同的增強(qiáng)技術(shù),如標(biāo)準(zhǔn)化、旋轉(zhuǎn)、移位、翻轉(zhuǎn)、亮度變化等等。使用 Keras 的 ImageDataGenerator 類的主要好處是它旨在提供實時數(shù)據(jù)增強(qiáng)。這意味著它會在您的模型處于訓(xùn)練階段時生成增強(qiáng)圖像。
ImageDataGenerator 類確保模型在每個時期接收圖像的新變化。但它只返回轉(zhuǎn)換后的圖像,并沒有將它們添加到原始圖像數(shù)據(jù)集中(如果加入到原始數(shù)據(jù)集,那么模型將多次處理原始圖像,這肯定會使我們的模型過擬合)。 ImageDataGenerator 的另一個優(yōu)點是它的內(nèi)存占用量很低,這是因為不使用此類,我們一次加載所有圖像。但是在使用它時,我們批量加載圖像,這節(jié)省了大量內(nèi)存。
它支持一系列的圖像增強(qiáng)方法,現(xiàn)在我們將專注于五種主要類型的方法,如下所示:
· 通過 width_shift_range 和 height_shift_range 參數(shù)進(jìn)行圖像位移增強(qiáng)。
· 通過 horizontal_flip 和 vertical_flip 參數(shù)進(jìn)行圖像翻轉(zhuǎn)增強(qiáng)。
· 通過 rotation_range 參數(shù)進(jìn)行圖像旋轉(zhuǎn)增強(qiáng)。
· 通過 brightness_range 參數(shù)進(jìn)行圖像亮度增強(qiáng)。
· 通過 zoom_range 參數(shù)進(jìn)行圖像縮放增強(qiáng)。
如下所示,我們可以構(gòu)造 ImageDataGenerator 類的實例。
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
img=load_img('path_directory/img.jpg')
from numpy import expand_dims
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
# load the image
plt.figure(figsize=(45,30))
# convert to numpy array
data=img_to_array(img)
# expand dimension to one sample
samples=expand_dims(data, 0)
# create image data augmentation generator
datagen=ImageDataGenerator(featurewise_center=True,rotation_range=(0-30),width_shift_range=0.2,
height_shift_range=0.2,brightness_range=[0.5,1.5],
shear_range=0.2, zoom_range=0.2,channel_shift_range=0.2,
horizontal_flip=True, vertical_flip=True,fill_mode='nearest')
# prepare iterator
it=datagen.flow(samples, batch_size=1)
# generate samples and plot
for i in range(6):
# define subplot
plt.subplot(330 + 1 + i)
# generate batch of images
batch=it.next()
# convert to unsigned integers for viewing
image=batch[0].astype('uint8')
# plot raw pixel data
plt.imshow(image)
# show the figure
plt.show()
最終將生成如下所示的隨機(jī)增強(qiáng)圖像,并將其提供給模型。
2.Pytorch
PyTorch 是一個基于 Python 的庫,有助于構(gòu)建深度學(xué)習(xí)模型并在各種應(yīng)用程序中使用它們。但它不僅僅是一個深度學(xué)習(xí)庫,還是一個科學(xué)計算庫。
使用 PyTorch 的主要優(yōu)點是我們可以對選定的圖像單獨應(yīng)用圖像增強(qiáng)技術(shù)。
從導(dǎo)入圖像開始,我們將定義 imshow() 函數(shù)來可視化實際和轉(zhuǎn)換后的圖像。
縮放:在縮放或調(diào)整大小時,將圖像調(diào)整為給定的大小。
# scaling
loader_transform=transforms.Resize((500,500))
imshow('path_directory/img.jpg', loader_transform)
裁剪:在裁剪中,選擇圖像的一部分,例如在給定的示例中,返回中心裁剪的圖像。
# cropping
loader_transform=transforms.CenterCrop(size=(600,600))
imshow('path_diectory/img.jpg', loader_transform)
翻轉(zhuǎn):在翻轉(zhuǎn)時,圖像被水平或垂直翻轉(zhuǎn)。
# horizontal flip with probability 1 (default is 0.5)
loader_transform=transforms.RandomHorizontalFlip(p=1)
imshow('path_directory/img.jpg', loader_transform)
像素增強(qiáng):像素增強(qiáng)是通過更改圖像的像素值來改變圖像的顏色屬性。
img=PIL.Image.open('path_directory/img.jpg')
fig, ax=plt.subplots(2, 2, figsize=(16, 10))
# brightness
loader_transform1=transforms.ColorJitter(brightness=2)
img1=loader_transform1(img)
ax[0, 0].set_title(f'brightness')
ax[0, 0].imshow(img1)
# contrast
loader_transform2=transforms.ColorJitter(contrast=2)
img2=loader_transform2(img)
ax[0, 1].set_title(f'contrast')
ax[0, 1].imshow(img2)
# saturation
loader_transform3=transforms.ColorJitter(saturation=2)
img3=loader_transform3(img)
ax[1, 0].set_title(f'saturation')
ax[1, 0].imshow(img3)
fig.savefig('color augmentation', bbox_inches='tight')
# hue
loader_transform4=transforms.ColorJitter(hue=0.2)
img4=loader_transform4(img)
ax[1, 1].set_title(f'hue')
ax[1, 1].imshow(img4)
fig.savefig('color augmentation', bbox_inches='tight')
3. Albumentation
Albumentations 是一種計算機(jī)視覺工具,可提高深度卷積神經(jīng)網(wǎng)絡(luò)的性能。Albumentations 是一個 Python 庫,用于快速靈活的圖像增強(qiáng)。它有效地實現(xiàn)了豐富多樣的圖像變換操作,這些操作針對性能進(jìn)行了優(yōu)化,同時為不同的計算機(jī)視覺任務(wù)提供簡潔而強(qiáng)大的圖像增強(qiáng)接口,包括對象分類、分割和檢測。
#importing all required libraries
import cv2
import random
from matplotlib import pyplot as plt
import albumentations as A
image=cv2.imread('/content/drive/MyDrive/sunil.jpg')
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
翻轉(zhuǎn)
transform=A.HorizontalFlip(p=0.5)
random.seed(7)
augmented_image=transform(image=image)['image']
plt.imshow(augmented_image)
旋轉(zhuǎn)
transform=A.ShiftScaleRotate(p=0.5)
random.seed(7)
augmented_image=transform(image=image)['image']
plt.imshow(augmented_image)
組合增強(qiáng)
transform=A.Compose([
A.RandomCrop(width=500, height=500),
A.RandomBrightnessContrast(p=0.2),
])
random.seed(7)
augmented_image=transform(image=image)['image']
plt.imshow(augmented_image)
總結(jié)
在本文中,我們了解了如何在訓(xùn)練深度學(xué)習(xí)神經(jīng)網(wǎng)絡(luò)時使用圖像數(shù)據(jù)增強(qiáng)。了解如何將圖像增強(qiáng)技術(shù)應(yīng)用于擴(kuò)展訓(xùn)練數(shù)據(jù)集,以提高模型的性能和泛化能力。并且知道如何使用 Keras、Pytorch 和 Albumentation 庫來對圖像進(jìn)行數(shù)據(jù)增強(qiáng)。