<dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.20.0</version> </dependency>
import ai.onnxruntime.*; import java.util.*; import java.nio.FloatBuffer; import java.nio.LongBuffer; public class OnnxExample { public static void main(String[] args) { try (OrtEnvironment env = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions opts = new OrtSession.SessionOptions()) { // 可选:设置优化选项 opts.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); // 加载模型 OrtSession session = env.createSession("path/to/your/model.onnx", opts); // 获取模型输入输出信息 System.out.println("Model inputs: " + session.getInputNames()); System.out.println("Model outputs: " + session.getOutputNames()); } catch (OrtException e) { e.printStackTrace(); } } }
// 创建输入数据 float[] inputData = {1.0f, 2.0f, 3.0f, 4.0f}; // 示例数据 long[] shape = {1, 4}; // 输入张量的形状,需要根据模型要求设置 OnnxTensor tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(inputData), shape); // 准备输入映射 Map<String, OnnxTensor> inputs = new HashMap<>(); inputs.put(session.getInputNames().iterator().next(), tensor);
// 运行推理 try (OrtSession.Result results = session.run(inputs)) { // 获取输出 OnnxTensor output = (OnnxTensor) results.get(0); float[] outputData = (float[]) output.getValue(); // 处理输出数据 for (float value : outputData) { System.out.println("Output value: " + value); } }
// 使用 try-with-resources 语句自动关闭资源 tensor.close(); session.close(); env.close();
<dependencies> <!-- Spring Boot Starter Web --> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-web</artifactId> </dependency> <!-- ONNX Runtime --> <dependency> <groupId>com.microsoft.onnxruntime</groupId> <artifactId>onnxruntime</artifactId> <version>1.20.0</version> </dependency> <!-- Lombok for reducing boilerplate code (optional) --> <dependency> <groupId>org.projectlombok</groupId> <artifactId>lombok</artifactId> <optional>true</optional> </dependency> </dependencies>
package com.example.onnxdemo.service; import ai.onnxruntime.*; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import java.nio.FloatBuffer; import java.util.HashMap; import java.util.Map; @Service public class OnnxService { @Value("${onnx.model.path}") private String modelPath; private OrtEnvironment environment; private OrtSession session; @PostConstruct public void initialize() throws OrtException { // 初始化ONNX运行时环境 environment = OrtEnvironment.getEnvironment(); OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions(); // 设置优化级别 sessionOptions.setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT); // 加载模型 session = environment.createSession(modelPath, sessionOptions); System.out.println("ONNX模型已加载,输入: " + session.getInputNames()); System.out.println("ONNX模型已加载,输出: " + session.getOutputNames()); } public float[] runInference(float[] inputData, long[] shape) throws OrtException { // 创建输入张量 OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(inputData), shape); // 准备输入映射 Map<String, OnnxTensor> inputs = new HashMap<>(); inputs.put(session.getInputNames().iterator().next(), tensor); // 运行推理 try (OrtSession.Result results = session.run(inputs)) { // 获取输出 OnnxTensor output = (OnnxTensor) results.get(0); return (float[]) output.getValue(); } finally { tensor.close(); } } @PreDestroy public void cleanup() throws OrtException { if (session != null) { session.close(); } if (environment != null) { environment.close(); } } }
package com.example.onnxdemo.controller; import com.example.onnxdemo.dto.InferenceRequest; import com.example.onnxdemo.dto.InferenceResponse; import com.example.onnxdemo.service.OnnxService; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RestController; @RestController @RequestMapping("/api/inference") public class InferenceController { @Autowired private OnnxService onnxService; @PostMapping public InferenceResponse runInference(@RequestBody InferenceRequest request) { try { float[] result = onnxService.runInference( request.getInputData(), request.getShape() ); return new InferenceResponse(result, null); } catch (Exception e) { return new InferenceResponse(null, e.getMessage()); } } }
package com.example.onnxdemo.dto; import lombok.Data; @Data public class InferenceRequest { private float[] inputData; private long[] shape; } @Data public class InferenceResponse { private float[] result; private String error; public InferenceResponse(float[] result, String error) { this.result = result; this.error = error; } }
application.properties
application.yml
# 模型路径配置 onnx.model.path=classpath:models/model.onnx # 服务器配置 server.port=8080
// 示例:批量处理推理请求 public List<float[]> batchInference(List<float[]> inputBatch, long[] shape) throws OrtException { List<float[]> results = new ArrayList<>(); for (float[] input : inputBatch) { results.add(runInference(input, shape)); } return results; }