您的位置:首頁>科技>正文

TensorFlow極簡教程:創建、保存和恢復機器學習模型

選自Github

機器之心編譯

參與:Jane W、李澤南

TensorFlow 是一個由穀歌發佈的機器學習框架, 在這篇文章中, 我們將闡述 TensorFlow 的一些本質概念。 相信你不會找到比本文更簡單的介紹。

TensorFlow 機器學習範例——Naked Tensor

連結:https://github.com/jostmey/NakedTensor?bare

在每個例子中, 我們用一條直線擬合一些資料。 使用梯度下降(gradient descent)確定最適合資料的線的斜率和 y 截距的值。 如果你不知道梯度下降, 請查看維琪百科:

https://en.wikipedia.org/wiki/Gradient_descent

創建所需的變數後, 資料和線之間的誤差是可以被定義(計算)的。 定義的誤差被嵌入到優化器(optimizer)中。 然後啟動 TensorFlow, 並重複調用優化器。 通過不斷反覆運算最小化誤差來達到資料與直線的最佳擬合。

按照順序閱讀下列腳本:

serial.py

tensor.py

bigdata.py

Serial.py

這個腳本的目的是說明 TensorFlow 模型的基本要點。 這個腳本使你更容易理解模型是如何組合在一起的。 我們使用 for 迴圈來定義資料與線之間的誤差。 由於定義誤差的方式為迴圈, 該腳本以序列化(串列)計算的方式運行。

Tensor.py

這個腳本比 serial.py 更進一步, 雖然實際上這個腳本的代碼行更少。 代碼的結構與之前相同, 唯一不同的是這次使用張量(tensor)操作來定義誤差。 使用張量可以並行(parallel)運行代碼。

每個數據點被看作是來自獨立同分佈的樣本。 因為每個數據點假定是獨立的, 所以計算也是獨立的。 當使用張量時, 每個數據點都在分隔的計算內核上運行。 我們有 8 個資料點, 所以如果你有一個有八個內核的電腦, 它的運行速度應該快八倍。

BigData.py

你現在距離專業水準僅有一個流行語之遙。 我們現在不需要將一條線擬合到 8 個資料點, 而是將一條線擬合到 800 萬個資料點。 歡迎來到大數據時代。

代碼中有兩處主要的修改。 第一點變化是簿記(bookkeeping), 因為所有資料必須使用預留位置(placeholder)而不是實際資料來定義誤差。

在代碼的後半部分, 資料需要通過預留位置饋送(feed)入模型。 第二點變化是, 因為我們的資料量是巨大的, 在給定的任意時間我們僅將一個樣本數據傳入模型。 每次調用梯度下降操作時, 新的資料樣本將被饋送到模型中。 通過對資料集進行抽樣, TensorFlow 不需要一次處理整個資料集。 這樣抽樣的效果出奇的好, 並有理論支持這種方法:

https://en.wikipedia.org/wiki/Stochastic_gradient_descent

理論上需要滿足一些重要的條件, 如步長(step size)必須隨每次反覆運算而縮短。 不管是否滿足條件, 這種方法至少是有效的。

結論

當你運行腳本時, 你可能看到怎樣定義任何你想要的誤差。 它可能是一組圖像和卷積神經網路(convolutional neural network)之間的誤差。 它可能是古典音樂和迴圈神經網路(recurrent neural network)之間的誤差。

它讓你的想像力瘋狂。 一旦定義了誤差, 你就可以使用 TensorFlow 進行嘗試並最小化誤差。

希望你從這個教程中得到啟發。

需求

