網(wǎng)站首頁(yè) 編程語(yǔ)言 正文
如何將pytorch模型部署到安卓上
這篇文章演示如何將訓(xùn)練好的pytorch模型部署到安卓設(shè)備上。我也是剛開始學(xué)安卓,代碼寫的簡(jiǎn)單。
環(huán)境:
pytorch版本:1.10.0
模型轉(zhuǎn)化
pytorch_android支持的模型是.pt模型,我們訓(xùn)練出來的模型是.pth。所以需要轉(zhuǎn)化才可以用。先看官網(wǎng)上給的轉(zhuǎn)化方式:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")
這個(gè)模型在安卓對(duì)應(yīng)的包:
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
注:pytorch_android_lite版本和轉(zhuǎn)化模型用的版本要一致,不一致就會(huì)報(bào)各種錯(cuò)誤。
目前用這種方法有點(diǎn)問題,我采用的另一種方法。
轉(zhuǎn)化代碼如下:
import torch
import torch.utils.data.distributed
# pytorch環(huán)境中
model_pth = 'model_31_0.96.pth' #模型的參數(shù)文件
mobile_pt ='model.pt' # 將模型保存為Android可以調(diào)用的文件
model = torch.load(model_pth)
model.eval() # 模型設(shè)為評(píng)估模式
device = torch.device('cpu')
model.to(device)
# 1張3通道224*224的圖片
input_tensor = torch.rand(1, 3, 224, 224) # 設(shè)定輸入數(shù)據(jù)格式
mobile = torch.jit.trace(model, input_tensor) # 模型轉(zhuǎn)化
mobile.save(mobile_pt) # 保存文件
對(duì)應(yīng)的包:
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
定義模型文件和轉(zhuǎn)化后的文件路徑。
load模型。這里要注意,如果保存模型
torch.save(model,'models.pth')
加載模型則是
model=torch.load('models.pth')
如果保存模型是
torch.save(model.state_dict(),"models.pth")
加載模型則是
model.load_state_dict(torch.load('models.pth'))
定義輸入數(shù)據(jù)格式。
模型轉(zhuǎn)化,然后再保存模型。
安卓部署
新建項(xiàng)目
新建安卓項(xiàng)目,選擇Empy Activity,然后選擇Next
然后,填寫項(xiàng)目信息,選擇安卓版本,我用的4.4,點(diǎn)擊完成
導(dǎo)入包
導(dǎo)入pytorch_android的包
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
如果有參數(shù)報(bào)錯(cuò)請(qǐng)參照我的完整的配置,代碼如下:
plugins {
id 'com.android.application'
}
android {
compileSdk 32
defaultConfig {
applicationId "com.example.myapplication"
minSdk 21
targetSdk 32
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.3.0'
implementation 'com.google.android.material:material:1.4.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.3'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
頁(yè)面文件
頁(yè)面的配置如下:
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:scaleType="fitCenter" />
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_gravity="top"
android:textSize="24sp"
android:background="#80000000"
android:textColor="@android:color/holo_red_light" />
</FrameLayout>
這個(gè)頁(yè)面只有兩個(gè)空間,一個(gè)展示圖片,一個(gè)顯示文字。
模型推理
新增assets文件夾,然后將轉(zhuǎn)化的模型和待測(cè)試的圖片放進(jìn)去。
新增ImageNetClasses類,這個(gè)類存放類別名字。
代碼如下:
package com.example.myapplication;
public class ImageNetClasses {
public static String[] IMAGENET_CLASSES = new String[]{
"Black-grass",
"Charlock",
"Cleavers",
"Common Chickweed",
"Common wheat",
"Fat Hen",
"Loose Silky-bent",
"Maize",
"Scentless Mayweed",
"Shepherds Purse",
"Small-flowered Cranesbill",
"Sugar beet",
};
}
在MainActivity類中,增加模型推理的邏輯。完成代碼如下:
package com.example.myapplication;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("1.png"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "models.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
System.out.println(maxScoreIdx);
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
然后運(yùn)行。
原文鏈接:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/122860445
相關(guān)推薦
- 2023-04-19 nginx: [error] CreateFile() “D:\nginx-1.21.6/logs/
- 2022-10-21 Go錯(cuò)誤和異常CGO?fallthrough處理教程詳解_Golang
- 2022-04-10 C#實(shí)現(xiàn)泛型動(dòng)態(tài)循環(huán)數(shù)組隊(duì)列的方法_C#教程
- 2022-01-14 函數(shù)的防抖和節(jié)流&&深淺克隆
- 2022-11-12 python中validators庫(kù)的使用方法詳解_python
- 2023-03-17 python?函數(shù)、變量中單下劃線和雙下劃線的區(qū)別詳解_python
- 2022-01-17 報(bào)錯(cuò):是否需要更改目標(biāo)庫(kù)?請(qǐng)嘗試將lib編譯器選項(xiàng)更改為es2015或更高版本
- 2022-10-13 python中arrow庫(kù)用法大全_python
- 最近更新
-
- window11 系統(tǒng)安裝 yarn
- 超詳細(xì)win安裝深度學(xué)習(xí)環(huán)境2025年最新版(
- Linux 中運(yùn)行的top命令 怎么退出?
- MySQL 中decimal 的用法? 存儲(chǔ)小
- get 、set 、toString 方法的使
- @Resource和 @Autowired注解
- Java基礎(chǔ)操作-- 運(yùn)算符,流程控制 Flo
- 1. Int 和Integer 的區(qū)別,Jav
- spring @retryable不生效的一種
- Spring Security之認(rèn)證信息的處理
- Spring Security之認(rèn)證過濾器
- Spring Security概述快速入門
- Spring Security之配置體系
- 【SpringBoot】SpringCache
- Spring Security之基于方法配置權(quán)
- redisson分布式鎖中waittime的設(shè)
- maven:解決release錯(cuò)誤:Artif
- restTemplate使用總結(jié)
- Spring Security之安全異常處理
- MybatisPlus優(yōu)雅實(shí)現(xiàn)加密?
- Spring ioc容器與Bean的生命周期。
- 【探索SpringCloud】服務(wù)發(fā)現(xiàn)-Nac
- Spring Security之基于HttpR
- Redis 底層數(shù)據(jù)結(jié)構(gòu)-簡(jiǎn)單動(dòng)態(tài)字符串(SD
- arthas操作spring被代理目標(biāo)對(duì)象命令
- Spring中的單例模式應(yīng)用詳解
- 聊聊消息隊(duì)列,發(fā)送消息的4種方式
- bootspring第三方資源配置管理
- GIT同步修改后的遠(yuǎn)程分支