如何使用Spring AI和GenAI实现关系型数据库查询
本文将介绍如何借助GenAI(生成式人工智能)和Spring AI实现这一功能,利用大语言模型(LLMs)根据用户问题和数据库表结构信息生成SQL查询语句,并在Spring Boot应用中执行这些语句,获取查询结果。
一、项目搭建
首先要创建一个Spring Boot应用。此应用使用OpenAI的GPT模型来生成SQL查询语句,因此需要添加spring-ai-openai-spring-boot-starter
模块作为依赖。关于Spring AI的详细入门指南,可以参考相关资料进一步了解。
为了执行数据库查询操作,这里选择JdbcTemplate和H2数据库,所以也要添加相应的依赖。在pom.xml
文件中添加如下依赖配置:
<dependency> <groupId>org.springframework.ai</groupId> <artifactId>spring-ai-openai-spring-boot-starter</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-jdbc</artifactId> </dependency> <dependency> <groupId>com.h2database</groupId> <artifactId>h2</artifactId> <scope>runtime</scope> </dependency>
别忘记在属性文件或环境变量中添加OpenAI API密钥。在application.properties
文件中添加:
spring.ai.openai.api-key=${OPENAI_API_KEY}
二、DDL语句和初始数据
本演示使用H2数据库,Spring Boot的自动配置功能会自动对其进行配置。只要将schema.sql
和data.sql
脚本放置在/src/main/resources
目录下,Spring Boot就会自动执行它们。
schema.sql
文件用于创建数据库表结构,内容如下:
create table TBL_USER ( id int not null auto_increment, username varchar(255) not null, email varchar(255) not null, password varchar(255) not null, primary key (id) ); create table TBL_ACCOUNT ( id int not null auto_increment, accountNumber varchar(255) not null, user_id int not null, balance decimal(10, 2) not null, openDate date not null, primary key (id), foreign key (user_id) references TBL_USER(id) );
data.sql
文件用于插入初始数据,内容如下:
INSERT INTO TBL_USER (username, email, password) VALUES ('user1', 'user1@example.com', 'password1'), ('user2', 'user2@example.com', 'password2'), ('user3', 'user3@example.com', 'password3'), ('user4', 'user4@example.com', 'password4'), ('user5', 'user5@example.com', 'password5'), ('user6', 'user6@example.com', 'password6'), ('user7', 'user7@example.com', 'password7'), ('user8', 'user8@example.com', 'password8'), ('user9', 'user9@example.com', 'password9'), ('user10', 'user10@example.com', 'password10'); INSERT INTO TBL_ACCOUNT (accountNumber, user_id, balance, openDate) VALUES ('ACC001', 1, 1000.00, '2024-07-09'), ('ACC002', 1, 500.00, '2024-07-10'), ('ACC003', 2, 1500.00, '2024-07-09'), ('ACC004', 2, 200.00, '2024-07-10'), ('ACC005', 3, 800.00, '2024-07-09'), ('ACC006', 4, 3000.00, '2024-07-09'), ('ACC007', 4, 100.00, '2024-07-10'), ('ACC008', 5, 250.00, '2024-07-09'), ('ACC009', 6, 1800.00, '2024-07-09'), ('ACC010', 6, 700.00, '2024-07-10'), ('ACC011', 7, 500.00, '2024-07-09'), ('ACC012', 8, 1200.00, '2024-07-09'), ('ACC013', 9, 900.00, '2024-07-09'), ('ACC014', 9, 300.00, '2024-07-10'), ('ACC015', 10, 2000.00, '2024-07-09');
三、构建提示信息
编写详细、清晰的提示信息至关重要,这有助于生成可直接在应用中执行的正确SQL语句。下面的提示信息用于请求生成一条SQL SELECT
语句,以便从数据库中获取所需数据。这里不支持执行其他创建、更新或删除操作,如果需要支持这些操作,可以选择仅返回生成的查询语句作为API输出,而不实际执行它们。
sql-prompt-template.st Given the DDL in the DDL section, write an SQL query that answers the asked question in the QUESTION section. Only produce select queries. Do not append any text or markup in the start or end of response. Remove the markups such as ``` , sql , n as well. If the question would result in an insert, update, or delete, or if the query would alter the DDL in any way, say that the operation isn't supported. If the question can't be answered, say that the DDL doesn't support answering that question. QUESTION {question} DDL {ddl}
这段提示信息的意思是:根据“DDL部分”提供的数据库表结构定义(DDL),编写一条SQL查询语句来回答“QUESTION部分”提出的问题。只生成SELECT
查询语句,并且在响应的开头和结尾不要添加任何文本或标记,还要去除诸如“` 、sql 、n这类标记。如果问题会导致插入、更新、删除操作,或者查询会以任何方式改变DDL结构,就返回“操作不支持”。要是问题无法回答,就返回“DDL不支持回答该问题” 。在实际使用时,{question}
和{ddl}
会被具体的用户问题和数据库表结构定义所替代。
四、SQL控制器
接下来编写API,该API使用ChatClient API将输入的提示信息发送给LLM,并接收生成的SQL查询语句。如果生成的查询语句有效,就使用JdbcTemplate执行该查询并返回响应结果。
import java.io.IOException; import java.nio.charset.Charset; import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; @RestController public class SqlController { // 注入数据库表结构定义文件的资源对象 @Value("classpath:/schema.sql") private Resource ddlResource; // 注入提示信息模板文件的资源对象 @Value("classpath:/sql-prompt-template.st") private Resource sqlPromptTemplateResource; // 注入ChatClient实例,用于与LLM交互 private final ChatClient aiClient; // 注入JdbcTemplate实例,用于执行SQL查询 private final JdbcTemplate jdbcTemplate; public SqlController(ChatClient.Builder aiClientBuilder, JdbcTemplate jdbcTemplate) { // 构建ChatClient实例 this.aiClient = aiClientBuilder.build(); // 初始化JdbcTemplate实例 this.jdbcTemplate = jdbcTemplate; } @PostMapping(path = "/sql") public AiResponse sql(@RequestBody AiRequest request) throws IOException { // 读取数据库表结构定义文件的内容 String schema = ddlResource.getContentAsString(Charset.defaultCharset()); // 使用ChatClient发送提示信息给LLM,获取生成的SQL查询语句 String query = aiClient.prompt() .advisors(new SimpleLoggerAdvisor()) .user(userSpec -> userSpec // 设置提示信息模板文件 .text(sqlPromptTemplateResource) // 设置用户问题参数 .param("question", request.text()) // 设置数据库表结构定义参数 .param("ddl", schema) ) .call() .content(); // 判断生成的查询语句是否以select开头 if (query.toLowerCase().startsWith("select")) { // 如果是,执行查询并返回结果 return new AiResponse(query, jdbcTemplate.queryForList(query)); } // 如果不是,抛出异常 throw new AiException(query); } } // 定义请求数据结构,包含用户输入的文本 public record AiRequest(String text) { } // 定义响应数据结构,包含生成的SQL查询语句和查询结果 public record AiResponse(String sqlQuery, List<Map<String, Object>> results) { }
五、异常处理
并非所有用户请求都是有效的,有些用户可能会尝试生成修改数据库模式或存储数据的查询语句。对于这类情况,我们使用Spring ProblemDetail API以标准格式返回错误信息。
// 自定义异常类,继承自RuntimeException public class AiException extends RuntimeException { public AiException(String response) { super(response); } } import org.springframework.http.HttpStatus; import org.springframework.http.ProblemDetail; import org.springframework.web.bind.annotation.ControllerAdvice; import org.springframework.web.bind.annotation.ExceptionHandler; // 全局异常处理类,用于处理AiException异常 @ControllerAdvice public class CustomExceptionHandler { @ExceptionHandler(AiException.class) public ProblemDetail handle(AiException ex) { // 返回包含错误状态码和详细信息的ProblemDetail对象 return ProblemDetail.forStatusAndDetail(HttpStatus.EXPECTATION_FAILED, ex.getMessage()); } }
六、演示
启动嵌入式服务器中的应用,默认监听端口为8080
。接下来,发送一些有效的和无效的SQL生成请求,应用会根据情况做出相应响应。
- 用户请求:“Find the count of accounts.”
- API响应:
{ "sqlQuery": "select count(*) as account_count from TBL_ACCOUNT;", "results": [ { "ACCOUNT_COUNT": 15 } ] }
- 用户请求:“Sum of all accounts for all users.”
- API响应:
{ "sqlQuery": "select sum(balance) as total_balancernfrom TBL_ACCOUNT;", "results": [ { "TOTAL_BALANCE": 14750.00 } ] }
- 无效的用户请求:“Empty the account balance of user1.”
- API响应:
{ "type": "about:blank", "title": "Expectation Failed", "status": 417, "detail": "The operation isn't supported.", "instance": "/sql" }
七、总结
在本教程中,我们学习了如何利用Spring AI从LLM生成SQL语句,并借助Spring JDBC执行这些语句。这种SQL生成功能对编写基础以及涉及连接操作的复杂SQL语句的开发者来说非常实用。但要注意,出于安全和审计的考虑,务必对生成的SQL进行全面验证。由于LLM生成的SQL有时可能存在错误,所以在生产环境中使用前,一定要仔细检查这些SQL语句。