Skip to content

Commit ece1d91

Browse files
authored
Merge pull request #76 from kense-lab/feat/upgrade-depends
Feat/upgrade versions of some libraries
2 parents b5e878c + ab9f81e commit ece1d91

29 files changed

+907
-838
lines changed

app/api/v1/action.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from app.api.deps import get_async_session, get_token_id
88
from app.libs.paginate import cursor_page, CommonPage
9-
from app.models.action import Action
9+
from app.models.action import Action, ActionRead
1010
from app.models.token_relation import RelationType
1111
from app.providers.auth_provider import auth_policy
1212
from app.schemas.common import DeleteResponse, BaseSuccessDataResponse
@@ -16,44 +16,50 @@
1616
router = APIRouter()
1717

1818

19-
@router.get("", response_model=CommonPage[Action])
19+
@router.get("", response_model=CommonPage[ActionRead])
2020
async def list_actions(*, session: AsyncSession = Depends(get_async_session), token_id=Depends(get_token_id)):
2121
"""
2222
Returns a list of Actions.
2323
"""
2424
statement = auth_policy.token_filter(
2525
select(Action), field=Action.id, relation_type=RelationType.Action, token_id=token_id
2626
)
27-
return await cursor_page(statement, session)
27+
page = await cursor_page(statement, session)
28+
page.data = [ast.model_dump(by_alias=True) for ast in page.data]
29+
return page.model_dump(by_alias=True)
2830

2931

30-
@router.post("", response_model=List[Action])
32+
@router.post("", response_model=List[ActionRead])
3133
async def create_actions(
3234
*, session: AsyncSession = Depends(get_async_session), body: ActionBulkCreateRequest, token_id=Depends(get_token_id)
33-
) -> List[Action]:
35+
):
3436
"""
3537
Create an action with openapi schema.
3638
"""
3739

38-
return await ActionService.create_actions(session=session, body=body, token_id=token_id)
40+
actions = await ActionService.create_actions(session=session, body=body, token_id=token_id)
41+
actions = [item.model_dump(by_alias=True) for item in actions]
42+
return actions
3943

4044

41-
@router.get("/{action_id}", response_model=Action)
42-
async def get_action(*, session: AsyncSession = Depends(get_async_session), action_id: str) -> Action:
45+
@router.get("/{action_id}", response_model=ActionRead)
46+
async def get_action(*, session: AsyncSession = Depends(get_async_session), action_id: str):
4347
"""
4448
Retrieves an action.
4549
"""
46-
return await ActionService.get_action(session=session, action_id=action_id)
50+
action = await ActionService.get_action(session=session, action_id=action_id)
51+
return action.model_dump(by_alias=True)
4752

4853

49-
@router.post("/{action_id}", response_model=Action)
54+
@router.post("/{action_id}", response_model=ActionRead)
5055
async def modify_action(
5156
*, session: AsyncSession = Depends(get_async_session), action_id: str, body: ActionUpdateRequest
52-
) -> Action:
57+
):
5358
"""
5459
Modifies an action.
5560
"""
56-
return await ActionService.modify_action(session=session, action_id=action_id, body=body)
61+
action = await ActionService.modify_action(session=session, action_id=action_id, body=body)
62+
return action.model_dump(by_alias=True)
5763

5864

5965
@router.delete("/{action_id}", response_model=DeleteResponse)

app/api/v1/assistant.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlmodel import select
44

55
from app.api.deps import get_token_id, get_async_session
6-
from app.models.assistant import Assistant, AssistantUpdate, AssistantCreate
6+
from app.models.assistant import Assistant, AssistantUpdate, AssistantCreate, AssistantRead
77
from app.libs.paginate import cursor_page, CommonPage
88
from app.models.token_relation import RelationType
99
from app.providers.auth_provider import auth_policy
@@ -13,7 +13,7 @@
1313
router = APIRouter()
1414

1515

