AI教我做事之RAG开发-12 处理多用户会话
1 概述
在RAG(检索增强生成)中,处理多用户会话需要确保每个用户的查询历史和上下文是独立的,这样可以为每个用户提供个性化的响应。以下是实现这一目标的简单方法:
- 用户识别:为每个用户分配一个唯一ID,例如用户名或会话令牌。
- 会话管理:使用LangChain的记忆功能(如ConversationBufferMemory)为每个用户维护独立的对话历史。
- 请求处理:通过Web框架(如Flask)接收用户请求,根据用户ID获取对应的对话链(chain),生成响应并更新历史。
例如,你可以通过一个Flask服务器接收POST请求,包含用户ID和查询内容,然后返回个性化的回答。
2 实现步骤
- 设置环境:
- 安装必要的库,如LangChain、Flask和OpenAI。
- 准备一个向量存储(如FAISS)来存储检索文档。
- 创建会话管理器:
- 定义一个类(如UserSessionManager)来管理用户特定的对话链,每个链包含独立的记忆。
- 当新用户首次请求时,为其创建新的记忆和链。
- 处理用户请求:
- 用户通过API发送请求,包含用户ID和查询。
- 根据用户ID获取对应的链,调用LangChain的ConversationalRetrievalChain生成响应。
- 返回响应并更新用户的对话历史。
一个意外的细节是,这种方法不仅适用于实时聊天,还可以扩展到持久化存储(如数据库),以在应用重启后保留用户历史。
3 完整代码示意
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from flask import Flask, request, jsonify
from langchain_ollama import OllamaLLM
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaLLM
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import pipeline
from langchain.prompts import PromptTemplate
from pathlib import Path
import torch
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationSummaryMemory
import time
from datetime import datetime, timedelta
# 为支持多用户,需要一个管理类来跟踪每个用户的对话链(chain)。每个链应包含独立的记忆(memory)
# 同时针对不活跃的链进行定期清理
class UserSessionManager:
def __init__(self, llm, vector_store, inactivity_timeout=300): # 5分钟不活跃超时
self.user_chains = {} # 存储活跃用户链
self.user_last_active = {} # 跟踪最后活跃时间
self.llm = llm
self.vector_store = vector_store
self.inactivity_timeout = inactivity_timeout
def cleanup_inactive(self):
"""清理不活跃用户的链"""
current_time = datetime.now()
inactive_users = [
user_id for user_id, last_active in self.user_last_active.items()
if (current_time - last_active).total_seconds() > self.inactivity_timeout
]
for user_id in inactive_users:
del self.user_chains[user_id]
del self.user_last_active[user_id]
print(f"Removed inactive user {user_id}")
# 5. 从本地维持的列表中查询对应ID的chain信息,没有就新建
def get_chain(self, user_id):
if user_id not in self.user_chains:
# 使用摘要记忆减少内存
memory = ConversationSummaryMemory(llm=self.llm, max_token_limit=1000)
# 自定义提示模板,只要求简洁回答
prompt_template = """根据以下上下文,直接用中文回答问题,不要添加多余内容。
上下文: {context}
问题: {question}
回答: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=self.vector_store.as_retriever(),
memory=memory,
chain_type_kwargs={"prompt": PROMPT} # 使用自定义提示
)
self.user_chains[user_id] = chain
#记录该用户最后的活跃时间
self.user_last_active[user_id] = datetime.now()
return self.user_chains[user_id]
# 我们自定义的一个基本封装类,用于处理RAG基本行为
class FRAG:
# 1. 加载和处理文档
def load_documents(self,file_path):
# 使用PyPDFLoader加载PDF文档
loader = PyPDFLoader(file_path)
documents = loader.load()
return documents
def split_document(self,documents):
# 分割文档为小块,便于检索,每个块约500字符,适合短文档;对于长文档可调整到1000
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.split_documents(documents)
return docs
# 1. 加载该目录下的所有类型文件
def load_documents_from_dir(self,directory_path):
loader = DirectoryLoader(directory_path,glob="**/*", show_progress=True)
# 如果目录包含多种文件类型(如 .txt 和 .pdf),默认的 UnstructuredFileLoader 通常能处理,但对于特定类型,可能希望使用专用加载器以优化性能
# glob 参数支持通配符:
# **/*.txt:匹配所有子目录下的 .txt 文件。
# **/*:匹配所有文件(包括子目录)。
# 多个模式可以通过列表传递,例如 glob=["**/*.txt", "**/*.pdf"],但需注意所有文件仍使用同一加载器。
documents = loader.load()
return documents
def add_docs_to_vector_store(self,docs):
self.vector_store.add_documents(docs)
# 2. 创建向量存储
def create_vector_store(self,docs):
# 使用HuggingFace的嵌入模型生成向量,all-MiniLM-L6-v2是一个轻量高效的嵌入模型,适合快速开发;生产环境可考虑更强的模型如all-mpnet-base-v2。
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# 使用FAISS创建向量存储
vector_store = FAISS.from_documents(docs, embeddings)
return vector_store
# 3. 初始化语言模型
def initialize_llm(self):
# 使用Hugging Face的生成模型(这里用distilgpt2作为示例),Demo中使用distilgpt2(轻量但生成质量一般);建议替换为gpt-neo-1.3B或LLaMA(需本地部署)。
# text_generation_pipeline = pipeline("text-generation", model="distilgpt2", max_length=200)
# llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
myLLM = OllamaLLM(model="llama2")
#myLLM = OllamaLLM(model="qwen2")
return myLLM
# 4. 创建RAG链
def create_rag_chain(self,vector_store, llm):
# 自定义提示模板,只要求简洁回答
prompt_template = """根据以下上下文,直接用中文回答问题,不要添加多余内容。
上下文: {context}
问题: {question}
回答: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# 创建RetrievalQA链
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
return_source_documents=False, # 不返回来源文档
chain_type_kwargs={"prompt": PROMPT} # 使用自定义提示
)
return qa_chain
app = Flask(__name__)
# 加载文档和创建Vector store暴露在外面创建,更灵活。
# 创建全局Session管理对象
myBasicRag=FRAG()
dir_path = Path('testdata')
docs1=myBasicRag.load_documents_from_dir(dir_path)
docs2=myBasicRag.split_document(docs1)
session_manager = UserSessionManager(myBasicRag.initialize_llm(), myBasicRag.create_vector_store(docs2))
# 对外暴露的可供http调用的post类型接口
@app.route("/query", methods=["POST"])
def handle_query():
data = request.get_json()
user_id = data["user_id"]
query = data["query"]
chain = session_manager.get_chain(user_id)
response = chain({"question": query})
return jsonify({"answer": response["answer"]})
if __name__ == "__main__":
app.run()
4 技术细节与扩展
- 检索与生成分离:默认情况下,ConversationalRetrievalChain基于当前问题检索文档,然后结合对话历史生成响应。检索仅依赖当前问题,生成则考虑历史。这可能导致跟进问题(如“巴黎的人口是多少?”)的检索不够上下文相关。若需改进,可定制检索器,将历史总结加入查询。
- 并发性:多用户同时请求时,由于每个用户有独立链,理论上不会干扰。但若LLM或向量存储有并发限制,需考虑异步处理或限流。
- 持久化:当前实现内存中的记忆在应用重启后丢失。为支持持久化,可使用LangChain的MongoDBChatMessageHistory等,将对话历史存储到数据库。
- 最佳实践
- 性能优化:对于大量用户,创建多个链可能耗费内存。可考虑按需加载记忆,或使用更高效的记忆实现(如摘要记忆)。
- 安全考虑:确保用户ID验证,防止恶意用户篡改他人会话。
- 扩展性:若用户有不同权限,可为不同用户组配置独立的向量存储。
结论
通过LangChain的ConversationalRetrievalChain和自定义UserSessionManager,可以有效处理RAG中的多用户会话。每个用户有独立的对话历史,确保上下文隔离。结合Flask等Web框架,可构建实用的多用户RAG应用。未来可考虑持久化存储和定制检索以提升性能。