Skip to content

Commit 2ac319d

Browse files
committed
Added sample with custom embedding generator.
1 parent 65a2746 commit 2ac319d

File tree

3 files changed

+179
-5
lines changed

3 files changed

+179
-5
lines changed

src/KernelMemory.Extensions.ConsoleTest/KernelMemory.Extensions.ConsoleTest.csproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
<PackageReference Include="CommandDotNet.Spectre" />
1616
<PackageReference Include="Alkampfer.KernelMemory.ElasticSearch" />
1717
<PackageReference Include="Microsoft.KernelMemory.AI.AzureOpenAI" />
18-
<PackageReference Include="PdfPig" />
1918
</ItemGroup>
2019

2120
</Project>

src/KernelMemory.Extensions.ConsoleTest/Samples/CustomParsersSample.cs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using KernelMemory.Extensions.ConsoleTest.Helper;
2+
using KernelMemory.Extensions.ConsoleTest.SpecialHandlers;
23
using Microsoft.Extensions.DependencyInjection;
34
using Microsoft.Extensions.Http.Resilience;
45
using Microsoft.Extensions.Logging;
@@ -13,13 +14,13 @@
1314
using Microsoft.KernelMemory.Pipeline;
1415
using System.Security.Cryptography;
1516
using System.Text;
16-
using System.Linq;
17-
using HandlebarsDotNet.Extensions;
17+
using static Microsoft.KernelMemory.Constants.CustomContext;
1818

1919
namespace SemanticMemory.Samples;
2020

2121
internal class CustomParsersSample : ISample
2222
{
23+
2324
public async Task RunSample(string fileToParse)
2425
{
2526
var services = new ServiceCollection();
@@ -60,6 +61,9 @@ public async Task RunSample(string fileToParse)
6061
GenerateEmbeddingsHandler textEmbedding = new("gen_embeddings", orchestrator);
6162
await orchestrator.AddHandlerAsync(textEmbedding);
6263

64+
CustomizedEmbeddingsHandler questionAnwerParser = new CustomizedEmbeddingsHandler("qa_embeddings", orchestrator, ExtractTextFromChunk);
65+
await orchestrator.AddHandlerAsync(questionAnwerParser);
66+
6367
SaveRecordsHandler saveRecords = new("save_records", orchestrator);
6468
await orchestrator.AddHandlerAsync(saveRecords);
6569

@@ -77,12 +81,14 @@ public async Task RunSample(string fileToParse)
7781
.Then("extract")
7882
//.Then("partition")
7983
.Then("markdownpartition")
80-
.Then("gen_embeddings")
84+
//.Then("gen_embeddings")
85+
.Then("qa_embeddings")
8186
.Then("save_records");
8287

83-
contextProvider.AddLLamaCloudParserOptions(fileName, "This is a manual for Dreame vacuum cleaner, I need you to extract a series of sections that can be useful for an helpdesk to answer user questions. You will create sections where each sections contains a question and an answer taken from the text. Each question will be separated with ---");
88+
contextProvider.AddLLamaCloudParserOptions(fileName, @"This is a manual for Dreame vacuum cleaner, I need you to extract a series of sections that can be useful for an helpdesk to answer user questions. You will create sections where each sections contains a question and an answer taken from the text. Question must be on a single line. Each question will be separated with ---");
8489

8590
var pipeline = pipelineBuilder.Build();
91+
pipeline.GetContext().SetArg(EmbeddingGeneration.BatchSize, 50);
8692
await orchestrator.RunPipelineAsync(pipeline);
8793

8894
// now ask a question to the user continuously until the user ask an empty question
@@ -99,6 +105,18 @@ public async Task RunSample(string fileToParse)
99105
} while (!string.IsNullOrWhiteSpace(question));
100106
}
101107

