PyTorch图像分割模型安卓部署

软件及版本信息

在PyTorch官网中有如下建议

We recommend you to open this project in Android Studio 3.5.1+. At the moment PyTorch Android and demo applications use android gradle plugin of version 3.5.0, which is supported only by Android Studio version 3.5.1 and higher. Using Android Studio you will be able to install Android NDK and Android SDK with Android Studio UI.

建议的Android Studio版本应该在3.5.1及以上,这里推荐使用最新版。

建议的自动化构建工具android gradle plugin of version版本应该为3.5.0,请注意,不要将gradle的版本升级至4.0.1,在开发过程中发现该版本不太稳定。

模型修改及依赖导入

我们通过model.save保存的模型不能直接在Java程序中使用,需要使用torch.jit.trace(model, example)命令将模型从Java程序不可读取转为可读取。请注意这一步如果未处理会导致读取文件失败。

1
2
3
4
5
6
7
8
import torch
import torchvision

model = torchvision.models.resnet18(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/model.pt")

在安卓程序build.gradle配置中加入如下依赖,请注意这里的org.pytorch:pytorch_androidorg.pytorch:pytorch_android_torchvision需要更新到最新版。即下图的设置是无法读取PyTorch版本大于1.4的模型。

1
2
3
4
5
6
7
8
repositories {
jcenter()
}

dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

预处理图像获取

在安卓程序中我们一般通过两种方式来获取图像,分别是通过调用相机和通过调用相册。

调用相机

通过MediaStore.ACTION_IMAGE_CAPTURE来调用系统相机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int cameraRequestCode = 001;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Button capture = findViewById(R.id.capture);
capture.setOnClickListener(new View.OnClickListener(){
@Override
public void onClick(View view){
Intent cameraIntent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
startActivityForResult(cameraIntent,cameraRequestCode);
}
});
}

在匹配结果后通过data.getExtras()获取相机返回的android.graphics.Bitmap对象

1
2
3
4
5
6
7
8
9
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data) {
super.onActivityResult(requestCode, resultCode, data);
if (requestCode == cameraRequestCode && resultCode == RESULT_OK) {
Intent resultView = new Intent(this, Result.class);
resultView.putExtra("imagedata", data.getExtras());
startActivity(resultView);
}
}

调用系统相册

根据安卓版本不同,权限管理之间的差异,这一步比较复杂,需要获取照片真实地址,略。

模型读取和图像处理

预处理

第一步、通过上述两种方式获取图像。

第二步、将获取到的图像转为android.graphics.Bitmap对象

As a first step we read image.jpg to android.graphics.Bitmap using the standard Android API.

第三部、模型读取:

1、在Utils工具类中增加如下静态方法assetFilePath,这个方法的作用是通过文件名使用文件流的方式读取模型文件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public static String assetFilePath(Context context, String assetName) {
File file = new File(context.getFilesDir(), assetName);

try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
} catch (IOException e) {
Log.e("pytorchandroid", "Error process asset " + assetName + " to file path");
}
return null;
}

2、读取模型,请注意,这里的assetNameassert文件夹下的模型文件名称。

1
2
3
4
Module module = null;
final String moduleFileAbsoluteFilePath = new File(
Objects.requireNonNull(Utils.assetFilePath(this, "example.pt"))).getAbsolutePath();
module = Module.load(moduleFileAbsoluteFilePath);

模型计算

第四步、预处理为输入Tensor对象

1
2
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

第五步、调用模型得到输出Tensor对象

1
2
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();

第六步、将输出Tensor对象重新转为android.graphics.Bitmap对象

1
2
3
4
5
6
7
8
9
10
11
12
final float[] scores = outputTensor.getDataAsFloatArray();


Bitmap test = Bitmap.createBitmap(256, 256, Bitmap.Config.ARGB_8888);

int[] pixels = new int[256 * 256];
for (int i = 0; i < 256 * 256; ++i) {
//关键代码,生产灰度图
pixels[i] = (int) (scores[i] * 50);

}
test.setPixels(pixels, 0, 256, 0, 0, 256, 256);

输出UI设计

主要需要注意的部分是android:adjustViewBounds="true"将两个ImageView控件都设置为相同位置,相同大小的正方形,这样结果部分的图像才能重合。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:background="@android:color/black">

<ImageView
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:adjustViewBounds="true"
android:src="@drawable/ic_launcher_background"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent"
app:layout_constraintBottom_toBottomOf="parent"
android:id="@+id/image"
/>

<ImageView
android:id="@+id/image2"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:adjustViewBounds="true"
android:src="@drawable/ic_launcher_background"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.0"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />

</androidx.constraintlayout.widget.ConstraintLayout>