专注、坚持

浅谈小智 xiaozhi-esp32(二)

2025.03.12 by kingcos

概览

小智(xiaozhi-esp32)是一个开源的 AI 聊天机器人项目。小智分为端侧与服务端侧,本篇仅聚焦于服务端侧。

基本信息

  • 代码版本:release 0.1.6
  • 主要语言:Python
  • 开源协议:MIT

整体架构

重点模块

音频消息处理

音频数据处理:

# recieveAudioHandle.py

async def handleAudioMessage(conn, audio):
    # ...
    # 音频转文本
    text, file_path = await conn.asr.speech_to_text(conn.asr_audio, conn.session_id)
    text_len, _ = remove_punctuation_and_length(text)
    if text_len > 0:
        # 开启聊天
        await startToChat(conn, text)
    else:
        conn.asr_server_receive = True
    # ...

async def startToChat(conn, text):
    # 首先进行意图分析
    intent_handled = await handle_user_intent(conn, text)
    
    if intent_handled:
        # 如果意图已被处理,不再进行聊天
        conn.asr_server_receive = True
        return
    
    # 意图未被处理,继续常规聊天流程
    await send_stt_message(conn, text)
    if conn.use_function_call_mode:
        # 使用支持function calling的聊天方法
        conn.executor.submit(conn.chat_with_function_calling, text)
    else:
        conn.executor.submit(conn.chat, text)

意图识别:

# intentHandler.py

async def handle_user_intent(conn, text):
    """
    Handle user intent before starting chat
    
    Args:
        conn: Connection object
        text: User's text input
    
    Returns:
        bool: True if intent was handled, False if should proceed to chat
    """
    # 检查是否有明确的退出命令
    if await check_direct_exit(conn, text):
        return True

    if conn.use_function_call_mode:
        # 使用支持function calling的聊天方法,不再进行意图分析
        return False

    # 使用LLM进行意图分析
    intent = await analyze_intent_with_llm(conn, text)

    if not intent:
        return False

    # 处理各种意图
    return await process_intent_result(conn, intent, text)

async def check_direct_exit(conn, text):
    """检查是否有明确的退出命令"""
    # 去除标点符号
    _, text = remove_punctuation_and_length(text)
    cmd_exit = conn.cmd_exit
    # cmd_exit = ["退出", "关闭"]
    for cmd in cmd_exit:
        if text == cmd:
            logger.bind(tag=TAG).info(f"识别到明确的退出命令: {text}")
            await conn.close()
            return True
    return False

async def process_intent_result(conn, intent, original_text):
    """处理意图识别结果"""
    # 处理退出意图
    if "结束聊天" in intent:
        logger.bind(tag=TAG).info(f"识别到退出意图: {intent}")

        # 如果正在播放音乐,可以关了 TODO

        # 如果是明确的离别意图,发送告别语并关闭连接
        await send_stt_message(conn, original_text)
        conn.executor.submit(conn.chat_and_close, original_text)
        return True

    # 处理播放音乐意图
    if "播放音乐" in intent:
        logger.bind(tag=TAG).info(f"识别到音乐播放意图: {intent}")
        await conn.music_handler.handle_music_command(conn, intent)
        return True

    # 其他意图处理可以在这里扩展

    # 默认返回False,表示继续常规聊天流程
    return False

# intent_llm.py

async def detect_intent(self, conn, dialogue_history: List[Dict], text:str) -> str:
    if not self.llm:
        raise ValueError("LLM provider not set")

    # 构建用户最后一句话的提示
    msgStr = ""

    # 只使用最后两句即可
    if len(dialogue_history) >= 2:
        # 保证最少有两句话的时候处理
        msgStr += f"{dialogue_history[-2].role}: {dialogue_history[-2].content}\n"
    msgStr += f"{dialogue_history[-1].role}: {dialogue_history[-1].content}\n"

    msgStr += f"User: {text}\n"
    user_prompt = f"当前的对话如下:\n{msgStr}"
    prompt_music = f"{self.promot}\n<start>{conn.music_handler.music_files}\n<end>"
    logger.bind(tag=TAG).debug(f"User prompt: {prompt_music}")
    # 使用LLM进行意图识别
    intent = self.llm.response_no_stream(
        system_prompt=prompt_music,
        user_prompt=user_prompt
    )
    # 使用正则表达式提取大括号中的内容  
    # 使用正则表达式提取 {} 中的内容
    match = re.search(r'\{.*?\}', intent)
    if match:
        result = match.group(0)  # 获取匹配到的内容(包含 {})
        print(result)  # 输出:{intent: '播放音乐 [中秋月]'}
        intent = result
    else:
        intent = "{intent: '继续聊天'}"
    logger.bind(tag=TAG).info(f"Detected intent: {intent}")
    return intent.strip()

# llm/base.py

def response_no_stream(self, system_prompt, user_prompt):
    try:
        # 构造对话格式
        dialogue = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        result = ""
        for part in self.response("", dialogue):
            result += part
        return result

    except Exception as e:
        logger.bind(tag=TAG).error(f"Error in Ollama response generation: {e}")
        return "【LLM服务响应异常】"