Python3 (https://www.python.org/)

TensorFlow (https://www.tensorflow.org/)

NumPy (http://www.numpy.org/)

TensorFlow:保存/恢復和混合多重模型

在第一個模型成功建立並訓練之後, 你或許需要瞭解如何保存與恢復這些模型。 繼續之前, 也可以閱讀這個 Tensorflow 小入門:

https://blog.metaflow.fr/tensorflow-a-primer-4b3fa0978be3#.wxlmweb8h

你有必要瞭解這些資訊, 因為瞭解如何保存不同級別的代碼是非常重要的, 這可以避免混亂無序。

如何實際保存和載入

保存(saver)物件

可以使用 Saver 物件處理不同會話(session)中任何與檔案系統有持續資料傳輸的交互。 構造函數(constructor)允許你控制以下 3 個事物:

目標(target):在分散式架構的情況下用於處理計算。 可以指定要計算的 TF 伺服器或「目標」。

圖(graph):你希望會話處理的圖。

對於初學者來說, 棘手的事情是:TF 中總存在一個默認的圖, 其中所有操作的設置都是默認的, 所以你的操作範圍總在一個「默認的圖」中。

配置(config):你可以使用 ConfigProto 配置 TF。 查看本文最後的連結資源以獲取更多詳細資訊。

Saver 可以處理圖的中繼資料和變數資料的保存和載入(又稱恢復)。 它需要知道的唯一的事情是:需要使用哪個圖和變數?

預設情況下, Saver 會處理默認的圖及其所有包含的變數, 但是你可以創建盡可能多的 Saver 來控制你想要的任何圖或子圖的變數。 這裡是一個例子:

import tensorflow as tf

import os

dir = os.path.dirname(os.path.realpath(__file__))

# First, you design your mathematical operations

# We are the default graph scope

# Let's design a variable

v1 = tf.Variable(1. , name="v1")

v2 = tf.Variable(2. , name="v2")

# Let's design an operation

a = tf.add(v1, v2)

# Let's create a Saver object

# By default, the Saver handles every Variables related to the default graph

all_saver = tf.train.Saver()

# But you can precise which vars you want to save under which name

v2_saver = tf.train.Saver({"v2": v2})

# By default the Session handles the default graph and all its included variables

with tf.Session() as sess:

# Init v and v2

sess.run(tf.global_variables_initializer())

# Now v1 holds the value 1.0 and v2 holds the value 2.0

# We can now save all those values

all_saver.save(sess, dir + '/data-all.chkp')

# or saves only v2

v2_saver.save(sess, dir + '/data-v2.chkp')

如果查看你的資料夾, 它實際上每創建 3 個檔調用一次保存操作並創建一個檢查點(checkpoint)檔, 我會在附錄中講述更多的細節。 你可以簡單理解為權重被保存到 .chkp.data 檔中, 你的圖和中繼資料被保存到 .chkp.meta 文件中。

恢復操作和其它中繼資料

一個重要的資訊是,Saver 將保存與你的圖相關聯的任何中繼資料。這意味著載入元檢查點還將恢復與圖相關聯的所有空變數、操作和集合(例如,它將恢復訓練優化器)。

當你恢復一個元檢查點時,實際上是將保存的圖載入到當前默認的圖中。現在你可以通過它來載入任何包含的內容,如張量、操作或集合。

import tensorflow as tf

# Let's load a previously saved meta graph in the default graph

# This function returns a Saver

saver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')

# We can now access the default graph where all our metadata has been loaded

graph = tf.get_default_graph()

# Finally we can retrieve tensors, operations, collections, etc.

global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')

train_op = graph.get_operation_by_name('loss/train_op')

hyperparameters = tf.get_collection('hyperparameters')

恢復權重

請記住,實際的權重只存在於一個會話中。這意味著「恢復」操作必須能夠訪問會話以恢復圖內的權重。理解恢復操作的最好方法是將其簡單地當作一種初始化。

with tf.Session() as sess:

# To initialize values with saved data

saver.restore(sess, 'results/model.ckpt.data-1000-00000-of-00001')

print(sess.run(global_step_tensor)) # returns 1000

在新圖中使用預訓練圖

現在你知道了如何保存和載入,你可能已經明白如何去操作。然而,這裡有一些技巧能夠幫助你走得更快。

一個圖的輸出可以是另一個圖的輸入嗎?

是的,但有一個缺點:我還不知道使梯度流(gradient flow)在圖之間容易傳遞的一種方法,因為你將必須評估第一個圖,獲得結果,並將其饋送到下一個圖。

這樣一直下去是可以的,直到你需要重新訓練第一個圖。在這種情況下,你將需要將輸入梯度饋送到第一個圖的訓練步驟……

我可以在一個圖中混合所有這些不同的圖嗎?

是的,但你需要對命名空間(namespace)倍加小心。好的一點是,這種方法簡化了一切:例如,你可以載入預訓練的 VGG-16,訪問圖中的任何節點,嵌入自己的操作和訓練整個圖!

如果你只想微調(fine-tune)節點,你可以在任意地方停止梯度來避免訓練整個圖。

import tensorflow as tf

# Load the VGG-16 model in the default graph

vgg_saver = tf.train.import_meta_graph(dir + 'gg/resultsgg-16.meta')

# Access the graph

vgg_graph = tf.get_default_graph()

# Retrieve VGG inputs

self.x_plh = vgg_graph.get_tensor_by_name('input:0')

# Choose which node you want to connect your own graph

output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')

# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')

# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')

# output_conv =vgg_graph.get_tensor_by_name('conv4_3:0')

# output_conv =vgg_graph.get_tensor_by_name('conv5_3:0')

# Stop the gradient for fine-tuning

output_conv_sg = tf.stop_gradient(output_conv) # It's an identity function

# Build further operations

output_conv_shape = output_conv_sg.get_shape().as_list()

W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))

b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))

