Skip to content

Commit a7c960d

Browse files
Kilig947binary-sky
authored andcommitted
适配 google gemini 优化为从用户input中提取文件 (#1419)
适配 google gemini 优化为从用户input中提取文件
1 parent a96f842 commit a7c960d

File tree

5 files changed

+472
-95
lines changed

5 files changed

+472
-95
lines changed

config.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,14 @@
8989
LLM_MODEL = "gpt-3.5-turbo" # 可选 ↓↓↓
9090
AVAIL_LLM_MODELS = ["gpt-3.5-turbo-1106","gpt-4-1106-preview","gpt-4-vision-preview",
9191
"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "azure-gpt-3.5",
92-
"api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
9392
"gpt-4", "gpt-4-32k", "azure-gpt-4", "api2d-gpt-4",
94-
"chatglm3", "moss", "claude-2"]
95-
# P.S. 其他可用的模型还包括 ["zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-3.5-random"
93+
"gemini-pro", "chatglm3", "moss", "claude-2"]
94+
# P.S. 其他可用的模型还包括 [
95+
# "qwen-turbo", "qwen-plus", "qwen-max"
96+
# "zhipuai", "qianfan", "deepseekcoder", "llama2", "qwen-local", "gpt-3.5-turbo-0613",
97+
# "gpt-3.5-turbo-16k-0613", "gpt-3.5-random", "api2d-gpt-3.5-turbo", 'api2d-gpt-3.5-turbo-16k',
9698
# "spark", "sparkv2", "sparkv3", "chatglm_onnx", "claude-1-100k", "claude-2", "internlm", "jittorllms_pangualpha", "jittorllms_llama"
97-
# “qwen-turbo", "qwen-plus", "qwen-max"]
99+
# ]
98100

99101

100102
# 定义界面上“询问多个GPT模型”插件应该使用哪些模型,请从AVAIL_LLM_MODELS中选择,并在不同模型之间用`&`间隔,例如"gpt-3.5-turbo&chatglm3&azure-gpt-4"
@@ -204,6 +206,10 @@
204206
CUSTOM_API_KEY_PATTERN = ""
205207

206208

209+
# Google Gemini API-Key
210+
GEMINI_API_KEY = ''
211+
212+
207213
# HUGGINGFACE的TOKEN,下载LLAMA时起作用 https://huggingface.co/docs/hub/security-tokens
208214
HUGGINGFACE_ACCESS_TOKEN = "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"
209215

@@ -292,6 +298,9 @@
292298
├── "qwen-turbo" 等通义千问大模型
293299
│ └── DASHSCOPE_API_KEY
294300
301+
├── "Gemini"
302+
│ └── GEMINI_API_KEY
303+
295304
└── "newbing" Newbing接口不再稳定,不推荐使用
296305
├── NEWBING_STYLE
297306
└── NEWBING_COOKIES

request_llms/bridge_all.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
from .bridge_qianfan import predict_no_ui_long_connection as qianfan_noui
2929
from .bridge_qianfan import predict as qianfan_ui
3030

31+
from .bridge_google_gemini import predict as genai_ui
32+
from .bridge_google_gemini import predict_no_ui_long_connection as genai_noui
33+
3134
colors = ['#FF00FF', '#00FFFF', '#FF0000', '#990099', '#009999', '#990044']
3235

3336
class LazyloadTiktoken(object):
@@ -246,6 +249,22 @@ def decode(self, *args, **kwargs):
246249
"tokenizer": tokenizer_gpt35,
247250
"token_cnt": get_token_num_gpt35,
248251
},
252+
"gemini-pro": {
253+
"fn_with_ui": genai_ui,
254+
"fn_without_ui": genai_noui,
255+
"endpoint": None,
256+
"max_token": 1024 * 32,
257+
"tokenizer": tokenizer_gpt35,
258+
"token_cnt": get_token_num_gpt35,
259+
},
260+
"gemini-pro-vision": {
261+
"fn_with_ui": genai_ui,
262+
"fn_without_ui": genai_noui,
263+
"endpoint": None,
264+
"max_token": 1024 * 32,
265+
"tokenizer": tokenizer_gpt35,
266+
"token_cnt": get_token_num_gpt35,
267+
},
249268
}
250269

251270
# -=-=-=-=-=-=- api2d 对齐支持 -=-=-=-=-=-=-

request_llms/bridge_google_gemini.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# encoding: utf-8
2+
# @Time : 2023/12/21
3+
# @Author : Spike
4+
# @Descr :
5+
import json
6+
import re
7+
import time
8+
from request_llms.com_google import GoogleChatInit
9+
from toolbox import get_conf, update_ui, update_ui_lastest_msg
10+
11+
proxies, TIMEOUT_SECONDS, MAX_RETRY = get_conf('proxies', 'TIMEOUT_SECONDS', 'MAX_RETRY')
12+
timeout_bot_msg = '[Local Message] Request timeout. Network error. Please check proxy settings in config.py.' + \
13+
'网络错误,检查代理服务器是否可用,以及代理设置的格式是否正确,格式须是[协议]://[地址]:[端口],缺一不可。'
14+
15+
16+
def predict_no_ui_long_connection(inputs, llm_kwargs, history=[], sys_prompt="", observe_window=None,
17+
console_slience=False):
18+
# 检查API_KEY
19+
if get_conf("GEMINI_API_KEY") == "":
20+
raise ValueError(f"请配置 GEMINI_API_KEY。")
21+
22+
genai = GoogleChatInit()
23+
watch_dog_patience = 5 # 看门狗的耐心, 设置5秒即可
24+
gpt_replying_buffer = ''
25+
stream_response = genai.generate_chat(inputs, llm_kwargs, history, sys_prompt)
26+
for response in stream_response:
27+
results = response.decode()
28+
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
29+
error_match = re.search(r'\"message\":\s*\"(.*?)\"', results, flags=re.DOTALL)
30+
if match:
31+
try:
32+
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
33+
except:
34+
raise ValueError(f"解析GEMINI消息出错。")
35+
buffer = paraphrase['text']
36+
gpt_replying_buffer += buffer
37+
if len(observe_window) >= 1:
38+
observe_window[0] = gpt_replying_buffer
39+
if len(observe_window) >= 2:
40+
if (time.time() - observe_window[1]) > watch_dog_patience: raise RuntimeError("程序终止。")
41+
if error_match:
42+
raise RuntimeError(f'{gpt_replying_buffer} 对话错误')
43+
return gpt_replying_buffer
44+
45+
46+
def predict(inputs, llm_kwargs, plugin_kwargs, chatbot, history=[], system_prompt='', stream=True, additional_fn=None):
47+
# 检查API_KEY
48+
if get_conf("GEMINI_API_KEY") == "":
49+
yield from update_ui_lastest_msg(f"请配置 GEMINI_API_KEY。", chatbot=chatbot, history=history, delay=0)
50+
return
51+
52+
chatbot.append((inputs, ""))
53+
yield from update_ui(chatbot=chatbot, history=history)
54+
genai = GoogleChatInit()
55+
retry = 0
56+
while True:
57+
try:
58+
stream_response = genai.generate_chat(inputs, llm_kwargs, history, system_prompt)
59+
break
60+
except Exception as e:
61+
retry += 1
62+
chatbot[-1] = ((chatbot[-1][0], timeout_bot_msg))
63+
retry_msg = f",正在重试 ({retry}/{MAX_RETRY}) ……" if MAX_RETRY > 0 else ""
64+
yield from update_ui(chatbot=chatbot, history=history, msg="请求超时" + retry_msg) # 刷新界面
65+
if retry > MAX_RETRY: raise TimeoutError
66+
gpt_replying_buffer = ""
67+
gpt_security_policy = ""
68+
history.extend([inputs, ''])
69+
for response in stream_response:
70+
results = response.decode("utf-8") # 被这个解码给耍了。。
71+
gpt_security_policy += results
72+
match = re.search(r'"text":\s*"((?:[^"\\]|\\.)*)"', results, flags=re.DOTALL)
73+
error_match = re.search(r'\"message\":\s*\"(.*)\"', results, flags=re.DOTALL)
74+
if match:
75+
try:
76+
paraphrase = json.loads('{"text": "%s"}' % match.group(1))
77+
except:
78+
raise ValueError(f"解析GEMINI消息出错。")
79+
gpt_replying_buffer += paraphrase['text'] # 使用 json 解析库进行处理
80+
chatbot[-1] = (inputs, gpt_replying_buffer)
81+
history[-1] = gpt_replying_buffer
82+
yield from update_ui(chatbot=chatbot, history=history)
83+
if error_match:
84+
history = history[-2] # 错误的不纳入对话
85+
chatbot[-1] = (inputs, gpt_replying_buffer + f"对话错误,请查看message\n\n```\n{error_match.group(1)}\n```")
86+
yield from update_ui(chatbot=chatbot, history=history)
87+
raise RuntimeError('对话错误')
88+
if not gpt_replying_buffer:
89+
history = history[-2] # 错误的不纳入对话
90+
chatbot[-1] = (inputs, gpt_replying_buffer + f"触发了Google的安全访问策略,没有回答\n\n```\n{gpt_security_policy}\n```")
91+
yield from update_ui(chatbot=chatbot, history=history)
92+
93+
94+
95+
if __name__ == '__main__':
96+
import sys
97+
98+
llm_kwargs = {'llm_model': 'gemini-pro'}
99+
result = predict('Write long a story about a magic backpack.', llm_kwargs, llm_kwargs, [])
100+
for i in result:
101+
print(i)

