Skip to content

Commit efae48b

Browse files
committed
Added context to cohere.
1 parent 554cd00 commit efae48b

File tree

8 files changed

+150
-13
lines changed

8 files changed

+150
-13
lines changed

src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,15 @@ public async Task RunSample2()
221221
call.TokenCount.CachedTokenRead,
222222
call.TokenCount.CachedTokenWrite);
223223
}
224+
225+
if (call.Warnings.Count > 0)
226+
{
227+
Console.WriteLine("Warnings:");
228+
foreach (var warning in call.Warnings)
229+
{
230+
Console.WriteLine(warning);
231+
}
232+
}
224233
}
225234
}
226235
else
@@ -351,6 +360,7 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
351360

352361
services.AddSingleton<IKernelMemoryBuilder>(kernelMemoryBuilder);
353362
services.AddSingleton<CohereReRanker>();
363+
354364
services.AddSingleton<HandlebarSemanticKernelQueryRewriter>();
355365
services.AddSingleton<SemanticKernelQueryRewriter>();
356366
services.AddSingleton<StandardVectorSearchQueryHandler>();

src/KernelMemory.Extensions.FunctionalTests/Cohere/CohereTests.cs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
using KernelMemory.Extensions.Cohere;
33
using KernelMemory.Extensions.FunctionalTests.TestUtilities;
44
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.KernelMemory.Context;
56
using Microsoft.KernelMemory.MemoryStorage;
7+
using Moq;
68

79
namespace KernelMemory.Extensions.FunctionalTests.Cohere;
810

@@ -11,6 +13,8 @@ public class CohereTests
1113
private readonly ServiceProvider _serviceProvider;
1214

1315
private readonly IHttpClientFactory _httpClientFactory;
16+
private Mock<IContextProvider> _contextProvider;
17+
private Mock<IContext> _context;
1418

1519
public CohereTests()
1620
{
@@ -31,6 +35,13 @@ public CohereTests()
3135
// Configure standard resilience options here
3236
});
3337

38+
//Setup contextual provider.
39+
_contextProvider = new Mock<IContextProvider>();
40+
_context = new Moq.Mock<IContext>();
41+
_context.Setup(c => c.Arguments).Returns(new Dictionary<string, object?>());
42+
_contextProvider.Setup(c => c.GetContext()).Returns(_context.Object);
43+
services.AddSingleton(_contextProvider.Object);
44+
3445
var cohereApiKey = Environment.GetEnvironmentVariable("COHERE_API_KEY");
3546

3647
if (string.IsNullOrEmpty(cohereApiKey))

src/KernelMemory.Extensions/Cohere/CohereCommandRQueryExecutor.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Microsoft.Extensions.Logging;
33
using Microsoft.KernelMemory.Diagnostics;
44
using Microsoft.KernelMemory.MemoryStorage;
5+
using Polly.Fallback;
56
using System.Collections.Generic;
67
using System.Linq;
78
using System.Runtime.CompilerServices;
@@ -47,15 +48,18 @@ public class CohereCommandRQueryExecutor : BasicAsyncQueryHandlerWithProgress
4748

4849
private readonly RawCohereClient _rawCohereClient;
4950
private readonly CohereCommandRQueryExecutorConfiguration _config;
51+
private readonly CohereTokenizer _cohereTokenizer;
5052
private readonly ILogger<StandardRagQueryExecutor> _log;
5153

5254
public CohereCommandRQueryExecutor(
5355
RawCohereClient rawCohereClient,
5456
CohereCommandRQueryExecutorConfiguration config,
57+
CohereTokenizer cohereTokenizer,
5558
ILogger<StandardRagQueryExecutor>? log = null)
5659
{
5760
_rawCohereClient = rawCohereClient;
5861
_config = config;
62+
_cohereTokenizer = cohereTokenizer;
5963
_log = log ?? DefaultLogger<StandardRagQueryExecutor>.Instance;
6064
}
6165

src/KernelMemory.Extensions/Cohere/CohereConfiguration.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ public static IServiceCollection ConfigureCohereChat(
111111
BaseUrl = baseUrl,
112112
});
113113

114+
services.AddSingleton<CohereTokenizer>();
115+
114116
return services;
115117
}
116118

src/KernelMemory.Extensions/Cohere/RawCohereChatClient.cs

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,32 @@
1010
using System.Threading.Tasks;
1111
using KernelMemory.Extensions.Helper;
1212
using Microsoft.Extensions.Logging;
13+
using Microsoft.KernelMemory.Context;
1314
using Microsoft.KernelMemory.Diagnostics;
1415

1516
namespace KernelMemory.Extensions.Cohere;
1617

