构建下一代智能问答系统:从检索-生成融合到主动学习闭环
摘要
传统问答系统多采用检索式或生成式单一架构,存在信息陈旧、语境理解不足等固有局限。本文将深入探讨基于检索-生成混合架构的现代问答系统核心组件设计,重点剖析多粒度检索增强、可控文本生成、离线评估框架与主动学习反馈循环四个关键技术模块,并提供可在生产环境扩展的Python实现范式。
1. 系统架构演进:从管道式到协同式
1.1 传统架构的瓶颈
传统问答系统通常采用串行管道:查询解析 → 文档检索 → 答案抽取/生成。这种架构的瓶颈在于:
- 误差传播:上游模块的错误会逐级放大
- 上下文碎片化:检索与生成模块间缺乏深层信息交换
- 静态知识局限:无法有效利用对话历史中的隐式反馈
1.2 协同增强架构设计
我们提出一种协同增强架构,核心思想是让检索与生成模块进行多轮交互:
Query → [语义理解层] → ┌─────────────────┼─────────────────┐ ↓ ↓ ↓ [密集检索] [稀疏检索] [实体链接] │ │ │ └─────→ [多证据融合] ←─────┘ ↓ [生成控制器] ←─ [用户反馈日志] ↓ [条件生成器] → [置信度校准] → Answer2. 多粒度检索增强模块
2.1 混合检索策略的实现
import numpy as np from typing import List, Tuple, Dict import torch from sentence_transformers import SentenceTransformer from rank_bm25 import BM25Okapi import faiss class HybridRetriever: def __init__(self, dense_model_name: str = 'all-mpnet-base-v2'): """ 初始化混合检索器 - 密集检索: Sentence-BERT + FAISS - 稀疏检索: BM25 + 查询扩展 - 图检索: 实体关系路径检索 """ self.dense_model = SentenceTransformer(dense_model_name) self.bm25_index = None self.faiss_index = None self.corpus_embeddings = None def build_hybrid_index(self, corpus: List[str], entities: List[List[str]] = None): """构建混合索引""" # 1. 密集向量索引 corpus_embeddings = self.dense_model.encode(corpus, show_progress_bar=True) dimension = corpus_embeddings.shape[1] self.faiss_index = faiss.IndexFlatIP(dimension) # 内积相似度 faiss.normalize_L2(corpus_embeddings) self.faiss_index.add(corpus_embeddings) # 2. 稀疏索引(BM25) tokenized_corpus = [doc.split() for doc in corpus] self.bm25_index = BM25Okapi(tokenized_corpus) # 3. 实体图索引(简化示例) self.entity_graph = self._build_entity_graph(entities) if entities else None def retrieve(self, query: str, top_k: int = 10, fusion_method: str = 'reciprocal_rank_fusion') -> List[Tuple[str, float]]: """ 混合检索核心算法 支持多种融合策略:RRF、加权分数融合、学习排序 """ # 密集检索 query_embedding = self.dense_model.encode([query]) faiss.normalize_L2(query_embedding) dense_scores, dense_indices = self.faiss_index.search(query_embedding, top_k*3) # 稀疏检索 tokenized_query = query.split() bm25_scores = self.bm25_index.get_scores(tokenized_query) bm25_indices = np.argsort(bm25_scores)[::-1][:top_k*3] # 融合策略 if fusion_method == 'reciprocal_rank_fusion': return self._rrf_fusion(dense_indices[0], bm25_indices, top_k) elif fusion_method == 'weighted_score': return self._weighted_fusion(dense_scores[0], bm25_scores, dense_indices[0], bm25_indices, top_k) def _rrf_fusion(self, dense_indices, bm25_indices, top_k, k=60): """倒数排名融合算法""" scores = {} for rank, idx in enumerate(dense_indices): scores[idx] = scores.get(idx, 0) + 1/(rank + k) for rank, idx in enumerate(bm25_indices): scores[idx] = scores.get(idx, 0) + 1/(rank + k) sorted_items = sorted(scores.items(), key=lambda x: x[1], reverse=True) return sorted_items[:top_k]2.2 查询理解与扩展
现代检索系统需要理解查询的深层意图:
class QueryUnderstandingModule: def __init__(self, ner_model, keyword_extractor): self.ner = ner_model self.keyword_extractor = keyword_extractor def parse_query(self, query: str) -> Dict: """多维度查询解析""" analysis = { 'entities': self._extract_entities(query), 'intent': self._classify_intent(query), 'aspects': self._extract_aspects(query), 'temporal_constraints': self._extract_time(query), 'query_type': self._determine_query_type(query) # factoid, how-to, comparison } # 查询重写 rewritten_queries = self._query_rewriting(query, analysis) analysis['rewritten_queries'] = rewritten_queries return analysis def _query_rewriting(self, original_query: str, analysis: Dict) -> List[str]: """基于分析的查询重写策略""" rewritten = [original_query] # 1. 实体消歧扩展 for entity in analysis['entities']: if entity['type'] in ['PERSON', 'LOCATION', 'ORGANIZATION']: # 添加同义词或相关实体 synonyms = self._get_entity_synonyms(entity['text']) for syn in synonyms[:2]: rewritten.append(original_query.replace(entity['text'], syn)) # 2. 意图明确化 if analysis['intent'] == 'comparison': # 为比较查询添加对比关键词 rewritten.append(original_query + " vs") # 3. 时间约束处理 if analysis['temporal_constraints']: # 为时间敏感查询添加日期范围 time_expr = analysis['temporal_constraints'][0] rewritten.append(f"{original_query} {time_expr}") return list(set(rewritten)) # 去重3. 可控生成与知识整合
3.1 基于检索内容的约束生成
import torch from transformers import T5ForConditionalGeneration, T5Tokenizer from typing import List, Optional class KnowledgeAwareGenerator: def __init__(self, model_name: str = 't5-base'): self.model = T5ForConditionalGeneration.from_pretrained(model_name) self.tokenizer = T5Tokenizer.from_pretrained(model_name) self.max_source_length = 512 self.max_target_length = 150 def generate_with_constraints(self, query: str, contexts: List[str], constraints: Optional[Dict] = None) -> str: """ 基于检索内容的约束生成 constraints示例: { 'must_include': ['2023年', 'AI'], 'must_not_include': ['可能', '大概'], 'style': 'formal', # formal, concise, detailed 'length_limit': 100 } """ # 构建增强输入 augmented_input = self._construct_input(query, contexts) # 编码约束条件 constraint_prompt = self._encode_constraints(constraints) if constraints else "" # 完整输入 full_input = f"基于以下信息回答: {augmented_input[:400]}... {constraint_prompt} 问题: {query}" # 生成参数设置 generation_config = { 'max_length': constraints.get('length_limit', self.max_target_length) if constraints else self.max_target_length, 'num_beams': 4, 'temperature': 0.7, 'no_repeat_ngram_size': 3, 'early_stopping': True, 'repetition_penalty': 2.0 # 降低重复 } # 生成 inputs = self.tokenizer.encode(full_input, return_tensors='pt', max_length=self.max_source_length, truncation=True) with torch.no_grad(): outputs = self.model.generate(inputs, **generation_config) answer = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # 后处理:约束检查与修正 if constraints: answer = self._apply_constraints_postprocessing(answer, constraints) return answer def _encode_constraints(self, constraints: Dict) -> str: """将约束编码为自然语言提示""" prompt_parts = [] if 'must_include' in constraints: items = ', '.join(constraints['must_include']) prompt_parts.append(f"必须包含: {items}") if 'must_not_include' in constraints: items = ', '.join(constraints['must_not_include']) prompt_parts.append(f"不要包含: {items}") if 'style' in constraints: style_map = { 'formal': '使用正式的语言风格', 'concise': '简洁明了地回答', 'detailed': '提供详细的解释' } prompt_parts.append(style_map.get(constraints['style'], '')) return "。".join(prompt_parts) def _apply_constraints_postprocessing(self, text: str, constraints: Dict) -> str: """后处理约束应用""" # 检查必须包含的内容 if 'must_include' in constraints: for item in constraints['must_include']: if item not in text: # 在合适位置插入 sentences = text.split('。') if len(sentences) > 1: sentences.insert(1, item) text = '。'.join(sentences) # 检查禁止内容 if 'must_not_include' in constraints: for forbidden in constraints['must_not_include']: text = text.replace(forbidden, '') return text3.2 可信度评估与不确定性量化
class ConfidenceEstimator: """生成结果的可信度评估模块""" def estimate_confidence(self, query: str, answer: str, source_contexts: List[str], generator_logits: Optional[torch.Tensor] = None) -> Dict[str, float]: """ 多维度可信度评估 返回各个维度的置信度分数 """ scores = {} # 1. 语义一致性分数 scores['semantic_consistency'] = self._calc_semantic_consistency(query, answer, source_contexts) # 2. 生成模型置信度(基于概率) if generator_logits is not None: scores['generation_confidence'] = self._calc_generation_confidence(generator_logits) # 3. 事实一致性分数 scores['factual_consistency'] = self._calc_factual_consistency(answer, source_contexts) # 4. 自洽性检查(多次采样) scores['self_consistency'] = self._self_consistency_check(query, answer) # 综合置信度 weights = { 'semantic_consistency': 0.3, 'generation_confidence': 0.2, 'factual_consistency': 0.4, 'self_consistency': 0.1 } scores['overall_confidence'] = sum(scores[k] * weights.get(k, 0) for k in scores if k in weights) return scores def _calc_semantic_consistency(self, query: str, answer: str, contexts: List[str]) -> float: """计算答案与查询、上下文的语义一致性""" # 使用句子编码器计算相似度 encoder = SentenceTransformer('all-MiniLM-L6-v2') query_emb = encoder.encode([query]) answer_emb = encoder.encode([answer]) context_emb = encoder.encode([" ".join(contexts[:3])]) # 取前三段上下文 # 计算余弦相似度 sim_query_answer = cosine_similarity(query_emb, answer_emb)[0][0] sim_answer_context = cosine_similarity(answer_emb, context_emb)[0][0] # 加权组合 return 0.6 * sim_query_answer + 0.4 * sim_answer_context def _self_consistency_check(self, query: str, answer: str, n_samples: int = 3) -> float: """通过多次采样检查答案一致性""" # 使用不同参数生成多个答案 alternative_answers = [] for temp in [0.6, 0.8, 1.0]: alt = self._generate_alternative(query, temperature=temp) alternative_answers.append(alt) # 计算与原始答案的相似度 similarities = [] encoder = SentenceTransformer('all-MiniLM-L6-v2') orig_emb = encoder.encode([answer]) for alt in alternative_answers: alt_emb = encoder.encode([alt]) sim = cosine_similarity(orig_emb, alt_emb)[0][0] similarities.append(sim) return np.mean(similarities)4. 离线评估框架设计
4.1 多维度评估指标
class QAEvaluator: """问答系统多维度评估框架""" def __init__(self): self.metrics = { 'retrieval': self.evaluate_retrieval, 'generation': self.evaluate_generation, 'end_to_end': self.evaluate_end_to_end } def comprehensive_evaluation(self, test_dataset: List[Dict], system_predictions: List[Dict]) -> Dict: """ 综合评估报告 test_dataset: [{'query': ..., 'gold_answer': ..., 'gold_contexts': [...]}] system_predictions: [{'answer': ..., 'retrieved_contexts': [...], 'confidence': ...}] """ report = {} # 1. 检索质量评估 report['retrieval_metrics'] = self._evaluate_retrieval_quality( [p['retrieved_contexts'] for p in system_predictions], [d['gold_contexts'] for d in test_dataset] ) # 2. 生成质量评估 report['generation_metrics'] = self._evaluate_generation_quality( [p['answer'] for p in system_predictions], [d['gold_answer'] for d in test_dataset] ) # 3. 端到端评估 report['end_to_end_metrics'] = self._evaluate_end_to_end( test_dataset, system_predictions ) # 4. 错误分析 report['error_analysis'] = self._analyze_errors(test_dataset, system_predictions) return report def _evaluate_retrieval_quality(self, retrieved_contexts_list, gold_contexts_list): """检索质量评估""" scores = { 'recall@k': [], 'precision@k': [], 'mrr': [], # 平均倒数排名 'ndcg@