Skip to content

Commit 554cd00

Browse files
committed
Added ability base to intercept context call.
1 parent 6f75161 commit 554cd00

File tree

13 files changed

+392
-32
lines changed

13 files changed

+392
-32
lines changed

src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using Microsoft.Extensions.Logging;
1010
using Microsoft.KernelMemory;
1111
using Microsoft.KernelMemory.AI;
12+
using Microsoft.KernelMemory.Context;
1213
using Microsoft.KernelMemory.DocumentStorage.DevTools;
1314
using Microsoft.KernelMemory.FileSystem.DevTools;
1415
using Microsoft.KernelMemory.MemoryStorage;
@@ -84,7 +85,7 @@ public async Task RunSample2()
8485
var queryExecutorToUse = AnsiConsole.Prompt(new SelectionPrompt<string>()
8586
.Title("Select the query executor to use")
8687
.AddChoices([
87-
"KernelMemory Default",
88+
"KernelMemory Default",
8889
"Cohere CommandR+",
8990
"OpenAI Tool"]));
9091

@@ -105,6 +106,7 @@ public async Task RunSample2()
105106
var kernel = kernelBuider.Build();
106107

107108
//Add semantic kernel in DI
109+
services.AddSingleton<ISemanticKernelWrapper, SemanticKernelWrapper>();
108110
services.AddSingleton(kernel);
109111

110112
var serviceProvider = services.BuildServiceProvider();
@@ -125,6 +127,7 @@ public async Task RunSample2()
125127
// now ask a question to the user continuously until the user ask an empty question
126128
string? question;
127129
UserQuestion userQuestion = null;
130+
var contextAccessor = serviceProvider.GetRequiredService<IContextProvider>();
128131
do
129132
{
130133
bool shouldDumpRewrittenQuery = false;
@@ -196,6 +199,35 @@ public async Task RunSample2()
196199
Console.WriteLine("Document: {0}", citation.DocumentId);
197200
}
198201
}
202+
203+
//ask if we want details
204+
var details = AnsiConsole.Confirm("Do you want to see the details of the question? (y/n)", false);
205+
if (details)
206+
{
207+
if (userQuestion.CallContext != null)
208+
{
209+
Console.WriteLine("Number of LLM calls: {0}", userQuestion.CallContext.CallLogs.Count);
210+
foreach (var call in userQuestion.CallContext.CallLogs)
211+
{
212+
Console.WriteLine("\nCall Name: {0}", call.CallName);
213+
Console.WriteLine("Prompt: {0}\n\n", call.InputPrompt);
214+
Console.WriteLine("Output: {0}", call.Output);
215+
if (call.TokenCount != null)
216+
{
217+
Console.WriteLine(
218+
"\n******Token Count: Input: {0} Output: {1} CachedRead: {2} CachedWrite {3}*****",
219+
call.TokenCount.InputTokens,
220+
call.TokenCount.OutputTokens,
221+
call.TokenCount.CachedTokenRead,
222+
call.TokenCount.CachedTokenWrite);
223+
}
224+
}
225+
}
226+
else
227+
{
228+
Console.WriteLine("No details available");
229+
}
230+
}
199231
}
200232
} while (!string.IsNullOrWhiteSpace(question));
201233
}

src/KernelMemory.Extensions.FunctionalTests/Helper/OpenaiRagQueryExecutorTests.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Microsoft.SemanticKernel;
44
using Moq;
55
using Fasterflect;
6+
using KernelMemory.Extensions.Helper;
67

78
namespace KernelMemory.Extensions.FunctionalTests.Helper;
89

@@ -12,13 +13,15 @@ public class OpenaiRagQueryExecutorTests
1213
private OpenaiRagQueryExecutor _sut;
1314
private Mock<IPromptStore> _mockPromptStore;
1415
private Mock<ILogger<StandardRagQueryExecutor>> _mockLogger;
16+
private Mock<ISemanticKernelWrapper> _mockKernel;
1517

1618
public OpenaiRagQueryExecutorTests()
1719
{
1820
_kernel = new Kernel();
1921
_mockPromptStore = new Mock<IPromptStore>();
2022
_mockLogger = new Mock<ILogger<StandardRagQueryExecutor>>();
21-
_sut = new OpenaiRagQueryExecutor(_kernel, new OpenAIRagQueryExecutorConfiguration(), _mockLogger.Object, _mockPromptStore.Object);
23+
_mockKernel = new Mock<ISemanticKernelWrapper>();
24+
_sut = new OpenaiRagQueryExecutor(_mockKernel.Object, new OpenAIRagQueryExecutorConfiguration(), _mockLogger.Object, _mockPromptStore.Object);
2225
}
2326

