今天带大伙实现RAG(检索增强生成)里embedding模型处理文本向量化的过程,搞清楚embedding搜索功能到底是咋回事儿!

一、embedding到底是啥?

在RAG架构里,embedding可是实现文本向量化的关键一环。简单来说,它的核心操作就是把自然语言文本变成高维向量。为啥要这么干呢?因为有了这些向量,咱们就能实现基于语义的搜索啦。

打个比方,我们先把资料库中的文本,像文章标题、分类信息啥的,通过embedding模型转换成向量形式。用户提出问题时,也用同样的方法把问题转成向量。然后通过计算这两个向量之间的相似度,就能找到最符合用户意图的文本内容。这可比传统的搜索方式智能多了!要是想深入了解RAG,可以去看看“AI全栈必问的RAG是什么!”这篇文章。

二、简化版embedding实现流程

接下来,我给大家详细讲讲如何快速实现一个简化版的embedding应用,这里面涉及后端环境搭建、模型封装、文件读写,还有跨域处理这些关键步骤。

1. 环境初始化与模型封装

咱先从初始化后端Node.js环境开始,在命令行里敲上npm init -y,就能快速完成初始化。这和之前封装openai的操作有点类似,这次我们要封装的是embedding模型。在这个过程中,推荐大伙用dotenv模块来保护自己的API key,防止信息泄露,安全问题可不能马虎!

// 引入OpenAI和dotenv模块 import OpenAI from 'openai'; import dotenv from 'dotenv'; // 加载环境变量配置文件 dotenv.config({ path: '.env' }); // 创建OpenAI实例,设置API key和baseURL export const client = new OpenAI({ apiKey: process.env.OPENAI_API_KEY, baseURL: process.env.OPENAI_API_BASE_URL, }); 

2. 读写文件及调用embedding模型

文件读写这块,我们用fs/promises模块来操作,它能帮我们避免回调地狱,让代码逻辑更清晰。再结合async/await语法,代码读起来就更顺畅了。

我们从posts.json文件里读取那些需要向量化的文章数据,然后调用embedding模型,生成对应的向量,最后把这些结果存到新文件里。这里的数据格式可以自己模拟,像下面这样:

