AI教我做事之RAG开发-13 用链池增强处理多用户会话的性能
要增强RAG中多用户会话的链缓存机制,减少内存消耗,可以考虑使用链池方式,类似于内存池或数据库连接池。以下是实现的关键步骤:
- 共享组件:矢量存储和语言模型(LLM)已共享,减少重复资源。
- 优化记忆:使用摘要记忆(ConversationSummaryMemory)替代完整历史,降低每个用户的内存占用。
- 限制内存:为记忆设置最大令牌限制(max_token_limit),控制数据量。
- 管理不活跃用户:定期卸载不活跃用户的链,将历史存储到数据库,仅在需要时加载。
一个意外的细节是,这种方法不仅减少内存,还能支持应用重启后恢复用户会话。
完整代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from flask import Flask, request, jsonify
from langchain_ollama import OllamaLLM
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_ollama import OllamaLLM
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import pipeline
from langchain.prompts import PromptTemplate
from pathlib import Path
import torch
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationSummaryMemory
import time
from datetime import datetime, timedelta
# 为支持多用户,需要一个管理类来跟踪每个用户的对话链(chain)。每个链应包含独立的记忆(memory)
# 同时针对不活跃的链进行定期清理
class UserSessionManager:
def __init__(self, llm, vector_store, inactivity_timeout=300): # 5分钟不活跃超时
self.user_chains = {} # 存储活跃用户链
self.user_last_active = {} # 跟踪最后活跃时间
self.llm = llm
self.vector_store = vector_store
self.inactivity_timeout = inactivity_timeout
def cleanup_inactive(self):
"""清理不活跃用户的链"""
current_time = datetime.now()
inactive_users = [
user_id for user_id, last_active in self.user_last_active.items()
if (current_time - last_active).total_seconds() > self.inactivity_timeout
]
for user_id in inactive_users:
del self.user_chains[user_id]
del self.user_last_active[user_id]
print(f"Removed inactive user {user_id}")
# 5. 从本地维持的列表中查询对应ID的chain信息,没有就新建
def get_chain(self, user_id):
if user_id not in self.user_chains:
# 使用摘要记忆减少内存
memory = ConversationSummaryMemory(llm=self.llm, max_token_limit=1000)
# 自定义提示模板,只要求简洁回答
prompt_template = """根据以下上下文,直接用中文回答问题,不要添加多余内容。
上下文: {context}
问题: {question}
回答: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
chain = ConversationalRetrievalChain.from_llm(
llm=self.llm,
retriever=self.vector_store.as_retriever(),
memory=memory,
chain_type_kwargs={"prompt": PROMPT} # 使用自定义提示
)
self.user_chains[user_id] = chain
#记录该用户最后的活跃时间
self.user_last_active[user_id] = datetime.now()
return self.user_chains[user_id]
# 我们自定义的一个基本封装类,用于处理RAG基本行为
class FRAG:
# 1. 加载和处理文档
def load_documents(self,file_path):
# 使用PyPDFLoader加载PDF文档
loader = PyPDFLoader(file_path)
documents = loader.load()
return documents
def split_document(self,documents):
# 分割文档为小块,便于检索,每个块约500字符,适合短文档;对于长文档可调整到1000
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.split_documents(documents)
return docs
# 1. 加载该目录下的所有类型文件
def load_documents_from_dir(self,directory_path):
loader = DirectoryLoader(directory_path,glob="**/*", show_progress=True)
# 如果目录包含多种文件类型(如 .txt 和 .pdf),默认的 UnstructuredFileLoader 通常能处理,但对于特定类型,可能希望使用专用加载器以优化性能
# glob 参数支持通配符:
# **/*.txt:匹配所有子目录下的 .txt 文件。
# **/*:匹配所有文件(包括子目录)。
# 多个模式可以通过列表传递,例如 glob=["**/*.txt", "**/*.pdf"],但需注意所有文件仍使用同一加载器。
documents = loader.load()
return documents
def add_docs_to_vector_store(self,docs):
self.vector_store.add_documents(docs)
# 2. 创建向量存储
def create_vector_store(self,docs):
# 使用HuggingFace的嵌入模型生成向量,all-MiniLM-L6-v2是一个轻量高效的嵌入模型,适合快速开发;生产环境可考虑更强的模型如all-mpnet-base-v2。
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# 使用FAISS创建向量存储
vector_store = FAISS.from_documents(docs, embeddings)
return vector_store
# 3. 初始化语言模型
def initialize_llm(self):
# 使用Hugging Face的生成模型(这里用distilgpt2作为示例),Demo中使用distilgpt2(轻量但生成质量一般);建议替换为gpt-neo-1.3B或LLaMA(需本地部署)。
# text_generation_pipeline = pipeline("text-generation", model="distilgpt2", max_length=200)
# llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
myLLM = OllamaLLM(model="llama2")
#myLLM = OllamaLLM(model="qwen2")
return myLLM
# 4. 创建RAG链
def create_rag_chain(self,vector_store, llm):
# 自定义提示模板,只要求简洁回答
prompt_template = """根据以下上下文,直接用中文回答问题,不要添加多余内容。
上下文: {context}
问题: {question}
回答: """
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# 创建RetrievalQA链
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
return_source_documents=False, # 不返回来源文档
chain_type_kwargs={"prompt": PROMPT} # 使用自定义提示
)
return qa_chain
app = Flask(__name__)
# 加载文档和创建Vector store暴露在外面创建,更灵活。
# 创建全局Session管理对象
myBasicRag=FRAG()
dir_path = Path('testdata')
docs1=myBasicRag.load_documents_from_dir(dir_path)
docs2=myBasicRag.split_document(docs1)
session_manager = UserSessionManager(myBasicRag.initialize_llm(), myBasicRag.create_vector_store(docs2))
# 对外暴露的可供http调用的post类型接口
@app.route("/query", methods=["POST"])
def handle_query():
data = request.get_json()
user_id = data["user_id"]
query = data["query"]
chain = session_manager.get_chain(user_id)
response = chain({"question": query})
return jsonify({"answer": response["answer"]})
if __name__ == "__main__":
app.run()
上述的代码中,我们封装了2个类。
- UserSessionManage负责管理多用户状态
- FRAG负责管理基础的RAG行为操作调用