|
| 1 | +# encoding: utf-8 |
| 2 | +# @Time : 2023/12/21 |
| 3 | +# @Author : Spike |
| 4 | +# @Descr : |
| 5 | +import json |
| 6 | +import re |
| 7 | +import time |
| 8 | +from request_llms.com_google import GoogleChatInit |
| 9 | +from toolbox import get_conf, update_ui, update_ui_lastest_msg |
| 10 | + |
| 11 | +proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY') |
| 12 | +timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \ |
| 13 | + '网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。' |
| 14 | + |
| 15 | + |
| 16 | +def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None, |
| 17 | + console_slience=False): |
| 18 | + # 检查API_KEY |
| 19 | + if get_conf("GEMINI_API_KEY") == "": |
| 20 | + raise ValueError(f"请配置 GEMINI_API_KEY。") |
| 21 | + |
| 22 | + genai = GoogleChatInit() |
| 23 | + watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可 |
| 24 | + gpt_replying_buffer = '' |
| 25 | + stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt) |
| 26 | + for response in stream_response: |
| 27 | + results = response.decode() |
| 28 | + match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) |
| 29 | + error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL) |
| 30 | + if match: |
| 31 | + try: |
| 32 | + paraphrase = json.loads('{"text": "%s"}' % match.group(1)) |
| 33 | + except: |
| 34 | + raise ValueError(f"解析GEMINI消息出错。") |
| 35 | + buffer = paraphrase['text'] |
| 36 | + gpt_replying_buffer += buffer |
| 37 | + if len(observe_window) >= 1: |
| 38 | + observe_window[0] = gpt_replying_buffer |
| 39 | + if len(observe_window) >= 2: |
| 40 | + if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。") |
| 41 | + if error_match: |
| 42 | + raise RuntimeError(f'{gpt_replying_buffer} 对话错误') |
| 43 | + return gpt_replying_buffer |
| 44 | + |
| 45 | + |
| 46 | +def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None): |
| 47 | + # 检查API_KEY |
| 48 | + if get_conf("GEMINI_API_KEY") == "": |
| 49 | + yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0) |
| 50 | + return |
| 51 | + |
| 52 | + chatbot.append((inputs, "")) |
| 53 | + yield from update_ui(chatbot=chatbot, history=history) |
| 54 | + genai = GoogleChatInit() |
| 55 | + retry = 0 |
| 56 | + while True: |
| 57 | + try: |
| 58 | + stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt) |
| 59 | + break |
| 60 | + except Exception as e: |
| 61 | + retry += 1 |
| 62 | + chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg)) |
| 63 | + retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else "" |
| 64 | + yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面 |
| 65 | + if retry > MAX_RETRY: raise TimeoutError |
| 66 | + gpt_replying_buffer = "" |
| 67 | + gpt_security_policy = "" |
| 68 | + history.extend([inputs, '']) |
| 69 | + for response in stream_response: |
| 70 | + results = response.decode("utf-8") # 被这个解码给耍了。。 |
| 71 | + gpt_security_policy += results |
| 72 | + match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL) |
| 73 | + error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL) |
| 74 | + if match: |
| 75 | + try: |
| 76 | + paraphrase = json.loads('{"text": "%s"}' % match.group(1)) |
| 77 | + except: |
| 78 | + raise ValueError(f"解析GEMINI消息出错。") |
| 79 | + gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理 |
| 80 | + chatbot[-1] = (inputs, gpt_replying_buffer) |
| 81 | + history[-1] = gpt_replying_buffer |
| 82 | + yield from update_ui(chatbot=chatbot, history=history) |
| 83 | + if error_match: |
| 84 | + history = history[-2] # 错误的不纳入对话 |
| 85 | + chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```") |
| 86 | + yield from update_ui(chatbot=chatbot, history=history) |
| 87 | + raise RuntimeError('对话错误') |
| 88 | + if not gpt_replying_buffer: |
| 89 | + history = history[-2] # 错误的不纳入对话 |
| 90 | + chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```") |
| 91 | + yield from update_ui(chatbot=chatbot, history=history) |
| 92 | + |
| 93 | + |
| 94 | + |
| 95 | +if __name__ == '__main__': |
| 96 | + import sys |
| 97 | + |
| 98 | + llm_kwargs = {'llm_model': 'gemini-pro'} |
| 99 | + result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, []) |
| 100 | + for i in result: |
| 101 | + print(i) |
0 commit comments