[ { "title": "如何使用Nuxt.js进行服务器端渲染", "category": "前端开发" }, // 其他数据可以仿照这个格式自行补充 ] 

文件目录结构大概是这样:

ai-server ├── data │ ├── posts_with_embeddings.json │ └── posts.json ├── node_modules ├──.env ├── app.service.mjs ├── create-embedding.mjs ├── index.mjs ├── package.json └── pnpm-lock.yaml 

具体代码实现如下:

// 引入fs/promises模块和之前创建的client实例 import fs from 'fs/promises'; import { client } from './app.service.mjs'; // 定义输入输出文件路径 const inputFilePath = './data/posts.json'; const outputFilePath = './data/posts_with_embeddings.json'; // 异步读取数据文件,并将其解析为JSON格式 const data = await fs.readFile(inputFilePath, 'utf8'); const posts = JSON.parse(data); // 用于存储带有embedding向量的文章数据 const postsWithEmbedding = []; // 遍历每篇文章,生成embedding向量 for (const { title, category } of posts) { const response = await client.embeddings.create({ model: 'text-embedding-ada-002', // 将文章标题和分类拼接作为输入 input: `标题:${title};分类:${category}` }); postsWithEmbedding.push({ title, category, // 提取生成的embedding向量 embedding: response.data[0].embedding }); } // 将生成embedding的结果写入到新文件中 await fs.writeFile(outputFilePath, JSON.stringify(postsWithEmbedding)); 

3. 构建后端服务并实现搜索接口

这里我们用Koa框架来搭建服务,借助@koa/cors处理跨域问题。因为前端传值一般是JSON格式,所以还得引入koa-bodyparser来自动解析请求体。

下面这段代码实现了监听3000端口,并且创建了一个/search接口。这个接口的作用就是接收查询关键字,生成向量,计算余弦相似度,最后返回最匹配的结果。

// 引入Koa、cors、Router、bodyParser等模块,以及之前创建的client实例和fs/promises模块 import Koa from 'koa'; import cors from '@koa/cors'; import Router from 'koa-router'; import bodyParser from 'koa-bodyparser'; import { client } from './app.service.mjs'; import fs from 'fs/promises'; // 定义存储带有embedding向量数据的文件路径 const inputFilePath = './data/posts_with_embeddings.json'; // 读取文件数据并解析为JSON格式 const data = await fs.readFile(inputFilePath, 'utf8'); const posts = JSON.parse(data); // 创建Koa应用实例和Router实例 const app = new Koa(); const router = new Router(); // 设置服务监听端口 const port = 3000; // 使用cors和bodyParser中间件 app.use(cors()); app.use(bodyParser()); // 使用路由处理请求 app.use(router.routes()); app.use(router.allowedMethods()); // 监听服务启动,打印提示信息 app.listen(port, () => { console.log(`Server is running on port ${port}`); }); // 计算余弦相似度的函数 function cosineSimilarity(a, b) { if (a.length!== b.length) { throw new Error('向量长度不匹配'); } let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < a.length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } // 定义搜索路由 router.post('/search', async (ctx) => { const { keword } = ctx.request.body; // 从请求体中获取关键字 console.log(keword); // 生成查询关键字的embedding向量 const response = await client.embeddings.create({ model: 'text-embedding-ada-002', input: keword, }); const { embedding } = response.data[0]; // 获取生成的向量 // 计算每篇文章与查询向量的相似度 const results = posts.map(item => ({ ...item, similarity: cosineSimilarity(embedding, item.embedding) })); // 按相似度降序排序,并提取最相似的前三条记录 const topResults = results.sort((a, b) => b.similarity - a.similarity) .slice(0, 3) .map((item, index) => ({ id: index, title: `${index + 1}.${item.title}, ${item.category}` })); ctx.body = { status: 200, data: topResults }; }); 

这里有个小细节要注意,sort方法会返回一个新数组,所以不能直接用data:results。可以在原results上链式调用sort,也可以用topResults接收新值再传给data

三、余弦相似度函数解析

上面代码里的余弦相似度函数,用来衡量两个向量在方向上的相似程度。它的取值范围一般在 -1到1之间,对于正向量来说,通常在0到1之间。这个值越接近1,就表示两个向量在空间中的方向越接近,语义上也就越相关;值越低,说明两个向量在语义上越不相关。

function cosineSimilarity(a, b) { if (a.length!== b.length) { throw new Error('向量长度不匹配'); } let dotProduct = 0; let normA = 0; let normB = 0; for (let i = 0; i < a.length; i++) { dotProduct += a[i] * b[i]; normA += a[i] * a[i]; normB += b[i] * b[i]; } return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB)); } 

具体实现步骤是这样的:

  • 判断长度是否一致:先检查两个向量的长度,如果不一样,就抛出错误,因为长度不同的向量没法计算相似度。
  • 计算点积:遍历向量,把对应位置的元素相乘,再把这些乘积加起来,得到点积。
  • 计算向量模:分别计算向量a和向量b各元素的平方和,然后对平方和开平方,得到两个向量的模。
  • 返回余弦相似度:把点积除以两个向量模的乘积,得到的结果就是两向量之间的相似度。

四、CORS配置扩展

默认情况下,我们允许所有跨域请求。但有时候,我们需要更细致地控制跨域访问,比如设置允许跨域的源、方法、请求头,以及是否允许携带凭据。下面这段代码就展示了具体的配置方法:

// 配置CORS app.use(cors({ origin: (ctx) => { const allowedOrigins = ['http://localhost:3000', 'http://example.com']; const requestOrigin = ctx.request.header.origin; if (allowedOrigins.includes(requestOrigin)) { return requestOrigin; // 允许该来源 } return ''; // 拒绝跨域请求 }, allowMethods: ['GET', 'POST'], // 允许的HTTP方法 allowHeaders: ['Content-Type', 'Authorization'], // 允许的请求头 credentials: true // 允许携带凭据 })); 

五、小结

通过上面这些代码示例,咱们一步步实现了利用embedding模型进行文本向量化,再结合余弦相似度计算,完成了基于自然语义的搜索功能。这种搜索方式比传统的搜索更精准,能处理各种复杂的文本匹配场景。