2025-05-18 23:40:12 +08:00

117 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from flask import Flask, request, jsonify
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import json
app = Flask(__name__)
# 创建一个全局的模型缓存字典
model_cache = {}
# 分割文本块
def split_text(text, block_size, overlap_chars, delimiter):
chunks = text.split(delimiter)
text_blocks = []
current_block = ""
for chunk in chunks:
if len(current_block) + len(chunk) + 1 <= block_size:
if current_block:
current_block += " " + chunk
else:
current_block = chunk
else:
text_blocks.append(current_block)
current_block = chunk
if current_block:
text_blocks.append(current_block)
overlap_blocks = []
for i in range(len(text_blocks)):
if i > 0:
overlap_block = text_blocks[i - 1][-overlap_chars:] + text_blocks[i]
overlap_blocks.append(overlap_block)
overlap_blocks.append(text_blocks[i])
return overlap_blocks
# 文本向量化
def vectorize_text_blocks(text_blocks, model):
return model.encode(text_blocks)
# 文本检索
def retrieve_top_k(query, knowledge_base, k, block_size, overlap_chars, delimiter, model):
# 将知识库拆分为文本块
text_blocks = split_text(knowledge_base, block_size, overlap_chars, delimiter)
# 向量化文本块
knowledge_vectors = vectorize_text_blocks(text_blocks, model)
# 向量化查询文本
query_vector = model.encode([query]).reshape(1, -1)
# 计算相似度
similarities = cosine_similarity(query_vector, knowledge_vectors)
# 获取相似度最高的 k 个文本块的索引
top_k_indices = similarities[0].argsort()[-k:][::-1]
# 返回文本块和它们的向量
top_k_texts = [text_blocks[i] for i in top_k_indices]
top_k_embeddings = [knowledge_vectors[i] for i in top_k_indices]
return top_k_texts, top_k_embeddings
@app.route('/vectorize', methods=['POST'])
def vectorize_text():
# 从请求中获取 JSON 数据
data = request.json
print(f"Received request data: {data}") # 调试输出请求数据
text_list = data.get("text", [])
model_name = data.get("model_name", "msmarco-distilbert-base-tas-b") # 默认模型
delimiter = data.get("delimiter", "\n") # 默认分隔符
k = int(data.get("k", 3)) # 默认检索条数
block_size = int(data.get("block_size", 500)) # 默认文本块大小
overlap_chars = int(data.get("overlap_chars", 50)) # 默认重叠字符数
if not text_list:
return jsonify({"error": "Text is required."}), 400
# 检查模型是否已经加载
if model_name not in model_cache:
try:
model = SentenceTransformer(model_name)
model_cache[model_name] = model # 缓存模型
except Exception as e:
return jsonify({"error": f"Failed to load model: {e}"}), 500
model = model_cache[model_name]
top_k_texts_all = []
top_k_embeddings_all = []
# 如果只有一个查询文本
if len(text_list) == 1:
top_k_texts, top_k_embeddings = retrieve_top_k(text_list[0], text_list[0], k, block_size, overlap_chars, delimiter, model)
top_k_texts_all.append(top_k_texts)
top_k_embeddings_all.append(top_k_embeddings)
elif len(text_list) > 1:
# 如果多个查询文本,依次处理
for query in text_list:
top_k_texts, top_k_embeddings = retrieve_top_k(query, text_list[0], k, block_size, overlap_chars, delimiter, model)
top_k_texts_all.append(top_k_texts)
top_k_embeddings_all.append(top_k_embeddings)
# 将嵌入向量ndarray转换为可序列化的列表
top_k_embeddings_all = [[embedding.tolist() for embedding in embeddings] for embeddings in top_k_embeddings_all]
print(f"Top K texts: {top_k_texts_all}") # 打印检索到的文本
print(f"Top K embeddings: {top_k_embeddings_all}") # 打印检索到的向量
# 返回 JSON 格式的数据
return jsonify({
"topKEmbeddings": top_k_embeddings_all # 返回嵌入向量
})
if __name__ == '__main__':
app.run(host="0.0.0.0", port=5000, debug=True)