Featured image of post Spring AI 流式输出实战:SSE 实现 RAG 智能问答

Spring AI 流式输出实战:SSE 实现 RAG 智能问答

概述

本文介绍如何在 Spring AI 中实现 SSE(Server-Sent Events)流式输出,结合 RAG 知识库问答功能,实现类似 ChatGPT 的流式对话体验。

SSE 简介

SSE 是一种服务端推送技术,允许服务器通过 HTTP 连接向客户端连续发送事件。与 WebSocket 的全双工通信不同,SSE 是单通道的,适合服务器向客户端推送数据场景。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
┌─────────┐              ┌─────────┐              ┌─────────┐
│  客户端  │              │  服务器  │              │ 知识库  │
└────┬────┘              └────┬────┘              └────┬────┘
     │                         │                        │
     │──── GET /stream ───────▶│                        │
     │                         │                        │
     │◀─── event: chunk 1 ────│                        │
     │◀─── event: chunk 2 ────│                        │
     │◀─── event: chunk 3 ────│─────── RAG ──────────▶│
     │                         │◀────── 检索结果 ───────│
     │◀─── event: chunk N ────│                        │
     │                         │                        │
     │──── GET /sessions ──────▶│                        │

项目架构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
┌─────────────────────────────────────────────────────────────────┐
│                        前端(浏览器)                            │
│    EventSource 接收 SSE → 实时显示 AI 回答                       │
└───────────────────────────────┬─────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│                     RagChatController                            │
│  @PostMapping(value = "/messages/stream",                        │
│          produces = MediaType.TEXT_EVENT_STREAM_VALUE)           │
│    返回 Flux<ServerSentEvent<String>>                            │
└───────────────────────────────┬─────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│                    RagChatSessionService                         │
│  prepareStreamMessage()  → 保存用户消息,创建 AI 消息占位        │
│  getStreamAnswer()       → 调用 queryService 获取流式回答        │
│  completeStreamMessage() → 流式完成后更新消息内容                 │
└───────────────────────────────┬─────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│                  KnowledgeBaseQueryService                       │
│  answerQuestionStream() → 调用 Spring AI ChatClient 流式生成     │
│  normalizeStreamOutput() → 探测窗口归一化处理                    │
└───────────────────────────────┬─────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│                    KnowledgeBaseVectorService                    │
│  similaritySearch() → 向量数据库相似度检索                       │
└─────────────────────────────────────────────────────────────────┘

核心实现

1. 流式接口设计

RagChatController 提供流式消息接口:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@Slf4j
@RestController
@RequiredArgsConstructor
@RequestMapping("/api/rag-chat")
@Tag(name = "RAG 问答", description = "基于知识库的智能问答会话")
public class RagChatController {

    private final RagChatSessionService sessionService;

    /**
     * 发送消息(流式SSE)
     *
     * 流式响应设计:
     * 1. 先同步保存用户消息和创建 AI 消息占位
     * 2. 返回流式响应
     * 3. 流式完成后通过回调更新消息
     */
    @PostMapping(value = "/sessions/{sessionId}/messages/stream",
            produces = MediaType.TEXT_EVENT_STREAM_VALUE)
    public Flux<ServerSentEvent<String>> sendMessageStream(
            @PathVariable Long sessionId,
            @Valid @RequestBody RagChatDTO.SendMessageRequest request) {
        log.info("收到 RAG 聊天流式请求: sessionId={}, question={}",
                sessionId, request.question());

        // 1. 准备消息(保存用户消息,创建 AI 消息占位)
        Long messageId = sessionService.prepareStreamMessage(sessionId, request.question());

        // 2. 获取流式响应
        StringBuilder fullContent = new StringBuilder();

        return sessionService.getStreamAnswer(sessionId, request.question())
                // 累积完整内容
                .doOnNext(fullContent::append)
                // 转换为 ServerSentEvent 格式
                .map(chunk -> ServerSentEvent.<String>builder()
                        .data(chunk.replace("\n", "\\n").replace("\r", "\\r"))
                        .build()
                )
                // 流式完成后的回调
                .doOnComplete(() -> {
                    sessionService.completeStreamMessage(messageId, fullContent.toString());
                    log.info("RAG 聊天流式完成: sessionId={}, messageId={}", sessionId, messageId);
                })
                // 错误处理
                .doOnError(e -> {
                    String content = !fullContent.isEmpty()
                            ? fullContent.toString()
                            : "【错误】回答生成失败:" + e.getMessage();
                    sessionService.completeStreamMessage(messageId, content);
                    log.error("RAG 聊天流式错误: sessionId={}", sessionId, e);
                });
    }
}

