一、引言:表格检测的价值与技术选型
在文档数字化、数据自动化录入等场景中,表格作为信息承载的重要形式,其结构(行、列)的精准检测是后续表格识别、数据提取的核心前提。传统的表格检测方法(如基于规则的边缘检测)对复杂表格(如倾斜、模糊、合并单元格)适应性较差,而基于深度学习的目标检测技术能有效解决这一问题。
YOLO(You Only Look Once)系列模型作为实时目标检测的标杆,其最新版本YOLOv8在速度与精度上均有显著提升,支持灵活的模型导出与跨平台部署。而 Java 作为企业级应用的主流开发语言,在稳定性、可维护性上具备天然优势。本文将详细讲解如何基于 Java 结合 YOLOv8 的 ONNX 模型,实现表格图像中行与列的高效检测。
二、环境准备:搭建开发与运行环境
在开始编码前,需完成依赖配置、模型准备等基础工作,确保后续流程顺畅。
2.1 开发工具与基础环境
- JDK:推荐 JDK 11 及以上(兼容 ONNX Runtime 与 OpenCV)
- IDE:IntelliJ IDEA(或 Eclipse)
- 依赖管理:Maven(简化依赖引入与版本控制)
- 图像处理:OpenCV 4.7.0(负责图像读取、预处理)
- 模型推理:ONNX Runtime 1.16.3(加载 YOLOv8 ONNX 模型并执行推理)
- JSON 解析:FastJSON2(处理检测结果的序列化)
2.2 Maven 依赖配置
见项目中的pom.xml
2.3 模型与标签文件准备
(1)YOLOv8 ONNX 模型
需先训练一个用于 “表格行列检测” 的 YOLOv8 模型,再导出为 ONNX 格式:
- 使用 Ultralytics 库训练 YOLOv8:标注表格图像中的 “行(row)” 和 “列(column)” 作为目标类别,生成训练数据集。
- 导出 ONNX 模型:通过 Ultralytics 的
export方法,指定format="onnx",示例代码(Python):
from ultralytics import YOLO# 加载训练好的YOLOv8模型model = YOLO("runs/detect/train/weights/best.pt")# 导出为ONNX格式(输入尺寸与代码中保持一致,如640x640)model.export(format="onnx", imgsz=640, nms=False) # nms=False:代码中自行实现NMS
3.将导出的 ONNX 模型(如table_det_yolov8.onnx)放入项目resources/model目录。
(2)标签文件
创建标签文件table_labels.txt,定义检测类别(与训练时一致),内容如下:
rowcolumn
将该文件放入项目resources/labels目录。
(3)配置文件
创建config.properties配置文件,统一管理路径(避免硬编码):
# 模型路径(相对于resources目录)model_path=model/table_det_yolov8.onnx# 标签文件路径(相对于resources目录)table_det_labels_path=labels/table_labels.txt
三、核心代码解析:从模型加载到结果输出
核心代码分为两个类:ModelDet(模型封装与检测逻辑)和MainTest(主函数调用与流程控制)。以下逐模块解析关键逻辑。完成项目:https://github.com/jiangnanboy/table_row_column_detection
3.1 核心检测类:ModelDet
ModelDet是整个检测流程的核心,负责模型初始化、图像预处理、推理执行、输出解析等功能。
3.1.1 类结构与核心属性
package sy;import ai.onnxruntime.*;import org.opencv.core.CvType;import org.opencv.core.Mat;import org.opencv.core.Size;import org.opencv.imgproc.Imgproc;import utils.CollectionUtil;import java.io.IOException;import java.nio.FloatBuffer;import java.nio.file.Files;import java.nio.file.Paths;import java.util.*;import java.util.stream.Collectors;import java.util.stream.Stream;public class ModelDet {// 检测阈值:置信度阈值(过滤低置信度结果)、IOU阈值(NMS用,过滤重复框)float confThreshold;float iouThreshold;// 模型输入相关:输入尺寸、形状、元素数量、数据类型long inputHeight;long inputWidth;long[] inputShape;int numInputElements;OnnxJavaType inputType;// ONNX Runtime相关:环境、会话、输入输出名称OrtEnvironment env;OrtSession session;String inputName;String outputName;// 原图尺寸(用于将模型输出的坐标缩放回原图)long rawImgHeight;long rawImgWidth;// 标签列表(行、列)public List<String> labelNames;// 省略构造方法与其他方法...}
3.1.2 模型与标签初始化
初始化 ONNX Runtime 环境、加载模型,并读取标签文件:
// 构造方法:默认阈值(置信度0.3,IOU 0.5)public ModelDet(String path, String labelPath) throws OrtException {this(path, labelPath, 0.3f, 0.5f);}// 构造方法:自定义阈值public ModelDet(String path, String labelPath, float confThres, float iouThres) throws OrtException {this.confThreshold = confThres;this.iouThreshold = iouThres;initializeModel(path); // 初始化模型initializeLabel(labelPath); // 初始化标签}// 初始化ONNX模型(支持CPU/GPU)private void initializeModel(String path) throws OrtException {nu.pattern.OpenCV.loadLocally(); // 加载OpenCV本地库this.initializeModel(path, -1); // 默认CPU(-1表示不使用GPU)}// 重载:支持指定GPU设备IDprivate void initializeModel(String path, int gpuDeviceId) throws OrtException {// 1. 创建ONNX环境this.env = OrtEnvironment.getEnvironment();// 2. 配置会话选项(开启所有优化)var sessionOptions = new OrtSession.SessionOptions();sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT);// 3. 配置CPU/GPUif (gpuDeviceId >= 0) {sessionOptions.addCPU(false); // 禁用CPUsessionOptions.addCUDA(gpuDeviceId); // 启用指定GPU} else {sessionOptions.addCPU(true); // 启用CPU}// 4. 加载模型并创建会话this.session = this.env.createSession(path, sessionOptions);// 5. 获取模型输入信息(输入名称、形状、数据类型)Map<String, NodeInfo> inputMetaMap = this.session.getInputInfo();this.inputName = this.session.getInputNames().iterator().next(); // 输入名称(YOLOv8通常为"images")NodeInfo inputMeta = inputMetaMap.get(this.inputName);this.inputType = ((TensorInfo) inputMeta.getInfo()).type; // 输入数据类型(通常为FLOAT)this.inputShape = ((TensorInfo) inputMeta.getInfo()).getShape(); // 输入形状(如[1,3,640,640])// 6. 计算输入元素总数(用于初始化数据缓冲区)this.numInputElements = (int) (this.inputShape[1] * this.inputShape[2] * this.inputShape[3]);this.inputHeight = this.inputShape[2]; // 模型输入高度(如640)this.inputWidth = this.inputShape[3]; // 模型输入宽度(如640)// 7. 获取模型输出名称(YOLOv8通常为"output0")this.outputName = this.session.getOutputNames().iterator().next();}// 初始化标签(读取label文件)private void initializeLabel(String labelPath) {try (Stream<String> lines = Files.lines(Paths.get(labelPath))) {// 读取每行标签,去除空格,存入列表labelNames = lines.map(line -> line.strip()).collect(Collectors.toList());} catch (IOException e) {e.printStackTrace();throw new RuntimeException("标签文件加载失败:" + labelPath);}}
3.1.3 图像预处理:将原图转换为模型输入格式
YOLOv8 模型对输入图像有严格要求,需经过通道转换、尺寸缩放、归一化、维度转换四步预处理:
// 准备模型输入:将OpenCV的Mat图像转换为ONNX Tensorprivate Map<String, OnnxTensor> prepareInput(Mat img) throws OrtException {// 1. 记录原图尺寸(后续用于缩放检测框)this.rawImgHeight = img.height();this.rawImgWidth = img.width();// 2. 通道转换:BGR(OpenCV默认)→ RGB(YOLOv8训练时使用的格式)Mat inputImg = new Mat();Imgproc.cvtColor(img, inputImg, Imgproc.COLOR_BGR2RGB);// 3. 尺寸缩放:将图像缩放到模型输入尺寸(如640x640)Imgproc.resize(inputImg, inputImg, new Size(this.inputWidth, this.inputHeight));// 4. 归一化:像素值从[0,255]→[0,1](模型训练时的预处理逻辑)inputImg.convertTo(inputImg, CvType.CV_32FC3, 1.0 / 255.0);// 5. 维度转换:HWC(高度、宽度、通道)→ CHW(通道、高度、宽度)// OpenCV读取的Mat格式为HWC,而ONNX模型输入格式为NCHW(N为批量大小,这里为1)float[] whcData = new float[this.numInputElements];inputImg.get(0, 0, whcData); // 将Mat数据读取到HWC数组float[] chwData = ImageUtil.whc2cwh(whcData); // 转换为CHW格式(工具类实现)// 6. 封装为ONNX Tensor(符合模型输入要求)FloatBuffer inputBuffer = FloatBuffer.wrap(chwData);OnnxTensor inputTensor = OnnxTensor.createTensor(this.env, inputBuffer, this.inputShape);// 7. 构建输入映射(键为输入名称,值为Tensor)Map<String, OnnxTensor> inputMap = CollectionUtil.newHashMap();inputMap.put(this.inputName, inputTensor);return inputMap;}
工具类补充(ImageUtil.whc2cwh):
维度转换的核心逻辑是将[H, W, C]的数组重新排列为[C, H, W],示例实现:
public class ImageUtil {// HWC → CHW:height x width x channel → channel x height x widthpublic static float[] whc2cwh(float[] whcData) {// 假设模型输入尺寸为640x640,通道数为3(需根据实际inputShape调整)int inputHeight = 640;int inputWidth = 640;int channels = 3;float[] chwData = new float[channels * inputHeight * inputWidth];int idx = 0;for (int c = 0; c < channels; c++) { // 先遍历通道for (int h = 0; h < inputHeight; h++) { // 再遍历高度for (int w = 0; w < inputWidth; w++) { // 最后遍历宽度// HWC数组索引:h * inputWidth * channels + w * channels + c// CHW数组索引:c * inputHeight * inputWidth + h * inputWidth + wchwData[c * inputHeight * inputWidth + h * inputWidth + w]= whcData[h * inputWidth * channels + w * channels + c];}}}return chwData;}}
3.1.4 模型推理:执行检测并获取输出
推理过程本质是将预处理后的输入传入模型,获取原始输出张量:
// 模型推理:输入映射 → 原始检测结果private float[][] inference(Map<String, OnnxTensor> inputMap) throws OrtException {// 执行推理:session.run()返回推理结果集合OrtSession.Result result = this.session.run(inputMap);// 解析输出:YOLOv8 ONNX输出格式为[1, num_boxes, 4+num_classes]// 这里取第一个批量([0]),得到[num_boxes, 4+num_classes]的二维数组float[][][] outputTensor = (float[][][]) result.get(0).getValue();float[][] predictions = outputTensor[0];return predictions;}
输出格式说明:
predictions是[num_boxes, 4+num_classes]的二维数组,其中:num_boxes:模型预测的检测框总数;4:检测框坐标(x, y, w, h,对应框的中心坐标和宽高);num_classes:类别置信度(本文中为 2 类:row 和 column)。
3.1.5 输出处理:从原始结果到可读检测信息
原始输出需经过类别筛选、坐标转换、非极大值抑制(NMS) 三步处理,得到最终的检测结果:
// 处理输出:原始预测结果 → 可读检测列表(Detection)private List<Detection> processOutput(float[][] predictions) {// 1. 转置矩阵:[num_boxes, 4+num_classes] → [4+num_classes, num_boxes]// 便于按类别维度处理置信度predictions = transposeMatrix(predictions);// 2. 按类别分组检测框(key:类别ID,value:该类别的检测框列表)Map<Integer, List<float[]>> class2Bbox = CollectionUtil.newHashMap();for (float[] bbox : predictions) {// 提取类别置信度(跳过前4个坐标值)float[] classProbs = Arrays.copyOfRange(bbox, 4, bbox.length);// 找到置信度最高的类别(即当前检测框的类别)int labelId = predMax(classProbs);// 该类别的置信度float confidence = classProbs[labelId];// 过滤低置信度结果(低于confThreshold的框丢弃)if (confidence < this.confThreshold) {continue;}// 将置信度存入检测框数组(第4位)bbox[4] = confidence;// 坐标缩放:将模型输入尺寸(如640x640)的坐标 → 原图尺寸rescaleBoxes(bbox);// 坐标格式转换:xywh(中心+宽高)→ xyxy(左上角+右下角)// 便于后续绘制检测框ImageUtil.xywh2xyxy(bbox);// 过滤无效框(左上角x ≥ 右下角x 或 左上角y ≥ 右下角y的框丢弃)if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) {continue;}// 按类别分组存入Mapclass2Bbox.putIfAbsent(labelId, CollectionUtil.newArrayList());class2Bbox.get(labelId).add(bbox);}// 3. 非极大值抑制(NMS):去除同一类别中的重复检测框List<Detection> detectionList = CollectionUtil.newArrayList();for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {int labelId = entry.getKey();List<float[]> bboxList = entry.getValue();// 对当前类别执行NMS(保留置信度最高、无重叠的框)bboxList = ImageUtil.nonMaxSuppression(bboxList, this.iouThreshold);// 封装为Detection对象(包含标签名、类别ID、坐标、置信度)for (float[] bbox : bboxList) {String labelName = labelNames.get(labelId);float[] coordinates = Arrays.copyOfRange(bbox, 0, 4); // xyxy坐标float confidence = bbox[4];detectionList.add(new Detection(labelName, labelId, coordinates, confidence));}}return detectionList;}// 辅助方法1:矩阵转置private float[][] transposeMatrix(float[][] matrix) {float[][] transposed = new float[matrix[0].length][matrix.length];for (int i = 0; i < matrix.length; i++) {for (int j = 0; j < matrix[0].length; j++) {transposed[j][i] = matrix[i][j];}}return transposed;}// 辅助方法2:找到置信度最高的类别IDprivate int predMax(float[] probabilities) {float maxProb = Float.NEGATIVE_INFINITY;int maxIndex = 0;for (int i = 0; i < probabilities.length; i++) {if (probabilities[i] > maxProb) {maxProb = probabilities[i];maxIndex = i;}}return maxIndex;}// 辅助方法3:将模型输入尺寸的坐标缩放回原图尺寸public void rescaleBoxes(float[] bbox) {// 缩放x坐标(中心x、宽)bbox[0] = bbox[0] / this.inputWidth * this.rawImgWidth;bbox[2] = bbox[2] / this.inputWidth * this.rawImgWidth;// 缩放y坐标(中心y、高)bbox[1] = bbox[1] / this.inputHeight * this.rawImgHeight;bbox[3] = bbox[3] / this.inputHeight * this.rawImgHeight;}
3.2 主函数调用:MainTest
MainTest负责读取配置、加载模型、调用检测、展示结果(绘制检测框、打印 JSON):
package sy;import ai.onnxruntime.OrtException;import com.alibaba.fastjson2.JSON;import org.opencv.core.Mat;import org.opencv.imgcodecs.Imgcodecs;import utils.PropertiesReader;import java.util.List;public class MainTest {public static void main(String... args) {try {// 1. 读取配置文件(模型、标签路径)String modelPath = getResourcePath(PropertiesReader.get("model_path"));String labelPath = getResourcePath(PropertiesReader.get("table_det_labels_path"));// 2. 测试图像路径(需替换为你的图像路径)String imgPath = "D:\\project\\table_detection\\img\\test_table.jpg";String outputImgPath = "D:\\project\\table_detection\\img\\prediction.jpg";// 3. 加载模型(默认置信度0.3,IOU 0.5)ModelDet modelDet = new ModelDet(modelPath, labelPath);// 4. 读取图像(OpenCV的imread方法)Mat img = Imgcodecs.imread(imgPath);if (img.dataAddr() == 0) { // 检查图像是否成功读取System.err.println("图像读取失败:" + imgPath);System.exit(1);}// 5. 执行检测List<Detection> detectionList = modelDet.detectObjects(img);// 6. 结果展示:绘制检测框到图像ImageUtil.drawPredictions(img, detectionList);// 7. 结果输出:保存图像+打印JSONImgcodecs.imwrite(outputImgPath, img);System.out.println("检测结果(JSON):");System.out.println(JSON.toJSONString(detectionList, true)); // 格式化输出System.out.println("检测完成!结果图像已保存至:" + outputImgPath);} catch (OrtException e) {e.printStackTrace();System.err.println("模型推理失败:" + e.getMessage());}}// 辅助方法:获取资源文件的绝对路径(处理getResource的斜杠问题)private static String getResourcePath(String relativePath) {String path = MainTest.class.getClassLoader().getResource(relativePath).getPath();return path.replaceFirst("/", ""); // 去除路径开头的斜杠(如"/D:..."→"D:...")}}
四、运行效果展示


五、总结
本文详细讲解了基于 Java 与 YOLOv8 实现表格行列检测的完整流程,从环境搭建、代码解析到运行效果展示,覆盖了模型加载、图像预处理、推理执行、结果处理等核心环节。该方案基于 YOLOv8 的高速度与高精度,结合 Java 的企业级优势,可灵活集成到文档处理、数据录入等实际应用中。
读者可根据自身需求调整阈值、优化模型或扩展功能,进一步提升检测性能与适用场景。希望本文能为从事表格识别、计算机视觉相关开发的读者提供有价值的参考。
微信扫描下方的二维码阅读本文

Comments NOTHING