tio-boot 整合 Paddle Structure
一、依赖
确保下面这些依赖:
<dependency>
<groupId>org.openpnp</groupId>
<artifactId>opencv</artifactId>
<version>4.7.0-0</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<!--
<version>1.16.1</version>
-->
<version>1.19.2</version>
<!--
<version>1.23.2</version>
-->
</dependency>
<dependency>
<groupId>com.alibaba.fastjson2</groupId>
<artifactId>fastjson2</artifactId>
<version>2.0.52</version>
</dependency>
<dependency>
<groupId>org.yaml</groupId>
<artifactId>snakeyaml</artifactId>
<version>2.2</version>
</dependency>
、实体类
1、DetectionResult
如果你项目里已经有这个类,把 setBbox 修好即可。
package nexus.io.cv.model;
public class DetectionResult {
public String label;
public Integer clsId;
public float[] bbox;
public float confidence;
}
2、PreprocessResult
package nexus.io.cv.model;
import ai.onnxruntime.OnnxTensor;
/**
* 预处理结果封装
*/
public class PreprocessResult {
public OnnxTensor imageTensor;
public OnnxTensor imShapeTensor;
public OnnxTensor scaleFactorTensor;
public OnnxTensor tensor;
public int origW;
public int origH;
public float scaleX;
public float scaleY;
public int rows;
public int cols;
public int channels;
public double ratio;
public double dw;
public double dh;
public int batchSize;
}
三、PP-DocLayout 服务类
- 默认模型目录通过
EnvUtils.get("pp.doclayout.model_dir", "models/PP-DocLayout_plus-L_infer") - 默认置信度通过
EnvUtils - 复用你已经验证通过的推理逻辑
- 单例 service 可直接给 handler 调用
package nexus.io.cv.service;
import java.nio.FloatBuffer;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.opencv.core.Mat;
import org.opencv.core.MatOfByte;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import nexus.io.cv.model.DetectionResult;
import nexus.io.cv.model.PreprocessResult;
import nexus.io.cv.paddle.utils.ModelLabelUtils;
import nexus.io.tio.utils.environment.EnvUtils;
public class PPDocLayoutService {
private final String modelDir = EnvUtils.get("pp.doclayout.model_dir", "models/PP-DocLayout_plus-L_infer-onnx");
private final String modelFileName = EnvUtils.get("pp.doclayout.model_file", "inference.onnx");
private final String modelYmlName = EnvUtils.get("pp.doclayout.model_yml", "inference.yml");
private final float confThreshold = getFloat("pp.doclayout.conf_threshold", 0.3f);
private final int inputW = EnvUtils.getInt("pp.doclayout.input_width", 800);
private final int inputH = EnvUtils.getInt("pp.doclayout.input_height", 800);
/**
* 当前模型通常已经输出原图坐标
*/
private final boolean outputAlreadyOriginalCoords = EnvUtils.getBoolean("pp.doclayout.output_already_original_coords",
true);
private OrtEnvironment env;
private OrtSession session;
private List<String> labels;
public PPDocLayoutService() {
try {
Path modelPath = Paths.get(modelDir, modelFileName);
Path ymlPath = Paths.get(modelDir, modelYmlName);
this.env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
this.session = env.createSession(modelPath.toString(), options);
this.labels = ModelLabelUtils.loadLabels(ymlPath);
System.out.println("PPDocLayoutService init ok");
System.out.println("modelDir=" + modelDir);
System.out.println("inputs=" + session.getInputInfo().keySet());
System.out.println("outputs=" + session.getOutputInfo().keySet());
System.out.println("labels count=" + labels.size());
} catch (Exception e) {
throw new RuntimeException("初始化 PPDocLayoutService 失败", e);
}
}
public List<DetectionResult> detect(String fileName, byte[] imageBytes) throws Exception {
Mat bgr = bytesToMat(imageBytes);
if (bgr == null || bgr.empty()) {
throw new IllegalArgumentException("无法解析图片: " + fileName);
}
try {
PreprocessResult pre = preprocess(bgr);
List<DetectionResult> detections;
try {
detections = infer(pre);
} finally {
closeQuietly(pre.imageTensor);
closeQuietly(pre.imShapeTensor);
closeQuietly(pre.scaleFactorTensor);
}
if (!outputAlreadyOriginalCoords) {
remapBoxesToOriginal(detections, pre.origW, pre.origH, pre.scaleX, pre.scaleY);
} else {
clampBoxes(detections, pre.origW, pre.origH);
}
return detections;
} finally {
bgr.release();
}
}
public List<String> getLabels() {
return labels;
}
public Map<Integer, String> getLabelMap() {
Map<Integer, String> map = new LinkedHashMap<>();
for (int i = 0; i < labels.size(); i++) {
map.put(i, labels.get(i));
}
return map;
}
private Mat bytesToMat(byte[] imageBytes) {
MatOfByte mob = new MatOfByte(imageBytes);
Mat img = Imgcodecs.imdecode(mob, Imgcodecs.IMREAD_COLOR);
mob.release();
return img;
}
private PreprocessResult preprocess(Mat bgr) throws OrtException {
int origH = bgr.rows();
int origW = bgr.cols();
float scaleY = (float) inputH / (float) origH;
float scaleX = (float) inputW / (float) origW;
Mat resized = new Mat();
Imgproc.resize(bgr, resized, new Size(inputW, inputH), 0, 0, Imgproc.INTER_LINEAR);
Mat rgb = new Mat();
Imgproc.cvtColor(resized, rgb, Imgproc.COLOR_BGR2RGB);
float[] imageData = new float[3 * inputH * inputW];
for (int y = 0; y < inputH; y++) {
for (int x = 0; x < inputW; x++) {
double[] px = rgb.get(y, x);
imageData[0 * inputH * inputW + y * inputW + x] = (float) px[0] / 255.0f;
imageData[1 * inputH * inputW + y * inputW + x] = (float) px[1] / 255.0f;
imageData[2 * inputH * inputW + y * inputW + x] = (float) px[2] / 255.0f;
}
}
float[] imShapeData = new float[] { inputH, inputW };
float[] scaleFactorData = new float[] { scaleY, scaleX };
PreprocessResult r = new PreprocessResult();
r.imageTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(imageData), new long[] { 1, 3, inputH, inputW });
r.imShapeTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(imShapeData), new long[] { 1, 2 });
r.scaleFactorTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(scaleFactorData), new long[] { 1, 2 });
r.origW = origW;
r.origH = origH;
r.scaleX = scaleX;
r.scaleY = scaleY;
resized.release();
rgb.release();
return r;
}
private List<DetectionResult> infer(PreprocessResult pre) throws OrtException {
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("image", pre.imageTensor);
inputs.put("im_shape", pre.imShapeTensor);
inputs.put("scale_factor", pre.scaleFactorTensor);
try (OrtSession.Result outputs = session.run(inputs)) {
Object out0 = outputs.get("fetch_name_0").orElseThrow().getValue();
Object out1 = outputs.get("fetch_name_1").orElseThrow().getValue();
float[][] boxes = as2DFloatArray(out0);
int[] boxesNum = as1DIntArray(out1);
if (boxes == null) {
throw new IllegalStateException("fetch_name_0 不是预期二维数组");
}
List<DetectionResult> detections = new ArrayList<>();
int validRows = (boxesNum != null && boxesNum.length > 0) ? boxesNum[0] : boxes.length;
validRows = Math.min(validRows, boxes.length);
for (int i = 0; i < validRows; i++) {
float[] row = boxes[i];
if (row == null || row.length < 6) {
continue;
}
int clsId = Math.round(row[0]);
float score = row[1];
float x1 = row[2];
float y1 = row[3];
float x2 = row[4];
float y2 = row[5];
if (clsId < 0) {
continue;
}
if (score < confThreshold) {
continue;
}
if (x2 <= x1 || y2 <= y1) {
continue;
}
String label = clsId < labels.size() ? labels.get(clsId) : ("class_" + clsId);
detections.add(new DetectionResult(label, clsId, new float[] { x1, y1, x2, y2 }, score));
}
return detections;
}
}
private void remapBoxesToOriginal(List<DetectionResult> detections, int origW, int origH, float scaleX,
float scaleY) {
for (DetectionResult d : detections) {
float[] b = d.getBbox();
b[0] = clamp(b[0] / scaleX, 0, origW - 1);
b[1] = clamp(b[1] / scaleY, 0, origH - 1);
b[2] = clamp(b[2] / scaleX, 0, origW - 1);
b[3] = clamp(b[3] / scaleY, 0, origH - 1);
}
}
private void clampBoxes(List<DetectionResult> detections, int origW, int origH) {
for (DetectionResult d : detections) {
float[] b = d.getBbox();
b[0] = clamp(b[0], 0, origW - 1);
b[1] = clamp(b[1], 0, origH - 1);
b[2] = clamp(b[2], 0, origW - 1);
b[3] = clamp(b[3], 0, origH - 1);
}
}
private float[][] as2DFloatArray(Object val) {
if (val instanceof float[][]) {
return (float[][]) val;
}
if (val instanceof float[][][]) {
float[][][] arr = (float[][][]) val;
return arr.length > 0 ? arr[0] : null;
}
return null;
}
private int[] as1DIntArray(Object val) {
if (val instanceof long[]) {
long[] src = (long[]) val;
int[] dst = new int[src.length];
for (int i = 0; i < src.length; i++)
dst[i] = (int) src[i];
return dst;
}
if (val instanceof int[]) {
return (int[]) val;
}
if (val instanceof float[]) {
float[] src = (float[]) val;
int[] dst = new int[src.length];
for (int i = 0; i < src.length; i++)
dst[i] = Math.round(src[i]);
return dst;
}
if (val instanceof long[][]) {
long[][] src = (long[][]) val;
if (src.length == 0)
return new int[0];
int[] dst = new int[src[0].length];
for (int i = 0; i < src[0].length; i++)
dst[i] = (int) src[0][i];
return dst;
}
if (val instanceof int[][]) {
int[][] src = (int[][]) val;
if (src.length == 0)
return new int[0];
return src[0];
}
if (val instanceof float[][]) {
float[][] src = (float[][]) val;
if (src.length == 0)
return new int[0];
int[] dst = new int[src[0].length];
for (int i = 0; i < src[0].length; i++)
dst[i] = Math.round(src[0][i]);
return dst;
}
return null;
}
private float clamp(float v, float min, float max) {
return Math.max(min, Math.min(max, v));
}
private void closeQuietly(AutoCloseable c) {
if (c != null) {
try {
c.close();
} catch (Exception ignored) {
}
}
}
public void shutdown() {
closeQuietly(session);
closeQuietly(env);
}
private float getFloat(String string, float f) {
String str = EnvUtils.getStr(string);
if (str == null) {
return f;
}
return Float.parseFloat(str);
}
}
四、Handler
1、检测接口 Handler
使用 fastjson2 输出 JSON。
import java.util.List;
import com.alibaba.fastjson2.JSON;
import nexus.io.cv.model.DetectionResult;
import nexus.io.cv.service.PPDocLayoutService;
import nexus.io.jfinal.aop.Aop;
import nexus.io.model.body.RespBodyVo;
import nexus.io.model.upload.UploadFile;
import nexus.io.tio.boot.http.TioRequestContext;
import nexus.io.tio.http.common.HttpRequest;
import nexus.io.tio.http.common.HttpResponse;
import nexus.io.tio.http.server.handler.HttpRequestHandler;
import nexus.io.tio.http.server.util.CORSUtils;
public class PPDocLayoutHandler implements HttpRequestHandler {
private final PPDocLayoutService ppDocLayoutService = Aop.get(PPDocLayoutService.class);
@Override
public HttpResponse handle(HttpRequest httpRequest) throws Exception {
HttpResponse response = TioRequestContext.getResponse();
CORSUtils.enableCORS(response);
UploadFile uploadFile = httpRequest.getUploadFile("file");
if (uploadFile == null) {
response.setJson(JSON.toJSONString(RespBodyVo.fail("file不能为空")));
return response;
}
String name = uploadFile.getName();
byte[] data = uploadFile.getData();
List<DetectionResult> results = ppDocLayoutService.detect(name, data);
response.setJson(JSON.toJSONString(RespBodyVo.ok(results)));
return response;
}
}
2、labels 接口 Handler
单独提供一个获取模型所有标签的接口。
import java.util.Map;
import com.alibaba.fastjson2.JSON;
import nexus.io.cv.service.PPDocLayoutService;
import nexus.io.jfinal.aop.Aop;
import nexus.io.model.body.RespBodyVo;
import nexus.io.tio.boot.http.TioRequestContext;
import nexus.io.tio.http.common.HttpRequest;
import nexus.io.tio.http.common.HttpResponse;
import nexus.io.tio.http.server.handler.HttpRequestHandler;
import nexus.io.tio.http.server.util.CORSUtils;
public class PPDocLayoutLabelsHandler implements HttpRequestHandler {
private final PPDocLayoutService ppDocLayoutService = Aop.get(PPDocLayoutService.class);
@Override
public HttpResponse handle(HttpRequest httpRequest) throws Exception {
HttpResponse response = TioRequestContext.getResponse();
CORSUtils.enableCORS(response);
Map<Integer, String> labels = ppDocLayoutService.getLabelMap();
response.setJson(JSON.toJSONString(RespBodyVo.ok(labels)));
return response;
}
}
3、路由注册
PPDocLayoutHandler ppDocLayoutHandler = new PPDocLayoutHandler();
r.add("/pp/doc/layout", ppDocLayoutHandler);
PPDocLayoutLabelsHandler ppDocLayoutLabelsHandler = new PPDocLayoutLabelsHandler();
r.add("/pp/doc/layout/labels", ppDocLayoutLabelsHandler);
4、建议的配置项
你可以放到环境变量或配置文件里:
pp.doclayout.model_dir=models\PP-DocLayout_plus-L_infer
pp.doclayout.model_file=inference.onnx
pp.doclayout.model_yml=inference.yml
pp.doclayout.conf_threshold=0.3
pp.doclayout.input_width=800
pp.doclayout.input_height=800
pp.doclayout.output_already_original_coords=true
五、接口效果
1. 检测接口
请求:
curl -X POST http://localhost:8080/pp/doc/layout -F "file=@test.png"
返回示例:
{
"ok": true,
"code": 1,
"data": [
{
"label": "image",
"clsId": 1,
"bbox": [156.85117, 0.0, 776.67596, 430.62967],
"confidence": 0.9023752
},
{
"label": "formula",
"clsId": 7,
"bbox": [351.67267, 487.3784, 511.15216, 525.1032],
"confidence": 0.7211184
}
]
}
2. labels 接口
请求:
curl http://localhost:8080/pp/doc/layout/labels
返回示例:
{
"ok": true,
"code": 1,
"data": {
"0": "paragraph_title",
"1": "image",
"2": "text",
"3": "table",
"4": "title",
"5": "header",
"6": "figure_title",
"7": "formula"
}
}