2. 会话服务

RagChatSessionService 管理会话状态和消息:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@Service
@Slf4j
@RequiredArgsConstructor
public class RagChatSessionService {

    private final KnowledgeRepository knowledgeRepository;
    private final RagChatSessionRepository sessionRepository;
    private final RagChatMapper chatMapper;
    private final RagChatMessageRepository chatMessageRepository;
    private final KnowledgeBaseMapper knowledgeBaseMapper;
    private final KnowledgeBaseQueryService queryService;

    /**
     * 准备流式消息(保存用户消息,创建 AI 消息占位)
     *
     * @return AI 消息的 ID(用于后续更新)
     */
    @Transactional
    public Long prepareStreamMessage(Long sessionId, @NotBlank(message = "问题不能为空") String question) {
        RagChatSessionEntity session = sessionRepository.findByIdWithKnowledgeBases(sessionId)
                .orElseThrow(() -> new BusinessException(ErrorCode.NOT_FOUND, "会话不存在"));

        Integer nextOrder = session.getMessageCount();

        // 保存用户消息
        RagChatMessageEntity userMessage = new RagChatMessageEntity();
        userMessage.setSession(session);
        userMessage.setType(RagChatMessageEntity.MessageType.USER);
        userMessage.setContent(question);
        userMessage.setMessageOrder(nextOrder);
        userMessage.setCompleted(true);
        chatMessageRepository.save(userMessage);

        // 保存系统占位消息
        RagChatMessageEntity assistantMessage = new RagChatMessageEntity();
        assistantMessage.setSession(session);
        assistantMessage.setType(RagChatMessageEntity.MessageType.ASSISTANT);
        assistantMessage.setContent("");
        assistantMessage.setMessageOrder(++nextOrder);
        assistantMessage.setCompleted(false);
        chatMessageRepository.save(assistantMessage);

        // 更新会话消息数量
        session.setMessageCount(nextOrder + 1);
        sessionRepository.save(session);

        log.info("准备流式消息: sessionId={}, messageId={}", sessionId, assistantMessage.getId());
        return assistantMessage.getId();
    }

    /**
     * 获取流式回答
     */
    public Flux<String> getStreamAnswer(Long sessionId, String question) {
        RagChatSessionEntity session = sessionRepository.findByIdWithKnowledgeBases(sessionId)
                .orElseThrow(() -> new BusinessException(ErrorCode.NOT_FOUND, "会话不存在"));

        List<Long> kbIds = session.getKnowledgeBaseIds();
        return queryService.answerQuestionStream(kbIds, question);
    }

    /**
     * 完成流式消息(更新消息内容)
     */
    public void completeStreamMessage(Long messageId, String answer) {
        RagChatMessageEntity message = chatMessageRepository.findById(messageId)
                .orElseThrow(() -> new BusinessException(ErrorCode.NOT_FOUND, "消息不存在"));
        message.setCompleted(true);
        message.setContent(answer);
        chatMessageRepository.save(message);
        log.info("完成流式消息: messageId={}, contentLength={}", messageId, answer.length());
    }
}

3. 流式输出归一化处理

