您的位置:首頁>育兒>正文

如何用神經網路“尋找威利”(附代碼實現)

《威利在哪裡?》(Where’s Wally)是由英國插畫家馬丁·漢德福特(Martin Handford)創作的一套兒童繪本。 這個書的目標就是在一張人山人海的圖片中找出一個特定的人物——威利(Wally)。 “Where’s Wally”的商標已在28個國家進行了註冊, 為方便語言翻譯, 每一個國家都會給威利起一個新名字, 最成功的是北美版的“Where’s Waldo”, 在這裡, 威利改名成了沃爾多(Waldo)。

現在, 機器學習博主Tadej Magajna另闢蹊徑, 利用深度學習解開“威利在哪裡”的問題。 與傳統的電腦視覺影像處理方法不同的是, 它只使用了少數幾個標記出威利位置的圖片樣本, 就訓練成了一套“尋找威利”的系統。

訓練過的圖像評估模型和檢測腳本發佈在作者的GitHub repo上。

本文介紹了用TensorFlow物體檢測API訓練神經網路、並用相應的Python腳本尋找威利的過程。 大致分為以下幾步:

將圖片打標籤後創建資料集, 其中標籤注明了威利在圖片中的位置, 用x, y表示;

用TensorFlow物體檢測API獲取並配置神經網路模型;

在資料集上訓練模型;

用匯出的圖像測試模型;

開始前, 請確保你已經按照說明安裝了TensorFlow物體檢測API。

創建資料集

雖說深度學習中最重要的環節是處理神經網路, 但不幸的是, 資料科學家們總要花費大量時間準備訓練資料。

最簡單的機器學習問題最終得到的通常是一個標量(如數字檢測器)或是一個分類字串。 TensorFlow物體檢測API在訓練資料是則將上述兩個結果結合了起來。 它由一系列圖像組成, 並包含目標物件的標籤和他們在圖像中的位置。 由於在二維圖像中, 兩個點足以在物件周圍繪製邊界框, 所以圖像的定位只有兩個點。

為了創建訓練集, 我們需要準備一組Where’s Wally的插畫, 並標出威利的位置。 在此之前已經有人做出了一套解出威利在哪裡的訓練集。

最右邊的四列描述了威利所在的位置

創建資料集的最後一步就是將標籤(.csv)和圖片(.jpeg)打包, 存入單一二分類檔中(.tfrecord)。 詳細過程可參考這裡, 訓練和評估過程也可以在作者的GitHub上找到。

準備模型

TensorFlow物體檢測API提供了一組性能不同的模型, 它們要麼精度高, 但速度慢, 要麼速度快, 但精度低。 這些模型都在公開資料集上經過了預訓練。

雖然模型可以從頭開始訓練, 隨機初始化網路權重, 但這可能需要幾周的時間。 相反, 這裡作者採用了一種稱為遷移學習(Transfer Learning)的方法。

這種方法是指, 用一個經常訓練的模型解決一般性問題, 然後再將它重新訓練, 用於解決我們的問題。 也就是說, 與其從頭開始訓練新模型,

不如從預先訓練過的模型中獲取知識, 將其轉移到新模型的訓練中, 這是一種非常節省時間的方法。

作者使用了在COCO資料集上訓練過的搭載Inception v2模型的RCNN。 該模型包含一個.ckpycheckpoint檔, 可以利用它開始訓練。

設定檔下載完成後, 請確保將“PATHTOBE_CONFIGURED”欄位替換成指向checkpoint文件、訓練和評估的.tfrecord檔和標籤映射檔的路徑。

最後需要配置的檔是labels.txt映射檔, 其中包含我們所有不同物件的標籤。 由於我們尋找的都是同一個類型的物件(威利), 所以標籤檔如下:

item { id: 1 name: 'waldo'}

最終應該得到:

一個有著checkpoint檔的預訓練模型;

經過訓練並評估的.tfrecord資料集;

標籤映射文件;

指向上述檔的設定檔。

然後就可以開始訓練啦。

訓練

TensorFlow物體檢測API提供了一個十分容易上手的Python腳本, 可以在本地訓練模型。

它位於models/research/object_detection中, 可以通過以下命令運行:

python train.py --logtostderr --pipeline_config_path= PATH_TO_PIPELINE_CONFIG --train_dir=PATH_TO_TRAIN_DIR

PATH_TO_PIPELINE_CONFIG是通往設定檔的路徑, PATH_TO_TRAIN_DIR是新創建的directory, 用來儲存checkpoint和模型。

train.py的輸出看起來是這樣:

用最重要的資訊查看是否有損失, 這是各個樣本在訓練或驗證時出現錯誤的總和。 當然, 你肯定希望它降得越低越好, 因為如果它在緩慢地下降, 就意味著你的模型正在學習(要麼就是過擬合了你的資料……)。

你還可以用Tensorboard顯示更詳細的訓練資料。

腳本將在一定時間後自動存儲checkpoint檔, 萬一電腦半路崩潰, 你還可以恢復這些檔。 也就是說, 當你想完成模型的訓練時, 隨時都可以終止腳本。

但是什麼時候停止學習呢?一般是當我們的評估集損失停止減少或達到非常低的時候(在這個例子中低於0.01)。

測試

現在, 我們可以將模型用於實際測試啦。

首先,我們需要從儲存的checkpoint中輸出一個推理圖(interference graph),利用的腳本如下:

python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH

產生的推理圖就是用來Python腳本用來找到威利的工具。

作者寫了幾個簡單目標定位的腳本,其中find_wally.py和find_wally_pretty.py都可以在他的GitHub上找到,並且運行起來也很簡單:

python find_wally.py

或者

python find_wally_pretty.py

不過當你在自己的模型或圖像上運行腳本時,記得改變model-path和image-path的變數。

結語

模型的表現出乎意料地好。它不僅從資料集中成功地找到了威利,還能在隨機從網上找的圖片中找到威利。

但是如果威利在圖中特別大,模型就找不到了。我們總覺得,不應該是目標物體越大越好找嗎?這樣的結果表明,作者用於訓練的圖像並不多,模型可能對訓練資料過度擬合了。

首先,我們需要從儲存的checkpoint中輸出一個推理圖(interference graph),利用的腳本如下:

python export_inference_graph.py — pipeline_config_path PATH_TO_PIPELINE_CONFIG --trained_checkpoint_prefix PATH_TO_CHECPOINT --output_directory OUTPUT_PATH

產生的推理圖就是用來Python腳本用來找到威利的工具。

作者寫了幾個簡單目標定位的腳本,其中find_wally.py和find_wally_pretty.py都可以在他的GitHub上找到,並且運行起來也很簡單:

python find_wally.py

或者

python find_wally_pretty.py

不過當你在自己的模型或圖像上運行腳本時,記得改變model-path和image-path的變數。

結語

模型的表現出乎意料地好。它不僅從資料集中成功地找到了威利,還能在隨機從網上找的圖片中找到威利。

但是如果威利在圖中特別大,模型就找不到了。我們總覺得,不應該是目標物體越大越好找嗎?這樣的結果表明,作者用於訓練的圖像並不多,模型可能對訓練資料過度擬合了。

Next Article
喜欢就按个赞吧!!!
点击关闭提示