Skip to content

Commit b5e878c

Browse files
authored
Merge pull request #74 from klb3713/main
支持json返回模式
2 parents b4370ad + 52d193c commit b5e878c

File tree

6 files changed

+57
-7
lines changed

6 files changed

+57
-7
lines changed

app/core/runner/llm_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def run(
2525
extra_body=None,
2626
temperature=None,
2727
top_p=None,
28+
response_format=None,
2829
) -> ChatCompletion | Stream[ChatCompletionChunk]:
2930
chat_params = {
3031
"messages": messages,
@@ -44,6 +45,8 @@ def run(
4445
if tools:
4546
chat_params["tools"] = tools
4647
chat_params["tool_choice"] = tool_choice if tool_choice else "auto"
48+
if isinstance(response_format, dict) and response_format.get("type") == "json_object":
49+
chat_params["response_format"] = {"type": "json_object"}
4750
logging.info("chat_params: %s", chat_params)
4851
response = self.client.chat.completions.create(**chat_params)
4952
logging.info("chat_response: %s", response)

app/core/runner/thread_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __run_step(
133133
extra_body=run.extra_body,
134134
temperature=run.temperature,
135135
top_p=run.top_p,
136+
response_format=run.response_format,
136137
)
137138

138139
# create message callback

app/models/assistant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Optional, Union
22

33
from sqlalchemy import Column
44
from sqlmodel import Field, JSON, TEXT
@@ -15,7 +15,7 @@ class AssistantBase(BaseModel):
1515
name: Optional[str] = Field(default=None)
1616
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
1717
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
18-
response_format: Optional[str] = Field(default=None) # 响应格式
18+
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 响应格式
1919
tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源
2020
temperature: Optional[float] = Field(default=None) # 温度
2121
top_p: Optional[float] = Field(default=None) # top_p
@@ -38,7 +38,7 @@ class AssistantUpdate(BaseModel):
3838
name: Optional[str] = Field(default=None)
3939
tools: Optional[list] = Field(default=None, sa_column=Column(JSON))
4040
extra_body: Optional[dict] = Field(default={}, sa_column=Column(JSON))
41-
response_format: Optional[str] = Field(default=None) # 响应格式
41+
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 响应格式
4242
tool_resources: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 工具资源
4343
temperature: Optional[float] = Field(default=None) # 温度
4444
top_p: Optional[float] = Field(default=None) # top_p

app/models/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class Run(BaseModel, PrimaryKeyMixin, TimeStampMixin, table=True):
5050
incomplete_details: Optional[str] = Field(default=None) # 未完成详情
5151
max_completion_tokens: Optional[int] = Field(default=None) # 最大完成长度
5252
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
53-
response_format: Optional[str] = Field(default=None) # 返回格式
53+
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 返回格式
5454
tool_choice: Optional[str] = Field(default=None) # 工具选择
5555
truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 截断策略
5656
usage: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 调用使用情况
@@ -77,7 +77,7 @@ class RunCreate(BaseModel):
7777
max_completion_tokens: Optional[int] = None # 最大完成长度
7878
max_prompt_tokens: Optional[int] = Field(default=None) # 最大提示长度
7979
truncation_strategy: Optional[dict] = Field(default=None, sa_column=Column(JSON)) # 截断策略
80-
response_format: Optional[str] = Field(default=None) # 返回格式
80+
response_format: Union[str, dict] = Field(default="auto", sa_column=Column(JSON)) # 返回格式
8181
tool_choice: Optional[str] = Field(default=None) # 工具选择
8282
temperature: Optional[float] = Field(default=None) # 温度
8383
top_p: Optional[float] = Field(default=None) # top_p
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""empty message
2+
3+
Revision ID: 1c667e62f698
4+
Revises: aa4bda3363e3
5+
Create Date: 2024-05-28 11:35:33.961196
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
import sqlmodel
13+
from sqlalchemy.dialects import mysql
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = '1c667e62f698'
17+
down_revision: Union[str, None] = 'aa4bda3363e3'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
21+
22+
def upgrade() -> None:
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.alter_column('assistant', 'response_format',
25+
existing_type=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
26+
type_=sa.JSON(),
27+
existing_nullable=True)
28+
op.alter_column('run', 'response_format',
29+
existing_type=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
30+
type_=sa.JSON(),
31+
existing_nullable=True)
32+
# ### end Alembic commands ###
33+
34+
35+
def downgrade() -> None:
36+
# ### commands auto generated by Alembic - please adjust! ###
37+
op.alter_column('run', 'response_format',
38+
existing_type=sa.JSON(),
39+
type_=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
40+
existing_nullable=True)
41+
op.alter_column('assistant', 'response_format',
42+
existing_type=sa.JSON(),
43+
type_=mysql.VARCHAR(collation='utf8mb4_unicode_ci', length=255),
44+
existing_nullable=True)
45+
# ### end Alembic commands ###

tests/e2e/run_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def test_create_run_with_additional_messages_and_other_parmas():
1212
name="Assistant Demo",
1313
instructions="你是一个有用的助手",
1414
model="gpt-4o",
15+
response_format={"type": "json_object"},
1516
)
1617
thread = client.beta.threads.create(
1718
messages=[
@@ -42,7 +43,7 @@ def test_create_run_with_additional_messages_and_other_parmas():
4243
stream = client.beta.threads.runs.create(
4344
thread_id=thread.id,
4445
assistant_id=assistant.id,
45-
instructions="",
46+
instructions="请用 json 格式回答",
4647
additional_messages=[
4748
{
4849
"role": "user",
@@ -75,7 +76,7 @@ def test_create_run_with_additional_messages_and_other_parmas():
7576

7677
query = session.query(Run).filter(Run.thread_id == thread.id)
7778
run = query.one()
78-
assert run.instructions == "你是一个有用的助手"
79+
assert run.instructions == "请用 json 格式回答"
7980
assert run.model == "gpt-4o"
8081
query = session.query(Message).filter(Message.thread_id == thread.id).order_by(Message.created_at)
8182
messages = query.all()

0 commit comments

Comments
 (0)