-
-
Notifications
You must be signed in to change notification settings - Fork 9k
[feat] Support EAGLE for Qwen2 #21363
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for EAGLE speculative decoding for Qwen2 models, which shows a significant performance improvement based on the provided benchmarks. The changes include a new model file for the Qwen2 EAGLE head and an update to the model registry.
I've found a few critical issues in the implementation of EagleQwen2ForCausalLM
that will prevent it from working correctly. Specifically, the lm_head
is not initialized, and the forward
method returns incorrect values. I've provided suggestions to fix these. There's also a minor issue in the load_weights
method that should be addressed. Once these issues are resolved, this will be a great addition.
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | ||
nn.Module.__init__(self) | ||
self.config = vllm_config. \ | ||
speculative_config.draft_model_config.hf_config | ||
target_layer_num = vllm_config.model_config.get_num_layers( | ||
vllm_config.parallel_config) | ||
self.model = Qwen2Model(vllm_config=vllm_config, | ||
prefix="model", | ||
start_layer_id=target_layer_num) | ||
logit_scale = getattr(self.config, "logit_scale", 1.0) | ||
self.logits_processor = LogitsProcessor(self.config.vocab_size, | ||
scale=logit_scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The __init__
method for EagleQwen2ForCausalLM
is missing the initialization of self.lm_head
. This attribute is necessary for computing logits from the hidden states in the forward
method and for loading weights. Without it, the model will raise an AttributeError
at runtime.
You should initialize self.lm_head
, either by tying it to the token embeddings if self.config.tie_word_embeddings
is true, or by creating a new VocabParallelEmbedding
instance.
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
nn.Module.__init__(self) | |
self.config = vllm_config. \ | |
speculative_config.draft_model_config.hf_config | |
target_layer_num = vllm_config.model_config.get_num_layers( | |
vllm_config.parallel_config) | |
self.model = Qwen2Model(vllm_config=vllm_config, | |
prefix="model", | |
start_layer_id=target_layer_num) | |
logit_scale = getattr(self.config, "logit_scale", 1.0) | |
self.logits_processor = LogitsProcessor(self.config.vocab_size, | |
scale=logit_scale) | |
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): | |
nn.Module.__init__(self) | |
self.config = vllm_config. \ | |
speculative_config.draft_model_config.hf_config | |
target_layer_num = vllm_config.model_config.get_num_layers( | |
vllm_config.parallel_config) | |
self.model = Qwen2Model(vllm_config=vllm_config, | |
prefix="model", | |
start_layer_id=target_layer_num) | |
if getattr(self.config, "tie_word_embeddings", False): | |
self.lm_head = self.model.embed_tokens | |
else: | |
self.lm_head = VocabParallelEmbedding( | |
self.config.vocab_size, | |
self.config.hidden_size, | |
) | |
logit_scale = getattr(self.config, "logit_scale", 1.0) | |
self.logits_processor = LogitsProcessor(self.config.vocab_size, | |
scale=logit_scale) |
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
return self.model(input_ids, positions, hidden_states) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The forward
method currently returns hidden states directly from the model. For speculative decoding, the draft model's forward
method must return a tuple of (logits, hidden_states)
.
You need to compute the logits from the final hidden states using the logits_processor
and the lm_head
.
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.model(input_ids, positions, hidden_states)
logits = self.logits_processor(self.lm_head, hidden_states)
return logits, residual
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | ||
loader = AutoWeightsLoader( | ||
self, | ||
skip_prefixes=None, | ||
) | ||
model_weights = {} | ||
for name, loaded_weight in weights: | ||
if "lm_head" not in name: | ||
name = "model." + name | ||
model_weights[name] = loaded_weight | ||
loader.load_weights(model_weights.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The load_weights
method should return the set of loaded parameter names. This is used by the model loader to verify that all weights have been loaded correctly. The current implementation is missing the return statement.
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | |
loader = AutoWeightsLoader( | |
self, | |
skip_prefixes=None, | |
) | |
model_weights = {} | |
for name, loaded_weight in weights: | |
if "lm_head" not in name: | |
name = "model." + name | |
model_weights[name] = loaded_weight | |
loader.load_weights(model_weights.items()) | |
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): | |
loader = AutoWeightsLoader( | |
self, | |
skip_prefixes=None, | |
) | |
model_weights = {} | |
for name, loaded_weight in weights: | |
if "lm_head" not in name: | |
name = "model." + name | |
model_weights[name] = loaded_weight | |
return loader.load_weights(model_weights.items()) |
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
92223d1
to
c56c3d0
Compare
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.Purpose
Support EAGLE for Qwen2
Test Plan
QwQ model and QwQ-EAGLE model
2*A100
Test Result
W/O EAGLE
With EAGLE
Log:
(Optional) Documentation Update