Skip to content

Commit 25be802

Browse files
committed
added disk based prompt store
also added prompt store to query rewriter
1 parent efae48b commit 25be802

File tree

9 files changed

+496
-91
lines changed

9 files changed

+496
-91
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
using System;
2+
using System.IO;
3+
using System.Threading.Tasks;
4+
using Microsoft.Extensions.Logging;
5+
using Moq;
6+
using Xunit;
7+
8+
namespace KernelMemory.Extensions.FunctionalTests.Helper;
9+
10+
public class LocalFolderPromptStoreTests : IDisposable
11+
{
12+
private readonly string _testDirectory;
13+
private readonly Mock<ILogger<LocalFolderPromptStore>> _loggerMock;
14+
private readonly LocalFolderPromptStore _store;
15+
16+
public LocalFolderPromptStoreTests()
17+
{
18+
_testDirectory = Path.Combine(Path.GetTempPath(), $"promptstore_tests_{Guid.NewGuid()}");
19+
_loggerMock = new Mock<ILogger<LocalFolderPromptStore>>();
20+
_store = new LocalFolderPromptStore(_testDirectory, _loggerMock.Object);
21+
}
22+
23+
public void Dispose()
24+
{
25+
if (Directory.Exists(_testDirectory))
26+
{
27+
Directory.Delete(_testDirectory, true);
28+
}
29+
}
30+
31+
[Fact]
32+
public async Task GetPromptAsync_NonExistentKey_ReturnsNull()
33+
{
34+
// Act
35+
var result = await _store.GetPromptAsync("nonexistent");
36+
37+
// Assert
38+
Assert.Null(result);
39+
}
40+
41+
[Fact]
42+
public async Task SetAndGetPromptAsync_ValidKey_ReturnsStoredPrompt()
43+
{
44+
// Arrange
45+
const string key = "test-key";
46+
const string expectedPrompt = "This is a test prompt";
47+
48+
// Act
49+
await _store.SetPromptAsync(key, expectedPrompt);
50+
var result = await _store.GetPromptAsync(key);
51+
52+
// Assert
53+
Assert.Equal(expectedPrompt, result);
54+
}
55+
56+
[Fact]
57+
public async Task GetPromptAndSetDefaultAsync_NonExistentKey_SetsAndReturnsDefault()
58+
{
59+
// Arrange
60+
const string key = "default-key";
61+
const string defaultPrompt = "Default prompt value";
62+
63+
// Act
64+
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
65+
var storedPrompt = await _store.GetPromptAsync(key);
66+
67+
// Assert
68+
Assert.Equal(defaultPrompt, result);
69+
Assert.Equal(defaultPrompt, storedPrompt);
70+
}
71+
72+
[Fact]
73+
public async Task GetPromptAndSetDefaultAsync_ExistingKey_ReturnsExistingPrompt()
74+
{
75+
// Arrange
76+
const string key = "existing-key";
77+
const string existingPrompt = "Existing prompt";
78+
const string defaultPrompt = "Default prompt";
79+
await _store.SetPromptAsync(key, existingPrompt);
80+
81+
// Act
82+
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
83+
84+
// Assert
85+
Assert.Equal(existingPrompt, result);
86+
}
87+
88+
[Fact]
89+
public async Task SetPromptAsync_KeyWithSpecialCharacters_HandlesCorrectly()
90+
{
91+
// Arrange
92+
const string key = "special/\\*:?\"<>|characters";
93+
const string expectedPrompt = "Prompt with special characters";
94+
95+
// Act
96+
await _store.SetPromptAsync(key, expectedPrompt);
97+
var result = await _store.GetPromptAsync(key);
98+
99+
// Assert
100+
Assert.Equal(expectedPrompt, result);
101+
}
102+
103+
[Fact]
104+
public async Task GetPromptAndSetDefaultAsync_MissingPlaceholder_LogsError()
105+
{
106+
// Arrange
107+
const string key = "test-placeholder";
108+
const string existingPrompt = "A prompt without placeholder";
109+
const string defaultPrompt = "Default prompt with {{$placeholder}}";
110+
await _store.SetPromptAsync(key, existingPrompt);
111+
112+
// Act
113+
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
114+
115+
// Assert
116+
_loggerMock.Verify(
117+
x => x.Log(
118+
LogLevel.Error,
119+
It.IsAny<EventId>(),
120+
It.Is<It.IsAnyType>((v, t) => v.ToString().Contains("{{$placeholder}}")),
121+
It.IsAny<Exception>(),
122+
It.IsAny<Func<It.IsAnyType, Exception, string>>()
123+
),
124+
Times.Once);
125+
}
126+
127+
[Fact]
128+
public async Task GetPromptAndSetDefaultAsync_MultipleMissingPlaceholders_LogsMultipleErrors()
129+
{
130+
// Arrange
131+
const string key = "test-multiple-placeholders";
132+
const string existingPrompt = "A prompt without any placeholders";
133+
const string defaultPrompt = "Default with {{$first}} and {{$second}}";
134+
await _store.SetPromptAsync(key, existingPrompt);
135+
136+
// Act
137+
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
138+
139+
// Assert
140+
_loggerMock.Verify(
141+
x => x.Log(
142+
LogLevel.Error,
143+
It.IsAny<EventId>(),
144+
It.Is<It.IsAnyType>((v, t) => true),
145+
It.IsAny<Exception>(),
146+
It.IsAny<Func<It.IsAnyType, Exception, string>>()
147+
),
148+
Times.Exactly(2));
149+
}
150+
151+
[Fact]
152+
public async Task GetPromptAndSetDefaultAsync_ValidPlaceholders_NoErrors()
153+
{
154+
// Arrange
155+
const string key = "test-valid-placeholders";
156+
const string existingPrompt = "A prompt with {{$placeholder}} correctly set";
157+
const string defaultPrompt = "Default with {{$placeholder}}";
158+
await _store.SetPromptAsync(key, existingPrompt);
159+
160+
// Act
161+
var result = await _store.GetPromptAndSetDefaultAsync(key, defaultPrompt);
162+
163+
// Assert
164+
_loggerMock.Verify(
165+
x => x.Log(
166+
LogLevel.Error,
167+
It.IsAny<EventId>(),
168+
It.Is<It.IsAnyType>((v, t) => true),
169+
It.IsAny<Exception>(),
170+
It.IsAny<Func<It.IsAnyType, Exception, string>>()
171+
),
172+
Times.Never);
173+
}
174+
}

