Skip to content

Commit f963c12

Browse files
committed
feat: commit profiles in one session
1 parent 5760779 commit f963c12

File tree

5 files changed

+124
-73
lines changed

5 files changed

+124
-73
lines changed

src/server/api/memobase_server/controllers/modal/chat/__init__.py

Lines changed: 26 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1+
from ...project import get_project_profile_config
12
from ....connectors import Session
23
from ....env import LOG, ProfileConfig, CONFIG
34
from ....utils import get_blob_str, get_encoded_tokens
45
from ....models.blob import Blob
56
from ....models.utils import Promise, CODE
67
from ....models.response import IdsData, ChatModalResponse
7-
from ...profile import add_user_profiles, update_user_profiles, delete_user_profiles
8+
from ...profile import add_update_delete_user_profiles
89
from ...event import append_user_event
910
from .extract import extract_topics
1011
from .merge import merge_or_valid_new_memos
1112
from .summary import re_summary
1213
from .organize import organize_profiles
1314
from .types import MergeAddResult
1415
from .event_summary import tag_event
15-
from .entry_summary import entry_summary
16+
from .entry_summary import entry_chat_summary
1617

1718

1819
def truncate_chat_blobs(
@@ -43,12 +44,18 @@ async def process_blobs(
4344
return Promise.reject(
4445
CODE.SERVER_PARSE_ERROR, "No blobs to process after truncating"
4546
)
46-
p = await entry_summary(user_id, project_id, blobs)
47+
48+
p = await get_project_profile_config(project_id)
49+
if not p.ok():
50+
return p
51+
project_profiles = p.data()
52+
53+
p = await entry_chat_summary(user_id, project_id, blobs, project_profiles)
4754
if not p.ok():
4855
return p
4956
user_memo_str = p.data()
5057

51-
p = await extract_topics(user_id, project_id, user_memo_str)
58+
p = await extract_topics(user_id, project_id, user_memo_str, project_profiles)
5259
if not p.ok():
5360
return p
5461
extracted_data = p.data()
@@ -59,7 +66,7 @@ async def process_blobs(
5966
fact_contents=extracted_data["fact_contents"],
6067
fact_attributes=extracted_data["fact_attributes"],
6168
profiles=extracted_data["profiles"],
62-
config=extracted_data["config"],
69+
config=project_profiles,
6370
total_profiles=extracted_data["total_profiles"],
6471
)
6572
if not p.ok():
@@ -74,7 +81,7 @@ async def process_blobs(
7481
p = await organize_profiles(
7582
project_id,
7683
profile_options,
77-
config=extracted_data["config"],
84+
config=project_profiles,
7885
)
7986
if not p.ok():
8087
LOG.error(f"Failed to organize profiles: {p.msg()}")
@@ -94,29 +101,21 @@ async def process_blobs(
94101
project_id,
95102
user_memo_str,
96103
delta_profile_data,
97-
extracted_data["config"],
104+
project_profiles,
98105
)
99106
if not p.ok():
100107
return p
101108
eid = p.data()
102-
p = await exe_user_profile_add(user_id, project_id, profile_options)
103-
if not p.ok():
104-
return p
105-
add_profile_ids = p.data().ids
106-
p = await exe_user_profile_update(user_id, project_id, profile_options)
107-
if not p.ok():
108-
return p
109-
update_profile_ids = p.data().ids
110-
p = await exe_user_profile_delete(user_id, project_id, profile_options)
109+
110+
p = await handle_user_profile_db(user_id, project_id, profile_options)
111111
if not p.ok():
112112
return p
113-
delete_profile_ids = p.data().ids
114113
return Promise.resolve(
115114
ChatModalResponse(
116115
event_id=eid,
117-
add_profiles=add_profile_ids,
118-
update_profiles=update_profile_ids,
119-
delete_profiles=delete_profile_ids,
116+
add_profiles=p.data().ids,
117+
update_profiles=[up["profile_id"] for up in profile_options["update"]],
118+
delete_profiles=profile_options["delete"],
120119
)
121120
)
122121

@@ -149,44 +148,21 @@ async def handle_session_event(
149148
return eid
150149

151150

152-
async def exe_user_profile_add(
151+
async def handle_user_profile_db(
153152
user_id: str, project_id: str, profile_options: MergeAddResult
154153
) -> Promise[IdsData]:
155-
if not len(profile_options["add"]):
156-
return Promise.resolve(IdsData(ids=[]))
157154
LOG.info(f"Adding {len(profile_options['add'])} profiles for user {user_id}")
158-
task_add = await add_user_profiles(
155+
LOG.info(f"Updating {len(profile_options['update'])} profiles for user {user_id}")
156+
LOG.info(f"Deleting {len(profile_options['delete'])} profiles for user {user_id}")
157+
158+
p = await add_update_delete_user_profiles(
159159
user_id,
160160
project_id,
161161
[ap["content"] for ap in profile_options["add"]],
162162
[ap["attributes"] for ap in profile_options["add"]],
163-
)
164-
return task_add
165-
166-
167-
async def exe_user_profile_update(
168-
user_id: str, project_id: str, profile_options: MergeAddResult
169-
) -> Promise[IdsData]:
170-
if not len(profile_options["update"]):
171-
return Promise.resolve(IdsData(ids=[]))
172-
LOG.info(f"Updating {len(profile_options['update'])} profiles for user {user_id}")
173-
task_update = await update_user_profiles(
174-
user_id,
175-
project_id,
176163
[up["profile_id"] for up in profile_options["update"]],
177164
[up["content"] for up in profile_options["update"]],
178165
[up["attributes"] for up in profile_options["update"]],
166+
profile_options["delete"],
179167
)
180-
return task_update
181-
182-
183-
async def exe_user_profile_delete(
184-
user_id: str, project_id: str, profile_options: MergeAddResult
185-
) -> Promise[IdsData]:
186-
if not len(profile_options["delete"]):
187-
return Promise.resolve(IdsData(ids=[]))
188-
LOG.info(f"Deleting {len(profile_options['delete'])} profiles for user {user_id}")
189-
task_delete = await delete_user_profiles(
190-
user_id, project_id, profile_options["delete"]
191-
)
192-
return task_delete
168+
return p

src/server/api/memobase_server/controllers/modal/chat/entry_summary.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,16 @@
44
from ....models.blob import Blob, BlobType
55
from ....llms import llm_complete
66
from ....prompts.profile_init_utils import read_out_profile_config
7-
from ...project import get_project_profile_config
7+
from ...project import ProfileConfig
88
from ....prompts.profile_init_utils import read_out_event_tags
99
from ....prompts.utils import tag_chat_blobs_in_order_xml
1010
from .types import FactResponse, PROMPTS
1111

1212

13-
async def entry_summary(
14-
user_id: str, project_id: str, blobs: list[Blob]
13+
async def entry_chat_summary(
14+
user_id: str, project_id: str, blobs: list[Blob], project_profiles: ProfileConfig
1515
) -> Promise[str]:
1616
assert all(b.type == BlobType.chat for b in blobs), "All blobs must be chat blobs"
17-
p = await get_project_profile_config(project_id)
18-
if not p.ok():
19-
return p
20-
project_profiles = p.data()
2117
USE_LANGUAGE = project_profiles.language or CONFIG.language
2218
project_profiles_slots = read_out_profile_config(
2319
project_profiles, PROMPTS[USE_LANGUAGE]["profile"].CANDIDATE_PROFILE_TOPICS

src/server/api/memobase_server/controllers/modal/chat/extract.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from ....prompts.profile_init_utils import read_out_profile_config, UserProfileTopic
1414
from ...profile import get_user_profiles
15-
from ...project import get_project_profile_config
15+
from ...project import ProfileConfig
1616

1717
# from ...project impor
1818
from .types import FactResponse, PROMPTS
@@ -30,16 +30,12 @@ def merge_by_topic_sub_topics(new_facts: list[FactResponse]):
3030

3131

3232
async def extract_topics(
33-
user_id: str, project_id: str, user_memo: str
33+
user_id: str, project_id: str, user_memo: str, project_profiles: ProfileConfig
3434
) -> Promise[dict]:
3535
p = await get_user_profiles(user_id, project_id)
3636
if not p.ok():
3737
return p
3838
profiles = p.data().profiles
39-
p = await get_project_profile_config(project_id)
40-
if not p.ok():
41-
return p
42-
project_profiles = p.data()
4339
USE_LANGUAGE = project_profiles.language or CONFIG.language
4440
STRICT_MODE = (
4541
project_profiles.profile_strict_mode
@@ -113,7 +109,6 @@ async def extract_topics(
113109
"fact_contents": [],
114110
"fact_attributes": [],
115111
"profiles": profiles,
116-
"config": project_profiles,
117112
"total_profiles": project_profiles_slots,
118113
}
119114
)
@@ -145,7 +140,6 @@ async def extract_topics(
145140
"fact_contents": fact_contents,
146141
"fact_attributes": fact_attributes,
147142
"profiles": profiles,
148-
"config": project_profiles,
149143
"total_profiles": project_profiles_slots,
150144
}
151145
)

src/server/api/memobase_server/controllers/profile.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,88 @@ async def refresh_user_profile_cache(user_id: str, project_id: str) -> Promise[N
213213
async with get_redis_client() as redis_client:
214214
await redis_client.delete(f"user_profiles::{project_id}::{user_id}")
215215
return Promise.resolve(None)
216+
217+
218+
async def add_update_delete_user_profiles(
219+
user_id: str,
220+
project_id: str,
221+
add_profiles: list[str],
222+
add_attributes: list[dict],
223+
update_profile_ids: list[str],
224+
update_contents: list[str],
225+
update_attributes: list[dict | None],
226+
delete_profile_ids: list[str],
227+
) -> Promise[IdsData]:
228+
assert len(add_profiles) == len(
229+
add_attributes
230+
), "Length of add_profiles, add_attributes must be equal"
231+
assert len(update_profile_ids) == len(
232+
update_contents
233+
), "Length of update_profile_ids, update_contents must be equal"
234+
assert len(update_profile_ids) == len(
235+
update_attributes
236+
), "Length of update_profile_ids, update_attributes must be equal"
237+
238+
for attr in add_attributes + update_attributes:
239+
if attr is None:
240+
continue
241+
try:
242+
ProfileAttributes.model_validate(attr)
243+
except ValidationError as e:
244+
return Promise.reject(
245+
CODE.SERVER_PARSE_ERROR, f"Invalid profile attributes: {e}"
246+
)
247+
# Sanity Check done
248+
249+
with Session() as session:
250+
try:
251+
# 1. add new profiles
252+
if len(add_profiles):
253+
add_db_profiles = [
254+
UserProfile(
255+
user_id=user_id,
256+
project_id=project_id,
257+
content=content,
258+
attributes=attr,
259+
)
260+
for content, attr in zip(add_profiles, add_attributes)
261+
]
262+
session.add_all(add_db_profiles)
263+
add_profile_ids = [p.id for p in add_db_profiles]
264+
else:
265+
add_profile_ids = []
266+
# 2. update existing profiles
267+
update_db_profiles = []
268+
for profile_id, content, attribute in zip(
269+
update_profile_ids, update_contents, update_attributes
270+
):
271+
db_profile = (
272+
session.query(UserProfile)
273+
.filter_by(id=profile_id, user_id=user_id, project_id=project_id)
274+
.one_or_none()
275+
)
276+
if db_profile is None:
277+
LOG.error(f"Profile {profile_id} not found for user {user_id}")
278+
continue
279+
db_profile.content = content
280+
if attribute is not None:
281+
db_profile.attributes = attribute
282+
update_db_profiles.append(profile_id)
283+
284+
# 3. delete profiles
285+
session.query(UserProfile).filter(
286+
UserProfile.id.in_(delete_profile_ids),
287+
UserProfile.user_id == user_id,
288+
UserProfile.project_id == project_id,
289+
).delete(synchronize_session=False)
290+
291+
session.commit()
292+
except Exception as e:
293+
LOG.error(f"Error merging user profiles: {e}")
294+
session.rollback()
295+
return Promise.reject(
296+
CODE.SERVER_PARSE_ERROR, f"Error merging user profiles: {e}"
297+
)
298+
299+
await refresh_user_profile_cache(user_id, project_id)
300+
return Promise.resolve(IdsData(ids=add_profile_ids))

src/server/api/memobase_server/telemetry/open_telemetry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,14 @@ def setup_telemetry(self) -> None:
140140

141141
def _construct_attributes(self, **kwargs) -> Dict[str, str]:
142142

143-
if os.environ.get("POD_IP"):
144-
# use k8s downward API to get the pod ip
145-
pod_ip = os.environ.get("POD_IP")
146-
else:
147-
# use the hostname to get the ip address
148-
hostname = socket.gethostname()
149-
pod_ip = socket.gethostbyname(hostname)
150-
143+
# if os.environ.get("POD_IP"):
144+
# # use k8s downward API to get the pod ip
145+
# pod_ip = os.environ.get("POD_IP", None)
146+
# else:
147+
# # use the hostname to get the ip address
148+
# hostname = socket.gethostname()
149+
# pod_ip = socket.gethostbyname(hostname)
150+
pod_ip = os.environ.get("POD_IP", None)
151151
return {
152152
DEPLOYMENT_ENVIRONMENT: self._deployment_environment,
153153
"memobase_server_ip": pod_ip,

0 commit comments

Comments
 (0)