2427
[Fact]

src/KernelMemory.Extensions.FunctionalTests/QueryPipeline/AsyncUserQuestionPipelineTests.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using Microsoft.KernelMemory.Context;
2+
using Moq;
13
using System.Runtime.CompilerServices;
24

35
namespace KernelMemory.Extensions.FunctionalTests.QueryPipeline;
@@ -36,7 +38,8 @@ private UserQueryOptions GenerateOptions()
3638

3739
private static UserQuestionPipeline GenerateSut()
3840
{
39-
return new UserQuestionPipeline();
41+
var mock = new Mock<IContextProvider>();
42+
return new UserQuestionPipeline(mock.Object);
4043
}
4144

4245
private class SimpleTextGeneratorAsync : BasicAsyncQueryHandlerWithProgress

src/KernelMemory.Extensions.FunctionalTests/QueryPipeline/UserQuestionPipelineFactoryTests.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Fasterflect;
22
using KernelMemory.Extensions.QueryPipeline;
33
using Microsoft.Extensions.DependencyInjection;
4+
using Microsoft.KernelMemory.Context;
45
using Microsoft.KernelMemory.MemoryStorage;
56
using Moq;
67

@@ -33,6 +34,11 @@ public void Can_configure_and_resolve_pipeline()
3334

3435
var mdb = new Mock<IMemoryDb>();
3536
serviceCollection.AddSingleton(mdb.Object);
37+
serviceCollection.AddSingleton(s =>
38+
{
39+
var mock = new Mock<IContextProvider>();
40+
return mock.Object;
41+
});
3642