src/KernelMemory.Extensions.FunctionalTests/Helper/OpenaiRagQueryExecutorTests.cs

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@ namespace KernelMemory.Extensions.FunctionalTests.Helper;
99

1010
public class OpenaiRagQueryExecutorTests
1111
{
12-
private Kernel _kernel;
13-
private OpenaiRagQueryExecutor _sut;
14-
private Mock<IPromptStore> _mockPromptStore;
15-
private Mock<ILogger<StandardRagQueryExecutor>> _mockLogger;
16-
private Mock<ISemanticKernelWrapper> _mockKernel;
12+
private readonly OpenaiRagQueryExecutor _sut;
13+
private readonly Mock<IPromptStore> _mockPromptStore;
14+
private readonly Mock<ILogger<StandardRagQueryExecutor>> _mockLogger;
15+
private readonly Mock<ISemanticKernelWrapper> _mockKernel;
1716

1817
public OpenaiRagQueryExecutorTests()
1918
{
20-
_kernel = new Kernel();
19+
var kernel = new Kernel();
2120
_mockPromptStore = new Mock<IPromptStore>();
2221
_mockLogger = new Mock<ILogger<StandardRagQueryExecutor>>();
2322
_mockKernel = new Mock<ISemanticKernelWrapper>();
@@ -29,7 +28,7 @@ public async Task GetPromptAsync_ShouldReturnPromptFromMock()
2928
{
3029
// Arrange
3130
var expectedPrompt = "Test Prompt";
32-
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync(expectedPrompt);
31+
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync(expectedPrompt);
3332

3433
// Act
3534
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
@@ -44,7 +43,7 @@ public async Task GetPromptAsync_Should_validate_log_called()
4443
{
4544
// Arrange
4645
var invalidPrompt = "Invalid Prompt";
47-
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync(invalidPrompt);
46+
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync(invalidPrompt);
4847

4948
// Act
5049
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
@@ -73,15 +72,15 @@ public async Task GetPromptAsync_Should_validate_log_called()
7372

7473
// verify that store method of Ipromptstore is not called
7574
_mockPromptStore.Verify(
76-
store => store.SetPromptAsync(It.IsAny<string>(), It.IsAny<string>()),
75+
store => store.SetPromptAsync(It.IsAny<string>(), It.IsAny<string>(), It.IsAny<CancellationToken>()),
7776
Times.Never);
7877
}
7978

8079
[Fact]
8180
public async Task If_prompt_not_saved_reload()
8281
{
8382
// Arrange
84-
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync((String?) null);
83+
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>(), It.IsAny<CancellationToken>())).ReturnsAsync((String?) null);
8584

8685
// Act
8786
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
@@ -90,7 +89,7 @@ public async Task If_prompt_not_saved_reload()
9089
// Assert
9190
// verify that store method of Ipromptstore is called
9291
_mockPromptStore.Verify(
93-
store => store.SetPromptAsync("OpenaiRagQueryExecutor", It.IsAny<string>()),
92+
store => store.SetPromptAsync("OpenaiRagQueryExecutor", It.IsAny<string>(), It.IsAny<CancellationToken>()),
9493
Times.Once);
9594
}
9695
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
using KernelMemory.Extensions.Helper;
2+
using KernelMemory.Extensions.QueryPipeline;
3+
using Microsoft.SemanticKernel;
4+
using Microsoft.SemanticKernel.ChatCompletion;
5+
using Moq;
6+
7+
namespace KernelMemory.Extensions.FunctionalTests.QueryPipeline;
8+
9+
public class SemanticKernelQueryRewriterTests
10+
{
11+
private readonly Mock<IPromptStore> _promptStoreMock;
12+
private readonly Mock<ISemanticKernelWrapper> _kernelWrapperMock;
13+
private readonly Mock<IChatCompletionService> _chatCompletionServiceMock;
14+
private readonly SemanticKernelQueryRewriterOptions _options;
15+
16+
public SemanticKernelQueryRewriterTests()
17+
{
18+
_promptStoreMock = new Mock<IPromptStore>();
19+
_kernelWrapperMock = new Mock<ISemanticKernelWrapper>();
20+
_chatCompletionServiceMock = new Mock<IChatCompletionService>();
21+
_options = new SemanticKernelQueryRewriterOptions { ModelId = "gpt-4" };
22+
23+
_kernelWrapperMock.Setup(x => x.GetChatCompletionService())
24+
.Returns(_chatCompletionServiceMock.Object);
25+
}
26+
27+
[Fact]
28+
public async Task RewriteAsync_ShouldUsePromptFromStore()
29+
{
30+
// Arrange
31+
var conversation = new Conversation();
32+
var question = "What is the weather?";
33+
var customPrompt = "Custom prompt template {{question}}";
34+
35+
_promptStoreMock.Setup(x => x.GetPromptAndSetDefaultAsync(
36+
"SemanticKernelQueryRewriter",
37+
It.IsAny<string>(),
38+
It.IsAny<CancellationToken>()))
39+
.ReturnsAsync(customPrompt);
40+
41+
_chatCompletionServiceMock.Setup(x => x.GetChatMessageContentsAsync(
42+
It.IsAny<ChatHistory>(),
43+
It.IsAny<PromptExecutionSettings?>(),
44+
It.IsAny<Kernel?>(),
45+
It.IsAny<CancellationToken>()))
46+
.ReturnsAsync([new ChatMessageContent(AuthorRole.Assistant, "Rewritten question")]);
47+
48+
var rewriter = new SemanticKernelQueryRewriter(_options, _promptStoreMock.Object, _kernelWrapperMock.Object);
49+
50+
// Act
51+
var result = await rewriter.RewriteAsync(conversation, question);
52+
53+
// Assert
54+
_promptStoreMock.Verify(x => x.GetPromptAndSetDefaultAsync(
55+
"SemanticKernelQueryRewriter",
56+
It.IsAny<string>(),
57+
It.IsAny<CancellationToken>()),
58+
Times.Once);
59+
60+
Assert.Equal("Rewritten question", result);
61+
}
62+
63+
/// <summary>
64+
/// IF we cannot rewrite, we cannot answer
65+
/// </summary>
66+
/// <returns></returns>
67+
[Fact]
68+
public async Task RewriteAsync_WhenChatCompletionFails_ShouldThrow()
69+
{
70+
// Arrange
71+
var conversation = new Conversation();
72+
var question = "What is the weather?";
73+
74+
_promptStoreMock.Setup(x => x.GetPromptAndSetDefaultAsync(
75+
It.IsAny<string>(),
76+
It.IsAny<string>(),
77+
It.IsAny<CancellationToken>()))
78+
.ReturnsAsync("prompt");
79+
80+
_chatCompletionServiceMock.Setup(x => x.GetChatMessageContentsAsync(
81+
It.IsAny<ChatHistory>(),
82+
It.IsAny<PromptExecutionSettings?>(),
83+
It.IsAny<Kernel?>(),
84+
It.IsAny<CancellationToken>()))
85+
.ReturnsAsync(new List<ChatMessageContent>());
86+
87+
var rewriter = new SemanticKernelQueryRewriter(_options, _promptStoreMock.Object, _kernelWrapperMock.Object);
88+
89+
// Act
90+
await Assert.ThrowsAsync<InvalidOperationException>(() => rewriter.RewriteAsync(conversation, question));
91+
}
92+
}

src/KernelMemory.Extensions.FunctionalTests/QueryPipeline/UserQuestionPipelineFactoryTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ public Task<IReadOnlyCollection<MemoryRecord>> ReRankAsync(string question, IRea
178178

179179
private class TestQueryRewriter : IConversationQueryRewriter
180180
{
181-
public Task<string> RewriteAsync(Conversation conversation, string question)
181+
public Task<string> RewriteAsync(Conversation conversation, string question, CancellationToken cancellationToken = default)
182182
{
183183
return Task.FromResult(question);
184184
}

src/KernelMemory.Extensions.FunctionalTests/QueryPipeline/UserQuestionPipelineTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public async Task Basic_conversation_handling()
175175

176176
//Generate mock for conversation rewriter
177177
var conversationRewriterMock = new Mock<IConversationQueryRewriter>();
178-
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>()))
178+
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
179179
.Returns(Task.FromResult("New rewritten question"));
180180
sut.SetConversationQueryRewriter(conversationRewriterMock.Object);
181181

@@ -199,7 +199,7 @@ public async Task Basic_conversation_handling_async()
199199

200200
//Generate mock for conversation rewriter
201201
var conversationRewriterMock = new Mock<IConversationQueryRewriter>();
202-
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>()))
202+
conversationRewriterMock.Setup(x => x.RewriteAsync(It.IsAny<Conversation>(), It.IsAny<string>(), It.IsAny<CancellationToken>()))
203203
.Returns(Task.FromResult("New rewritten question"));
204204
sut.SetConversationQueryRewriter(conversationRewriterMock.Object);
205205

0 commit comments

Comments
 (0)