Skip to content

Commit 00e3dee

Browse files
committed
Add Whisper support
1 parent 731ad6f commit 00e3dee

File tree

12 files changed

+899
-123
lines changed

12 files changed

+899
-123
lines changed

DemoApp/App.xaml.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public App()
3737
builder.Services.AddSingleton<IDetectService, DetectService>();
3838
builder.Services.AddSingleton<ITextService, TextService>();
3939
builder.Services.AddSingleton<IInterpolationService, InterpolationService>();
40+
builder.Services.AddSingleton<IWhisperService, WhisperService>();
4041

4142
_appHost = builder.Build();
4243

DemoApp/Common/TextModel.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ public async Task<bool> DownloadAsync(string modelDirectory)
6060
public enum TextModelType
6161
{
6262
Summary = 0,
63-
Phi3 = 1
63+
Phi3 = 1,
64+
Whisper = 2,
65+
Supertonic = 3
6466
}
6567
}

DemoApp/Controls/TextModelControl.xaml.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public partial class TextModelControl : BaseControl
2323
private TextModel _selectedModel;
2424
private Device _currentDevice;
2525
private TextModel _currentModel;
26+
private TextModelType? _modelType;
2627

2728
/// <summary>
2829
/// Initializes a new instance of the <see cref="TextModelControl"/> class.
@@ -78,6 +79,12 @@ public ListCollectionView ModelCollectionView
7879
set { SetProperty(ref _modelCollectionView, value); }
7980
}
8081

82+
public TextModelType? ModelType
83+
{
84+
get { return _modelType; }
85+
set { SetProperty(ref _modelType, value); OnSettingsChanged(); }
86+
}
87+
8188

8289
private async Task LoadAsync()
8390
{
@@ -144,6 +151,9 @@ private Task OnSettingsChanged()
144151
if (_selectedDevice == null)
145152
return false;
146153

154+
if (_modelType != null && _modelType != viewModel.Type)
155+
return false;
156+
147157
return viewModel.SupportedDevices?.Contains(_selectedDevice.Type) ?? false;
148158
};
149159

DemoApp/DemoApp.csproj

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010

1111
<!--Common Packages-->
1212
<ItemGroup>
13-
<PackageReference Include="System.Linq.Async" Version="6.0.3" />
14-
<PackageReference Include="Microsoft.Extensions.Hosting" Version="10.0.0" />
15-
<PackageReference Include="TensorStack.Providers.DML" Version="0.1.78" />
16-
<PackageReference Include="TensorStack.WPF" Version="0.1.78" />
17-
<PackageReference Include="TensorStack.Upscaler" Version="0.1.78" />
18-
<PackageReference Include="TensorStack.Extractors" Version="0.1.78" />
19-
<PackageReference Include="TensorStack.TextGeneration" Version="0.1.78" />
20-
<PackageReference Include="TensorStack.StableDiffusion" Version="0.1.78" />
13+
<PackageReference Include="System.Linq.Async" Version="7.0.0" />
14+
<PackageReference Include="Microsoft.Extensions.Hosting" Version="10.0.1" />
15+
<PackageReference Include="TensorStack.Providers.DML" Version="0.2.1" />
16+
<PackageReference Include="TensorStack.WPF" Version="0.2.1" />
17+
<PackageReference Include="TensorStack.Upscaler" Version="0.2.1" />
18+
<PackageReference Include="TensorStack.Extractors" Version="0.2.1" />
19+
<PackageReference Include="TensorStack.TextGeneration" Version="0.2.1" />
20+
<PackageReference Include="TensorStack.StableDiffusion" Version="0.2.1" />
2121
</ItemGroup>
2222

2323

DemoApp/Services/DiffusionService.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
using TensorStack.StableDiffusion.Pipelines.StableDiffusion3;
1919
using TensorStack.StableDiffusion.Pipelines.StableDiffusionXL;
2020
using TensorStack.WPF;
21-
using Windows.Foundation;
2221

2322
namespace DemoApp.Services
2423
{
@@ -157,7 +156,7 @@ public async Task LoadAsync(PipelineModel pipeline)
157156
}
158157
else if (model.PipelineType == PipelineType.Nitro)
159158
{
160-
var nitroPipeline = NitroPipeline.FromFolder(model.Path, model.ModelType, provider);
159+
var nitroPipeline = NitroPipeline.FromFolder(model.Path, 512, model.ModelType, provider);
161160
_diffusionPipeline = nitroPipeline;
162161
_defaultOptions = nitroPipeline.DefaultOptions;
163162
}

