Skip to content

[feat] Support EAGLE for Qwen2 #21363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Ximingwang-09
Copy link
Contributor

@Ximingwang-09 Ximingwang-09 commented Jul 22, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Support EAGLE for Qwen2

Test Plan

QwQ model and QwQ-EAGLE model
2*A100

VLLM_USE_V1=1 vllm serve /mnt/Qwen__QwQ-32B \
  --dtype auto \
  --tensor-parallel-size 2 \
  --enable-prefix-caching \
  --port 30000 \
  --max-model-len 32768 \
  --speculative_config '{"model": "/mnt/qwq_eagle", "method": "eagle", "num_speculative_tokens": 7}'

Test Result

W/O EAGLE

============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  64.18     
Total input tokens:                      20480     
Total generated tokens:                  2433      
Request throughput (req/s):              0.31      
Output token throughput (tok/s):         37.91     
Total Token throughput (tok/s):          357.03    
---------------Time to First Token----------------
Mean TTFT (ms):                          179.51    
Median TTFT (ms):                        188.56    
P99 TTFT (ms):                           204.18    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          25.10     
Median TPOT (ms):                        25.10     
P99 TPOT (ms):                           25.15     
---------------Inter-token Latency----------------
Mean ITL (ms):                           25.10     
Median ITL (ms):                         25.09     
P99 ITL (ms):                            25.78     
==================================================

With EAGLE

============ Serving Benchmark Result ============
Successful requests:                     20        
Benchmark duration (s):                  35.59     
Total input tokens:                      20480     
Total generated tokens:                  2410      
Request throughput (req/s):              0.56      
Output token throughput (tok/s):         67.71     
Total Token throughput (tok/s):          643.14    
---------------Time to First Token----------------
Mean TTFT (ms):                          187.96    
Median TTFT (ms):                        197.10    
P99 TTFT (ms):                           213.35    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          13.37     
Median TPOT (ms):                        11.10     
P99 TPOT (ms):                           30.89     
---------------Inter-token Latency----------------
Mean ITL (ms):                           33.67     
Median ITL (ms):                         33.63     
P99 ITL (ms):                            34.99     
==================================================

Log:

INFO 2025-07-22 16:58:56 8871 [loggers.py:116] Engine 000: Avg prompt throughput: 105.2 tokens/s, Avg generation throughput: 25.2 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 4.6%
INFO 2025-07-22 16:58:56 8871 [metrics.py:86] SpecDecoding metrics: Draft acceptance rate: 29.8%, Mean acceptance length: 3.09, Accepted: 171 tokens, Drafted: 574 tokens, Per-position acceptance rate: 0.768, 0.476, 0.341, 0.195, 0.195, 0.061, 0.049

(Optional) Documentation Update

@mergify mergify bot added new-model Requests to new models qwen Related to Qwen models speculative-decoding labels Jul 22, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for EAGLE speculative decoding for Qwen2 models, which shows a significant performance improvement based on the provided benchmarks. The changes include a new model file for the Qwen2 EAGLE head and an update to the model registry.

I've found a few critical issues in the implementation of EagleQwen2ForCausalLM that will prevent it from working correctly. Specifically, the lm_head is not initialized, and the forward method returns incorrect values. I've provided suggestions to fix these. There's also a minor issue in the load_weights method that should be addressed. Once these issues are resolved, this will be a great addition.

Comment on lines +111 to +138
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The __init__ method for EagleQwen2ForCausalLM is missing the initialization of self.lm_head. This attribute is necessary for computing logits from the hidden states in the forward method and for loading weights. Without it, the model will raise an AttributeError at runtime.

You should initialize self.lm_head, either by tying it to the token embeddings if self.config.tie_word_embeddings is true, or by creating a new VocabParallelEmbedding instance.

Suggested change
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.model = Qwen2Model(vllm_config=vllm_config,
prefix="model",
start_layer_id=target_layer_num)
if getattr(self.config, "tie_word_embeddings", False):
self.lm_head = self.model.embed_tokens
else:
self.lm_head = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
scale=logit_scale)

Comment on lines +123 to +146
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The forward method currently returns hidden states directly from the model. For speculative decoding, the draft model's forward method must return a tuple of (logits, hidden_states).

You need to compute the logits from the final hidden states using the logits_processor and the lm_head.

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        hidden_states, residual = self.model(input_ids, positions, hidden_states)
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits, residual

Comment on lines 130 to 140
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The load_weights method should return the set of loaded parameter names. This is used by the model loader to verify that all weights have been loaded correctly. The current implementation is missing the return statement.

Suggested change
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
loader.load_weights(model_weights.items())
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
)
model_weights = {}
for name, loaded_weight in weights:
if "lm_head" not in name:
name = "model." + name
model_weights[name] = loaded_weight
return loader.load_weights(model_weights.items())

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

纬杭 added 3 commits July 24, 2025 19:35
Signed-off-by: 纬杭 <ximing.wxm@antgroup.com>
fix
Signed-off-by: 纬杭 <ximing.wxm@antgroup.com>
fix
Signed-off-by: 纬杭 <ximing.wxm@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new-model Requests to new models qwen Related to Qwen models speculative-decoding
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant