1
+ # encoding: utf-8
2
+ # @Time : 2023/12/25
3
+ # @Author : Spike
4
+ # @Descr :
5
+ import json
6
+ import os
7
+ import re
8
+ import requests
9
+ from typing import List , Dict , Tuple
10
+ from toolbox import get_conf , encode_image
11
+
12
+ proxies , TIMEOUT_SECONDS = get_conf ('proxies' , 'TIMEOUT_SECONDS' )
13
+
14
+ """
15
+ ========================================================================
16
+ 第五部分 一些文件处理方法
17
+ files_filter_handler 根据type过滤文件
18
+ input_encode_handler 提取input中的文件,并解析
19
+ file_manifest_filter_html 根据type过滤文件, 并解析为html or md 文本
20
+ link_mtime_to_md 文件增加本地时间参数,避免下载到缓存文件
21
+ html_view_blank 超链接
22
+ html_local_file 本地文件取相对路径
23
+ to_markdown_tabs 文件list 转换为 md tab
24
+ """
25
+
26
+
27
+ def files_filter_handler (file_list ):
28
+ new_list = []
29
+ filter_ = ['png' , 'jpg' , 'jpeg' , 'bmp' , 'svg' , 'webp' , 'ico' , 'tif' , 'tiff' , 'raw' , 'eps' ]
30
+ for file in file_list :
31
+ file = str (file ).replace ('file=' , '' )
32
+ if os .path .exists (file ):
33
+ if str (os .path .basename (file )).split ('.' )[- 1 ] in filter_ :
34
+ new_list .append (file )
35
+ return new_list
36
+
37
+
38
+ def input_encode_handler (inputs ):
39
+ md_encode = []
40
+ pattern_md_file = r"(!?\[[^\]]+\]\([^\)]+\))"
41
+ matches_path = re .findall (pattern_md_file , inputs )
42
+ for md_path in matches_path :
43
+ pattern_file = r"\((file=.*)\)"
44
+ matches_path = re .findall (pattern_file , md_path )
45
+ encode_file = files_filter_handler (file_list = matches_path )
46
+ if encode_file :
47
+ md_encode .extend ([{
48
+ "data" : encode_image (i ),
49
+ "type" : os .path .splitext (i )[1 ].replace ('.' , '' )
50
+ } for i in encode_file ])
51
+ inputs = inputs .replace (md_path , '' )
52
+ return inputs , md_encode
53
+
54
+
55
+ def file_manifest_filter_html (file_list , filter_ : list = None , md_type = False ):
56
+ new_list = []
57
+ if not filter_ :
58
+ filter_ = ['png' , 'jpg' , 'jpeg' , 'bmp' , 'svg' , 'webp' , 'ico' , 'tif' , 'tiff' , 'raw' , 'eps' ]
59
+ for file in file_list :
60
+ if str (os .path .basename (file )).split ('.' )[- 1 ] in filter_ :
61
+ new_list .append (html_local_img (file , md = md_type ))
62
+ elif os .path .exists (file ):
63
+ new_list .append (link_mtime_to_md (file ))
64
+ else :
65
+ new_list .append (file )
66
+ return new_list
67
+
68
+
69
+ def link_mtime_to_md (file ):
70
+ link_local = html_local_file (file )
71
+ link_name = os .path .basename (file )
72
+ a = f"[{ link_name } ]({ link_local } ?{ os .path .getmtime (file )} )"
73
+ return a
74
+
75
+
76
+ def html_local_file (file ):
77
+ base_path = os .path .dirname (__file__ ) # 项目目录
78
+ if os .path .exists (str (file )):
79
+ file = f'file={ file .replace (base_path , "." )} '
80
+ return file
81
+
82
+
83
+ def html_local_img (__file , layout = 'left' , max_width = None , max_height = None , md = True ):
84
+ style = ''
85
+ if max_width is not None :
86
+ style += f"max-width: { max_width } ;"
87
+ if max_height is not None :
88
+ style += f"max-height: { max_height } ;"
89
+ __file = html_local_file (__file )
90
+ a = f'<div align="{ layout } "><img src="{ __file } " style="{ style } "></div>'
91
+ if md :
92
+ a = f''
93
+ return a
94
+
95
+
96
+ def to_markdown_tabs (head : list , tabs : list , alignment = ':---:' , column = False ):
97
+ """
98
+ Args:
99
+ head: 表头:[]
100
+ tabs: 表值:[[列1], [列2], [列3], [列4]]
101
+ alignment: :--- 左对齐, :---: 居中对齐, ---: 右对齐
102
+ column: True to keep data in columns, False to keep data in rows (default).
103
+ Returns:
104
+ A string representation of the markdown table.
105
+ """
106
+ if column :
107
+ transposed_tabs = list (map (list , zip (* tabs )))
108
+ else :
109
+ transposed_tabs = tabs
110
+ # Find the maximum length among the columns
111
+ max_len = max (len (column ) for column in transposed_tabs )
112
+
113
+ tab_format = "| %s "
114
+ tabs_list = "" .join ([tab_format % i for i in head ]) + '|\n '
115
+ tabs_list += "" .join ([tab_format % alignment for i in head ]) + '|\n '
116
+
117
+ for i in range (max_len ):
118
+ row_data = [tab [i ] if i < len (tab ) else '' for tab in transposed_tabs ]
119
+ row_data = file_manifest_filter_html (row_data , filter_ = None )
120
+ tabs_list += "" .join ([tab_format % i for i in row_data ]) + '|\n '
121
+
122
+ return tabs_list
123
+
124
+
125
+ class GoogleChatInit :
126
+
127
+ def __init__ (self ):
128
+ self .url_gemini = 'https://generativelanguage.googleapis.com/v1beta/models/%m:streamGenerateContent?key=%k'
129
+
130
+ def __conversation_user (self , user_input ):
131
+ what_i_have_asked = {"role" : "user" , "parts" : []}
132
+ if 'vision' not in self .url_gemini :
133
+ input_ = user_input
134
+ encode_img = []
135
+ else :
136
+ input_ , encode_img = input_encode_handler (user_input )
137
+ what_i_have_asked ['parts' ].append ({'text' : input_ })
138
+ if encode_img :
139
+ for data in encode_img :
140
+ what_i_have_asked ['parts' ].append (
141
+ {'inline_data' : {
142
+ "mime_type" : f"image/{ data ['type' ]} " ,
143
+ "data" : data ['data' ]
144
+ }})
145
+ return what_i_have_asked
146
+
147
+ def __conversation_history (self , history ):
148
+ messages = []
149
+ conversation_cnt = len (history ) // 2
150
+ if conversation_cnt :
151
+ for index in range (0 , 2 * conversation_cnt , 2 ):
152
+ what_i_have_asked = self .__conversation_user (history [index ])
153
+ what_gpt_answer = {
154
+ "role" : "model" ,
155
+ "parts" : [{"text" : history [index + 1 ]}]
156
+ }
157
+ messages .append (what_i_have_asked )
158
+ messages .append (what_gpt_answer )
159
+ return messages
160
+
161
+ def generate_chat (self , inputs , llm_kwargs , history , system_prompt ):
162
+ headers , payload = self .generate_message_payload (inputs , llm_kwargs , history , system_prompt )
163
+ response = requests .post (url = self .url_gemini , headers = headers , data = json .dumps (payload ),
164
+ stream = True , proxies = proxies , timeout = TIMEOUT_SECONDS )
165
+ return response .iter_lines ()
166
+
167
+ def generate_message_payload (self , inputs , llm_kwargs , history , system_prompt ) -> Tuple [Dict , Dict ]:
168
+ messages = [
169
+ # {"role": "system", "parts": [{"text": system_prompt}]}, # gemini 不允许对话轮次为偶数,所以这个没有用,看后续支持吧。。。
170
+ # {"role": "user", "parts": [{"text": ""}]},
171
+ # {"role": "model", "parts": [{"text": ""}]}
172
+ ]
173
+ self .url_gemini = self .url_gemini .replace (
174
+ '%m' , llm_kwargs ['llm_model' ]).replace (
175
+ '%k' , get_conf ('GEMINI_API_KEY' )
176
+ )
177
+ header = {'Content-Type' : 'application/json' }
178
+ if 'vision' not in self .url_gemini : # 不是vision 才处理history
179
+ messages .extend (self .__conversation_history (history )) # 处理 history
180
+ messages .append (self .__conversation_user (inputs )) # 处理用户对话
181
+ payload = {
182
+ "contents" : messages ,
183
+ "generationConfig" : {
184
+ "stopSequences" : str (llm_kwargs .get ('stop' , '' )).split (' ' ),
185
+ "temperature" : llm_kwargs .get ('temperature' , 1 ),
186
+ # "maxOutputTokens": 800,
187
+ "topP" : llm_kwargs .get ('top_p' , 0.8 ),
188
+ "topK" : 10
189
+ }
190
+ }
191
+ return header , payload
192
+
193
+
194
+ if __name__ == '__main__' :
195
+ google = GoogleChatInit ()
196
+ # print(gootle.generate_message_payload('你好呀', {},
197
+ # ['123123', '3123123'], ''))
198
+ # gootle.input_encode_handle('123123[123123](./123123), ')
0 commit comments