您的位置:首頁>科技>正文

用Keras+TF,實現ImageNet資料集日常物件的識別

王新民 編譯自 Deep Learning Sandbox博客

在電腦視覺領域裡, 有3個最受歡迎且影響非常大的學術競賽:ImageNet ILSVRC(大規模視覺識別挑戰賽), PASCAL VOC(關於模式分析, 統計建模和計算學習的研究)和微軟COCO圖像識別大賽。

這些比賽大大地推動了在電腦視覺研究中的多項發明和創新, 其中很多都是免費開源的。

博客Deep Learning Sandbox作者Greg Chu打算通過一篇文章, 教你用Keras和TensorFlow, 實現對ImageNet資料集中日常物體的識別。

量子位翻譯了這篇文章:

你想識別什麼?

看看ILSVRC競賽中包含的物體物件。 如果你要研究的物體物件是該清單1001個物件中的一個, 運氣真好, 可以獲得大量該類別圖像資料!以下是這個資料集包含的部分類別:

狗熊椅子汽車鍵盤箱子嬰兒床旗杆iPod播放機輪船麵包車項鍊降落傘枕頭桌子錢包球拍步槍校車薩克斯管足球襪子舞臺火爐火把吸塵器自動售貨機眼鏡紅綠燈菜肴盤子西蘭花紅酒

△ 表1 ImageNet ILSVRC的類別摘錄

完整類別列表見:https://gist.github.com/gregchu/134677e041cd78639fea84e3e619415b

如果你研究的物體物件不在該清單中, 或者像醫學圖像分析中具有多種差異較大的背景, 遇到這些情況該怎麼辦?可以借助遷移學習(transfer learning)和微調(fine-tuning), 我們以後再另外寫文章講。

圖像識別

圖像識別, 或者說物體識別是什麼?它回答了一個問題:“這張圖像中描繪了哪幾個物體物件?”如果你研究的是基於圖像內容進行標記, 確定盤子上的食物類型, 對癌症患者或非癌症患者的醫學圖像進行分類, 以及更多的實際應用, 那麼就能用到圖像識別。

Keras和TensorFlow

Keras是一個高級神經網路庫, 能夠作為一種簡單好用的抽象層, 接入到數值計算庫TensorFlow中。 另外, 它可以通過其keras.applications模組獲取在ILSVRC競賽中獲勝的多個卷積網路模型, 如由Microsoft Research開發的ResNet50網路和由Google Research開發的InceptionV3網路,

這一切都是免費和開源的。 具體安裝參照以下說明進行操作:

Keras安裝:https://keras.io/#installation

TensorFlow安裝:https://www.tensorflow.org/install/

實現過程

我們的最終目標是編寫一個簡單的python程式, 只需要輸入本地影像檔的路徑或是圖像的URL連結就能實現物體識別。

以下是輸入非洲大象照片的示例:

1. python classify.py --image African_Bush_Elephant.jpg

2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

輸入:

輸出將如下所示:

△ 該圖像最可能的前3種預測類別及其相應概率

預測功能

我們接下來要載入ResNet50網路模型。 首先, 要載入keras.preprocessing和keras.applications.resnet50模組, 並使用在ImageNet ILSVRC比賽中已經訓練好的權重。

想瞭解ResNet50的原理, 可以閱讀論文《基於深度殘差網路的圖像識別》。 地址:https://arxiv.org/pdf/1512.03385.pdf

import numpy as np

from keras.preprocessing import image

from keras.applications.resnet50

import ResNet50, preprocess_input, decode_predictions model = ResNet50(weights='imagenet')

接下來定義一個預測函數:

def predict(model, img, target_size, top_n=3): """Run model prediction on image Args: model: keras model img: PIL format image target_size: (width, height) tuple top_n: # of top predictions to return Returns: list of predicted labels and their probabilities """ if img.size != target_size: img = img.resize(target_size) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = preprocess_input(x) preds = model.predict(x)

return decode_predictions(preds, top=top_n)[0]

在使用ResNet50網路結構時需要注意, 輸入大小target_size必須等於(224,224)。 許多CNN網路結構具有固定的輸入大小, ResNet50正是其中之一, 作者將輸入大小定為(224,224)。

image.img_to_array:將PIL格式的圖像轉換為numpy陣列。

np.expand_dims:將我們的(3, 224, 224)大小的圖像轉換為(1, 3, 224, 224)。 因為model.predict函數需要4維陣列作為輸入, 其中第4維為每批預測圖像的數量。 這也就是說, 我們可以一次性分類多個圖像。

preprocess_input:使用訓練資料集中的平均通道值對圖像資料進行零值處理, 即使得圖像所有點的和為0。

這是非常重要的步驟, 如果跳過, 將大大影響實際預測效果。 這個步驟稱為資料歸一化。

model.predict:對我們的資料分批次處理並返回預測值。

decode_predictions:採用與model.predict函數相同的編碼標籤, 並從ImageNet ILSVRC集返回可讀的標籤。

keras.applications模組還提供4種結構:ResNet50、InceptionV3、VGG16、VGG19和XCeption, 你可以用其中任何一種替換ResNet50。 更多資訊可以參考https://keras.io/applications/。

繪圖

我們可以使用matplotlib函式程式庫將預測結果做成柱狀圖, 如下所示:

def plot_preds(image, preds):主體部分

為了實現以下從網路中載入圖片的功能:

1. python classify.py --image African_Bush_Elephant.jpg

2. python classify.py --image_url http://i.imgur.com/wpxMwsR.jpg

我們將定義主函數如下:

if __name__=="__main__": a = argparse.ArgumentParser() a.add_argument("--image",

help="path to image") a.add_argument("--image_url",

help="url to image") args = a.parse_args()

if args.image is None and args.image_url is None: a.print_help() sys.exit(1)

if args.image is not None: img = Image.open(args.image) print_preds(predict(model, img, target_size))

if args.image_url is not None: response = requests.get(args.image_url) img = Image.open(BytesIO(response.content)) print_preds(predict(model, img, target_size))

其中在寫入image_url完工

將上述代碼組合起來, 你就創建了一個圖像識別系統。 專案的完整程式和示例圖像請查看GitHub連結:

https://github.com/DeepLearningSandbox/DeepLearningSandbox/tree/master/image_recognition

招聘

我們正在招募編輯記者、運營等崗位, 工作地點在北京中關村, 期待你的到來, 一起體驗人工智慧的風起雲湧。

One More Thing…

一起體驗人工智慧的風起雲湧。

One More Thing…

Next Article
喜欢就按个赞吧!!!
点击关闭提示