request_llms/com_google.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# encoding: utf-8
2+
# @Time : 2023/12/25
3+
# @Author : Spike
4+
# @Descr :
5+
import json
6+
import os
7+
import re
8+
import requests
9+
from typing import List, Dict, Tuple
10+
from toolbox import get_conf, encode_image
11+
12+
proxies, TIMEOUT_SECONDS = get_conf('proxies', 'TIMEOUT_SECONDS')
13+
14+
"""
15+
========================================================================
16+
第五部分 一些文件处理方法
17+
files_filter_handler 根据type过滤文件
18+
input_encode_handler 提取input中的文件,并解析
19+
file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
20+
link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
21+
html_view_blank 超链接
22+
html_local_file 本地文件取相对路径
23+
to_markdown_tabs 文件list 转换为 md tab
24+
"""
25+
26+
27+
def files_filter_handler(file_list):
28+
new_list = []
29+
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
30+
for file in file_list:
31+
file = str(file).replace('file=', '')
32+
if os.path.exists(file):
33+
if str(os.path.basename(file)).split('.')[-1] in filter_:
34+
new_list.append(file)
35+
return new_list
36+
37+
38+
def input_encode_handler(inputs):
39+
md_encode = []
40+
pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
41+
matches_path = re.findall(pattern_md_file, inputs)
42+
for md_path in matches_path:
43+
pattern_file = r"\((file=.*)\)"
44+
matches_path = re.findall(pattern_file, md_path)
45+
encode_file = files_filter_handler(file_list=matches_path)
46+
if encode_file:
47+
md_encode.extend([{
48+
"data": encode_image(i),
49+
"type": os.path.splitext(i)[1].replace('.', '')
50+
} for i in encode_file])
51+
inputs = inputs.replace(md_path, '')
52+
return inputs, md_encode
53+
54+
55+
def file_manifest_filter_html(file_list, filter_: list = None, md_type=False):
56+
new_list = []
57+
if not filter_:
58+
filter_ = ['png', 'jpg', 'jpeg', 'bmp', 'svg', 'webp', 'ico', 'tif', 'tiff', 'raw', 'eps']
59+
for file in file_list:
60+
if str(os.path.basename(file)).split('.')[-1] in filter_:
61+
new_list.append(html_local_img(file, md=md_type))
62+
elif os.path.exists(file):
63+
new_list.append(link_mtime_to_md(file))
64+
else:
65+
new_list.append(file)
66+
return new_list
67+
68+
69+
def link_mtime_to_md(file):
70+
link_local = html_local_file(file)
71+
link_name = os.path.basename(file)
72+
a = f"[{link_name}]({link_local}?{os.path.getmtime(file)})"
73+
return a
74+
75+
76+
def html_local_file(file):
77+
base_path = os.path.dirname(__file__) # 项目目录
78+
if os.path.exists(str(file)):
79+
file = f'file={file.replace(base_path, ".")}'
80+
return file
81+
82+
83+
def html_local_img(__file, layout='left', max_width=None, max_height=None, md=True):
84+
style = ''
85+
if max_width is not None:
86+
style += f"max-width: {max_width};"
87+
if max_height is not None:
88+
style += f"max-height: {max_height};"
89+
__file = html_local_file(__file)
90+
a = f'<div align="{layout}"><img src="{__file}" style="{style}"></div>'
91+
if md:
92+
a = f'![{__file}]({__file})'
93+
return a
94+
95+
96+
def to_markdown_tabs(head: list, tabs: list, alignment=':---:', column=False):
97+
"""
98+
Args:
99+
head: 表头:[]
100+
tabs: 表值:[[列1], [列2], [列3], [列4]]
101+
alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
102+
column: True to keep data in columns, False to keep data in rows (default).
103+
Returns:
104+
A string representation of the markdown table.
105+
"""
106+
if column:
107+
transposed_tabs = list(map(list, zip(*tabs)))
108+
else:
109+
transposed_tabs = tabs
110+
# Find the maximum length among the columns
111+
max_len = max(len(column) for column in transposed_tabs)
112+
113+
tab_format = "| %s "
114+
tabs_list = "".join([tab_format % i for i in head]) + '|\n'
115+
tabs_list += "".join([tab_format % alignment for i in head]) + '|\n'
116+
117+
for i in range(max_len):
118+
row_data = [tab[i] if i < len(tab) else '' for tab in transposed_tabs]
119+
row_data = file_manifest_filter_html(row_data, filter_=None)
120+
tabs_list += "".join([tab_format % i for i in row_data]) + '|\n'
121+
122+
return tabs_list
123+
124+
125+
class GoogleChatInit:
126+
127+
def __init__(self):
128+
self.url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
129+
130+
def __conversation_user(self, user_input):
131+
what_i_have_asked = {"role": "user", "parts": []}
132+
if 'vision' not in self.url_gemini:
133+
input_ = user_input
134+
encode_img = []
135+
else:
136+
input_, encode_img = input_encode_handler(user_input)
137+
what_i_have_asked['parts'].append({'text': input_})
138+
if encode_img:
139+
for data in encode_img:
140+
what_i_have_asked['parts'].append(
141+
{'inline_data': {
142+
"mime_type": f"image/{data['type']}",
143+
"data": data['data']
144+
}})
145+
return what_i_have_asked
146+
147+
def __conversation_history(self, history):
148+
messages = []
149+
conversation_cnt = len(history) // 2
150+
if conversation_cnt:
151+
for index in range(0, 2 * conversation_cnt, 2):
152+
what_i_have_asked = self.__conversation_user(history[index])
153+
what_gpt_answer = {
154+
"role": "model",
155+
"parts": [{"text": history[index + 1]}]
156+
}
157+
messages.append(what_i_have_asked)
158+
messages.append(what_gpt_answer)
159+
return messages
160+
161+
def generate_chat(self, inputs, llm_kwargs, history, system_prompt):
162+
headers, payload = self.generate_message_payload(inputs, llm_kwargs, history, system_prompt)
163+
response = requests.post(url=self.url_gemini, headers=headers, data=json.dumps(payload),
164+
stream=True, proxies=proxies, timeout=TIMEOUT_SECONDS)
165+
return response.iter_lines()
166+
167+
def generate_message_payload(self, inputs, llm_kwargs, history, system_prompt) -> Tuple[Dict, Dict]:
168+
messages = [
169+
# {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
170+
# {"role": "user", "parts": [{"text": ""}]},
171+
# {"role": "model", "parts": [{"text": ""}]}
172+
]
173+
self.url_gemini = self.url_gemini.replace(
174+
'%m', llm_kwargs['llm_model']).replace(
175+
'%k', get_conf('GEMINI_API_KEY')
176+
)
177+
header = {'Content-Type': 'application/json'}
178+
if 'vision' not in self.url_gemini: # 不是vision 才处理history
179+
messages.extend(self.__conversation_history(history)) # 处理 history
180+
messages.append(self.__conversation_user(inputs)) # 处理用户对话
181+
payload = {
182+
"contents": messages,
183+
"generationConfig": {
184+
"stopSequences": str(llm_kwargs.get('stop', '')).split(' '),
185+
"temperature": llm_kwargs.get('temperature', 1),
186+
# "maxOutputTokens": 800,
187+
"topP": llm_kwargs.get('top_p', 0.8),
188+
"topK": 10
189+
}
190+
}
191+
return header, payload
192+
193+
194+
if __name__ == '__main__':
195+
google = GoogleChatInit()
196+
# print(gootle.generate_message_payload('你好呀', {},
197+
# ['123123', '3123123'], ''))
198+
# gootle.input_encode_handle('123123[123123](./123123), ![53425](./asfafa/fff.jpg)')

0 commit comments

Comments
 (0)