import os
import time
import json5
import json
import yaml
import fnmatch
from pathlib import Path
from typing import List, Optional, Dict, Any
from langchain_community.document_loaders import (
TextLoader,
PyPDFLoader,
DirectoryLoader,
)
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.documents import Document
from newspaper import Article, Config
os.environ['HF_TOKEN'] = '你的_token_here'
# ==================== 配置加载 ====================
with open("config.yaml", "r", encoding="utf-8") as f:
config = yaml.safe_load(f)
# ==================== 初始化嵌入模型 ====================
embedding_model_name = config["embedding"]["model_name"]
embedding_device = config["embedding"]["device"]
embeddings = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs={"device": embedding_device},
encode_kwargs={"normalize_embeddings": True},
)
# ==================== 初始化向量库 ====================
persist_dir = config["knowledge_base"]["chroma_persist_dir"]
vectorstore = Chroma(
persist_directory=persist_dir,
embedding_function=embeddings,
)
# ==================== 文件状态管理 ====================
FILE_STATE_PATH = "file_state.json" # 存储每个文件的路径和最后修改时间
def _load_file_state() -> Dict[str, float]:
"""加载文件状态索引,返回 {文件路径: 修改时间}"""
if os.path.exists(FILE_STATE_PATH):
with open(FILE_STATE_PATH, "r", encoding="utf-8") as f:
return json.load(f)
return {}
def _save_file_state(state: Dict[str, float]) -> None:
"""保存文件状态索引"""
with open(FILE_STATE_PATH, "w", encoding="utf-8") as f:
json.dump(state, f, indent=2, ensure_ascii=False)
def _get_file_mtime(file_path: str) -> float:
"""获取文件的最后修改时间戳"""
return os.path.getmtime(file_path)
# ==================== 多源配置加载 ====================
SOURCES_FILE = "knowledge_sources.json5" # 知识库源配置文件
def _load_sources() -> List[Dict[str, Any]]:
"""加载知识库源配置,支持 JSON5 格式(允许注释)"""
if not os.path.exists(SOURCES_FILE):
print(f"警告: 未找到 {SOURCES_FILE},将使用旧版单一目录模式")
return []
try:
# 修改点:使用 json5.loads 读取
with open(SOURCES_FILE, "r", encoding="utf-8") as f:
data = json5.load(f) # 或者 json5.loads(f.read())
return data.get("sources", [])
except Exception as e:
print(f"读取配置文件失败: {e}")
return []
def _should_include_file(file_path: str, source_config: Dict[str, Any]) -> bool:
"""
判断文件是否应被纳入知识库
规则:
- 如果 source_config 包含 include_extensions,则文件扩展名必须在其中
- 如果 source_config 包含 exclude_dirs,则文件路径中的任何部分不能匹配这些目录名
- 如果 source_config 包含 exclude_files,则文件名不能匹配任意通配符模式
"""
# 扩展名检查
ext = os.path.splitext(file_path)[1].lower()
allowed_exts = source_config.get("include_extensions", [])
if allowed_exts and ext not in allowed_exts:
return False
# 排除目录检查(精确匹配路径中的目录名)
exclude_dirs = source_config.get("exclude_dirs", [])
path_parts = Path(file_path).parts
for ex_dir in exclude_dirs:
if ex_dir in path_parts:
return False
# 排除文件名模式检查(支持通配符)
exclude_files = source_config.get("exclude_files", [])
base_name = os.path.basename(file_path)
for pattern in exclude_files:
if fnmatch.fnmatch(base_name, pattern):
return False
return True
def _load_documents_from_file(file_path: str) -> List[Document]:
"""
根据文件路径(本地路径或URL)加载文档。
"""
# 判断是否为 URL
if file_path.startswith("http://") or file_path.startswith("https://"):
try:
# 配置爬虫(模拟浏览器,设置超时)
config = Config()
config.request_timeout = 10
config.browser_user_agent = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
)
article = Article(file_path, config=config)
article.download()
article.parse()
# 创建 LangChain Document 对象
doc = Document(
page_content=article.text,
metadata={
"source": file_path,
"title": article.title,
"type": "web_page",
},
)
return [doc]
except Exception as e:
print(f"抓取网页 {file_path} 失败: {e}")
return []
# 以下是原有的本地文件处理逻辑
else:
ext = os.path.splitext(file_path)[1].lower()
try:
if ext == ".pdf":
loader = PyPDFLoader(file_path)
else:
loader = TextLoader(file_path, encoding="utf-8")
docs = loader.load()
for doc in docs:
doc.metadata["source"] = file_path
return docs
except Exception as e:
print(f"加载文件 {file_path} 失败: {e}")
return []
# ==================== 知识库更新(增量) ====================
def update_knowledge_base():
"""
根据 knowledge_sources.json5 中的配置增量更新知识库。
检测新增、修改、删除的文件,并同步到向量库。
"""
sources = _load_sources()
if not sources:
print("未找到任何知识库源,请创建 knowledge_sources.json5")
return
# 加载上一次的文件状态(记录文件路径和修改时间)
file_state = _load_file_state()
# 构建本次扫描到的文件状态
new_file_state = {}
# 记录本次处理的所有文件路径,用于后续判断哪些文件被删除了
all_files_processed = set()
for src in sources:
print(f"处理源: {src['name']} ({src['type']})")
if src["type"] == "directory":
root_dir = src["path"]
if not os.path.exists(root_dir):
print(f" 路径不存在,跳过: {root_dir}")
continue
# 递归遍历目录
for root, dirs, files in os.walk(root_dir):
# 跳过排除目录
exclude_dirs = src.get("exclude_dirs", [])
dirs[:] = [d for d in dirs if d not in exclude_dirs]
for file in files:
file_path = os.path.join(root, file)
# 检查文件扩展名是否符合要求
if not _should_include_file(file_path, src):
continue
# 获取文件的最后修改时间
try:
current_mtime = os.path.getmtime(file_path)
# 构建文件的唯一标识(这里直接用路径)
file_key = file_path
all_files_processed.add(file_key)
# --- 增量更新逻辑 ---
# 如果文件是新的,或者文件已被修改,则重新加载
if (
file_key not in file_state
or file_state[file_key] != current_mtime
):
print(f" 更新: {file_path}")
# 1. 如果向量库中已存在旧数据,先删除(保证数据一致性)
try:
result = vectorstore.get(where={"source": file_key})
if result["ids"]:
vectorstore.delete(ids=result["ids"])
except Exception as e:
print(f" 删除旧数据失败 (可能首次加载): {e}")
# 2. 加载文档
docs = _load_documents_from_file(file_key)
if docs:
# 3. 切分文本
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config["knowledge_base"]["chunk_size"],
chunk_overlap=config["knowledge_base"][
"chunk_overlap"
],
separators=["\n\n", "\n", " ", ""],
)
chunks = text_splitter.split_documents(docs)
# 4. 添加到向量库
if chunks:
vectorstore.add_documents(chunks)
vectorstore.persist() # 立即持久化
print(f" 已添加 {len(chunks)} 个块")
# 更新本次的状态(无论是否变动,都记录当前时间)
new_file_state[file_key] = current_mtime
except Exception as e:
print(f"处理文件失败 {file_path}: {e}")
elif src["type"] == "urls":
# 处理 URL 列表
urls = src.get("urls", [])
for url in urls:
print(f" 处理网页: {url}")
all_files_processed.add(url) # 将 URL 视为一种特殊的文件路径
# 网页没有文件系统时间戳,策略:每次都强制更新(或根据ETag优化,此处简化为强制更新)
# 1. 删除旧数据
try:
result = vectorstore.get(where={"source": url})
if result["ids"]:
vectorstore.delete(ids=result["ids"])
vectorstore.persist()
except Exception as e:
print(f" 清理旧网页数据失败: {e}")
# 2. 抓取并加载
docs = _load_documents_from_file(url)
if docs:
# 3. 切分并添加
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config["knowledge_base"]["chunk_size"],
chunk_overlap=config["knowledge_base"]["chunk_overlap"],
separators=["\n\n", "\n", " ", ""],
)
chunks = text_splitter.split_documents(docs)
if chunks:
vectorstore.add_documents(chunks)
vectorstore.persist()
print(f" 添加了 {len(chunks)} 个文档块")
# 网页状态记录:记录当前时间为“最后抓取时间”
# 这样下次重启时,这个 URL 会被标记为已处理
new_file_state[url] = time.time()
# --- 处理文件删除逻辑 ---
# 遍历上一次的状态,如果某个文件在本次扫描中没有出现,说明被删除了
for old_path in file_state:
if old_path not in all_files_processed:
print(f"检测到文件已删除: {old_path}")
try:
result = vectorstore.get(where={"source": old_path})
ids_to_delete = result["ids"]
if ids_to_delete:
vectorstore.delete(ids=ids_to_delete)
vectorstore.persist()
print(f" 已从向量库中移除 {len(ids_to_delete)} 个块")
except Exception as e:
print(f" 删除向量块失败: {e}")
# 注意:不需要将该路径写入 new_file_state,即自动从状态中剔除
# 保存本次更新后的状态
_save_file_state(new_file_state)
print("知识库增量更新完成")
# ==================== 原有函数(保持兼容) ====================
def load_knowledge_base(reload: bool = False):
"""
[旧版] 从配置的单一根目录加载所有文档,并更新向量库。
如果 reload=True,会清空现有数据后重建;否则增量更新(按文件路径去重)。
注意:此函数保留仅为兼容,推荐使用 update_knowledge_base() 进行多源增量更新。
"""
root_dir = config["knowledge_base"]["root_dir"]
if not os.path.exists(root_dir):
print(f"知识库目录不存在: {root_dir}")
return
# 支持的文档扩展名
extensions = [".txt", ".md", ".pdf", ".py", ".js", ".c", ".cpp", ".h", ".hpp"]
loaders = []
for ext in extensions:
loaders.append(
DirectoryLoader(
root_dir,
glob=f"**/*{ext}",
loader_cls=TextLoader if ext != ".pdf" else PyPDFLoader,
loader_kwargs={"encoding": "utf-8"} if ext != ".pdf" else {},
recursive=True,
show_progress=True,
)
)
all_docs = []
for loader in loaders:
try:
docs = loader.load()
all_docs.extend(docs)
except Exception as e:
print(f"加载 {loader} 时出错: {e}")
if not all_docs:
print("未找到任何文档")
return
# 文本切分
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=config["knowledge_base"]["chunk_size"],
chunk_overlap=config["knowledge_base"]["chunk_overlap"],
separators=["\n\n", "\n", " ", ""],
)
chunks = text_splitter.split_documents(all_docs)
if reload:
# 清空现有数据
vectorstore.delete_collection()
vectorstore.persist()
# 批量添加到向量库
vectorstore.add_documents(chunks)
vectorstore.persist()
print(f"成功加载 {len(chunks)} 个文档块")
def retrieve(query: str, top_k: int = 5) -> List[Document]:
"""
根据查询文本检索最相关的文档块。
"""
results = vectorstore.similarity_search(query, k=top_k)
return results
# ==================== 测试入口 ====================
if __name__ == "__main__":
print("测试嵌入模型...")
test_text = "Hello, world!"
emb = embeddings.embed_query(test_text)
print(f"嵌入维度: {len(emb)}")
# 增量更新函数(需要 knowledge_sources.json)
print("\n开始增量更新知识库...")
update_knowledge_base()
print("\n测试检索...")
results = retrieve("代码审查", top_k=2)
for i, doc in enumerate(results):
print(f"\n--- 结果 {i+1} ---")
print(doc.page_content[:200])