Paddle Structure
一、背景与目标
下载模型 https://www.paddleocr.ai/latest/version3.x/pipeline_usage/PP-StructureV3.html#1-pp-structurev3
目标是将 PaddleOCR 中的文档布局分析模型:
PP-DocLayout_plus-L
从 Paddle 推理格式:
inference.pdiparamsinference.jsoninference.yml
转换为:
- ONNX 模型(用于 Java 推理)
- 使用 ONNX Runtime 在 Java 中完成推理
- 最终接入 tio-boot 提供 HTTP 服务
二、整体流程
完整链路如下:
Paddle模型
↓
Paddle2ONNX
↓
ONNX模型
↓
ONNX Runtime(Java)
↓
前处理(对齐 Paddle)
↓
推理
↓
后处理(boxes解析)
↓
HTTP服务(tio-boot)
三、模型转换(Paddle → ONNX)
1. 推荐方式
使用 paddlex 提供的转换能力:
conda create --name paddlex python=3.12
conda activate paddlex
# 先按官方说明装 dev/nightly paddle
python -m pip install --pre paddlepaddle -i https://www.paddlepaddle.org.cn/packages/nightly/cpu/
# 再装 PaddleX
python -m pip install paddlex
# 再装插件
paddlex --install paddle2onnx
这是 PaddleOCR 官方给 Windows 的推荐路径。
然后再转:
paddlex --paddle2onnx --paddle_model_dir PP-DocLayout_plus-L_infer --onnx_model_dir PP-DocLayout_plus-L_infer-onnx
2. 成功标志
- 输出
inference.onnx - 自动复制
inference.yml - 日志包含 constant folding
inference.yml
mode: paddle
draw_threshold: 0.5
metric: COCO
use_dynamic_shape: false
Global:
model_name: PP-DocLayout_plus-L
arch: DETR
min_subgraph_size: 3
Preprocess:
- interp: 2
keep_ratio: false
target_size:
- 800
- 800
type: Resize
- mean:
- 0.0
- 0.0
- 0.0
norm_type: none
std:
- 1.0
- 1.0
- 1.0
type: NormalizeImage
- type: Permute
label_list:
- paragraph_title
- image
- text
- number
- abstract
- content
- figure_title
- formula
- table
- reference
- doc_title
- footnote
- header
- algorithm
- footer
- seal
- chart
- formula_number
- aside_text
- reference_content
Hpi:
backend_configs:
paddle_infer:
trt_dynamic_shapes: &id001
im_shape:
- - 1
- 2
- - 1
- 2
- - 8
- 2
image:
- - 1
- 3
- 800
- 800
- - 1
- 3
- 800
- 800
- - 8
- 3
- 800
- 800
scale_factor:
- - 1
- 2
- - 1
- 2
- - 8
- 2
trt_dynamic_shape_input_data:
im_shape:
- - 800
- 800
- - 800
- 800
- - 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
- 800
scale_factor:
- - 2
- 2
- - 1
- 1
- - 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
- 0.67
tensorrt:
dynamic_shapes: *id001
3. 关键参数
- 默认使用 opset=16
- 通常不需要手动指定
四、labels 获取方式
从:
inference.yml → label_list
读取
特点:
- 顺序即 cls_id
- 必须用于映射结果
package nexus.io.cv.paddle.utils;
import org.yaml.snakeyaml.Yaml;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class ModelLabelUtils {
public static List<String> loadLabelsFromYaml(String yamlPath) throws IOException {
Yaml yaml = new Yaml();
try (InputStream in = Files.newInputStream(Paths.get(yamlPath))) {
Map<String, Object> root = yaml.load(in);
Object labelsObj = root.get("label_list");
if (labelsObj instanceof List<?>) {
return ((List<?>) labelsObj).stream().map(String::valueOf).collect(Collectors.toList());
}
return new ArrayList<>();
}
}
public static List<String> loadLabels(Path yamlPath) throws IOException {
if (!Files.exists(yamlPath)) {
throw new IllegalArgumentException("inference.yml 不存在: " + yamlPath);
}
Yaml yaml = new Yaml();
try (InputStream in = Files.newInputStream(yamlPath)) {
Map<String, Object> root = yaml.load(in);
if (root == null) {
return new ArrayList<>();
}
Object labelsObj = root.get("label_list");
if (labelsObj instanceof List<?>) {
List<String> labels = new ArrayList<>();
for (Object o : (List<?>) labelsObj) {
labels.add(String.valueOf(o));
}
return labels;
}
return new ArrayList<>();
}
}
public static Map<Integer, String> loadLabelMap(Path yamlPath) throws IOException {
List<String> labels = loadLabels(yamlPath);
Map<Integer, String> map = new LinkedHashMap<>();
for (int i = 0; i < labels.size(); i++) {
map.put(i, labels.get(i));
}
return map;
}
}
package nexus.io.cv.paddle.utils;
import java.io.IOException;
import java.nio.file.Paths;
import java.util.Map;
import org.junit.Test;
import nexus.io.tio.utils.json.JsonUtils;
public class ModelLabelUtilsTest {
@Test
public void test() {
String path = "F:\\code\\python\\python-study\\11_python_ai-study\\python-pp-structure-study\\PP-DocLayout_plus-L_infer-onnx\\inference.yml";
try {
Map<Integer, String> map = ModelLabelUtils.loadLabelMap(Paths.get(path));
System.out.println(JsonUtils.toJson(map));
} catch (IOException e) {
e.printStackTrace();
}
}
}
output
{
"0": "paragraph_title",
"1": "image",
"2": "text",
"3": "number",
"4": "abstract",
"5": "content",
"6": "figure_title",
"7": "formula",
"8": "table",
"9": "reference",
"10": "doc_title",
"11": "footnote",
"12": "header",
"13": "algorithm",
"14": "footer",
"15": "seal",
"16": "chart",
"17": "formula_number",
"18": "aside_text",
"19": "reference_content"
}
五、推理
1、模型输入输出结构(核心)
1. 输入
该模型不是单输入,而是 三输入结构:
image [N, 3, H, W]
im_shape [N, 2]
scale_factor [N, 2]
简要说明:
im_shape = 输入给模型的图像尺寸(H, W)
scale_factor = 原图 → 输入图 的缩放比例(scaleY, scaleX)
对应你代码:
im_shape = [800, 800]; // resize 后尺寸
scale_factor = [scaleY, scaleX]; // 缩放比例
作用:
让模型知道:
1. 当前图有多大(im_shape)
2. 是从原图怎么缩放来的(scale_factor)
模型不仅要看图,还要知道这张图是“怎么被缩放来的”
否则:
检测框坐标会错乱
2. 输出
fetch_name_0 → 检测框
fetch_name_1 → 每张图的框数量
2、最关键问题:前处理对齐 Paddle
这是整个项目最重要的坑点。
1. Paddle inference.yml 本质
关键配置:
NormalizeImage:
is_scale: true
mean: [0,0,0]
std: [1,1,1]
2. 实际含义
不是“不处理”,而是:
必须做:像素 / 255
3. 如果不除 255 会发生什么?
你已经遇到:
- boxes = []
- 可视化无框
- 模型看起来“正常运行但没结果”
本质是:
输入分布错了 → 模型置信度全崩
3、推理输出结构解析
1. fetch_name_0
二维数组:
[cls_id, score, x1, y1, x2, y2]
2. fetch_name_1
每张图的检测框数量
例如:
[3, 1, 4]
表示:
- 第1张:3个框
- 第2张:1个框
- 第3张:4个框
3. 必须做的事情
用 cursor 按 batch 切分
否则:
多图结果混在一起 → 框位置错乱
3、结果验证
你最终结果已经与 Python 对齐:
Java:
score≈0.902
Python:
score≈0.900
说明:
推理完全正确
完整推理代码
package nexus.io.cv.paddle.demo;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.TreeMap;
import org.opencv.core.Mat;
import org.opencv.core.Point;
import org.opencv.core.Scalar;
import org.opencv.core.Size;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import com.alibaba.fastjson2.JSON;
import com.alibaba.fastjson2.JSONWriter;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import nexus.io.cv.model.DetectionResult;
import nexus.io.cv.model.PreprocessResult;
import nexus.io.cv.model.ProcessResult;
public class PPDocLayoutOnnxDemo {
private static final int INPUT_W = 800;
private static final int INPUT_H = 800;
/**
* 该模型通常已经输出原图坐标
*/
private static final boolean OUTPUT_ALREADY_ORIGINAL_COORDS = true;
private static final boolean DEBUG_LOG = true;
// ======================================
// ================= 主流程:单张 =================
public static ProcessResult processSingleImage(OrtEnvironment env, OrtSession session, Path imgPath,
List<String> labels, float confThres) throws Exception {
Mat bgr = Imgcodecs.imread(imgPath.toString());
if (bgr.empty()) {
throw new IllegalArgumentException("无法读取图片: " + imgPath);
}
Mat original = bgr.clone();
PreprocessResult pre = null;
try {
pre = preprocess(env, bgr);
List<DetectionResult> detections = infer(session, pre, labels, confThres);
if (!OUTPUT_ALREADY_ORIGINAL_COORDS) {
remapBoxesToOriginal(detections, pre.origW, pre.origH, pre.scaleX, pre.scaleY);
} else {
clampBoxes(detections, pre.origW, pre.origH);
}
ProcessResult result = new ProcessResult();
result.imgPath = imgPath;
result.originalImage = original;
result.detections = detections;
return result;
} finally {
if (pre != null) {
closeQuietly(pre.imageTensor);
closeQuietly(pre.imShapeTensor);
closeQuietly(pre.scaleFactorTensor);
}
bgr.release();
}
}
// ================= 前处理:单张 =================
private static PreprocessResult preprocess(OrtEnvironment env, Mat bgr) throws OrtException {
int origH = bgr.rows();
int origW = bgr.cols();
float scaleY = (float) INPUT_H / (float) origH;
float scaleX = (float) INPUT_W / (float) origW;
Mat resized = new Mat();
Imgproc.resize(bgr, resized, new Size(INPUT_W, INPUT_H), 0, 0, Imgproc.INTER_LINEAR);
Mat rgb = new Mat();
Imgproc.cvtColor(resized, rgb, Imgproc.COLOR_BGR2RGB);
float[] imageData = new float[3 * INPUT_H * INPUT_W];
for (int y = 0; y < INPUT_H; y++) {
for (int x = 0; x < INPUT_W; x++) {
double[] px = rgb.get(y, x);
imageData[0 * INPUT_H * INPUT_W + y * INPUT_W + x] = (float) px[0] / 255.0f;
imageData[1 * INPUT_H * INPUT_W + y * INPUT_W + x] = (float) px[1] / 255.0f;
imageData[2 * INPUT_H * INPUT_W + y * INPUT_W + x] = (float) px[2] / 255.0f;
}
}
float[] imShapeData = new float[] { INPUT_H, INPUT_W };
float[] scaleFactorData = new float[] { scaleY, scaleX };
if (DEBUG_LOG) {
System.out.printf(Locale.US, "preprocess file=%s orig=(%d,%d) scaleFactor=[%.6f, %.6f]%n", "single", origW, origH,
scaleY, scaleX);
}
PreprocessResult r = new PreprocessResult();
r.imageTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(imageData), new long[] { 1, 3, INPUT_H, INPUT_W });
r.imShapeTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(imShapeData), new long[] { 1, 2 });
r.scaleFactorTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(scaleFactorData), new long[] { 1, 2 });
r.origW = origW;
r.origH = origH;
r.scaleX = scaleX;
r.scaleY = scaleY;
resized.release();
rgb.release();
return r;
}
// ================= 推理 + 后处理:单张 =================
private static List<DetectionResult> infer(OrtSession session, PreprocessResult pre, List<String> labels,
float confThres) throws OrtException {
Map<String, OnnxTensor> inputs = new HashMap<>();
inputs.put("image", pre.imageTensor);
inputs.put("im_shape", pre.imShapeTensor);
inputs.put("scale_factor", pre.scaleFactorTensor);
try (OrtSession.Result outputs = session.run(inputs)) {
Object out0 = outputs.get("fetch_name_0").orElseThrow().getValue();
Object out1 = outputs.get("fetch_name_1").orElseThrow().getValue();
if (DEBUG_LOG) {
System.out.println("fetch_name_0 class = " + out0.getClass().getName());
System.out.println("fetch_name_1 class = " + out1.getClass().getName());
}
float[][] boxes = as2DFloatArray(out0);
int[] boxesNum = as1DIntArray(out1);
if (boxes == null) {
throw new IllegalStateException("fetch_name_0 不是预期二维数组");
}
if (DEBUG_LOG) {
System.out
.println("fetch_name_0 rows = " + boxes.length + ", cols = " + (boxes.length > 0 ? boxes[0].length : 0));
System.out.println("fetch_name_1 = " + Arrays.toString(boxesNum));
for (int i = 0; i < Math.min(10, boxes.length); i++) {
System.out.println("row[" + i + "] = " + Arrays.toString(boxes[i]));
}
}
List<DetectionResult> detections = new ArrayList<>();
int validRows = (boxesNum != null && boxesNum.length > 0) ? boxesNum[0] : boxes.length;
validRows = Math.min(validRows, boxes.length);
for (int i = 0; i < validRows; i++) {
DetectionResult d = parseDetectionRow(boxes[i], labels, confThres);
if (d != null) {
detections.add(d);
}
}
return detections;
}
}
private static DetectionResult parseDetectionRow(float[] row, List<String> labels, float confThres) {
if (row == null || row.length < 6) {
return null;
}
int clsId = Math.round(row[0]);
float score = row[1];
float x1 = row[2];
float y1 = row[3];
float x2 = row[4];
float y2 = row[5];
if (clsId < 0) {
return null;
}
if (score < confThres) {
return null;
}
if (x2 <= x1 || y2 <= y1) {
return null;
}
String label = clsId < labels.size() ? labels.get(clsId) : ("class_" + clsId);
DetectionResult d = new DetectionResult();
d.clsId = clsId;
d.label = label;
d.confidence = score;
d.bbox = new float[] { x1, y1, x2, y2 };
return d;
}
// ================= 坐标处理 =================
private static void remapBoxesToOriginal(List<DetectionResult> detections, int origW, int origH, float scaleX,
float scaleY) {
for (DetectionResult d : detections) {
d.bbox[0] = clamp(d.bbox[0] / scaleX, 0, origW - 1);
d.bbox[1] = clamp(d.bbox[1] / scaleY, 0, origH - 1);
d.bbox[2] = clamp(d.bbox[2] / scaleX, 0, origW - 1);
d.bbox[3] = clamp(d.bbox[3] / scaleY, 0, origH - 1);
}
}
private static void clampBoxes(List<DetectionResult> detections, int origW, int origH) {
for (DetectionResult d : detections) {
d.bbox[0] = clamp(d.bbox[0], 0, origW - 1);
d.bbox[1] = clamp(d.bbox[1], 0, origH - 1);
d.bbox[2] = clamp(d.bbox[2], 0, origW - 1);
d.bbox[3] = clamp(d.bbox[3], 0, origH - 1);
}
}
// ================= 结果保存 =================
public static void saveVisualized(Mat original, List<DetectionResult> detections, Path imgPath, String outDir) {
Mat vis = original.clone();
for (DetectionResult d : detections) {
Scalar color = colorOfClass(d.clsId);
Point p1 = new Point(d.bbox[0], d.bbox[1]);
Point p2 = new Point(d.bbox[2], d.bbox[3]);
Imgproc.rectangle(vis, p1, p2, color, 2);
String text = d.label + " " + String.format(Locale.US, "%.2f", d.confidence);
double textY = Math.max(15, d.bbox[1] - 5);
Imgproc.putText(vis, text, new Point(d.bbox[0], textY), Imgproc.FONT_HERSHEY_SIMPLEX, 0.7, color, 2);
}
Path out = Paths.get(outDir, "vis", imgPath.getFileName().toString());
Imgcodecs.imwrite(out.toString(), vis);
vis.release();
original.release();
}
public static void saveJson(List<DetectionResult> detections, Path imgPath, String outDir) throws IOException {
List<Map<String, Object>> boxes = new ArrayList<>();
for (DetectionResult d : detections) {
Map<String, Object> m = new LinkedHashMap<>();
m.put("cls_id", d.clsId);
m.put("label", d.label);
m.put("score", d.confidence);
m.put("coordinate", Arrays.asList(d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3]));
boxes.add(m);
}
Map<String, Object> res = new LinkedHashMap<>();
res.put("input_path", imgPath.toString());
res.put("page_index", null);
res.put("boxes", boxes);
Map<String, Object> root = new LinkedHashMap<>();
root.put("res", res);
root.put("debug_detection_count", detections.size());
root.put("debug_saved_at", System.currentTimeMillis());
Path out = Paths.get(outDir, "json", getStem(imgPath) + ".json");
String json = JSON.toJSONString(root, JSONWriter.Feature.PrettyFormat);
Files.writeString(out, json);
}
public static void saveYoloTxt(List<DetectionResult> detections, Path imgPath, String outDir) throws IOException {
Mat img = Imgcodecs.imread(imgPath.toString());
int w = img.cols();
int h = img.rows();
img.release();
List<String> lines = new ArrayList<>();
for (DetectionResult d : detections) {
float x1 = clamp(d.bbox[0], 0, w);
float y1 = clamp(d.bbox[1], 0, h);
float x2 = clamp(d.bbox[2], 0, w);
float y2 = clamp(d.bbox[3], 0, h);
if (x2 <= x1 || y2 <= y1) {
continue;
}
float cx = ((x1 + x2) / 2.0f) / w;
float cy = ((y1 + y2) / 2.0f) / h;
float bw = (x2 - x1) / w;
float bh = (y2 - y1) / h;
lines.add(String.format(Locale.US, "%d %.6f %.6f %.6f %.6f", d.clsId, cx, cy, bw, bh));
}
Path out = Paths.get(outDir, "labels", getStem(imgPath) + ".txt");
Files.write(out, lines);
}
public static void saveClassesTxt(Map<Integer, String> classMap, Path out) throws IOException {
if (classMap.isEmpty()) {
Files.write(out, Collections.emptyList());
return;
}
int maxId = classMap.keySet().stream().max(Integer::compareTo).orElse(0);
List<String> lines = new ArrayList<>();
for (int i = 0; i <= maxId; i++) {
lines.add(classMap.getOrDefault(i, "class_" + i));
}
Files.write(out, lines);
}
public static Map<Integer, String> toClassMap(List<DetectionResult> detections) {
Map<Integer, String> map = new TreeMap<>();
for (DetectionResult d : detections) {
map.put(d.clsId, d.label);
}
return map;
}
private static String getStem(Path p) {
String name = p.getFileName().toString();
int idx = name.lastIndexOf('.');
return idx >= 0 ? name.substring(0, idx) : name;
}
private static float[][] as2DFloatArray(Object val) {
if (val instanceof float[][]) {
return (float[][]) val;
}
if (val instanceof float[][][]) {
float[][][] arr = (float[][][]) val;
return arr.length > 0 ? arr[0] : null;
}
return null;
}
private static int[] as1DIntArray(Object val) {
if (val instanceof long[]) {
long[] src = (long[]) val;
int[] dst = new int[src.length];
for (int i = 0; i < src.length; i++) {
dst[i] = (int) src[i];
}
return dst;
}
if (val instanceof int[]) {
return (int[]) val;
}
if (val instanceof float[]) {
float[] src = (float[]) val;
int[] dst = new int[src.length];
for (int i = 0; i < src.length; i++) {
dst[i] = Math.round(src[i]);
}
return dst;
}
if (val instanceof long[][]) {
long[][] src = (long[][]) val;
if (src.length == 0) {
return new int[0];
}
int[] dst = new int[src[0].length];
for (int i = 0; i < src[0].length; i++) {
dst[i] = (int) src[0][i];
}
return dst;
}
if (val instanceof int[][]) {
int[][] src = (int[][]) val;
if (src.length == 0) {
return new int[0];
}
return src[0];
}
if (val instanceof float[][]) {
float[][] src = (float[][]) val;
if (src.length == 0) {
return new int[0];
}
int[] dst = new int[src[0].length];
for (int i = 0; i < src[0].length; i++) {
dst[i] = Math.round(src[0][i]);
}
return dst;
}
return null;
}
private static Scalar colorOfClass(int clsId) {
Random random = new Random(clsId * 2027L + 17);
return new Scalar(50 + random.nextInt(206), 50 + random.nextInt(206), 50 + random.nextInt(206));
}
private static float clamp(float v, float min, float max) {
return Math.max(min, Math.min(max, v));
}
private static void closeQuietly(AutoCloseable c) {
if (c != null) {
try {
c.close();
} catch (Exception ignored) {
}
}
}
}
package nexus.io.cv.paddle.demo;
import java.io.IOException;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.junit.Test;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import nexus.io.cv.model.DetectionResult;
import nexus.io.cv.model.ProcessResult;
import nexus.io.cv.paddle.utils.ModelLabelUtils;
public class PPDocLayoutOnnxDemoTest {
Set<String> SUPPORTED_EXTS = new HashSet<>(Arrays.asList(".jpg", ".jpeg", ".png", ".bmp"));
@Test
public void test() throws Exception {
nu.pattern.OpenCV.loadLocally();
// ================= 配置 =================
String MODEL_DIR = "F:\\code\\python\\python-study\\11_python_ai-study\\python-pp-structure-study\\PP-DocLayout_plus-L_infer-onnx";
String MODEL_PATH = MODEL_DIR + "\\inference.onnx";
String MODEL_YML = MODEL_DIR + "\\inference.yml";
String INPUT_DIR = "F:\\code\\project\\project-videotutor\\mc-qa-cv\\mc-qa-cv-base\\upload";
String OUTPUT_DIR = "F:\\code\\project\\project-videotutor\\mc-qa-cv\\mc-qa-cv-base\\output\\java\\single";
float CONF_THRESHOLD = 0.35f;
String DEBUG_FILE_NAME = "2027202380864421888.png";
ensureDir(Paths.get(OUTPUT_DIR));
ensureDir(Paths.get(OUTPUT_DIR, "vis"));
ensureDir(Paths.get(OUTPUT_DIR, "json"));
ensureDir(Paths.get(OUTPUT_DIR, "labels"));
List<String> labels = ModelLabelUtils.loadLabelsFromYaml(MODEL_YML);
System.out.println("labels count = " + labels.size());
Path imgPath = resolveSingleImage(INPUT_DIR, DEBUG_FILE_NAME);
if (imgPath == null) {
System.out.println("没有可处理图片");
return;
}
try (OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
OrtSession session = env.createSession(MODEL_PATH, options)) {
System.out.println("load ok");
System.out.println("inputs: " + session.getInputInfo().keySet());
System.out.println("outputs: " + session.getOutputInfo().keySet());
ProcessResult result = PPDocLayoutOnnxDemo.processSingleImage(env, session, imgPath, labels, CONF_THRESHOLD);
System.out.println("final detections size = " + result.detections.size());
for (DetectionResult d : result.detections) {
System.out.printf(Locale.US, "det: cls=%d label=%s score=%.6f box=[%.2f, %.2f, %.2f, %.2f]%n", d.clsId, d.label,
d.confidence, d.bbox[0], d.bbox[1], d.bbox[2], d.bbox[3]);
}
PPDocLayoutOnnxDemo.saveVisualized(result.originalImage, result.detections, imgPath, OUTPUT_DIR);
PPDocLayoutOnnxDemo.saveJson(result.detections, imgPath, OUTPUT_DIR);
PPDocLayoutOnnxDemo.saveYoloTxt(result.detections, imgPath, OUTPUT_DIR);
Map<Integer, String> classMap = PPDocLayoutOnnxDemo.toClassMap(result.detections);
PPDocLayoutOnnxDemo.saveClassesTxt(classMap, Paths.get(OUTPUT_DIR, "classes.txt"));
System.out.println();
System.out.println("========== 处理完成 ==========");
System.out.println("JSON目录: " + Paths.get(OUTPUT_DIR, "json"));
System.out.println("YOLO目录: " + Paths.get(OUTPUT_DIR, "labels"));
System.out.println("可视化目录: " + Paths.get(OUTPUT_DIR, "vis"));
}
}
private Path resolveSingleImage(String inputDir, String fileName) throws IOException {
Path root = Paths.get(inputDir);
if (!Files.exists(root)) {
throw new IllegalArgumentException("输入目录不存在: " + inputDir);
}
if (fileName != null && !fileName.trim().isEmpty()) {
Path p = root.resolve(fileName);
if (!Files.exists(p) || !Files.isRegularFile(p)) {
throw new IllegalArgumentException("指定文件不存在: " + p);
}
String ext = getExtLower(p);
if (!SUPPORTED_EXTS.contains(ext)) {
throw new IllegalArgumentException("不支持的图片后缀: " + ext);
}
return p;
}
try (DirectoryStream<Path> stream = Files.newDirectoryStream(root)) {
for (Path p : stream) {
if (!Files.isRegularFile(p)) {
continue;
}
String ext = getExtLower(p);
if (SUPPORTED_EXTS.contains(ext)) {
return p;
}
}
}
return null;
}
private static String getExtLower(Path p) {
String name = p.getFileName().toString();
int idx = name.lastIndexOf('.');
return idx >= 0 ? name.substring(idx).toLowerCase(Locale.ROOT) : "";
}
private static void ensureDir(Path path) throws IOException {
if (!Files.exists(path)) {
Files.createDirectories(path);
}
}
}
常见错误
1、ONNX Runtime 版本坑
1. 常见报错
Unsupported model IR version: 10, max supported IR version: 9
2. 原因
- ONNX 模型 IR version 较新
- Java ONNX Runtime 版本过旧
3. 解决方案
升级 ONNX Runtime:
>= 1.19.x(推荐)
2、Windows DLL 初始化失败问题
1. 报错
UnsatisfiedLinkError: onnxruntime.dll 初始化失败
2. 根因
通常是:
- VC++ 运行库缺失
- CPU 指令集不兼容
- 多版本冲突
3. 解决
升级 ONNX Runtime:
>= 1.19.x(推荐)