聊天处理:

# connection.py

def chat(self, query):
    if self.isNeedAuth():
        self.llm_finish_task = True
        future = asyncio.run_coroutine_threadsafe(self._check_and_broadcast_auth_code(), self.loop)
        future.result()
        return True

    self.dialogue.put(Message(role="user", content=query))

    response_message = []
    processed_chars = 0  # 跟踪已处理的字符位置
    try:
        start_time = time.time()
        # 使用带记忆的对话
        future = asyncio.run_coroutine_threadsafe(self.memory.query_memory(query), self.loop)
        memory_str = future.result()

        self.logger.bind(tag=TAG).debug(f"记忆内容: {memory_str}")
        # 大模型生成回复
        llm_responses = self.llm.response(
            self.session_id,
            self.dialogue.get_llm_dialogue_with_memory(memory_str)
        )
    except Exception as e:
        self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
        return None

    self.llm_finish_task = False
    text_index = 0
    for content in llm_responses:
        response_message.append(content)
        if self.client_abort:
            break

        end_time = time.time()
        self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")

        # 合并当前全部文本并处理未分割部分
        full_text = "".join(response_message)
        current_text = full_text[processed_chars:]  # 从未处理的位置开始

        # 查找最后一个有效标点
        punctuations = ("。", "?", "!", ";", ":")
        last_punct_pos = -1
        for punct in punctuations:
            pos = current_text.rfind(punct)
            if pos > last_punct_pos:
                last_punct_pos = pos

        # 找到分割点则处理
        if last_punct_pos != -1:
            segment_text_raw = current_text[:last_punct_pos + 1]
            segment_text = get_string_no_punctuation_or_emoji(segment_text_raw)
            if segment_text:
                # 强制设置空字符,测试TTS出错返回语音的健壮性
                # if text_index % 2 == 0:
                #     segment_text = " "
                text_index += 1
                self.recode_first_last_text(segment_text, text_index)
                # 对每段文本进行 TTS 转换
                future = self.executor.submit(self.speak_and_play, segment_text, text_index)
                self.tts_queue.put(future)
                processed_chars += len(segment_text_raw)  # 更新已处理字符位置

    # 处理最后剩余的文本
    full_text = "".join(response_message)
    remaining_text = full_text[processed_chars:]
    if remaining_text:
        segment_text = get_string_no_punctuation_or_emoji(remaining_text)
        if segment_text:
            text_index += 1
            self.recode_first_last_text(segment_text, text_index)
            future = self.executor.submit(self.speak_and_play, segment_text, text_index)
            self.tts_queue.put(future)

    self.llm_finish_task = True
    self.dialogue.put(Message(role="assistant", content="".join(response_message)))
    self.logger.bind(tag=TAG).debug(json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False))
    return True

def speak_and_play(self, text, text_index=0):
    if text is None or len(text) <= 0:
        self.logger.bind(tag=TAG).info(f"无需tts转换,query为空,{text}")
        return None, text, text_index
    # TTS 转换
    tts_file = self.tts.to_tts(text)
    if tts_file is None:
        self.logger.bind(tag=TAG).error(f"tts转换失败,{text}")
        return None, text, text_index
    self.logger.bind(tag=TAG).debug(f"TTS 文件生成完毕: {tts_file}")
    return tts_file, text, text_index

# tts/base.py

def to_tts(self, text):
    tmp_file = self.generate_filename()
    try:
        max_repeat_time = 5
        while not os.path.exists(tmp_file) and max_repeat_time > 0:
            # 生成语音
            asyncio.run(self.text_to_speak(text, tmp_file))
            if not os.path.exists(tmp_file):
                max_repeat_time = max_repeat_time - 1
                logger.bind(tag=TAG).error(f"语音生成失败: {text}:{tmp_file},再试{max_repeat_time}次")

        if max_repeat_time > 0:
            logger.bind(tag=TAG).info(f"语音生成成功: {text}:{tmp_file},重试{5 - max_repeat_time}次")

        return tmp_file
    except Exception as e:
        logger.bind(tag=TAG).info(f"Failed to generate TTS file: {e}")
        return None

# tts/doubao.py

async def text_to_speak(self, text, output_file):
    request_json = {
        "app": {
            "appid": f"{self.appid}",
            "token": "access_token",
            "cluster": self.cluster
        },
        "user": {
            "uid": "1"
        },
        "audio": {
            "voice_type": self.voice,
            "encoding": "wav",
            "speed_ratio": 1.0,
            "volume_ratio": 1.0,
            "pitch_ratio": 1.0,
        },
        "request": {
            "reqid": str(uuid.uuid4()),
            "text": text,
            "text_type": "plain",
            "operation": "query",
            "with_frontend": 1,
            "frontend_type": "unitTson"
        }
    }

    try:
        resp = requests.post(self.api_url, json.dumps(request_json), headers=self.header)
        if "data" in resp.json():
            data = resp.json()["data"]
            file_to_save = open(output_file, "wb")
            file_to_save.write(base64.b64decode(data))
        else:
            raise Exception(f"{__name__} status_code: {resp.status_code} response: {resp.content}")
    except Exception as e:
        raise Exception(f"{__name__} error: {e}")