108+
private async Task<string> ExtractTextFromChunk(string text, CancellationToken token)
109+
{
110+
//I need to take first line from the text.
111+
var lines = text.Split('\n');
112+
if (lines.Length > 0)
113+
{
114+
var firstLine = lines[0];
115+
return firstLine.Trim('\r', '\n', ' ', '*');
116+
}
117+
return text;
118+
}
119+
102120
private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
103121
ServiceCollection services)
104122
{
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
using Microsoft.Extensions.Logging;
2+
using Microsoft.KernelMemory;
3+
using Microsoft.KernelMemory.AI;
4+
using Microsoft.KernelMemory.Context;
5+
using Microsoft.KernelMemory.Diagnostics;
6+
using Microsoft.KernelMemory.Handlers;
7+
using Microsoft.KernelMemory.Pipeline;
8+
using System;
9+
using System.Collections.Generic;
10+
using System.Linq;
11+
using System.Text;
12+
using System.Threading.Tasks;
13+
14+
namespace KernelMemory.Extensions.ConsoleTest.SpecialHandlers;
15+
16+
/// <summary>
17+
/// This is based on the original embedding handler of kernel memory, it only adds the ability
18+
/// to externalize a transformer that extract from the original text the text that needs to be
19+
/// passed to the embedding generator.
20+
/// </summary>
21+
public sealed class CustomizedEmbeddingsHandler : GenerateEmbeddingsHandlerBase, IPipelineStepHandler
22+
{
23+
private readonly ILogger<CustomizedEmbeddingsHandler> _log;
24+
private readonly List<ITextEmbeddingGenerator> _embeddingGenerators;
25+
private readonly bool _embeddingGenerationEnabled;
26+
private readonly Func<string, CancellationToken, Task<string>> _extractTextToEmbedAsync;
27+
28+
/// <inheritdoc />
29+
public string StepName { get; }
30+
31+
/// <summary>
32+
/// Handler responsible for generating embeddings and saving them to document storages (not memory db).
33+
/// Note: stepName and other params are injected with DI
34+
/// </summary>
35+
/// <param name="stepName">Pipeline step for which the handler will be invoked</param>
36+
/// <param name="orchestrator">Current orchestrator used by the pipeline, giving access to content and other helps.</param>
37+
/// <param name="loggerFactory">Application logger factory</param>
38+
public CustomizedEmbeddingsHandler(
39+
string stepName,
40+
IPipelineOrchestrator orchestrator,
41+
Func<string, CancellationToken, Task<string>> extractTextToEmbedAsync,
42+
ILoggerFactory? loggerFactory = null)
43+
: base(orchestrator, (loggerFactory ?? DefaultLogger.Factory).CreateLogger<CustomizedEmbeddingsHandler>())
44+
{
45+
this.StepName = stepName;
46+
_extractTextToEmbedAsync = extractTextToEmbedAsync ?? throw new ArgumentNullException(nameof(extractTextToEmbedAsync));
47+
48+
this._log = (loggerFactory ?? DefaultLogger.Factory).CreateLogger<CustomizedEmbeddingsHandler>();
49+
this._embeddingGenerationEnabled = orchestrator.EmbeddingGenerationEnabled;
50+
this._embeddingGenerators = orchestrator.GetEmbeddingGenerators();
51+
52+
if (this._embeddingGenerationEnabled)
53+
{
54+
if (this._embeddingGenerators.Count < 1)
55+
{
56+
this._log.LogError("Handler '{0}' NOT ready, no embedding generators configured", stepName);
57+
}
58+
59+
this._log.LogInformation("Handler '{0}' ready, {1} embedding generators", stepName, this._embeddingGenerators.Count);
60+
}
61+
else
62+
{
63+
this._log.LogInformation("Handler '{0}' ready, embedding generation DISABLED", stepName);
64+
}
65+
}
66+
67+
/// <inheritdoc />
68+
public async Task<(ReturnType returnType, DataPipeline updatedPipeline)> InvokeAsync(
69+
DataPipeline pipeline, CancellationToken cancellationToken = default)
70+
{
71+
if (!this._embeddingGenerationEnabled)
72+
{
73+
this._log.LogTrace("Embedding generation is disabled, skipping - pipeline '{0}/{1}'", pipeline.Index, pipeline.DocumentId);
74+
return (ReturnType.Success, pipeline);
75+
}
76+
77+
foreach (ITextEmbeddingGenerator generator in this._embeddingGenerators)
78+
{
79+
var subStepName = GetSubStepName(generator);
80+
var partitions = await this.GetListOfPartitionsToProcessAsync(pipeline, subStepName, cancellationToken).ConfigureAwait(false);
81+
82+
int batchSize = pipeline.GetContext().GetCustomEmbeddingGenerationBatchSizeOrDefault((generator as ITextEmbeddingBatchGenerator)?.MaxBatchSize ?? 1);
83+
if (batchSize > 1 && generator is ITextEmbeddingBatchGenerator batchGenerator)
84+
{
85+
await this.GenerateEmbeddingsWithBatchingAsync(pipeline, batchGenerator, batchSize, partitions, cancellationToken).ConfigureAwait(false);
86+
}
87+
else
88+
{
89+
await this.GenerateEmbeddingsOneAtATimeAsync(pipeline, generator, partitions, cancellationToken).ConfigureAwait(false);
90+
}
91+
}
92+
93+
return (ReturnType.Success, pipeline);
94+
}
95+
96+
protected override IPipelineStepHandler ActualInstance => this;
97+
98+
// Generate and save embeddings, one batch at a time
99+
private async Task GenerateEmbeddingsWithBatchingAsync(
100+
DataPipeline pipeline,
101+
ITextEmbeddingBatchGenerator generator,
102+
int batchSize,
103+
List<PartitionInfo> partitions,
104+
CancellationToken cancellationToken)
105+
{
106+
PartitionInfo[][] batches = partitions.Chunk(batchSize).ToArray();
107+
108+
this._log.LogTrace("Generating embeddings, pipeline '{0}/{1}', batch generator '{2}', batch size {3}, batch count {4}",
109+
pipeline.Index, pipeline.DocumentId, generator.GetType().FullName, generator.MaxBatchSize, batches.Length);
110+
111+
// One batch at a time
112+
foreach (PartitionInfo[] partitionsInfo in batches)
113+
{
114+
List<string> strings = new();
115+
foreach (var partition in partitionsInfo)
116+
{
117+
var textToEmbed = await _extractTextToEmbedAsync(partition.PartitionContent, cancellationToken).ConfigureAwait(false);
118+
strings.Add(textToEmbed);
119+
}
120+
121+
int totalTokens = strings.Sum(s => ((ITextEmbeddingGenerator)generator).CountTokens(s));
122+
this._log.LogTrace("Generating embeddings, pipeline '{0}/{1}', generator '{2}', batch size {3}, total {4} tokens",
123+
pipeline.Index, pipeline.DocumentId, generator.GetType().FullName, strings.Count, totalTokens);
124+
125+
Embedding[] embeddings = await generator.GenerateEmbeddingBatchAsync(strings, cancellationToken).ConfigureAwait(false);
126+
await this.SaveEmbeddingsToDocumentStorageAsync(
127+
pipeline, partitionsInfo, embeddings, GetEmbeddingProviderName(generator), GetEmbeddingGeneratorName(generator), cancellationToken)
128+
.ConfigureAwait(false);
129+
}
130+
}
131+
132+
// Generate and save embeddings, one chunk at a time
133+
private async Task GenerateEmbeddingsOneAtATimeAsync(
134+
DataPipeline pipeline,
135+
ITextEmbeddingGenerator generator,
136+
List<PartitionInfo> partitions,
137+
CancellationToken cancellationToken)
138+
{
139+
this._log.LogTrace("Generating embeddings, pipeline '{0}/{1}', generator '{2}', partition count {3}",
140+
pipeline.Index, pipeline.DocumentId, generator.GetType().FullName, partitions.Count);
141+
142+
// One partition at a time
143+
foreach (PartitionInfo partitionInfo in partitions)
144+
{
145+
this._log.LogTrace("Generating embedding, pipeline '{0}/{1}', generator '{2}', content size {3} tokens",
146+
pipeline.Index, pipeline.DocumentId, generator.GetType().FullName, generator.CountTokens(partitionInfo.PartitionContent));
147+
148+
//we need to transform the partition content
149+
var textToEmbed = await _extractTextToEmbedAsync(partitionInfo.PartitionContent, cancellationToken).ConfigureAwait(false);
150+
151+
var embedding = await generator.GenerateEmbeddingAsync(textToEmbed, cancellationToken).ConfigureAwait(false);
152+
await this.SaveEmbeddingToDocumentStorageAsync(
153+
pipeline, partitionInfo, embedding, GetEmbeddingProviderName(generator), GetEmbeddingGeneratorName(generator), cancellationToken)
154+
.ConfigureAwait(false);
155+
}
156+
}
157+
}

0 commit comments

Comments
 (0)