KnowledgeBaseQueryService 中的流式处理逻辑:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
/**
 * 流式输出归一化处理
 *
 * 设计目标:
 * - 探测前 120 字符,快速识别"无信息"模板
 * - 命中无信息:立即输出固定模板并结束,防止长篇拒答
 * - 非无信息:尽快释放缓冲并实时透传
 */
private Flux<String> normalizeStreamOutput(Flux<String> rawFlux) {
    return Flux.create(sink -> {
        StringBuilder probeBuffer = new StringBuilder();
        AtomicBoolean passthrough = new AtomicBoolean(false);
        AtomicBoolean completed = new AtomicBoolean(false);
        final Disposable[] disposableRef = new Disposable[1];

        disposableRef[0] = rawFlux.subscribe(
                chunk -> {
                    // 已完成或已取消,跳过
                    if (completed.get() || sink.isCancelled()) {
                        return;
                    }

                    // 已进入透传模式,直接发送
                    if (passthrough.get()) {
                        sink.next(chunk);
                        return;
                    }

                    // 累积探测文本
                    probeBuffer.append(chunk);
                    String probeText = probeBuffer.toString();

                    // 检测"无信息"模板
                    if (isNoResultLike(probeText)) {
                        completed.set(true);
                        sink.next(NO_RESULT_RESPONSE);
                        sink.complete();
                        if (disposableRef[0] != null) {
                            disposableRef[0].dispose();
                        }
                        return;
                    }

                    // 探测窗口(120字符)后开始透传
                    if (probeBuffer.length() >= STREAM_PROBE_CHARS) {
                        passthrough.set(true);
                        sink.next(probeText);
                        probeBuffer.setLength(0);
                    }
                },
                sink::error,
                () -> {
                    // 流式完成
                    if (completed.get() || sink.isCancelled()) {
                        return;
                    }
                    if (!passthrough.get()) {
                        sink.next(normalizeAnswer(probeBuffer.toString()));
                    }
                    sink.complete();
                }
        );

        // 取消时清理资源
        sink.onCancel(() -> {
            if (disposableRef[0] != null) {
                disposableRef[0].dispose();
            }
        });
    });
}

private boolean isNoResultLike(String text) {
    return text.contains("没有找到相关信息")
            || text.contains("未检索到相关信息")
            || text.contains("信息不足")
            || text.contains("超出知识库范围")
            || text.contains("无法根据提供内容回答");
}

4. 消息实体

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@Entity
@Table(name = "rag_chat_messages", indexes = {
    @Index(name = "idx_rag_message_session", columnList = "session_id"),
    @Index(name = "idx_rag_message_order", columnList = "session_id, messageOrder")
})
@Getter
@Setter
@NoArgsConstructor
public class RagChatMessageEntity {

    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;

    @ManyToOne(fetch = FetchType.LAZY)
    @JoinColumn(name = "session_id", nullable = false)
    private RagChatSessionEntity session;

    @Enumerated(EnumType.STRING)
    @Column(nullable = false, length = 20)
    private MessageType type;

    @Column(columnDefinition = "TEXT", nullable = false)
    private String content;

    @Column(nullable = false)
    private Integer messageOrder;

    @Column(nullable = false, updatable = false)
    private LocalDateTime createdAt;

    private LocalDateTime updatedAt;

    private Boolean completed = true;

    public enum MessageType {
        USER,      // 用户消息
        ASSISTANT  // AI 回答
    }

    @PrePersist
    protected void onCreate() {
        createdAt = LocalDateTime.now();
        updatedAt = LocalDateTime.now();
    }

    @PreUpdate
    protected void onUpdate() {
        updatedAt = LocalDateTime.now();
    }
}

5. DTO 定义

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
public class RagChatDTO {

    // ========== 请求 DTO ==========

    /**
     * 创建会话请求
     */
    public record CreateSessionRequest(
            @NotEmpty(message = "至少选择一个知识库")
            List<Long> knowledgeBaseIds,
            String title  // 可选,为空则自动生成
    ) {
    }

    /**
     * 发送消息请求
     */
    public record SendMessageRequest(
            @NotBlank(message = "问题不能为空")
            String question
    ) {
    }

