Skip to content

Commit 5a12f6e

Browse files
Kilig947binary-sky
authored andcommitted
适配 google gemini 优化为从用户input中提取文件 (#1419)
适配 google gemini 优化为从用户input中提取文件
1 parent a96f842 commit 5a12f6e

File tree

5 files changed

+360
-95
lines changed

5 files changed

+360
-95
lines changed

config.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@
8989
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
9090
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
9191
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
92-
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
9392
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
94-
"chatglm3", "moss", "claude-2"]
95-
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
93+
"gemini-pro", "chatglm3", "moss", "claude-2"]
94+
# P.S. 其他可用的模型还包括 [
95+
# "qwen-turbo", "qwen-plus", "qwen-max"
96+
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
97+
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
9698
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
97-
# “qwen-turbo", "qwen-plus", "qwen-max"]
99+
# ]
98100

99101

100102
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
@@ -204,6 +206,10 @@
204206
CUSTOM_API_KEY_PATTERN = ""
205207

206208

209+
# Google Gemini API-Key
210+
GEMINI_API_KEY = ''
211+
212+
207213
# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
208214
HUGGINGFACE_ACCESS_TOKEN = "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"
209215

@@ -292,6 +298,9 @@
292298
├── "qwen-turbo" 等通义千问大模型
293299
│ └── DASHSCOPE_API_KEY
294300
301+
├── "Gemini"
302+
│ └── GEMINI_API_KEY
303+
295304
└── "newbing" Newbing接口不再稳定,不推荐使用
296305
├── NEWBING_STYLE
297306
└── NEWBING_COOKIES

request_llms/bridge_all.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
2929
from .bridge_qianfan import predict as qianfan_ui
3030

31+
from .bridge_google_gemini import predict as genai_ui
32+
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
33+
3134
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
3235

3336
class LazyloadTiktoken(object):
@@ -246,6 +249,22 @@ def decode(self, *args, **kwargs):
246249
"tokenizer": tokenizer_gpt35,
247250
"token_cnt": get_token_num_gpt35,
248251
},
252+
"gemini-pro": {
253+
"fn_with_ui": genai_ui,
254+
"fn_without_ui": genai_noui,
255+
"endpoint": None,
256+
"max_token": 1024 * 32,
257+
"tokenizer": tokenizer_gpt35,
258+
"token_cnt": get_token_num_gpt35,
259+
},
260+
"gemini-pro-vision": {
261+
"fn_with_ui": genai_ui,
262+
"fn_without_ui": genai_noui,
263+
"endpoint": None,
264+
"max_token": 1024 * 32,
265+
"tokenizer": tokenizer_gpt35,
266+
"token_cnt": get_token_num_gpt35,
267+
},
249268
}
250269

251270
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-

request_llms/bridge_google_gemini.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# encoding: utf-8
2+
# @Time : 2023/12/21
3+
# @Author : Spike
4+
# @Descr :
5+
import json
6+
import re
7+
import time
8+
from request_llms.com_google import GoogleChatInit
9+
from toolbox import get_conf, update_ui, update_ui_lastest_msg
10+
11+
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
12+
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
13+
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
14+
15+
16+
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
17+
console_slience=False):
18+
# 检查API_KEY
19+
if get_conf("GEMINI_API_KEY") == "":
20+
raise ValueError(f"请配置 GEMINI_API_KEY。")
21+
22+
genai = GoogleChatInit()
23+
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
24+
gpt_replying_buffer = ''
25+
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
26+
for response in stream_response:
27+
results = response.decode()
28+
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
29+
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
30+
if match:
31+
try:
32+
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
33+
except:
34+
raise ValueError(f"解析GEMINI消息出错。")
35+
buffer = paraphrase['text']
36+
gpt_replying_buffer += buffer
37+
if len(observe_window) >= 1:
38+
observe_window[0] = gpt_replying_buffer
39+
if len(observe_window) >= 2:
40+
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
41+
if error_match:
42+
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
43+
return gpt_replying_buffer
44+
45+
46+
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
47+
# 检查API_KEY
48+
if get_conf("GEMINI_API_KEY") == "":
49+
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
50+
return
51+
52+
chatbot.append((inputs, ""))
53+
yield from update_ui(chatbot=chatbot, history=history)
54+
genai = GoogleChatInit()
55+
retry = 0
56+
while True:
57+
try:
58+
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
59+
break
60+
except Exception as e:
61+
retry += 1
62+
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
63+
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
64+
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
65+
if retry > MAX_RETRY: raise TimeoutError
66+
gpt_replying_buffer = ""
67+
gpt_security_policy = ""
68+
history.extend([inputs, ''])
69+
for response in stream_response:
70+
results = response.decode("utf-8") # 被这个解码给耍了。。
71+
gpt_security_policy += results
72+
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
73+
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
74+
if match:
75+
try:
76+
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
77+
except:
78+
raise ValueError(f"解析GEMINI消息出错。")
79+
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
80+
chatbot[-1] = (inputs, gpt_replying_buffer)
81+
history[-1] = gpt_replying_buffer
82+
yield from update_ui(chatbot=chatbot, history=history)
83+
if error_match:
84+
history = history[-2] # 错误的不纳入对话
85+
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
86+
yield from update_ui(chatbot=chatbot, history=history)
87+
raise RuntimeError('对话错误')
88+
if not gpt_replying_buffer:
89+
history = history[-2] # 错误的不纳入对话
90+
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
91+
yield from update_ui(chatbot=chatbot, history=history)
92+
93+
94+
95+
if __name__ == '__main__':
96+
import sys
97+
98+
llm_kwargs = {'llm_model': 'gemini-pro'}
99+
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
100+
for i in result:
101+
print(i)