DemoApp/Services/WhisperService.cs

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
using DemoApp.Common;
2+
using System;
3+
using System.Threading;
4+
using System.Threading.Tasks;
5+
using TensorStack.Common;
6+
using TensorStack.Common.Pipeline;
7+
using TensorStack.Common.Tensor;
8+
using TensorStack.Providers;
9+
using TensorStack.TextGeneration.Common;
10+
using TensorStack.TextGeneration.Pipelines.Whisper;
11+
12+
namespace DemoApp.Services
13+
{
14+
public class WhisperService : ServiceBase, IWhisperService
15+
{
16+
private readonly Settings _settings;
17+
private PipelineModel _currentPipeline;
18+
private IPipeline _whisperPipeline;
19+
private CancellationTokenSource _cancellationTokenSource;
20+
private bool _isLoaded;
21+
private bool _isLoading;
22+
private bool _isExecuting;
23+
24+
/// <summary>
25+
/// Initializes a new instance of the <see cref="WhisperService"/> class.
26+
/// </summary>
27+
/// <param name="settings">The settings.</param>
28+
public WhisperService(Settings settings)
29+
{
30+
_settings = settings;
31+
}
32+
33+
/// <summary>
34+
/// Gets the pipeline.
35+
/// </summary>
36+
public PipelineModel Pipeline => _currentPipeline;
37+
38+
/// <summary>
39+
/// Gets a value indicating whether this instance is loaded.
40+
/// </summary>
41+
public bool IsLoaded
42+
{
43+
get { return _isLoaded; }
44+
private set { SetProperty(ref _isLoaded, value); }
45+
}
46+
47+
/// <summary>
48+
/// Gets a value indicating whether this instance is loading.
49+
/// </summary>
50+
public bool IsLoading
51+
{
52+
get { return _isLoading; }
53+
private set { SetProperty(ref _isLoading, value); NotifyPropertyChanged(nameof(CanCancel)); }
54+
}
55+
56+
/// <summary>
57+
/// Gets a value indicating whether this instance is executing.
58+
/// </summary>
59+
public bool IsExecuting
60+
{
61+
get { return _isExecuting; }
62+
private set { SetProperty(ref _isExecuting, value); NotifyPropertyChanged(nameof(CanCancel)); }
63+
}
64+
65+
/// <summary>
66+
/// Gets a value indicating whether this instance can cancel.
67+
/// </summary>
68+
public bool CanCancel => _isLoading || _isExecuting;
69+
70+
71+
/// <summary>
72+
/// Load the upscale pipeline
73+
/// </summary>
74+
/// <param name="config">The configuration.</param>
75+
public async Task LoadAsync(PipelineModel pipeline)
76+
{
77+
try
78+
{
79+
IsLoaded = false;
80+
IsLoading = true;
81+
using (_cancellationTokenSource = new CancellationTokenSource())
82+
{
83+
var cancellationToken = _cancellationTokenSource.Token;
84+
if (_currentPipeline != null)
85+
await _whisperPipeline.UnloadAsync(cancellationToken);
86+
87+
_currentPipeline = pipeline;
88+
var model = _currentPipeline.TextModel;
89+
var provider = _currentPipeline.Device.GetProvider();
90+
var providerCPU = Provider.GetProvider(DeviceType.CPU); // TODO: DirectML not working with decoder
91+
92+
if (!Enum.TryParse<WhisperType>(model.Version, true, out var whisperType))
93+
throw new ArgumentException("Invalid WhisperType Version");
94+
95+
_whisperPipeline = WhisperPipeline.Create(providerCPU, model.Path, whisperType);
96+
await Task.Run(() => _whisperPipeline.LoadAsync(cancellationToken), cancellationToken);
97+
}
98+
}
99+
catch (OperationCanceledException)
100+
{
101+
_whisperPipeline?.Dispose();
102+
_whisperPipeline = null;
103+
_currentPipeline = null;
104+
throw;
105+
}
106+
finally
107+
{
108+
IsLoaded = true;
109+
IsLoading = false;
110+
}
111+
}
112+
113+
114+
/// <summary>
115+
/// Execute the pipeline.
116+
/// </summary>
117+
/// <param name="options">The options.</param>
118+
public async Task<GenerateResult[]> ExecuteAsync(WhisperRequest options)
119+
{
120+
try
121+
{
122+
IsExecuting = true;
123+
using (_cancellationTokenSource = new CancellationTokenSource())
124+
{
125+
var pipelineOptions = new WhisperOptions
126+
{
127+
Prompt = options.Prompt,
128+
Seed = options.Seed,
129+
Beams = options.Beams,
130+
TopK = options.TopK,
131+
TopP = options.TopP,
132+
Temperature = options.Temperature,
133+
MaxLength = options.MaxLength,
134+
MinLength = options.MinLength,
135+
NoRepeatNgramSize = options.NoRepeatNgramSize,
136+
LengthPenalty = options.LengthPenalty,
137+
DiversityLength = options.DiversityLength,
138+
EarlyStopping = options.EarlyStopping,
139+
AudioInput = options.AudioInput,
140+
Language = options.Language,
141+
Task = options.Task
142+
};
143+
144+
var pipelineResult = await Task.Run(async () =>
145+
{
146+
if (options.Beams == 0)
147+
{
148+
// Greedy Search
149+
var greedyPipeline = _whisperPipeline as IPipeline<GenerateResult, WhisperOptions, GenerateProgress>;
150+
return [await greedyPipeline.RunAsync(pipelineOptions, cancellationToken: _cancellationTokenSource.Token)];
151+
}
152+
153+
// Beam Search
154+
var beamSearchPipeline = _whisperPipeline as IPipeline<GenerateResult[], WhisperSearchOptions, GenerateProgress>;
155+
return await beamSearchPipeline.RunAsync(new WhisperSearchOptions(pipelineOptions), cancellationToken: _cancellationTokenSource.Token);
156+
});
157+
158+
return pipelineResult;
159+
}
160+
}
161+
finally
162+
{
163+
IsExecuting = false;
164+
}
165+
}
166+
167+
168+
/// <summary>
169+
/// Cancel the running task (Load or Execute)
170+
/// </summary>
171+
public async Task CancelAsync()
172+
{
173+
await _cancellationTokenSource.SafeCancelAsync();
174+
}
175+
176+
177+
/// <summary>
178+
/// Unload the pipeline
179+
/// </summary>
180+
public async Task UnloadAsync()
181+
{
182+
if (_currentPipeline != null)
183+
{
184+
await _cancellationTokenSource.SafeCancelAsync();
185+
await _whisperPipeline.UnloadAsync();
186+
_whisperPipeline.Dispose();
187+
_whisperPipeline = null;
188+
_currentPipeline = null;
189+
}
190+
191+
IsLoaded = false;
192+
IsLoading = false;
193+
IsExecuting = false;
194+
}
195+
}
196+
197+
198+
public interface IWhisperService
199+
{
200+
PipelineModel Pipeline { get; }
201+
bool IsLoaded { get; }
202+
bool IsLoading { get; }
203+
bool IsExecuting { get; }
204+
bool CanCancel { get; }
205+
Task LoadAsync(PipelineModel pipeline);
206+
Task UnloadAsync();
207+
Task CancelAsync();
208+
Task<GenerateResult[]> ExecuteAsync(WhisperRequest options);
209+
}
210+
211+
212+
public record WhisperRequest : ITransformerRequest
213+
{
214+
public AudioTensor AudioInput { get; set; }
215+
public LanguageType Language { get; set; } = LanguageType.EN;
216+
public TaskType Task { get; set; } = TaskType.Transcribe;
217+
218+
public string Prompt { get; set; }
219+
public int MinLength { get; set; } = 20;
220+
public int MaxLength { get; set; } = 200;
221+
public int NoRepeatNgramSize { get; set; } = 3;
222+
public int Seed { get; set; }
223+
public int Beams { get; set; } = 1;
224+
public int TopK { get; set; } = 1;
225+
public float TopP { get; set; } = 0.9f;
226+
public float Temperature { get; set; } = 1.0f;
227+
public float LengthPenalty { get; set; } = 1.0f;
228+
public EarlyStopping EarlyStopping { get; set; }
229+
public int DiversityLength { get; set; } = 5;
230+
}
231+
232+
}

0 commit comments

Comments
 (0)