    /**
     * 更新标题请求
     */
    public record UpdateTitleRequest(
            @NotNull(message = "请选择会话")
            Long sessionId,
            @NotBlank(message = "标题不能为空")
            String title
    ) {
    }

    /**
     * 更新知识库请求
     */
    public record UpdateKnowledgeBasesRequest(
            @NotNull(message = "请选择会话")
            Long sessionId,
            @NotEmpty(message = "至少选择一个知识库")
            List<Long> knowledgeBaseIds
    ) {
    }

    // ========== 响应 DTO ==========

    /**
     * 会话基础信息
     */
    public record SessionDTO(
            Long id,
            String title,
            List<Long> knowledgeBaseIds,
            LocalDateTime createdAt
    ) {
    }

    /**
     * 会话列表项
     */
    public record SessionListItemDTO(
            Long id,
            String title,
            Integer messageCount,
            List<String> knowledgeBaseNames,
            LocalDateTime updatedAt,
            Boolean isPinned
    ) {
    }

    /**
     * 会话详情(含消息)
     */
    public record SessionDetailDTO(
            Long id,
            String title,
            List<KnowledgeBaseListItemDTO> knowledgeBases,
            List<MessageDTO> messages,
            LocalDateTime createdAt,
            LocalDateTime updatedAt
    ) {
    }

    /**
     * 消息 DTO
     */
    public record MessageDTO(
            Long id,
            String type,  // "user" | "assistant"
            String content,
            LocalDateTime createdAt
    ) {
    }
}

前端对接

EventSource 接收 SSE

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
// 创建 EventSource(注意:GET 请求,参数在 URL 中)
const sessionId = 123;
const eventSource = new EventSource(`/api/rag-chat/sessions/${sessionId}/messages/stream?question=${encodeURIComponent(question)}`);

// 监听消息事件
eventSource.onmessage = (event) => {
    const data = JSON.parse(event.data);
    console.log('Received:', data);
    // data 就是 AI 返回的文本片段
    appendToChat(data);
};

// 监听错误
eventSource.onerror = (error) => {
    console.error('SSE Error:', error);
    eventSource.close();
};

// 关闭连接
function closeChat() {
    eventSource.close();
}

Fetch API 接收 SSE(推荐)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
async function sendMessage(sessionId, question) {
    const response = await fetch(`/api/rag-chat/sessions/${sessionId}/messages/stream`, {
        method: 'POST',
        headers: {
            'Content-Type': 'application/json',
        },
        body: JSON.stringify({ question }),
    });

    const reader = response.body.getReader();
    const decoder = new TextDecoder();

    while (true) {
        const { done, value } = await reader.read();
        if (done) break;

        const chunk = decoder.decode(value);
        // 解析 SSE 格式: data: {"data":"xxx"}\n\n
        const lines = chunk.split('\n');
        for (const line of lines) {
            if (line.startsWith('data: ')) {
                const data = JSON.parse(line.slice(6));
                console.log('Received:', data.data);
                appendToChat(data.data);
            }
        }
    }
}

SSE 数据格式

Spring AI 返回的 SSE 格式:

1
2
3
4
5
6
7
8
9
data: {"data":"今天"}

data: {"data":"天"}

data: {"data":"气"}

data: {"data":"不"}

data: {"data":"错"}

会话管理 API

创建会话

1
2
3
4
5
6
7
POST /api/rag-chat/sessions
Content-Type: application/json

{
    "knowledgeBaseIds": [1, 2, 3],
    "title": "我的智能问答"
}

获取会话列表

1
GET /api/rag-chat/sessions

获取会话详情

1
GET /api/rag-chat/sessions/{sessionId}

删除会话

1
DELETE /api/rag-chat/sessions/{sessionId}

更新会话标题

1
2
3
4
5
6
7
PUT /api/rag-chat/sessions/title
Content-Type: application/json