1718
public class RawCohereChatClient
1819
{
19-
private readonly HttpClient _httpClient;
20+
private readonly HttpClient _httpClient;
21+
private readonly IContextProvider _contextProvider;
2022
private readonly ILogger<RawCohereChatClient> _log;
2123
private readonly string _apiKey;
2224
private readonly string _baseUrl;
2325

2426
public RawCohereChatClient(
2527
CohereChatConfiguration config,
2628
HttpClient httpClient,
29+
IContextProvider contextProvider,
2730
ILogger<RawCohereChatClient>? log = null)
2831
{
2932
if (String.IsNullOrEmpty(config.ApiKey))
3033
{
3134
throw new ArgumentException("ApiKey is required", nameof(config.ApiKey));
3235
}
3336

34-
this._httpClient = httpClient;
37+
_httpClient = httpClient;
38+
_contextProvider = contextProvider;
3539
_log = log ?? DefaultLogger<RawCohereChatClient>.Instance;
3640
_apiKey = config.ApiKey;
3741
_baseUrl = config.BaseUrl;
@@ -45,10 +49,7 @@ public async Task<CohereRagResponse> RagQueryAsync(
4549
CohereRagRequest cohereRagRequest,
4650
CancellationToken cancellationToken = default)
4751
{
48-
if (cohereRagRequest is null)
49-
{
50-
throw new ArgumentNullException(nameof(cohereRagRequest));
51-
}
52+
ArgumentNullException.ThrowIfNull(cohereRagRequest);
5253

5354
if (cohereRagRequest.Stream)
5455
{
@@ -91,10 +92,9 @@ public async IAsyncEnumerable<CohereRagStreamingResponse> RagQueryStreamingAsync
9192
CohereRagRequest cohereRagRequest,
9293
[EnumeratorCancellation] CancellationToken cancellationToken = default)
9394
{
94-
if (cohereRagRequest is null)
95-
{
96-
throw new ArgumentNullException(nameof(cohereRagRequest));
97-
}
95+
ArgumentNullException.ThrowIfNull(cohereRagRequest);
96+
97+
var context = _contextProvider.GetContext();
9898

9999
var client = _httpClient;
100100
//force streaming
@@ -130,7 +130,7 @@ public async IAsyncEnumerable<CohereRagStreamingResponse> RagQueryStreamingAsync
130130
string line = (await reader.ReadLineAsync(cancellationToken))!;
131131
var data = JsonSerializer.Deserialize<ChatStreamEvent>(line)!;
132132

133-
if (data.EventType == "stream-start" || data.EventType == "stream-end" || data.EventType == "search-results")
133+
if (data.EventType == "stream-start" || data.EventType == "search-results")
134134
{
135135
//not interested in this events
136136
continue;
@@ -152,12 +152,47 @@ public async IAsyncEnumerable<CohereRagStreamingResponse> RagQueryStreamingAsync
152152
ResponseType = CohereRagResponseType.Citations
153153
};
154154
}
155+
else if (data.EventType == "stream-end")
156+
{
157+
//create log
158+
AddLog(context, "CommandR+RAG", cohereRagRequest.Describe(), data);
159+
}
155160
else
156161
{
157162
//not supported.
158163
_log.LogWarning("Cohere stream api receved unknown event data type {0}", data.EventType);
159164
}
160165
}
161166
}
162-
}
167+
}
168+
169+
private void AddLog(
170+
IContext context,
171+
string name,
172+
string input,
173+
ChatStreamEvent data)
174+
{
175+
LLMCallLog callLog = new()
176+
{
177+
CallName = name,
178+
ReturnObject = data,
179+
InputPrompt = input,
180+
Output = data.Response.Text,
181+
TokenCount = new TokenCount()
182+
{
183+
InputTokens = data.Response?.Meta.Tokens.InputTokens ?? 0,
184+
OutputTokens = data.Response?.Meta.Tokens.OutputTokens ?? 0,
185+
}
186+
};
187+
188+
if (data.Response?.Meta.Warnings?.Length > 0)
189+
{
190+
foreach (var warning in data.Response.Meta.Warnings)
191+
{
192+
callLog.AddWarning(warning);
193+
}
194+
}
195+
196+
context.AddCallLog(callLog);
197+
}
163198
}

src/KernelMemory.Extensions/Cohere/RawCohereClientDtos.cs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
using Microsoft.KernelMemory.MemoryStorage;
2+
using System;
23
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
36
using System.Text.Json.Serialization;
47