16-
@router.get("", response_model=CommonPage[Assistant])
16+
@router.get("", response_model=CommonPage[AssistantRead])
1717
async def list_assistants(*, session: AsyncSession = Depends(get_async_session), token_id=Depends(get_token_id)):
1818
"""
1919
Returns a list of assistants.
@@ -26,30 +26,30 @@ async def list_assistants(*, session: AsyncSession = Depends(get_async_session),
2626
return asts_page
2727

2828

29-
@router.post("", response_model=Assistant, response_model_exclude={"metadata"})
29+
@router.post("", response_model=AssistantRead)
3030
async def create_assistant(
3131
*, session: AsyncSession = Depends(get_async_session), body: AssistantCreate, token_id=Depends(get_token_id)
32-
) -> Assistant:
32+
):
3333
"""
3434
Create an assistant with a model and instructions.
3535
"""
3636
ast = await AssistantService.create_assistant(session=session, body=body, token_id=token_id)
3737
return ast.model_dump(by_alias=True)
3838

3939

40-
@router.get("/{assistant_id}", response_model=Assistant)
41-
async def get_assistant(*, session: AsyncSession = Depends(get_async_session), assistant_id: str) -> Assistant:
40+
@router.get("/{assistant_id}", response_model=AssistantRead)
41+
async def get_assistant(*, session: AsyncSession = Depends(get_async_session), assistant_id: str):
4242
"""
4343
Retrieves an assistant.
4444
"""
4545
ast = await AssistantService.get_assistant(session=session, assistant_id=assistant_id)
4646
return ast.model_dump(by_alias=True)
4747

4848

49-
@router.post("/{assistant_id}", response_model=Assistant)
49+
@router.post("/{assistant_id}", response_model=AssistantRead)
5050
async def modify_assistant(
5151
*, session: AsyncSession = Depends(get_async_session), assistant_id: str, body: AssistantUpdate
52-
) -> Assistant:
52+
):
5353
"""
5454
Modifies an assistant.
5555
"""

app/api/v1/message.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from app.api.deps import get_async_session
88
from app.models import MessageFile
9-
from app.models.message import Message, MessageCreate, MessageUpdate
9+
from app.models.message import Message, MessageCreate, MessageUpdate, MessageRead
1010
from app.libs.paginate import cursor_page, CommonPage
1111
from app.services.message.message import MessageService
1212

@@ -15,7 +15,7 @@
1515

1616
@router.get(
1717
"/{thread_id}/messages",
18-
response_model=CommonPage[Message],
18+
response_model=CommonPage[MessageRead],
1919
)
2020
async def list_messages(
2121
*,
@@ -30,47 +30,53 @@ async def list_messages(
3030
if run_id:
3131
# 根据 run_id 进行过滤
3232
statement = statement.where(Message.run_id == run_id)
33-
return await cursor_page(statement, session)
3433

34+
page = await cursor_page(statement, session)
35+
page.data = [ast.model_dump(by_alias=True) for ast in page.data]
36+
return page
3537

36-
@router.post("/{thread_id}/messages", response_model=Message)
38+
39+
@router.post("/{thread_id}/messages", response_model=MessageRead)
3740
async def create_message(
3841
*, session: AsyncSession = Depends(get_async_session), thread_id: str, body: MessageCreate
39-
) -> Message:
42+
):
4043
"""
4144
Create a message.
4245
"""
43-
return await MessageService.create_message(session=session, thread_id=thread_id, body=body)
46+
message = await MessageService.create_message(session=session, thread_id=thread_id, body=body)
47+
return message.model_dump(by_alias=True)
4448

4549

4650
@router.get(
4751
"/{thread_id}/messages/{message_id}",
48-
response_model=Message,
52+
response_model=MessageRead,
4953
)
5054
async def get_message(
5155
*, session: AsyncSession = Depends(get_async_session), thread_id: str, message_id: str
52-
) -> Message:
56+
):
5357
"""
5458
Retrieve a message.
5559
"""
56-
return await MessageService.get_message(session=session, thread_id=thread_id, message_id=message_id)
60+
message = await MessageService.get_message(session=session, thread_id=thread_id, message_id=message_id)
61+
return message.model_dump(by_alias=True)
5762

5863

5964
@router.post(
6065
"/{thread_id}/messages/{message_id}",
61-
response_model=Message,
66+
response_model=MessageRead,
6267
)
6368
async def modify_message(
6469
*,
6570
session: AsyncSession = Depends(get_async_session),
6671
thread_id: str,
6772
message_id: str = ...,
6873
body: MessageUpdate = ...,
69-
) -> Message:
74+
):
7075
"""
7176
Modifies a message.
7277
"""
73-
return await MessageService.modify_message(session=session, thread_id=thread_id, message_id=message_id, body=body)
78+
message = await MessageService.modify_message(session=session, thread_id=thread_id, message_id=message_id, body=body)
79+
return message.model_dump(by_alias=True)
7480

7581

7682
@router.get(

app/api/v1/runs.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from app.api.deps import get_async_session
77
from app.core.runner import pub_handler
88
from app.libs.paginate import cursor_page, CommonPage
9-
from app.models import RunStep
109
from app.models.run import RunCreate, RunRead, RunUpdate, Run
10+
from app.models.run_step import RunStep, RunStepRead
1111
from app.schemas.runs import SubmitToolOutputsRunRequest
1212
from app.schemas.threads import CreateThreadAndRun
1313
from app.services.run.run import RunService
@@ -19,7 +19,7 @@
1919

2020
@router.get(
2121
"/{thread_id}/runs",
22-
response_model=CommonPage[Run],
22+
response_model=CommonPage[RunRead],
2323
)
2424
async def list_runs(
2525
*,
@@ -30,7 +30,9 @@ async def list_runs(
3030
Returns a list of runs belonging to a thread.
3131
"""
3232
await ThreadService.get_thread(session=session, thread_id=thread_id)
33-
return await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
33+
page = await cursor_page(select(Run).where(Run.thread_id == thread_id), session)
34+
page.data = [ast.model_dump(by_alias=True) for ast in page.data]
35+
return page.model_dump(by_alias=True)
3436