{
    "sessionId": 123,
    "title": "新标题"
}

切换置顶状态

1
PUT /api/rag-chat/sessions/{sessionId}/pin

更换会话知识库

1
2
3
4
5
6
7
PUT /api/rag-chat/sessions/knowledge-bases
Content-Type: application/json

{
    "sessionId": 123,
    "knowledgeBaseIds": [4, 5]
}

流式响应完整流程

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
┌─────────────────────────────────────────────────────────────────────┐
│                         请求阶段                                     │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  前端  ──── POST /sessions/1/messages/stream ────▶  Controller      │
│                        │                                             │
│                        ▼                                             │
│               prepareStreamMessage()                                 │
│                        │                                             │
│                        ├── 保存用户消息                               │
│                        ├── 创建 AI 消息占位(completed=false)         │
│                        └── 返回 messageId                            │
│                                                                      │
├─────────────────────────────────────────────────────────────────────┤
│                         流式阶段                                     │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  Controller  ──── getStreamAnswer() ────▶  QueryService             │
│                        │                                             │
│                        ▼                                             │
│              answerQuestionStream()                                  │
│                        │                                             │
│                        ├── RAG 检索(多候选查询)                      │
│                        ├── 构建 Prompt                               │
│                        └── chatClient.stream() → Flux<String>        │
│                                                                      │
│  Flux<String>  ──── normalizeStreamOutput() ────▶ 前端              │
│                        │                                             │
│                        ├── 探测窗口(120字符)                         │
│                        ├── 无结果模板检测                             │
│                        └── 实时透传或终止                             │
│                                                                      │
├─────────────────────────────────────────────────────────────────────┤
│                         完成阶段                                     │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  前端  ──── 流式完成 ────▶  doOnComplete()                           │
│                        │                                             │
│                        ▼                                             │
│              completeStreamMessage(messageId, content)              │
│                        │                                             │
│                        └── 更新消息内容,completed=true                │
│                                                                      │
└─────────────────────────────────────────────────────────────────────┘

提示词配置

系统提示词

1
2
3
4
5
6
7
你是一个专业的知识库问答助手。请根据提供的上下文信息,准确回答用户的问题。

要求:
1. 只回答与上下文相关的问题
2. 如果上下文中没有相关信息,请明确告知用户
3. 回答要条理清晰,简洁明了
4. 对于涉及专业知识的问题,尽量给出准确的解释

用户提示词

1
2
3
4
5
6
上下文信息:
{{context}}

用户问题:{{question}}

请根据上下文信息回答用户问题。

常见问题

1. 流式中断

检查网络连接和服务器日志,确保 SSE 连接保持活跃。

2. 乱序问题

SSE 保证顺序,但如果前端渲染不及时,可能导致显示顺序混乱。建议使用数组累积后统一渲染。

3. 跨域问题

如果前端与后端在不同域名,需要配置 CORS:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
@Configuration
public class WebConfig implements WebFluxConfigurer {
    @Override
    public void addCorsMappings(CorsRegistry registry) {
        registry.addMapping("/api/**")
                .allowedOrigins("*")
                .allowedMethods("GET", "POST", "PUT", "DELETE", "OPTIONS")
                .allowedHeaders("*");
    }
}

总结

本文介绍了基于 Spring AI 实现 SSE 流式输出的完整方案:

  1. Controller 层:返回 Flux<ServerSentEvent<String>>,使用 .map() 转换为 SSE 格式
  2. Service 层
    • prepareStreamMessage() - 预先保存消息占位
    • getStreamAnswer() - 获取流式回答
    • completeStreamMessage() - 流式完成后更新消息
  3. QueryService 层
    • normalizeStreamOutput() - 流式归一化处理
    • 探测窗口快速识别无结果
    • 支持流式透传和异常终止
  4. 前端对接:使用 EventSource 或 Fetch API 接收 SSE
使用 Hugo 构建
主题 StackJimmy 设计

发布了 33 篇文章 | 共 83531 字