Skip to content

Commit 2120135

Browse files
integrate IOManager with runner
Differential Revision: D78499456 Pull Request resolved: #12793
1 parent e44da93 commit 2120135

13 files changed

+72
-19
lines changed

examples/models/llava/runner/llava_runner.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@
1515
#include <executorch/examples/models/llava/runner/llava_text_decoder_runner.h>
1616
#include <pytorch/tokenizers/llama2c_tokenizer.h>
1717

18-
#include <ctime>
1918
#include <memory>
20-
#include <sstream>
2119
#include <vector>
2220

2321
namespace llm = ::executorch::extension::llm;
@@ -49,7 +47,8 @@ Error LlavaRunner::load() {
4947
// Load the text decoder runner
5048
text_decoder_runner_ =
5149
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
52-
std::make_unique<LlavaTextDecoderRunner>(module_.get());
50+
std::make_unique<LlavaTextDecoderRunner>(
51+
module_.get(), io_manager_.get());
5352
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
5453
text_decoder_runner_->load();
5554

examples/models/llava/runner/llava_text_decoder_runner.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@ namespace example {
1818
class ET_EXPERIMENTAL LlavaTextDecoderRunner
1919
: public executorch::extension::llm::TextDecoderRunner {
2020
public:
21-
explicit LlavaTextDecoderRunner(executorch::extension::Module* module)
22-
: TextDecoderRunner(module) {}
21+
explicit LlavaTextDecoderRunner(
22+
executorch::extension::Module* module,
23+
executorch::extension::llm::IOManager* io_manager)
24+
: TextDecoderRunner(module, io_manager) {}
2325

2426
inline executorch::runtime::Result<executorch::aten::Tensor> step(
2527
executorch::extension::TensorPtr& tokens,

extension/llm/runner/multimodal_runner.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
#include <functional>
1717
#include <memory>
1818
#include <string>
19-
#include <type_traits>
20-
#include <unordered_map>
2119

2220
#include <executorch/extension/llm/runner/image.h>
2321
#include <executorch/extension/llm/runner/image_prefiller.h>
22+
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
2423
#include <executorch/extension/llm/runner/stats.h>
2524
#include <executorch/extension/llm/runner/text_decoder_runner.h>
2625
#include <executorch/extension/llm/runner/text_prefiller.h>
@@ -41,6 +40,7 @@ class ET_EXPERIMENTAL MultimodalRunner {
4140
const float temperature = 0.8f)
4241
: temperature_(temperature),
4342
module_(std::make_unique<Module>(model_path, Module::LoadMode::File)),
43+
io_manager_(std::make_unique<IOManager>()),
4444
tokenizer_path_(tokenizer_path) {
4545
ET_LOG(
4646
Info,
@@ -127,6 +127,7 @@ class ET_EXPERIMENTAL MultimodalRunner {
127127
std::unique_ptr<TextDecoderRunner> text_decoder_runner_;
128128
std::unique_ptr<TextPrefiller> text_prefiller_;
129129
std::unique_ptr<ImagePrefiller> image_prefiller_;
130+
std::unique_ptr<IOManager> io_manager_;
130131
std::unique_ptr<TextTokenGenerator> text_token_generator_;
131132
std::string tokenizer_path_;
132133
std::unique_ptr<::tokenizers::Tokenizer> tokenizer_;

extension/llm/runner/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def define_common_targets():
3636
":stats",
3737
"//executorch/kernels/portable/cpu/util:arange_util" + aten_suffix,
3838
"//executorch/extension/llm/sampler:sampler" + aten_suffix,
39+
"//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix,
3940
"//executorch/extension/module:module" + aten_suffix,
4041
"//executorch/extension/tensor:tensor" + aten_suffix,
4142
],
@@ -102,6 +103,7 @@ def define_common_targets():
102103
":text_decoder_runner" + aten_suffix,
103104
":text_prefiller" + aten_suffix,
104105
":text_token_generator" + aten_suffix,
106+
"//executorch/extension/llm/runner/io_manager:io_manager" + aten_suffix,
105107
"//pytorch/tokenizers:hf_tokenizer",
106108
"//pytorch/tokenizers:llama2c_tokenizer",
107109
"//pytorch/tokenizers:sentencepiece",

extension/llm/runner/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ runtime.cxx_test(
1818
srcs = ["test_text_decoder_runner.cpp"],
1919
deps = [
2020
"//executorch/extension/llm/runner:runner_lib",
21+
"//executorch/extension/llm/runner/io_manager:io_manager",
2122
"//executorch/kernels/portable:generated_lib",
2223
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
2324
],

extension/llm/runner/test/test_text_decoder_runner.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
88
*/
99

10+
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
1011
#include <executorch/extension/llm/runner/text_decoder_runner.h>
1112
#include <executorch/extension/module/module.h>
1213
#include <executorch/extension/tensor/tensor.h>
@@ -18,6 +19,7 @@
1819
using namespace ::testing;
1920
using executorch::extension::Module;
2021
using executorch::extension::TensorPtr;
22+
using executorch::extension::llm::IOManager;
2123
using executorch::extension::llm::TextDecoderRunner;
2224
using executorch::runtime::Error;
2325
using executorch::runtime::EValue;
@@ -34,11 +36,14 @@ class TextDecoderRunnerTest : public Test {
3436
protected:
3537
void SetUp() override {
3638
mock_module_ = std::make_unique<MockModule>();
37-
runner_ = std::make_unique<TextDecoderRunner>(mock_module_.get());
39+
io_manager_ = std::make_unique<executorch::extension::llm::IOManager>();
40+
runner_ = std::make_unique<TextDecoderRunner>(
41+
mock_module_.get(), io_manager_.get());
3842
}
3943

4044
std::unique_ptr<MockModule> mock_module_;
4145
std::unique_ptr<TextDecoderRunner> runner_;
46+
std::unique_ptr<IOManager> io_manager_;
4247
};
4348

4449
// Test logits_to_token() method with Float tensor
@@ -150,15 +155,17 @@ TEST_F(TextDecoderRunnerTest, StepWithAllModels) {
150155

151156
// Load the model
152157
auto module = std::make_unique<Module>(model_path);
158+
153159
auto load_result = module->load();
154160
if (load_result != Error::Ok) {
155161
ADD_FAILURE() << "Failed to load model " << model_name << " from "
156162
<< model_path << " with error: " << (int)load_result;
157163
continue;
158164
}
159-
165+
std::unique_ptr<executorch::extension::llm::IOManager> io_manager =
166+
std::make_unique<executorch::extension::llm::IOManager>();
160167
// Create TextDecoderRunner
161-
TextDecoderRunner runner(module.get());
168+
TextDecoderRunner runner(module.get(), io_manager.get());
162169
auto runner_load_result = runner.load();
163170
ASSERT_EQ(runner_load_result, Error::Ok)
164171
<< "Failed to load runner for " << model_name;

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
88
*/
99

10+
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
1011
#include <executorch/extension/llm/runner/irunner.h>
1112
#include <executorch/extension/llm/runner/text_llm_runner.h>
1213
#include <executorch/extension/llm/runner/text_prefiller.h>
@@ -63,7 +64,7 @@ class MockModule : public ::executorch::extension::Module {
6364

6465
class MockTextDecoderRunner : public TextDecoderRunner {
6566
public:
66-
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
67+
MockTextDecoderRunner() : TextDecoderRunner(nullptr, nullptr) {}
6768
MOCK_METHOD(
6869
Result<executorch::aten::Tensor>,
6970
step,
@@ -219,6 +220,7 @@ TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) {
219220
std::move(text_decoder_runner),
220221
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
221222
text_prefiller.release()),
223+
std::make_unique<executorch::extension::llm::IOManager>(),
222224
std::move(text_token_generator),
223225
std::move(stats));
224226

@@ -278,6 +280,7 @@ TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) {
278280
std::move(text_decoder_runner),
279281
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
280282
text_prefiller.release()),
283+
std::make_unique<executorch::extension::llm::IOManager>(),
281284
std::move(text_token_generator),
282285
std::move(stats));
283286

@@ -312,6 +315,7 @@ TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) {
312315
std::move(text_decoder_runner),
313316
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
314317
text_prefiller.release()),
318+
std::make_unique<executorch::extension::llm::IOManager>(),
315319
std::move(text_token_generator),
316320
std::move(stats));
317321

@@ -356,6 +360,7 @@ TEST_F(RunnerTest, GenerateFromPosErrorsWithNegativeMaxNewTokens) {
356360
std::move(text_decoder_runner),
357361
std::unique_ptr<::executorch::extension::llm::TextPrefiller>(
358362
text_prefiller.release()),
363+
std::make_unique<executorch::extension::llm::IOManager>(),
359364
std::move(text_token_generator),
360365
std::move(stats));
361366

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ using executorch::runtime::testing::TensorFactory;
2424
// Mock class for TextDecoderRunner
2525
class MockTextDecoderRunner : public TextDecoderRunner {
2626
public:
27-
MockTextDecoderRunner() : TextDecoderRunner(nullptr) {}
27+
MockTextDecoderRunner() : TextDecoderRunner(nullptr, nullptr) {}
2828
MOCK_METHOD(
2929
Result<executorch::aten::Tensor>,
3030
step,

extension/llm/runner/text_decoder_runner.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ namespace llm {
2222
// NOTE: we observed ~2x loading performance increase on iPhone 15
2323
// and a ~5% improvement on Galaxy S22 by switching to
2424
// FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
25-
TextDecoderRunner::TextDecoderRunner(Module* module) : module_(module) {}
25+
TextDecoderRunner::TextDecoderRunner(Module* module, IOManager* io_manager)
26+
: module_(module), io_manager_(io_manager) {}
2627

2728
// This function is functional, meaning it shouldn't modify any state of the
2829
// input. It should be safe to call multiple times with the same inputs. The
@@ -66,8 +67,22 @@ ::executorch::runtime::Result<executorch::aten::Tensor> TextDecoderRunner::step(
6667
start_pos_tensor = from_blob(
6768
&start_pos, sizes_vec, ::executorch::aten::ScalarType::Long);
6869
}
69-
auto outputs_res = module_->forward({tokens, start_pos_tensor});
70+
71+
std::vector<runtime::EValue> inputs;
72+
auto method_err = module_->method("forward");
73+
ET_CHECK_OK_OR_RETURN_ERROR(method_err.error());
74+
auto& method = *(method_err.get());
75+
76+
auto inputs_res =
77+
io_manager_->prepare_decode(tokens, start_pos_tensor, method);
78+
ET_CHECK_OK_OR_RETURN_ERROR(inputs_res.error());
79+
inputs = inputs_res.get();
80+
auto outputs_res = module_->forward(inputs);
7081
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
82+
83+
auto update_err = io_manager_->update_decode(method, outputs_res.get());
84+
ET_CHECK_OK_OR_RETURN_ERROR(update_err);
85+
7186
ET_CHECK_MSG(
7287
outputs_res.get().size() == 1,
7388
"More then one output returned from executing LLM.");

extension/llm/runner/text_decoder_runner.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#pragma once
1212

13+
#include <executorch/extension/llm/runner/io_manager/io_manager.h>
1314
#include <executorch/extension/llm/sampler/sampler.h>
1415
#include <executorch/extension/module/module.h>
1516
#include <executorch/extension/tensor/tensor.h>
@@ -21,7 +22,7 @@ namespace llm {
2122

2223
class ET_EXPERIMENTAL TextDecoderRunner {
2324
public:
24-
explicit TextDecoderRunner(Module* module);
25+
explicit TextDecoderRunner(Module* module, IOManager* io_manager);
2526

2627
virtual ~TextDecoderRunner() = default;
2728

@@ -94,13 +95,14 @@ class ET_EXPERIMENTAL TextDecoderRunner {
9495

9596
protected:
9697
/**
97-
* Note: TextDecoderRunner does not own the Module instance. It is expected
98-
* that the outer class (likely Runner) manages the lifecycle of the Module.
99-
* This means that the responsibility for creating, maintaining, and
98+
* Note: TextDecoderRunner does not own the Module or IOManager instance. It
99+
* is expected that the outer class (likely Runner) manages the lifecycle of
100+
* them. This means that the responsibility for creating, maintaining, and
100101
* destroying the Module lies outside of TextDecoderRunner. Ensure that the
101102
* Module remains valid for the duration of TextDecoderRunner's usage.
102103
*/
103104
Module* module_;
105+
IOManager* io_manager_;
104106
bool should_stop_{false};
105107
};
106108

0 commit comments

Comments
 (0)