3537

3638
@router.post(
@@ -52,7 +54,7 @@ async def create_run(
5254
if body.stream:
5355
return pub_handler.sub_stream(db_run.id, request)
5456
else:
55-
return db_run
57+
return db_run.model_dump(by_alias=True)
5658

5759

5860
@router.get(
@@ -63,7 +65,8 @@ async def get_run(*, session: AsyncSession = Depends(get_async_session), thread_
6365
"""
6466
Retrieves a run.
6567
"""
66-
return await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
68+
run = await RunService.get_run(session=session, run_id=run_id, thread_id=thread_id)
69+
return run.model_dump(by_alias=True)
6770

6871

6972
@router.post(
@@ -80,7 +83,8 @@ async def modify_run(
8083
"""
8184
Modifies a run.
8285
"""
83-
return await RunService.modify_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
86+
run = await RunService.modify_run(session=session, thread_id=thread_id, run_id=run_id, body=body)
87+
return run.model_dump(by_alias=True)
8488

8589

8690
@router.post(
@@ -93,12 +97,13 @@ async def cancel_run(
9397
"""
9498
Cancels a run that is `in_progress`.
9599
"""
96-
return await RunService.cancel_run(session=session, thread_id=thread_id, run_id=run_id)
100+
run = await RunService.cancel_run(session=session, thread_id=thread_id, run_id=run_id)
101+
return run.model_dump(by_alias=True)
97102

98103

99104
@router.get(
100105
"/{thread_id}/runs/{run_id}/steps",
101-
response_model=CommonPage[RunStep],
106+
response_model=CommonPage[RunStepRead],
102107
)
103108
async def list_run_steps(
104109
*,
@@ -109,14 +114,17 @@ async def list_run_steps(
109114
"""
110115
Returns a list of run steps belonging to a run.
111116
"""
112-
return await cursor_page(
117+
page = await cursor_page(
113118
select(RunStep).where(RunStep.thread_id == thread_id).where(RunStep.run_id == run_id), session
114119
)
120+
page.data = [ast.model_dump(by_alias=True) for ast in page.data]
121+
return page.model_dump(by_alias=True)
122+
115123

116124

117125
@router.get(
118126
"/{thread_id}/runs/{run_id}/steps/{step_id}",
119-
response_model=RunStep,
127+
response_model=RunStepRead,
120128
)
121129
async def get_run_step(
122130
*,
@@ -128,7 +136,8 @@ async def get_run_step(
128136
"""
129137
Retrieves a run step.
130138
"""
131-
return await RunService.get_run_step(thread_id=thread_id, run_id=run_id, step_id=step_id, session=session)
139+
run_step = await RunService.get_run_step(thread_id=thread_id, run_id=run_id, step_id=step_id, session=session)
140+
return run_step.model_dump(by_alias=True)
132141

133142

134143
@router.post(
@@ -156,7 +165,7 @@ async def submit_tool_outputs_to_run(
156165
if body.stream:
157166
return pub_handler.sub_stream(db_run.id, request)
158167
else:
159-
return db_run
168+
return db_run.model_dump(by_alias=True)
160169

161170

162171
@router.post("/runs", response_model=RunRead)
@@ -170,4 +179,4 @@ async def create_thread_and_run(
170179
if body.stream:
171180
return pub_handler.sub_stream(run.id, request)
172181
else:
173-
return run
182+
return run.model_dump(by_alias=True)

app/core/runner/thread_runner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from sqlalchemy.orm import Session
88

9+
from app.models.token_relation import RelationType
910
from config.config import settings
1011
from config.llm import llm_settings, tool_settings
1112

@@ -247,7 +248,7 @@ def __init_llm_backend(self, assistant_id):
247248
if settings.AUTH_ENABLE:
248249
# init llm backend with token id
249250
token_id = TokenRelationService.get_token_id_by_relation(
250-
session=self.session, relation_type="assistant", relation_id=assistant_id
251+
session=self.session, relation_type=RelationType.Assistant, relation_id=assistant_id
251252
)
252253
token = TokenService.get_token_by_id(self.session, token_id)
253254
return LLMBackend(base_url=token.llm_base_url, api_key=token.llm_api_key)

app/libs/paginate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
class CursorParams(BaseModel, AbstractParams):
1919
limit: int = Query(20, ge=1, le=100, description="Page offset")
20-
order: str = "desc"
21-
after: Optional[str] = None
22-
before: Optional[str] = None
20+
order: str = Query(default="desc", description="Sort order")
21+
after: Optional[str] = Query(None, description="Page after")
22+
before: Optional[str] = Query(None, description="Page before")
2323

2424
def to_raw_params(self) -> CursorRawParams:
2525
return CursorRawParams(cursor=None, size=self.limit, include_total=True)

app/libs/types.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1 @@
1-
from datetime import datetime
2-
3-
from app.libs.util import datetime2timestamp
4-
51
from app.libs.bson.objectid import ObjectId as ObjectId # noqa
6-
7-
8-
class Timestamp(datetime):
9-
@classmethod
10-
def __get_validators__(cls):
11-
yield datetime2timestamp

app/libs/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,21 @@ def datetime2timestamp(value: datetime):
88
return value.timestamp()
99

1010

11+
def str2datetime(value: str):
12+
if not value:
13+
return None
14+
return datetime.fromisoformat(value)
15+
16+
17+
def is_valid_datetime(date_str, format="%Y-%m-%d %H:%M:%S"):
18+
if not date_str or not isinstance(date_str, str):
19+
return False
20+
try:
21+
datetime.strptime(date_str, format)
22+
return True
23+
except ValueError:
24+
return False
25+
26+
1127
def random_uuid() -> str:
1228
return "ml-" + str(uuid.uuid4()).replace("-", "")

0 commit comments

Comments
 (0)