Skip to content

Commit 6f75161

Browse files
committed
Introduced IpromptStore
1 parent ab4f26b commit 6f75161

File tree

4 files changed

+225
-12
lines changed

4 files changed

+225
-12
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
using KernelMemory.Extensions.QueryPipeline;
2+
using Microsoft.Extensions.Logging;
3+
using Microsoft.SemanticKernel;
4+
using Moq;
5+
using Fasterflect;
6+
7+
namespace KernelMemory.Extensions.FunctionalTests.Helper;
8+
9+
public class OpenaiRagQueryExecutorTests
10+
{
11+
private Kernel _kernel;
12+
private OpenaiRagQueryExecutor _sut;
13+
private Mock<IPromptStore> _mockPromptStore;
14+
private Mock<ILogger<StandardRagQueryExecutor>> _mockLogger;
15+
16+
public OpenaiRagQueryExecutorTests()
17+
{
18+
_kernel = new Kernel();
19+
_mockPromptStore = new Mock<IPromptStore>();
20+
_mockLogger = new Mock<ILogger<StandardRagQueryExecutor>>();
21+
_sut = new OpenaiRagQueryExecutor(_kernel, new OpenAIRagQueryExecutorConfiguration(), _mockLogger.Object, _mockPromptStore.Object);
22+
}
23+
24+
[Fact]
25+
public async Task GetPromptAsync_ShouldReturnPromptFromMock()
26+
{
27+
// Arrange
28+
var expectedPrompt = "Test Prompt";
29+
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync(expectedPrompt);
30+
31+
// Act
32+
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
33+
var actualPrompt = await task;
34+
35+
// Assert
36+
Assert.Equal(expectedPrompt, actualPrompt);
37+
}
38+
39+
[Fact]
40+
public async Task GetPromptAsync_Should_validate_log_called()
41+
{
42+
// Arrange
43+
var invalidPrompt = "Invalid Prompt";
44+
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync(invalidPrompt);
45+
46+
// Act
47+
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
48+
var actualPrompt = await task;
49+
50+
// Assert
51+
// template is not valid, we expect two log error
52+
// void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter);
53+
_mockLogger.Verify(
54+
logger => logger.Log(
55+
LogLevel.Error,
56+
It.IsAny<EventId>(),
57+
It.Is<It.IsAnyType>((value, type) => value.ToString()!.Contains("{{$question}}")),
58+
It.IsAny<Exception>(),
59+
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
60+
Times.Once);
61+
62+
_mockLogger.Verify(
63+
logger => logger.Log(
64+
LogLevel.Error,
65+
It.IsAny<EventId>(),
66+
It.Is<It.IsAnyType>((value, type) => value.ToString()!.Contains("{{$documents}}")),
67+
It.IsAny<Exception>(),
68+
It.IsAny<Func<It.IsAnyType, Exception?, string>>()),
69+
Times.Once);
70+
71+
// verify that store method of Ipromptstore is not called
72+
_mockPromptStore.Verify(
73+
store => store.SetPromptAsync(It.IsAny<string>(), It.IsAny<string>()),
74+
Times.Never);
75+
}
76+
77+
[Fact]
78+
public async Task If_prompt_not_saved_reload()
79+
{
80+
// Arrange
81+
_mockPromptStore.Setup(store => store.GetPromptAsync(It.IsAny<string>())).ReturnsAsync((String?) null);
82+
83+
// Act
84+
var task = (Task<string>)_sut.CallMethod("GetPromptAsync");
85+
var actualPrompt = await task;
86+
87+
// Assert
88+
// verify that store method of Ipromptstore is called
89+
_mockPromptStore.Verify(
90+
store => store.SetPromptAsync("OpenaiRagQueryExecutor", It.IsAny<string>()),
91+
Times.Once);
92+
}
93+
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using System.Threading.Tasks;
2+
3+
namespace KernelMemory.Extensions;
4+
5+
/// <summary>
6+
/// <para>
7+
/// To let the user to change prompt used in the various part of extensions
8+
/// we introduce an interface that is capable of providing user prompt given a key.
9+
/// The user can implement its own implementation of this interface and pass it to the extension.
10+
/// </para>
11+
/// <para>
12+
/// We do not use propmtp provider from kernel memory because we need to have a way to set
13+
/// the prompt into the provider for a better experience of the user
14+
/// </para>
15+
/// </summary>
16+
public interface IPromptStore
17+
{
18+
/// <summary>
19+
/// Get the prompt for the given key.
20+
/// </summary>
21+
/// <param name="key">The key for which the prompt is requested.</param>
22+
/// <returns>The prompt for the given key or null if the prompt is not present. If null is returned the
23+
/// various components will use some default prompts.</returns>
24+
Task<string?> GetPromptAsync(string key);
25+
26+
/// <summary>
27+
/// Allow setting prompt value.
28+
/// </summary>
29+
/// <param name="key"></param>
30+
/// <param name="prompt"></param>
31+
Task SetPromptAsync(string key, string prompt);
32+
}
33+
34+
public class NullPromptStore : IPromptStore
35+
{
36+
public static NullPromptStore Instance { get; } = new NullPromptStore();
37+
38+
/// <summary>
39+
/// Get the prompt for the given key.
40+
/// </summary>
41+
/// <param name="key">The key for which the prompt is requested.</param>
42+
/// <returns>An empty prompt.</returns>
43+
public Task<string> GetPromptAsync(string key)
44+
{
45+
return Task.FromResult(string.Empty);
46+
}
47+
48+
/// <summary>
49+
/// Allow setting prompt value.
50+
/// </summary>
51+
/// <param name="key">The key for which the prompt is set.</param>
52+
/// <param name="prompt">The prompt value to set.</param>
53+
public Task SetPromptAsync(string key, string prompt)
54+
{
55+
// No operation as this is a null implementation.
56+
return Task.CompletedTask;
57+
}
58+
}
59+

src/KernelMemory.Extensions/QueryPipeline/OpenaiRagQueryExecutor.cs

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public class OpenAIRagQueryExecutorConfiguration
2424
/// the default <see cref="ModelId"/> automatically
2525
/// we will use a standard gpt3.5 model
2626
/// </summary>
27-
public string ModelName { get; set; } = "gpt35";
27+
public string ModelName { get; set; } = "gpt-35";
2828

2929
/// <summary>
3030
/// This is the modelId configured in Semantic Kernel
@@ -63,16 +63,26 @@ public class OpenaiRagQueryExecutor : BasicQueryHandler
6363
private readonly OpenAIRagQueryExecutorConfiguration _config;
6464
private readonly Tokenizer _tokenizer;
6565
private readonly ILogger<StandardRagQueryExecutor> _log;
66+
private readonly IPromptStore _promptStore;
67+
68+
private const string DefaultPrompt = @"You are an AI assistant that helps users answer questions given a specific context. You will be given a context and asked a question based on that context. Your answer should be as precise as possible and should only come from the context.
69+
Please add all documents used as citations.
70+
Question: {{$question}}
71+
72+
Documents:
73+
{{$documents}}";
6674

6775
public OpenaiRagQueryExecutor(
6876
Kernel kernel,
6977
OpenAIRagQueryExecutorConfiguration? config = null,
70-
ILogger<StandardRagQueryExecutor>? log = null)
78+
ILogger<StandardRagQueryExecutor>? log = null,
79+
IPromptStore? promptStore = null)
7180
{
7281
_kernel = kernel;
7382
_config = config ?? new OpenAIRagQueryExecutorConfiguration();
7483
_tokenizer = TiktokenTokenizer.CreateForModel(_config.ModelName);
7584
_log = log ?? DefaultLogger<StandardRagQueryExecutor>.Instance;
85+
_promptStore = promptStore ?? NullPromptStore.Instance;
7686
}
7787

7888
protected override async Task OnHandleAsync(
@@ -170,9 +180,9 @@ protected override async Task OnHandleAsync(
170180

171181
private class GptAnswer
172182
{
173-
public string Answer { get; set; }
183+
public string Answer { get; set; } = null!;
174184

175-
public HashSet<int> Documents { get; set; }
185+
public HashSet<int> Documents { get; set; } = null!;
176186
}
177187

178188
/// <summary>
@@ -194,22 +204,19 @@ private class GptAnswer
194204
[Description("Answer of the question")] string answer,
195205
[Description("Documents used to formulate the answer")] int[] documents
196206
) =>
197-
{
198-
}, "return_result");
207+
{
208+
}, "return_result");
199209
var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function]);
200210
var openAIFunction = plugin.GetFunctionsMetadata().First().ToOpenAIFunction();
201211