z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1

a = tf.nn.relu(z1)

附錄:更多關於 TF 資料生態系統的內容

我們在這裡談論穀歌,他們主要使用內部構建的工具來處理他們的工作,所以資料保存的格式為 ProtoBuff 也是不奇怪的。

協定緩衝區

協定緩衝區(Protocol Buffer/簡寫 Protobufs)是 TF 有效存儲和傳輸資料的常用方式。

我不在這裡詳細介紹它,但可以把它當成一個更快的 JSON 格式,當你在存儲/傳輸時需要節省空間/頻寬,你可以壓縮它。簡而言之,你可以使用 Protobufs 作為:

一種未壓縮的、人性化的文本格式,副檔名為 .pbtxt

一種壓縮的、機器友好的二進位格式,副檔名為 .pb 或根本沒有副檔名

這就像在開發設置中使用 JSON,並且在遷移到生產環境時為了提高效率而壓縮資料一樣。用 Protobufs 可以做更多的事情,如果你有興趣可以查看教程

整潔的小技巧:在張量流中處理 protobufs 的所有操作都有這個表示「協議緩衝區定義」的「_def」尾碼。例如,要載入保存的圖的 protobufs,可以使用函數:tf.import_graph_def。要獲取當前圖作為 protobufs,可以使用:Graph.as_graph_def()。

檔的架構

回到 TF,當保存你的資料時,你會得到 5 種不同類型的檔:

「檢查點」檔

「事件(event)」檔

「文本 protobufs」檔

一些「chkp」檔

一些「元 chkp」檔

現在讓我們休息一下。當你想到,當你在做機器學習時可能會保存什麼?你可以保存模型的架構和與其關聯的學習到的權重。你可能希望在訓練或事件整個訓練架構時保存一些訓練特徵,如模型的損失(loss)和準確率(accuracy)。你可能希望保存超參數和其它操作,以便之後重新啟動訓練或重複實現結果。這正是 TensorFlow 的作用。

在這裡,檢查點文件的三種類型用於存儲模型及其權重有關的壓縮後資料。

檢查點檔只是一個簿記檔,你可以結合使用高級輔助程式載入不同時間保存的 chkp 檔。

元 chkp 檔包含模型的壓縮 Protobufs 圖以及所有與之關聯的中繼資料(集合、學習速率、操作等)。

chkp 檔保存資料(權重)本身(這一個通常是相當大的大小)。

如果你想做一些調試,pbtxt 檔只是模型的非壓縮 Protobufs 圖。

最後,事件檔在 TensorBoard 中存儲了所有你需要用來視覺化模型和訓練時測量的所有資料。這與保存/恢復模型本身無關。

下面讓我們看一下結果資料夾的螢幕截圖:

一些隨機訓練的結果資料夾的螢幕截圖

該模型已經在步驟 433,858,1000 被保存了 3 次。為什麼這些數字看起來像隨機?因為我設定每 S 秒保存一次模型,而不是每 T 次反覆運算後保存。