58
namespace KernelMemory.Extensions.Cohere;
@@ -43,16 +46,51 @@ public static CohereRagRequest CreateFromMemoryRecord(string question, IEnumerab
4346

4447
foreach (var memory in memoryRecords)
4548
{
49+
//if the text is more than 300 words we need to split it
50+
var text = memory.GetPartitionText();
51+
int start = 0;
52+
int spaceCount = 0;
53+
for (int i = 0; i < text.Length; i++)
54+
{
55+
if (text[i] == ' ')
56+
{
57+
spaceCount++;
58+
}
59+
if (spaceCount > 250)
60+
{
61+
ragRequest.Documents.Add(new RagDocument()
62+
{
63+
DocId = memory.Id,
64+
Text = text[start..i]
65+
});
66+
start = i;
67+
spaceCount = 0;
68+
}
69+
}
70+
4671
ragRequest.Documents.Add(new RagDocument()
4772
{
4873
DocId = memory.Id,
49-
Text = memory.GetPartitionText()
74+
Text = text[start..text.Length]
5075
});
5176
}
5277

5378
return ragRequest;
5479
}
5580

81+
internal string Describe()
82+
{
83+
StringBuilder stringBuilder = new StringBuilder();
84+
stringBuilder.AppendLine($"Message: {Message}");
85+
stringBuilder.AppendLine($"Model: {Model}");
86+
stringBuilder.AppendLine($"Document count: {Documents.Count}");
87+
stringBuilder.AppendLine($"Temperature: {Temperature}");
88+
stringBuilder.AppendLine($"Stream: {Stream}");
89+
stringBuilder.AppendLine($"\n\nFullDocuments\n{string.Join("\n", Documents.Select(d => d.Text))}");
90+
91+
return stringBuilder.ToString();
92+
}
93+
5694
[JsonPropertyName("message")]
5795
public string Message { get; set; }
5896

@@ -321,6 +359,30 @@ public class ChatStreamEvent
321359

322360
[JsonPropertyName("citations")]
323361
public List<CohereRagCitation> Citations { get; set; }
362+
363+
[JsonPropertyName("response")]
364+
public ChatStreamingResponse Response { get; set; }
365+
}
366+
367+
public class ChatStreamingResponse
368+
{
369+
[JsonPropertyName("response_id")]
370+
public string ResponseId { get; set; }
371+
372+
[JsonPropertyName("text")]
373+
public string Text { get; set; }
374+
375+
[JsonPropertyName("generation_id")]
376+
public string GenerationId { get; set; }
377+
378+
[JsonPropertyName("chat_history")]
379+
public List<ChatMessage> ChatHistory { get; set; }
380+
381+
[JsonPropertyName("finish_reason")]
382+
public string FinishReason { get; set; }
383+
384+
[JsonPropertyName("meta")]
385+
public Meta Meta { get; set; }
324386
}
325387

326388
public class CohereRagCitation

src/KernelMemory.Extensions/Helper/LLMCallLog.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using Microsoft.SemanticKernel;
33
using Microsoft.SemanticKernel.Connectors.OpenAI;
44
using OpenAI.Chat;
5+
using System;
56
using System.Collections.Generic;
67
using System.Linq;
78

@@ -20,6 +21,10 @@ public class LLMCallLog
2021

2122
public object? ReturnObject { get; set; }
2223

24+
public IReadOnlyList<string> Warnings => _warnings;
25+
26+
private readonly List<string> _warnings = new();
27+
2328
public TokenCount TokenCount { get; set; } = null!;
2429

2530
public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc)
@@ -50,6 +55,11 @@ public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc)
5055
};
5156
}
5257
}
58+
59+
public void AddWarning(string warning)
60+
{
61+
_warnings.Add(warning);
62+
}
5363
}
5464

5565
public class TokenCount
@@ -66,6 +76,8 @@ public class TokenCount
6676
/// </summary>
6777
public class LLMCallLogContext
6878
{
79+
public Guid Id { get; private set; } = Guid.NewGuid();
80+
6981
public IReadOnlyList<LLMCallLog> CallLogs => _callLogs;
7082

7183
private readonly List<LLMCallLog> _callLogs = new();

src/KernelMemory.Extensions/Helper/SemanticKernelWrapper.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace KernelMemory.Extensions.Helper;
1818
public interface ISemanticKernelWrapper
1919
{
2020
KernelFunction CreateFunctionFromMethod(Delegate method, string functionName);
21+
2122
KernelPlugin CreateFromFunctions(string pluginName, IEnumerable<KernelFunction> functions);
2223

2324
KernelFunction CreateFunctionFromPrompt(PromptTemplateConfig config, IPromptTemplateFactory? promptTemplateFactory = null);

0 commit comments

Comments
 (0)