3743
serviceCollection.AddKernelMemoryUserQuestionPipeline(config =>
3844
{
@@ -61,6 +67,11 @@ public void Can_configure_and_resolve_pipeline_with_re_ranker()
6167
ServiceCollection serviceCollection = new ServiceCollection();
6268
serviceCollection.AddSingleton<StandardVectorSearchQueryHandler>();
6369
serviceCollection.AddSingleton<TestReRanker>();
70+
serviceCollection.AddSingleton(s =>
71+
{
72+
var mock = new Mock<IContextProvider>();
73+
return mock.Object;
74+
});
6475

6576
var mdb = new Mock<IMemoryDb>();
6677
serviceCollection.AddSingleton(mdb.Object);
@@ -90,6 +101,11 @@ public void Can_configure_and_resolve_pipeline_with_generics()
90101
ServiceCollection serviceCollection = new ServiceCollection();
91102
serviceCollection.AddSingleton<StandardVectorSearchQueryHandler>();
92103
serviceCollection.AddSingleton<TestReRanker>();
104+
serviceCollection.AddSingleton(s =>
105+
{
106+
var mock = new Mock<IContextProvider>();
107+
return mock.Object;
108+
});
93109
serviceCollection.AddSingleton<TestQueryRewriter>();
94110

95111
var mdb = new Mock<IMemoryDb>();
@@ -128,6 +144,11 @@ public void Supports_keyed_service()
128144
serviceCollection.AddKeyedSingleton("1", mdb1.Object);
129145
var mdb2 = new Mock<IMemoryDb>();
130146
serviceCollection.AddKeyedSingleton("2", mdb2.Object);
147+
serviceCollection.AddSingleton(s =>
148+
{
149+
var mock = new Mock<IContextProvider>();
150+
return mock.Object;
151+
});
131152

132153
//I register handler with key "2" and it depends on IMemoryDb with key "2"
133154
serviceCollection.AddKeyedSingleton("2", (serviceProvider, _) =>

src/KernelMemory.Extensions.FunctionalTests/QueryPipeline/UserQuestionPipelineTests.cs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using KernelMemory.Extensions.QueryPipeline;
33
using Microsoft.Extensions.Logging.Abstractions;
44
using Microsoft.KernelMemory;
5+
using Microsoft.KernelMemory.Context;
56
using Microsoft.KernelMemory.MemoryStorage;
67
using Moq;
78

@@ -10,6 +11,7 @@ namespace KernelMemory.Extensions.FunctionalTests.QueryPipeline;
1011
public class UserQuestionPipelineTests
1112
{
1213
private const string AnswerHandlerValue = "AnswerHandler";
14+
private static Mock<IContextProvider> _mock;
1315

1416
[Fact]
1517
public async Task Null_query_has_no_answer()
@@ -348,16 +350,12 @@ private static Mock<IQueryHandler> GenerateCitationsMock(string sourceName, IEnu
348350

349351
private static UserQuestionPipeline GenerateSut()
350352
{
351-
return new UserQuestionPipeline();
353+
_mock = new Mock<IContextProvider>();
354+
return new UserQuestionPipeline(_mock.Object);
352355
}
353356

354357
private class BaseAnswerSimulator : BasicQueryHandler
355358
{
356-
public BaseAnswerSimulator()
357-
{
358-
359-
}
360-
361359
public override string Name => nameof(BaseAnswerSimulator);
362360

363361
protected override async Task OnHandleAsync(UserQuestion userQuestion, CancellationToken cancellationToken)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
using Microsoft.KernelMemory.Context;
2+
using Microsoft.SemanticKernel;
3+
using Microsoft.SemanticKernel.Connectors.OpenAI;
4+
using OpenAI.Chat;
5+
using System.Collections.Generic;
6+
using System.Linq;
7+
8+
namespace KernelMemory.Extensions.Helper;
9+
10+
/// <summary>
11+
/// A generic object that represent a call log to a Large Language Model.
12+
/// </summary>
13+
public class LLMCallLog
14+
{
15+
public string CallName { get; set; } = null!;
16+
17+
public string InputPrompt { get; set; } = null!;
18+
19+
public string? Output { get; set; }
20+
21+
public object? ReturnObject { get; set; }
22+
23+
public TokenCount TokenCount { get; set; } = null!;
24+
25+
public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc)
26+
{
27+
// check if the answer is a tool call or a standard Answer
28+
var isToolAnswer = mc.ToolCalls?.Any() == true;
29+
30+
if (!isToolAnswer)
31+
{
32+
TextContent content = mc.Items.First() as TextContent;
33+
this.Output = content.Text;
34+
}
35+
else
36+
{
37+
var toolCall = mc.ToolCalls!.Single();
38+
39+
this.Output = $"Function Call: Function {toolCall.FunctionName} with arguments {toolCall.FunctionArguments}";
40+
}
41+
42+
// now token usage
43+
if (mc.Metadata?.TryGetValue("Usage", out var usage) == true && usage is ChatTokenUsage ctusage)
44+
{
45+
this.TokenCount = new TokenCount()
46+
{
47+
InputTokens = ctusage.InputTokenCount,
48+
OutputTokens = ctusage.OutputTokenCount,
49+
CachedTokenRead = ctusage.InputTokenDetails?.CachedTokenCount ?? 0,
50+
};
51+
}
52+
}
53+
}
54+
55+
public class TokenCount
56+
{
57+
public int InputTokens { get; set; }
58+
public int OutputTokens { get; set; }
59+
public int CachedTokenRead { get; set; }
60+
61+
public int CachedTokenWrite { get; set; }
62+
}
63+
64+
/// <summary>
65+
/// A collection of call log to a large language model.
66+
/// </summary>
67+
public class LLMCallLogContext
68+
{
69+
public IReadOnlyList<LLMCallLog> CallLogs => _callLogs;
70+
71+
private readonly List<LLMCallLog> _callLogs = new();
72+
73+
public void AddCallLog(LLMCallLog callLog)
74+
{
75+
_callLogs.Add(callLog);
76+
}
77+
}
78+
79+
public static class LLMCallLogExtensions
80+
{
81+
public static LLMCallLogContext InitializeCallLogContext(this IContextProvider contextProvider)
82+
{
83+
LLMCallLogContext lLMCallLogContext = new();
84+
var context = contextProvider.GetContext();
85+
if (context != null)
86+
{
87+
context.Arguments[nameof(LLMCallLogContext)] = lLMCallLogContext;
88+
}
89+
return lLMCallLogContext;
90+
}
91+
92+
public static LLMCallLogContext? GetCallLogContext(this IContextProvider contextProvider)
93+
{
94+
var context = contextProvider.GetContext();
95+
if (context != null)
96+
{
97+
if (context.Arguments.TryGetValue(nameof(LLMCallLogContext), out var llmCallLogContext))
98+
{
99+
return llmCallLogContext as LLMCallLogContext;
100+
}
101+
}
102+
return null;
103+
}
104+
105+
public static void AddCallLog(this IContext context, LLMCallLog callLog)
106+
{
107+
if (!context.Arguments.TryGetValue(nameof(LLMCallLogContext), out var llmCallLogContext))
108+
{
109+
llmCallLogContext = new LLMCallLogContext();
110+
context.Arguments[nameof(LLMCallLogContext)] = llmCallLogContext;
111+
}
112+
113+
((LLMCallLogContext)llmCallLogContext).AddCallLog(callLog);
114+
}
115+
}

0 commit comments

Comments
 (0)