chkp 檔比元 chkp 檔更大,因為它包含我們模型的權重

pbtxt 檔比元 chkp 檔大一點:它被認為是非壓縮版本!

TF 自帶多個方便的幫助方法,如:

在時間和反覆運算中處理模型的不同檢查點。它如同一個救生員,以防你的機器在訓練結束前崩潰。

注意:TensorFlow 現在發展很快,這些文章目前是基於 1.0.0 版本編寫的。

參考資源

http://stackoverflow.com/questions/38947658/tensorflow-saving-into-loading-a-graph-from-a-file

http://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow?rq=1

http://stackoverflow.com/questions/39468640/tensorflow-freeze-graph-py-the-name-save-const0-refers-to-a-tensor-which-doe?rq=1

http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python

http://stackoverflow.com/questions/34500052/tensorflow-saving-and-restoring-session?noredirect=1&lq=1

http://stackoverflow.com/questions/35687678/using-a-pre-trained-word-embedding-word2vec-or-glove-in-tensorflow

https://github.com/jtoy/awesome-tensorflow

你的圖和中繼資料被保存到 .chkp.meta 文件中。

恢復操作和其它中繼資料

一個重要的資訊是,Saver 將保存與你的圖相關聯的任何中繼資料。這意味著載入元檢查點還將恢復與圖相關聯的所有空變數、操作和集合(例如,它將恢復訓練優化器)。

當你恢復一個元檢查點時,實際上是將保存的圖載入到當前默認的圖中。現在你可以通過它來載入任何包含的內容,如張量、操作或集合。

import tensorflow as tf

# Let's load a previously saved meta graph in the default graph

# This function returns a Saver

saver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')

# We can now access the default graph where all our metadata has been loaded

graph = tf.get_default_graph()

# Finally we can retrieve tensors, operations, collections, etc.

global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')

train_op = graph.get_operation_by_name('loss/train_op')

hyperparameters = tf.get_collection('hyperparameters')

恢復權重

請記住,實際的權重只存在於一個會話中。這意味著「恢復」操作必須能夠訪問會話以恢復圖內的權重。理解恢復操作的最好方法是將其簡單地當作一種初始化。

with tf.Session() as sess:

# To initialize values with saved data

saver.restore(sess, 'results/model.ckpt.data-1000-00000-of-00001')

print(sess.run(global_step_tensor)) # returns 1000

在新圖中使用預訓練圖

現在你知道了如何保存和載入,你可能已經明白如何去操作。然而,這裡有一些技巧能夠幫助你走得更快。

一個圖的輸出可以是另一個圖的輸入嗎?

是的,但有一個缺點:我還不知道使梯度流(gradient flow)在圖之間容易傳遞的一種方法,因為你將必須評估第一個圖,獲得結果,並將其饋送到下一個圖。

這樣一直下去是可以的,直到你需要重新訓練第一個圖。在這種情況下,你將需要將輸入梯度饋送到第一個圖的訓練步驟……

我可以在一個圖中混合所有這些不同的圖嗎?

是的,但你需要對命名空間(namespace)倍加小心。好的一點是,這種方法簡化了一切:例如,你可以載入預訓練的 VGG-16,訪問圖中的任何節點,嵌入自己的操作和訓練整個圖!

如果你只想微調(fine-tune)節點,你可以在任意地方停止梯度來避免訓練整個圖。

import tensorflow as tf

# Load the VGG-16 model in the default graph

vgg_saver = tf.train.import_meta_graph(dir + 'gg/resultsgg-16.meta')

# Access the graph

vgg_graph = tf.get_default_graph()

# Retrieve VGG inputs

self.x_plh = vgg_graph.get_tensor_by_name('input:0')

# Choose which node you want to connect your own graph

output_conv =vgg_graph.get_tensor_by_name('conv1_2:0')

# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')

# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')

# output_conv =vgg_graph.get_tensor_by_name('conv4_3:0')

# output_conv =vgg_graph.get_tensor_by_name('conv5_3:0')

# Stop the gradient for fine-tuning

output_conv_sg = tf.stop_gradient(output_conv) # It's an identity function

# Build further operations

output_conv_shape = output_conv_sg.get_shape().as_list()

