1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
|
@Service
@Slf4j
public class KnowledgeBaseQueryService {
private static final String NO_RESULT_RESPONSE = "抱歉,在选定的知识库中未检索到相关信息。请换一个更具体的关键词或补充上下文后再试。";
private static final Pattern SHORT_TOKEN_PATTERN = Pattern.compile("^[\\p{L}\\p{N}_-]{2,20}$");
// 中文疑问前缀:提取核心词
private static final Pattern ZH_QUESTION_PREFIX = Pattern.compile(
"^(?:什么是|如何|怎么|怎样|为什么|什么叫|什么叫做|讲一下|解释一下|介绍一下|说一下|谈谈|描述)(.+)$");
// 中文疑问后缀:提取核心词
private static final Pattern ZH_QUESTION_SUFFIX = Pattern.compile(
"^(.+?)(?:是什么|怎么样|如何|有哪些|有什么|是啥|是干什么的).*$");
private static final int STREAM_PROBE_CHARS = 120;
private final KnowledgeBaseVectorService vectorService;
private final KnowledgeBaseListService listService;
private final ChatClient chatClient;
private final PromptTemplate sysPromptTemplate;
private final PromptTemplate userPromptTemplate;
private final PromptTemplate reWritePromptTemplate;
// Query Rewrite 配置
private final boolean reWriteEnable;
private final int shortQueryLength;
// 动态 topK 配置
private final int topKShort;
private final int topKMedium;
private final int topKLong;
private final double minScoreShort;
private final double minScoreDefault;
public KnowledgeBaseQueryService(KnowledgeBaseVectorService vectorService,
KnowledgeBaseListService listService,
ChatClient.Builder builder,
KnowledgeBaseQueryProperties queryProperties,
ResourceLoader resourceLoader) {
this.vectorService = vectorService;
this.listService = listService;
this.chatClient = builder.build();
this.sysPromptTemplate = new PromptTemplate(resourceLoader.getResource(queryProperties.getSystemPromptPath()));
this.userPromptTemplate = new PromptTemplate(resourceLoader.getResource(queryProperties.getUserPromptPath()));
this.reWritePromptTemplate = new PromptTemplate(resourceLoader.getResource(queryProperties.getRewritePromptPath()));
this.reWriteEnable = queryProperties.getRewrite().isEnabled();
this.shortQueryLength = queryProperties.getSearch().getShortQueryLength();
this.topKShort = queryProperties.getSearch().getTopkShort();
this.topKMedium = queryProperties.getSearch().getTopkMedium();
this.topKLong = queryProperties.getSearch().getTopkLong();
this.minScoreShort = queryProperties.getSearch().getMinScoreShort();
this.minScoreDefault = queryProperties.getSearch().getMinScoreDefault();
}
/**
* 流式查询知识库(SSE)
*/
public Flux<String> answerQuestionStream(List<Long> knowledgeBaseIds, String question) {
try {
log.info("收到知识库流式提问: kbIds={}, question={}", knowledgeBaseIds, question);
if (knowledgeBaseIds == null || knowledgeBaseIds.isEmpty() || normalizeQuestion(question).isBlank()) {
return Flux.just(NO_RESULT_RESPONSE);
}
// 1. 验证知识库是否存在并更新问题计数
listService.updateQuestionCounts(knowledgeBaseIds);
// 2. Query rewrite + 动态参数检索
QueryContext queryContext = buildQueryContext(question);
List<Document> documents = retrieveRelevantDocs(queryContext, knowledgeBaseIds);
if (!hasEffectiveHit(question, documents)) {
return Flux.just(NO_RESULT_RESPONSE);
}
// 3. 构建上下文
String context = documents.stream()
.map(Document::getText)
.collect(Collectors.joining("\n\n---\n\n"));
log.debug("检索到 {} 个相关文档片段", documents.size());
// 4. 构建提示词
String systemPrompt = buildSystemPrompt();
String userPrompt = buildUserPrompt(context, queryContext.candidateQueries.getFirst());
// 5. 流式调用 + 探测窗口归一化
Flux<String> responseFlux = chatClient.prompt()
.system(systemPrompt)
.user(userPrompt)
.stream()
.content();
log.info("开始流式输出知识库回答(探测窗口): kbIds={}", knowledgeBaseIds);
return normalizeStreamOutput(responseFlux)
.doOnComplete(() -> log.info("流式输出完成: kbIds={}", knowledgeBaseIds))
.onErrorResume(e -> {
log.error("流式输出失败: kbIds={}, error={}", knowledgeBaseIds, e.getMessage(), e);
return Flux.just("【错误】知识库查询失败:AI服务暂时不可用,请稍后重试。");
});
} catch (Exception e) {
log.error("知识库流式问答失败: {}", e.getMessage(), e);
return Flux.just("【错误】知识库查询失败:" + e.getMessage());
}
}
/**
* 普通查询知识库(非流式)
*/
public String answerQuestion(List<Long> knowledgeIds, String question) {
log.info("收到知识库提问: kbIds={}, question={}", knowledgeIds, question);
if (knowledgeIds == null || knowledgeIds.isEmpty() || normalizeQuestion(question).isBlank()) {
return NO_RESULT_RESPONSE;
}
// 1. 验证知识库是否存在并更新问题计数
listService.updateQuestionCounts(knowledgeIds);
// 2. Query rewrite + 动态参数检索
QueryContext queryContext = buildQueryContext(question);
// 3. 向量检索
List<Document> relevantDocs = retrieveRelevantDocs(queryContext, knowledgeIds);
if (!hasEffectiveHit(question, relevantDocs)) {
return NO_RESULT_RESPONSE;
}
// 4. 构建上下文
String context = relevantDocs.stream()
.map(Document::getText)
.collect(Collectors.joining("\n\n---\n\n"));
log.debug("检索到 {} 个相关文档片段", relevantDocs.size());
// 5. 构建提示词
String systemPrompt = buildSystemPrompt();
String userPrompt = buildUserPrompt(context, queryContext.candidateQueries.getFirst());
// 6. 调用大模型
try {
String content = chatClient
.prompt()
.system(systemPrompt)
.user(userPrompt)
.call()
.content();
content = normalizeAnswer(content);
return content;
} catch (Exception e) {
log.error("知识库问答失败: {}", e.getMessage(), e);
throw new BusinessException(ErrorCode.KNOWLEDGE_BASE_QUERY_FAILED, "知识库查询失败:" + e.getMessage());
}
}
/**
* 流式输出归一化处理
* - 探测前 120 字符,快速识别"无信息"模板
* - 命中无信息:立即输出固定模板并结束
* - 非无信息:尽快释放缓冲并实时透传
*/
private Flux<String> normalizeStreamOutput(Flux<String> rawFlux) {
return Flux.create(sink -> {
StringBuilder probeBuffer = new StringBuilder();
AtomicBoolean passthrough = new AtomicBoolean(false);
AtomicBoolean completed = new AtomicBoolean(false);
final Disposable[] disposableRef = new Disposable[1];
disposableRef[0] = rawFlux.subscribe(
chunk -> {
if (completed.get() || sink.isCancelled()) {
return;
}
if (passthrough.get()) {
sink.next(chunk);
return;
}
probeBuffer.append(chunk);
String probeText = probeBuffer.toString();
if (isNoResultLike(probeText)) {
completed.set(true);
sink.next(NO_RESULT_RESPONSE);
sink.complete();
if (disposableRef[0] != null) {
disposableRef[0].dispose();
}
return;
}
if (probeBuffer.length() >= STREAM_PROBE_CHARS) {
passthrough.set(true);
sink.next(probeText);
probeBuffer.setLength(0);
}
},
sink::error,
() -> {
if (completed.get() || sink.isCancelled()) {
return;
}
if (!passthrough.get()) {
sink.next(normalizeAnswer(probeBuffer.toString()));
}
sink.complete();
}
);
sink.onCancel(() -> {
if (disposableRef[0] != null) {
disposableRef[0].dispose();
}
});
});
}
private String normalizeAnswer(String answer) {
if (answer == null || answer.isBlank()) {
return NO_RESULT_RESPONSE;
}
String normalized = answer.trim();
if (isNoResultLike(normalized)) {
return NO_RESULT_RESPONSE;
}
return normalized;
}
private boolean isNoResultLike(String text) {
return text.contains("没有找到相关信息")
|| text.contains("未检索到相关信息")
|| text.contains("信息不足")
|| text.contains("超出知识库范围")
|| text.contains("无法根据提供内容回答");
}
private String buildUserPrompt(String context, String question) {
Map<String, Object> map = Map.of("context", context, "question", question);
return userPromptTemplate.render(map);
}
private String buildSystemPrompt() {
return sysPromptTemplate.render();
}
/**
* 多候选查询检索相关文档
*/
private List<Document> retrieveRelevantDocs(QueryContext queryContext, List<Long> knowledgeIds) {
for (String candidateQuery : queryContext.candidateQueries) {
if (candidateQuery.isBlank()) {
continue;
}
List<Document> docs = vectorService.similaritySearch(
candidateQuery,
knowledgeIds,
queryContext.searchParams().topK(),
queryContext.searchParams().minScore()
);
log.info("检索候选 query='{}',命中 {} 条", candidateQuery, docs.size());
// 判断查询结果是否命中用户提问的核心词
if (hasEffectiveHit(candidateQuery, docs)) {
return docs;
}
}
return List.of();
}
/**
* 判断检索结果是否有效命中
* 通过提取中文问句的核心词进行字面匹配
*/
private boolean hasEffectiveHit(String candidateQuery, List<Document> docs) {
if (Func.isEmpty(candidateQuery) || Func.isEmpty(docs)) {
return false;
}
String question = normalizeQuestion(candidateQuery);
String coreTerm = extractCoreTerm(question).toLowerCase();
for (Document doc : docs) {
String text = doc.getText();
if (!Func.isEmpty(text) && text.toLowerCase().contains(coreTerm)) {
return true;
}
}
return false;
}
/**
* 提取中文问句的核心词
* "什么是进程" → "进程"
* "进程是什么" → "进程"
*/
private String extractCoreTerm(String question) {
Matcher m = ZH_QUESTION_PREFIX.matcher(question);
if (m.matches()) {
return m.group(1).trim();
}
m = ZH_QUESTION_SUFFIX.matcher(question);
if (m.matches()) {
return m.group(1).trim();
}
return question;
}
/**
* 构建查询上下文(Query Rewrite + 动态参数)
*/
private QueryContext buildQueryContext(String question) {
String normalizeQuestion = normalizeQuestion(question);
String rewrittenQuestion = rewriteQuestion(normalizeQuestion);
Set<String> candidates = new LinkedHashSet<>();
candidates.add(rewrittenQuestion);
candidates.add(normalizeQuestion);
SearchParams searchParams = resolveSearchParams(normalizeQuestion);
return new QueryContext(normalizeQuestion, new ArrayList<>(candidates), searchParams);
}
/**
* 根据问题长度动态调整搜索参数
* 短查询:检索更多文档,降低阈值
* 长查询:检索较少文档
*/
private SearchParams resolveSearchParams(String question) {
int compactLength = question.replaceAll("\\s+", "").length();
if (compactLength <= shortQueryLength) {
return new SearchParams(topKShort, minScoreDefault);
}
if (compactLength <= 12) {
return new SearchParams(topKMedium, minScoreDefault);
}
return new SearchParams(topKLong, minScoreDefault);
}
/**
* Query Rewrite - 对用户问题进行重构丰富
*/
private String rewriteQuestion(String normalizeQuestion) {
if (normalizeQuestion.isBlank() || !reWriteEnable) {
return normalizeQuestion;
}
try {
Map<String, Object> param = Map.of("question", normalizeQuestion);
String render = reWritePromptTemplate.render(param);
String rewritten = chatClient.prompt()
.user(render)
.call()
.content();
if (rewritten == null || rewritten.isBlank()) {
return normalizeQuestion;
}
log.info("Query rewrite: origin='{}', rewritten='{}'", normalizeQuestion, rewritten);
return rewritten.trim();
} catch (Exception e) {
log.warn("Query rewrite 失败,使用原问题继续检索: {}", e.getMessage());
}
return normalizeQuestion;
}
private String normalizeQuestion(String question) {
return question == null ? "" : question.trim();
}
private record SearchParams(int topK, double minScore) {}
private record QueryContext(String originalQuestion, List<String> candidateQueries, SearchParams searchParams) {}
}
|