212+
string prompt = await GetPromptAsync();
213+
202214
// Create a template for chat with settings
203215
var chat = _kernel.CreateFunctionFromPrompt(new PromptTemplateConfig()
204216
{
205217
Name = "Rag",
206218
Description = "Answer user question with documents.",
207-
Template = @"You are an AI assistant that helps users answer questions given a specific context. You will be given a context and asked a question based on that context. Your answer should be as precise as possible and should only come from the context.
208-
Please add all documents used as citations.
209-
Question: {{$question}}
210-
211-
Documents:
212-
{{$documents}}",
219+
Template = prompt,
213220
TemplateFormat = "semantic-kernel",
214221
InputVariables =
215222
[
@@ -252,4 +259,27 @@ Please add all documents used as citations.
252259

253260
return null;
254261
}
262+
263+
private async Task<string> GetPromptAsync()
264+
{
265+
var prompt = await _promptStore.GetPromptAsync(nameof(OpenaiRagQueryExecutor));
266+
if (prompt == null)
267+
{
268+
//Set the default prompt into the storage so the user can change.
269+
await _promptStore.SetPromptAsync(nameof(OpenaiRagQueryExecutor), DefaultPrompt);
270+
prompt = DefaultPrompt;
271+
}
272+
273+
if (!prompt.Contains("{{$question}}"))
274+
{
275+
_log.LogError("The prompt does not contain {{$question}} placeholder, the prompt will not work correctly");
276+
}
277+
278+
if (!prompt.Contains("{{$documents}}"))
279+
{
280+
_log.LogError("The prompt does not contain {{$documents}} placeholder, the prompt will not work correctly");
281+
}
282+
283+
return prompt;
284+
}
255285
}

src/NullPromptStore.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System.Threading.Tasks;
2+
3+
namespace KernelMemory.Extensions
4+
{
5+
/// <summary>
6+
/// A null implementation of the IPromptStore interface that returns empty prompts.
7+
/// </summary>
8+
public class NullPromptStore : IPromptStore
9+
{
10+
/// <summary>
11+
/// Get the prompt for the given key.
12+
/// </summary>
13+
/// <param name="key">The key for which the prompt is requested.</param>
14+
/// <returns>An empty prompt.</returns>
15+
public Task<string> GetPromptAsync(string key)
16+
{
17+
return Task.FromResult(string.Empty);
18+
}
19+
20+
/// <summary>
21+
/// Allow setting prompt value.
22+
/// </summary>
23+
/// <param name="key">The key for which the prompt is set.</param>
24+
/// <param name="prompt">The prompt value to set.</param>
25+
public Task SetPromptAsync(string key, string prompt)
26+
{
27+
// No operation as this is a null implementation.
28+
return Task.CompletedTask;
29+
}
30+
}
31+
}

0 commit comments

Comments
 (0)