W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))

b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))

z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1

a = tf.nn.relu(z1)

附錄:更多關於 TF 資料生態系統的內容

我們在這裡談論穀歌,他們主要使用內部構建的工具來處理他們的工作,所以資料保存的格式為 ProtoBuff 也是不奇怪的。

協定緩衝區

協定緩衝區(Protocol Buffer/簡寫 Protobufs)是 TF 有效存儲和傳輸資料的常用方式。

我不在這裡詳細介紹它,但可以把它當成一個更快的 JSON 格式,當你在存儲/傳輸時需要節省空間/頻寬,你可以壓縮它。簡而言之,你可以使用 Protobufs 作為:

一種未壓縮的、人性化的文本格式,副檔名為 .pbtxt

一種壓縮的、機器友好的二進位格式,副檔名為 .pb 或根本沒有副檔名

這就像在開發設置中使用 JSON,並且在遷移到生產環境時為了提高效率而壓縮資料一樣。用 Protobufs 可以做更多的事情,如果你有興趣可以查看教程

整潔的小技巧:在張量流中處理 protobufs 的所有操作都有這個表示「協議緩衝區定義」的「_def」尾碼。例如,要載入保存的圖的 protobufs,可以使用函數:tf.import_graph_def。要獲取當前圖作為 protobufs,可以使用:Graph.as_graph_def()。

檔的架構

回到 TF,當保存你的資料時,你會得到 5 種不同類型的檔:

「檢查點」檔

「事件(event)」檔

「文本 protobufs」檔

一些「chkp」檔

一些「元 chkp」檔

現在讓我們休息一下。當你想到,當你在做機器學習時可能會保存什麼?你可以保存模型的架構和與其關聯的學習到的權重。你可能希望在訓練或事件整個訓練架構時保存一些訓練特徵,如模型的損失(loss)和準確率(accuracy)。你可能希望保存超參數和其它操作,以便之後重新啟動訓練或重複實現結果。這正是 TensorFlow 的作用。

在這裡,檢查點文件的三種類型用於存儲模型及其權重有關的壓縮後資料。

檢查點檔只是一個簿記檔,你可以結合使用高級輔助程式載入不同時間保存的 chkp 檔。

元 chkp 檔包含模型的壓縮 Protobufs 圖以及所有與之關聯的中繼資料(集合、學習速率、操作等)。

chkp 檔保存資料(權重)本身(這一個通常是相當大的大小)。

如果你想做一些調試,pbtxt 檔只是模型的非壓縮 Protobufs 圖。

最後,事件檔在 TensorBoard 中存儲了所有你需要用來視覺化模型和訓練時測量的所有資料。這與保存/恢復模型本身無關。

下面讓我們看一下結果資料夾的螢幕截圖:

一些隨機訓練的結果資料夾的螢幕截圖

該模型已經在步驟 433,858,1000 被保存了 3 次。為什麼這些數字看起來像隨機?因為我設定每 S 秒保存一次模型,而不是每 T 次反覆運算後保存。

chkp 檔比元 chkp 檔更大,因為它包含我們模型的權重

pbtxt 檔比元 chkp 檔大一點:它被認為是非壓縮版本!

TF 自帶多個方便的幫助方法,如:

在時間和反覆運算中處理模型的不同檢查點。它如同一個救生員,以防你的機器在訓練結束前崩潰。

注意:TensorFlow 現在發展很快,這些文章目前是基於 1.0.0 版本編寫的。

參考資源

http://stackoverflow.com/questions/38947658/tensorflow-saving-into-loading-a-graph-from-a-file

http://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow?rq=1

http://stackoverflow.com/questions/39468640/tensorflow-freeze-graph-py-the-name-save-const0-refers-to-a-tensor-which-doe?rq=1

http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python

http://stackoverflow.com/questions/34500052/tensorflow-saving-and-restoring-session?noredirect=1&lq=1

http://stackoverflow.com/questions/35687678/using-a-pre-trained-word-embedding-word2vec-or-glove-in-tensorflow

https://github.com/jtoy/awesome-tensorflow

Next Article
喜欢就按个赞吧!!!
点击关闭提示