你知道嗎?在 iOS 設(shè)備上也可以直接訓(xùn)練 LeNet 卷積神經(jīng)網(wǎng)絡(luò),而且性能一點(diǎn)也不差,iPhone 和 iPad 也能化為實(shí)實(shí)在在的生產(chǎn)力。
機(jī)器學(xué)習(xí)要想在移動(dòng)端上應(yīng)用一般分為如下兩個(gè)階段,第一個(gè)階段是訓(xùn)練模型,第二個(gè)階段是部署模型。常規(guī)的做法是在算力強(qiáng)大的 GPU 或 TPU 上對(duì)模型進(jìn)行訓(xùn)練,之后再使用一系列模型壓縮的方法,將其轉(zhuǎn)換為可在移動(dòng)端上運(yùn)行的模型,并與 APP 連通起來(lái)。Core ML 主要解決的就是最后的模型部署這一環(huán)節(jié),它為開發(fā)者提供了一個(gè)便捷的模型轉(zhuǎn)換工具,可以很方便地將訓(xùn)練好的模型轉(zhuǎn)換為 Core ML 類型的模型文件,實(shí)現(xiàn)模型與 APP 數(shù)據(jù)的互通。
以上是常規(guī)的操作。然而,隨著 iOS 設(shè)備計(jì)算性能的提升,坊間不斷產(chǎn)生一些 iPad Pro 算力超過(guò)普通筆記本的言論。于是乎,就出現(xiàn)了這么一位「勇者」,開源了可以直接在 iOS 設(shè)備上訓(xùn)練神經(jīng)網(wǎng)絡(luò)的項(xiàng)目。
項(xiàng)目作者在 macOS、iOS 模擬器和真實(shí)的 iOS 設(shè)備上進(jìn)行了測(cè)試。用 60000 個(gè) MNIST 樣本訓(xùn)練了 10 個(gè) epoch,在模型架構(gòu)與訓(xùn)練參數(shù)完全相同的前提下,使用 Core ML 在 iPhone 11 上訓(xùn)練大概需要 248 秒,在 i7 MacBook Pro 上使用 TensorFlow 2.0 訓(xùn)練需要 158 秒(僅使用 CPU 的情況下),但準(zhǔn)確率都超過(guò)了 0.98。
當(dāng)然,在 248 秒和 158 秒之間還有非常大的差距,但進(jìn)行此項(xiàng)實(shí)驗(yàn)的目的并不是比速度,而是為了探索用移動(dòng)設(shè)備或可穿戴設(shè)備在本地進(jìn)行訓(xùn)練的可行性,因?yàn)檫@些設(shè)備中的數(shù)據(jù)往往比較敏感,而且涉及隱私,本地訓(xùn)練可以提供更好的安全性。
項(xiàng)目地址:https://github.com/JacopoMangiavacchi/MNIST-CoreML-Training
MNIST 數(shù)據(jù)集
在這篇文章中,作者介紹了如何使用 MNIST 數(shù)據(jù)集部署一個(gè)圖像分類模型,值得注意的是,這個(gè) Core ML 模型是在 iOS 設(shè)備上直接訓(xùn)練的,而無(wú)需提前在其他 ML 框架中進(jìn)行訓(xùn)練。
作者在這里使用了一個(gè)很有名的數(shù)據(jù)集——MNIST 手寫數(shù)字?jǐn)?shù)據(jù)集。它提供了 60000 個(gè)訓(xùn)練樣本和 10000 個(gè)測(cè)試樣本,都是從 0 到 9 的 28x28 手寫數(shù)字黑白圖像。
LeNet CNN 架構(gòu)
如果你想了解 CNN 的細(xì)節(jié)和優(yōu)勢(shì),從 LeNet 架構(gòu)著手是一個(gè)再好不過(guò)的起點(diǎn)。LeNet CNN+MNIST 數(shù)據(jù)集的組合是機(jī)器學(xué)習(xí)「訓(xùn)練」的標(biāo)準(zhǔn)組合,簡(jiǎn)直相當(dāng)于深度學(xué)習(xí)圖像分類的「Hello, World」。
這篇文章主要著眼于如何在 iOS 設(shè)備上直接為 MNIST 數(shù)據(jù)集構(gòu)建和訓(xùn)練一個(gè) LeNet CNN 模型。接下來(lái),研究者將把它與基于著名的 ML 框架(如 TensorFlow)的經(jīng)典「Python」實(shí)現(xiàn)方法進(jìn)行比較。
在 Swift 中為 Core ML 的訓(xùn)練準(zhǔn)備數(shù)據(jù)
在討論如何在 Core ML 中創(chuàng)建及訓(xùn)練 LeNet CNN 網(wǎng)絡(luò)之前,我們可以先看一下如何準(zhǔn)備 MNIST 訓(xùn)練數(shù)據(jù),以將其正確地 batch 至 Core ML 運(yùn)行中去。
在下列 Swift 代碼中,訓(xùn)練數(shù)據(jù)的 batch 是專門為 MNIST 數(shù)據(jù)集準(zhǔn)備的,只需將每個(gè)圖像的「像素」值從 0 到 255 的初始范圍歸一化至 0 到 1 之間的「可理解」范圍即可。
為 Core ML 模型(CNN)訓(xùn)練做準(zhǔn)備
處理好訓(xùn)練數(shù)據(jù)的 batch 并將其歸一化之后,現(xiàn)在就可以使用 SwiftCoreMLTools 庫(kù)在 Swift 的 CNN Core ML 模型中進(jìn)行一系列本地化準(zhǔn)備。
在下列的 SwiftCoreMLTools DSL 函數(shù)構(gòu)建器代碼中,還可以查看在相同的情況中如何傳遞至 Core ML 模型中。同時(shí),也包含了基本的訓(xùn)練信息、超參數(shù)等,如損失函數(shù)、優(yōu)化器、學(xué)習(xí)率、epoch 數(shù)、batch size 等等。
使用 Adam 優(yōu)化器訓(xùn)練神經(jīng)網(wǎng)絡(luò),具體參數(shù)如下:
接下來(lái)是構(gòu)建 CNN 網(wǎng)絡(luò),卷積層、激活與池化層定義如下:
再使用一組與前面相同的卷積、激活與池化操作,之后輸入 Flatten 層,再經(jīng)過(guò)兩個(gè)全連接層后使用 Softmax 輸出結(jié)果。
得到的模型
剛剛構(gòu)建的 Core ML 模型有兩個(gè)卷積和最大池化嵌套層,在將數(shù)據(jù)全部壓平之后,連接一個(gè)隱含層,最后是一個(gè)全連接層,經(jīng)過(guò) Softmax 激活后輸出結(jié)果。
基準(zhǔn) TensorFlow 2.0 模型
為了對(duì)結(jié)果進(jìn)行基準(zhǔn)測(cè)試,尤其是運(yùn)行時(shí)間方面的訓(xùn)練效果,作者還使用 TensorFlow 2.0 重新創(chuàng)建了同一 CNN 模型的精確副本。
下方的的 Python 代碼展示了 TF 中的同一模型架構(gòu)和每層 OutPut Shape 的情況:
可以看到,這里的層、層形狀、卷積過(guò)濾器和池大小與使用 SwiftCoreMLTools 庫(kù)在設(shè)備上創(chuàng)建的 Core ML 模型完全相同。
比較結(jié)果
在查看訓(xùn)練執(zhí)行時(shí)間性能之前,首先確保 Core ML 和 TensorFlow 模型都訓(xùn)練了相同的 epoch 數(shù)(10),用相同的超參數(shù)在相同的 10000 張測(cè)試樣本圖像上獲得非常相似的準(zhǔn)確度度量。
從下面的 Python 代碼中可以看出,TensorFlow 模型使用 Adam 優(yōu)化器和分類交叉熵?fù)p失函數(shù)進(jìn)行訓(xùn)練,測(cè)試用例的最終準(zhǔn)確率結(jié)果大于 0.98。
Core ML 模型的結(jié)果如下圖所示,它使用了和 TensorFlow 相同的優(yōu)化器、損失函數(shù)以及訓(xùn)練集和測(cè)試集,可以看到,其識(shí)別準(zhǔn)確率也超過(guò)了 0.98。
(免責(zé)聲明:本網(wǎng)站內(nèi)容主要來(lái)自原創(chuàng)、合作伙伴供稿和第三方自媒體作者投稿,凡在本網(wǎng)站出現(xiàn)的信息,均僅供參考。本網(wǎng)站將盡力確保所提供信息的準(zhǔn)確性及可靠性,但不保證有關(guān)資料的準(zhǔn)確性及可靠性,讀者在使用前請(qǐng)進(jìn)一步核實(shí),并對(duì)任何自主決定的行為負(fù)責(zé)。本網(wǎng)站對(duì)有關(guān)資料所引致的錯(cuò)誤、不確或遺漏,概不負(fù)任何法律責(zé)任。
任何單位或個(gè)人認(rèn)為本網(wǎng)站中的網(wǎng)頁(yè)或鏈接內(nèi)容可能涉嫌侵犯其知識(shí)產(chǎn)權(quán)或存在不實(shí)內(nèi)容時(shí),應(yīng)及時(shí)向本網(wǎng)站提出書面權(quán)利通知或不實(shí)情況說(shuō)明,并提供身份證明、權(quán)屬證明及詳細(xì)侵權(quán)或不實(shí)情況證明。本網(wǎng)站在收到上述法律文件后,將會(huì)依法盡快聯(lián)系相關(guān)文章源頭核實(shí),溝通刪除相關(guān)內(nèi)容或斷開相關(guān)鏈接。 )