request_llms/com_google.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# encoding: utf-8
2+
# @Time : 2023/12/25
3+
# @Author : Spike
4+
# @Descr :
5+
import json
6+
import requests
7+
from typing import List, Dict, Tuple
8+
from toolbox import get_conf
9+
10+
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
11+
12+
class GoogleChatInit:
13+
14+
def __init__(self):
15+
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
16+
17+
def __conversation_user(self, user_input):
18+
from toolbox import input_encode_handler
19+
what_i_have_asked = {"role": "user", "parts": []}
20+
if 'vision' not in self.url_gemini:
21+
input_ = user_input
22+
encode_img = []
23+
else:
24+
input_, encode_img = input_encode_handler(user_input)
25+
what_i_have_asked['parts'].append({'text': input_})
26+
if encode_img:
27+
for data in encode_img:
28+
what_i_have_asked['parts'].append(
29+
{'inline_data': {
30+
"mime_type": f"image/{data['type']}",
31+
"data": data['data']
32+
}})
33+
return what_i_have_asked
34+
35+
def __conversation_history(self, history):
36+
messages = []
37+
conversation_cnt = len(history) // 2
38+
if conversation_cnt:
39+
for index in range(0, 2 * conversation_cnt, 2):
40+
what_i_have_asked = self.__conversation_user(history[index])
41+
what_gpt_answer = {
42+
"role": "model",
43+
"parts": [{"text": history[index + 1]}]
44+
}
45+
messages.append(what_i_have_asked)
46+
messages.append(what_gpt_answer)
47+
return messages
48+
49+
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
50+
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
51+
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
52+
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
53+
return response.iter_lines()
54+
55+
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
56+
messages = [
57+
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
58+
# {"role": "user", "parts": [{"text": ""}]},
59+
# {"role": "model", "parts": [{"text": ""}]}
60+
]
61+
self.url_gemini = self.url_gemini.replace(
62+
'%m', llm_kwargs['llm_model']).replace(
63+
'%k', get_conf('GEMINI_API_KEY')
64+
)
65+
header = {'Content-Type': 'application/json'}
66+
if 'vision' not in self.url_gemini: # 不是vision 才处理history
67+
messages.extend(self.__conversation_history(history)) # 处理 history
68+
messages.append(self.__conversation_user(inputs)) # 处理用户对话
69+
payload = {
70+
"contents": messages,
71+
"generationConfig": {
72+
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
73+
"temperature": llm_kwargs.get('temperature', 1),
74+
# "maxOutputTokens": 800,
75+
"topP": llm_kwargs.get('top_p', 0.8),
76+
"topK": 10
77+
}
78+
}
79+
return header, payload
80+
81+
82+
if __name__ == '__main__':
83+
google = GoogleChatInit()
84+
# print(gootle.generate_message_payload('你好呀', {},
85+
# ['123123', '3123123'], ''))
86+
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')

0 commit comments

Comments
 (0)