# connection.py

def chat_with_function_calling(self, query, tool_call = False):
    self.logger.bind(tag=TAG).debug(f"Chat with function calling start: {query}")
    """Chat with function calling for intent detection using streaming"""
    if self.isNeedAuth():
        self.llm_finish_task = True
        future = asyncio.run_coroutine_threadsafe(self._check_and_broadcast_auth_code(), self.loop)
        future.result()
        return True
    
    if not tool_call:
        self.dialogue.put(Message(role="user", content=query))

    # Define intent functions
    functions = self.func_handler.get_functions()

    response_message = []
    processed_chars = 0  # 跟踪已处理的字符位置

    try:
        start_time = time.time()

        # 使用带记忆的对话
        future = asyncio.run_coroutine_threadsafe(self.memory.query_memory(query), self.loop)
        memory_str = future.result()

        #self.logger.bind(tag=TAG).info(f"对话记录: {self.dialogue.get_llm_dialogue_with_memory(memory_str)}")

        # 使用支持functions的streaming接口
        llm_responses = self.llm.response_with_functions(
            self.session_id,
            self.dialogue.get_llm_dialogue_with_memory(memory_str),
            functions=functions
        )
    except Exception as e:
        self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
        return None

    self.llm_finish_task = False
    text_index = 0

    # 处理流式响应
    tool_call_flag = False
    function_name = None
    function_id = None
    function_arguments = ""
    content_arguments = ""
    for response in llm_responses:
        content, tools_call = response
        if content is not None and len(content)>0:
            if len(response_message)<=0 and (content=="```" or "<tool_call>" in content):
                tool_call_flag = True

        if tools_call is not None:
            tool_call_flag = True
            if tools_call[0].id is not None:
                function_id = tools_call[0].id
            if tools_call[0].function.name is not None:
                function_name = tools_call[0].function.name
            if tools_call[0].function.arguments is not None:
                function_arguments += tools_call[0].function.arguments

        if content is not None and len(content) > 0:
            if tool_call_flag:
                content_arguments+=content
            else:
                response_message.append(content)

                if self.client_abort:
                    break

                end_time = time.time()
                self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")

                # 处理文本分段和TTS逻辑
                # 合并当前全部文本并处理未分割部分
                full_text = "".join(response_message)
                current_text = full_text[processed_chars:]  # 从未处理的位置开始

                # 查找最后一个有效标点
                punctuations = ("。", "?", "!", ";", ":")
                last_punct_pos = -1
                for punct in punctuations:
                    pos = current_text.rfind(punct)
                    if pos > last_punct_pos:
                        last_punct_pos = pos

                # 找到分割点则处理
                if last_punct_pos != -1:
                    segment_text_raw = current_text[:last_punct_pos + 1]
                    segment_text = get_string_no_punctuation_or_emoji(segment_text_raw)
                    if segment_text:
                        text_index += 1
                        self.recode_first_last_text(segment_text, text_index)
                        future = self.executor.submit(self.speak_and_play, segment_text, text_index)
                        self.tts_queue.put(future)
                        processed_chars += len(segment_text_raw)  # 更新已处理字符位置

    # 处理function call
    if tool_call_flag:
        bHasError = False
        if function_id is None:
            a = extract_json_from_string(content_arguments)
            if a is not None:
                try:
                    content_arguments_json = json.loads(a)
                    function_name = content_arguments_json["name"]
                    function_arguments = json.dumps(content_arguments_json["arguments"], ensure_ascii=False)
                    function_id = str(uuid.uuid4().hex)
                except Exception as e:
                    bHasError = True
                    response_message.append(a)
            else:
                bHasError = True
                response_message.append(content_arguments)
            if bHasError:
                self.logger.bind(tag=TAG).error(f"function call error: {content_arguments}")
            else:
                function_arguments = json.loads(function_arguments)
        if not bHasError:
            self.logger.bind(tag=TAG).info(f"function_name={function_name}, function_id={function_id}, function_arguments={function_arguments}")
            function_call_data = {
                "name": function_name,
                "id": function_id,
                "arguments": function_arguments
            }
            result = self.func_handler.handle_llm_function_call(self, function_call_data)
            self._handle_function_result(result, function_call_data, text_index+1)

        # 处理最后剩余的文本
    full_text = "".join(response_message)
    remaining_text = full_text[processed_chars:]
    if remaining_text:
        segment_text = get_string_no_punctuation_or_emoji(remaining_text)
        if segment_text:
            text_index += 1
            self.recode_first_last_text(segment_text, text_index)
            future = self.executor.submit(self.speak_and_play, segment_text, text_index)
            self.tts_queue.put(future)

    # 存储对话内容
    if len(response_message)>0:
        self.dialogue.put(Message(role="assistant", content="".join(response_message)))

    self.llm_finish_task = True
    self.logger.bind(tag=TAG).debug(json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False))

    return True